모델 체크포인트 마이그레이션하기

TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서 보기 노트북 다운로드하기

참고: tf.compat.v1.Saver를 사용하여 저장한 체크포인트는 종종 TF1 또는 이름 기반 체크포인트라고 합니다. tf.train.Checkpoint를 사용하여 저장한 체크포인트는 TF2 또는 객체 기반 체크포인트라고 합니다.

개요

이 가이드에서는 tf.compat.v1.Saver로 체크포인트를 저장하고 로드하는 모델이 있으며, TF2 tf.train.Checkpoint API를 사용하거나 TF2 모델의 기존 체크포인트를 사용하려는 경우를 가정합니다.

다음은 발생할 수 있는 몇 가지 일반적인 시나리오입니다.

시나리오 1

이전 훈련에서 실행하는 기존 TF1 체크포인트가 있으며 이를 TF2로 로드하거나 변환해야 합니다.

시나리오 2

변수 이름과 경로를 변경할 위험이 있는 방식으로 모델을 조정하고 있으며(예: get_variable에서 명시적 tf.Variable를 생성하도록 점진적으로 마이그레이션하는 경우) 작업 도중에 기존 체크포인트의 저장/로드를 유지하려고 합니다.

모델 마이그레이션을 진행하는 동안 체크포인트 호환성을 유지하는 방법 섹션을 참조하세요.

시나리오 3

훈련 코드와 체크포인트를 TF2로 마이그레이션하고 있지만 현재 추론 파이프라인은 프로덕션 안정성을 위해 계속해서 TF1 체크포인트를 필요로 합니다.

옵션 1

훈련을 진행할 때 TF1과 TF2 체크포인트를 모두 저장합니다.

옵션 2

TF2 체크포인트를 TF1로 변환합니다.


아래의 예제는 TF1/TF2에서 체크포인트를 저장하고 로드하는 모든 조합을 보여주기에 모델을 마이그레이션하는 방법을 결정할 때 약간의 유연성을 가질 수 있습니다.

설치하기

import tensorflow as tf
import tensorflow.compat.v1 as tf1

def print_checkpoint(save_path):
  reader = tf.train.load_checkpoint(save_path)
  shapes = reader.get_variable_to_shape_map()
  dtypes = reader.get_variable_to_dtype_map()
  print(f"Checkpoint at '{save_path}':")
  for key in shapes:
    print(f"  (key='{key}', shape={shapes[key]}, dtype={dtypes[key].name}, "
          f"value={reader.get_tensor(key)})")
2022-12-14 21:01:23.911576: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 21:01:23.911667: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 21:01:23.911676: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

TF1에서 TF2로 변경하기

이 섹션은 TF1과 TF2 사이에 변경된 사항과 "이름 기반"(TF1) 대비 "객체 기반"(TF2) 체크포인트의 의미에 대해 궁금해하는 내용을 포함합니다.

두 가지 유형의 체크포인트는 실제로 동일한 형식으로 저장되며 이는 본질적으로 키-값 테이블입니다. 키가 생성되는 방식만 다릅니다.

이름 기반 체크포인트의 키는 변수 이름입니다. 객체 기반 체크포인트의 키는 루트 객체에서 변수까지의 경로를 나타냅니다. 아래의 예제를 통해 이것이 의미하는 바를 더 잘 이해할 수 있습니다.

먼저 일부 체크포인트를 저장합니다.

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    saver.save(sess, 'tf1-ckpt')

