Các điểm kiểm tra đào tạo

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Cụm từ "Lưu mô hình TensorFlow" thường có nghĩa là một trong hai điều:

  1. Trạm kiểm soát, HOẶC
  2. SavedModel.

Các điểm kiểm tra nắm bắt giá trị chính xác của tất cả các tham số (đối tượng tf.Variable ) được sử dụng bởi một mô hình. Các điểm kiểm tra không chứa bất kỳ mô tả nào về tính toán được xác định bởi mô hình và do đó thường chỉ hữu ích khi có sẵn mã nguồn sử dụng các giá trị tham số đã lưu.

Mặt khác, định dạng SavedModel bao gồm mô tả tuần tự của phép tính được xác định bởi mô hình cùng với các giá trị tham số (điểm kiểm tra). Các mô hình ở định dạng này độc lập với mã nguồn đã tạo ra mô hình. Do đó, chúng phù hợp để triển khai thông qua TensorFlow Serving, TensorFlow Lite, TensorFlow.js hoặc các chương trình bằng các ngôn ngữ lập trình khác (C, C ++, Java, Go, Rust, C #, v.v. TensorFlow API).

Hướng dẫn này bao gồm các API để ghi và đọc các điểm kiểm tra.

Thành lập

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Tiết kiệm từ các API đào tạo tf.keras

Xem hướng dẫn của tf.keras về cách lưu và khôi phục.

tf.keras.Model.save_weights lưu một điểm kiểm tra TensorFlow.

net.save_weights('easy_checkpoint')

Viết các trạm kiểm soát

Trạng thái liên tục của một mô hình TensorFlow được lưu trữ trong các đối tượng tf.Variable . Chúng có thể được tạo trực tiếp, nhưng thường được tạo thông qua các API cấp cao như tf.keras.layers hoặc tf.keras.Model .

Cách dễ nhất để quản lý các biến là gắn chúng vào các đối tượng Python, sau đó tham chiếu đến các đối tượng đó.

Các lớp con của tf.train.Checkpoint , tf.keras.layers.Layertf.keras.Model tự động theo dõi các biến được gán cho thuộc tính của chúng. Ví dụ sau đây xây dựng một mô hình tuyến tính đơn giản, sau đó viết các điểm kiểm tra chứa các giá trị cho tất cả các biến của mô hình.

Bạn có thể dễ dàng lưu điểm kiểm tra mô hình với Model.save_weights .

Kiểm tra thủ công

Thành lập

Để giúp chứng minh tất cả các tính năng của tf.train.Checkpoint , hãy xác định tập dữ liệu đồ chơi và bước tối ưu hóa:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Tạo các đối tượng điểm kiểm tra

Sử dụng đối tượng tf.train.Checkpoint để tạo một điểm kiểm tra theo cách thủ công, trong đó các đối tượng bạn muốn điểm kiểm tra được đặt làm thuộc tính trên đối tượng.

Một tf.train.CheckpointManager cũng có thể hữu ích để quản lý nhiều trạm kiểm soát.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Huấn luyện và kiểm tra mô hình

Vòng lặp đào tạo sau đây tạo một phiên bản của mô hình và của một trình tối ưu hóa, sau đó tập hợp chúng thành một đối tượng tf.train.Checkpoint . Nó gọi bước huấn luyện trong một vòng lặp trên mỗi lô dữ liệu và định kỳ ghi các điểm kiểm tra vào đĩa.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

Khôi phục và tiếp tục đào tạo

Sau chu kỳ đào tạo đầu tiên, bạn có thể vượt qua một mô hình và người quản lý mới, nhưng hãy tiếp tục đào tạo chính xác nơi bạn đã dừng lại:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

Đối tượng tf.train.CheckpointManager xóa các điểm kiểm tra cũ. Ở trên, nó được định cấu hình để chỉ giữ ba điểm kiểm tra gần đây nhất.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Các đường dẫn này, ví dụ: './tf_ckpts/ckpt-10' , không phải là tệp trên đĩa. Thay vào đó, chúng là tiền tố cho một tệp index và một hoặc nhiều tệp dữ liệu chứa các giá trị biến. Các tiền tố này được nhóm lại với nhau trong một tệp checkpoint ( './tf_ckpts/checkpoint' ) nơi CheckpointManager lưu trạng thái của nó.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Cơ khí tải

TensorFlow đối sánh các biến với các giá trị đã kiểm tra bằng cách duyệt qua một biểu đồ có hướng với các cạnh được đặt tên, bắt đầu từ đối tượng đang được tải. Tên cạnh thường đến từ tên thuộc tính trong các đối tượng, ví dụ như "l1" trong self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint sử dụng tên đối số từ khóa của nó, như trong "step" trong tf.train.Checkpoint(step=...) .

Biểu đồ phụ thuộc từ ví dụ trên trông như sau:

Trực quan hóa biểu đồ phụ thuộc cho vòng lặp đào tạo ví dụ

Trình tối ưu hóa có màu đỏ, các biến thông thường có màu xanh lam và các biến vị trí của trình tối ưu hóa có màu cam. Các nút khác — ví dụ, đại diện cho tf.train.Checkpoint — có màu đen.

Các biến vị trí là một phần của trạng thái của trình tối ưu hóa, nhưng được tạo cho một biến cụ thể. Ví dụ: các cạnh 'm' ở trên tương ứng với động lượng mà trình tối ưu hóa Adam theo dõi cho mỗi biến. Các biến vị trí chỉ được lưu trong một trạm kiểm soát nếu cả biến và trình tối ưu hóa đều được lưu, do đó các cạnh gạch ngang.

Gọi restore trên đối tượng tf.train.Checkpoint xếp hàng các khôi phục được yêu cầu, khôi phục các giá trị biến ngay khi có một đường dẫn phù hợp từ đối tượng Checkpoint . Ví dụ: bạn có thể tải chỉ độ lệch từ mô hình bạn đã xác định ở trên bằng cách tạo lại một đường dẫn đến nó thông qua mạng và lớp.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

Biểu đồ phụ thuộc cho các đối tượng mới này là một đồ thị con nhỏ hơn nhiều của trạm kiểm soát lớn hơn mà bạn đã viết ở trên. Nó chỉ bao gồm bias và một bộ đếm lưu mà tf.train.Checkpoint sử dụng để đánh số các điểm kiểm tra.

Hình dung một đồ thị con cho biến thiên vị

restore trả về một đối tượng trạng thái, đối tượng này có các xác nhận tùy chọn. Tất cả các đối tượng được tạo trong Checkpoint mới đã được khôi phục, do đó, status.assert_existing_objects_matched vượt qua.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

Có nhiều đối tượng trong trạm kiểm soát chưa khớp, bao gồm nhân của lớp và các biến của trình tối ưu hóa. status.assert_consumed chỉ vượt qua nếu điểm kiểm tra và chương trình khớp chính xác và sẽ đưa ra một ngoại lệ ở đây.

Phục hồi hoãn lại

Các đối tượng Layer trong TensorFlow có thể trì hoãn việc tạo các biến cho lần gọi đầu tiên của chúng, khi các hình dạng đầu vào có sẵn. Ví dụ, hình dạng của nhân của lớp Dense phụ thuộc vào cả hình dạng đầu vào và đầu ra của lớp, và do đó hình dạng đầu ra được yêu cầu như một đối số của hàm tạo không đủ thông tin để tự tạo biến. Vì việc gọi một Layer cũng đọc giá trị của biến, nên việc khôi phục phải xảy ra giữa lần tạo biến và lần sử dụng đầu tiên.

Để hỗ trợ thành ngữ này, tf.train.Checkpoint khôi phục chưa có biến phù hợp.

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]

