نقاط تفتيش التدريب

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

عادةً ما تعني عبارة "حفظ نموذج TensorFlow" أحد أمرين:

  1. نقاط التفتيش ، أو
  2. نموذج.

تلتقط نقاط التحقق القيمة الدقيقة لجميع المعلمات (الكائنات tf.Variable ) المستخدمة بواسطة النموذج. لا تحتوي نقاط التحقق على أي وصف للحساب المحدد بواسطة النموذج ، وبالتالي فهي مفيدة فقط عندما يتوفر كود المصدر الذي سيستخدم قيم المعلمات المحفوظة.

من ناحية أخرى ، يتضمن تنسيق SavedModel وصفًا متسلسلًا للحساب المحدد بواسطة النموذج بالإضافة إلى قيم المعلمات (نقطة التحقق). النماذج في هذا التنسيق مستقلة عن التعليمات البرمجية المصدر التي أنشأت النموذج. وبالتالي فهي مناسبة للنشر عبر TensorFlow Serving أو TensorFlow Lite أو TensorFlow.js أو البرامج بلغات البرمجة الأخرى (C ، C ++ ، Java ، Go ، Rust ، C # إلخ. TensorFlow APIs).

يغطي هذا الدليل واجهات برمجة التطبيقات (APIs) لكتابة وقراءة نقاط التفتيش.

يثبت

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

الحفظ من واجهات برمجة تطبيقات التدريب tf.keras

راجع دليل tf.keras حول الحفظ والاستعادة.

يحفظ tf.keras.Model.save_weights نقطة تفتيش TensorFlow.

net.save_weights('easy_checkpoint')

كتابة نقاط التفتيش

يتم تخزين الحالة المستمرة لنموذج tf.Variable في كائنات متغيرة tf. يمكن إنشاء هذه بشكل مباشر ، ولكن غالبًا ما يتم إنشاؤها من خلال واجهات برمجة تطبيقات عالية المستوى مثل tf.keras.layers أو tf.keras.Model .

أسهل طريقة لإدارة المتغيرات هي إرفاقها بكائنات بايثون ، ثم الرجوع إلى تلك الكائنات.

تتتبع الفئات الفرعية لـ tf.train.Checkpoint و tf.keras.layers.Layer و tf.keras.Model المتغيرات المخصصة tf.keras.Model تلقائيًا. يُنشئ المثال التالي نموذجًا خطيًا بسيطًا ، ثم يكتب نقاط التحقق التي تحتوي على قيم لجميع متغيرات النموذج.

يمكنك بسهولة حفظ نقطة فحص النموذج باستخدام Model.save_weights .

التفتيش اليدوي

يثبت

للمساعدة في توضيح جميع ميزات tf.train.Checkpoint ، حدد مجموعة بيانات اللعبة وخطوة التحسين:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

إنشاء كائنات نقطة التفتيش

استخدم كائن tf.train.Checkpoint لإنشاء نقطة تحقق يدويًا ، حيث يتم تعيين الكائنات التي تريد تحديدها كسمات على الكائن.

يمكن أن يكون tf.train.CheckpointManager أيضًا في إدارة نقاط التفتيش المتعددة.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

تدريب وفحص النموذج

تُنشئ حلقة التدريب التالية مثيلاً للنموذج والمحسِّن ، ثم تجمعهما في كائن tf.train.Checkpoint . يستدعي خطوة التدريب في حلقة على كل دفعة من البيانات ، ويكتب بشكل دوري نقاط التفتيش على القرص.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

استعادة ومواصلة التدريب

بعد الدورة التدريبية الأولى ، يمكنك اجتياز نموذج ومدير جديدين ، ولكن يمكنك متابعة التدريب من حيث توقفت تمامًا:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

يحذف كائن tf.train.CheckpointManager نقاط التحقق القديمة. تم تكوينه أعلاه للاحتفاظ بنقاط التفتيش الثلاثة الأخيرة فقط.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

هذه المسارات ، مثل './tf_ckpts/ckpt-10' ، ليست ملفات على القرص. بدلاً من ذلك ، فهي بادئات لملف index وملف بيانات واحد أو أكثر يحتوي على القيم المتغيرة. يتم تجميع هذه البادئات معًا في ملف checkpoint تحقق واحد ( './tf_ckpts/checkpoint' ) حيث يحفظ CheckpointManager حالته.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

ميكانيكا التحميل

يطابق TensorFlow المتغيرات مع القيم المحددة من خلال اجتياز الرسم البياني الموجه ذي الحواف المسماة ، بدءًا من الكائن الذي يتم تحميله. تأتي أسماء الحواف عادةً من أسماء السمات في الكائنات ، على سبيل المثال "l1" في self.l1 = tf.keras.layers.Dense(5) . يستخدم tf.train.Checkpoint أسماء وسيطات الكلمات الرئيسية الخاصة به ، كما في "step" في tf.train.Checkpoint(step=...) .

يبدو الرسم البياني للتبعية من المثال أعلاه كما يلي:

تصور الرسم البياني للتبعية لمثال حلقة التدريب

المحسن باللون الأحمر ، والمتغيرات العادية باللون الأزرق ، ومتغيرات فتحة المحسن باللون البرتقالي. العقد الأخرى - على سبيل المثال ، التي تمثل tf.train.Checkpoint - هي باللون الأسود.

تعد متغيرات الفتحة جزءًا من حالة المحسن ، ولكن يتم إنشاؤها لمتغير معين. على سبيل المثال ، تتوافق حواف 'm' أعلاه مع الزخم ، الذي يتتبعه مُحسِّن آدم لكل متغير. يتم حفظ متغيرات الفتحة في نقطة فحص فقط إذا تم حفظ المتغير والمحسن ، وبالتالي الحواف المتقطعة.

يؤدي استدعاء restore على tf.train.Checkpoint الكائن في قائمة انتظار عمليات الاستعادة المطلوبة ، واستعادة القيم المتغيرة بمجرد وجود مسار مطابق من كائن Checkpoint . على سبيل المثال ، يمكنك تحميل التحيز فقط من النموذج الذي حددته أعلاه عن طريق إعادة بناء مسار واحد إليه عبر الشبكة والطبقة.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

الرسم البياني للتبعية لهذه الكائنات الجديدة هو رسم بياني فرعي أصغر بكثير لنقطة التحقق الأكبر التي كتبتها أعلاه. يتضمن فقط التحيز وعداد الحفظ الذي يستخدمه tf.train.Checkpoint نقاط التفتيش.

تصور الرسم البياني الفرعي لمتغير التحيز

restore كائن الحالة ، الذي يحتوي على تأكيدات اختيارية. تمت استعادة جميع الكائنات التي تم إنشاؤها في Checkpoint التحقق الجديدة ، لذا فإن status.assert_existing_objects_matched يمر.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

هناك العديد من الكائنات في نقطة التحقق غير المتطابقة ، بما في ذلك نواة الطبقة ومتغيرات المحسن. status.assert_consumed يمر فقط إذا كانت نقطة التفتيش والبرنامج متطابقتين تمامًا ، وسوف يطرح استثناء هنا.

الترميمات المؤجلة

قد تؤجل كائنات Layer في TensorFlow إنشاء المتغيرات لاستدعائها الأول ، عندما تكون أشكال الإدخال متاحة. على سبيل المثال ، يعتمد شكل نواة الطبقة Dense على كلٍ من أشكال الإدخال والإخراج للطبقة ، وبالتالي فإن شكل الإخراج المطلوب كوسيطة مُنشئ ليس معلومات كافية لإنشاء المتغير بمفرده. نظرًا لأن استدعاء Layer يقرأ أيضًا قيمة المتغير ، يجب أن تحدث استعادة بين إنشاء المتغير واستخدامه لأول مرة.

لدعم هذا المصطلح ، يستعيد tf.train.Checkpoint التي لا تحتوي بعد على متغير مطابق.

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]

