![]() |
![]() |
![]() |
![]() |
개요
tf.distribute.Strategy
API는 훈련을 여러 처리 장치들로 분산시키는 것을 추상화한 것입니다. 기존의 모델이나 훈련 코드를 조금만 바꾸어 분산 훈련을 할 수 있게 하는 것이 분산 전략 API의 목표입니다.
이 튜토리얼에서는 tf.distribute.MirroredStrategy
를 사용합니다. 이 전략은 동기화된 훈련 방식을 활용하여 한 장비에 있는 여러 개의 GPU로 그래프 내 복제를 수행합니다. 다시 말하자면, 모델의 모든 변수를 각 프로세서에 복사합니다. 그리고 각 프로세서의 그래디언트(gradient)를 올 리듀스(all-reduce)를 사용하여 모읍니다. 그다음 모아서 계산한 값을 각 프로세서의 모델 복사본에 적용합니다.
MirroredStategy
는 텐서플로에서 기본으로 제공하는 몇 가지 분산 전략 중 하나입니다. 다른 전략들에 대해서는 분산 전략 가이드를 참고하십시오.
케라스 API
이 예는 모델과 훈련 루프를 만들기 위해 tf.keras
API를 사용합니다. 직접 훈련 코드를 작성하는 방법은 사용자 정의 훈련 루프로 분산 훈련하기 튜토리얼을 참고하십시오.
필요한 패키지 가져오기
# 텐서플로와 텐서플로 데이터셋 패키지 가져오기
!pip install -q tensorflow-gpu==2.0.0-rc1
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
import os
데이터셋 다운로드
MNIST 데이터셋을 TensorFlow Datasets에서 다운로드받은 후 불러옵니다. 이 함수는 tf.data
형식을 반환합니다.
with_info
를 True
로 설정하면 전체 데이터에 대한 메타 정보도 함께 불러옵니다. 이 정보는 info
변수에 저장됩니다. 여기에는 훈련과 테스트 샘플 수를 비롯한 여러가지 정보들이 들어있습니다.
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
WARNING:tensorflow:Entity <function _get_dataset_from_filename at 0x7ff1e6df6d08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num' WARNING:tensorflow:Entity <function _get_dataset_from_filename at 0x7ff1e6df6d08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num' WARNING: Entity <function _get_dataset_from_filename at 0x7ff1e6df6d08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num' WARNING:tensorflow:Entity <bound method TopLevelFeature.decode_example of FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), })> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4 WARNING:tensorflow:Entity <bound method TopLevelFeature.decode_example of FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), })> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4 WARNING: Entity <bound method TopLevelFeature.decode_example of FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), })> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4 WARNING:tensorflow:Entity <function _get_dataset_from_filename at 0x7ff1e6df6d08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num' WARNING:tensorflow:Entity <function _get_dataset_from_filename at 0x7ff1e6df6d08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num' WARNING: Entity <function _get_dataset_from_filename at 0x7ff1e6df6d08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num' WARNING:tensorflow:Entity <bound method TopLevelFeature.decode_example of FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), })> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4 WARNING:tensorflow:Entity <bound method TopLevelFeature.decode_example of FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), })> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4 WARNING: Entity <bound method TopLevelFeature.decode_example of FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), })> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4
분산 전략 정의하기
분산과 관련된 처리를 하는 MirroredStrategy
객체를 만듭니다. 이 객체가 컨텍스트 관리자(tf.distribute.MirroredStrategy.scope
)도 제공하는데, 이 안에서 모델을 만들어야 합니다.
strategy = tf.distribute.MirroredStrategy()
WARNING:tensorflow:There is non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce. WARNING:tensorflow:There is non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))
장치의 수: 1
입력 파이프라인 구성하기
다중 GPU로 모델을 훈련할 때는 배치 크기를 늘려야 컴퓨팅 자원을 효과적으로 사용할 수 있습니다. 기본적으로는 GPU 메모리에 맞추어 가능한 가장 큰 배치 크기를 사용하십시오. 이에 맞게 학습률도 조정해야 합니다.
# 데이터셋 내 샘플의 수는 info.splits.total_num_examples 로도
# 얻을 수 있습니다.
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
픽셀의 값은 0~255 사이이므로 0-1 범위로 정규화해야 합니다. 정규화 함수를 정의합니다.
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
이 함수를 훈련과 테스트 데이터에 적용합니다. 훈련 데이터 순서를 섞고, 훈련을 위해 배치로 묶습니다.
train_dataset = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
WARNING:tensorflow:Entity <function scale at 0x7ff2544478c8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4 WARNING:tensorflow:Entity <function scale at 0x7ff2544478c8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4 WARNING: Entity <function scale at 0x7ff2544478c8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Bad argument number for Name: 3, expecting 4
모델 만들기
strategy.scope
컨텍스트 안에서 케라스 모델을 만들고 컴파일합니다.
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
콜백 정의하기
여기서 사용하는 콜백은 다음과 같습니다.
- 텐서보드(TensorBoard): 이 콜백은 텐서보드용 로그를 남겨서, 텐서보드에서 그래프를 그릴 수 있게 해줍니다.
- 모델 체크포인트(Checkpoint): 이 콜백은 매 에포크(epoch)가 끝난 후 모델을 저장합니다.
- 학습률 스케줄러: 이 콜백을 사용하면 매 에포크 혹은 배치가 끝난 후 학습률을 바꿀 수 있습니다.
콜백을 추가하는 방법을 보여드리기 위하여 노트북에 학습률을 표시하는 콜백도 추가하겠습니다.
# 체크포인트를 저장할 체크포인트 디렉터리를 지정합니다.
checkpoint_dir = './training_checkpoints'
# 체크포인트 파일의 이름
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# 학습률을 점점 줄이기 위한 함수
# 필요한 함수를 직접 정의하여 사용할 수 있습니다.
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
# 에포크가 끝날 때마다 학습률을 출력하는 콜백.
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\n에포크 {}의 학습률은 {}입니다.'.format(epoch + 1,
model.optimizer.lr.numpy()))
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]
훈련과 평가
이제 평소처럼 모델을 학습합시다. 모델의 fit
함수를 호출하고 튜토리얼의 시작 부분에서 만든 데이터셋을 넘깁니다. 이 단계는 분산 훈련 여부와 상관없이 동일합니다.
model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12 WARNING:tensorflow:Entity <function Function._initialize_uninitialized_variables.<locals>.initialize_variables at 0x7ff1c048b268> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num' WARNING:tensorflow:Entity <function Function._initialize_uninitialized_variables.<locals>.initialize_variables at 0x7ff1c048b268> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num' WARNING: Entity <function Function._initialize_uninitialized_variables.<locals>.initialize_variables at 0x7ff1c048b268> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Num' 938/Unknown - 17s 18ms/step - loss: 0.2076 - accuracy: 0.9396 에포크 1의 학습률은 0.0010000000474974513입니다. 938/938 [==============================] - 17s 18ms/step - loss: 0.2076 - accuracy: 0.9396 Epoch 2/12 935/938 [============================>.] - ETA: 0s - loss: 0.0663 - accuracy: 0.9798 에포크 2의 학습률은 0.0010000000474974513입니다. 938/938 [==============================] - 11s 12ms/step - loss: 0.0663 - accuracy: 0.9798 Epoch 3/12 936/938 [============================>.] - ETA: 0s - loss: 0.0457 - accuracy: 0.9861 에포크 3의 학습률은 0.0010000000474974513입니다. 938/938 [==============================] - 11s 12ms/step - loss: 0.0457 - accuracy: 0.9862 Epoch 4/12 937/938 [============================>.] - ETA: 0s - loss: 0.0254 - accuracy: 0.9933 에포크 4의 학습률은 9.999999747378752e-05입니다. 938/938 [==============================] - 11s 12ms/step - loss: 0.0254 - accuracy: 0.9933 Epoch 5/12 936/938 [============================>.] - ETA: 0s - loss: 0.0228 - accuracy: 0.9941 에포크 5의 학습률은 9.999999747378752e-05입니다. 938/938 [==============================] - 11s 12ms/step - loss: 0.0228 - accuracy: 0.9941 Epoch 6/12 937/938 [============================>.] - ETA: 0s - loss: 0.0212 - accuracy: 0.9945 에포크 6의 학습률은 9.999999747378752e-05입니다. 938/938 [==============================] - 11s 12ms/step - loss: 0.0212 - accuracy: 0.9945 Epoch 7/12 937/938 [============================>.] - ETA: 0s - loss: 0.0196 - accuracy: 0.9952 에포크 7의 학습률은 9.999999747378752e-05입니다. 938/938 [==============================] - 12s 12ms/step - loss: 0.0196 - accuracy: 0.9952 Epoch 8/12 937/938 [============================>.] - ETA: 0s - loss: 0.0173 - accuracy: 0.9962 에포크 8의 학습률은 9.999999747378752e-06입니다. 938/938 [==============================] - 11s 12ms/step - loss: 0.0173 - accuracy: 0.9962 Epoch 9/12 936/938 [============================>.] - ETA: 0s - loss: 0.0170 - accuracy: 0.9962 에포크 9의 학습률은 9.999999747378752e-06입니다. 938/938 [==============================] - 11s 12ms/step - loss: 0.0170 - accuracy: 0.9962 Epoch 10/12 936/938 [============================>.] - ETA: 0s - loss: 0.0168 - accuracy: 0.9963 에포크 10의 학습률은 9.999999747378752e-06입니다. 938/938 [==============================] - 11s 12ms/step - loss: 0.0168 - accuracy: 0.9963 Epoch 11/12 936/938 [============================>.] - ETA: 0s - loss: 0.0167 - accuracy: 0.9963 에포크 11의 학습률은 9.999999747378752e-06입니다. 938/938 [==============================] - 11s 12ms/step - loss: 0.0167 - accuracy: 0.9963 Epoch 12/12 937/938 [============================>.] - ETA: 0s - loss: 0.0165 - accuracy: 0.9964 에포크 12의 학습률은 9.999999747378752e-06입니다. 938/938 [==============================] - 11s 12ms/step - loss: 0.0165 - accuracy: 0.9964 <tensorflow.python.keras.callbacks.History at 0x7ff1c05bf7f0>
아래에서 볼 수 있듯이 체크포인트가 저장되고 있습니다.
# 체크포인트 디렉터리 확인하기
ls {checkpoint_dir}
checkpoint ckpt_4.data-00000-of-00001 ckpt_1.data-00000-of-00001 ckpt_4.index ckpt_1.index ckpt_5.data-00000-of-00001 ckpt_10.data-00000-of-00001 ckpt_5.index ckpt_10.index ckpt_6.data-00000-of-00001 ckpt_11.data-00000-of-00001 ckpt_6.index ckpt_11.index ckpt_7.data-00000-of-00001 ckpt_12.data-00000-of-00001 ckpt_7.index ckpt_12.index ckpt_8.data-00000-of-00001 ckpt_2.data-00000-of-00001 ckpt_8.index ckpt_2.index ckpt_9.data-00000-of-00001 ckpt_3.data-00000-of-00001 ckpt_9.index ckpt_3.index
모델의 성능이 어떤지 확인하기 위하여, 가장 최근 체크포인트를 불러온 후 테스트 데이터에 대하여 evaluate
를 호출합니다.
평소와 마찬가지로 적절한 데이터셋과 함께 evaluate
를 호출하면 됩니다.
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('평가 손실: {}, 평가 정확도: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 3s 18ms/step - loss: 0.0382 - accuracy: 0.9868 평가 손실: 0.03817672325612586, 평가 정확도: 0.9868000149726868
텐서보드 로그를 다운로드받은 후 터미널에서 다음과 같이 텐서보드를 실행하여 훈련 결과를 확인할 수 있습니다.
$ tensorboard --logdir=path/to/log-directory
ls -sh ./logs
total 4.0K 4.0K train
SavedModel로 내보내기
플랫폼에 무관한 SavedModel 형식으로 그래프와 변수들을 내보냅니다. 모델을 내보낸 후에는, 전략 범위(scope) 없이 불러올 수도 있고, 전략 범위와 함께 불러올 수도 있습니다.
path = 'saved_model/'
tf.keras.experimental.export_saved_model(model, path)
WARNING:tensorflow:From <ipython-input-1-7f22af6799f5>:1: export_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version. Instructions for updating: Please use `model.save(..., save_format="tf")` or `tf.keras.models.save_model(..., save_format="tf")`. WARNING:tensorflow:From <ipython-input-1-7f22af6799f5>:1: export_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version. Instructions for updating: Please use `model.save(..., save_format="tf")` or `tf.keras.models.save_model(..., save_format="tf")`. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Predict: None INFO:tensorflow:Signatures INCLUDED in export for Predict: None INFO:tensorflow:Signatures INCLUDED in export for Train: ['train'] INFO:tensorflow:Signatures INCLUDED in export for Train: ['train'] INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None WARNING:tensorflow:Export includes no default signature! WARNING:tensorflow:Export includes no default signature! INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to write. INFO:tensorflow:No assets to write. INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Predict: None INFO:tensorflow:Signatures INCLUDED in export for Predict: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval'] INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval'] WARNING:tensorflow:Export includes no default signature! WARNING:tensorflow:Export includes no default signature! INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to write. INFO:tensorflow:No assets to write. INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to write. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: saved_model/saved_model.pb INFO:tensorflow:SavedModel written to: saved_model/saved_model.pb
strategy.scope
없이 모델 불러오기.
unreplicated_model = tf.keras.experimental.load_from_saved_model(path)
unreplicated_model.compile(
loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print('평가 손실: {}, 평가 정확도: {}'.format(eval_loss, eval_acc))
WARNING:tensorflow:From <ipython-input-1-2f23d81b2b21>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version. Instructions for updating: The experimental save and load functions have been deprecated. Please switch to `tf.keras.models.load_model`. WARNING:tensorflow:From <ipython-input-1-2f23d81b2b21>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version. Instructions for updating: The experimental save and load functions have been deprecated. Please switch to `tf.keras.models.load_model`. 157/157 [==============================] - 2s 10ms/step - loss: 0.0382 - accuracy: 0.9868 평가 손실: 0.03817672325612586, 평가 정확도: 0.9868000149726868
strategy.scope
와 함께 모델 불러오기.
with strategy.scope():
replicated_model = tf.keras.experimental.load_from_saved_model(path)
replicated_model.compile(loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
print ('평가 손실: {}, 평가 정확도: {}'.format(eval_loss, eval_acc))
157/157 [==============================] - 3s 17ms/step - loss: 0.0382 - accuracy: 0.9868 평가 손실: 0.03817672325612586, 평가 정확도: 0.9868000149726868
예제와 튜토리얼
케라스 적합/컴파일과 함께 분산 전략을 쓰는 예제들이 더 있습니다.
tf.distribute.MirroredStrategy
를 사용하여 학습한 Transformer 예제.tf.distribute.MirroredStrategy
를 사용하여 학습한 NCF 예제.
분산 전략 가이드에 더 많은 예제 목록이 있습니다.
다음 단계
- 분산 전략 가이드를 읽어보세요.
- 사용자 정의 훈련 루프를 사용한 분산 훈련 튜토리얼을 읽어보세요.