blob: e066146eb0c7d28c9c1cdb18c5b32d2409a06a4b [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.audio;
import static org.tensorflow.lite.support.common.internal.SupportPreconditions.checkArgument;
import static java.lang.System.arraycopy;
import android.media.AudioFormat;
import android.media.AudioRecord;
import android.os.Build;
import androidx.annotation.RequiresApi;
import com.google.auto.value.AutoValue;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
/**
* Defines a ring buffer and some utility functions to prepare the input audio samples.
*
* <p>It maintains a <a href="https://en.wikipedia.org/wiki/Circular_buffer">Ring Buffer</a> to hold
* input audio data. Clients could feed input audio data via `load` methods and access the
* aggregated audio samples via `getTensorBuffer` method.
*
* <p>Note that this class can only handle input audio in Float (in {@link
* android.media.AudioFormat#ENCODING_PCM_16BIT}) or Short (in {@link
* android.media.AudioFormat#ENCODING_PCM_FLOAT}). Internally it converts and stores all the audio
* samples in PCM Float encoding.
*
* <p>Typical usage in Kotlin
*
* <pre>
* val tensor = TensorAudio.create(format, modelInputLength)
* tensor.load(newData)
* interpreter.run(tensor.getTensorBuffer(), outputBuffer);
* </pre>
*
* <p>Another sample usage with {@link android.media.AudioRecord}
*
* <pre>
* val tensor = TensorAudio.create(format, modelInputLength)
* Timer().scheduleAtFixedRate(delay, period) {
* tensor.load(audioRecord)
* interpreter.run(tensor.getTensorBuffer(), outputBuffer)
* }
* </pre>
*/
public class TensorAudio {
private static final String TAG = TensorAudio.class.getSimpleName();
private final FloatRingBuffer buffer;
private final TensorAudioFormat format;
/**
* Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
* sampleCounts} * {@code format.getChannels()}.
*
* @param format the expected {@link TensorAudioFormat} of audio data loaded into this class.
* @param sampleCounts the number of samples to be fed into the model
*/
public static TensorAudio create(TensorAudioFormat format, int sampleCounts) {
return new TensorAudio(format, sampleCounts);
}
/**
* Creates a {@link TensorAudio} instance with a ring buffer whose size is {@code sampleCounts}
* *
* {@code format.getChannelCount()}.
*
* @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
* the number of channels and sample rate.
* @param sampleCounts the number of samples to be fed into the model
*/
public static TensorAudio create(AudioFormat format, int sampleCounts) {
return new TensorAudio(TensorAudioFormat.create(format), sampleCounts);
}
/**
* Wraps a few constants describing the format of the incoming audio samples, namely number of
* channels and the sample rate. By default, channels is set to 1.
*/
@AutoValue
public abstract static class TensorAudioFormat {
private static final int DEFAULT_CHANNELS = 1;
/** Creates a {@link TensorAudioFormat} instance from Android AudioFormat class. */
@RequiresApi(Build.VERSION_CODES.M)
public static TensorAudioFormat create(AudioFormat format) {
return TensorAudioFormat.builder()
.setChannels(format.getChannelCount())
.setSampleRate(format.getSampleRate())
.build();
}
public abstract int getChannels();
public abstract int getSampleRate();
public static Builder builder() {
return new AutoValue_TensorAudio_TensorAudioFormat.Builder().setChannels(
DEFAULT_CHANNELS);
}
/** Builder for {@link TensorAudioFormat} */
@AutoValue.Builder
public abstract static class Builder {
/* By default, it's set to have 1 channel. */
public abstract Builder setChannels(int value);
public abstract Builder setSampleRate(int value);
abstract TensorAudioFormat autoBuild();
public TensorAudioFormat build() {
TensorAudioFormat format = autoBuild();
checkArgument(
format.getChannels() > 0, "Number of channels should be greater than 0");
checkArgument(format.getSampleRate() > 0, "Sample rate should be greater than 0");
return format;
}
}
}
/**
* Stores the input audio samples {@code src} in the ring buffer.
*
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
* multi-channel input, the array is interleaved.
*/
public void load(float[] src) {
load(src, 0, src.length);
}
/**
* Stores the input audio samples {@code src} in the ring buffer.
*
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
* multi-channel input, the array is interleaved.
* @param offsetInFloat starting position in the {@code src} array
* @param sizeInFloat the number of float values to be copied
* @throws IllegalArgumentException for incompatible audio format or incorrect input size
*/
public void load(float[] src, int offsetInFloat, int sizeInFloat) {
checkArgument(sizeInFloat % format.getChannels() == 0,
String.format("Size (%d) needs to be a multiplier of the number of channels (%d)",
sizeInFloat, format.getChannels()));
buffer.load(src, offsetInFloat, sizeInFloat);
}
/**
* Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the
* ring buffer.
*
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
* multi-channel input, the array is interleaved.
*/
public void load(short[] src) {
load(src, 0, src.length);
}
/**
* Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the
* ring buffer.
*
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
* multi-channel input, the array is interleaved.
* @param offsetInShort starting position in the src array
* @param sizeInShort the number of short values to be copied
* @throws IllegalArgumentException if the source array can't be copied
*/
public void load(short[] src, int offsetInShort, int sizeInShort) {
checkArgument(offsetInShort + sizeInShort <= src.length,
String.format(
"Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
offsetInShort, sizeInShort, src.length));
float[] floatData = new float[sizeInShort];
for (int i = offsetInShort; i < sizeInShort; i++) {
// Convert the data to PCM Float encoding i.e. values between -1 and 1
floatData[i] = src[i] / Short.MAX_VALUE;
}
load(floatData);
}
/**
* Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
* supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
*
* @param record an instance of {@link android.media.AudioRecord}
* @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
* there was no new data in the AudioRecord or an error occurred, this method will return 0.
* @throws IllegalArgumentException for unsupported audio encoding format
* @throws IllegalStateException if reading from AudioRecord failed
*/
@RequiresApi(Build.VERSION_CODES.M)
public int load(AudioRecord record) {
checkArgument(this.format.equals(TensorAudioFormat.create(record.getFormat())),
"Incompatible audio format.");
int loadedValues = 0;
if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
if (loadedValues > 0) {
load(newData, 0, loadedValues);
return loadedValues;
}
} else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
if (loadedValues > 0) {
load(newData, 0, loadedValues);
return loadedValues;
}
} else {
throw new IllegalArgumentException(
"Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
}
switch (loadedValues) {
case AudioRecord.ERROR_INVALID_OPERATION:
throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
case AudioRecord.ERROR_BAD_VALUE:
throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
case AudioRecord.ERROR_DEAD_OBJECT:
throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
case AudioRecord.ERROR:
throw new IllegalStateException("AudioRecord.ERROR");
default:
return 0;
}
}
/**
* Returns a float {@link TensorBuffer} holding all the available audio samples in {@link
* android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
*/
public TensorBuffer getTensorBuffer() {
ByteBuffer byteBuffer = buffer.getBuffer();
TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(
new int[] {/* batch= */ 1,
/* modelInputLengthInFloat= */ byteBuffer.asFloatBuffer().limit()},
DataType.FLOAT32);
tensorBuffer.loadBuffer(byteBuffer);
return tensorBuffer;
}
/* Returns the {@link TensorAudioFormat} associated with the tensor. */
public TensorAudioFormat getFormat() {
return format;
}
private TensorAudio(TensorAudioFormat format, int sampleCounts) {
this.format = format;
this.buffer = new FloatRingBuffer(sampleCounts * format.getChannels());
}
/** Actual implementation of the ring buffer. */
private static class FloatRingBuffer {
private final float[] buffer;
private int nextIndex = 0;
public FloatRingBuffer(int flatSize) {
buffer = new float[flatSize];
}
/**
* Loads the entire float array to the ring buffer. If the float array is longer than ring
* buffer's capacity, samples with lower indices in the array will be ignored.
*/
public void load(float[] newData) {
load(newData, 0, newData.length);
}
/**
* Loads a slice of the float array to the ring buffer. If the float array is longer than
* ring buffer's capacity, samples with lower indices in the array will be ignored.
*/
public void load(float[] newData, int offset, int size) {
checkArgument(offset + size <= newData.length,
String.format(
"Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
offset, size, newData.length));
// If buffer can't hold all the data, only keep the most recent data of size
// buffer.length
if (size > buffer.length) {
offset = size - buffer.length;
size = buffer.length;
}
if (nextIndex + size < buffer.length) {
// No need to wrap nextIndex, just copy newData[offset:offset + size]
// to buffer[nextIndex:nextIndex+size]
arraycopy(newData, offset, buffer, nextIndex, size);
} else {
// Need to wrap nextIndex, perform copy in two chunks.
int firstChunkSize = buffer.length - nextIndex;
// First copy newData[offset:offset+firstChunkSize] to
// buffer[nextIndex:buffer.length]
arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
// Then copy newData[offset+firstChunkSize:offset+size] to
// buffer[0:size-firstChunkSize]
arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
}
nextIndex = (nextIndex + size) % buffer.length;
}
public ByteBuffer getBuffer() {
// Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms,
// which can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around
// 0.01ms), so generally we don't create direct buffer for every invocation.
ByteBuffer byteBuffer =
ByteBuffer.allocate(DataType.FLOAT32.byteSize() * buffer.length);
byteBuffer.order(ByteOrder.nativeOrder());
FloatBuffer result = byteBuffer.asFloatBuffer();
result.put(buffer, nextIndex, buffer.length - nextIndex);
result.put(buffer, 0, nextIndex);
byteBuffer.rewind();
return byteBuffer;
}
public int getCapacity() {
return buffer.length;
}
}
}