print_checkpoint('tf1-ckpt')
Checkpoint at 'tf1-ckpt':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)
a = tf.Variable(5.0, name='a')
b = tf.Variable(6.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(7.0, name='c')

ckpt = tf.train.Checkpoint(variables=[a, b, c])
save_path_v2 = ckpt.save('tf2-ckpt')
print_checkpoint(save_path_v2)
Checkpoint at 'tf2-ckpt-1':
  (key='variables/2/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=7.0)
  (key='variables/1/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=6.0)
  (key='variables/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=5.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n%\n\r\x08\x01\x12\tvariables\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n\x19\n\x05\x08\x03\x12\x010\n\x05\x08\x04\x12\x011\n\x05\x08\x05\x12\x012*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nA\x12;\n\x0eVARIABLE_VALUE\x12\x01a\x1a&variables/0/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nA\x12;\n\x0eVARIABLE_VALUE\x12\x01b\x1a&variables/1/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nH\x12B\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a&variables/2/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

tf2-ckpt의 키를 보면 모두 각 변수의 개체 경로를 참조한다는 것을 알 수 있습니다. 예를 들어, 변수 avariables 목록의 첫 번째 요소이므로 해당 키는 variables/0/...이 됩니다(.ATTRIBUTES/VARIABLE_VALUE 상수는 무시해도 됨).

Checkpoint 객체를 자세히 살펴보면 아래와 같습니다.

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
root = ckpt = tf.train.Checkpoint(variables=[a, b, c])
print("root type =", type(root).__name__)
print("root.variables =", root.variables)
print("root.variables[0] =", root.variables[0])
root type = Checkpoint
root.variables = ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>])
root.variables[0] = <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>

아래 스니펫으로 실험해보고 객체 구조에 따라 체크포인트 키가 어떻게 변경되는지 확인해 보세요.

module = tf.Module()
module.d = tf.Variable(0.)
test_ckpt = tf.train.Checkpoint(v={'a': a, 'b': b}, 
                                c=c,
                                module=module)
test_ckpt_path = test_ckpt.save('root-tf2-ckpt')
print_checkpoint(test_ckpt_path)
Checkpoint at 'root-tf2-ckpt-1':
  (key='v/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='v/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='module/d/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='c/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=0.0)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n0\n\x05\x08\x01\x12\x01c\n\n\x08\x02\x12\x06module\n\x05\x08\x03\x12\x01v\n\x10\x08\x04\x12\x0csave_counter*\x02\x08\x01\n>\x128\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1cc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n\x0b\n\x05\x08\x05\x12\x01d*\x02\x08\x01\n\x12\n\x05\x08\x06\x12\x01a\n\x05\x08\x07\x12\x01b*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nE\x12?\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a#module/d/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1ev/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a\x1ev/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

TF2가 이 메커니즘을 사용하는 이유는 무엇인가요?

TF2에는 더 이상 전역 그래프가 없기 때문에 변수 이름을 신뢰할 수 없으며 프로그램이 서로 일관성이 없을 수 있습니다. TF2는 변수는 레이어가 소유하고, 레이어는 모델이 소유하는 객체 지향 모델링 접근 방식을 권장합니다.

variable = tf.Variable(...)
layer.variable_name = variable
model.layer_name = layer

모델 마이그레이션을 진행하는 동안 체크포인트 호환성을 유지하는 방법

모든 변수가 올바른 값으로 초기화되었는지 확인하면 작업/함수가 올바른 계산을 수행하고 있는지 검증할 수 있기에 마이그레이션 프로세스에서 중요합니다. 이렇게 하려면 다양한 마이그레이션 단계에서 모델 간의 체크포인트 호환성을 고려해야 합니다. 기본적으로 이 섹션은 모델을 변경하는 동안 어떻게 해야 동일한 체크포인트를 계속 사용할 수 있나요라는 질문에 답합니다.

다음은 유연성을 높이기 위해 체크포인트 호환성을 유지하는 세 가지 방법입니다.

  1. 모델이 이전과 동일한 변수 이름을 갖습니다.
  2. 모델이 다른 변수 이름을 가지며, 체크포인트의 변수 이름을 새 이름에 매핑하는 할당 매핑을 유지합니다.
  3. 모델이 다른 변수 이름을 가지며, 모든 변수를 저장하는 TF2 체크포인트 객체를 유지합니다.

변수 이름이 일치하는 경우

긴 제목: 변수 이름이 일치할 때 체크포인트를 다시 사용하는 방법

짧은 대답: tf1.train.Saver 또는 tf.train.Checkpoint를 사용하여 기존 체크포인트를 직접 로드할 수 있습니다.


tf.compat.v1.keras.utils.track_tf1_style_variables를 사용하는 경우 모델 변수 이름이 이전과 동일한지 확인합니다. 변수 이름이 일치하는지 수동으로 확인할 수도 있습니다.

마이그레이션된 모델의 변수 이름이 일치하면 tf.train.Checkpoint 또는 tf.compat.v1.train.Saver를 직접 사용하여 체크포인트를 로드할 수도 있습니다. 두 API 모두 Eager 모드 및 그래프 모드와 호환되므로 마이그레이션의 모든 단계에서 사용할 수 있습니다.

참고: tf.train.Checkpoint를 사용하여 TF1 체크포인트를 로드할 수 있지만 복잡한 이름 일치 작업 없이 tf.compat.v1.Saver를 사용하여 TF2 체크포인트를 로드할 수 없습니다.

다음은 다른 모델에 동일한 체크포인트를 사용하는 예제입니다. 먼저 tf1.train.Saver를 사용하여 TF1 체크포인트를 저장합니다.

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    saver = tf1.train.Saver()
    sess.run(a.assign(1))
    sess.run(b.assign(2))
    sess.run(c.assign(3))
    save_path = saver.save(sess, 'tf1-ckpt')
print_checkpoint(save_path)
Checkpoint at 'tf1-ckpt':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

아래 예제는 tf.compat.v1.Saver를 사용하여 Eager 모드로 체크포인트를 로드합니다.

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# With the removal of collections in TF2, you must pass in the list of variables
# to the Saver object:
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path=save_path)
print(f"loaded values of [a, b, c]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}]")

# Saving also works in eager (sess must be None).
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
INFO:tensorflow:Restoring parameters from tf1-ckpt
loaded values of [a, b, c]:  [1.0, 2.0, 3.0]
Checkpoint at 'tf1-ckpt-saved-in-eager':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

다음 스니펫은 TF2 API tf.train.Checkpoint를 사용하여 체크포인트를 로드합니다.

a = tf.Variable(0.0, name='a')
b = tf.Variable(0.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0.0, name='c')

# Without the name_scope, name="scoped/c" works too:
c_2 = tf.Variable(0.0, name='scoped/c')

print("Variable names: ")
print(f"  a.name = {a.name}")
print(f"  b.name = {b.name}")
print(f"  c.name = {c.name}")
print(f"  c_2.name = {c_2.name}")

# Restore the values with tf.train.Checkpoint
ckpt = tf.train.Checkpoint(variables=[a, b, c, c_2])
ckpt.restore(save_path)
print(f"loaded values of [a, b, c, c_2]:  [{a.numpy()}, {b.numpy()}, {c.numpy()}, {c_2.numpy()}]")
Variable names: 
  a.name = a:0
  b.name = b:0
  c.name = scoped/c:0
  c_2.name = scoped/c:0
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/checkpoint/checkpoint.py:1473: NameBasedSaverStatus.__init__ (from tensorflow.python.checkpoint.checkpoint) is deprecated and will be removed in a future version.
Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
loaded values of [a, b, c, c_2]:  [1.0, 2.0, 3.0, 3.0]

TF2의 변수 이름

  • 여전히 변수에는 모두 설정할 수 있는 name 인수가 있습니다.
  • 또한 Keras 모델은 변수의 접두사로 설정하는 name 인수를 사용합니다.
  • v1.name_scope 함수를 변수 이름의 접두어를 지정하는데 사용할 수 있습니다. 이 함수는 tf.variable_scope와는 매우 다릅니다. 이름에만 영향을 미치며 변수를 추적하거나 재사용을 관장하지 않습니다.

tf.compat.v1.keras.utils.track_tf1_style_variables 데코레이터는 tf.variable_scopetf.compat.v1.get_variable의 이름 지정 및 재사용 의미 체계가 변경되지 않도록 유지하여 변수 이름과 TF1 체크포인트 호환성을 유지하는 데 도움이 되는 shim입니다. 자세한 정보는 모델 매핑 가이드를 참조하세요.

참고 1: shim을 사용하는 경우 TF2 API를 사용하여 체크포인트를 로드합니다(사전 훈련된 TF1 체크포인트를 사용하는 경우에도).

체크포인트 Keras 섹션을 참조하세요.

참고 2: get_variable에서 tf.Variable로 마이그레이션하는 경우:

Shim으로 데코레이팅한 레이어 또는 모듈이 tf.compat.v1.get_variable 대신 tf.Variable을 사용하는 일부 변수(또는 Keras 레이어/모델)로 구성되어 있고 객체 지향 방식의 속성와 추적이 첨부된 경우 TF1.x 그래프/세션 대 Eager 실행을 진행하는 동안 다른 변수 이름 지정 의미 체계를 가질 수 있습니다.

즉, TF2에서 실행할 경우 이름이 예상과 다를 수 있습니다.

경고: Eager 실행에서 변수의 이름은 중복될 수 있으며 이는 이름 기반 체크포인트의 여러 변수를 동일한 이름에 매핑해야 하는 경우 문제를 일으킬 수 있습니다. tf.name_scope 및 레이어 생성자 또는 tf.Variable name 인수를 사용하여 레이어 및 변수 이름을 명시적으로 조정하여 변수 이름을 조정하고 중복이 없는지 확인할 수 있습니다.

할당 매핑 유지 관리하기

할당 매핑은 일반적으로 TF1 모델 사이에 가중치를 전이할 때 사용하며 변수 이름이 변경되는 경우 모델 마이그레이션을 진행할 때에도 사용할 수 있습니다.

이 매핑은 tf.compat.v1.train.init_from_checkpoint, tf.compat.v1.train.Saver와 함께 사용할 수 있으며, tf.train.load_checkpoint를 사용하여 변수 또는 범위 이름이 변경되었을 수 있는 모델에 가중치를 로드할 수 있습니다.

이 섹션의 예제에서는 이전에 저장된 체크포인트를 사용합니다.

print_checkpoint('tf1-ckpt')
Checkpoint at 'tf1-ckpt':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

init_from checkpoint를 사용하여 로드하기

tf1.train.init_from_checkpoint는 할당 연산을 생성하는 대신 변수 이니셜라이저에 값을 배치하기 때문에 그래프/세션에 있는 동안 호출해야 합니다.

assignment_map 인수를 사용하여 변수가 로드되는 방식을 구성할 수 있습니다. 다음 문서 내용을 확인하세요.

할당 매핑이 지원하는 구문:

  • 'checkpoint_scope_name/': 'scope_name/' - 텐서 이름이 일치하는 checkpoint_scope_name의 현재 scope_name에 있는 모든 변수를 로드합니다.
  • 'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name' - checkpoint_scope_name/some_other_variable에서 scope_name/variable_name 변수를 초기화합니다.
  • 'scope_variable_name': variable - 체크포인트에서 텐서 'scope_variable_name'을 사용하여 주어진 tf.Variable 객체를 초기화합니다.
  • 'scope_variable_name': list(variable) - 체크포인트에서 텐서 'scope_variable_name'을 사용하여 분할된 변수 목록을 초기화합니다.
  • '/': 'scope_name/' - 체크포인트의 루트에서 현재 scope_name의 모든 변수를 로드합니다(예: 범위 없음).
# Restoring with tf1.train.init_from_checkpoint:

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # The assignment map will remap all variables in the checkpoint to the
    # new scope:
    tf1.train.init_from_checkpoint(
        'tf1-ckpt',
        assignment_map={'/': 'new_scope/'})
    # `init_from_checkpoint` adds the initializers to these variables.
    # Use `sess.run` to run these initializers.
    sess.run(tf1.global_variables_initializer())

    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Restored [a, b, c]:  [1.0, 2.0, 3.0]

tf1.train.Saver을 사용하여 로드하기

init_from_checkpoint와 달리 tf.compat.v1.train.Saver는 그래프와 Eager 모드에서 모두 실행됩니다. var_list 인수는 변수 이름을 tf.Variable 객체에 매핑해야 한다는 점을 제외하고 선택적으로 사전을 허용합니다.

# Restoring with tf1.train.Saver (works in both graph and eager):

# A new model with a different scope for the variables.
with tf1.variable_scope('new_scope'):
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                      initializer=tf1.zeros_initializer())
  c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
