Lihat di TensorFlow.org | Jalankan di Google Colab | Lihat di GitHub | Unduh buku catatan |
Ringkasan
Panduan ini memberikan daftar praktik terbaik untuk menulis kode menggunakan TensorFlow 2 (TF2), ini ditulis untuk pengguna yang baru saja beralih dari TensorFlow 1 (TF1). Lihat bagian migrasi dari panduan untuk info lebih lanjut tentang migrasi kode TF1 Anda ke TF2.
Mempersiapkan
Impor TensorFlow dan dependensi lainnya untuk contoh dalam panduan ini.
import tensorflow as tf
import tensorflow_datasets as tfds
Rekomendasi untuk TensorFlow idiomatik 2
Refactor kode Anda menjadi modul yang lebih kecil
Praktik yang baik adalah memfaktorkan ulang kode Anda menjadi fungsi yang lebih kecil yang dipanggil sesuai kebutuhan. Untuk kinerja terbaik, Anda harus mencoba mendekorasi blok komputasi terbesar yang Anda bisa dalam tf.function
(perhatikan bahwa fungsi python bersarang yang dipanggil oleh tf.function
tidak memerlukan dekorasi terpisah mereka sendiri, kecuali jika Anda ingin menggunakan jit_compile
yang berbeda pengaturan untuk tf.function
). Bergantung pada kasus penggunaan Anda, ini bisa berupa beberapa langkah pelatihan atau bahkan seluruh putaran pelatihan Anda. Untuk kasus penggunaan inferensi, ini mungkin model forward pass tunggal.
Sesuaikan kecepatan pembelajaran default untuk beberapa tf.keras.optimizer
s
Beberapa pengoptimal Keras memiliki tingkat pembelajaran yang berbeda di TF2. Jika Anda melihat perubahan dalam perilaku konvergensi untuk model Anda, periksa tingkat pembelajaran default.
Tidak ada perubahan untuk optimizers.SGD
, optimizers.Adam
, atau optimizers.RMSprop
.
Tingkat pembelajaran default berikut telah berubah:
-
optimizers.Adagrad
dari0.01
hingga0.001
-
optimizers.Adadelta
dari1.0
hingga0.001
-
optimizers.Adamax
dari0.002
hingga0.001
-
optimizers.Nadam
. Nadam dari0.002
hingga0.001
Gunakan tf.Module
s dan Keras untuk mengelola variabel
tf.Module
s dan tf.keras.layers.Layer
s menawarkan variables
yang mudah digunakan dan properti trainable_variables
, yang mengumpulkan semua variabel dependen secara rekursif. Ini membuatnya mudah untuk mengelola variabel secara lokal ke tempat mereka digunakan.
Lapisan/model Keras mewarisi dari tf.train.Checkpointable
dan terintegrasi dengan @tf.function
, yang memungkinkan untuk langsung memeriksa atau mengekspor SavedModels dari objek Keras. Anda tidak perlu menggunakan API Model.fit
Keras untuk memanfaatkan integrasi ini.
Baca bagian tentang pembelajaran transfer dan penyesuaian dalam panduan Keras untuk mempelajari cara mengumpulkan subset variabel yang relevan menggunakan Keras.
Gabungkan tf.data.Dataset
s dan tf.function
Paket TensorFlow Datasets ( tfds
) berisi utilitas untuk memuat set data yang telah ditentukan sebelumnya sebagai objek tf.data.Dataset
. Untuk contoh ini, Anda dapat memuat dataset MNIST menggunakan tfds
:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Kemudian siapkan data untuk pelatihan:
- Skala ulang setiap gambar.
- Acak urutan contoh.
- Kumpulkan kumpulan gambar dan label.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
Untuk mempersingkat contoh, pangkas kumpulan data agar hanya mengembalikan 5 kumpulan:
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-12-08 17:15:01.637157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Gunakan iterasi Python reguler untuk mengulangi data pelatihan yang sesuai dengan memori. Jika tidak, tf.data.Dataset
adalah cara terbaik untuk mengalirkan data pelatihan dari disk. Kumpulan data adalah iterables (bukan iterator) , dan bekerja seperti iterables Python lainnya dalam eksekusi yang bersemangat. Anda dapat sepenuhnya memanfaatkan fitur prefetching/streaming async dataset dengan membungkus kode Anda dalam tf.function
, yang menggantikan iterasi Python dengan operasi grafik yang setara menggunakan AutoGraph.
@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
prediction = model(x, training=True)
loss = loss_fn(prediction, y)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Jika Anda menggunakan Keras Model.fit
API, Anda tidak perlu khawatir tentang iterasi kumpulan data.
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)
Gunakan loop pelatihan Keras
Jika Anda tidak memerlukan kontrol tingkat rendah dari proses pelatihan Anda, disarankan menggunakan metode built-in fit
, evaluate
, dan predict
dari Keras. Metode-metode ini menyediakan antarmuka yang seragam untuk melatih model terlepas dari implementasinya (berurutan, fungsional, atau sub-kelas).
Keuntungan dari metode ini meliputi:
- Mereka menerima array Numpy, generator Python dan,
tf.data.Datasets
. - Mereka menerapkan regularisasi, dan kerugian aktivasi secara otomatis.
- Mereka mendukung
tf.distribute
di mana kode pelatihan tetap sama terlepas dari konfigurasi perangkat kerasnya . - Mereka mendukung callable sewenang-wenang sebagai kerugian dan metrik.
- Mereka mendukung panggilan balik seperti
tf.keras.callbacks.TensorBoard
, dan panggilan balik khusus. - Mereka berkinerja baik, secara otomatis menggunakan grafik TensorFlow.
Berikut adalah contoh pelatihan model menggunakan Dataset
. Untuk detail tentang cara kerjanya, lihat tutorial .
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5 5/5 [==============================] - 9s 7ms/step - loss: 1.5762 - accuracy: 0.4938 Epoch 2/5 2021-12-08 17:15:11.145429: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 6ms/step - loss: 0.5087 - accuracy: 0.8969 Epoch 3/5 2021-12-08 17:15:11.559374: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 2s 5ms/step - loss: 0.3348 - accuracy: 0.9469 Epoch 4/5 2021-12-08 17:15:13.860407: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 5ms/step - loss: 0.2445 - accuracy: 0.9688 Epoch 5/5 2021-12-08 17:15:14.269850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 6ms/step - loss: 0.2006 - accuracy: 0.9719 2021-12-08 17:15:14.717552: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 1s 4ms/step - loss: 1.4553 - accuracy: 0.5781 Loss 1.4552843570709229, Accuracy 0.578125 2021-12-08 17:15:15.862684: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Sesuaikan pelatihan dan tulis loop Anda sendiri
Jika model Keras bekerja untuk Anda, tetapi Anda membutuhkan lebih banyak fleksibilitas dan kontrol dari langkah pelatihan atau loop pelatihan luar, Anda dapat menerapkan langkah-langkah pelatihan Anda sendiri atau bahkan seluruh loop pelatihan. Lihat panduan Keras tentang menyesuaikan fit
untuk mempelajari lebih lanjut.
Anda juga dapat mengimplementasikan banyak hal sebagai tf.keras.callbacks.Callback
.
Metode ini memiliki banyak keuntungan yang disebutkan sebelumnya , tetapi memberi Anda kendali atas langkah kereta dan bahkan putaran luar.
Ada tiga langkah untuk loop pelatihan standar:
- Ulangi generator Python atau
tf.data.Dataset
untuk mendapatkan kumpulan contoh. - Gunakan
tf.GradientTape
untuk mengumpulkan gradien. - Gunakan salah satu
tf.keras.optimizers
untuk menerapkan pembaruan bobot ke variabel model.
Ingat:
- Selalu sertakan argumen
training
pada metodecall
lapisan dan model subkelas. - Pastikan untuk memanggil model dengan argumen
training
yang disetel dengan benar. - Tergantung pada penggunaan, variabel model mungkin tidak ada sampai model dijalankan pada kumpulan data.
- Anda perlu menangani hal-hal seperti kerugian regularisasi untuk model secara manual.
Tidak perlu menjalankan inisialisasi variabel atau menambahkan dependensi kontrol manual. tf.function
menangani dependensi kontrol otomatis dan inisialisasi variabel pada pembuatan untuk Anda.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
2021-12-08 17:15:16.714849: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 0 2021-12-08 17:15:17.097043: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 1 2021-12-08 17:15:17.502480: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 2 2021-12-08 17:15:17.873701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 3 Finished epoch 4 2021-12-08 17:15:18.344196: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Manfaatkan tf.function
dengan aliran kontrol Python
tf.function
menyediakan cara untuk mengubah aliran kontrol yang bergantung pada data menjadi ekuivalen mode grafik seperti tf.cond
dan tf.while_loop
.
Satu tempat umum di mana aliran kontrol yang bergantung pada data muncul adalah dalam model urutan. tf.keras.layers.RNN
membungkus sel RNN, memungkinkan Anda untuk membuka gulungan perulangan secara statis atau dinamis. Sebagai contoh, Anda dapat mengimplementasikan kembali dynamic unroll sebagai berikut.
class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
@tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
timesteps = tf.shape(input_data)[0]
batch_size = tf.shape(input_data)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
for i in tf.range(timesteps):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)
my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)
Baca panduan tf.function
untuk informasi lebih lanjut.
Metrik dan kerugian gaya baru
Metrik dan kerugian adalah objek yang bekerja dengan penuh semangat dan di tf.function
s.
Objek loss dapat dipanggil, dan mengharapkan ( y_true
, y_pred
) sebagai argumen:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
Gunakan metrik untuk mengumpulkan dan menampilkan data
Anda dapat menggunakan tf.metrics
untuk menggabungkan data dan tf.summary
untuk mencatat ringkasan dan mengarahkannya ke penulis menggunakan pengelola konteks. Ringkasan dipancarkan langsung ke penulis yang berarti Anda harus memberikan nilai step
di situs panggilan.
summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)
Gunakan tf.metrics
untuk menggabungkan data sebelum mencatatnya sebagai ringkasan. Metrik bersifat stateful; mereka mengumpulkan nilai dan mengembalikan hasil kumulatif saat Anda memanggil metode result
(seperti Mean.result
). Hapus nilai akumulasi dengan Model.reset_states
.
def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states()
def test(model, test_x, test_y, step_num):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss = loss_fn(model(test_x, training=False), test_y)
tf.summary.scalar('loss', loss, step=step_num)
train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')
with train_summary_writer.as_default():
train(model, optimizer, dataset)
with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)
Visualisasikan ringkasan yang dihasilkan dengan mengarahkan TensorBoard ke direktori log ringkasan:
tensorboard --logdir /tmp/summaries
Gunakan tf.summary
API untuk menulis data ringkasan untuk visualisasi di TensorBoard. Untuk info lebih lanjut, baca panduan tf.summary
.
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
2021-12-08 17:15:19.339736: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 0 loss: 0.142 accuracy: 0.991 2021-12-08 17:15:19.781743: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 1 loss: 0.125 accuracy: 0.997 2021-12-08 17:15:20.219033: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 2 loss: 0.110 accuracy: 0.997 2021-12-08 17:15:20.598085: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 3 loss: 0.099 accuracy: 0.997 Epoch: 4 loss: 0.085 accuracy: 1.000 2021-12-08 17:15:20.981787: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Nama metrik keras
Model Keras konsisten dalam menangani nama metrik. Saat Anda meneruskan string dalam daftar metrik, string persis tersebut digunakan sebagai name
metrik . Nama-nama ini terlihat di objek history yang dikembalikan oleh model.fit
, dan di log yang diteruskan ke keras.callbacks
. disetel ke string yang Anda berikan dalam daftar metrik.
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0963 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969 2021-12-08 17:15:21.942940: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
Men-debug
Gunakan eksekusi bersemangat untuk menjalankan kode Anda selangkah demi selangkah untuk memeriksa bentuk, tipe data, dan nilai. API tertentu, seperti tf.function
, tf.keras
, dll. dirancang untuk menggunakan eksekusi Graph, untuk kinerja dan portabilitas. Saat men-debug, gunakan tf.config.run_functions_eagerly(True)
untuk menggunakan eksekusi bersemangat di dalam kode ini.
Sebagai contoh:
@tf.function
def f(x):
if x > 0:
import pdb
pdb.set_trace()
x = x + 1
return x
tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
6 @tf.function
7 def f(x):
8 if x > 0:
9 import pdb
10 pdb.set_trace()
11 -> x = x + 1
12 return x
13
14 tf.config.run_functions_eagerly(True)
15 f(tf.constant(1))
[EOF]
Ini juga berfungsi di dalam model Keras dan API lain yang mendukung eksekusi yang bersemangat:
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
import pdb
pdb.set_trace()
return input_data // 2
tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
10 if tf.reduce_mean(input_data) > 0:
11 return input_data
12 else:
13 import pdb
14 pdb.set_trace()
15 -> return input_data // 2
16
17
18 tf.config.run_functions_eagerly(True)
19 model = CustomModel()
20 model(tf.constant([-2, -4]))
Catatan:
metode
tf.keras.Model
sepertifit
, evaluation , danevaluate
tf.function
sebagai grafik denganpredict
di bawah tenda.Saat menggunakan
tf.keras.Model.compile
, setelrun_eagerly = True
untuk menonaktifkan logikaModel
agar tidak dibungkus dengantf.function
.Gunakan
tf.data.experimental.enable_debug_mode
untuk mengaktifkan mode debug untuktf.data
. Baca dokumen API untuk detail selengkapnya.
Jangan simpan tf.Tensors
di objek Anda
Objek tensor ini mungkin dibuat baik dalam tf.function
atau dalam konteks bersemangat, dan tensor ini berperilaku berbeda. Selalu gunakan tf.Tensor
s hanya untuk nilai menengah.
Untuk melacak status, gunakan tf.Variable
s karena selalu dapat digunakan dari kedua konteks. Baca panduan tf.Variable
untuk mempelajari lebih lanjut.
Sumber daya dan bacaan lebih lanjut
Baca panduan dan tutorial TF2 untuk mempelajari lebih lanjut tentang cara menggunakan TF2.
Jika sebelumnya Anda menggunakan TF1.x, sangat disarankan agar Anda memigrasikan kode ke TF2. Baca panduan migrasi untuk mempelajari lebih lanjut.