Xem trên TensorFlow.org | Chạy trong Google Colab | Xem nguồn trên GitHub | Tải xuống sổ ghi chép |
Trong TensorFlow 1, bạn sử dụng tf.estimator.LoggingTensorHook
để theo dõi và ghi lại các tensor, trong khi tf.estimator.StopAtStepHook
giúp dừng đào tạo ở một bước cụ thể khi đào tạo với tf.estimator.Estimator
. Sổ tay này trình bày cách di chuyển từ các API này sang các API tương đương của chúng trong TensorFlow 2 bằng cách sử dụng các lệnh gọi lại Keras tùy chỉnh ( tf.keras.callbacks.Callback
) với Model.fit
.
Các lệnh gọi lại Keras là các đối tượng được gọi tại các điểm khác nhau trong quá trình đào tạo / đánh giá / dự đoán trong các API Model.fit
/ Model.evaluate
/ Model.predict
dự đoán tích hợp sẵn. Bạn có thể tìm hiểu thêm về lệnh gọi lại trong tài liệu API tf.keras.callbacks.Callback
, cũng như hướng dẫn Viết lệnh gọi lại và Đào tạo và đánh giá của riêng bạn bằng các phương pháp tích hợp (phần Sử dụng lệnh gọi lại ). Để di chuyển từ lệnh gọi lại SessionRunHook
trong TensorFlow 1 sang Keras trong TensorFlow 2, hãy xem khóa đào tạo Di chuyển với hướng dẫn logic được hỗ trợ .
Thành lập
Bắt đầu với nhập khẩu và một tập dữ liệu đơn giản cho mục đích trình diễn:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
# Define an input function.
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
TensorFlow 1: Trình căng nhật ký và dừng đào tạo với API tf.estimator
Trong TensorFlow 1, bạn xác định các móc khác nhau để kiểm soát hành vi đào tạo. Sau đó, bạn chuyển các hook này tới tf.estimator.EstimatorSpec
.
Trong ví dụ dưới đây:
- Để theo dõi / ghi nhật ký độ căng — ví dụ, trọng lượng hoặc tổn thất của mô hình — bạn sử dụng
tf.estimator.LoggingTensorHook
(tf.train.LoggingTensorHook
là bí danh của nó). - Để dừng đào tạo ở một bước cụ thể, bạn sử dụng
tf.estimator.StopAtStepHook
(tf.train.StopAtStepHook
là bí danh của nó).
def _model_fn(features, labels, mode):
dense = tf1.layers.Dense(1)
logits = dense(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
# Define the stop hook.
stop_hook = tf1.train.StopAtStepHook(num_steps=2)
# Access tensors to be logged by names.
kernel_name = tf.identity(dense.weights[0])
bias_name = tf.identity(dense.weights[1])
logging_weight_hook = tf1.train.LoggingTensorHook(
tensors=[kernel_name, bias_name],
every_n_iter=1)
# Log the training loss by the tensor object.
logging_loss_hook = tf1.train.LoggingTensorHook(
{'loss from LoggingTensorHook': loss},
every_n_secs=3)
# Pass all hooks to `EstimatorSpec`.
return tf1.estimator.EstimatorSpec(mode,
loss=loss,
train_op=train_op,
training_hooks=[stop_hook,
logging_weight_hook,
logging_loss_hook])
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
estimator.train(_input_fn)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp3q__3yt7 INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3q__3yt7', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3q__3yt7/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.025395721, step = 0 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.0769143] [ 1.0241832]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.] INFO:tensorflow:loss from LoggingTensorHook = 0.025395721 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.1124082] [ 0.9824805]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [-0.03549388] (0.026 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmp3q__3yt7/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Loss for final step: 0.09248222. <tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f05ec414d10>
TensorFlow 2: Bộ căng dây ghi nhật ký và ngừng đào tạo với các lệnh gọi lại tùy chỉnh và Model.fit
Trong TensorFlow 2, khi bạn sử dụng Keras Model.fit
(hoặc Model.evaluate
) tích hợp để đào tạo / đánh giá, bạn có thể định cấu hình theo dõi tensor và dừng đào tạo bằng cách xác định Keras tf.keras.callbacks.Callback
s tùy chỉnh. Sau đó, bạn chuyển chúng đến tham số callbacks
của Model.fit
(hoặc Model.evaluate
). (Tìm hiểu thêm trong hướng dẫn Viết lệnh gọi lại của riêng bạn .)
Trong ví dụ dưới đây:
- Để tạo lại các chức năng của
StopAtStepHook
, hãy xác định một lệnh gọi lại tùy chỉnh (có tên làStopAtStepCallback
bên dưới) nơi bạn ghi đè phương thứcon_batch_end
để dừng đào tạo sau một số bước nhất định. - Để tạo lại hành vi
LoggingTensorHook
, hãy xác định một lệnh gọi lại tùy chỉnh (LoggingTensorCallback
) nơi bạn ghi lại và xuất các tensor đã ghi nhật ký theo cách thủ công, vì không hỗ trợ truy cập vào tensors theo tên. Bạn cũng có thể triển khai tần suất ghi nhật ký bên trong lệnh gọi lại tùy chỉnh. Ví dụ dưới đây sẽ in các trọng số sau mỗi hai bước. Các chiến lược khác như ghi nhật ký mỗi N giây cũng có thể thực hiện được.
class StopAtStepCallback(tf.keras.callbacks.Callback):
def __init__(self, stop_step=None):
super().__init__()
self._stop_step = stop_step
def on_batch_end(self, batch, logs=None):
if self.model.optimizer.iterations >= self._stop_step:
self.model.stop_training = True
print('\nstop training now')
class LoggingTensorCallback(tf.keras.callbacks.Callback):
def __init__(self, every_n_iter):
super().__init__()
self._every_n_iter = every_n_iter
self._log_count = every_n_iter
def on_batch_end(self, batch, logs=None):
if self._log_count > 0:
self._log_count -= 1
print("Logging Tensor Callback: dense/kernel:",
model.layers[0].weights[0])
print("Logging Tensor Callback: dense/bias:",
model.layers[0].weights[1])
print("Logging Tensor Callback loss:", logs["loss"])
else:
self._log_count -= self._every_n_iter
Khi hoàn tất, hãy chuyển các lệnh gọi lại StopAtStepCallback
và LoggingTensorCallback
— vào tham số callbacks
của Model.fit
:
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),
LoggingTensorCallback(every_n_iter=2)])
1/3 [=========>....................] - ETA: 0s - loss: 3.2473Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[-0.27049014], [-0.73790836]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.04980864], dtype=float32)> Logging Tensor Callback loss: 3.2473244667053223 stop training now Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[-0.22285421], [-0.6911988 ]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.09196297], dtype=float32)> Logging Tensor Callback loss: 5.644947052001953 3/3 [==============================] - 0s 4ms/step - loss: 5.6449 <keras.callbacks.History at 0x7f053022be90>
Bước tiếp theo
Tìm hiểu thêm về gọi lại trong:
- Tài liệu API:
tf.keras.callbacks.Callback
- Hướng dẫn: Viết lệnh gọi lại của riêng bạn
- Hướng dẫn: Đào tạo và đánh giá bằng các phương pháp tích hợp (phần Sử dụng lệnh gọi lại )
Bạn cũng có thể thấy hữu ích các tài nguyên liên quan đến di chuyển sau:
- Hướng dẫn di chuyển dừng sớm :
tf.keras.callbacks.EarlyStopping
là một lệnh gọi lại dừng sớm được tích hợp sẵn - Hướng dẫn di chuyển TensorBoard: TensorBoard cho phép theo dõi và hiển thị các chỉ số
- Huấn luyện với hướng dẫn di chuyển logic được hỗ trợ : Từ
SessionRunHook
trong TensorFlow 1 đến các lệnh gọi lại Keras trong TensorFlow 2