# Initialize the saver with a dictionary with the original variable names:
saver = tf1.train.Saver({'a': a, 'b': b, 'scoped/c': c})
saver.restore(sess=None, save_path='tf1-ckpt')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
INFO:tensorflow:Restoring parameters from tf1-ckpt
Restored [a, b, c]:  [1.0, 2.0, 3.0]

tf.train.load_checkpoint을 사용하여 로드하기

이 옵션은 변수 값을 정밀하게 제어해야 하는 경우에 적합합니다. 다시 말하지만 이 옵션은그래프 모드와 Eager 모드 모두에서 작동합니다.

# Restoring with tf.train.load_checkpoint (works in both graph and eager):

# A new model with a different scope for the variables.
with tf.Graph().as_default() as g:
  with tf1.variable_scope('new_scope'):
    a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
    c = tf1.get_variable('scoped/c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    # It may be easier writing a loop if your model has a lot of variables.
    reader = tf.train.load_checkpoint('tf1-ckpt')
    sess.run(a.assign(reader.get_tensor('a')))
    sess.run(b.assign(reader.get_tensor('b')))
    sess.run(c.assign(reader.get_tensor('scoped/c')))
    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Restored [a, b, c]:  [1.0, 2.0, 3.0]

TF2 체크포인트 객체 유지 관리하기

마이그레이션을 진행하는 동안 변수 및 범위 이름이 많이 변경될 수 있는 경우 tf.train.Checkpoint 및 TF2 체크포인트를 사용합니다. TF2는 변수 이름 대신 객체 구조를 사용합니다(자세한 내용은 TF1에서 TF2로의 변경사항 참조).

즉, tf.train.Checkpoint를 생성하여 체크포인트를 저장하거나 복원할 때 동일한 순서(목록의 경우)와 (Checkpoint 이니셜라이저에 대한 사전 및 키워드 인수의 경우)를 사용하는지 확인합니다. 체크포인트 호환성의 몇 가지 예제는 다음과 같습니다.

ckpt = tf.train.Checkpoint(foo=[var_a, var_b])

# compatible with ckpt
tf.train.Checkpoint(foo=[var_a, var_b])

# not compatible with ckpt
tf.train.Checkpoint(foo=[var_b, var_a])
tf.train.Checkpoint(bar=[var_a, var_b])

아래 코드 샘플은 "동일한" tf.train.Checkpoint를 사용하여 다른 이름의 변수를 로드하는 방식을 보여줍니다. 먼저 TF2 체크포인트를 저장합니다.

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("[a, b, c]: ", sess.run([a, b, c]))

    # Save a TF2 checkpoint
    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    tf2_ckpt_path = ckpt.save('tf2-ckpt')
    print_checkpoint(tf2_ckpt_path)
[a, b, c]:  [1.0, 2.0, 3.0]
Checkpoint at 'tf2-ckpt-1':
  (key='unscoped/1/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='unscoped/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='scoped/0/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n0\n\n\x08\x01\x12\x06scoped\n\x0c\x08\x02\x12\x08unscoped\n\x10\x08\x03\x12\x0csave_counter*\x02\x08\x01\n\x0b\n\x05\x08\x04\x12\x010*\x02\x08\x01\n\x12\n\x05\x08\x05\x12\x010\n\x05\x08\x06\x12\x011*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nE\x12?\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a#scoped/0/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%unscoped/0/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%unscoped/1/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

변수/범위 이름이 변경되더라도 tf.train.Checkpoint를 계속 사용할 수 있습니다.

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  b = tf1.get_variable('b_different_name', shape=[], dtype=tf.float32, 
                       initializer=tf1.zeros_initializer())
  with tf1.variable_scope('different_scope'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.zeros_initializer())
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))

    ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])
    # `assert_consumed` validates that all checkpoint objects are restored from
    # the checkpoint. `run_restore_ops` is required when running in a TF1
    # session.
    ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()

    # Removing `assert_consumed` is fine if you want to skip the validation.
    # ckpt.restore(tf2_ckpt_path).run_restore_ops()

    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
