מסמך זה מספק עצות ביצועים ספציפיות ל-TensorFlow Datasets (TFDS). שים לב ש-TFDS מספק מערכי נתונים כאובייקטי tf.data.Dataset
, כך שהעצה ממדריך tf.data
עדיין חלה.
מערכי נתונים בהשוואה
השתמש ב- tfds.benchmark(ds)
כדי לסמן כל אובייקט tf.data.Dataset
.
הקפד לציין את batch_size=
כדי לנרמל את התוצאות (למשל 100 iter/sec -> 3200 ex/sec). זה עובד עם כל איטרציה (למשל tfds.benchmark(tfds.as_numpy(ds))
).
ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)
מערכי נתונים קטנים (פחות מ-1 GB)
כל מערכי הנתונים של TFDS מאחסנים את הנתונים בדיסק בפורמט TFRecord
. עבור מערכי נתונים קטנים (למשל MNIST, CIFAR-10/-100), קריאה מ- .tfrecord
יכולה להוסיף תקורה משמעותית.
כאשר מערכי הנתונים הללו משתלבים בזיכרון, ניתן לשפר משמעותית את הביצועים על ידי אחסון במטמון או טעינה מראש של מערך הנתונים. שים לב ש-TFDS מאחסן באופן אוטומטי מערכי נתונים קטנים (בחלק הבא יש את הפרטים).
שמירה במטמון של מערך הנתונים
הנה דוגמה לצינור נתונים ששומר במפורש את מערך הנתונים לאחר נרמול התמונות.
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds, ds_info = tfds.load(
'mnist',
split='train',
as_supervised=True, # returns `(img, label)` instead of dict(image=, ...)
with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
בעת איטרציה על מערך נתונים זה, האיטרציה השנייה תהיה מהירה בהרבה מהראשונה הודות לאחסון במטמון.
שמירה אוטומטית במטמון
כברירת מחדל, TFDS מטמון אוטומטי (עם ds.cache()
) מערכי נתונים העומדים באילוצים הבאים:
- גודל הנתונים הכולל (כל הפיצולים) מוגדר ופחות מ-250 MiB
-
shuffle_files
מושבת, או שרק רסיס בודד נקרא
אפשר לבטל את ההצטרפות לשמירת מטמון אוטומטית על ידי העברת try_autocaching=False
ל- tfds.ReadConfig
ב- tfds.load
. עיין בתיעוד של קטלוג הנתונים כדי לראות אם מערך נתונים ספציפי ישתמש במטמון אוטומטי.
טעינת הנתונים המלאים כטנזור יחיד
אם מערך הנתונים שלך מתאים לזיכרון, אתה יכול גם לטעון את מערך הנתונים המלא כמערך Tensor או NumPy יחיד. אפשר לעשות זאת על ידי הגדרת batch_size=-1
לאצווה את כל הדוגמאות ב- tf.Tensor
יחיד. לאחר מכן השתמש ב- tfds.as_numpy
להמרה מ- tf.Tensor
ל- np.array
.
(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
'mnist',
split=['train', 'test'],
batch_size=-1,
as_supervised=True,
))
מערכי נתונים גדולים
מערכי נתונים גדולים מחולקים (מפוצלים במספר קבצים) ובדרך כלל אינם מתאימים לזיכרון, לכן אין לשמור אותם במטמון.
ערבוב והדרכה
במהלך האימון, חשוב לערבב היטב את הנתונים - נתונים שמערבבים בצורה גרועה עלולים לגרום לדיוק האימון נמוך יותר.
בנוסף לשימוש ב- ds.shuffle
כדי לערבב רשומות, עליך להגדיר גם shuffle_files=True
כדי לקבל התנהגות ערבוב טובה עבור מערכי נתונים גדולים יותר שמחולקים למספר קבצים. אחרת, תקופות יקראו את הרסיסים באותו סדר, וכך הנתונים לא יהיו אקראי באמת.
ds = tfds.load('imagenet2012', split='train', shuffle_files=True)
בנוסף, כאשר shuffle_files=True
, TFDS משבית את options.deterministic
, מה שעשוי לתת שיפור קל בביצועים. כדי לקבל ערבוב דטרמיניסטי, אפשר לבטל את הסכמתו לתכונה זו עם tfds.ReadConfig
: על ידי הגדרת read_config.shuffle_seed
או החלפת read_config.options.deterministic
.
חלוקה אוטומטית של הנתונים שלך בין עובדים (TF)
בעת אימון על מספר עובדים, אתה יכול להשתמש בארגומנט input_context
של tfds.ReadConfig
, כך שכל עובד יקרא תת-קבוצה של הנתונים.
input_context = tf.distribute.InputContext(
input_pipeline_id=1, # Worker id
num_input_pipelines=4, # Total number of workers
)
read_config = tfds.ReadConfig(
input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)
זה משלים ל-API המשנה. ראשית, ממשק ה-API של חלוקת המשנה מוחל: train[:50%]
מומרת לרשימת קבצים לקריאה. לאחר מכן, הפעלה ds.shard()
מוחלת על קבצים אלה. לדוגמה, בעת שימוש train[:50%]
עם num_input_pipelines=2
, כל אחד משני העובדים יקרא 1/4 מהנתונים.
כאשר shuffle_files=True
, קבצים עוברים ערבוב בתוך עובד אחד, אך לא בין עובדים. כל עובד יקרא את אותה תת-קבוצה של קבצים בין תקופות.
חלוקה אוטומטית של הנתונים שלך בין עובדים (Jax)
עם Jax, אתה יכול להשתמש ב- tfds.split_for_jax_process
או tfds.even_splits
API כדי להפיץ את הנתונים שלך בין עובדים. עיין במדריך ה-API המפוצל .
split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)
tfds.split_for_jax_process
הוא כינוי פשוט עבור:
# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]
פענוח תמונות מהיר יותר
כברירת מחדל, TFDS מפענח תמונות באופן אוטומטי. עם זאת, ישנם מקרים שבהם זה יכול להיות יעיל יותר לדלג על פענוח התמונה עם tfds.decode.SkipDecoding
ולהחיל באופן ידני את ה- tf.io.decode_image
op:
- בעת סינון דוגמאות (עם
tf.data.Dataset.filter
), לפענח תמונות לאחר סינון של דוגמאות. - בעת חיתוך תמונות, כדי להשתמש ב-
tf.image.decode_and_crop_jpeg
fused.
הקוד עבור שתי הדוגמאות זמין במדריך הפענוח .
דלג על תכונות שאינן בשימוש
אם אתה משתמש רק בתת-קבוצה של התכונות, אפשר לדלג לחלוטין על חלק מהתכונות. אם מערך הנתונים שלך כולל תכונות רבות שאינן בשימוש, אי פענוחן יכול לשפר משמעותית את הביצועים. ראה https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features
tf.data משתמש בכל זיכרון ה-RAM שלי!
אם אתה מוגבל ב-RAM, או אם אתה טוען מערכי נתונים רבים במקביל תוך שימוש ב- tf.data
, הנה כמה אפשרויות שיכולות לעזור:
ביטול גודל המאגר
builder.as_dataset(
read_config=tfds.ReadConfig(
...
override_buffer_size=1024, # Save quite a bit of RAM.
),
...
)
זה עוקף את buffer_size
המועבר ל- TFRecordDataset
(או שווה ערך): https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args
השתמש ב-tf.data.Dataset.with_options כדי לעצור התנהגויות קסם
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options
options = tf.data.Options()
# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False
data = data.with_options(options)