Kiểm tra các trạm kiểm soát theo cách thủ công

tf.train.load_checkpoint trả về CheckpointReader cấp quyền truy cập cấp thấp hơn vào nội dung điểm kiểm tra. Nó chứa các ánh xạ từ khóa của mỗi biến, đến hình dạng và kiểu cho mỗi biến trong trạm kiểm soát. Chìa khóa của một biến là đường dẫn đối tượng của nó, giống như trong đồ thị được hiển thị ở trên.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Vì vậy, nếu bạn quan tâm đến giá trị của net.l1.kernel , bạn có thể lấy giá trị bằng đoạn mã sau:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Nó cũng cung cấp một phương thức get_tensor cho phép bạn kiểm tra giá trị của một biến:

reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

Theo dõi đối tượng

Các trạm kiểm soát lưu và khôi phục các giá trị của các đối tượng tf.Variable bằng cách "theo dõi" bất kỳ biến hoặc đối tượng có thể theo dõi nào được đặt trong một trong các thuộc tính của nó. Khi thực hiện lưu, các biến được thu thập một cách đệ quy từ tất cả các đối tượng được theo dõi có thể truy cập được.

Như với các phép gán thuộc tính trực tiếp như self.l1 = tf.keras.layers.Dense(5) , việc gán danh sách và từ điển cho các thuộc tính sẽ theo dõi nội dung của chúng.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Bạn có thể nhận thấy các đối tượng trình bao bọc cho danh sách và từ điển. Các trình bao bọc này là các phiên bản có thể kiểm tra của cấu trúc dữ liệu cơ bản. Cũng giống như cách tải dựa trên thuộc tính, các trình bao bọc này khôi phục giá trị của một biến ngay sau khi nó được thêm vào vùng chứa.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Các đối tượng có thể theo dõi bao gồm tf.train.Checkpoint , tf.Module và các lớp con của nó (ví dụ: keras.layers.Layerkeras.Model ) và các vùng chứa Python được công nhận:

  • dict (và collections.OrderedDict tập.OrderedDict)
  • list
  • tuple (và collections.namedtuple , typing.NamedTuple )

Các loại vùng chứa khác không được hỗ trợ , bao gồm:

  • collections.defaultdict
  • set

Tất cả các đối tượng Python khác đều bị bỏ qua , bao gồm:

  • int
  • string
  • float

Bản tóm tắt

Các đối tượng TensorFlow cung cấp một cơ chế tự động dễ dàng để lưu và khôi phục các giá trị của các biến mà chúng sử dụng.