ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูแหล่งที่มาบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
คู่มือนี้สาธิตวิธีโยกย้ายการฝึกอบรมการฝังบน TPU จาก embedding_column API ของ TensorFlow 1 พร้อม TPUEstimator ไปยัง API เลเยอร์ TPUEmbedding ของ TensorFlow 2 ด้วย TPUStrategy
การฝังเป็นเมทริกซ์ (ใหญ่) เป็นตารางค้นหาที่แมปจากพื้นที่คุณสมบัติกระจัดกระจายไปจนถึงเวกเตอร์หนาแน่น การฝังให้การแสดงที่มีประสิทธิภาพและหนาแน่น โดยจับความคล้ายคลึงที่ซับซ้อนและความสัมพันธ์ระหว่างคุณลักษณะต่างๆ
TensorFlow มีการสนับสนุนเฉพาะสำหรับการฝังการฝึกอบรมบน TPU การรองรับการฝังเฉพาะสำหรับ TPU นี้ช่วยให้คุณฝึกการฝังที่มีขนาดใหญ่กว่าหน่วยความจำของอุปกรณ์ TPU เครื่องเดียว และใช้อินพุตแบบกระจัดกระจายและขาดช่วงบน TPU
- ใน TensorFlow 1
tf.compat.v1.estimator.tpu.TPUEstimatorเป็น API ระดับสูงที่สรุปการฝึกอบรม การประเมิน การคาดคะเน และการส่งออกสำหรับการให้บริการด้วย TPU มีการสนับสนุนพิเศษสำหรับtf.compat.v1.tpu.experimental.embedding_column - ในการดำเนินการนี้ใน TensorFlow 2 ให้ใช้เลเยอร์ tfrs.layers.embedding.TPUEmbedding ของ
tfrs.layers.embedding.TPUEmbeddingRecommenders สำหรับการฝึกอบรมและการประเมิน ให้ใช้กลยุทธ์การแจกจ่ายtf.distribute.TPUStrategy— ซึ่งเข้ากันได้กับ Keras API เช่น การสร้างแบบจำลอง (tf.keras.Model) เครื่องมือเพิ่มประสิทธิภาพ (tf.keras.optimizers.Optimizer) และการฝึกด้วยModel.fitหรือการฝึกวนซ้ำแบบกำหนดเองด้วยtf.functionและtf.GradientTape
สำหรับข้อมูลเพิ่มเติม โปรดดูเอกสาร API ของเลเยอร์ tfrs.layers.embedding.TPUEmbedding รวมถึงเอกสาร tf.tpu.experimental.embedding.TableConfig และ tf.tpu.experimental.embedding.FeatureConfig สำหรับข้อมูลเพิ่มเติม สำหรับภาพรวมของ tf.distribute.TPUStrategy โปรดดูคู่มือ การฝึกอบรมแบบกระจาย และคู่มือการ ใช้ TPU หากคุณกำลังย้ายจาก TPUEstimator เป็น TPUStrategy โปรดดู คู่มือการย้าย TPU
ติดตั้ง
เริ่มต้นด้วยการติดตั้ง TensorFlow Recommenders และนำเข้าแพ็คเกจที่จำเป็น:
pip install tensorflow-recommenders
import tensorflow as tf
import tensorflow.compat.v1 as tf1
# TPUEmbedding layer is not part of TensorFlow.
import tensorflow_recommenders as tfrs
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/requests/__init__.py:104: RequestsDependencyWarning: urllib3 (1.26.8) or chardet (2.3.0)/charset_normalizer (2.0.11) doesn't match a supported version! RequestsDependencyWarning)
และเตรียมชุดข้อมูลอย่างง่ายเพื่อการสาธิต:
features = [[1., 1.5]]
embedding_features_indices = [[0, 0], [0, 1]]
embedding_features_values = [0, 5]
labels = [[0.3]]
eval_features = [[4., 4.5]]
eval_embedding_features_indices = [[0, 0], [0, 1]]
eval_embedding_features_values = [4, 3]
eval_labels = [[0.8]]
TensorFlow 1: ฝึกการฝังบน TPU ด้วย TPUEstimator
ใน TensorFlow 1 คุณตั้งค่าการฝัง TPU โดยใช้ tf.compat.v1.tpu.experimental.embedding_column API และฝึกฝน/ประเมินโมเดลบน TPU ด้วย tf.compat.v1.estimator.tpu.TPUEstimator
อินพุตเป็นจำนวนเต็มตั้งแต่ศูนย์จนถึงขนาดคำศัพท์สำหรับตารางการฝัง TPU เริ่มต้นด้วยการเข้ารหัสอินพุตไปยัง categorical ID ด้วย tf.feature_column.categorical_column_with_identity ใช้ "sparse_feature" สำหรับพารามิเตอร์ key เนื่องจากคุณลักษณะอินพุตเป็นค่าจำนวนเต็ม ขณะที่ num_buckets คือขนาดคำศัพท์สำหรับตารางการฝัง ( 10 )
embedding_id_column = (
tf1.feature_column.categorical_column_with_identity(
key="sparse_feature", num_buckets=10))
ถัดไป ให้แปลงอินพุตการจัดหมวดหมู่แบบกระจายเป็นการแสดงแบบหนาแน่นด้วย tpu.experimental.embedding_column โดยที่ dimension คือความกว้างของตารางการฝัง จะเก็บเวกเตอร์ฝังสำหรับแต่ละ num_buckets
embedding_column = tf1.tpu.experimental.embedding_column(
embedding_id_column, dimension=5)
ตอนนี้ กำหนดการกำหนดค่าการฝังเฉพาะสำหรับ TPU ผ่าน tf.estimator.tpu.experimental.EmbeddingConfigSpec คุณจะส่งต่อไปยัง tf.estimator.tpu.TPUEstimator เป็นพารามิเตอร์ embedding_config_spec
embedding_config_spec = tf1.estimator.tpu.experimental.EmbeddingConfigSpec(
feature_columns=(embedding_column,),
optimization_parameters=(
tf1.tpu.experimental.AdagradParameters(0.05)))
ถัดไป เพื่อใช้ TPUEstimator ให้กำหนด:
- ฟังก์ชันอินพุตสำหรับข้อมูลการฝึก
- ฟังก์ชันอินพุตการประเมินสำหรับข้อมูลการประเมิน
- ฟังก์ชัน model สำหรับสั่งสอน
TPUEstimatorว่า Training op ถูกกำหนดด้วยคุณสมบัติและป้ายกำกับอย่างไร
def _input_fn(params):
dataset = tf1.data.Dataset.from_tensor_slices((
{"dense_feature": features,
"sparse_feature": tf1.SparseTensor(
embedding_features_indices,
embedding_features_values, [1, 2])},
labels))
dataset = dataset.repeat()
return dataset.batch(params['batch_size'], drop_remainder=True)
def _eval_input_fn(params):
dataset = tf1.data.Dataset.from_tensor_slices((
{"dense_feature": eval_features,
"sparse_feature": tf1.SparseTensor(
eval_embedding_features_indices,
eval_embedding_features_values, [1, 2])},
eval_labels))
dataset = dataset.repeat()
return dataset.batch(params['batch_size'], drop_remainder=True)
def _model_fn(features, labels, mode, params):
embedding_features = tf1.keras.layers.DenseFeatures(embedding_column)(features)
concatenated_features = tf1.keras.layers.Concatenate(axis=1)(
[embedding_features, features["dense_feature"]])
logits = tf1.layers.Dense(1)(concatenated_features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
optimizer = tf1.tpu.CrossShardOptimizer(optimizer)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)
ด้วยฟังก์ชันที่กำหนดไว้ ให้สร้าง tf.distribute.cluster_resolver.TPUClusterResolver ที่จัดเตรียมข้อมูลคลัสเตอร์ และอ็อบเจ็กต์ tf.compat.v1.estimator.tpu.RunConfig
นอกจากฟังก์ชัน model ที่คุณกำหนดแล้ว คุณสามารถสร้าง TPUEstimator ได้แล้ว ที่นี่ คุณจะลดความซับซ้อนของโฟลว์โดยข้ามการประหยัดด่าน จากนั้น คุณจะต้องระบุขนาดแบทช์สำหรับทั้งการฝึกอบรมและการประเมินสำหรับ TPUEstimator
cluster_resolver = tf1.distribute.cluster_resolver.TPUClusterResolver(tpu='')
print("All devices: ", tf1.config.list_logical_devices('TPU'))
All devices: []
tpu_config = tf1.estimator.tpu.TPUConfig(
iterations_per_loop=10,
per_host_input_for_training=tf1.estimator.tpu.InputPipelineConfig
.PER_HOST_V2)
config = tf1.estimator.tpu.RunConfig(
cluster=cluster_resolver,
save_checkpoints_steps=None,
tpu_config=tpu_config)
estimator = tf1.estimator.tpu.TPUEstimator(
model_fn=_model_fn, config=config, train_batch_size=8, eval_batch_size=8,
embedding_config_spec=embedding_config_spec)
WARNING:tensorflow:Estimator's model_fn (<function _model_fn at 0x7eff1dbf4ae8>) includes params argument, but params are not passed to Estimator.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpc68an8jx
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpc68an8jx', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
job {
name: "worker"
tasks {
key: 0
value: "10.240.1.2:8470"
}
}
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.240.1.2:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.240.1.2:8470', '_evaluation_master': 'grpc://10.240.1.2:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=10, num_shards=None, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1, experimental_allow_per_host_v2_parallel_get_next=False, experimental_feed_hook=None), '_cluster': <tensorflow.python.distribute.cluster_resolver.tpu.tpu_cluster_resolver.TPUClusterResolver object at 0x7eff1dbfa2b0>}
INFO:tensorflow:_TPUContext: eval_on_tpu True
โทร TPUEstimator.train เพื่อเริ่มฝึกโมเดล:
estimator.train(_input_fn, steps=1)
INFO:tensorflow:Querying Tensorflow master (grpc://10.240.1.2:8470) for TPU system metadata. INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, -3018931587863375246) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1249032734884062775) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, -3881759543008185868) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -3421771184935649663) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 8872583169621331661) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -1222373804129613329) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 6258068298163390748) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 5190265587768274342) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 3073578684150069836) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 2071242092327503173) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, -1319360343564144287) WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/tpu/feature_column_v2.py:479: IdentityCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead. INFO:tensorflow:Querying Tensorflow master (grpc://10.240.1.2:8470) for TPU system metadata. INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, -3018931587863375246) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1249032734884062775) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, -3881759543008185868) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -3421771184935649663) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 8872583169621331661) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -1222373804129613329) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 6258068298163390748) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 5190265587768274342) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 3073578684150069836) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 2071242092327503173) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, -1319360343564144287) WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Bypassing TPUEstimator hook INFO:tensorflow:Done calling model_fn. INFO:tensorflow:TPU job name worker INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:758: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Prefer Variable.assign which has equivalent behavior in 2.X. INFO:tensorflow:Initialized dataset iterators in 0 seconds INFO:tensorflow:Installing graceful shutdown hook. INFO:tensorflow:Creating heartbeat manager for ['/job:worker/replica:0/task:0/device:CPU:0'] INFO:tensorflow:Configuring worker heartbeat: shutdown_mode: WAIT_FOR_COORDINATOR INFO:tensorflow:Init TPU system INFO:tensorflow:Initialized TPU in 9 seconds INFO:tensorflow:Starting infeed thread controller. INFO:tensorflow:Starting outfeed thread controller. INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed. INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed. INFO:tensorflow:Outfeed finished for iteration (0, 0) INFO:tensorflow:loss = 0.5212165, step = 1 INFO:tensorflow:Stop infeed thread controller INFO:tensorflow:Shutting down InfeedController thread. INFO:tensorflow:InfeedController received shutdown signal, stopping. INFO:tensorflow:Infeed thread finished, shutting down. INFO:tensorflow:infeed marked as finished INFO:tensorflow:Stop output thread controller INFO:tensorflow:Shutting down OutfeedController thread. INFO:tensorflow:OutfeedController received shutdown signal, stopping. INFO:tensorflow:Outfeed thread finished, shutting down. INFO:tensorflow:outfeed marked as finished INFO:tensorflow:Shutdown TPU system. INFO:tensorflow:Loss for final step: 0.5212165. INFO:tensorflow:training_loop marked as finished <tensorflow_estimator.python.estimator.tpu.tpu_estimator.TPUEstimator at 0x7eff1dbfa7b8>
จากนั้นเรียก TPUEstimator.evaluate เพื่อประเมินโมเดลโดยใช้ข้อมูลการประเมิน:
estimator.evaluate(_eval_input_fn, steps=1)
INFO:tensorflow:Could not find trained model in model_dir: /tmp/tmpc68an8jx, running initialization to evaluate.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Querying Tensorflow master (grpc://10.240.1.2:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, -3018931587863375246)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 1249032734884062775)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, -3881759543008185868)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -3421771184935649663)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 8872583169621331661)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -1222373804129613329)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 6258068298163390748)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 5190265587768274342)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 3073578684150069836)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 2071242092327503173)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, -1319360343564144287)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:3406: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-02-05T13:21:42
INFO:tensorflow:TPU job name worker
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Init TPU system
INFO:tensorflow:Initialized TPU in 11 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
INFO:tensorflow:Initialized dataset iterators in 0 seconds
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Outfeed finished for iteration (0, 0)
INFO:tensorflow:Evaluation [1/1]
INFO:tensorflow:Stop infeed thread controller
INFO:tensorflow:Shutting down InfeedController thread.
INFO:tensorflow:InfeedController received shutdown signal, stopping.
INFO:tensorflow:Infeed thread finished, shutting down.
INFO:tensorflow:infeed marked as finished
INFO:tensorflow:Stop output thread controller
INFO:tensorflow:Shutting down OutfeedController thread.
INFO:tensorflow:OutfeedController received shutdown signal, stopping.
INFO:tensorflow:Outfeed thread finished, shutting down.
INFO:tensorflow:outfeed marked as finished
INFO:tensorflow:Shutdown TPU system.
INFO:tensorflow:Inference Time : 12.50468s
INFO:tensorflow:Finished evaluation at 2022-02-05-13:21:54
INFO:tensorflow:Saving dict for global step 1: global_step = 1, loss = 36.28813
INFO:tensorflow:evaluation_loop marked as finished
{'loss': 36.28813, 'global_step': 1}
TensorFlow 2: ฝึกการฝังบน TPU ด้วย TPUStrategy
ใน TensorFlow 2 เพื่อฝึกอบรมผู้ปฏิบัติงาน TPU ให้ใช้ tf.distribute.TPUStrategy ร่วมกับ Keras API สำหรับการกำหนดโมเดลและการฝึกอบรม/การประเมิน (โปรดดูคู่มือการ ใช้ TPU สำหรับตัวอย่างเพิ่มเติมของการฝึกด้วย Keras Model.fit และการฝึกวนซ้ำแบบกำหนดเอง (ด้วย tf.function และ tf.GradientTape ))
เนื่องจากคุณจำเป็นต้องดำเนินการเริ่มต้นบางอย่างเพื่อเชื่อมต่อกับคลัสเตอร์ระยะไกลและเริ่มต้นผู้ปฏิบัติงาน TPU ให้เริ่มต้นด้วยการสร้าง TPUClusterResolver เพื่อให้ข้อมูลคลัสเตอร์และเชื่อมต่อกับคลัสเตอร์ (ดูข้อมูลเพิ่มเติมในส่วนการ เริ่มต้น TPU ของคู่มือการ ใช้ TPU)
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.2:8470 INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.2:8470 INFO:tensorflow:Finished initializing TPU system. INFO:tensorflow:Finished initializing TPU system. All devices: [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]
ถัดไป เตรียมข้อมูลของคุณ สิ่งนี้คล้ายกับวิธีที่คุณสร้างชุดข้อมูลในตัวอย่าง TensorFlow 1 ยกเว้นว่าตอนนี้ฟังก์ชันชุดข้อมูลถูกส่งผ่านอ็อบเจ็กต์ tf.distribute.InputContext แทนที่จะเป็น params dict คุณสามารถใช้อ็อบเจ็กต์นี้เพื่อกำหนดขนาดแบตช์ในเครื่องได้ (และไปป์ไลน์นี้มีไว้สำหรับโฮสต์ใด เพื่อให้คุณสามารถแบ่งพาร์ติชั่นข้อมูลของคุณได้อย่างถูกต้อง)
- เมื่อใช้
tfrs.layers.embedding.TPUEmbeddingAPI สิ่งสำคัญคือต้องรวมตัวเลือกdrop_remainder=Trueเมื่อทำการแบทช์ชุดข้อมูลด้วยDataset.batchเนื่องจากTPUEmbeddingต้องการขนาดแบตช์คงที่ - นอกจากนี้ ต้องใช้ขนาดแบทช์เดียวกันสำหรับการประเมินและการฝึกอบรม หากเกิดขึ้นในอุปกรณ์ชุดเดียวกัน
- สุดท้าย คุณควรใช้
tf.keras.utils.experimental.DatasetCreatorร่วมกับตัวเลือกอินพุตพิเศษ —experimental_fetch_to_device=False— ในtf.distribute.InputOptions(ซึ่งมีการกำหนดค่าเฉพาะกลยุทธ์) นี่แสดงให้เห็นด้านล่าง:
global_batch_size = 8
def _input_dataset(context: tf.distribute.InputContext):
dataset = tf.data.Dataset.from_tensor_slices((
{"dense_feature": features,
"sparse_feature": tf.SparseTensor(
embedding_features_indices,
embedding_features_values, [1, 2])},
labels))
dataset = dataset.shuffle(10).repeat()
dataset = dataset.batch(
context.get_per_replica_batch_size(global_batch_size),
drop_remainder=True)
return dataset.prefetch(2)
def _eval_dataset(context: tf.distribute.InputContext):
dataset = tf.data.Dataset.from_tensor_slices((
{"dense_feature": eval_features,
"sparse_feature": tf.SparseTensor(
eval_embedding_features_indices,
eval_embedding_features_values, [1, 2])},
eval_labels))
dataset = dataset.repeat()
dataset = dataset.batch(
context.get_per_replica_batch_size(global_batch_size),
drop_remainder=True)
return dataset.prefetch(2)
input_options = tf.distribute.InputOptions(
experimental_fetch_to_device=False)
input_dataset = tf.keras.utils.experimental.DatasetCreator(
_input_dataset, input_options=input_options)
eval_dataset = tf.keras.utils.experimental.DatasetCreator(
_eval_dataset, input_options=input_options)
ถัดไป เมื่อเตรียมข้อมูลแล้ว คุณจะต้องสร้าง TPUStrategy และกำหนดโมเดล หน่วยวัด และเครื่องมือเพิ่มประสิทธิภาพภายใต้ขอบเขตของกลยุทธ์นี้ ( Strategy.scope )
คุณควรเลือกหมายเลขสำหรับ steps_per_execution ใน Model.compile เนื่องจากจะระบุจำนวนแบตช์ที่จะรันในระหว่างการเรียก tf.function แต่ละครั้ง และมีความสำคัญต่อประสิทธิภาพ อาร์กิวเมนต์นี้คล้ายกับ iterations_per_loop ที่ใช้ใน TPUEstimator
คุณลักษณะและการกำหนดค่าตารางที่ระบุใน TensorFlow 1 ผ่าน tf.tpu.experimental.embedding_column (และ tf.tpu.experimental.shared_embedding_column ) สามารถระบุได้โดยตรงใน TensorFlow 2 ผ่านคู่ของอ็อบเจ็กต์การกำหนดค่า:
(ดูรายละเอียดเพิ่มเติมในเอกสารประกอบ API ที่เกี่ยวข้อง)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
with strategy.scope():
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
dense_input = tf.keras.Input(shape=(2,), dtype=tf.float32, batch_size=global_batch_size)
sparse_input = tf.keras.Input(shape=(), dtype=tf.int32, batch_size=global_batch_size)
embedded_input = tfrs.layers.embedding.TPUEmbedding(
feature_config=tf.tpu.experimental.embedding.FeatureConfig(
table=tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=10,
dim=5,
initializer=tf.initializers.TruncatedNormal(mean=0.0, stddev=1)),
name="sparse_input"),
optimizer=optimizer)(sparse_input)
input = tf.keras.layers.Concatenate(axis=1)([dense_input, embedded_input])
result = tf.keras.layers.Dense(1)(input)
model = tf.keras.Model(inputs={"dense_feature": dense_input, "sparse_feature": sparse_input}, outputs=result)
model.compile(optimizer, "mse", steps_per_execution=10)
INFO:tensorflow:Found TPU system: INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
ด้วยเหตุนี้ คุณจึงพร้อมที่จะฝึกโมเดลด้วยชุดข้อมูลการฝึก:
model.fit(input_dataset, epochs=5, steps_per_epoch=10)
Epoch 1/5 10/10 [==============================] - 2s 164ms/step - loss: 0.4005 Epoch 2/5 10/10 [==============================] - 0s 3ms/step - loss: 0.0036 Epoch 3/5 10/10 [==============================] - 0s 3ms/step - loss: 3.0932e-05 Epoch 4/5 10/10 [==============================] - 0s 3ms/step - loss: 2.5767e-07 Epoch 5/5 10/10 [==============================] - 0s 3ms/step - loss: 2.1366e-09 <keras.callbacks.History at 0x7efd8c461c18>ตัวยึดตำแหน่ง23
สุดท้าย ประเมินแบบจำลองโดยใช้ชุดข้อมูลการประเมิน:
model.evaluate(eval_dataset, steps=1, return_dict=True)
1/1 [==============================] - 1s 1s/step - loss: 15.3952
{'loss': 15.395216941833496}
ขั้นตอนถัดไป
เรียนรู้เพิ่มเติมเกี่ยวกับการตั้งค่าการฝังเฉพาะของ TPU ในเอกสาร API:
-
tfrs.layers.embedding.TPUEmbedding: โดยเฉพาะเกี่ยวกับคุณลักษณะและการกำหนดค่าตาราง การตั้งค่าตัวเพิ่มประสิทธิภาพ การสร้างแบบจำลอง (โดยใช้ Keras functional API หรือผ่าน คลาสย่อยtf.keras.Model) การฝึกอบรม/การประเมิน และรูปแบบที่ให้บริการด้วยtf.saved_model -
tf.tpu.experimental.embedding.TableConfig -
tf.tpu.experimental.embedding.FeatureConfig
สำหรับข้อมูลเพิ่มเติมเกี่ยวกับ TPUStrategy ใน TensorFlow 2 ให้พิจารณาแหล่งข้อมูลต่อไปนี้:
- คำแนะนำ: ใช้ TPU (ครอบคลุมการฝึกด้วย Keras
Model.fit/a ลูปการฝึกแบบกำหนดเองด้วยtf.distribute.TPUStrategyตลอดจนเคล็ดลับในการปรับปรุงประสิทธิภาพด้วยtf.function) - คู่มือ: การฝึกอบรมแบบกระจายด้วย TensorFlow
- คำแนะนำ: ย้ายจาก TPUEstimator เป็น TPUStrategy
หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับการปรับแต่งการฝึกของคุณ โปรดดูที่:
- คำแนะนำ: ปรับแต่งสิ่งที่เกิดขึ้นใน Model.fit
- คำแนะนำ: การเขียนวงจรการฝึกตั้งแต่เริ่มต้น
TPU—ASIC เฉพาะของ Google สำหรับแมชชีนเลิร์นนิง—มีให้ใช้งานผ่าน Google Colab , TPU Research Cloud และ Cloud TPU
ดูบน TensorFlow.org
ทำงานใน Google Colab
ดูแหล่งที่มาบน GitHub
ดาวน์โหลดโน๊ตบุ๊ค