Lihat di TensorFlow.org | Jalankan di Google Colab | Lihat sumber di GitHub | Unduh buku catatan |
Di TensorFlow 1, Anda menggunakan tf.estimator.LoggingTensorHook
untuk memantau dan mencatat tensor, sedangkan tf.estimator.StopAtStepHook
membantu menghentikan pelatihan pada langkah tertentu saat berlatih dengan tf.estimator.Estimator
. Notebook ini menunjukkan cara bermigrasi dari API ini ke yang setara di TensorFlow 2 menggunakan panggilan balik Keras khusus ( tf.keras.callbacks.Callback
) dengan Model.fit
.
Callback Keras adalah objek yang dipanggil pada titik yang berbeda selama pelatihan/evaluasi/prediksi dalam Keras Model.fit
/ Model.evaluate
/ Model.predict
API bawaan. Anda dapat mempelajari lebih lanjut tentang panggilan balik di dokumen tf.keras.callbacks.Callback
API, serta panduan Menulis panggilan balik Anda sendiri dan Pelatihan dan evaluasi dengan metode bawaan (bagian Menggunakan panggilan balik ). Untuk bermigrasi dari SessionRunHook
di TensorFlow 1 ke callback Keras di TensorFlow 2, lihat pelatihan Migrasi dengan panduan logika terbantu .
Mempersiapkan
Mulailah dengan impor dan kumpulan data sederhana untuk tujuan demonstrasi:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
# Define an input function.
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
TensorFlow 1: Catat tensor dan hentikan pelatihan dengan tf.estimator API
Di TensorFlow 1, Anda menentukan berbagai kait untuk mengontrol perilaku pelatihan. Kemudian, Anda meneruskan kait ini ke tf.estimator.EstimatorSpec
.
Dalam contoh di bawah ini:
- Untuk memantau/mencatat tensor—misalnya, bobot model atau kerugian—Anda menggunakan
tf.estimator.LoggingTensorHook
(tf.train.LoggingTensorHook
adalah aliasnya). - Untuk menghentikan pelatihan pada langkah tertentu, Anda menggunakan
tf.estimator.StopAtStepHook
(tf.train.StopAtStepHook
adalah aliasnya).
def _model_fn(features, labels, mode):
dense = tf1.layers.Dense(1)
logits = dense(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
# Define the stop hook.
stop_hook = tf1.train.StopAtStepHook(num_steps=2)
# Access tensors to be logged by names.
kernel_name = tf.identity(dense.weights[0])
bias_name = tf.identity(dense.weights[1])
logging_weight_hook = tf1.train.LoggingTensorHook(
tensors=[kernel_name, bias_name],
every_n_iter=1)
# Log the training loss by the tensor object.
logging_loss_hook = tf1.train.LoggingTensorHook(
{'loss from LoggingTensorHook': loss},
every_n_secs=3)
# Pass all hooks to `EstimatorSpec`.
return tf1.estimator.EstimatorSpec(mode,
loss=loss,
train_op=train_op,
training_hooks=[stop_hook,
logging_weight_hook,
logging_loss_hook])
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
estimator.train(_input_fn)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp3q__3yt7 INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3q__3yt7', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3q__3yt7/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.025395721, step = 0 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.0769143] [ 1.0241832]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.] INFO:tensorflow:loss from LoggingTensorHook = 0.025395721 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.1124082] [ 0.9824805]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [-0.03549388] (0.026 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmp3q__3yt7/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Loss for final step: 0.09248222. <tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f05ec414d10>
TensorFlow 2: Catat tensor dan hentikan pelatihan dengan callback khusus dan Model.fit
Di TensorFlow 2, saat Anda menggunakan Keras Model.fit
(atau Model.evaluate
) bawaan untuk pelatihan/evaluasi, Anda dapat mengonfigurasi pemantauan tensor dan penghentian pelatihan dengan mendefinisikan Keras tf.keras.callbacks.Callback
s kustom. Kemudian, Anda meneruskannya ke parameter callbacks
Model.fit
(atau Model.evaluate
). (Pelajari lebih lanjut di panduan Menulis panggilan balik Anda sendiri .)
Dalam contoh di bawah ini:
- Untuk membuat ulang fungsionalitas
StopAtStepHook
, tentukan panggilan balik khusus (bernamaStopAtStepCallback
di bawah) tempat Anda mengganti metodeon_batch_end
untuk menghentikan pelatihan setelah sejumlah langkah tertentu. - Untuk membuat ulang perilaku
LoggingTensorHook
, tentukan panggilan balik khusus (LoggingTensorCallback
) tempat Anda merekam dan mengeluarkan tensor yang dicatat secara manual, karena mengakses tensor dengan nama tidak didukung. Anda juga dapat menerapkan frekuensi logging di dalam panggilan balik khusus. Contoh di bawah ini akan mencetak bobot setiap dua langkah. Strategi lain seperti logging setiap N detik juga dimungkinkan.
class StopAtStepCallback(tf.keras.callbacks.Callback):
def __init__(self, stop_step=None):
super().__init__()
self._stop_step = stop_step
def on_batch_end(self, batch, logs=None):
if self.model.optimizer.iterations >= self._stop_step:
self.model.stop_training = True
print('\nstop training now')
class LoggingTensorCallback(tf.keras.callbacks.Callback):
def __init__(self, every_n_iter):
super().__init__()
self._every_n_iter = every_n_iter
self._log_count = every_n_iter
def on_batch_end(self, batch, logs=None):
if self._log_count > 0:
self._log_count -= 1
print("Logging Tensor Callback: dense/kernel:",
model.layers[0].weights[0])
print("Logging Tensor Callback: dense/bias:",
model.layers[0].weights[1])
print("Logging Tensor Callback loss:", logs["loss"])
else:
self._log_count -= self._every_n_iter
Setelah selesai, teruskan callback StopAtStepCallback
dan LoggingTensorCallback
—ke parameter callbacks
Model.fit
:
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),
LoggingTensorCallback(every_n_iter=2)])
1/3 [=========>....................] - ETA: 0s - loss: 3.2473Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[-0.27049014], [-0.73790836]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.04980864], dtype=float32)> Logging Tensor Callback loss: 3.2473244667053223 stop training now Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[-0.22285421], [-0.6911988 ]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.09196297], dtype=float32)> Logging Tensor Callback loss: 5.644947052001953 3/3 [==============================] - 0s 4ms/step - loss: 5.6449 <keras.callbacks.History at 0x7f053022be90>
Langkah selanjutnya
Pelajari lebih lanjut tentang panggilan balik di:
- Dokumen API:
tf.keras.callbacks.Callback
- Panduan: Menulis panggilan balik Anda sendiri
- Panduan: Pelatihan dan evaluasi dengan metode bawaan (bagian Menggunakan panggilan balik )
Anda mungkin juga menemukan sumber daya terkait migrasi berikut ini berguna:
- Panduan migrasi penghentian awal :
tf.keras.callbacks.EarlyStopping
adalah panggilan balik penghentian awal bawaan - Panduan migrasi TensorBoard : TensorBoard memungkinkan pelacakan dan menampilkan metrik
- Pelatihan dengan panduan migrasi logika terbantu : Dari
SessionRunHook
di TensorFlow 1 hingga callback Keras di TensorFlow 2