הצג באתר TensorFlow.org | הפעל בגוגל קולאב | צפה במקור ב-GitHub | הורד מחברת |
הביטוי "שמירת מודל TensorFlow" אומר בדרך כלל אחד משני דברים:
- מחסומים, OR
- SavedModel.
נקודות ביקורת לוכדות את הערך המדויק של כל הפרמטרים ( tf.Variable
objects) המשמשים את המודל. נקודות ביקורת אינן מכילות כל תיאור של החישוב שהוגדר על ידי המודל ולכן הן בדרך כלל שימושיות רק כאשר קוד מקור שישתמש בערכי הפרמטרים השמורים זמין.
פורמט SavedModel לעומת זאת כולל תיאור סדרתי של החישוב שהוגדר על ידי המודל בנוסף לערכי הפרמטרים (נקודת ביקורת). מודלים בפורמט זה אינם תלויים בקוד המקור שיצר את המודל. לפיכך הם מתאימים לפריסה באמצעות TensorFlow Serving, TensorFlow Lite, TensorFlow.js, או תוכניות בשפות תכנות אחרות (ה-C, C++, Java, Go, Rust, C# וכו'. API של TensorFlow).
מדריך זה מכסה ממשקי API לכתיבה וקריאה של מחסומים.
להכין
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
שמירת ממשקי API להדרכה של tf.keras
עיין במדריך tf.keras
בנושא שמירה ושחזור.
tf.keras.Model.save_weights
שומר נקודת ביקורת של TensorFlow.
net.save_weights('easy_checkpoint')
כתיבת מחסומים
המצב המתמשך של מודל TensorFlow מאוחסן באובייקטים tf.Variable
. ניתן לבנות אותם ישירות, אך לרוב נוצרים באמצעות ממשקי API ברמה גבוהה כמו tf.keras.layers
או tf.keras.Model
.
הדרך הקלה ביותר לנהל משתנים היא על ידי הצמדתם לאובייקטים של Python, ואז הפניה לאובייקטים הללו.
תת-מחלקות של tf.train.Checkpoint
, tf.keras.layers.Layer
ו- tf.keras.Model
באופן אוטומטי אחר משתנים המוקצים לתכונות שלהם. הדוגמה הבאה בונה מודל ליניארי פשוט, ולאחר מכן כותבת נקודות ביקורת המכילות ערכים עבור כל משתני המודל.
אתה יכול בקלות לשמור נקודת ביקורת מודל עם Model.save_weights
.
בדיקה ידנית
להכין
כדי לעזור להדגים את כל התכונות של tf.train.Checkpoint
, הגדר מערך צעצוע ושלב אופטימיזציה:
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
צור את אובייקטי המחסום
השתמש באובייקט tf.train.Checkpoint
כדי ליצור נקודת ביקורת באופן ידני, כאשר האובייקטים שברצונך למחסום מוגדרים כמאפיינים על האובייקט.
tf.train.CheckpointManager
יכול להיות מועיל גם לניהול מחסומים מרובים.
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
הרכבת ונקודת ביקורת של הדגם
לולאת האימון הבאה יוצרת מופע של המודל ושל אופטימיזציה, ואז אוספת אותם לאובייקט tf.train.Checkpoint
. זה קורא לשלב האימון בלולאה על כל אצווה של נתונים, ומדי פעם כותב נקודות ביקורת לדיסק.
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch. Saved checkpoint for step 10: ./tf_ckpts/ckpt-1 loss 31.27 Saved checkpoint for step 20: ./tf_ckpts/ckpt-2 loss 24.68 Saved checkpoint for step 30: ./tf_ckpts/ckpt-3 loss 18.12 Saved checkpoint for step 40: ./tf_ckpts/ckpt-4 loss 11.65 Saved checkpoint for step 50: ./tf_ckpts/ckpt-5 loss 5.39
שחזר והמשך אימון
לאחר מחזור ההכשרה הראשון אתה יכול לעבור מודל ומנהל חדש, אבל להמשיך את האימונים בדיוק מהנקודה בה הפסקתם:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5 Saved checkpoint for step 60: ./tf_ckpts/ckpt-6 loss 1.50 Saved checkpoint for step 70: ./tf_ckpts/ckpt-7 loss 1.27 Saved checkpoint for step 80: ./tf_ckpts/ckpt-8 loss 0.56 Saved checkpoint for step 90: ./tf_ckpts/ckpt-9 loss 0.70 Saved checkpoint for step 100: ./tf_ckpts/ckpt-10 loss 0.35
האובייקט tf.train.CheckpointManager
מוחק נקודות ביקורת ישנות. מעל זה מוגדר לשמור רק את שלושת המחסומים העדכניים ביותר.
print(manager.checkpoints) # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']
נתיבים אלה, למשל './tf_ckpts/ckpt-10'
, אינם קבצים בדיסק. במקום זאת הם קידומות לקובץ index
וקובץ נתונים אחד או יותר המכילים את ערכי המשתנים. קידומות אלה מקובצות יחד בקובץ נקודת checkpoint
בודד ( './tf_ckpts/checkpoint'
) שבו ה- CheckpointManager
שומר את מצבו.
ls ./tf_ckpts
checkpoint ckpt-8.data-00000-of-00001 ckpt-9.index ckpt-10.data-00000-of-00001 ckpt-8.index ckpt-10.index ckpt-9.data-00000-of-00001
מכניקת טעינה
TensorFlow מתאים משתנים לערכי נקודת ביקורת על ידי חציית גרף מכוון עם קצוות בעלי שם, החל מהאובייקט הנטען. שמות קצה מגיעים בדרך כלל משמות תכונות באובייקטים, למשל ה- "l1"
ב- self.l1 = tf.keras.layers.Dense(5)
. tf.train.Checkpoint
משתמש בשמות הארגומנטים של מילות המפתח שלו, כמו ב"שלב "step"
ב- tf.train.Checkpoint(step=...)
.
גרף התלות מהדוגמה למעלה נראה כך:
כלי האופטימיזציה באדום, משתנים רגילים הם בכחול, ומשתני משבצת המיטוב הם בכתום. הצמתים האחרים - למשל, המייצגים את tf.train.Checkpoint
- הם בצבע שחור.
משתני משבצות הם חלק ממצב האופטימיזציה, אך נוצרים עבור משתנה ספציפי. לדוגמה, קצוות ה- 'm'
שלמעלה תואמים למומנטום, שאותו מייעל Adam עוקב אחר כל משתנה. משתני משבצת נשמרים בנקודת ביקורת רק אם המשתנה והמייעל יישמרו שניהם, ובכך הקצוות המקווקוים.
קריאה restore
באובייקט tf.train.Checkpoint
בתור את השחזורים המבוקשים, משחזר ערכי משתנים ברגע שיש נתיב תואם מאובייקט Checkpoint
. לדוגמה, אתה יכול לטעון רק את ההטיה מהמודל שהגדרת למעלה על ידי שחזור נתיב אחד אליו דרך הרשת והשכבה.
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
[0. 0. 0. 0. 0.] [2.7209885 3.7588918 4.421351 4.1466427 4.0712557]
גרף התלות של האובייקטים החדשים הללו הוא תת-גרף קטן בהרבה של המחסום הגדול יותר שכתבת למעלה. הוא כולל רק את ההטיה ומונה שמירה ש- tf.train.Checkpoint
משתמש בהם כדי למספר נקודות ביקורת.
restore
מחזיר אובייקט סטטוס, שיש לו קביעות אופציונליות. כל האובייקטים שנוצרו ב- Checkpoint
החדש שוחזרו, אז status.assert_existing_objects_matched
עובר.
status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>
ישנם אובייקטים רבים בנקודת הבידוק שלא התאימו, כולל ליבת השכבה והמשתנים של האופטימיזציה. status.assert_consumed
עובר רק אם המחסום והתוכנית תואמים במדויק, ותזרוק כאן חריגה.
שחזורים דחויים
אובייקטי Layer
ב-TensorFlow עשויים לדחות את יצירת המשתנים לקריאה הראשונה שלהם, כאשר צורות קלט זמינות. לדוגמה, הצורה של ליבת שכבה Dense
תלויה הן בצורות הקלט והן בצורות הפלט של השכבה, ולכן צורת הפלט הנדרשת כארגומנט הבנאי אינה מספיק מידע כדי ליצור את המשתנה בפני עצמו. מכיוון שקריאה Layer
קוראת גם את ערך המשתנה, שחזור חייב לקרות בין יצירת המשתנה לשימוש הראשון בו.
כדי לתמוך בביטוי זה, tf.train.Checkpoint
דוחה שחזורים שעדיין אין להם משתנה תואם.
deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy()) # Restored
[[0. 0. 0. 0. 0.]] [[4.5854754 4.607731 4.649179 4.8474874 5.121 ]]
בדיקה ידנית של מחסומים
tf.train.load_checkpoint
מחזיר CheckpointReader
שנותן גישה ברמה נמוכה יותר לתוכן המחסום. הוא מכיל מיפויים מהמפתח של כל משתנה, לצורה ול-dtype עבור כל משתנה בנקודת הבידוק. המפתח של משתנה הוא נתיב האובייקט שלו, כמו בגרפים המוצגים למעלה.
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH', 'iterator/.ATTRIBUTES/ITERATOR_STATE', 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', 'save_counter/.ATTRIBUTES/VARIABLE_VALUE', 'step/.ATTRIBUTES/VARIABLE_VALUE']
אז אם אתה מעוניין בערך של net.l1.kernel
אתה יכול לקבל את הערך עם הקוד הבא:
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5] Dtype: float32
זה גם מספק שיטת get_tensor
המאפשרת לך לבדוק את הערך של משתנה:
reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121 ]], dtype=float32)
מעקב אחר אובייקטים
מחסומים שומרים ומשחזרים את הערכים של אובייקטי tf.Variable
על ידי "מעקב" אחר כל משתנה או אובייקט בר מעקב באחת מהתכונות שלו. בעת ביצוע שמירה, משתנים נאספים באופן רקורסיבי מכל האובייקטים הנגישים למעקב.
כמו בהקצאות תכונות ישירות כמו self.l1 = tf.keras.layers.Dense(5)
, הקצאת רשימות ומילון לתכונות תעקוב אחר תוכנן.
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
ייתכן שתבחין באובייקטי מעטפת עבור רשימות ומילונים. עטיפות אלו הן גרסאות הניתנות לבדיקה של מבני הנתונים הבסיסיים. בדיוק כמו הטעינה המבוססת על תכונה, עטיפות אלה משחזרות ערך של משתנה ברגע שהוא מתווסף למיכל.
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])
אובייקטים שניתנים למעקב כוללים tf.train.Checkpoint
, tf.Module
המשנה שלו (למשל keras.layers.Layer
ו- keras.Model
), ומכולות Python מזוהות:
-
dict
(andcollections.OrderedDict
) -
list
-
tuple
(ו-collections.namedtuple
,typing.NamedTuple
)
סוגי מיכל אחרים אינם נתמכים , כולל:
-
collections.defaultdict
-
set
מתעלמים מכל שאר האובייקטים של Python, כולל:
-
int
-
string
-
float
סיכום
אובייקטי TensorFlow מספקים מנגנון אוטומטי קל לשמירה ושחזור הערכים של המשתנים שבהם הם משתמשים.