Lihat di TensorFlow.org | Jalankan di Google Colab | Lihat sumber di GitHub | Unduh buku catatan |
Ungkapan "Menyimpan model TensorFlow" biasanya berarti salah satu dari dua hal:
- Pos pemeriksaan, OR
- Model Tersimpan.
Pos pemeriksaan menangkap nilai yang tepat dari semua parameter ( objek tf.Variable
) yang digunakan oleh model. Pos pemeriksaan tidak berisi deskripsi komputasi yang ditentukan oleh model dan dengan demikian biasanya hanya berguna ketika kode sumber yang akan menggunakan nilai parameter yang disimpan tersedia.
Format SavedModel di sisi lain mencakup deskripsi serial dari perhitungan yang ditentukan oleh model di samping nilai parameter (pos pemeriksaan). Model dalam format ini tidak bergantung pada kode sumber yang membuat model. Oleh karena itu, mereka cocok untuk diterapkan melalui TensorFlow Serving, TensorFlow Lite, TensorFlow.js, atau program dalam bahasa pemrograman lain (API C, C++, Java, Go, Rust, C# dll. TensorFlow).
Panduan ini mencakup API untuk menulis dan membaca pos pemeriksaan.
Mempersiapkan
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
Menyimpan dari API pelatihan tf.keras
Lihat panduan tf.keras
tentang menyimpan dan memulihkan .
tf.keras.Model.save_weights
menyimpan pos pemeriksaan TensorFlow.
net.save_weights('easy_checkpoint')
Menulis pos pemeriksaan
Status persisten model TensorFlow disimpan di objek tf.Variable
. Ini dapat dibuat secara langsung, tetapi sering dibuat melalui API tingkat tinggi seperti tf.keras.layers
atau tf.keras.Model
.
Cara termudah untuk mengelola variabel adalah dengan melampirkannya ke objek Python, lalu mereferensikan objek tersebut.
Subkelas tf.train.Checkpoint
, tf.keras.layers.Layer
, dan tf.keras.Model
secara otomatis melacak variabel yang ditetapkan ke atributnya. Contoh berikut membangun model linier sederhana, kemudian menulis pos pemeriksaan yang berisi nilai untuk semua variabel model.
Anda dapat dengan mudah menyimpan model-checkpoint dengan Model.save_weights
.
Pos pemeriksaan manual
Mempersiapkan
Untuk membantu mendemonstrasikan semua fitur tf.train.Checkpoint
, tentukan kumpulan data mainan dan langkah pengoptimalan:
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
Buat objek pos pemeriksaan
Gunakan objek tf.train.Checkpoint
untuk membuat pos pemeriksaan secara manual, di mana objek yang ingin Anda periksa ditetapkan sebagai atribut pada objek.
Sebuah tf.train.CheckpointManager
juga dapat membantu untuk mengelola beberapa pos pemeriksaan.
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
Latih dan pos pemeriksaan model
Loop pelatihan berikut membuat instance model dan pengoptimal, lalu mengumpulkannya menjadi objek tf.train.Checkpoint
. Ini memanggil langkah pelatihan dalam satu lingkaran pada setiap kumpulan data, dan secara berkala menulis pos pemeriksaan ke disk.
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch. Saved checkpoint for step 10: ./tf_ckpts/ckpt-1 loss 31.27 Saved checkpoint for step 20: ./tf_ckpts/ckpt-2 loss 24.68 Saved checkpoint for step 30: ./tf_ckpts/ckpt-3 loss 18.12 Saved checkpoint for step 40: ./tf_ckpts/ckpt-4 loss 11.65 Saved checkpoint for step 50: ./tf_ckpts/ckpt-5 loss 5.39
Pulihkan dan lanjutkan pelatihan
Setelah siklus pelatihan pertama, Anda dapat melewati model dan manajer baru, tetapi melanjutkan pelatihan tepat di tempat terakhir Anda tinggalkan:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5 Saved checkpoint for step 60: ./tf_ckpts/ckpt-6 loss 1.50 Saved checkpoint for step 70: ./tf_ckpts/ckpt-7 loss 1.27 Saved checkpoint for step 80: ./tf_ckpts/ckpt-8 loss 0.56 Saved checkpoint for step 90: ./tf_ckpts/ckpt-9 loss 0.70 Saved checkpoint for step 100: ./tf_ckpts/ckpt-10 loss 0.35
Objek tf.train.CheckpointManager
menghapus pos pemeriksaan lama. Di atasnya dikonfigurasi untuk hanya menyimpan tiga pos pemeriksaan terbaru.
print(manager.checkpoints) # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']
Jalur ini, misalnya './tf_ckpts/ckpt-10'
, bukan file di disk. Sebaliknya mereka adalah awalan untuk file index
dan satu atau lebih file data yang berisi nilai variabel. Awalan ini dikelompokkan bersama dalam satu file checkpoint
( './tf_ckpts/checkpoint'
) tempat CheckpointManager
menyimpan statusnya.
ls ./tf_ckpts
checkpoint ckpt-8.data-00000-of-00001 ckpt-9.index ckpt-10.data-00000-of-00001 ckpt-8.index ckpt-10.index ckpt-9.data-00000-of-00001
Mekanika pemuatan
TensorFlow mencocokkan variabel dengan nilai checkpoint dengan melintasi grafik berarah dengan tepi bernama, mulai dari objek yang dimuat. Nama tepi biasanya berasal dari nama atribut di objek, misalnya "l1"
di self.l1 = tf.keras.layers.Dense(5)
. tf.train.Checkpoint
menggunakan nama argumen kata kuncinya, seperti pada "step"
di tf.train.Checkpoint(step=...)
.
Grafik ketergantungan dari contoh di atas terlihat seperti ini:
Pengoptimal berwarna merah, variabel reguler berwarna biru, dan variabel slot pengoptimal berwarna oranye. Node lain—misalnya, mewakili tf.train.Checkpoint
—berwarna hitam.
Variabel slot adalah bagian dari status pengoptimal, tetapi dibuat untuk variabel tertentu. Misalnya, tepi 'm'
di atas sesuai dengan momentum, yang dilacak oleh pengoptimal Adam untuk setiap variabel. Variabel slot hanya disimpan di pos pemeriksaan jika variabel dan pengoptimal keduanya akan disimpan, sehingga tepi putus-putus.
Memanggil restore
pada objek tf.train.Checkpoint
mengantre pemulihan yang diminta, memulihkan nilai variabel segera setelah ada jalur yang cocok dari objek Checkpoint
. Misalnya, Anda dapat memuat hanya bias dari model yang Anda definisikan di atas dengan merekonstruksi satu jalur ke sana melalui jaringan dan lapisan.
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
[0. 0. 0. 0. 0.] [2.7209885 3.7588918 4.421351 4.1466427 4.0712557]
Grafik ketergantungan untuk objek baru ini adalah subgraf yang jauh lebih kecil dari pos pemeriksaan yang lebih besar yang Anda tulis di atas. Ini hanya mencakup bias dan penghitung simpan yang digunakan tf.train.Checkpoint
untuk memberi nomor pos pemeriksaan.
restore
mengembalikan objek status, yang memiliki pernyataan opsional. Semua objek yang dibuat di Checkpoint
baru telah dipulihkan, jadi status.assert_existing_objects_matched
lolos.
status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>
Ada banyak objek di pos pemeriksaan yang belum cocok, termasuk kernel layer dan variabel pengoptimal. status.assert_consumed
hanya lolos jika pos pemeriksaan dan program sama persis, dan akan mengeluarkan pengecualian di sini.
Restorasi tertunda
Objek Layer
di TensorFlow dapat menunda pembuatan variabel ke panggilan pertama, saat bentuk input tersedia. Misalnya, bentuk kernel layer Dense
bergantung pada bentuk input dan output layer, sehingga bentuk output yang diperlukan sebagai argumen konstruktor bukanlah informasi yang cukup untuk membuat variabel sendiri. Karena memanggil Layer
juga membaca nilai variabel, pemulihan harus terjadi antara pembuatan variabel dan penggunaan pertama.
Untuk mendukung idiom ini, tf.train.Checkpoint
yang belum memiliki variabel yang cocok.
deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy()) # Restored
[[0. 0. 0. 0. 0.]] [[4.5854754 4.607731 4.649179 4.8474874 5.121 ]]
Memeriksa pos pemeriksaan secara manual
tf.train.load_checkpoint
mengembalikan CheckpointReader
yang memberikan akses tingkat rendah ke konten pos pemeriksaan. Ini berisi pemetaan dari setiap kunci variabel, ke bentuk dan tipe d untuk setiap variabel di pos pemeriksaan. Kunci variabel adalah jalur objeknya, seperti pada grafik yang ditampilkan di atas.
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH', 'iterator/.ATTRIBUTES/ITERATOR_STATE', 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', 'save_counter/.ATTRIBUTES/VARIABLE_VALUE', 'step/.ATTRIBUTES/VARIABLE_VALUE']
Jadi jika Anda tertarik dengan nilai net.l1.kernel
Anda bisa mendapatkan nilainya dengan kode berikut:
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5] Dtype: float32
Ini juga menyediakan metode get_tensor
yang memungkinkan Anda memeriksa nilai variabel:
reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121 ]], dtype=float32)
Pelacakan objek
Pos pemeriksaan menyimpan dan memulihkan nilai objek tf.Variable
dengan "melacak" setiap variabel atau objek yang dapat dilacak yang ditetapkan dalam salah satu atributnya. Saat menjalankan penyimpanan, variabel dikumpulkan secara rekursif dari semua objek terlacak yang dapat dijangkau.
Seperti penetapan atribut langsung seperti self.l1 = tf.keras.layers.Dense(5)
, menetapkan daftar dan kamus ke atribut akan melacak isinya.
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
Anda mungkin melihat objek pembungkus untuk daftar dan kamus. Pembungkus ini adalah versi yang dapat diperiksa dari struktur data yang mendasarinya. Sama seperti pemuatan berbasis atribut, pembungkus ini mengembalikan nilai variabel segera setelah ditambahkan ke penampung.
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])
Objek yang dapat dilacak termasuk tf.train.Checkpoint
, tf.Module
dan subkelasnya (misalnya keras.layers.Layer
dan keras.Model
), dan wadah Python yang dikenali:
-
dict
(dancollections.OrderedDict
) -
list
-
tuple
(dancollections.namedtuple
,typing.NamedTuple
)
Jenis penampung lainnya tidak didukung , termasuk:
-
collections.defaultdict
-
set
Semua objek Python lainnya diabaikan , termasuk:
-
int
-
string
-
float
Ringkasan
Objek TensorFlow menyediakan mekanisme otomatis yang mudah untuk menyimpan dan memulihkan nilai variabel yang mereka gunakan.