![]() | ![]() | ![]() | ![]() | ![]() |
A biblioteca Criador TensorFlow Lite Modelo simplifica o processo de adaptação e conversão de um modelo neural-network TensorFlow a determinados dados de entrada ao implantar este modelo para aplicações ML no dispositivo.
Este notebook mostra um exemplo de ponta a ponta que utiliza esta biblioteca Model Maker para ilustrar a adaptação e conversão de um modelo de classificação de imagem comumente usado para classificar flores em um dispositivo móvel.
Pré-requisitos
Para executar esse exemplo, primeiro é necessário instalar vários pacotes necessários, incluindo pacote Fabricante Modelo que no GitHub repo .
pip install -q tflite-model-maker
Importe os pacotes necessários.
import os
import numpy as np
import tensorflow as tf
assert tf.__version__.startswith('2')
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release PkgResourcesDeprecationWarning, /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/numba/core/errors.py:168: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9 warnings.warn(msg)
Exemplo simples de ponta a ponta
Obtenha o caminho de dados
Vamos pegar algumas imagens para brincar com este exemplo simples de ponta a ponta. Centenas de imagens são um bom começo para o Model Maker, enquanto mais dados podem alcançar melhor precisão.
image_path = tf.keras.utils.get_file(
'flower_photos.tgz',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz 228818944/228813984 [==============================] - 1s 0us/step 228827136/228813984 [==============================] - 1s 0us/step
Você poderia substituir image_path
com suas próprias pastas de imagem. Quanto ao upload de dados para o colab, você pode encontrar o botão de upload na barra lateral esquerda mostrada na imagem abaixo com o retângulo vermelho. Basta tentar fazer o upload de um arquivo zip e descompactá-lo. O caminho do arquivo raiz é o caminho atual.
Se você preferir não fazer o upload de suas imagens para a nuvem, você pode tentar executar a biblioteca local seguindo o guia no GitHub.
Execute o exemplo
O exemplo consiste apenas em 4 linhas de código, conforme mostrado abaixo, cada uma representando uma etapa do processo geral.
Etapa 1. Carregar dados de entrada específicos para um aplicativo de ML no dispositivo. Divida-o em dados de treinamento e dados de teste.
data = DataLoader.from_folder(image_path)
train_data, test_data = data.split(0.9)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.
Etapa 2. Personalize o modelo do TensorFlow.
model = image_classifier.create(train_data)
INFO:tensorflow:Retraining the models... Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2 (HubKer (None, 1280) 3413024 _________________________________________________________________ dropout (Dropout) (None, 1280) 0 _________________________________________________________________ dense (Dense) (None, 5) 6405 ================================================================= Total params: 3,419,429 Trainable params: 6,405 Non-trainable params: 3,413,024 _________________________________________________________________ None Epoch 1/5 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. "The `lr` argument is deprecated, use `learning_rate` instead.") 103/103 [==============================] - 7s 35ms/step - loss: 0.8551 - accuracy: 0.7718 Epoch 2/5 103/103 [==============================] - 4s 35ms/step - loss: 0.6503 - accuracy: 0.8956 Epoch 3/5 103/103 [==============================] - 4s 34ms/step - loss: 0.6157 - accuracy: 0.9196 Epoch 4/5 103/103 [==============================] - 3s 33ms/step - loss: 0.6036 - accuracy: 0.9293 Epoch 5/5 103/103 [==============================] - 4s 34ms/step - loss: 0.5929 - accuracy: 0.9317
Etapa 3. Avalie o modelo.
loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 2s 40ms/step - loss: 0.6282 - accuracy: 0.9019
Etapa 4. Exporte para o modelo do TensorFlow Lite.
Aqui, nós exportamos modelo TensorFlow Lite com metadados que fornece um padrão para descrição de modelo. O arquivo de rótulo é incorporado aos metadados. A técnica de quantização pós-treinamento padrão é a quantização de inteiros completos para a tarefa de classificação de imagens.
Você pode baixá-lo na barra lateral esquerda da mesma forma que a parte de upload para seu próprio uso.
model.export(export_dir='.')
2021-11-02 11:34:05.568024: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/tmpkqikzotp/assets INFO:tensorflow:Assets written to: /tmp/tmpkqikzotp/assets 2021-11-02 11:34:09.488041: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-11-02 11:34:09.488090: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3 WARNING:absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Saving labels in /tmp/tmpoblx4ed5/labels.txt INFO:tensorflow:Saving labels in /tmp/tmpoblx4ed5/labels.txt INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
Após estas simples 4 passos, poderíamos usar mais arquivo de modelo TensorFlow Lite em aplicações no dispositivo como na imagem classificação aplicativo de referência.
Processo Detalhado
Atualmente, oferecemos suporte a vários modelos, como modelos EfficientNet-Lite*, MobileNetV2, ResNet50 como modelos pré-treinados para classificação de imagens. Mas é muito flexível adicionar novos modelos pré-treinados a essa biblioteca com apenas algumas linhas de código.
Veja a seguir este exemplo de ponta a ponta passo a passo para mostrar mais detalhes.
Etapa 1: carregar dados de entrada específicos para um aplicativo de ML no dispositivo
O conjunto de dados de flores contém 3670 imagens pertencentes a 5 classes. Baixe a versão de arquivo do conjunto de dados e descompacte-o.
O conjunto de dados tem a seguinte estrutura de diretórios:
flower_photos |__ daisy |______ 100080576_f52e8ee070_n.jpg |______ 14167534527_781ceb1b7a_n.jpg |______ ... |__ dandelion |______ 10043234166_e6dd915111_n.jpg |______ 1426682852_e62169221f_m.jpg |______ ... |__ roses |______ 102501987_3cdb8e5394_n.jpg |______ 14982802401_a3dfb22afb.jpg |______ ... |__ sunflowers |______ 12471791574_bb1be83df4.jpg |______ 15122112402_cafa41934f.jpg |______ ... |__ tulips |______ 13976522214_ccec508fe7.jpg |______ 14487943607_651e8062a1_m.jpg |______ ...
image_path = tf.keras.utils.get_file(
'flower_photos.tgz',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
Use DataLoader
classe para carregar dados.
Quanto from_folder()
método, pode carregar os dados a partir da pasta. Ele assume que os dados de imagem da mesma classe estão no mesmo subdiretório e o nome da subpasta é o nome da classe. Atualmente, há suporte para imagens codificadas em JPEG e imagens codificadas em PNG.
data = DataLoader.from_folder(image_path)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips. INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.
Divida-o em dados de treinamento (80%), dados de validação (10%, opcional) e dados de teste (10%).
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
Mostre 25 exemplos de imagens com rótulos.
plt.figure(figsize=(10,10))
for i, (image, label) in enumerate(data.gen_dataset().unbatch().take(25)):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(image.numpy(), cmap=plt.cm.gray)
plt.xlabel(data.index_to_label[label.numpy()])
plt.show()
Etapa 2: personalizar o modelo do TensorFlow
Crie um modelo de classificador de imagem personalizado com base nos dados carregados. O modelo padrão é EfficientNet-Lite0.
model = image_classifier.create(train_data, validation_data=validation_data)
INFO:tensorflow:Retraining the models... INFO:tensorflow:Retraining the models... Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2_1 (HubK (None, 1280) 3413024 _________________________________________________________________ dropout_1 (Dropout) (None, 1280) 0 _________________________________________________________________ dense_1 (Dense) (None, 5) 6405 ================================================================= Total params: 3,419,429 Trainable params: 6,405 Non-trainable params: 3,413,024 _________________________________________________________________ None Epoch 1/5 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. "The `lr` argument is deprecated, use `learning_rate` instead.") 91/91 [==============================] - 6s 54ms/step - loss: 0.8689 - accuracy: 0.7655 - val_loss: 0.6941 - val_accuracy: 0.8835 Epoch 2/5 91/91 [==============================] - 5s 50ms/step - loss: 0.6596 - accuracy: 0.8949 - val_loss: 0.6668 - val_accuracy: 0.8807 Epoch 3/5 91/91 [==============================] - 5s 50ms/step - loss: 0.6188 - accuracy: 0.9159 - val_loss: 0.6537 - val_accuracy: 0.8807 Epoch 4/5 91/91 [==============================] - 5s 52ms/step - loss: 0.6050 - accuracy: 0.9210 - val_loss: 0.6432 - val_accuracy: 0.8892 Epoch 5/5 91/91 [==============================] - 5s 52ms/step - loss: 0.5898 - accuracy: 0.9348 - val_loss: 0.6348 - val_accuracy: 0.8864
Dê uma olhada na estrutura detalhada do modelo.
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2_1 (HubK (None, 1280) 3413024 _________________________________________________________________ dropout_1 (Dropout) (None, 1280) 0 _________________________________________________________________ dense_1 (Dense) (None, 5) 6405 ================================================================= Total params: 3,419,429 Trainable params: 6,405 Non-trainable params: 3,413,024 _________________________________________________________________
Etapa 3: avaliar o modelo personalizado
Avalie o resultado do modelo, obtenha a perda e a precisão do modelo.
loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 27ms/step - loss: 0.6324 - accuracy: 0.8965
Poderíamos plotar os resultados previstos em 100 imagens de teste. Os rótulos previstos com cor vermelha são os resultados previstos errados, enquanto outros estão corretos.
# A helper function that returns 'red'/'black' depending on if its two input
# parameter matches or not.
def get_label_color(val1, val2):
if val1 == val2:
return 'black'
else:
return 'red'
# Then plot 100 test images and their predicted labels.
# If a prediction result is different from the label provided label in "test"
# dataset, we will highlight it in red color.
plt.figure(figsize=(20, 20))
predicts = model.predict_top_k(test_data)
for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(100)):
ax = plt.subplot(10, 10, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(image.numpy(), cmap=plt.cm.gray)
predict_label = predicts[i][0][0]
color = get_label_color(predict_label,
test_data.index_to_label[label.numpy()])
ax.xaxis.label.set_color(color)
plt.xlabel('Predicted: %s' % predict_label)
plt.show()
Se a precisão não cumprir a exigência de aplicação, pode-se referir a Uso Avançado para explorar alternativas, tais como a mudança para um modelo maior, ajustando os parâmetros re-treinamento etc.
Etapa 4: exportar para o modelo do TensorFlow Lite
Converter o modelo treinado para o formato modelo TensorFlow Lite com metadados de modo que você pode usar mais tarde em um aplicativo ML no dispositivo. O arquivo de rótulo e o arquivo de vocabulário são incorporados nos metadados. O nome do arquivo TFLite padrão é model.tflite
.
Em muitos aplicativos de ML no dispositivo, o tamanho do modelo é um fator importante. Portanto, é recomendável que você aplique quantize o modelo para torná-lo menor e potencialmente mais rápido. A técnica de quantização pós-treinamento padrão é a quantização de inteiros completos para a tarefa de classificação de imagens.
model.export(export_dir='.')
INFO:tensorflow:Assets written to: /tmp/tmp6tt5g8de/assets INFO:tensorflow:Assets written to: /tmp/tmp6tt5g8de/assets 2021-11-02 11:35:40.254046: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-11-02 11:35:40.254099: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency. fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3 WARNING:absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Saving labels in /tmp/tmpf601xty1/labels.txt INFO:tensorflow:Saving labels in /tmp/tmpf601xty1/labels.txt INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
Ver aplicações e guias de classificação de imagens de exemplo para obter mais detalhes sobre como integrar o modelo TensorFlow Lite em aplicativos móveis.
Este modelo pode ser integrado em um Android ou um app iOS usando o API ImageClassifier da Biblioteca de Tarefas Lite TensorFlow .
Os formatos de exportação permitidos podem ser um ou uma lista dos seguintes:
Por padrão, ele apenas exporta o modelo TensorFlow Lite com metadados. Você também pode exportar arquivos diferentes seletivamente. Por exemplo, exportando apenas o arquivo de etiqueta da seguinte forma:
model.export(export_dir='.', export_format=ExportFormat.LABEL)
INFO:tensorflow:Saving labels in ./labels.txt INFO:tensorflow:Saving labels in ./labels.txt
Você também pode avaliar o modelo tflite com o evaluate_tflite
método.
model.evaluate_tflite('model.tflite', test_data)
{'accuracy': 0.9019073569482289}
Uso avançado
A create
função é a parte crítica desta biblioteca. Ele usa o aprendizado de transferência com um modelo pré-treinado semelhante ao tutorial .
A create
função contém os seguintes passos:
- Dividir os dados em formação, validação, testes de dados de acordo com o parâmetro
validation_ratio
etest_ratio
. O valor padrão devalidation_ratio
etest_ratio
são0.1
e0.1
. - Baixar um vetor de imagens características como o modelo base TensorFlow Hub. O modelo pré-treinado padrão é EfficientNet-Lite0.
- Adicionar uma cabeça classificador com uma camada Dropout com
dropout_rate
entre a camada de cabeça e modelo pré-treinados. O padrãodropout_rate
é o padrãodropout_rate
valor de make_image_classifier_lib por TensorFlow Hub. - Pré-processe os dados brutos de entrada. Atualmente, as etapas de pré-processamento incluem a normalização do valor de cada pixel da imagem para modelar a escala de entrada e redimensioná-la para modelar o tamanho da entrada. EfficientNet-Lite0 tem a escala de entrada
[0, 1]
e o tamanho da imagem de entrada[224, 224, 3]
. - Alimente os dados no modelo do classificador. Por padrão, os parâmetros de formação, tais como épocas de treinamento, tamanho do lote, taxa de aprendizagem, momento são os valores padrão de make_image_classifier_lib por TensorFlow Hub. Apenas o chefe do classificador é treinado.
Nesta seção, descrevemos vários tópicos avançados, incluindo alternar para um modelo de classificação de imagem diferente, alterar os hiperparâmetros de treinamento etc.
Personalize a quantização pós-treinamento no modelo TensorFLow Lite
Quantização pós-treino é uma técnica de conversão que pode reduzir o tamanho do modelo e latência inferência, ao mesmo tempo melhorar a velocidade da CPU e acelerador de hardware inferência, com um pouco de degradação na precisão do modelo. Assim, é amplamente utilizado para otimizar o modelo.
A biblioteca Model Maker aplica uma técnica de quantização pós-treinamento padrão ao exportar o modelo. Se você quiser personalizar quantização pós-treino, Model Maker suporta múltiplas opções de pós-formação de quantização usando QuantizationConfig também. Vamos usar a quantização float16 como exemplo. Primeiro, defina a configuração de quantização.
config = QuantizationConfig.for_float16()
Em seguida, exportamos o modelo do TensorFlow Lite com essa configuração.
model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)
INFO:tensorflow:Assets written to: /tmp/tmpa528qeqj/assets INFO:tensorflow:Assets written to: /tmp/tmpa528qeqj/assets INFO:tensorflow:Label file is inside the TFLite model with metadata. 2021-11-02 11:43:43.724165: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:351] Ignored output_format. 2021-11-02 11:43:43.724219: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored drop_control_dependency. INFO:tensorflow:Label file is inside the TFLite model with metadata. INFO:tensorflow:Saving labels in /tmp/tmpvlx_qa4j/labels.txt INFO:tensorflow:Saving labels in /tmp/tmpvlx_qa4j/labels.txt INFO:tensorflow:TensorFlow Lite model exported successfully: ./model_fp16.tflite INFO:tensorflow:TensorFlow Lite model exported successfully: ./model_fp16.tflite
Em Colab, você pode baixar o modelo chamado model_fp16.tflite
da barra lateral esquerda, mesmo que a parte upload mencionado acima.
Mude o modelo
Mude para o modelo com suporte nesta biblioteca.
Esta biblioteca suporta modelos EfficientNet-Lite, MobileNetV2, ResNet50 até agora. EfficientNet-Lite são uma família de modelos de classificação de imagem que pode atingir precisão o estado-da-arte e adequado para dispositivos de borda. O modelo padrão é EfficientNet-Lite0.
Poderíamos mudar modelo para MobileNetV2 por apenas ajustando o parâmetro model_spec
à especificação do modelo MobileNetV2 em create
método.
model = image_classifier.create(train_data, model_spec=model_spec.get('mobilenet_v2'), validation_data=validation_data)
INFO:tensorflow:Retraining the models... INFO:tensorflow:Retraining the models... Model: "sequential_2" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2_2 (HubK (None, 1280) 2257984 _________________________________________________________________ dropout_2 (Dropout) (None, 1280) 0 _________________________________________________________________ dense_2 (Dense) (None, 5) 6405 ================================================================= Total params: 2,264,389 Trainable params: 6,405 Non-trainable params: 2,257,984 _________________________________________________________________ None Epoch 1/5 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. "The `lr` argument is deprecated, use `learning_rate` instead.") 91/91 [==============================] - 8s 53ms/step - loss: 0.9163 - accuracy: 0.7634 - val_loss: 0.7789 - val_accuracy: 0.8267 Epoch 2/5 91/91 [==============================] - 4s 50ms/step - loss: 0.6836 - accuracy: 0.8822 - val_loss: 0.7223 - val_accuracy: 0.8551 Epoch 3/5 91/91 [==============================] - 4s 50ms/step - loss: 0.6506 - accuracy: 0.9045 - val_loss: 0.7086 - val_accuracy: 0.8580 Epoch 4/5 91/91 [==============================] - 5s 50ms/step - loss: 0.6218 - accuracy: 0.9227 - val_loss: 0.7049 - val_accuracy: 0.8636 Epoch 5/5 91/91 [==============================] - 5s 52ms/step - loss: 0.6092 - accuracy: 0.9279 - val_loss: 0.7181 - val_accuracy: 0.8580
Avalie o modelo MobileNetV2 recém-treinado para ver a precisão e a perda nos dados de teste.
loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 2s 35ms/step - loss: 0.6866 - accuracy: 0.8747
Mude para o modelo no TensorFlow Hub
Além disso, também podemos mudar para outros novos modelos que inserem uma imagem e geram um vetor de recursos com o formato TensorFlow Hub.
Como Iniciação V3 modelo como um exemplo, podemos definir inception_v3_spec
que é um objecto do image_classifier.ModelSpec e contém a especificação do modelo de Iniciação V3.
Precisamos especificar o nome do modelo name
, o URL do modelo TensorFlow Hub uri
. Enquanto isso, o valor padrão de input_image_shape
é [224, 224]
. Precisamos mudá-lo para [299, 299]
para o modelo Inception V3.
inception_v3_spec = image_classifier.ModelSpec(
uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]
Então, por parametrização model_spec
para inception_v3_spec
em create
método, poderíamos voltar a treinar o modelo Inception V3.
As etapas restantes são exatamente as mesmas e podemos obter um modelo personalizado do InceptionV3 TensorFlow Lite no final.
Altere seu próprio modelo personalizado
Se nós gostaríamos de usar o modelo personalizado que não está em TensorFlow Hub, devemos criar e exportar ModelSpec em TensorFlow Hub.
Em seguida, começa a definir ModelSpec
objecto parecido com o processo acima.
Alterar os hiperparâmetros de treinamento
Nós também poderia mudar os hiperparâmetros de treinamento como epochs
, dropout_rate
e batch_size
que poderiam afetar a precisão do modelo. Os parâmetros do modelo que você pode ajustar são:
-
epochs
: mais épocas poderia conseguir uma melhor precisão até que converge mas o treinamento para muitas épocas pode levar a overfitting. -
dropout_rate
: A taxa de abandono, overfitting evitar. Nenhum por padrão. -
batch_size
: número de amostras a utilizar em uma etapa de treinamento. Nenhum por padrão. -
validation_data
Validação de dados:. Se Nenhum, ignora o processo de validação. Nenhum por padrão. -
train_whole_model
: Se for verdade, o módulo de Hub é treinado em conjunto com a camada de classificação no topo. Caso contrário, treine apenas a camada de classificação superior. Nenhum por padrão. -
learning_rate
: taxa de aprendizagem Base. Nenhum por padrão. -
momentum
: uma bóia Python encaminhado para o otimizador. Apenas usada quandouse_hub_library
é True. Nenhum por padrão. -
shuffle
: booleano, se os dados devem ser embaralhadas. Falso por padrão. -
use_augmentation
: Boolean, o aumento do uso de dados para pré-processamento. Falso por padrão. -
use_hub_library
: booleana, usomake_image_classifier_lib
de cubo tensorflow para reconverter o modelo. Esse pipeline de treinamento pode obter melhor desempenho para conjuntos de dados complicados com muitas categorias. Verdadeiro por padrão. -
warmup_steps
: Número de passos de aquecimento para programar configurações de aquecimento na taxa de aprendizagem. Se Nenhum, o warmup_steps padrão é usado, que é o total de etapas de treinamento em duas épocas. Apenas usada quandouse_hub_library
é False. Nenhum por padrão. -
model_dir
: Opcional, a localização dos arquivos de modelo de ponto de verificação. Apenas usada quandouse_hub_library
é False. Nenhum por padrão.
Parâmetros que são None por padrão como epochs
irá obter os parâmetros padrão de concreto em make_image_classifier_lib da biblioteca TensorFlow Hub ou train_image_classifier_lib .
Por exemplo, poderíamos treinar com mais épocas.
model = image_classifier.create(train_data, validation_data=validation_data, epochs=10)
INFO:tensorflow:Retraining the models... INFO:tensorflow:Retraining the models... Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= hub_keras_layer_v1v2_3 (HubK (None, 1280) 3413024 _________________________________________________________________ dropout_3 (Dropout) (None, 1280) 0 _________________________________________________________________ dense_3 (Dense) (None, 5) 6405 ================================================================= Total params: 3,419,429 Trainable params: 6,405 Non-trainable params: 3,413,024 _________________________________________________________________ None Epoch 1/10 91/91 [==============================] - 6s 53ms/step - loss: 0.8735 - accuracy: 0.7644 - val_loss: 0.6701 - val_accuracy: 0.8892 Epoch 2/10 91/91 [==============================] - 4s 49ms/step - loss: 0.6502 - accuracy: 0.8984 - val_loss: 0.6442 - val_accuracy: 0.8864 Epoch 3/10 91/91 [==============================] - 4s 49ms/step - loss: 0.6215 - accuracy: 0.9107 - val_loss: 0.6306 - val_accuracy: 0.8920 Epoch 4/10 91/91 [==============================] - 4s 49ms/step - loss: 0.5962 - accuracy: 0.9299 - val_loss: 0.6253 - val_accuracy: 0.8977 Epoch 5/10 91/91 [==============================] - 5s 52ms/step - loss: 0.5845 - accuracy: 0.9334 - val_loss: 0.6206 - val_accuracy: 0.9062 Epoch 6/10 91/91 [==============================] - 5s 50ms/step - loss: 0.5743 - accuracy: 0.9451 - val_loss: 0.6159 - val_accuracy: 0.9062 Epoch 7/10 91/91 [==============================] - 4s 48ms/step - loss: 0.5682 - accuracy: 0.9444 - val_loss: 0.6192 - val_accuracy: 0.9006 Epoch 8/10 91/91 [==============================] - 4s 49ms/step - loss: 0.5595 - accuracy: 0.9557 - val_loss: 0.6153 - val_accuracy: 0.9091 Epoch 9/10 91/91 [==============================] - 4s 47ms/step - loss: 0.5560 - accuracy: 0.9523 - val_loss: 0.6213 - val_accuracy: 0.9062 Epoch 10/10 91/91 [==============================] - 4s 45ms/step - loss: 0.5520 - accuracy: 0.9595 - val_loss: 0.6220 - val_accuracy: 0.8977
Avalie o modelo recém-treinado com 10 épocas de treinamento.
loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 27ms/step - loss: 0.6417 - accuracy: 0.8883
Consulte Mais informação
Você pode ler a nossa imagem de classificação de exemplo para aprender detalhes técnicos. Para mais informações, consulte:
- TensorFlow Lite Modelo Criador guia e referência da API .
- Biblioteca de Tarefas: ImageClassifier para implantação.
- A referência end-to-end Apps: Android , iOS , e Raspberry PI .