Veja no TensorFlow.org | Executar no Google Colab | Ver fonte no GitHub | Baixar caderno |
Visão geral
É comum salvar e carregar um modelo durante o treinamento. Existem dois conjuntos de APIs para salvar e carregar um modelo keras: uma API de alto nível e uma API de baixo nível. Este tutorial demonstra como você pode usar as APIs SavedModel ao usar tf.distribute.Strategy
. Para saber mais sobre SavedModel e serialização em geral, leia o guia de modelo salvo e o guia de serialização de modelo Keras . Vamos começar com um exemplo simples:
Importar dependências:
import tensorflow_datasets as tfds
import tensorflow as tf
Prepare os dados e o modelo usando tf.distribute.Strategy
:
mirrored_strategy = tf.distribute.MirroredStrategy()
def get_data():
datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
return train_dataset, eval_dataset
def get_model():
with mirrored_strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Treine o modelo:
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 1/2 2022-01-26 05:41:11.916000: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 938/938 [==============================] - 11s 5ms/step - loss: 0.1873 - sparse_categorical_accuracy: 0.9451 Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0641 - sparse_categorical_accuracy: 0.9807 <keras.callbacks.History at 0x7f3b900396d0>
Salve e carregue o modelo
Agora que você tem um modelo simples para trabalhar, vamos dar uma olhada nas APIs de salvamento/carregamento. Há dois conjuntos de APIs disponíveis:
- Keras de alto nível
model.save
etf.keras.models.load_model
- Nível baixo
tf.saved_model.save
etf.saved_model.load
As APIs Keras
Aqui está um exemplo de como salvar e carregar um modelo com as APIs Keras:
keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
2022-01-26 05:41:26.593570: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets
Restaure o modelo sem tf.distribute.Strategy
:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0476 - sparse_categorical_accuracy: 0.9859 Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0334 - sparse_categorical_accuracy: 0.9895 <keras.callbacks.History at 0x7f3b187b7150>
Após restaurar o modelo, você pode continuar treinando nele, mesmo sem precisar chamar compile()
novamente, pois ele já está compilado antes de salvar. O modelo é salvo no formato proto padrão SavedModel
do TensorFlow. Para obter mais informações, consulte o guia para o formato saved_model
.
Agora para carregar o modelo e treiná-lo usando um tf.distribute.Strategy
:
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2 2022-01-26 05:41:33.036733: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 2022-01-26 05:41:33.083001: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. 938/938 [==============================] - 10s 10ms/step - loss: 0.0474 - sparse_categorical_accuracy: 0.9860 Epoch 2/2 938/938 [==============================] - 10s 10ms/step - loss: 0.0327 - sparse_categorical_accuracy: 0.9903
Como você pode ver, o carregamento funciona como esperado com tf.distribute.Strategy
. A estratégia usada aqui não precisa ser a mesma usada antes de salvar.
As APIs tf.saved_model
Agora vamos dar uma olhada nas APIs de nível inferior. Salvar o modelo é semelhante à API keras:
model = get_model() # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets
O carregamento pode ser feito com tf.saved_model.load()
. No entanto, como é uma API que está no nível inferior (e, portanto, possui uma gama maior de casos de uso), ela não retorna um modelo Keras. Em vez disso, ele retorna um objeto que contém funções que podem ser usadas para fazer inferência. Por exemplo:
DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
O objeto carregado pode conter várias funções, cada uma associada a uma chave. O "serving_default"
é a chave padrão para a função de inferência com um modelo Keras salvo. Para fazer uma inferência com esta função:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[-1.18789300e-01, -1.78404614e-01, 4.92432676e-02, -9.37875658e-02, 1.14302970e-01, -8.99422392e-02, 9.47709680e-02, -7.75382966e-02, 4.04430032e-02, 2.41404288e-02], [-2.35370561e-01, -3.39397341e-02, 2.73427293e-02, -1.08200148e-01, 5.10682352e-02, 1.36142194e-01, 9.28785652e-02, -5.35808355e-02, 2.56292164e-01, 1.05301209e-01], [-1.91031799e-01, -7.72745535e-02, -7.23153427e-02, -1.99329913e-01, -7.45072216e-02, 2.42738128e-02, 2.07733169e-01, -3.15396488e-03, 4.95976806e-02, 2.14848563e-01], [-9.82482210e-02, -6.13910556e-02, 1.00815810e-01, -1.87558904e-01, 1.14685424e-01, 1.53835595e-01, 1.85714245e-01, -8.74890238e-02, 1.07493028e-01, 1.57510787e-02], [-8.56257528e-02, 3.23683321e-02, -3.66768315e-02, -1.47201523e-01, -5.31517603e-02, 1.52744055e-02, 1.69184029e-01, -5.42814359e-02, 1.11524366e-01, 5.65215349e-02], [-1.50604844e-01, -7.87255913e-03, 1.26651973e-01, -1.24476865e-01, 6.94983900e-02, 4.27672639e-03, 1.86136231e-01, -4.54714149e-03, 9.12746191e-02, 6.12779632e-02], [-2.79157639e-01, -4.61089313e-02, 2.51544192e-02, -1.79003477e-01, 3.83432880e-02, 2.05054253e-01, -8.25636461e-03, -8.25546682e-03, 2.41342247e-01, 8.24805871e-02], [-1.42795354e-01, 6.54597580e-02, 2.05058958e-02, -1.28471941e-01, 1.10977650e-01, 4.51317504e-02, 2.44124904e-01, 1.90523565e-02, 3.11958641e-02, 6.49511665e-02], [-1.33037239e-01, -2.72594951e-02, 8.09026062e-02, -1.95883229e-01, 1.84634060e-01, 1.00822970e-01, 4.40884084e-02, -6.43826872e-02, 1.47807434e-01, -1.92791894e-02], [-1.43770471e-01, -2.53150351e-02, 4.18904647e-02, -1.02573663e-01, 6.15917407e-02, 7.95702711e-02, 9.27314460e-02, -4.31537181e-02, 4.59018350e-02, 1.02965936e-01], [-1.90395206e-01, 2.93233991e-03, 1.48900077e-02, -1.15877971e-01, 1.06598288e-02, 1.40121073e-01, 6.86443001e-02, -4.61921766e-02, 1.27470195e-01, 6.73005953e-02], [-2.60747373e-01, -1.45188004e-01, 7.10044056e-04, -1.04602516e-01, 5.00324890e-02, 2.96664417e-01, 8.57191086e-02, 6.65097907e-02, 1.31302923e-01, -1.84605196e-02], [-1.62942797e-01, -3.63466889e-02, -1.33987352e-01, -1.34576231e-01, -8.19503814e-02, 1.30840242e-02, 6.16783127e-02, -3.64837795e-02, 3.18005830e-02, 1.98420882e-01], [-1.25772715e-01, -6.94367215e-02, -1.35144517e-02, -6.30265176e-02, 8.36028308e-02, 2.96559408e-02, 2.19864860e-01, -7.08417147e-02, 4.76131588e-02, 1.15781695e-01], [-1.55139655e-01, -1.27863720e-01, 9.67459157e-02, -1.48635745e-01, 1.25129193e-01, 4.04443927e-02, 2.94884086e-01, -7.66484886e-02, 1.18753463e-01, 2.93397382e-02], [-1.59221828e-01, -9.30457860e-02, 9.18259323e-02, -1.72857821e-01, 8.09611157e-02, 1.11391053e-01, 1.66679412e-01, 3.52456123e-02, 9.05358568e-02, 9.89414975e-02], [-2.01425552e-01, -4.67008501e-02, -1.62331611e-02, -9.73629057e-02, 1.36456266e-01, 1.30628154e-01, 1.53577864e-01, -6.73157908e-03, 9.31103677e-02, 1.50734074e-02], [-1.29348308e-01, -3.03804129e-03, 2.82487050e-02, -2.02886015e-01, 7.09105879e-02, 1.74542382e-01, 2.57992335e-02, -1.63579211e-02, 2.30892301e-02, 6.69767857e-02], [-1.56857669e-01, 5.46110943e-02, -5.93251809e-02, -1.04585059e-01, 2.61763521e-02, 1.43062070e-01, 1.57771498e-01, -6.19823262e-02, 3.59585434e-02, 6.62322640e-02], [-8.64257440e-02, -1.33483298e-03, 7.46414512e-02, -1.82848468e-01, 1.21074423e-01, 1.55276239e-01, 1.46483868e-01, -6.22515939e-03, 1.91641584e-01, -9.95825827e-02], [-2.52117336e-01, -6.92471862e-02, 1.09911412e-01, -3.73112522e-02, 3.76211852e-03, 5.23591004e-02, 9.16506499e-02, 6.80204183e-02, -4.27842364e-02, 7.91264027e-02], [-2.11018056e-01, 5.97522780e-03, 8.47486481e-02, -7.27925971e-02, 9.36664082e-03, 1.62506998e-01, 5.32426499e-02, 1.78599171e-02, -2.30420940e-02, 4.07365486e-02], [-1.35342121e-01, -4.06659022e-02, -2.09493563e-02, -1.64699793e-01, 8.35808069e-02, 7.68100768e-02, -7.14773983e-02, -3.43702435e-02, 9.47649628e-02, 9.36352089e-02], [-1.20486066e-01, 3.77080180e-02, 1.14158325e-01, -6.50681928e-02, 1.03382617e-02, 1.17891498e-01, 1.13154747e-01, -1.49052702e-02, 1.28893867e-01, 1.12219512e-01], [-2.23867983e-01, -9.79400948e-02, 7.37103820e-02, -1.05197895e-02, 3.75595838e-02, 1.80490598e-01, 6.83145374e-02, -3.09509300e-02, 1.42565176e-01, 8.05927664e-02], [-2.32092351e-01, -3.42734642e-02, -5.15977889e-02, -1.75458089e-01, 1.46448284e-01, 1.80426955e-01, 1.52164772e-01, -2.57370695e-02, 1.26812875e-01, 1.22049123e-01], [-9.45013613e-02, 5.85526973e-02, 1.47456676e-02, -4.40606587e-02, 4.86647561e-02, 6.28624633e-02, 3.69989276e-02, -3.68277319e-02, 3.56127135e-02, 3.10502797e-02], [-1.02712311e-01, 3.16979140e-02, 1.88253060e-01, -5.99608906e-02, 3.73450294e-02, 6.38176724e-02, 1.12240583e-01, 2.42183693e-02, 1.45670772e-02, -9.52028483e-03], [-1.62333213e-02, -1.42737105e-02, -5.79352975e-02, -1.01807326e-01, -7.93362781e-03, -7.22003728e-02, 1.49934232e-01, -1.19943202e-01, 9.22369361e-02, 1.46321565e-01], [-1.32534593e-01, 1.18380897e-02, 2.23980099e-03, -9.28303748e-02, -2.20538303e-02, 7.68908709e-02, 5.29715866e-02, -3.43324393e-02, -1.27909705e-02, -7.04141408e-02], [-8.10261145e-02, -8.95578321e-03, 3.96864787e-02, -1.21861629e-01, 7.98310041e-02, 1.56087667e-01, 9.11872089e-02, -2.29295418e-02, 5.64432219e-02, -3.55931222e-02], [-1.76416740e-01, 1.12043694e-02, -1.80068091e-02, -1.88012689e-01, 8.68914276e-02, 1.57958359e-01, 5.77907935e-02, -2.12088451e-02, 5.33877537e-02, 2.19271183e-02], [-2.70012528e-01, -1.26611829e-01, 3.10387388e-02, -7.24840909e-02, 1.03253610e-01, 8.91268626e-02, 1.38662308e-01, -6.25240132e-02, 2.36210316e-01, 1.40534222e-01], [-8.52961093e-02, -1.15273651e-02, -2.88792588e-02, -2.01282576e-02, 5.43357767e-02, 7.14191943e-02, 3.46604213e-02, -6.00920171e-02, 5.11362031e-02, 3.58160883e-02], [-1.63262367e-01, 2.44849995e-02, 3.81964818e-02, -3.93010303e-02, 3.95263731e-03, 9.11088511e-02, 3.88236046e-02, 1.33745335e-02, 1.00076631e-01, 6.05135933e-02], [-3.01809371e-01, -1.58440098e-01, 4.65333983e-02, -1.63946241e-01, -6.42775744e-02, 3.93286347e-04, 2.82839835e-01, -8.93663988e-02, 1.97781295e-01, 2.87044942e-01], [-2.15368003e-01, -4.83291782e-02, -8.29075277e-03, -1.01776704e-01, 1.43144801e-02, 1.82002857e-02, 2.76539754e-02, -1.94141679e-02, 8.87098238e-02, 6.60644472e-02], [-2.20715180e-01, -7.20694065e-02, -6.08972833e-02, -4.82957587e-02, 1.28858402e-01, 1.30042464e-01, 1.32807568e-01, -7.52742141e-02, 9.51702446e-02, 3.10119465e-02], [-1.09407350e-01, -5.27948700e-03, 1.29588693e-03, -2.61662379e-02, 3.01920641e-02, 1.13487415e-01, 8.23267922e-02, 1.92574020e-02, 2.31986474e-02, 4.13139611e-02], [-2.12277412e-01, -1.35507256e-01, 4.22930568e-02, -1.34565741e-01, 1.17879853e-01, 1.30573064e-01, 1.81054786e-01, -1.70722306e-01, 1.05854876e-01, 7.36362934e-02], [-1.78249478e-01, -7.55607188e-02, 7.75147527e-02, -2.14659080e-01, 3.26948166e-02, 7.76198730e-02, 1.08791113e-01, -2.38809325e-02, 1.79410487e-01, 1.94452941e-01], [-1.92162693e-01, -1.50472090e-01, -8.24331492e-02, -1.40473023e-02, 3.60646360e-02, -9.39090401e-02, 1.83859855e-01, -1.09493822e-01, -3.09051797e-02, 1.36017531e-01], [-9.21519399e-02, -1.53335631e-02, -5.56742400e-02, -9.68495384e-02, 2.35293470e-02, 2.53665410e-02, 1.79999322e-01, -7.10204691e-02, -7.29817525e-02, 4.50368747e-02], [-1.22261971e-01, -6.94630146e-02, -7.97796808e-03, -1.03088826e-01, -7.38603100e-02, 1.84892826e-02, 9.76646394e-02, -3.29037756e-02, -1.77134499e-02, 1.62288889e-01], [-6.78652674e-02, -1.08500615e-01, 5.66991530e-02, -9.52370912e-02, 5.28126955e-02, 1.05176866e-02, 1.73085481e-01, -1.37753151e-02, 1.95556954e-02, 1.38068855e-01], [-2.02808753e-01, -3.39423120e-02, 1.82233751e-03, -5.71424365e-02, 3.40205729e-02, 8.74454305e-02, 8.47227685e-03, -2.52498202e-02, 4.66104299e-02, 1.10718749e-01], [-9.52449068e-02, -3.35062481e-02, -1.00178778e-01, -9.72513855e-02, -3.58061343e-02, 3.04423086e-02, 5.70362583e-02, -4.03833576e-02, -4.28436548e-02, 9.73245874e-02], [-2.06081957e-01, -1.71493232e-01, 2.52560824e-02, -1.55212343e-01, -4.33478206e-02, 2.34177694e-01, 8.46128762e-02, 1.75322518e-02, 2.04347119e-01, 1.54971585e-01], [-1.95310384e-01, 1.30968075e-02, -9.68117267e-03, -7.31432810e-02, 1.02618083e-01, 1.59629256e-01, 1.66028887e-01, -7.12903216e-03, 1.78021699e-01, -2.17130631e-02], [-1.59163624e-01, -1.77137554e-05, 1.75410658e-02, -9.08103511e-02, 7.25786015e-02, 9.21041369e-02, 1.24915361e-01, -6.55939505e-02, -1.13440230e-02, 1.03661232e-01], [-1.93366870e-01, -4.36344892e-02, 1.37750164e-01, -1.91939399e-01, -1.50268525e-03, 8.03942382e-02, 2.15812266e-01, 5.38492575e-02, 1.36685073e-01, 2.22119391e-01], [-1.65946245e-01, 7.89588690e-03, -1.65037125e-01, -1.23690292e-01, -8.57629776e-02, -2.55736727e-02, 1.67541012e-01, -6.63827211e-02, 2.98694819e-02, 1.71927184e-01], [-1.56264767e-01, -1.72245800e-02, -4.98924702e-02, -2.98387632e-02, 2.80477256e-02, 4.94132042e-02, 4.89805043e-02, 1.96998678e-02, -4.14144360e-02, -5.05549274e-02], [-1.46449029e-01, -1.12528354e-01, -4.66653258e-02, -3.78398523e-02, 7.60737807e-03, -2.70657167e-02, 1.11277811e-01, 6.37479573e-02, -2.39458829e-02, 1.22067556e-01], [-1.92323536e-01, -1.43002480e-01, 5.29062748e-03, -1.70663983e-01, 8.39572400e-03, 6.37906119e-02, 1.24084033e-01, 6.02792688e-02, 7.18353763e-02, 5.03963791e-03], [-1.70977920e-01, 1.04207098e-02, 1.18544906e-01, -4.29532528e-02, -3.53983864e-02, 1.80302024e-01, 8.08775946e-02, 3.19045782e-02, 2.52931342e-02, 1.29424319e-01], [-2.13301033e-01, -6.96119964e-02, 2.32847631e-02, -7.73920864e-02, 1.10387571e-01, 1.13307782e-01, 1.41805351e-01, -5.19381016e-02, 1.15313083e-01, 1.40049949e-01], [-1.71651557e-01, -5.98860830e-02, -3.92800570e-03, -1.04376137e-01, 7.78115019e-02, 6.84583709e-02, 2.51923770e-01, -1.05199262e-01, 1.64517179e-01, 2.18875334e-01], [-2.60777414e-01, -8.93031508e-02, 1.27723843e-01, -1.97950065e-01, 1.19145498e-01, 7.30907321e-02, 2.23771721e-01, -6.83849230e-02, 3.68930906e-01, 1.86811388e-01], [-2.38028213e-01, 1.11199915e-03, 2.25015372e-01, 8.22724327e-02, -1.14511400e-01, 1.57513067e-01, 5.22858277e-02, 2.13724375e-03, 3.15639377e-02, 2.08704025e-01], [-1.46687120e-01, -1.10313833e-01, -1.16352811e-02, -1.44550815e-01, 2.09794566e-02, 1.47883072e-02, 3.96856442e-02, -2.15019658e-03, -4.90810722e-02, 1.34708211e-01], [-2.02591017e-01, -2.29728431e-01, 6.73423260e-02, -1.24901496e-01, -1.38434023e-02, 8.64367038e-02, 1.22342721e-01, 1.67826824e-02, 1.65354639e-01, 1.83434993e-01], [-2.25799978e-01, -1.02682747e-01, 9.48531851e-02, -9.38871950e-02, 1.03806734e-01, 2.04695478e-01, 8.09893832e-02, -1.45416632e-02, 1.33486420e-01, -6.27665371e-02], [-1.19375348e-01, 2.23235339e-02, 1.04302749e-01, -1.11149743e-01, 6.12434298e-02, 6.89433664e-02, 2.08741099e-01, -3.81497070e-02, -1.42122135e-02, 7.65201449e-03]], dtype=float32)>} 2022-01-26 05:41:53.590742: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Você também pode carregar e fazer inferência de maneira distribuída:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
dist_predict_dataset = another_strategy.experimental_distribute_dataset(
predict_dataset)
# Calling the function in a distributed manner
for batch in dist_predict_dataset:
another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) 2022-01-26 05:41:53.931428: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
Chamar a função restaurada é apenas um passo para frente no modelo salvo (previsão). E se você quiser continuar treinando a função carregada? Ou incorporar a função carregada em um modelo maior? Uma prática comum é envolver esse objeto carregado em uma camada Keras para conseguir isso. Felizmente, o TF Hub possui hub.KerasLayer para essa finalidade, mostrado aqui:
import tensorflow_hub as hub
def build_model(loaded):
x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
# Wrap what's loaded to a KerasLayer
keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
model = tf.keras.Model(x, keras_layer)
return model
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
model = build_model(loaded)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) Epoch 1/2 2022-01-26 05:41:55.594317: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 938/938 [==============================] - 6s 3ms/step - loss: 0.1910 - sparse_categorical_accuracy: 0.9442 Epoch 2/2 938/938 [==============================] - 3s 4ms/step - loss: 0.0633 - sparse_categorical_accuracy: 0.9813
Como você pode ver, hub.KerasLayer
encapsula o resultado carregado de volta de tf.saved_model.load()
em uma camada Keras que pode ser usada para construir outro modelo. Isso é muito útil para o aprendizado de transferência.
Qual API devo usar?
Para salvar, se você estiver trabalhando com um modelo keras, quase sempre é recomendável usar a API model.save()
do Keras. Se o que você está salvando não é um modelo Keras, a API de nível inferior é sua única opção.
Para carregar, qual API você usa depende do que você deseja obter da API de carregamento. Se você não puder (ou não quiser) obter um modelo Keras, use tf.saved_model.load()
. Caso contrário, use tf.keras.models.load_model()
. Observe que você pode obter um modelo Keras de volta somente se tiver salvo um modelo Keras.
É possível misturar e combinar as APIs. Você pode salvar um modelo Keras com model.save
e carregar um modelo não Keras com a API de baixo nível, tf.saved_model.load
.
model = get_model()
# Saving the model using Keras's save() API
model.save(keras_model_path)
another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Salvando/Carregando do dispositivo local
Ao salvar e carregar de um dispositivo io local durante a execução remota, por exemplo, usando uma TPU em nuvem, a opção experimental_io_device
deve ser usada para definir o dispositivo io como localhost.
model = get_model()
# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)
# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Ressalvas
Um caso especial é quando você tem um modelo Keras que não possui entradas bem definidas. Por exemplo, um modelo Sequential pode ser criado sem nenhuma forma de entrada ( Sequential([Dense(3), ...]
). Modelos de subclasses também não possuem entradas bem definidas após a inicialização. Neste caso, você deve ficar com o APIs de nível inferior tanto para salvar quanto para carregar, caso contrário, você receberá um erro.
Para verificar se seu modelo possui entradas bem definidas, basta verificar se model.inputs
é None
. Se não for None
, você está bem. As formas de entrada são definidas automaticamente quando o modelo é usado em .fit
, .evaluate
, .predict
ou ao chamar o modelo ( model(inputs)
).
Aqui está um exemplo:
class SubclassedModel(tf.keras.Model):
output_name = 'output_layer'
def __init__(self):
super(SubclassedModel, self).__init__()
self._dense_layer = tf.keras.layers.Dense(
5, dtype=tf.dtypes.float32, name=self.output_name)
def call(self, inputs):
return self._dense_layer(inputs)
my_model = SubclassedModel()
# my_model.save(keras_model_path) # ERROR!
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built. INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets