ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

AudioClassifier

public final class AudioClassifier

Performs classification on audio waveforms.

The API expects a TFLite model with TFLite Model Metadata..

The API supports models with one audio input tensor and one classification output tensor. To be more specific, here are the requirements.

  • Input audio tensor (kTfLiteFloat32)
    • input audio buffer of size [batch x samples].
    • batch inference is not supported (batch is required to be 1).
  • Output score tensor (kTfLiteFloat32)
    • with N classes of either 2 or 4 dimensions, such as [1 x N] or [1 x 1 x 1 x N]
    • the label file is required to be packed to the metadata. See the example of creating metadata for an image classifier. If no label files are packed, it will use index as label in the result.
See an example of such model, and a CLI demo tool for easily trying out this API.

Nested Classes

class AudioClassifier.AudioClassifierOptions Options for setting up an AudioClassifier

Public Methods

List<Classifications>
classify(TensorAudio tensor)
Performs actual classification on the provided audio tensor.
AudioRecord
createAudioRecord()
Creates an AudioRecord instance to record audio stream.
static AudioClassifier
createFromBuffer(ByteBuffer modelBuffer)
Creates an AudioClassifier instance with a model buffer and the default AudioClassifier.AudioClassifierOptions.
static AudioClassifier
static AudioClassifier
createFromFile(Context context, String modelPath)
Creates an AudioClassifier instance from the default AudioClassifier.AudioClassifierOptions.
static AudioClassifier
createFromFile(File modelFile)
Creates an AudioClassifier instance from the default AudioClassifier.AudioClassifierOptions.
static AudioClassifier
static AudioClassifier
TensorAudio
createInputTensorAudio()
Creates a TensorAudio instance to store input audio samples.
long
getRequiredInputBufferSize()
Returns the required input buffer size in number of float elements.
TensorAudio.TensorAudioFormat
getRequiredTensorAudioFormat()
Returns the TensorAudio.TensorAudioFormat required by the model.

Inherited Methods

Public Methods

public List<Classifications> classify (TensorAudio tensor)

Performs actual classification on the provided audio tensor.

Parameters
tensor a TensorAudio containing the input audio clip in float with values between [-1, 1). The tensor argument should have the same flat size as the TFLite model's input tensor. It's recommended to create tensor using createInputTensorAudio method.
Throws
IllegalArgumentException if an argument is invalid
IllegalStateException if error occurs when classifying the audio clip from the native code

public AudioRecord createAudioRecord ()

Creates an AudioRecord instance to record audio stream. The returned AudioRecord instance is initialized and client needs to call AudioRecord.startRecordingnull method to start recording.

Throws
IllegalArgumentException if the model required channel count is unsupported
IllegalStateException if AudioRecord instance failed to initialize

public static AudioClassifier createFromBuffer (ByteBuffer modelBuffer)

Creates an AudioClassifier instance with a model buffer and the default AudioClassifier.AudioClassifierOptions.

Parameters
modelBuffer a direct ByteBuffer or a MappedByteBuffer of the classification model
Throws
IllegalStateException if there is an internal error
RuntimeException if there is an otherwise unspecified error
IllegalArgumentException if the model buffer is not a direct ByteBuffer or a MappedByteBuffer

public static AudioClassifier createFromBufferAndOptions (ByteBuffer modelBuffer, AudioClassifier.AudioClassifierOptions options)

Creates an AudioClassifier instance with a model buffer and AudioClassifier.AudioClassifierOptions.

Parameters
modelBuffer a direct ByteBuffer or a MappedByteBuffer of the classification model
options
Throws
IllegalStateException if there is an internal error
RuntimeException if there is an otherwise unspecified error
IllegalArgumentException if the model buffer is not a direct ByteBuffer or a MappedByteBuffer

public static AudioClassifier createFromFile (Context context, String modelPath)

Creates an AudioClassifier instance from the default AudioClassifier.AudioClassifierOptions.

Parameters
context
modelPath path of the classification model with metadata in the assets
Throws
IOException if an I/O error occurs when loading the tflite model
IllegalArgumentException if an argument is invalid
IllegalStateException if there is an internal error
RuntimeException if there is an otherwise unspecified error

public static AudioClassifier createFromFile (File modelFile)

Creates an AudioClassifier instance from the default AudioClassifier.AudioClassifierOptions.

Parameters
modelFile the classification model File instance
Throws
IOException if an I/O error occurs when loading the tflite model
IllegalArgumentException if an argument is invalid
IllegalStateException if there is an internal error
RuntimeException if there is an otherwise unspecified error

public static AudioClassifier createFromFileAndOptions (Context context, String modelPath, AudioClassifier.AudioClassifierOptions options)

Parameters
context
modelPath path of the classification model with metadata in the assets
options
Throws
IOException if an I/O error occurs when loading the tflite model
IllegalArgumentException if an argument is invalid
IllegalStateException if there is an internal error
RuntimeException if there is an otherwise unspecified error

public static AudioClassifier createFromFileAndOptions (File modelFile, AudioClassifier.AudioClassifierOptions options)

Creates an AudioClassifier instance.

Parameters
modelFile the classification model File instance
options
Throws
IOException if an I/O error occurs when loading the tflite model
IllegalArgumentException if an argument is invalid
IllegalStateException if there is an internal error
RuntimeException if there is an otherwise unspecified error

public TensorAudio createInputTensorAudio ()

Creates a TensorAudio instance to store input audio samples.

Returns
Throws
IllegalArgumentException if the model is not compatible

public long getRequiredInputBufferSize ()

Returns the required input buffer size in number of float elements.

public TensorAudio.TensorAudioFormat getRequiredTensorAudioFormat ()

Returns the TensorAudio.TensorAudioFormat required by the model.