TensorFlow.org에서 보기 | Google Colab에서 실행 | GitHub에서 소스 보기 | 노트북 다운로드 |
이 가이드는 TPU 에서 실행되는 워크플로를 TensorFlow 1의 TPUEstimator
API에서 TensorFlow 2의 TPUStrategy
API로 마이그레이션하는 방법을 보여줍니다.
- TensorFlow 1에서
tf.compat.v1.estimator.tpu.TPUEstimator
API를 사용하면 모델을 훈련 및 평가할 수 있을 뿐만 아니라 추론을 수행하고 (클라우드) TPU에서 모델(제공용)을 저장할 수 있습니다. - TensorFlow 2에서 TPU 및 TPU Pod(전용 고속 네트워크 인터페이스로 연결된 TPU 장치 모음)에서 동기식 훈련을 수행하려면 TPU 배포 전략
tf.distribute.TPUStrategy
를 사용해야 합니다. 이 전략은 모델 구축(tf.keras.Model
), 옵티마이저(tf.keras.optimizers.Optimizer
) 및 교육(Model.fit
)을 포함한 Keras API와 사용자 지정 교육 루프(tf.function
포함)와 함께 작동할 수 있습니다.tf.function
및tf.GradientTape
).
종단 간 TensorFlow 2 예제는 TPU 사용 가이드(즉, TPU의 분류 섹션)와 TPU에서 BERT를 사용하여 GLUE 작업 해결 자습서를 확인하세요. TPUStrategy
를 포함한 모든 TensorFlow 배포 전략을 다루는 Distributed training guide가 유용할 수도 있습니다.
설정
데모용으로 가져오기 및 간단한 데이터세트로 시작합니다.
import tensorflow as tf
import tensorflow.compat.v1 as tf1
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/requests/__init__.py:104: RequestsDependencyWarning: urllib3 (1.26.8) or chardet (2.3.0)/charset_normalizer (2.0.11) doesn't match a supported version! RequestsDependencyWarning)
features = [[1., 1.5]]
labels = [[0.3]]
eval_features = [[4., 4.5]]
eval_labels = [[0.8]]
TensorFlow 1: TPUEstimator를 사용하여 TPU에서 모델 구동
이 가이드 섹션에서는 TensorFlow 1에서 tf.compat.v1.estimator.tpu.TPUEstimator
를 사용하여 교육 및 평가를 수행하는 방법을 보여줍니다.
TPUEstimator
를 사용하려면 먼저 몇 가지 함수를 정의합니다. 학습 데이터에 대한 입력 함수, 평가 데이터에 대한 평가 입력 함수, 학습 작업이 기능 및 레이블로 정의되는 방식을 TPUEstimator
에 알려주는 모델 함수:
def _input_fn(params):
dataset = tf1.data.Dataset.from_tensor_slices((features, labels))
dataset = dataset.repeat()
return dataset.batch(params['batch_size'], drop_remainder=True)
def _eval_input_fn(params):
dataset = tf1.data.Dataset.from_tensor_slices((eval_features, eval_labels))
dataset = dataset.repeat()
return dataset.batch(params['batch_size'], drop_remainder=True)
def _model_fn(features, labels, mode, params):
logits = tf1.layers.Dense(1)(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)
이러한 기능을 정의하고 클러스터 정보를 제공하는 tf.distribute.cluster_resolver.TPUClusterResolver
와 tf.compat.v1.estimator.tpu.RunConfig
객체를 생성합니다. 정의한 모델 함수와 함께 이제 TPUEstimator
를 만들 수 있습니다. 여기에서는 체크포인트 절약을 건너뛰어 흐름을 단순화합니다. 그런 다음 TPUEstimator
에 대한 훈련 및 평가 모두에 대한 배치 크기를 지정합니다.
cluster_resolver = tf1.distribute.cluster_resolver.TPUClusterResolver(tpu='')
print("All devices: ", tf1.config.list_logical_devices('TPU'))
All devices: []
tpu_config = tf1.estimator.tpu.TPUConfig(iterations_per_loop=10)
config = tf1.estimator.tpu.RunConfig(
cluster=cluster_resolver,
save_checkpoints_steps=None,
tpu_config=tpu_config)
estimator = tf1.estimator.tpu.TPUEstimator(
model_fn=_model_fn,
config=config,
train_batch_size=8,
eval_batch_size=8)
WARNING:tensorflow:Estimator's model_fn (<function _model_fn at 0x7fef73ae76a8>) includes params argument, but params are not passed to Estimator. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp_bkua7zf INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp_bkua7zf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true cluster_def { job { name: "worker" tasks { key: 0 value: "10.240.1.2:8470" } } } isolate_session_state: true , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.240.1.2:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.240.1.2:8470', '_evaluation_master': 'grpc://10.240.1.2:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=10, num_shards=None, num_cores_per_replica=None, per_host_input_for_training=2, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1, experimental_allow_per_host_v2_parallel_get_next=False, experimental_feed_hook=None), '_cluster': <tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver.TPUClusterResolver object at 0x7ff288b6aa20>} INFO:tensorflow:_TPUContext: eval_on_tpu True
TPUEstimator.train
을 호출하여 모델 학습을 시작합니다.
estimator.train(_input_fn, steps=1)
INFO:tensorflow:Querying Tensorflow master (grpc://10.240.1.2:8470) for TPU system metadata. INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 2562214468325910549) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 7806191887455116208) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 4935096526614797404) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 6208852770722846295) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, -4484747666522931072) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -8715412538518264422) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, -3521027846460785533) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, -6534172152637582552) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 4735861352635655596) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, -411508280321075475) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 2431932884271560631) WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Bypassing TPUEstimator hook INFO:tensorflow:Done calling model_fn. INFO:tensorflow:TPU job name worker INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:758: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Prefer Variable.assign which has equivalent behavior in 2.X. INFO:tensorflow:Initialized dataset iterators in 0 seconds INFO:tensorflow:Installing graceful shutdown hook. INFO:tensorflow:Creating heartbeat manager for ['/job:worker/replica:0/task:0/device:CPU:0'] INFO:tensorflow:Configuring worker heartbeat: shutdown_mode: WAIT_FOR_COORDINATOR INFO:tensorflow:Init TPU system INFO:tensorflow:Initialized TPU in 7 seconds INFO:tensorflow:Starting infeed thread controller. INFO:tensorflow:Starting outfeed thread controller. INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed. INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed. INFO:tensorflow:Outfeed finished for iteration (0, 0) INFO:tensorflow:loss = 4.462118, step = 1 INFO:tensorflow:Stop infeed thread controller INFO:tensorflow:Shutting down InfeedController thread. INFO:tensorflow:InfeedController received shutdown signal, stopping. INFO:tensorflow:Infeed thread finished, shutting down. INFO:tensorflow:infeed marked as finished INFO:tensorflow:Stop output thread controller INFO:tensorflow:Shutting down OutfeedController thread. INFO:tensorflow:OutfeedController received shutdown signal, stopping. INFO:tensorflow:Outfeed thread finished, shutting down. INFO:tensorflow:outfeed marked as finished INFO:tensorflow:Shutdown TPU system. INFO:tensorflow:Loss for final step: 4.462118. INFO:tensorflow:training_loop marked as finished <tensorflow_estimator.python.estimator.tpu.tpu_estimator.TPUEstimator at 0x7fec59ef9d68>
그런 다음 TPUEstimator.evaluate
를 호출하여 평가 데이터를 사용하여 모델을 평가합니다.
estimator.evaluate(_eval_input_fn, steps=1)
INFO:tensorflow:Could not find trained model in model_dir: /tmp/tmp_bkua7zf, running initialization to evaluate. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:3406: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Deprecated in favor of operator or tf.math.divide. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-02-05T13:15:25 INFO:tensorflow:TPU job name worker INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Init TPU system INFO:tensorflow:Initialized TPU in 10 seconds INFO:tensorflow:Starting infeed thread controller. INFO:tensorflow:Starting outfeed thread controller. INFO:tensorflow:Initialized dataset iterators in 0 seconds INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed. INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed. INFO:tensorflow:Outfeed finished for iteration (0, 0) INFO:tensorflow:Evaluation [1/1] INFO:tensorflow:Stop infeed thread controller INFO:tensorflow:Shutting down InfeedController thread. INFO:tensorflow:InfeedController received shutdown signal, stopping. INFO:tensorflow:Infeed thread finished, shutting down. INFO:tensorflow:infeed marked as finished INFO:tensorflow:Stop output thread controller INFO:tensorflow:Shutting down OutfeedController thread. INFO:tensorflow:OutfeedController received shutdown signal, stopping. INFO:tensorflow:Outfeed thread finished, shutting down. INFO:tensorflow:outfeed marked as finished INFO:tensorflow:Shutdown TPU system. INFO:tensorflow:Inference Time : 10.80091s INFO:tensorflow:Finished evaluation at 2022-02-05-13:15:36 INFO:tensorflow:Saving dict for global step 1: global_step = 1, loss = 116.58184 INFO:tensorflow:evaluation_loop marked as finished {'loss': 116.58184, 'global_step': 1}
TensorFlow 2: Keras Model.fit 및 TPUStrategy를 사용하여 TPU에서 모델 구동
TensorFlow 2에서 TPU 작업자를 교육하려면 모델 정의 및 교육/평가를 위해 tf.distribute.TPUStrategy
API와 함께 tf.distribute.TPUStrategy를 사용하세요. ( Model.fit
및 맞춤형 훈련 루프( tf.function
및 tf.GradientTape
)를 사용한 훈련에 대한 더 많은 예는 TPU 사용 가이드를 참조하십시오.)
원격 클러스터에 접속하고 TPU 워커를 초기화하기 위해서는 초기화 작업이 필요하므로 먼저 TPUClusterResolver
를 생성하여 클러스터 정보를 제공하고 클러스터에 접속한다. (TPU 사용 가이드의 TPU 초기화 섹션에서 자세히 알아보세요.)
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.2:8470 INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.2:8470 INFO:tensorflow:Finished initializing TPU system. INFO:tensorflow:Finished initializing TPU system. All devices: [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]
다음으로 데이터가 준비되면 TPUStrategy
를 만들고 이 전략의 범위에서 모델, 메트릭 및 옵티마이저를 정의합니다.
TPUStrategy
로 비슷한 훈련 속도를 얻으려면 각 tf.function
호출 중에 실행할 배치 수를 지정하고 성능에 중요하기 때문에 steps_per_execution
에서 Model.compile
에 대한 숫자를 선택해야 합니다. 이 인수는 TPUEstimator
에서 사용되는 iterations_per_loop
와 유사합니다. 사용자 지정 훈련 루프를 사용하는 경우 tf.function
-ed 훈련 함수 내에서 여러 단계를 실행해야 합니다. 자세한 내용은 TPU 사용 가이드의 tf.function 섹션에서 여러 단계를 통해 성능 향상 을 참조하세요.
tf.distribute.TPUStrategy
는 제한된 동적 모양을 지원할 수 있으며, 이는 동적 모양 계산의 상한을 유추할 수 있는 경우입니다. 그러나 동적 모양은 정적 모양에 비해 약간의 성능 오버헤드를 유발할 수 있습니다. 따라서, 특히 훈련에서 가능하면 입력 모양을 정적으로 만드는 것이 일반적으로 권장됩니다. 스트림에 남아 있는 샘플 수가 배치 크기보다 작을 수 있으므로 동적 모양을 반환하는 일반적인 작업 중 하나는 tf.data.Dataset.batch(batch_size)
입니다. 따라서 TPU에서 훈련할 때 최상의 훈련 성능을 위해 tf.data.Dataset.batch(..., drop_remainder=True)
를 사용해야 합니다.
dataset = tf.data.Dataset.from_tensor_slices(
(features, labels)).shuffle(10).repeat().batch(
8, drop_remainder=True).prefetch(2)
eval_dataset = tf.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1, drop_remainder=True)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
with strategy.scope():
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse", steps_per_execution=10)
INFO:tensorflow:Found TPU system: INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
이것으로 훈련 데이터 세트로 모델을 훈련할 준비가 되었습니다.
model.fit(dataset, epochs=5, steps_per_epoch=10)
Epoch 1/5 10/10 [==============================] - 2s 151ms/step - loss: 0.0840 Epoch 2/5 10/10 [==============================] - 0s 3ms/step - loss: 9.6915e-04 Epoch 3/5 10/10 [==============================] - 0s 3ms/step - loss: 1.5100e-05 Epoch 4/5 10/10 [==============================] - 0s 3ms/step - loss: 2.3593e-07 Epoch 5/5 10/10 [==============================] - 0s 3ms/step - loss: 3.7059e-09 <keras.callbacks.History at 0x7fec58275438>
마지막으로 평가 데이터 세트를 사용하여 모델을 평가합니다.
model.evaluate(eval_dataset, return_dict=True)
1/1 [==============================] - 2s 2s/step - loss: 0.6127 {'loss': 0.6127181053161621}
다음 단계
TensorFlow 2의 TPUStrategy
에 대해 자세히 알아보려면 다음 리소스를 고려하세요.
- 가이드: TPU 사용 (
Model.fit
을 사용한 훈련 /a tf.distribute.TPUStrategy를 사용한 맞춤형 훈련 루프 및tf.distribute.TPUStrategy
으로 성능 향상에 대한 팁tf.function
) - 가이드: TensorFlow를 사용한 분산 교육
훈련 사용자 지정에 대한 자세한 내용은 다음을 참조하십시오.
- 가이드: Model.fit에서 발생하는 작업 사용자 지정
- 가이드: 처음부터 훈련 루프 작성하기
기계 학습을 위한 Google의 특수 ASIC인 TPU는 Google Colab , TPU Research Cloud 및 Cloud TPU 를 통해 사용할 수 있습니다.