Audio classifier for training/inference and exporing.
tflite_model_maker.audio_classifier.AudioClassifier(
model_spec, index_to_label, shuffle, train_whole_model
)
Args |
model_spec
|
Specification for the model.
|
index_to_label
|
A list that map from index to label class name.
|
shuffle
|
Whether the data should be shuffled.
|
train_whole_model
|
If true, the Hub module is trained together with the
classification layer on top. Otherwise, only train the top
classification layer.
|
Methods
confusion_matrix
View source
confusion_matrix(
data, batch_size=32
)
create
View source
@classmethod
create(
train_data,
model_spec,
validation_data=None,
batch_size=32,
epochs=5,
model_dir=None,
do_train=True,
train_whole_model=False
)
Loads data and retrains the model.
Args |
train_data
|
A instance of audio_dataloader.DataLoader class.
|
model_spec
|
Specification for the model.
|
validation_data
|
Validation DataLoader. If None, skips validation process.
|
batch_size
|
Number of samples per training step. If use_hub_library is
False, it represents the base learning rate when train batch size is 256
and it's linear to the batch size.
|
epochs
|
Number of epochs for training.
|
model_dir
|
The location of the model checkpoint files.
|
do_train
|
Whether to run training.
|
train_whole_model
|
Boolean. By default, only the classification head is
trained. When True, the base model is also trained.
|
Returns |
An instance based on AudioClassifier.
|
create_model
View source
create_model(
num_classes, train_whole_model
)
create_serving_model
View source
create_serving_model()
Returns the underlining Keras model for serving.
evaluate
View source
evaluate(
data, batch_size=32
)
Evaluates the model.
Args |
data
|
Data to be evaluated.
|
batch_size
|
Number of samples per evaluation step.
|
Returns |
The loss value and accuracy.
|
evaluate_tflite
View source
evaluate_tflite(
tflite_filepath, data, postprocess_fn=None
)
Evaluates the tflite model.
Args |
tflite_filepath
|
File path to the TFLite model.
|
data
|
Data to be evaluated.
|
postprocess_fn
|
Postprocessing function that will be applied to the output
of lite_runner.run before calculating the probabilities.
|
Returns |
The evaluation result of TFLite model - accuracy.
|
export
View source
export(
export_dir,
tflite_filename='model.tflite',
label_filename='labels.txt',
vocab_filename='vocab.txt',
saved_model_filename='saved_model',
tfjs_folder_name='tfjs',
export_format=None,
**kwargs
)
Converts the retrained model based on export_format
.
Args |
export_dir
|
The directory to save exported files.
|
tflite_filename
|
File name to save tflite model. The full export path is
{export_dir}/{tflite_filename}.
|
label_filename
|
File name to save labels. The full export path is
{export_dir}/{label_filename}.
|
vocab_filename
|
File name to save vocabulary. The full export path is
{export_dir}/{vocab_filename}.
|
saved_model_filename
|
Path to SavedModel or H5 file to save the model. The
full export path is
{export_dir}/{saved_model_filename}/{saved_model.pb|assets|variables}.
|
tfjs_folder_name
|
Folder name to save tfjs model. The full export path is
{export_dir}/{tfjs_folder_name}.
|
export_format
|
List of export format that could be saved_model, tflite,
label, vocab.
|
**kwargs
|
Other parameters like quantized_config for TFLITE model.
|
predict_top_k
View source
predict_top_k(
data, k=1, batch_size=32
)
Predicts the top-k predictions.
Args |
data
|
Data to be evaluated. Either an instance of DataLoader or just raw
data entries such TF tensor or numpy array.
|
k
|
Number of top results to be predicted.
|
batch_size
|
Number of samples per evaluation step.
|
Returns |
top k results. Each one is (label, probability).
|
summary
View source
summary()
train
View source
train(
train_data, validation_data, epochs, batch_size
)
Class Variables |
ALLOWED_EXPORT_FORMAT
|
(<ExportFormat.LABEL: 'LABEL'>,
<ExportFormat.TFLITE: 'TFLITE'>,
<ExportFormat.SAVED_MODEL: 'SAVED_MODEL'>)
|
DEFAULT_EXPORT_FORMAT
|
<ExportFormat.TFLITE: 'TFLITE'>
|