הצג באתר TensorFlow.org | הפעל בגוגל קולאב | הצג ב-GitHub | הורד מחברת |
סקירה כללית
מדריך זה מניח שיש לך מודל ששומר וטוען נקודות ביקורת עם tf.compat.v1.Saver
, וברצונך להעביר את הקוד השתמש ב-TF2 tf.train.Checkpoint
API, או השתמש בנקודות ביקורת קיימות במודל TF2 שלך.
להלן כמה תרחישים נפוצים שאתה עלול להיתקל בהם:
תרחיש 1
ישנן נקודות ביקורת קיימות של TF1 מריצות אימון קודמות שצריך לטעון או להמיר ל-TF2.
- כדי לטעון את נקודת ביקורת TF1 ב-TF2, ראה את הקטע טעינת נקודת ביקורת TF1 ב-TF2 .
- כדי להמיר את נקודת הבידוק ל-TF2, ראה המרת נקודת ביקורת.
תרחיש 2
אתה מתאים את המודל שלך באופן שמסתכן בשינוי שמות ונתיבים של משתנים (כגון בעת הגירה הדרגתית מ- get_variable
ליצירת tf.Variable
מפורשת), וברצונך לשמור על שמירה/טעינה של נקודות ביקורת קיימות לאורך הדרך.
עיין בסעיף כיצד לשמור על תאימות נקודות ביקורת במהלך העברת מודלים
תרחיש 3
אתה מעביר את קוד האימון ונקודות הבידוק שלך ל-TF2, אבל צינור ההסקנות שלך ממשיך לדרוש נקודות ביקורת TF1 לעת עתה (למען יציבות הייצור).
אופציה 1
שמור גם את מחסומי TF1 וגם TF2 בעת אימון.
אפשרות 2
המר את נקודת ביקורת TF2 ל-TF1.
הדוגמאות שלהלן מציגות את כל השילובים של שמירה וטעינה של נקודות ביקורת ב-TF1/TF2, כך שיש לך גמישות מסוימת בקביעה כיצד להעביר את המודל שלך.
להכין
import tensorflow as tf
import tensorflow.compat.v1 as tf1
def print_checkpoint(save_path):
reader = tf.train.load_checkpoint(save_path)
shapes = reader.get_variable_to_shape_map()
dtypes = reader.get_variable_to_dtype_map()
print(f"Checkpoint at '{save_path}':")
for key in shapes:
print(f" (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, "
f"value={reader.get_tensor(key)})")
שינויים מ-TF1 ל-TF2
סעיף זה כלול אם אתה סקרן לגבי מה השתנה בין TF1 ל-TF2, ולמה אנו מתכוונים בנקודות "מבוססות שם" (TF1) לעומת "מבוססות אובייקטים" (TF2).
שני סוגי המחסומים נשמרים למעשה באותו פורמט, שהוא בעצם טבלת מפתח-ערך. ההבדל טמון באופן שבו המפתחות נוצרים.
המפתחות בנקודות ביקורת מבוססות שם הם שמות המשתנים . המפתחות בנקודות ביקורת מבוססות-אובייקט מתייחסים לנתיב מאובייקט השורש למשתנה (הדוגמאות למטה יעזרו להבין טוב יותר מה זה אומר).
ראשית, שמור כמה מחסומים:
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.Session() as sess:
saver = tf1.train.Saver()
sess.run(a.assign(1))
sess.run(b.assign(2))
sess.run(c.assign(3))
saver.save(sess, 'tf1-ckpt')
print_checkpoint('tf1-ckpt')
Checkpoint at 'tf1-ckpt': (key='scoped/c', shape=[], dtype=float32, value=3.0) (key='a', shape=[], dtype=float32, value=1.0) (key='b', shape=[], dtype=float32, value=2.0)
a = tf.Variable(5.0, name='a')
b = tf.Variable(6.0, name='b')
with tf.name_scope('scoped'):
c = tf.Variable(7.0, name='c')
ckpt = tf.train.Checkpoint(variables=[a, b, c])
save_path_v2 = ckpt.save('tf2-ckpt')
print_checkpoint(save_path_v2)
Checkpoint at 'tf2-ckpt-1': (key='variables/2/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=7.0) (key='variables/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=5.0) (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n!\n\r\x08\x01\x12\tvariables\n\x10\x08\x02\x12\x0csave_counter\n\x15\n\x05\x08\x03\x12\x010\n\x05\x08\x04\x12\x011\n\x05\x08\x05\x12\x012\nI\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE\n=\x12;\n\x0eVARIABLE_VALUE\x12\x01a\x1a&variables/0/.ATTRIBUTES/VARIABLE_VALUE\n=\x12;\n\x0eVARIABLE_VALUE\x12\x01b\x1a&variables/1/.ATTRIBUTES/VARIABLE_VALUE\nD\x12B\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a&variables/2/.ATTRIBUTES/VARIABLE_VALUE") (key='variables/1/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=6.0) (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
אם אתה מסתכל על המפתחות ב- tf2-ckpt
, כולם מתייחסים לנתיבי האובייקט של כל משתנה. לדוגמה, משתנה a
הוא האלמנט הראשון ברשימת variables
, ולכן המפתח שלו הופך variables/0/...
(אתם מוזמנים להתעלם מהקבוע .ATTRIBUTES/VARIABLE_VALUE).
בדיקה מדוקדקת יותר של אובייקט Checkpoint
להלן:
a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
root = ckpt = tf.train.Checkpoint(variables=[a, b, c])
print("root type =", type(root).__name__)
print("root.variables =", root.variables)
print("root.variables[0] =", root.variables[0])
root type = Checkpoint root.variables = ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>]) root.variables[0] = <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>
נסה להתנסות עם הקטע שלהלן וראה כיצד מפתחות המחסום משתנים עם מבנה האובייקט:
module = tf.Module()
module.d = tf.Variable(0.)
test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b},
c=c,
module=module)
test_ckpt_path = test_ckpt.save('root-tf2-ckpt')
print_checkpoint(test_ckpt_path)
Checkpoint at 'root-tf2-ckpt-1': (key='v/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0) (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1) (key='v/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0) (key='module/d/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0) (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n,\n\x05\x08\x01\x12\x01c\n\n\x08\x02\x12\x06module\n\x05\x08\x03\x12\x01v\n\x10\x08\x04\x12\x0csave_counter\n:\x128\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1cc/.ATTRIBUTES/VARIABLE_VALUE\n\x07\n\x05\x08\x05\x12\x01d\n\x0e\n\x05\x08\x06\x12\x01a\n\x05\x08\x07\x12\x01b\nI\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE\nA\x12?\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a#module/d/.ATTRIBUTES/VARIABLE_VALUE\n<\x12:\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1ev/a/.ATTRIBUTES/VARIABLE_VALUE\n<\x12:\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1ev/b/.ATTRIBUTES/VARIABLE_VALUE") (key='c/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
מדוע TF2 משתמש במנגנון זה?
מכיוון שאין יותר גרף גלובלי ב-TF2, שמות משתנים אינם אמינים ויכולים להיות לא עקביים בין תוכניות. TF2 מעודד את גישת הדוגמנות מונחה עצמים שבה משתנים נמצאים בבעלות שכבות, ושכבות בבעלות מודל:
variable = tf.Variable(...)
layer.variable_name = variable
model.layer_name = layer
כיצד לשמור על תאימות נקודות ביקורת במהלך העברת מודלים
שלב חשוב אחד בתהליך ההעברה הוא להבטיח שכל המשתנים מאותחלים לערכים הנכונים , מה שבתורו מאפשר לך לאמת שהאופס/פונקציות מבצעות את החישובים הנכונים. כדי להשיג זאת, עליך לשקול את תאימות המחסום בין מודלים בשלבים השונים של ההגירה. בעיקרו של דבר, סעיף זה עונה על השאלה, איך אני ממשיך להשתמש באותה מחסום בזמן שינוי המודל .
להלן שלוש דרכים לשמירה על תאימות מחסום, על מנת להגביר את הגמישות:
- למודל יש את אותם שמות משתנים כמו קודם.
- למודל יש שמות משתנים שונים, והוא מחזיק מפת הקצאה הממפה את שמות המשתנים במחסום לשמות החדשים.
- למודל יש שמות משתנים שונים, והוא שומר על אובייקט TF2 Checkpoint המאחסן את כל המשתנים.
כאשר שמות המשתנים תואמים
כותרת ארוכה: כיצד לעשות שימוש חוזר בנקודות ביקורת כאשר שמות המשתנים תואמים.
תשובה קצרה: אתה יכול לטעון ישירות את המחסום הקיים עם tf1.train.Saver
או tf.train.Checkpoint
.
אם אתה משתמש ב- tf.compat.v1.keras.utils.track_tf1_style_variables
, זה יבטיח ששמות משתני המודל שלך יהיו זהים לקודמים. אתה יכול גם לוודא ידנית ששמות משתנים תואמים.
כאשר שמות המשתנים תואמים במודלים שהועברו, תוכל להשתמש ישירות ב- tf.train.Checkpoint
או tf.compat.v1.train.Saver
כדי לטעון את נקודת הבידוק. שני ממשקי ה-API תואמים למצב להוט ולמצב גרף, כך שתוכל להשתמש בהם בכל שלב של ההגירה.
להלן דוגמאות לשימוש באותו מחסום עם דגמים שונים. ראשית, שמור מחסום TF1 עם tf1.train.Saver
:
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.Session() as sess:
saver = tf1.train.Saver()
sess.run(a.assign(1))
sess.run(b.assign(2))
sess.run(c.assign(3))
save_path = saver.save(sess, 'tf1-ckpt')
print_checkpoint(save_path)
Checkpoint at 'tf1-ckpt': (key='scoped/c', shape=[], dtype=float32, value=3.0) (key='a', shape=[], dtype=float32, value=1.0) (key='b', shape=[], dtype=float32, value=2.0)
הדוגמה שלהלן משתמשת ב- tf.compat.v1.Saver
כדי לטעון את נקודת הבידוק במצב להוט:
a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
c = tf.Variable(0.0, name='c')
# With the removal of collections in TF2, you must pass in the list of variables
# to the Saver object:
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path=save_path)
print(f"loaded values of [a, b, c]: [{a.numpy()}, {b.numpy()}, {c.numpy()}]")
# Saving also works in eager (sess must be None).
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone. INFO:tensorflow:Restoring parameters from tf1-ckpt loaded values of [a, b, c]: [1.0, 2.0, 3.0] Checkpoint at 'tf1-ckpt-saved-in-eager': (key='scoped/c', shape=[], dtype=float32, value=3.0) (key='a', shape=[], dtype=float32, value=1.0) (key='b', shape=[], dtype=float32, value=2.0)
הקטע הבא טוען את נקודת הבידוק באמצעות TF2 API tf.train.Checkpoint
:
a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
c = tf.Variable(0.0, name='c')
# Without the name_scope, name="scoped/c" works too:
c_2 = tf.Variable(0.0, name='scoped/c')
print("Variable names: ")
print(f" a.name = {a.name}")
print(f" b.name = {b.name}")
print(f" c.name = {c.name}")
print(f" c_2.name = {c_2.name}")
# Restore the values with tf.train.Checkpoint
ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])
ckpt.restore(save_path)
print(f"loaded values of [a, b, c, c_2]: [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]")
Variable names: a.name = a:0 b.name = b:0 c.name = scoped/c:0 c_2.name = scoped/c:0 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/tracking/util.py:1345: NameBasedSaverStatus.__init__ (from tensorflow.python.training.tracking.util) is deprecated and will be removed in a future version. Instructions for updating: Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future. loaded values of [a, b, c, c_2]: [1.0, 2.0, 3.0, 3.0]
שמות משתנים ב-TF2
- למשתנים עדיין יש ארגומנט
name
שאתה יכול להגדיר. - מודלים של Keras גם לוקחים ארגומנט
name
שהם מגדירים כתחילית למשתנים שלהם. - ניתן להשתמש בפונקציה
v1.name_scope
כדי להגדיר קידומות של שמות משתנים. זה שונה מאוד מ-tf.variable_scope
. זה משפיע רק על שמות, ואינו עוקב אחר משתנים ושימוש חוזר.
tf.compat.v1.keras.utils.track_tf1_style_variables
הוא תבנית שמסייעת לך לשמור על שמות משתנים ותאימות TF1 Checkpoint, על ידי שמירה על סמנטיקה של שמות ושימוש חוזר של tf.variable_scope
ו- tf.compat.v1.get_variable
ללא שינוי. עיין במדריך מיפוי מודלים למידע נוסף.
הערה 1: אם אתה משתמש ב-shim, השתמש בממשקי API של TF2 כדי לטעון את המחסומים שלך (אפילו בעת שימוש בנקודות ביקורת TF1 מאומנות מראש).
עיין בסעיף מחסום קרס .
הערה 2: בעת מעבר ל- tf.Variable
מ- get_variable
:
אם השכבה או המודול המעוטרים ב-shim שלך מורכבים ממשתנים מסוימים (או שכבות/מודלים של Keras) המשתמשים ב- tf.Variable
במקום tf.compat.v1.get_variable
/עקובים אחר באופן מונחה עצמים, ייתכן שיהיה להם שונה סמנטיקה של שמות משתנים בגרפים/הפעלות TF1.x לעומת במהלך ביצוע נלהב.
בקיצור, ייתכן שהשמות לא יהיו מה שאתה מצפה מהם להיות כאשר הם פועלים ב-TF2.
שמירה על מפות מטלות
מפות הקצאה משמשות בדרך כלל להעברת משקלים בין דגמי TF1, וניתן להשתמש בהן גם במהלך העברת המודלים שלך אם שמות המשתנים משתנים.
אתה יכול להשתמש במפות אלה עם tf.compat.v1.train.init_from_checkpoint
, tf.compat.v1.train.Saver
ו- tf.train.load_checkpoint
כדי לטעון משקלים למודלים שבהם ייתכן שהמשתנה או שמות ההיקף השתנו.
הדוגמאות בסעיף זה ישתמשו בנקודת ביקורת שנשמרה בעבר:
print_checkpoint('tf1-ckpt')
Checkpoint at 'tf1-ckpt': (key='scoped/c', shape=[], dtype=float32, value=3.0) (key='a', shape=[], dtype=float32, value=1.0) (key='b', shape=[], dtype=float32, value=2.0)
טוען עם init_from_checkpoint
יש לקרוא tf1.train.init_from_checkpoint
תוך כדי גרף/הפעלה, מכיוון שהוא מציב את הערכים במתחלי המשתנים במקום ליצור אופציה להקצות.
אתה יכול להשתמש בארגומנט assignment_map
כדי להגדיר את אופן טעינת המשתנים. מתוך התיעוד:
מפת הקצאות תומכת בתחביר הבא:
-
'checkpoint_scope_name/': 'scope_name/'
- יטען את כל המשתנים ב-scope_name
הנוכחי מ-checkpoint_scope_name
עם שמות טנסור תואמים. -
'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'
- יאתחלscope_name/variable_name
variable_name מ-checkpoint_scope_name/some_other_variable
. -
'scope_variable_name': variable
- יאתחל נתון אובייקטtf.Variable
עם הטנסור 'scope_variable_name' מנקודת הבידוק. -
'scope_variable_name': list(variable)
- יאתחל רשימה של משתנים מחולקים עם הטנסור 'scope_variable_name' מנקודת הבידוק. -
'/': 'scope_name/'
- יטען את כל המשתנים ב-scope_name
הנוכחי מהשורש של המחסום (למשל ללא היקף).
# Restoring with tf1.train.init_from_checkpoint:
# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
with tf1.variable_scope('new_scope'):
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.Session() as sess:
# The assignment map will remap all variables in the checkpoint to the
# new scope:
tf1.train.init_from_checkpoint(
'tf1-ckpt',
assignment_map={'/': 'new_scope/'})
# `init_from_checkpoint` adds the initializers to these variables.
# Use `sess.run` to run these initializers.
sess.run(tf1.global_variables_initializer())
print("Restored [a, b, c]: ", sess.run([a, b, c]))
Restored [a, b, c]: [1.0, 2.0, 3.0]
טוען עם tf1.train.Saver
בניגוד init_from_checkpoint
, tf.compat.v1.train.Saver
פועל גם במצב גרף וגם במצב להוט. הארגומנט var_list
מקבל באופן אופציונלי מילון, אלא שהוא חייב למפות שמות משתנים לאובייקט tf.Variable
.
# Restoring with tf1.train.Saver (works in both graph and eager):
# A new model with a different scope for the variables.
with tf1.variable_scope('new_scope'):
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
# Initialize the saver with a dictionary with the original variable names:
saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c})
saver.restore(sess=None, save_path='tf1-ckpt')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone. INFO:tensorflow:Restoring parameters from tf1-ckpt Restored [a, b, c]: [1.0, 2.0, 3.0]
טוען עם tf.train.load_checkpoint
אפשרות זו מיועדת לך אם אתה זקוק לשליטה מדויקת על ערכי המשתנים. שוב, זה עובד גם במצבי גרף וגם במצבי להוט.
# Restoring with tf.train.load_checkpoint (works in both graph and eager):
# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
with tf1.variable_scope('new_scope'):
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.Session() as sess:
# It may be easier writing a loop if your model has a lot of variables.
reader = tf.train.load_checkpoint('tf1-ckpt')
sess.run(a.assign(reader.get_tensor('a')))
sess.run(b.assign(reader.get_tensor('b')))
sess.run(c.assign(reader.get_tensor('scoped/c')))
print("Restored [a, b, c]: ", sess.run([a, b, c]))
Restored [a, b, c]: [1.0, 2.0, 3.0]
שמירה על אובייקט TF2 Checkpoint
אם שמות המשתנים וההיקף עשויים להשתנות הרבה במהלך ההעברה, השתמש בנקודות ביקורת tf.train.Checkpoint
ו-TF2. TF2 משתמש במבנה האובייקט במקום בשמות משתנים (פרטים נוספים בשינויים מ-TF1 ל-TF2 ).
בקיצור, בעת יצירת tf.train.Checkpoint
לשמירה או שחזור של נקודות ביקורת, ודא שהוא משתמש באותו סדר (עבור רשימות) ומפתחות (עבור מילונים וארגומנטים של מילות מפתח לאתחול Checkpoint
). כמה דוגמאות לתאימות נקודות ביקורת:
ckpt = tf.train.Checkpoint(foo=[var_a, var_b])
# compatible with ckpt
tf.train.Checkpoint(foo=[var_a, var_b])
# not compatible with ckpt
tf.train.Checkpoint(foo=[var_b, var_a])
tf.train.Checkpoint(bar=[var_a, var_b])
דוגמאות הקוד שלהלן מראות כיצד להשתמש ב-"אותו" tf.train.Checkpoint
כדי לטעון משתנים עם שמות שונים. ראשית, שמור נקודת ביקורת TF2:
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(1))
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(2))
with tf1.variable_scope('scoped'):
c = tf1.get_variable('c', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(3))
with tf1.Session() as sess:
sess.run(tf1.global_variables_initializer())
print("[a, b, c]: ", sess.run([a, b, c]))
# Save a TF2 checkpoint
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
tf2_ckpt_path = ckpt.save('tf2-ckpt')
print_checkpoint(tf2_ckpt_path)
[a, b, c]: [1.0, 2.0, 3.0] Checkpoint at 'tf2-ckpt-1': (key='unscoped/1/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0) (key='unscoped/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0) (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n,\n\n\x08\x01\x12\x06scoped\n\x0c\x08\x02\x12\x08unscoped\n\x10\x08\x03\x12\x0csave_counter\n\x07\n\x05\x08\x04\x12\x010\n\x0e\n\x05\x08\x05\x12\x010\n\x05\x08\x06\x12\x011\nI\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE\nA\x12?\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a#scoped/0/.ATTRIBUTES/VARIABLE_VALUE\n<\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%unscoped/0/.ATTRIBUTES/VARIABLE_VALUE\n<\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%unscoped/1/.ATTRIBUTES/VARIABLE_VALUE") (key='scoped/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0) (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
אתה יכול להמשיך להשתמש ב- tf.train.Checkpoint
גם אם שמות המשתנים/ההיקף משתנים:
with tf.Graph().as_default() as g:
a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.variable_scope('different_scope'):
c = tf1.get_variable('c', shape=[], dtype=tf.float32,
initializer=tf1.zeros_initializer())
with tf1.Session() as sess:
sess.run(tf1.global_variables_initializer())
print("Initialized [a, b, c]: ", sess.run([a, b, c]))
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
# `assert_consumed` validates that all checkpoint objects are restored from
# the checkpoint. `run_restore_ops` is required when running in a TF1
# session.
ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()
# Removing `assert_consumed` is fine if you want to skip the validation.
# ckpt.restore(tf2_ckpt_path).run_restore_ops()
print("Restored [a, b, c]: ", sess.run([a, b, c]))
Initialized [a, b, c]: [0.0, 0.0, 0.0] Restored [a, b, c]: [1.0, 2.0, 3.0]
ובמצב להוט:
a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
# The keys "scoped" and "unscoped" are no longer relevant, but are used to
# maintain compatibility with the saved checkpoints.
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Initialized [a, b, c]: [0.0, 0.0, 0.0] Restored [a, b, c]: [1.0, 2.0, 3.0]
מחסומי TF2 באומדן
הסעיפים שלמעלה מתארים כיצד לשמור על תאימות נקודות ביקורת בזמן העברת המודל שלך. מושגים אלה חלים גם על מודלים של Estimator, אם כי אופן השמירה/טעינת המחסום שונה במקצת. בזמן שאתה מעביר את מודל האומד שלך לשימוש בממשקי TF2 API, ייתכן שתרצה לעבור מנקודות ביקורת TF1 ל-TF2 בזמן שהמודל עדיין משתמש באומדן . חלק זה מראה כיצד לעשות זאת.
tf.estimator.Estimator
ול- MonitoredSession
יש מנגנון שמירה שנקרא scaffold
, אובייקט tf.compat.v1.train.Scaffold
. ה- Scaffold
יכול להכיל tf1.train.Saver
או tf.train.Checkpoint
, המאפשרים ל- Estimator
ול- MonitoredSession
לשמור נקודות ביקורת בסגנון TF1 או TF2.
# A model_fn that saves a TF1 checkpoint
def model_fn_tf1_ckpt(features, labels, mode):
# This model adds 2 to the variable `v` in every train step.
train_step = tf1.train.get_or_create_global_step()
v = tf1.get_variable('var', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
return tf.estimator.EstimatorSpec(
mode,
predictions=v,
train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
loss=tf.constant(1.),
scaffold=None
)
!rm -rf est-tf1
est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1')
def train_fn():
return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)
latest_checkpoint = tf.train.latest_checkpoint('est-tf1')
print_checkpoint(latest_checkpoint)
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': 'est-tf1', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:401: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into est-tf1/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into est-tf1/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:Loss for final step: 1.0. Checkpoint at 'est-tf1/model.ckpt-1': (key='var', shape=[], dtype=float32, value=2.0) (key='global_step', shape=[], dtype=int64, value=1)
# A model_fn that saves a TF2 checkpoint
def model_fn_tf2_ckpt(features, labels, mode):
# This model adds 2 to the variable `v` in every train step.
train_step = tf1.train.get_or_create_global_step()
v = tf1.get_variable('var', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step)
return tf.estimator.EstimatorSpec(
mode,
predictions=v,
train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
loss=tf.constant(1.),
scaffold=tf1.train.Scaffold(saver=ckpt)
)
!rm -rf est-tf2
est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2',
warm_start_from='est-tf1')
def train_fn():
return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)
latest_checkpoint = tf.train.latest_checkpoint('est-tf2')
print_checkpoint(latest_checkpoint)
assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': 'est-tf2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='est-tf1', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: est-tf1 INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 1 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into est-tf2/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into est-tf2/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:Loss for final step: 1.0. Checkpoint at 'est-tf2/model.ckpt-1': (key='var_list/var/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=4.0) (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n\x18\n\x08\x08\x01\x12\x04step\n\x0c\x08\x02\x12\x08var_list\n@\x12>\n\x0eVARIABLE_VALUE\x12\x0bglobal_step\x1a\x1fstep/.ATTRIBUTES/VARIABLE_VALUE\n\t\n\x07\x08\x03\x12\x03var\n@\x12>\n\x0eVARIABLE_VALUE\x12\x03var\x1a'var_list/var/.ATTRIBUTES/VARIABLE_VALUE") (key='step/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
הערך הסופי של v
צריך להיות 16
, לאחר הפעלה חמה מ- est-tf1
, ולאחר מכן אימון ל-5 צעדים נוספים. ערך צעד הרכבת לא עובר ממחסום warm_start
.
מחסום קרס
דגמים שנבנו עם Keras עדיין משתמשים ב- tf1.train.Saver
וב- tf.train.Checkpoint
לטעינת משקלים קיימים. כאשר המודל שלך מועבר במלואו, עבור לשימוש ב- model.save_weights
ו- model.load_weights
, במיוחד אם אתה משתמש ב- ModelCheckpoint
callback בעת אימון.
כמה דברים שכדאי לדעת על מחסומים וקרס:
אתחול מול בנייה
מודלים ושכבות של Keras חייבים לעבור שני שלבים לפני שהם נוצרים במלואם. ראשית הוא האתחול של אובייקט Python: layer = tf.keras.layers.Dense(x)
. שנית הוא שלב הבנייה , שבו למעשה נוצרות רוב המשקולות: layer.build(input_shape)
. אתה יכול גם לבנות מודל על ידי קריאה אליו או הפעלת train
בודדת , eval
, או שלב predict
(בפעם הראשונה בלבד).
אם אתה מגלה ש- model.load_weights(path).assert_consumed()
מעלה שגיאה, סביר להניח שהמודל/השכבות לא נבנו.
קרס משתמש בנקודות ביקורת TF2
tf.train.Checkpoint(model).write
.write שווה ערך ל- model.save_weights
. אותו דבר עם tf.train.Checkpoint(model).read
ו- model.load_weights
. שימו לב ש- Checkpoint(model) != Checkpoint(model=model)
.
נקודות ביקורת TF2 פועלות עם שלב ה- build()
של Keras
ל- tf.train.Checkpoint.restore
יש מנגנון שנקרא שחזור דחוי המאפשר tf.Module
ו-Keras לאחסן ערכי משתנים אם המשתנה עדיין לא נוצר. זה מאפשר לדגמים מאותחלים להעמיס משקולות ולבנות לאחר מכן.
m = YourKerasModel()
status = m.load_weights(path)
# This call builds the model. The variables are created with the restored
# values.
m.predict(inputs)
status.assert_consumed()
בגלל מנגנון זה, אנו ממליצים בחום להשתמש בממשקי API של טעינת נקודות TF2 עם מודלים של Keras (אפילו בעת שחזור נקודות ביקורת TF1 קיימות מראש ל- shims של מיפוי המודל ). ראה עוד במדריך המחסום .
קטעי קוד
הקטעים למטה מציגים את תאימות גרסת TF1/TF2 בממשקי ה-API של שמירת נקודות המחסום.
שמור מחסום TF1 ב-TF2
a = tf.Variable(1.0, name='a')
b = tf.Variable(2.0, name='b')
with tf.name_scope('scoped'):
c = tf.Variable(3.0, name='c')
saver = tf1.train.Saver(var_list=[a, b, c])
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone. Checkpoint at 'tf1-ckpt-saved-in-eager': (key='scoped/c', shape=[], dtype=float32, value=3.0) (key='a', shape=[], dtype=float32, value=1.0) (key='b', shape=[], dtype=float32, value=2.0)
טען מחסום TF1 ב-TF2
a = tf.Variable(0., name='a')
b = tf.Variable(0., name='b')
with tf.name_scope('scoped'):
c = tf.Variable(0., name='c')
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Initialized [a, b, c]: [0.0, 0.0, 0.0] WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone. INFO:tensorflow:Restoring parameters from tf1-ckpt-saved-in-eager Restored [a, b, c]: [1.0, 2.0, 3.0]
שמור מחסום TF2 ב-TF1
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(1))
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(2))
with tf1.variable_scope('scoped'):
c = tf1.get_variable('c', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(3))
with tf1.Session() as sess:
sess.run(tf1.global_variables_initializer())
ckpt = tf.train.Checkpoint(
var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session')
print_checkpoint(tf2_in_tf1_path)
Checkpoint at 'tf2-ckpt-saved-in-session-1': (key='var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0) (key='var_list/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0) (key='var_list/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0) (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n \n\x0c\x08\x01\x12\x08var_list\n\x10\x08\x02\x12\x0csave_counter\n\x1c\n\x05\x08\x03\x12\x01a\n\x05\x08\x04\x12\x01b\n\x0c\x08\x05\x12\x08scoped/c\nI\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE\n<\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%var_list/a/.ATTRIBUTES/VARIABLE_VALUE\n<\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%var_list/b/.ATTRIBUTES/VARIABLE_VALUE\nK\x12I\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a-var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE") (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
טען מחסום TF2 ב-TF1
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
with tf1.variable_scope('scoped'):
c = tf1.get_variable('c', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
with tf1.Session() as sess:
sess.run(tf1.global_variables_initializer())
print("Initialized [a, b, c]: ", sess.run([a, b, c]))
ckpt = tf.train.Checkpoint(
var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops()
print("Restored [a, b, c]: ", sess.run([a, b, c]))
Initialized [a, b, c]: [0.0, 0.0, 0.0] Restored [a, b, c]: [1.0, 2.0, 3.0]
המרת מחסום
ניתן להמיר מחסומים בין TF1 ל-TF2 על ידי טעינה ושמירת המחסומים מחדש. חלופה היא tf.train.load_checkpoint
, המוצגת בקוד למטה.
המרת נקודת ביקורת TF1 ל-TF2
def convert_tf1_to_tf2(checkpoint_path, output_prefix):
"""Converts a TF1 checkpoint to TF2.
To load the converted checkpoint, you must build a dictionary that maps
variable names to variable objects.
```
ckpt = tf.train.Checkpoint(vars={name: variable})
ckpt.restore(converted_ckpt_path)
```
Args:
checkpoint_path: Path to the TF1 checkpoint.
output_prefix: Path prefix to the converted checkpoint.
Returns:
Path to the converted checkpoint.
"""
vars = {}
reader = tf.train.load_checkpoint(checkpoint_path)
dtypes = reader.get_variable_to_dtype_map()
for key in dtypes.keys():
vars[key] = tf.Variable(reader.get_tensor(key))
return tf.train.Checkpoint(vars=vars).save(output_prefix)
```
Convert the checkpoint saved in the snippet `Save a TF1 checkpoint in TF2`:
```python
# Make sure to run the snippet in `Save a TF1 checkpoint in TF2`.
print_checkpoint('tf1-ckpt-saved-in-eager')
converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager',
'converted-tf1-to-tf2')
print("\n[Converted]")
print_checkpoint(converted_path)
# Try loading the converted checkpoint.
a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c})
ckpt.restore(converted_path).assert_consumed()
print("\nRestored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Checkpoint at 'tf1-ckpt-saved-in-eager': (key='scoped/c', shape=[], dtype=float32, value=3.0) (key='a', shape=[], dtype=float32, value=1.0) (key='b', shape=[], dtype=float32, value=2.0) [Converted] Checkpoint at 'converted-tf1-to-tf2-1': (key='vars/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0) (key='vars/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0) (key='vars/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0) (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n\x1c\n\x08\x08\x01\x12\x04vars\n\x10\x08\x02\x12\x0csave_counter\n\x1c\n\x0c\x08\x03\x12\x08scoped/c\n\x05\x08\x04\x12\x01a\n\x05\x08\x05\x12\x01b\nI\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE\nG\x12E\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a)vars/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE\n?\x12=\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a!vars/a/.ATTRIBUTES/VARIABLE_VALUE\n?\x12=\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a!vars/b/.ATTRIBUTES/VARIABLE_VALUE") (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1) Restored [a, b, c]: [1.0, 2.0, 3.0]
המרת נקודת ביקורת TF2 ל-TF1
def convert_tf2_to_tf1(checkpoint_path, output_prefix):
"""Converts a TF2 checkpoint to TF1.
The checkpoint must be saved using a
`tf.train.Checkpoint(var_list={name: variable})`
To load the converted checkpoint with `tf.compat.v1.Saver`:
```
saver = tf.compat.v1.train.Saver(var_list={name: variable})
# An alternative, if the variable names match the keys:
saver = tf.compat.v1.train.Saver(var_list=[variables])
saver.restore(sess, output_path)
```
"""
vars = {}
reader = tf.train.load_checkpoint(checkpoint_path)
dtypes = reader.get_variable_to_dtype_map()
for key in dtypes.keys():
# Get the "name" from the
if key.startswith('var_list/'):
var_name = key.split('/')[1]
# TF2 checkpoint keys use '/', so if they appear in the user-defined name,
# they are escaped to '.S'.
var_name = var_name.replace('.S', '/')
vars[var_name] = tf.Variable(reader.get_tensor(key))
return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)
```
Convert the checkpoint saved in the snippet `Save a TF2 checkpoint in TF1`:
```python
# Make sure to run the snippet in `Save a TF2 checkpoint in TF1`.
print_checkpoint('tf2-ckpt-saved-in-session-1')
converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1',
'converted-tf2-to-tf1')
print("\n[Converted]")
print_checkpoint(converted_path)
# Try loading the converted checkpoint.
with tf.Graph().as_default() as g:
a = tf1.get_variable('a', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
b = tf1.get_variable('b', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
with tf1.variable_scope('scoped'):
c = tf1.get_variable('c', shape=[], dtype=tf.float32,
initializer=tf1.constant_initializer(0))
with tf1.Session() as sess:
saver = tf1.train.Saver([a, b, c])
saver.restore(sess, converted_path)
print("\nRestored [a, b, c]: ", sess.run([a, b, c]))
Checkpoint at 'tf2-ckpt-saved-in-session-1': (key='var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0) (key='var_list/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0) (key='var_list/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0) (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n \n\x0c\x08\x01\x12\x08var_list\n\x10\x08\x02\x12\x0csave_counter\n\x1c\n\x05\x08\x03\x12\x01a\n\x05\x08\x04\x12\x01b\n\x0c\x08\x05\x12\x08scoped/c\nI\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE\n<\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%var_list/a/.ATTRIBUTES/VARIABLE_VALUE\n<\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%var_list/b/.ATTRIBUTES/VARIABLE_VALUE\nK\x12I\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a-var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE") (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1) WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone. [Converted] Checkpoint at 'converted-tf2-to-tf1': (key='scoped/c', shape=[], dtype=float32, value=3.0) (key='a', shape=[], dtype=float32, value=1.0) (key='b', shape=[], dtype=float32, value=2.0) INFO:tensorflow:Restoring parameters from converted-tf2-to-tf1 Restored [a, b, c]: [1.0, 2.0, 3.0]