Restored [a, b, c]:  [1.0, 2.0, 3.0]

Eager 모드에서는 다음과 같습니다.

a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])

# The keys "scoped" and "unscoped" are no longer relevant, but are used to
# maintain compatibility with the saved checkpoints.
ckpt = tf.train.Checkpoint(unscoped=[a, b], scoped=[c])

ckpt.restore(tf2_ckpt_path).assert_consumed().run_restore_ops()
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
Restored [a, b, c]:  [1.0, 2.0, 3.0]

Estimator의 TF2 체크포인트

위의 섹션에서는 모델을 마이그레이션하는 동안 체크포인트 호환성을 유지하는 방식을 설명합니다. 체크포인트가 저장/로드되는 방식이 약간 다르지만 이러한 개념은 Estimator 모델에도 적용됩니다. TF2 API를 사용하도록 Estimator 모델을 마이그레이션할 때 모델이 계속 Estimator를 사용하는 동안 TF1에서 TF2 체크포인트로 전환하길 바랄 수 있습니다. 이 섹션은 그렇게 하는 방법을 보여줍니다.

tf.estimator.EstimatorMonitoredSession에는 scaffold라는 tf.compat.v1.train.Scaffold 객체 저장 메커니즘이 있습니다. ScaffoldEstimatorMonitoredSession을 사용하여 TF1 또는 TF2 스타일의 체크포인트를 저장합니다.