التفتيش اليدوي على نقاط التفتيش

tf.train.load_checkpoint بإرجاع CheckpointReader الذي يوفر وصولاً منخفض المستوى لمحتويات نقطة التفتيش. يحتوي على تعيينات من مفتاح كل متغير ، إلى الشكل والنوع لكل متغير في نقطة التحقق. مفتاح المتغير هو مسار الكائن الخاص به ، كما هو الحال في الرسوم البيانية المعروضة أعلاه.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

لذلك إذا كنت مهتمًا بقيمة net.l1.kernel ، يمكنك الحصول على القيمة باستخدام الكود التالي:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

يوفر أيضًا طريقة get_tensor تسمح لك بفحص قيمة المتغير:

reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

تتبع الكائن

تقوم نقاط التحقق بحفظ واستعادة قيم tf.Variable الكائنات عن طريق "تتبع" أي متغير أو كائن قابل للتتبع تم تعيينه في إحدى سماته. عند تنفيذ حفظ ، يتم جمع المتغيرات بشكل متكرر من جميع الكائنات المتعقبة التي يمكن الوصول إليها.

كما هو الحال مع تعيينات السمات المباشرة مثل self.l1 = tf.keras.layers.Dense(5) ، فإن تخصيص القوائم والقواميس للسمات سيؤدي إلى تتبع محتوياتها.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

قد تلاحظ كائنات مجمعة للقوائم والقواميس. هذه الأغلفة هي إصدارات يمكن التحقق منها من هياكل البيانات الأساسية. تمامًا مثل التحميل المستند إلى السمة ، تستعيد هذه الأغلفة قيمة المتغير بمجرد إضافته إلى الحاوية.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

تتضمن الكائنات tf.train.Checkpoint و tf.Module الفرعية (مثل keras.layers.Layer و keras.Model ) وحاويات Python المعترف بها:

  • dict ( collections.OrderedDict .
  • list
  • tuple ( collections.namedtuple . التي تسمى tuple ، الكتابة. typing.NamedTuple )

أنواع الحاويات الأخرى غير مدعومة ، بما في ذلك:

  • collections.defaultdict
  • set

يتم تجاهل جميع كائنات Python الأخرى ، بما في ذلك:

  • int
  • string
  • float

ملخص

توفر كائنات TensorFlow آلية تلقائية سهلة لحفظ واستعادة قيم المتغيرات التي تستخدمها.