הצג באתר TensorFlow.org | הפעל בגוגל קולאב | צפה במקור ב-GitHub | הורד מחברת |
מסמך זה מציג את tf.estimator
- API ברמה גבוהה של TensorFlow. אומדנים כוללים את הפעולות הבאות:
- הַדְרָכָה
- הַעֲרָכָה
- נְבוּאָה
- ייצוא להגשה
TensorFlow מיישמת מספר אומדנים מוכנים מראש. אומדנים מותאמים אישית עדיין נתמכים, אבל בעיקר כאמצעי תאימות לאחור. אין להשתמש באומדנים מותאמים אישית עבור קוד חדש . כל האומדנים - מוכנים מראש או מותאמים אישית - הם מחלקות המבוססות על המחלקה tf.estimator.Estimator
.
לדוגמא מהירה, נסה את מדריכי האומדן . לסקירה כללית של עיצוב ה-API, עיין בנייר הלבן .
להכין
pip install -U tensorflow_datasets
import tempfile
import os
import tensorflow as tf
import tensorflow_datasets as tfds
יתרונות
בדומה ל- tf.keras.Model
, estimator
הוא הפשטה ברמת המודל. ה- tf.estimator
מספק כמה יכולות כרגע עדיין בפיתוח עבור tf.keras
. אלו הם:
- אימון מבוסס שרת פרמטרים
- אינטגרציה מלאה של TFX
יכולות אומדנים
אומדנים מספקים את היתרונות הבאים:
- אתה יכול להפעיל מודלים מבוססי Estimator על מארח מקומי או על סביבת ריבוי שרתים מבוזרת מבלי לשנות את המודל שלך. יתרה מזאת, אתה יכול להריץ מודלים מבוססי Estimator על CPUs, GPUs או TPUs מבלי לקוד מחדש את המודל שלך.
- האומדנים מספקים לולאת אימון מבוזרת בטוחה השולטת כיצד ומתי:
- לטעון מידע
- לטפל בחריגים
- צור קבצי מחסום והתאושש מתקלות
- שמור סיכומים עבור TensorBoard
בעת כתיבת אפליקציה עם Estimators, עליך להפריד את צינור קלט הנתונים מהמודל. הפרדה זו מפשטת ניסויים עם מערכי נתונים שונים.
שימוש באומדנים מוכנים מראש
אומדנים מוכנים מראש מאפשרים לך לעבוד ברמה קונספטואלית גבוהה בהרבה מזו של ממשקי ה-API של TensorFlow הבסיסיים. אתה כבר לא צריך לדאוג לגבי יצירת הגרף החישובי או ההפעלות מכיוון שאומדנים מטפלים בכל ה"צנרת" עבורך. יתר על כן, אומדנים מוכנים מראש מאפשרים לך להתנסות בארכיטקטורות מודלים שונות על ידי ביצוע שינויים קוד מינימליים בלבד. tf.estimator.DNNClassifier
, למשל, היא מחלקת Estimator שהוכנה מראש, המאמנת מודלים של סיווג המבוססים על רשתות עצביות צפופות ומזינות קדימה.
תוכנית TensorFlow הנשענת על אומדן שהוכן מראש מורכבת בדרך כלל מארבעת השלבים הבאים:
1. כתוב פונקציות קלט
לדוגמה, תוכל ליצור פונקציה אחת לייבוא ערכת ההדרכה ופונקציה אחרת לייבא את ערכת הבדיקות. מעריכים מצפים שהקלטים שלהם יהיו מעוצבים כזוג אובייקטים:
- מילון שבו המפתחות הם שמות תכונה והערכים הם Tensors (או SparseTensors) המכילים את נתוני התכונה המתאימים
- טנסור המכיל תווית אחת או יותר
ה- input_fn
צריך להחזיר tf.data.Dataset
שמניב זוגות בפורמט הזה.
לדוגמה, הקוד הבא בונה tf.data.Dataset
מקובץ train.csv
של מערך הנתונים של Titanic:
def train_input_fn():
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=32,
label_name="survived")
titanic_batches = (
titanic.cache().repeat().shuffle(500)
.prefetch(tf.data.AUTOTUNE))
return titanic_batches
ה- input_fn
מבוצע ב- tf.Graph
ויכול גם להחזיר ישירות זוג (features_dics, labels)
המכיל טנסור גרפים, אבל זה נוטה לשגיאה מחוץ למקרים פשוטים כמו החזרת קבועים.
2. הגדר את עמודות התכונה.
כל tf.feature_column
מזהה שם תכונה, סוגה וכל עיבוד מקדים של קלט.
לדוגמה, הקטע הבא יוצר שלוש עמודות תכונה.
- הראשון משתמש בתכונת
age
ישירות כקלט של נקודה צפה. - השני משתמש בתכונת
class
כקלט קטגורי. - השלישי משתמש ב-
embark_town
כקלט קטגורי, אך משתמשhashing trick
כדי למנוע את הצורך למנות את האפשרויות ולקבוע את מספר האפשרויות.
למידע נוסף, עיין במדריך עמודות תכונה .
age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third'])
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
3. הצג את האומד הרלוונטי המוכן מראש.
לדוגמה, הנה דוגמה לדוגמה של הערכה מוכנה מראש בשם LinearClassifier
:
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=[embark, cls, age],
n_classes=2
)
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpl24pp3cp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
למידע נוסף, תוכל לעבור למדריך המסווג הליניארי .
4. קרא לשיטת אימון, הערכה או מסקנות.
כל האומדנים מספקים שיטות train
, evaluate
predict
.
model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_v1.py:1684: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead. warnings.warn('`layer.add_variable` is deprecated and ' WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/ftrl.py:147: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpl24pp3cp/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100... INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpl24pp3cp/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100... INFO:tensorflow:Loss for final step: 0.6319582. 2021-09-22 20:49:10.453286: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
result = model.evaluate(train_input_fn, steps=10)
for key, value in result.items():
print(key, ":", value)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:11 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpl24pp3cp/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.74609s INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:12 INFO:tensorflow:Saving dict for global step 100: accuracy = 0.734375, accuracy_baseline = 0.640625, auc = 0.7373913, auc_precision_recall = 0.64306235, average_loss = 0.563341, global_step = 100, label/mean = 0.359375, loss = 0.563341, precision = 0.734375, prediction/mean = 0.3463129, recall = 0.40869564 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpl24pp3cp/model.ckpt-100 accuracy : 0.734375 accuracy_baseline : 0.640625 auc : 0.7373913 auc_precision_recall : 0.64306235 average_loss : 0.563341 label/mean : 0.359375 loss : 0.563341 precision : 0.734375 prediction/mean : 0.3463129 recall : 0.40869564 global_step : 100 2021-09-22 20:49:12.168629: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
for pred in model.predict(train_input_fn):
for key, value in pred.items():
print(key, ":", value)
break
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpl24pp3cp/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. logits : [-1.5173098] logistic : [0.17985801] probabilities : [0.820142 0.17985801] class_ids : [0] classes : [b'0'] all_class_ids : [0 1] all_classes : [b'0' b'1'] 2021-09-22 20:49:13.076528: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
היתרונות של אומדנים מוכנים מראש
מעריכים מוכנים מראש מקודדים שיטות עבודה מומלצות, ומספקים את היתרונות הבאים:
- שיטות עבודה מומלצות לקביעה היכן חלקים שונים של הגרף החישובי צריכים לפעול, יישום אסטרטגיות במכונה בודדת או באשכול.
- שיטות עבודה מומלצות לכתיבת אירועים (סיכום) וסיכומים שימושיים אוניברסליים.
אם אינך משתמש באומדנים מוכנים מראש, עליך ליישם את התכונות הקודמות בעצמך.
אומדנים מותאמים אישית
הלב של כל הערכה - בין אם הוא מוכן מראש או בהתאמה אישית - הוא פונקציית המודל שלו, model_fn
, שהיא שיטה שבונה גרפים לאימון, הערכה וחיזוי. כאשר אתה משתמש באומדן שהוכן מראש, מישהו אחר כבר יישם את פונקציית המודל. כאשר מסתמכים על אומדן מותאם אישית, עליך לכתוב את פונקציית המודל בעצמך.
צור מעריך מדגם Keras
אתה יכול להמיר מודלים קיימים של Keras ל-Estimators עם tf.keras.estimator.model_to_estimator
. זה מועיל אם ברצונך לחדש את קוד הדגם שלך, אך צינור ההדרכה שלך עדיין דורש מעריכים.
הצג דגם של Keras MobileNet V2 והרכב את המודל עם כלי האופטימיזציה, האובדן והמדדים להתאמן איתם:
keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False
estimator_model = tf.keras.Sequential([
keras_mobilenet_v2,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1)
])
# Compile the model
estimator_model.compile(
optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5 9412608/9406464 [==============================] - 0s 0us/step 9420800/9406464 [==============================] - 0s 0us/step
צור Estimator
ממודל Keras המלוקט. מצב המודל הראשוני של מודל Estimator
נשמר באומדן שנוצר:
est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpmosnmied INFO:tensorflow:Using the Keras model provided. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/backend.py:401: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and ' /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument. category=CustomMaskWarning) INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpmosnmied', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
התייחס Estimator
הנגזר כפי שהיית מתייחס לכל Estimator
אחר.
IMG_SIZE = 160 # All images will be resized to 160x160
def preprocess(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
def train_input_fn(batch_size):
data = tfds.load('cats_vs_dogs', as_supervised=True)
train_data = data['train']
train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
return train_data
כדי להתאמן, התקשר לפונקציית הרכבת של Estimator:
est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpmosnmied/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpmosnmied/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmp/tmpmosnmied/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmp/tmpmosnmied/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6994096, step = 0 INFO:tensorflow:loss = 0.6994096, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpmosnmied/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.68789804. INFO:tensorflow:Loss for final step: 0.68789804. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f4b1c1e9890>
באופן דומה, כדי להעריך, קרא לפונקציית ההערכה של האומד:
est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/training.py:2470: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. warnings.warn('`Model.state_updates` will be removed in a future version. ' INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:36 INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:36 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpmosnmied/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmpmosnmied/model.ckpt-50 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 3.89658s INFO:tensorflow:Inference Time : 3.89658s INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:39 INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:39 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.525, global_step = 50, loss = 0.6723582 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.525, global_step = 50, loss = 0.6723582 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpmosnmied/model.ckpt-50 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpmosnmied/model.ckpt-50 {'accuracy': 0.525, 'loss': 0.6723582, 'global_step': 50}
לפרטים נוספים, עיין בתיעוד עבור tf.keras.estimator.model_to_estimator
.
שמירת מחסומים מבוססי אובייקטים עם Estimator
מעריכים כברירת מחדל שומרים נקודות ביקורת עם שמות משתנים במקום עם גרף האובייקטים המתואר במדריך המחסום . tf.train.Checkpoint
יקרא נקודות ביקורת מבוססות שמות, אך שמות משתנים עשויים להשתנות בעת הזזת חלקים של מודל מחוץ ל- model_fn
של האומד. עבור תאימות קדימה, שמירת נקודות ביקורת מבוססות אובייקטים מקלה על אימון מודל בתוך Estimator ולאחר מכן להשתמש בו מחוץ לאחד.
import tensorflow.compat.v1 as tf_compat
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
def model_fn(features, labels, mode):
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
optimizer=opt, net=net)
with tf.GradientTape() as tape:
output = net(features['x'])
loss = tf.reduce_mean(tf.abs(output - features['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
return tf.estimator.EstimatorSpec(
mode,
loss=loss,
train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
ckpt.step.assign_add(1)),
# Tell the Estimator to save "ckpt" in an object-based format.
scaffold=tf_compat.train.Scaffold(saver=ckpt))
tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 4.659403, step = 0 INFO:tensorflow:loss = 4.659403, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 39.58891. INFO:tensorflow:Loss for final step: 39.58891. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f4b7c451fd0>
לאחר מכן, tf.train.Checkpoint
יכול לטעון את נקודות הבידוק של האומד מה- model_dir
שלו.
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy() # From est.train(..., steps=10)
10
SavedModels מאת מעריכים
מעריכים מייצאים SavedModels דרך tf.Estimator.export_saved_model
.
input_column = tf.feature_column.numeric_column("x")
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])
def input_fn():
return tf.data.Dataset.from_tensor_slices(
({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp30_d7xz6 WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp30_d7xz6 INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30_d7xz6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30_d7xz6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp30_d7xz6/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.4022895. INFO:tensorflow:Loss for final step: 0.4022895. <tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f4b1c10fd10>
כדי לשמור Estimator
אתה צריך ליצור serving_input_receiver
. פונקציה זו בונה חלק מ- tf.Graph
את הנתונים הגולמיים שהתקבלו על ידי SavedModel.
מודול tf.estimator.export
מכיל פונקציות שיעזרו לבנות receivers
אלה.
הקוד הבא בונה מקלט, המבוסס על feature_columns
, שמקבל מאגרי פרוטוקול tf.Example
בסידרה, המשמשים לעתים קרובות עם tf-serving .
tmpdir = tempfile.mkdtemp()
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec([input_column]))
estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from /tmp/tmp30_d7xz6/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmp30_d7xz6/model.ckpt-50 INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: /tmp/tmpi_szzuj1/from_estimator/temp-1632343781/saved_model.pb INFO:tensorflow:SavedModel written to: /tmp/tmpi_szzuj1/from_estimator/temp-1632343781/saved_model.pb
אתה יכול גם לטעון ולהפעיל את המודל הזה, מ-python:
imported = tf.saved_model.load(estimator_path)
def predict(x):
example = tf.train.Example()
example.features.feature["x"].float_list.value.extend([x])
return imported.signatures["predict"](
examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2974025]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5738074]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.42619258, 0.5738074 ]], dtype=float32)>} {'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1919093]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.23291764]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7670824 , 0.23291762]], dtype=float32)>}
tf.estimator.export.build_raw_serving_input_receiver_fn
מאפשר לך ליצור פונקציות קלט שלוקחות טנזורים גולמיים ולא tf.train.Example
s.
שימוש ב- tf.distribute.Strategy
עם Estimator (תמיכה מוגבלת)
tf.estimator
הוא API של הדרכה מבוזר של TensorFlow שתמך במקור בגישת שרת פרמטרים אסינכרוניים. tf.estimator
תומך כעת ב- tf.distribute.Strategy
. אם אתה משתמש ב- tf.estimator
, אתה יכול לעבור לאימון מבוזר עם מעט מאוד שינויים בקוד שלך. עם זה, משתמשי Estimator יכולים כעת לבצע אימון מבוזר סינכרוני על מספר GPUs ומספר עובדים, כמו גם להשתמש ב-TPUs. עם זאת, תמיכה זו באומדן מוגבלת. עיין בסעיף מה נתמך עכשיו למטה לפרטים נוספים.
השימוש ב- tf.distribute.Strategy
עם Estimator שונה במקצת מאשר במקרה של Keras. במקום להשתמש ב-strategi.scope, כעת אתה מעביר את אובייקט strategy.scope
ל- RunConfig
עבור האומד.
אתה יכול לעיין במדריך ההדרכה המבוזר למידע נוסף.
להלן קטע קוד שמראה זאת עם Estimator LinearRegressor
ו- MirroredStrategy
שהוכן מראש:
mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
feature_columns=[tf.feature_column.numeric_column('feats')],
optimizer='SGD',
config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Not using Distribute Coordinator. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpftw63jyd WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpftw63jyd INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpftw63jyd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpftw63jyd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
כאן, אתה משתמש באומד מראש, אבל אותו קוד עובד גם עם אומד מותאם אישית. train_distribute
קובע כיצד ההדרכה תחולק, ו- eval_distribute
קובעת כיצד תחולק הערכה. זהו הבדל נוסף מ-Keras שבו אתה משתמש באותה אסטרטגיה גם לאימון וגם לאיוואל.
כעת אתה יכול לאמן ולהעריך את האומד הזה עם פונקציית קלט:
def input_fn():
dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:374: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across " INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... 2021-09-22 20:49:45.706166: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2021-09-22 20:49:45.707521: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpftw63jyd/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:46 INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:46 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpftw63jyd/model.ckpt-10 INFO:tensorflow:Restoring parameters from /tmp/tmpftw63jyd/model.ckpt-10 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. 2021-09-22 20:49:46.680821: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2021-09-22 20:49:46.682161: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.26514s INFO:tensorflow:Inference Time : 0.26514s INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:46 INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:46 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpftw63jyd/model.ckpt-10 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpftw63jyd/model.ckpt-10 {'average_loss': 1.4210855e-14, 'label/mean': 1.0, 'loss': 1.4210855e-14, 'prediction/mean': 0.99999994, 'global_step': 10}
הבדל נוסף שיש להדגיש כאן בין Estimator ל-Keras הוא הטיפול בקלט. ב-Keras, כל אצווה של מערך הנתונים מפוצלת באופן אוטומטי על פני ההעתקים המרובים. עם זאת, ב-Estimator, אינך מבצע פיצול אצווה אוטומטי, ואינו מפיץ את הנתונים באופן אוטומטי בין עובדים שונים. יש לך שליטה מלאה על האופן שבו אתה רוצה שהנתונים שלך יופצו על פני עובדים ומכשירים, ועליך לספק input_fn
כדי לציין כיצד להפיץ את הנתונים שלך.
input_fn
שלך נקרא פעם אחת לכל עובד, ובכך נותן מערך נתונים אחד לכל עובד. אז אצווה אחת מאותו מערך נתונים מוזנת לעתק אחד באותו עובד, ובכך צורכת N אצווה עבור N העתקים בעובד אחד. במילים אחרות, מערך הנתונים המוחזר על ידי ה- input_fn
צריך לספק אצוות בגודל PER_REPLICA_BATCH_SIZE
. וניתן לקבל את גודל האצווה הגלובלי עבור שלב כ- PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
.
בעת ביצוע הכשרה מרובה עובדים, עליך לפצל את הנתונים שלך בין העובדים, או לערבב עם סיד אקראי על כל אחד מהם. אתה יכול לבדוק דוגמה כיצד לעשות זאת במדריך הדרכה ריבוי עובדים עם הערכה .
ובאופן דומה, אתה יכול להשתמש גם באסטרטגיות ריבוי עובדים ושרת פרמטרים. הקוד נשאר זהה, אבל עליך להשתמש ב- tf.estimator.train_and_evaluate
ולהגדיר משתני סביבה TF_CONFIG
עבור כל בינארי הפועל באשכול שלך.
מה נתמך עכשיו?
קיימת תמיכה מוגבלת לאימון עם Estimator תוך שימוש בכל האסטרטגיות למעט TPUStrategy
. הכשרה והערכה בסיסית אמורות לעבוד, אך מספר תכונות מתקדמות כגון v1.train.Scaffold
לא. ייתכנו גם מספר באגים באינטגרציה הזו ואין תוכניות לשפר את התמיכה הזו באופן אקטיבי (ההתמקדות היא ב-Keras ותמיכה בלולאת אימון מותאמת אישית). אם זה אפשרי, עליך להעדיף להשתמש ב- tf.distribute
עם ממשקי API אלה במקום זאת.
הדרכה API | אסטרטגיית מראה | אסטרטגיה של TPUS | אסטרטגיית MultiWorkerMirrored | אסטרטגיית אחסון מרכזית | ParameterServerStrategy |
---|---|---|---|---|---|
API של Estimator | תמיכה מוגבלת | אינו נתמך | תמיכה מוגבלת | תמיכה מוגבלת | תמיכה מוגבלת |
דוגמאות והדרכות
הנה כמה דוגמאות מקצה לקצה שמראות כיצד להשתמש באסטרטגיות שונות עם אומדן:
- ערכת הדרכה מרובה עובדים עם אומדן מראה כיצד ניתן להתאמן עם מספר עובדים באמצעות
MultiWorkerMirroredStrategy
במערך הנתונים של MNIST. - דוגמה מקצה לקצה להפעלת הכשרה מרובת עובדים עם אסטרטגיות הפצה ב-
tensorflow/ecosystem
באמצעות תבניות Kubernetes. זה מתחיל במודל Keras וממיר אותו ל-Estimator באמצעותtf.keras.estimator.model_to_estimator
API. - הדגם הרשמי של ResNet50 , שניתן לאמן אותו באמצעות
MirroredStrategy
אוMultiWorkerMirroredStrategy
.