# A model_fn that saves a TF1 checkpoint
def model_fn_tf1_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=None
  )

!rm -rf est-tf1
est = tf.estimator.Estimator(model_fn_tf1_ckpt, 'est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf1')
print_checkpoint(latest_checkpoint)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'est-tf1', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into est-tf1/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into est-tf1/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:Loss for final step: 1.0.
Checkpoint at 'est-tf1/model.ckpt-1':
  (key='var', shape=[], dtype=float32, value=2.0)
  (key='global_step', shape=[], dtype=int64, value=1)
# A model_fn that saves a TF2 checkpoint
def model_fn_tf2_ckpt(features, labels, mode):
  # This model adds 2 to the variable `v` in every train step.
  train_step = tf1.train.get_or_create_global_step()
  v = tf1.get_variable('var', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  ckpt = tf.train.Checkpoint(var_list={'var': v}, step=train_step)
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=v,
      train_op=tf.group(v.assign_add(2), train_step.assign_add(1)),
      loss=tf.constant(1.),
      scaffold=tf1.train.Scaffold(saver=ckpt)
  )

!rm -rf est-tf2
est = tf.estimator.Estimator(model_fn_tf2_ckpt, 'est-tf2',
                             warm_start_from='est-tf1')

def train_fn():
  return tf.data.Dataset.from_tensor_slices(([1,2,3], [4,5,6]))
est.train(train_fn, steps=1)

latest_checkpoint = tf.train.latest_checkpoint('est-tf2')
print_checkpoint(latest_checkpoint)  

assert est.get_variable_value('var_list/var/.ATTRIBUTES/VARIABLE_VALUE') == 4
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'est-tf2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='est-tf1', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: est-tf1
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 1 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into est-tf2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into est-tf2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:Loss for final step: 1.0.
Checkpoint at 'est-tf2/model.ckpt-1':
  (key='var_list/var/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=4.0)
  (key='step/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n\x1c\n\x08\x08\x01\x12\x04step\n\x0c\x08\x02\x12\x08var_list*\x02\x08\x01\nD\x12>\n\x0eVARIABLE_VALUE\x12\x0bglobal_step\x1a\x1fstep/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n\r\n\x07\x08\x03\x12\x03var*\x02\x08\x01\nD\x12>\n\x0eVARIABLE_VALUE\x12\x03var\x1a'var_list/var/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

v의 최종 값은 est-tf1에서 웜 스타트된 후 추가 5단계를 훈련하여 16가 되어야 합니다. 훈련 단계 값은 warm_start 체크포인트로부터 전달되지 않습니다.

Keras 체크포인트하기

Keras로 빌드한 모델은 여전히 tf1.train.Savertf.train.Checkpoint를 사용하여 기존 가중치를 로드합니다. 모델이 완전히 마이그레이션되면 특히 ModelCheckpoint 콜백을 사용하여 훈련하는 경우 model.save_weightsmodel.load_weights를 사용하도록 전환합니다.

체크포인트와 Keras에 대해 알아야 할 몇 가지 사항:

초기화 vs 빌드

Keras 모델과 레이어는 완전히 생성되기 전에 2단계를 거쳐야 합니다. 첫 번째 단계는 Python 객체의 초기화layer = tf.keras.layers.Dense(x)입니다. 두 번째 단계는 대부분의 가중치가 실제로 생성되는 빌드 단계인 layer.build(input_shape)입니다. 모델을 호출하거나 단일 train, eval 또는 predict 단계(처음에만 해당)를 실행하여 모델을 빌드할 수도 있습니다.

model.load_weights(path).assert_consumed()에서 오류가 발생하는 것이 확인되면 모델/레이어가 빌드되지 않았을 가능성이 높습니다.

TF2 체크포인트를 사용하는 Keras

tf.train.Checkpoint(model).writemodel.save_weights와 동일합니다. tf.train.Checkpoint(model).readmodel.load_weights도 동일합니다. Checkpoint(model) != Checkpoint(model=model)에 유의해야 합니다.

Keras의 build() 단계와 함께 동작하는 TF2 체크포인트

tf.train.Checkpoint.restore에는 지연된 복원이라는 메커니즘이 있으며 변수가 아직 생성되지 않은 경우 tf.Module과 Keras 객체를 사용하여 변수 값을 저장합니다. 이렇게 하면 초기화된 모델로 가중치를 로드하고 나중에 빌드할 수 있습니다.

m = YourKerasModel()
status = m.load_weights(path)

# This call builds the model. The variables are created with the restored
# values.
m.predict(inputs)

status.assert_consumed()

이 메커니즘 때문에 Keras 모델과 함께 TF2 체크포인트 로드 API를 사용하는 것이 좋습니다(기존 TF1 체크포인트를 모델 매핑 shim으로 복원하는 경우도 동일). 자세한 내용은 체크포인트 가이드를 참조합니다.

코드 스니펫

아래의 스니펫은 체크포인트 저장 API의 TF1/TF2 버전 호환성을 보여줍니다.

TF2에서 TF1 체크포인트 저장하기

a = tf.Variable(1.0, name='a')
b = tf.Variable(2.0, name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(3.0, name='c')

saver = tf1.train.Saver(var_list=[a, b, c])
path = saver.save(sess=None, save_path='tf1-ckpt-saved-in-eager')
print_checkpoint(path)
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
Checkpoint at 'tf1-ckpt-saved-in-eager':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

TF2에서 TF1 체크포인트 로드하기

a = tf.Variable(0., name='a')
b = tf.Variable(0., name='b')
with tf.name_scope('scoped'):
  c = tf.Variable(0., name='c')
print("Initialized [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
saver = tf1.train.Saver(var_list=[a, b, c])
saver.restore(sess=None, save_path='tf1-ckpt-saved-in-eager')
print("Restored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
INFO:tensorflow:Restoring parameters from tf1-ckpt-saved-in-eager
Restored [a, b, c]:  [1.0, 2.0, 3.0]

TF1에서 TF2 체크포인트 저장하기

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(1))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(2))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(3))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    tf2_in_tf1_path = ckpt.save('tf2-ckpt-saved-in-session')
    print_checkpoint(tf2_in_tf1_path)
Checkpoint at 'tf2-ckpt-saved-in-session-1':
  (key='var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='var_list/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='var_list/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n$\n\x0c\x08\x01\x12\x08var_list\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n \n\x05\x08\x03\x12\x01a\n\x05\x08\x04\x12\x01b\n\x0c\x08\x05\x12\x08scoped/c*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%var_list/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%var_list/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nO\x12I\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a-var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

TF1에서 TF2 체크포인트 로드하기

with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(0))
  with tf1.Session() as sess:
    sess.run(tf1.global_variables_initializer())
    print("Initialized [a, b, c]: ", sess.run([a, b, c]))
    ckpt = tf.train.Checkpoint(
        var_list={v.name.split(':')[0]: v for v in tf1.global_variables()})
    ckpt.restore('tf2-ckpt-saved-in-session-1').run_restore_ops()
    print("Restored [a, b, c]: ", sess.run([a, b, c]))
Initialized [a, b, c]:  [0.0, 0.0, 0.0]
Restored [a, b, c]:  [1.0, 2.0, 3.0]

체크포인트 변환하기

체크포인트를 로드하고 다시 저장하는 방식으로 TF1과 TF2 사이의 체크포인트를 변환할 수 있습니다. 대안은 아래 코드에 표시된 tf.train.load_checkpoint입니다.

TF1 체크포인트를 TF2로 변환하기

def convert_tf1_to_tf2(checkpoint_path, output_prefix):
  """Converts a TF1 checkpoint to TF2.

  To load the converted checkpoint, you must build a dictionary that maps
  variable names to variable objects.
  ```
  ckpt = tf.train.Checkpoint(vars={name: variable})  
  ckpt.restore(converted_ckpt_path)

    ```

    Args:
      checkpoint_path: Path to the TF1 checkpoint.
      output_prefix: Path prefix to the converted checkpoint.

    Returns:
      Path to the converted checkpoint.
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      vars[key] = tf.Variable(reader.get_tensor(key))
    return tf.train.Checkpoint(vars=vars).save(output_prefix)
  ```

스니펫 `Save a TF1 checkpoint in TF2`에 저장된 체크포인트를 변환합니다.


```python
# Make sure to run the snippet in `Save a TF1 checkpoint in TF2`.
print_checkpoint('tf1-ckpt-saved-in-eager')
converted_path = convert_tf1_to_tf2('tf1-ckpt-saved-in-eager', 
                                     'converted-tf1-to-tf2')
print("\n[Converted]")
print_checkpoint(converted_path)

# Try loading the converted checkpoint.
a = tf.Variable(0.)
b = tf.Variable(0.)
c = tf.Variable(0.)
ckpt = tf.train.Checkpoint(vars={'a': a, 'b': b, 'scoped/c': c})
ckpt.restore(converted_path).assert_consumed()
print("\nRestored [a, b, c]: ", [a.numpy(), b.numpy(), c.numpy()])
Checkpoint at 'tf1-ckpt-saved-in-eager':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)

[Converted]
Checkpoint at 'converted-tf1-to-tf2-1':
  (key='vars/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='vars/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='vars/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n \n\x08\x08\x01\x12\x04vars\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n \n\x0c\x08\x03\x12\x08scoped/c\n\x05\x08\x04\x12\x01b\n\x05\x08\x05\x12\x01a*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nK\x12E\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a)vars/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nC\x12=\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a!vars/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nC\x12=\n\x0eVARIABLE_VALUE\x12\x08Variable\x1a!vars/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")

Restored [a, b, c]:  [1.0, 2.0, 3.0]

TF2 체크포인트를 TF1로 변환하기

def convert_tf2_to_tf1(checkpoint_path, output_prefix):
  """Converts a TF2 checkpoint to TF1.

  The checkpoint must be saved using a 
  `tf.train.Checkpoint(var_list={name: variable})`

  To load the converted checkpoint with `tf.compat.v1.Saver`:
  ```
  saver = tf.compat.v1.train.Saver(var_list={name: variable}) 

  # An alternative, if the variable names match the keys:
  saver = tf.compat.v1.train.Saver(var_list=[variables]) 
  saver.restore(sess, output_path)

    ```
    """
    vars = {}
    reader = tf.train.load_checkpoint(checkpoint_path)
    dtypes = reader.get_variable_to_dtype_map()
    for key in dtypes.keys():
      # Get the "name" from the 
      if key.startswith('var_list/'):
        var_name = key.split('/')[1]
        # TF2 checkpoint keys use '/', so if they appear in the user-defined name,
        # they are escaped to '.S'.
        var_name = var_name.replace('.S', '/')
        vars[var_name] = tf.Variable(reader.get_tensor(key))

    return tf1.train.Saver(var_list=vars).save(sess=None, save_path=output_prefix)
  ```

스니펫 `Save a TF2 checkpoint in TF1`에 저장된 체크포인트를 변환합니다.


```python
# Make sure to run the snippet in `Save a TF2 checkpoint in TF1`.
print_checkpoint('tf2-ckpt-saved-in-session-1')
converted_path = convert_tf2_to_tf1('tf2-ckpt-saved-in-session-1',
                                    'converted-tf2-to-tf1')
print("\n[Converted]")
print_checkpoint(converted_path)

# Try loading the converted checkpoint.
with tf.Graph().as_default() as g:
  a = tf1.get_variable('a', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  b = tf1.get_variable('b', shape=[], dtype=tf.float32, 
                       initializer=tf1.constant_initializer(0))
  with tf1.variable_scope('scoped'):
    c = tf1.get_variable('c', shape=[], dtype=tf.float32, 
                        initializer=tf1.constant_initializer(0))
  with tf1.Session() as sess:
    saver = tf1.train.Saver([a, b, c])
    saver.restore(sess, converted_path)
    print("\nRestored [a, b, c]: ", sess.run([a, b, c]))
Checkpoint at 'tf2-ckpt-saved-in-session-1':
  (key='var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=3.0)
  (key='var_list/b/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=2.0)
  (key='var_list/a/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=float32, value=1.0)
  (key='save_counter/.ATTRIBUTES/VARIABLE_VALUE', shape=[], dtype=int64, value=1)
  (key='_CHECKPOINTABLE_OBJECT_GRAPH', shape=[], dtype=string, value=b"\n$\n\x0c\x08\x01\x12\x08var_list\n\x10\x08\x02\x12\x0csave_counter*\x02\x08\x01\n \n\x05\x08\x03\x12\x01a\n\x05\x08\x04\x12\x01b\n\x0c\x08\x05\x12\x08scoped/c*\x02\x08\x01\nM\x12G\n\x0eVARIABLE_VALUE\x12\x0csave_counter\x1a'save_counter/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01a\x1a%var_list/a/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\n@\x12:\n\x0eVARIABLE_VALUE\x12\x01b\x1a%var_list/b/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01\nO\x12I\n\x0eVARIABLE_VALUE\x12\x08scoped/c\x1a-var_list/scoped.Sc/.ATTRIBUTES/VARIABLE_VALUE*\x02\x08\x01")
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.

[Converted]
Checkpoint at 'converted-tf2-to-tf1':
  (key='scoped/c', shape=[], dtype=float32, value=3.0)
  (key='b', shape=[], dtype=float32, value=2.0)
  (key='a', shape=[], dtype=float32, value=1.0)
INFO:tensorflow:Restoring parameters from converted-tf2-to-tf1

Restored [a, b, c]:  [1.0, 2.0, 3.0]

관련 가이드