सहायता Kaggle पर TensorFlow साथ ग्रेट बैरियर रीफ की रक्षा चैलेंज में शामिल हों

tf.data: TensorFlow इनपुट पाइपलाइन बनाएँ

TensorFlow.org पर देखें GitHub पर स्रोत देखेंनोटबुक डाउनलोड करें

tf.data एपीआई आप सरल, पुन: प्रयोज्य टुकड़े से जटिल इनपुट पाइपलाइनों का निर्माण करने के लिए सक्षम बनाता है। उदाहरण के लिए, एक छवि मॉडल के लिए पाइपलाइन एक वितरित फ़ाइल सिस्टम में फ़ाइलों से डेटा एकत्र कर सकती है, प्रत्येक छवि पर यादृच्छिक गड़बड़ी लागू कर सकती है, और प्रशिक्षण के लिए यादृच्छिक रूप से चयनित छवियों को एक बैच में मर्ज कर सकती है। टेक्स्ट मॉडल के लिए पाइपलाइन में कच्चे टेक्स्ट डेटा से प्रतीकों को निकालना, उन्हें लुकअप टेबल के साथ एम्बेडिंग पहचानकर्ताओं में परिवर्तित करना और विभिन्न लंबाई के अनुक्रमों को एक साथ बैच करना शामिल हो सकता है। tf.data एपीआई यह संभव, डेटा की बड़ी मात्रा को संभालने के विभिन्न डेटा प्रारूपों से पढ़ा है, और जटिल परिवर्तनों को करने के लिए बनाता है।

tf.data API का परिचय एक tf.data.Dataset अमूर्त है कि तत्वों के एक दृश्य है, जिसमें प्रत्येक तत्व एक या अधिक घटक होते हैं प्रतिनिधित्व करता है। उदाहरण के लिए, एक छवि पाइपलाइन में, एक तत्व एक एकल प्रशिक्षण उदाहरण हो सकता है, जिसमें छवि और उसके लेबल का प्रतिनिधित्व करने वाले टेंसर घटकों की एक जोड़ी होती है।

डेटासेट बनाने के दो अलग-अलग तरीके हैं:

  • कोई डेटा स्रोत एक निर्माण करती Dataset डेटा स्मृति में संग्रहीत या एक या अधिक फ़ाइलों में से।

  • एक डाटा परिवर्तन एक या अधिक से एक डाटासेट निर्माण करती tf.data.Dataset वस्तुओं।

import tensorflow as tf
import pathlib
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

np.set_printoptions(precision=4)

बुनियादी यांत्रिकी

एक इनपुट पाइपलाइन बनाने के लिए, आप डेटा स्रोत से शुरू करनी चाहिए। उदाहरण के लिए, एक के निर्माण के लिए Dataset स्मृति में डेटा से, आप उपयोग कर सकते हैं tf.data.Dataset.from_tensors() या tf.data.Dataset.from_tensor_slices() । वैकल्पिक रूप से, अगर आपके इनपुट डेटा की सिफारिश की TFRecord प्रारूप में एक फ़ाइल में संग्रहीत है, तो आप उपयोग कर सकते हैं tf.data.TFRecordDataset()

एक बार जब आप एक राशि Dataset वस्तु, आप एक नए रूप में बदल सकता है Dataset पर विधि कॉल चेनिंग द्वारा tf.data.Dataset वस्तु। उदाहरण के लिए, आप प्रति-तत्व जैसे परिवर्तनों लागू कर सकते हैं Dataset.map() , और इस तरह के रूप में बहु तत्व परिवर्तनों Dataset.batch() । के लिए दस्तावेज़ देखें tf.data.Dataset परिवर्तनों की एक पूरी सूची के लिए।

Dataset वस्तु एक अजगर iterable है। यह लूप के लिए इसके तत्वों का उपभोग करना संभव बनाता है:

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
<TensorSliceDataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>
for elem in dataset:
  print(elem.numpy())
8
3
0
8
2
1

या स्पष्ट रूप से एक अजगर बनाने इटरेटर का उपयोग करके iter और उसके तत्वों का उपयोग कर लेने वाली next :

it = iter(dataset)

print(next(it).numpy())
8

वैकल्पिक रूप से, डाटासेट तत्वों का उपयोग कर सेवन किया जा सकता reduce परिवर्तन है, जो एक एकल परिणाम का उत्पादन करने के सभी तत्वों को कम करता है। निम्न उदाहरण कैसे उपयोग करने के लिए दिखाता है reduce पूर्णांकों का एक डाटासेट की राशि की गणना करने के परिवर्तन।

print(dataset.reduce(0, lambda state, value: state + value).numpy())
22

डेटासेट संरचना

एक डाटासेट तत्वों, जहां प्रत्येक तत्व एक ही (नेस्ट) घटकों की संरचना है की एक दृश्य पैदा करता है। संरचना के व्यक्तिगत घटकों द्वारा किसी भी प्रकार का हो सकता है प्रदर्शनीय tf.TypeSpec , सहित tf.Tensor , tf.sparse.SparseTensor , tf.RaggedTensor , tf.TensorArray , या tf.data.Dataset

अजगर निर्माणों कि तत्वों की (नेस्टेड) संरचना व्यक्त करने के लिए इस्तेमाल किया जा सकता शामिल हैं tuple , dict , NamedTuple , और OrderedDict । विशेष रूप से, list डाटासेट तत्वों की संरचना को व्यक्त करने के लिए एक वैध निर्माण नहीं है। इसका कारण यह है जल्दी tf.data उपयोगकर्ताओं दृढ़ता के बारे में महसूस किया है list (उदाहरण के लिए पारित किया आदानों tf.data.Dataset.from_tensors ) स्वचालित रूप से tensors और के रूप में पैक किया जा रहा list outputs (उपयोगकर्ता परिभाषित कार्यों का उदाहरण वापसी मान) एक मजबूर किया जा रहा tuple । एक परिणाम के रूप में, यदि आप एक चाहते हैं list इनपुट एक संरचना के रूप में इलाज किया जाना है, तो आप इसे में बदलने की आवश्यकता tuple और अगर आप एक चाहते हैं list उत्पादन एक भी घटक के रूप में है, तो आप स्पष्ट रूप से उपयोग कर इसे पैक करने के लिए की जरूरत है tf.stack .

Dataset.element_spec संपत्ति आप प्रत्येक तत्व घटक के प्रकार के निरीक्षण करने के लिए अनुमति देता है। संपत्ति के लिए एक आंतरिक संरचना रिटर्न tf.TypeSpec वस्तुओं, तत्व है, जो एक एकल घटक, घटकों की एक टपल, या घटकों के एक नेस्टेड टपल हो सकता है की संरचना से मेल खाते। उदाहरण के लिए:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))

dataset1.element_spec
TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None),
 TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3.element_spec
(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
 (TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))

dataset4.element_spec
SparseTensorSpec(TensorShape([3, 4]), tf.int32)
# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type
tensorflow.python.framework.sparse_tensor.SparseTensor

Dataset परिवर्तनों किसी भी संरचना के डेटासेट समर्थन करते हैं। का उपयोग करते समय Dataset.map() , और Dataset.filter() परिवर्तनों है, जो प्रत्येक तत्व के लिए एक समारोह लागू होते हैं, तत्व संरचना समारोह के तर्कों को निर्धारित करता है:

dataset1 = tf.data.Dataset.from_tensor_slices(
    tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))

dataset1
<TensorSliceDataset element_spec=TensorSpec(shape=(10,), dtype=tf.int32, name=None)>
for z in dataset1:
  print(z.numpy())
[8 2 4 4 4 9 9 1 5 4]
[8 9 8 6 5 8 8 1 3 7]
[3 7 8 8 5 9 9 4 5 5]
[4 4 9 7 6 4 5 9 4 4]
dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2
<TensorSliceDataset element_spec=(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None))>
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3
<ZipDataset element_spec=(TensorSpec(shape=(10,), dtype=tf.int32, name=None), (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None)))>
for a, (b,c) in dataset3:
  print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

इनपुट डेटा पढ़ना

NumPy सरणियों का उपभोग

देखें NumPy सरणी लोड हो रहा है और उदाहरण के लिए।

स्मृति में अपने इनपुट डेटा फिट के सभी, एक बनाने के लिए सबसे आसान तरीका है तो Dataset उन लोगों से उन्हें कन्वर्ट करने के लिए है tf.Tensor वस्तुओं और उपयोग Dataset.from_tensor_slices()

train, test = tf.keras.datasets.fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
40960/29515 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
26435584/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
16384/5148 [===============================================================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
4431872/4422102 [==============================] - 0s 0us/step
images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset
<TensorSliceDataset element_spec=(TensorSpec(shape=(28, 28), dtype=tf.float64, name=None), TensorSpec(shape=(), dtype=tf.uint8, name=None))>

पायथन जनरेटर का सेवन

एक अन्य आम डेटा स्रोत है कि आसानी से एक के रूप में ग्रहण किया जा सकता tf.data.Dataset अजगर जनरेटर है।

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
for n in count(5):
  print(n)
0
1
2
3
4

Dataset.from_generator निर्माता एक पूरी तरह कार्यात्मक के लिए अजगर जनरेटर धर्मान्तरित tf.data.Dataset

कंस्ट्रक्टर एक कॉल करने योग्य इनपुट के रूप में लेता है, एक पुनरावर्तक नहीं। यह अंत तक पहुंचने पर जनरेटर को पुनरारंभ करने की अनुमति देता है। यह एक वैकल्पिक लेता है args तर्क है, जो प्रतिदेय के तर्कों के रूप में पारित कर दिया है।

output_types तर्क की आवश्यकता है क्योंकि tf.data एक बनाता है tf.Graph आंतरिक रूप से, और ग्राफ किनारों एक की आवश्यकता होती है tf.dtype

ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

output_shapes तर्क की आवश्यकता नहीं है, लेकिन अत्यधिक कई tensorflow संचालन के रूप में सिफारिश की है अज्ञात रैंक के साथ tensors समर्थन नहीं करते। एक विशेष अक्ष की लंबाई अज्ञात या चर रहा है, तो के रूप में सेट None में output_shapes

यह भी ध्यान देना ज़रूरी है कि output_shapes और output_types अन्य डाटासेट तरीकों के रूप में ही घोंसले नियमों का पालन करें।

यहां एक उदाहरण जनरेटर है जो दोनों पहलुओं को प्रदर्शित करता है, यह सरणी के टुपल्स देता है, जहां दूसरी सरणी अज्ञात लंबाई वाला वेक्टर है।

def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1
for i, series in gen_series():
  print(i, ":", str(series))
  if i > 5:
    break
0 : [-0.507   0.825  -0.9698  1.0904  1.1761 -0.9112 -0.0045 -0.8401 -0.2676]
1 : [ 0.621   1.5843 -0.4695]
2 : []
3 : [-0.5107]
4 : [-1.6201 -1.8984  0.6082  1.8105 -2.368   0.4142  0.2167]
5 : [0.9673]
6 : []

पहले उत्पादन एक है int32 दूसरा एक है float32

पहले आइटम एक अदिश, आकार है () , और दूसरा अज्ञात लंबाई, आकार का एक वेक्टर है (None,)

ds_series = tf.data.Dataset.from_generator(
    gen_series, 
    output_types=(tf.int32, tf.float32), 
    output_shapes=((), (None,)))

ds_series
<FlatMapDataset element_spec=(TensorSpec(shape=(), dtype=tf.int32, name=None), TensorSpec(shape=(None,), dtype=tf.float32, name=None))>

अब यह एक नियमित रूप से की तरह इस्तेमाल किया जा सकता tf.data.Dataset । ध्यान दें कि जब एक चर आकार के साथ एक डाटासेट batching, आप उपयोग करने की आवश्यकता Dataset.padded_batch

ds_series_batch = ds_series.shuffle(20).padded_batch(10)

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
[12 19  5 10  6 15  3 20 21 17]

[[ 0.      0.      0.      0.      0.      0.      0.      0.      0.    ]
 [ 0.483  -0.3454  0.      0.      0.      0.      0.      0.      0.    ]
 [ 0.0869 -1.5191 -1.9252 -0.6955  0.3542  1.7332 -0.084   0.      0.    ]
 [-1.0211  0.2689 -0.4805 -0.6755  0.6886  0.8313  0.      0.      0.    ]
 [ 0.5442  0.0539  1.4572  2.7313 -0.0386  1.2614 -0.0811 -0.5399  0.    ]
 [ 1.6386 -0.8331 -1.4722  0.0403  1.3425 -0.3833 -2.1371  0.901   0.9595]
 [ 0.1002  0.0705 -0.4418  0.0806 -1.4263 -0.1352  0.      0.      0.    ]
 [ 0.      0.      0.      0.      0.      0.      0.      0.      0.    ]
 [-0.8525  0.1426 -0.0869  2.163  -0.3666  0.      0.      0.      0.    ]
 [ 0.6796 -0.824  -0.0424  0.      0.      0.      0.      0.      0.    ]]

एक और अधिक यथार्थवादी उदाहरण के लिए, लपेटकर कोशिश preprocessing.image.ImageDataGenerator एक के रूप में tf.data.Dataset

पहले डेटा डाउनलोड करें:

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 1s 0us/step
228827136/228813984 [==============================] - 1s 0us/step

बनाएं image.ImageDataGenerator

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))
Found 3670 images belonging to 5 classes.
print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
float32 (32, 256, 256, 3)
float32 (32, 5)
ds = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers), 
    output_types=(tf.float32, tf.float32), 
    output_shapes=([32,256,256,3], [32,5])
)

ds.element_spec
(TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(32, 5), dtype=tf.float32, name=None))
for images, label in ds.take(1):
  print('images.shape: ', images.shape)
  print('labels.shape: ', labels.shape)
Found 3670 images belonging to 5 classes.
images.shape:  (32, 256, 256, 3)
labels.shape:  (32, 5)

TFRecord डेटा की खपत

देखें TFRecords लोड हो रहा है एंड-टू-एंड उदाहरण के लिए।

tf.data एपीआई है कि आप बड़े डेटासेट कि स्मृति में फिट नहीं संसाधित कर सकते हैं तो फ़ाइल स्वरूपों की एक किस्म का समर्थन करता है। उदाहरण के लिए, TFRecord फ़ाइल प्रारूप एक साधारण रिकॉर्ड-उन्मुख बाइनरी प्रारूप है जिसका उपयोग कई TensorFlow एप्लिकेशन प्रशिक्षण डेटा के लिए करते हैं। tf.data.TFRecordDataset वर्ग आप एक इनपुट पाइपलाइन के हिस्से के रूप में एक या अधिक TFRecord फ़ाइलों की सामग्री पर स्ट्रीम करने में सक्षम बनाता है।

यहां फ़्रेंच स्ट्रीट नेम साइन्स (FSNS) से परीक्षण फ़ाइल का उपयोग करते हुए एक उदाहरण दिया गया है।

# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7905280/7904079 [==============================] - 0s 0us/step
7913472/7904079 [==============================] - 0s 0us/step

filenames करने के लिए तर्क TFRecordDataset प्रारंभकर्ता या तो एक स्ट्रिंग, स्ट्रिंग की एक सूची है, या एक हो सकता है tf.Tensor तार की। इसलिए यदि आपके पास प्रशिक्षण और सत्यापन उद्देश्यों के लिए फाइलों के दो सेट हैं, तो आप एक फ़ैक्टरी विधि बना सकते हैं जो डेटासेट का उत्पादन करती है, फ़ाइल नाम को इनपुट तर्क के रूप में लेती है:

dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

कई TensorFlow परियोजनाओं धारावाहिक का उपयोग tf.train.Example उनके TFRecord फ़ाइलों में रिकॉर्ड। इनका निरीक्षण करने से पहले इन्हें डीकोड करने की आवश्यकता है:

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

parsed.features.feature['image/text']
bytes_list {
  value: "Rue Perreyon"
}

टेक्स्ट डेटा की खपत

देखें पाठ लोड हो रहा है अंत उदाहरण के लिए एक अंत के लिए।

कई डेटासेट एक या अधिक टेक्स्ट फ़ाइलों के रूप में वितरित किए जाते हैं। tf.data.TextLineDataset एक आसान तरीका एक या अधिक पाठ फ़ाइलों से लाइनों को निकालने के लिए प्रदान करता है। एक या अधिक फ़ाइल नाम को देखते हुए, एक TextLineDataset प्रति उन फ़ाइलों की लाइन एक स्ट्रिंग-मान तत्व का उत्पादन करेगा।

directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']

file_paths = [
    tf.keras.utils.get_file(file_name, directory_url + file_name)
    for file_name in file_names
]
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
819200/815980 [==============================] - 0s 0us/step
827392/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
811008/809730 [==============================] - 0s 0us/step
819200/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
811008/807992 [==============================] - 0s 0us/step
819200/807992 [==============================] - 0s 0us/step
dataset = tf.data.TextLineDataset(file_paths)

यहाँ पहली फ़ाइल की पहली कुछ पंक्तियाँ हैं:

for line in dataset.take(5):
  print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b'His wrath pernicious, who ten thousand woes'
b"Caused to Achaia's host, sent many a soul"
b'Illustrious into Ades premature,'
b'And Heroes gave (so stood the will of Jove)'

फ़ाइलों के बीच वैकल्पिक लाइनों के लिए उपयोग Dataset.interleave । इससे फाइलों को एक साथ फेरबदल करना आसान हो जाता है। यहाँ प्रत्येक अनुवाद की पहली, दूसरी और तीसरी पंक्तियाँ हैं:

files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)

for i, line in enumerate(lines_ds.take(9)):
  if i % 3 == 0:
    print()
  print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse,"
b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought'

b'His wrath pernicious, who ten thousand woes'
b'The vengeance, deep and deadly; whence to Greece'
b'countless ills upon the Achaeans. Many a brave soul did it send'

b"Caused to Achaia's host, sent many a soul"
b'Unnumbered ills arose; which many a soul'
b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'

डिफ़ॉल्ट रूप से, एक TextLineDataset अगर फ़ाइल, एक हैडर लाइन के साथ शुरू होता है या टिप्पणियां हैं, प्रत्येक फ़ाइल है, जो वांछनीय नहीं हो सकता है, उदाहरण के लिए के हर लाइन अर्जित करता है। ये लाइनें का उपयोग कर हटाया जा सकता है Dataset.skip() या Dataset.filter() परिवर्तनों। यहां, आप पहली पंक्ति को छोड़ देते हैं, फिर केवल बचे लोगों को खोजने के लिए फ़िल्टर करते हैं।

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
40960/30874 [=======================================] - 0s 0us/step
for line in titanic_lines.take(10):
  print(line.numpy())
b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone'
b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n'
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y'
b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
def survived(line):
  return tf.not_equal(tf.strings.substr(line, 0, 1), "0")

survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):
  print(line.numpy())
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y'
b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y'
b'1,male,28.0,0,0,35.5,First,A,Southampton,y'
b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'

सीएसवी डेटा की खपत

देखें CSV फ़ाइलें लोड हो रहा है , और लोड हो रहा है पांडा DataFrames अधिक उदाहरण के लिए।

सादा पाठ में सारणीबद्ध डेटा संग्रहीत करने के लिए CSV फ़ाइल स्वरूप एक लोकप्रिय प्रारूप है।

उदाहरण के लिए:

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file)
df.head()

यदि स्मृति में अपने डेटा को फिट ही Dataset.from_tensor_slices शब्दकोशों पर विधि काम करती है, इस डेटा की इजाजत दी आसानी से आयात करने के लिए:

titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))

for feature_batch in titanic_slices.take(1):
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived'          : 0
  'sex'               : b'male'
  'age'               : 22.0
  'n_siblings_spouses': 1
  'parch'             : 0
  'fare'              : 7.25
  'class'             : b'Third'
  'deck'              : b'unknown'
  'embark_town'       : b'Southampton'
  'alone'             : b'n'

आवश्यकतानुसार डिस्क से लोड करना एक अधिक स्केलेबल दृष्टिकोण है।

tf.data मॉड्यूल एक या अधिक CSV फ़ाइलों से निकालने के रिकॉर्ड के तरीकों कि का अनुपालन प्रदान करता है आरएफसी 4180

experimental.make_csv_dataset समारोह csv फ़ाइलें के सेट पढ़ने के लिए उच्च स्तरीय इंटरफेस है। यह उपयोग को सरल बनाने के लिए कॉलम प्रकार के अनुमान और कई अन्य सुविधाओं का समर्थन करता है, जैसे बैचिंग और शफलिंग।

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived")
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  print("features:")
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [1 0 1 0]
features:
  'sex'               : [b'female' b'male' b'female' b'male']
  'age'               : [ 4. 50. 28. 28.]
  'n_siblings_spouses': [0 0 0 1]
  'parch'             : [2 0 0 0]
  'fare'              : [22.025 13.     7.75  16.1  ]
  'class'             : [b'Third' b'Second' b'Third' b'Third']
  'deck'              : [b'unknown' b'unknown' b'unknown' b'unknown']
  'embark_town'       : [b'Southampton' b'Southampton' b'Queenstown' b'Southampton']
  'alone'             : [b'n' b'y' b'y' b'n']

आप उपयोग कर सकते हैं select_columns तर्क अगर आप केवल स्तंभों की एक सबसेट की जरूरत है।

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [1 1 1 1]
  'fare'              : [ 12.  108.9  23.   26. ]
  'class'             : [b'Second' b'First' b'Second' b'Second']

वहाँ भी एक निचले स्तर है experimental.CsvDataset वर्ग जो महीन बेहतर नियंत्रण प्रदान करता है। यह कॉलम प्रकार के अनुमान का समर्थन नहीं करता है। इसके बजाय आपको प्रत्येक कॉलम का प्रकार निर्दिष्ट करना होगा।

titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string] 
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)

for line in dataset.take(10):
  print([item.numpy() for item in line])
[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
[1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
[1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
[0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
[0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
[1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
[0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']

यदि कुछ कॉलम खाली हैं, तो यह निम्न-स्तरीय इंटरफ़ेस आपको कॉलम प्रकारों के बजाय डिफ़ॉल्ट मान प्रदान करने की अनुमति देता है।

%%writefile missing.csv
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,
Writing missing.csv
# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.

record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<MapDataset element_spec=TensorSpec(shape=(4,), dtype=tf.int32, name=None)>
for line in dataset:
  print(line.numpy())
[1 2 3 4]
[999   2   3   4]
[  1 999   3   4]
[  1   2 999   4]
[  1   2   3 999]
[999 999 999 999]

डिफ़ॉल्ट रूप से, एक CsvDataset फ़ाइल के हर लाइन है, जो वांछनीय नहीं हो सकता है, उदाहरण के लिए यदि फ़ाइल, एक हैडर लाइन है कि अनदेखा किया जाना चाहिए के साथ शुरू होता है या कुछ स्तंभ इनपुट में की आवश्यकता नहीं है, तो के हर स्तंभ अर्जित करता है। ये लाइनें और क्षेत्रों के साथ हटाया जा सकता है header और select_cols क्रमशः तर्क।

# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<MapDataset element_spec=TensorSpec(shape=(2,), dtype=tf.int32, name=None)>
for line in dataset:
  print(line.numpy())
[2 4]
[2 4]
[999   4]
[2 4]
[  2 999]
[999 999]

फाइलों का उपभोग करना

फाइलों के एक सेट के रूप में वितरित कई डेटासेट हैं, जहां प्रत्येक फाइल एक उदाहरण है।

flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)

रूट निर्देशिका में प्रत्येक वर्ग के लिए एक निर्देशिका होती है:

for item in flowers_root.glob("*"):
  print(item.name)
sunflowers
daisy
LICENSE.txt
roses
tulips
dandelion

प्रत्येक वर्ग निर्देशिका में फ़ाइलें उदाहरण हैं:

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

for f in list_ds.take(5):
  print(f.numpy())
b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/6145005439_ef6e07f9c6_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/5762590366_5cf7a32b87_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/10443973_aeb97513fc_m.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/roses/921138131_9e1393eb2b_m.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/3730618647_5725c692c3_m.jpg'

का उपयोग कर डेटा पढ़ें tf.io.read_file समारोह और पथ से लेबल निकालने, लौटने (image, label) जोड़े:

def process_path(file_path):
  label = tf.strings.split(file_path, os.sep)[-2]
  return tf.io.read_file(file_path), label

labeled_ds = list_ds.map(process_path)
for image_raw, label_text in labeled_ds.take(1):
  print(repr(image_raw.numpy()[:100]))
  print()
  print(label_text.numpy())
b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x03\x02\x02\x03\x02\x02\x03\x03\x03\x03\x04\x03\x03\x04\x05\x08\x05\x05\x04\x04\x05\n\x07\x07\x06\x08\x0c\n\x0c\x0c\x0b\n\x0b\x0b\r\x0e\x12\x10\r\x0e\x11\x0e\x0b\x0b\x10\x16\x10\x11\x13\x14\x15\x15\x15\x0c\x0f\x17\x18\x16\x14\x18\x12\x14\x15\x14\xff\xdb\x00C\x01\x03\x04\x04\x05\x04\x05'

b'dandelion'

डेटासेट तत्वों को बैचना

सरल बैचिंग

ढेर batching का सबसे सरल रूप n एक भी तत्व में किसी डेटासेट के लगातार तत्वों। Dataset.batch() परिवर्तन वास्तव में यह करता है, के रूप में ही बाधाओं के साथ tf.stack() ऑपरेटर, तत्वों के प्रत्येक घटक के लिए लागू: यानी प्रत्येक घटक मैं के लिए, सभी तत्वों को ठीक उसी आकार का एक टेन्सर होना आवश्यक है।

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

for batch in batched_dataset.take(4):
  print([arr.numpy() for arr in batch])
[array([0, 1, 2, 3]), array([ 0, -1, -2, -3])]
[array([4, 5, 6, 7]), array([-4, -5, -6, -7])]
[array([ 8,  9, 10, 11]), array([ -8,  -9, -10, -11])]
[array([12, 13, 14, 15]), array([-12, -13, -14, -15])]

जबकि tf.data की कोशिश करता आकार जानकारी का प्रचार करने, की डिफ़ॉल्ट सेटिंग्स Dataset.batch एक अज्ञात बैच आकार में परिणाम है क्योंकि अंतिम बैच नहीं पूरा हो सकता है। नोट None आकार में रों:

batched_dataset
<BatchDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.int64, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

का प्रयोग करें drop_remainder कि अंतिम बैच की उपेक्षा, और पूर्ण आकार प्रचार पाने के लिए तर्क:

batched_dataset = dataset.batch(7, drop_remainder=True)
batched_dataset
<BatchDataset element_spec=(TensorSpec(shape=(7,), dtype=tf.int64, name=None), TensorSpec(shape=(7,), dtype=tf.int64, name=None))>

पैडिंग के साथ बैचिंग टेंसर

उपरोक्त नुस्खा उन टेंसरों के लिए काम करता है जिनका आकार समान होता है। हालांकि, कई मॉडल (जैसे अनुक्रम मॉडल) इनपुट डेटा के साथ काम करते हैं जिनका आकार अलग-अलग हो सकता है (उदाहरण के लिए अलग-अलग लंबाई के अनुक्रम)। इस मामले को संभालने के लिए Dataset.padded_batch परिवर्तन एक या अधिक आयाम जिसमें वे गद्देदार किया जा सकता है निर्दिष्ट करने के द्वारा विभिन्न आकार के बैच tensors के लिए सक्षम बनाता है।

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=(None,))

for batch in dataset.take(2):
  print(batch.numpy())
  print()
[[0 0 0]
 [1 0 0]
 [2 2 0]
 [3 3 3]]

[[4 4 4 4 0 0 0]
 [5 5 5 5 5 0 0]
 [6 6 6 6 6 6 0]
 [7 7 7 7 7 7 7]]

Dataset.padded_batch परिवर्तन आप प्रत्येक घटक के प्रत्येक आयाम के लिए विभिन्न गद्दी स्थापित करने के लिए अनुमति देता है, और यह चर लंबाई (द्वारा संकेतित हो सकता है None या निरंतर लंबाई ऊपर के उदाहरण में)। पैडिंग मान को ओवरराइड करना भी संभव है, जो डिफ़ॉल्ट रूप से 0 है।

प्रशिक्षण कार्यप्रवाह

कई युगों को संसाधित करना

tf.data दो एपीआई प्रदान करता है मुख्य तरीके एक ही डेटा के कई अवधियों को कार्रवाई करने के लिए।

कई अवधियों में एक डाटासेट से अधिक पुनरावृति करने के लिए सबसे आसान तरीका उपयोग करने के लिए है Dataset.repeat() परिवर्तन। सबसे पहले, टाइटैनिक डेटा का डेटासेट बनाएं:

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
def plot_batch_sizes(ds):
  batch_sizes = [batch.shape[0] for batch in ds]
  plt.bar(range(len(batch_sizes)), batch_sizes)
  plt.xlabel('Batch number')
  plt.ylabel('Batch size')

लागू करने Dataset.repeat() कोई तर्क के साथ परिवर्तन इनपुट अनिश्चित काल के दोहराया जाएगा।

Dataset.repeat परिवर्तन एक युग के अंत और अगले युग की शुरुआत का संकेत के बिना अपने तर्कों कोनकैटेनेट्स किया गया। इस वजह से एक Dataset.batch के बाद लागू Dataset.repeat बैचों निकलेगा कि पैर फैलाकर बैठना युग सीमाओं:

titanic_batches = titanic_lines.repeat(3).batch(128)
plot_batch_sizes(titanic_batches)

पीएनजी

आप स्पष्ट युग जुदाई की जरूरत है, डाल Dataset.batch दोहराने से पहले:

titanic_batches = titanic_lines.batch(128).repeat(3)

plot_batch_sizes(titanic_batches)

पीएनजी

यदि आप प्रत्येक युग के अंत में एक कस्टम गणना (उदाहरण के लिए आंकड़े एकत्र करने के लिए) करना चाहते हैं तो प्रत्येक युग पर डेटासेट पुनरावृत्ति को पुनरारंभ करना सबसे आसान है:

epochs = 3
dataset = titanic_lines.batch(128)

for epoch in range(epochs):
  for batch in dataset:
    print(batch.shape)
  print("End of epoch: ", epoch)
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  0
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  1
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  2

इनपुट डेटा को बेतरतीब ढंग से फेरबदल करना

Dataset.shuffle() परिवर्तन एक निश्चित-आकार बफर रखता है और है कि बफर से यादृच्छिक पर समान रूप से अगले तत्व चुनता है।

डेटासेट में एक इंडेक्स जोड़ें ताकि आप प्रभाव देख सकें:

lines = tf.data.TextLineDataset(titanic_file)
counter = tf.data.experimental.Counter()

dataset = tf.data.Dataset.zip((counter, lines))
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(20)
dataset
<BatchDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.int64, name=None), TensorSpec(shape=(None,), dtype=tf.string, name=None))>

चूंकि buffer_size 100 है, और बैच का आकार 20 है, पहले बैच के 120 से अधिक एक सूचकांक के साथ कोई तत्व शामिल हैं।

n,line_batch = next(iter(dataset))
print(n.numpy())
[26 31 11 82 21 52 98  7 79 60 56 81 46 28  8 74 44 16  1  3]

साथ के रूप में Dataset.batch आदेश के सापेक्ष Dataset.repeat मायने रखती है।

Dataset.shuffle एक युग के अंत का संकेत नहीं है जब तक फेरबदल बफर खाली है। इसलिए दोहराव से पहले रखा गया फेरबदल अगले युग में जाने से पहले एक युग के प्रत्येक तत्व को दिखाएगा:

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(60).take(5):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[539 272 610 626 515 615 304 499 547 580]
[565 367 511 513 595 589 576 584 415 588]
[567 463 608 554 619 596 523 573]
[ 88  94  72  59  92   4  69 100  67   3]
[ 57  16  93  38  45 104  52  90  26 114]
shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7f26501d9b10>

पीएनजी

लेकिन एक फेरबदल से पहले एक दोहराव युग की सीमाओं को एक साथ मिलाता है:

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(55).take(15):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[ 19 603 478 559 480 611 516   3 402  30]
[596 554 495 586 564 571 510 477 583 576]
[ 24 508 419 616 474 515  26   1  31  38]
[562  27 592 461 456  32  53  46 509  48]
[  8 578  21  16  57  20 621 608 580  58]
[ 54  42  62   7  69 594 622  35 421  41]
[605  28  11  13 574   2  66  67 560  72]
[ 61 617  36  44   5  51  77 537  78   6]
[ 63 607  43  56 604 530  91 593  88 104]
[102 557 539  60 115  52 582  68  81  47]
[ 74 534  83  97 119  80 626 114 577 563]
[130  55 121 316  40 136  90 111   9  14]
[107 470 106  64 122 615 113 129  18  50]
[ 98  92  45 148 327  29 120 151 381 112]
[159 511 455 127 153  86 619 128 100 117]
repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]

plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.plot(repeat_shuffle, label="repeat().shuffle()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7f25fc63bf10>

पीएनजी

प्रीप्रोसेसिंग डेटा

Dataset.map(f) परिवर्तन किसी दिए गए समारोह लगाने से एक नया डाटासेट का उत्पादन f इनपुट डाटासेट के प्रत्येक तत्व के लिए। यह पर आधारित है map() समारोह है कि आमतौर पर कार्यात्मक प्रोग्रामिंग भाषाओं में सूची (और अन्य संरचनाओं) के लिए आवेदन किया है। समारोह f लेता tf.Tensor वस्तुओं है कि इनपुट में एक भी तत्व का प्रतिनिधित्व, और रिटर्न tf.Tensor वस्तुओं है कि नए डाटासेट में एक भी तत्व का प्रतिनिधित्व करेंगे। इसका कार्यान्वयन एक तत्व को दूसरे में बदलने के लिए मानक TensorFlow संचालन का उपयोग करता है।

इस अनुभाग में बताया उपयोग करने के लिए के सामान्य उदाहरण शामिल किया गया Dataset.map()

छवि डेटा को डिकोड करना और उसका आकार बदलना

वास्तविक दुनिया छवि डेटा पर एक तंत्रिका नेटवर्क को प्रशिक्षित करते समय, विभिन्न आकारों की छवियों को एक सामान्य आकार में परिवर्तित करना अक्सर आवश्यक होता है, ताकि उन्हें एक निश्चित आकार में बैच किया जा सके।

फूल फ़ाइल नाम डेटासेट का पुनर्निर्माण करें:

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

डेटासेट तत्वों में हेरफेर करने वाला एक फ़ंक्शन लिखें।

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(filename):
  parts = tf.strings.split(filename, os.sep)
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.io.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize(image, [128, 128])
  return image, label

परीक्षण करें कि यह काम करता है।

file_path = next(iter(list_ds))
image, label = parse_image(file_path)

def show(image, label):
  plt.figure()
  plt.imshow(image)
  plt.title(label.numpy().decode('utf-8'))
  plt.axis('off')

show(image, label)

पीएनजी

इसे डेटासेट पर मैप करें।

images_ds = list_ds.map(parse_image)

for image, label in images_ds.take(2):
  show(image, label)

पीएनजी

पीएनजी

मनमाना पायथन तर्क लागू करना

प्रदर्शन कारणों से, जब भी संभव हो अपने डेटा को प्रीप्रोसेस करने के लिए TensorFlow संचालन का उपयोग करें। हालांकि, कभी-कभी आपके इनपुट डेटा को पार्स करते समय बाहरी पायथन पुस्तकालयों को कॉल करना उपयोगी होता है। आप उपयोग कर सकते हैं tf.py_function() एक में ऑपरेशन Dataset.map() परिवर्तन।

उदाहरण के लिए, यदि आप एक यादृच्छिक रोटेशन लागू करना चाहते हैं, tf.image मॉड्यूल केवल tf.image.rot90 जो छवि वृद्धि के लिए बहुत उपयोगी नहीं है।

प्रदर्शित करने के लिए tf.py_function , का उपयोग करके देखें scipy.ndimage.rotate बजाय समारोह:

import scipy.ndimage as ndimage

def random_rotate_image(image):
  image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
  return image
image, label = next(iter(images_ds))
image = random_rotate_image(image)
show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

पीएनजी

के साथ इस समारोह का उपयोग करने के Dataset.map ही कैविएट्स के साथ के रूप में लागू Dataset.from_generator , तो आपको समारोह लागू वापसी आकृति और प्रकारों का वर्णन करने की जरूरत है:

def tf_random_rotate_image(image, label):
  im_shape = image.shape
  [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])
  image.set_shape(im_shape)
  return image, label
rot_ds = images_ds.map(tf_random_rotate_image)

for image, label in rot_ds.take(2):
  show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

पीएनजी

पीएनजी

पार्सिंग tf.Example प्रोटोकॉल बफ़र संदेशों

कई इनपुट पाइपलाइनों निकालने tf.train.Example एक TFRecord प्रारूप से प्रोटोकॉल बफ़र संदेशों। प्रत्येक tf.train.Example रिकॉर्ड एक या अधिक "सुविधाओं" शामिल है, और इनपुट पाइपलाइन आम तौर पर tensors में इन सुविधाओं को बदल देता है।

fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

आप के साथ काम कर सकते हैं tf.train.Example एक की protos बाहर tf.data.Dataset डेटा को समझने के लिए:

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

feature = parsed.features.feature
raw_img = feature['image/encoded'].bytes_list.value[0]
img = tf.image.decode_png(raw_img)
plt.imshow(img)
plt.axis('off')
_ = plt.title(feature["image/text"].bytes_list.value[0])

पीएनजी

raw_example = next(iter(dataset))
def tf_parse(eg):
  example = tf.io.parse_example(
      eg[tf.newaxis], {
          'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
          'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)
      })
  return example['image/encoded'][0], example['image/text'][0]
img, txt = tf_parse(raw_example)
print(txt.numpy())
print(repr(img.numpy()[:20]), "...")
b'Rue Perreyon'
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02X' ...
decoded = dataset.map(tf_parse)
decoded
<MapDataset element_spec=(TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(), dtype=tf.string, name=None))>
image_batch, text_batch = next(iter(decoded.batch(10)))
image_batch.shape
TensorShape([10])

समय श्रृंखला विंडोिंग

अंत समय श्रृंखला उदाहरण के लिए एक अंत के लिए देखें: समय श्रृंखला भविष्यवाणी

समय श्रृंखला डेटा को अक्सर समय अक्ष के साथ व्यवस्थित किया जाता है।

एक सरल प्रयोग करें Dataset.range प्रदर्शित करने के लिए:

range_ds = tf.data.Dataset.range(100000)

आमतौर पर, इस प्रकार के डेटा पर आधारित मॉडल एक सन्निहित समय टुकड़ा चाहते हैं।

डेटा बैच करने का सबसे आसान तरीका होगा:

का उपयोग करते हुए batch

batches = range_ds.batch(10, drop_remainder=True)

for batch in batches.take(5):
  print(batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]

या भविष्य में एक कदम आगे की भविष्यवाणी करने के लिए, आप सुविधाओं और लेबल को एक दूसरे के सापेक्ष एक कदम से स्थानांतरित कर सकते हैं:

def dense_1_step(batch):
  # Shift features and labels one step relative to each other.
  return batch[:-1], batch[1:]

predict_dense_1_step = batches.map(dense_1_step)

for features, label in predict_dense_1_step.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8]  =>  [1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18]  =>  [11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28]  =>  [21 22 23 24 25 26 27 28 29]

एक निश्चित ऑफसेट के बजाय पूरी विंडो की भविष्यवाणी करने के लिए आप बैचों को दो भागों में विभाजित कर सकते हैं:

batches = range_ds.batch(15, drop_remainder=True)

def label_next_5_steps(batch):
  return (batch[:-5],   # Inputs: All except the last 5 steps
          batch[-5:])   # Labels: The last 5 steps

predict_5_steps = batches.map(label_next_5_steps)

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]  =>  [25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]  =>  [40 41 42 43 44]

एक बैच की सुविधाओं और एक अन्य के लेबल के बीच कुछ ओवरलैप करने के लिए, का उपयोग Dataset.zip :

feature_length = 10
label_length = 3

features = range_ds.batch(feature_length, drop_remainder=True)
labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:label_length])

predicted_steps = tf.data.Dataset.zip((features, labels))

for features, label in predicted_steps.take(5):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12]
[10 11 12 13 14 15 16 17 18 19]  =>  [20 21 22]
[20 21 22 23 24 25 26 27 28 29]  =>  [30 31 32]
[30 31 32 33 34 35 36 37 38 39]  =>  [40 41 42]
[40 41 42 43 44 45 46 47 48 49]  =>  [50 51 52]

का उपयोग करते हुए window

का उपयोग करते समय Dataset.batch काम करता है, ऐसी परिस्थितियाँ होती है जहाँ आप बेहतर नियंत्रण की आवश्यकता हो सकती है। Dataset.window विधि आप पूरा नियंत्रण देता है, लेकिन कुछ देखभाल की आवश्यकता होती है: यह एक रिटर्न Dataset के Datasets । देखें डेटासेट संरचना जानकारी के लिए।

window_size = 5

windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):
  print(sub_ds)
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>
<_VariantDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>

Dataset.flat_map विधि डेटासेट के एक डाटासेट लेने के लिए और एक भी डाटासेट में समतल कर सकते हैं:

for x in windows.flat_map(lambda x: x).take(30):
   print(x.numpy(), end=' ')
0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 3 4 5 6 7 4 5 6 7 8 5 6 7 8 9

लगभग सभी मामलों में, आप चाहते हैं .batch पहले डाटासेट:

def sub_to_batch(sub):
  return sub.batch(window_size, drop_remainder=True)

for example in windows.flat_map(sub_to_batch).take(5):
  print(example.numpy())
[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

अब, आप देख सकते हैं कि shift तर्क नियंत्रण कितना प्रत्येक विंडो चाल से अधिक।

इसे एक साथ रखकर आप यह फ़ंक्शन लिख सकते हैं:

def make_window_dataset(ds, window_size=5, shift=1, stride=1):
  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub):
    return sub.batch(window_size, drop_remainder=True)

  windows = windows.flat_map(sub_to_batch)
  return windows
ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)

for example in ds.take(10):
  print(example.numpy())
[ 0  3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34 37]
[15 18 21 24 27 30 33 36 39 42]
[20 23 26 29 32 35 38 41 44 47]
[25 28 31 34 37 40 43 46 49 52]
[30 33 36 39 42 45 48 51 54 57]
[35 38 41 44 47 50 53 56 59 62]
[40 43 46 49 52 55 58 61 64 67]
[45 48 51 54 57 60 63 66 69 72]

फिर पहले की तरह लेबल निकालना आसान है:

dense_labels_ds = ds.map(dense_1_step)

for inputs,labels in dense_labels_ds.take(3):
  print(inputs.numpy(), "=>", labels.numpy())
[ 0  3  6  9 12 15 18 21 24] => [ 3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29] => [ 8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34] => [13 16 19 22 25 28 31 34 37]

रीसेंपलिंग

बहुत वर्ग-असंतुलित डेटासेट के साथ काम करते समय, आप डेटासेट को फिर से नमूना करना चाह सकते हैं। tf.data दो तरीकों यह करने के लिए प्रदान करता है। क्रेडिट कार्ड धोखाधड़ी डेटासेट इस तरह की समस्या का एक अच्छा उदाहरण है।

zip_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',
    fname='creditcard.zip',
    extract=True)

csv_path = zip_path.replace('.zip', '.csv')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip
69156864/69155632 [==============================] - 3s 0us/step
69165056/69155632 [==============================] - 3s 0us/step
creditcard_ds = tf.data.experimental.make_csv_dataset(
    csv_path, batch_size=1024, label_name="Class",
    # Set the column types: 30 floats and an int.
    column_defaults=[float()]*30+[int()])

अब, वर्गों के वितरण की जाँच करें, यह अत्यधिक विषम है:

def count(counts, batch):
  features, labels = batch
  class_1 = labels == 1
  class_1 = tf.cast(class_1, tf.int32)

  class_0 = labels == 0
  class_0 = tf.cast(class_0, tf.int32)

  counts['class_0'] += tf.reduce_sum(class_0)
  counts['class_1'] += tf.reduce_sum(class_1)

  return counts
counts = creditcard_ds.take(10).reduce(
    initial_state={'class_0': 0, 'class_1': 0},
    reduce_func = count)

counts = np.array([counts['class_0'].numpy(),
                   counts['class_1'].numpy()]).astype(np.float32)

fractions = counts/counts.sum()
print(fractions)
[0.9951 0.0049]

असंतुलित डेटासेट के साथ प्रशिक्षण का एक सामान्य तरीका इसे संतुलित करना है। tf.data कुछ तरीकों जो इस कार्यप्रवाह सक्षम शामिल हैं:

डेटासेट नमूनाकरण

एक डाटासेट resampling के लिए एक दृष्टिकोण का उपयोग है sample_from_datasets । तो आपको एक अलग होने पर ही ये लागू है data.Dataset प्रत्येक वर्ग के लिए।

यहां, क्रेडिट कार्ड धोखाधड़ी डेटा से उन्हें उत्पन्न करने के लिए बस फ़िल्टर का उपयोग करें:

negative_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==0)
    .repeat())
positive_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==1)
    .repeat())
for features, label in positive_ds.batch(10).take(1):
  print(label.numpy())
[1 1 1 1 1 1 1 1 1 1]

उपयोग करने के लिए tf.data.Dataset.sample_from_datasets डेटासेट, और प्रत्येक के लिए वजन पारित:

balanced_ds = tf.data.Dataset.sample_from_datasets(
    [negative_ds, positive_ds], [0.5, 0.5]).batch(10)

अब डेटासेट प्रत्येक वर्ग के उदाहरण 50/50 संभावना के साथ तैयार करता है:

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
[0 1 0 0 1 0 1 1 0 0]
[0 1 0 1 0 0 1 0 1 0]
[1 0 0 0 0 1 0 1 0 1]
[1 0 0 0 1 1 0 0 1 1]
[0 1 0 1 0 1 0 1 1 1]
[1 0 0 0 0 1 0 0 0 1]
[0 1 0 0 1 1 0 0 0 0]
[0 0 0 0 0 0 1 0 1 1]
[0 1 1 1 0 1 0 1 0 1]
[0 0 1 1 0 0 1 0 1 1]

अस्वीकृति पुन: नमूनाकरण

इसके बाद के संस्करण के साथ एक समस्या Dataset.sample_from_datasets दृष्टिकोण है कि यह एक अलग की जरूरत है tf.data.Dataset वर्ग प्रति। आप इस्तेमाल कर सकते हैं Dataset.filter उन दो डेटासेट बनाने के लिए, लेकिन सभी डेटा में है कि परिणाम दो बार लोड किए जा रहे।

data.Dataset.rejection_resample विधि यह संतुलित है, जबकि केवल एक बार लोड हो रहा है एक डाटासेट के लिए लागू किया जा सकता है। संतुलन हासिल करने के लिए तत्वों को डेटासेट से हटा दिया जाएगा।

data.Dataset.rejection_resample एक लेता है class_func तर्क। यह class_func प्रत्येक डेटासेट तत्व को लागू किया जाता है, और निर्धारित करने के लिए जो वर्ग एक उदाहरण संतुलन के प्रयोजनों के लिए के अंतर्गत आता है प्रयोग किया जाता है।

यहां लक्ष्य करने के लिए सक्षम वितरण संतुलित करने के लिए है, और के तत्वों creditcard_ds पहले से ही कर रहे हैं (features, label) जोड़े। तो class_func सिर्फ उन लेबल वापस जाने के लिए की जरूरत है:

def class_func(features, label):
  return label

व्यक्तिगत उदाहरण के साथ resampling विधि सौदों, इसलिए इस मामले में आप चाहिए unbatch डाटासेट कि विधि लागू करने से पहले।

विधि को लक्ष्य वितरण की आवश्यकता होती है, और वैकल्पिक रूप से इनपुट के रूप में प्रारंभिक वितरण अनुमान की आवश्यकता होती है।

resample_ds = (
    creditcard_ds
    .unbatch()
    .rejection_resample(class_func, target_dist=[0.5,0.5],
                        initial_dist=fractions)
    .batch(10))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:5797: Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20.
Instructions for updating:
Use tf.print instead of tf.Print. Note that tf.print returns a no-output operator that directly prints the output. Outside of defuns or eager mode, this operator will not be executed unless it is directly specified in session.run or used as a control dependency for other operators. This is only a concern in graph mode. Below is an example of how to ensure tf.print executes in graph mode:

rejection_resample विधि रिटर्न (class, example) जोड़े जहां class के उत्पादन में है class_func । इस मामले में, example पहले से ही एक था (feature, label) जोड़ी है, तो उपयोग map लेबल के अतिरिक्त प्रति ड्रॉप करने:

balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)

अब डेटासेट प्रत्येक वर्ग के उदाहरण 50/50 संभावना के साथ तैयार करता है:

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
Proportion of examples rejected by sampler is high: [0.995117188][0.995117188 0.0048828125][0 1]
Proportion of examples rejected by sampler is high: [0.995117188][0.995117188 0.0048828125][0 1]
Proportion of examples rejected by sampler is high: [0.995117188][0.995117188 0.0048828125][0 1]
Proportion of examples rejected by sampler is high: [0.995117188][0.995117188 0.0048828125][0 1]
Proportion of examples rejected by sampler is high: [0.995117188][0.995117188 0.0048828125][0 1]
Proportion of examples rejected by sampler is high: [0.995117188][0.995117188 0.0048828125][0 1]
Proportion of examples rejected by sampler is high: [0.995117188][0.995117188 0.0048828125][0 1]
Proportion of examples rejected by sampler is high: [0.995117188][0.995117188 0.0048828125][0 1]
Proportion of examples rejected by sampler is high: [0.995117188][0.995117188 0.0048828125][0 1]
Proportion of examples rejected by sampler is high: [0.995117188][0.995117188 0.0048828125][0 1]
[0 1 1 0 1 0 0 0 0 1]
[1 0 1 0 0 1 1 1 1 1]
[1 1 0 0 1 0 0 0 1 1]
[1 0 0 1 1 1 1 1 0 0]
[0 1 0 1 0 0 0 1 0 0]
[0 0 1 0 0 1 1 0 1 1]
[0 1 1 1 0 1 0 0 1 0]
[1 0 0 1 0 0 0 1 1 1]
[0 1 1 1 1 0 0 0 1 1]
[1 0 0 1 0 1 0 0 1 1]

इटरेटर चेकपॉइंटिंग

Tensorflow समर्थन चौकियों लेने इसलिए जब अपने प्रशिक्षण प्रक्रिया पुन: प्रारंभ हो वह अपनी प्रगति के सबसे ठीक करने के लिए नवीनतम चौकी बहाल कर सकते हैं कि। मॉडल चर को चेकपॉइंट करने के अलावा, आप डेटासेट इटरेटर की प्रगति की जांच भी कर सकते हैं। यह उपयोगी हो सकता है यदि आपके पास एक बड़ा डेटासेट है और प्रत्येक पुनरारंभ पर शुरुआत से डेटासेट प्रारंभ नहीं करना चाहते हैं। नोट हालांकि इटरेटर चौकियों बड़ी हो सकती है कि, इस तरह के रूप परिवर्तनों के बाद से shuffle और prefetch इटरेटर भीतर बफरिंग तत्वों की आवश्यकता होती है।

एक चौकी में अपने इटरेटर को शामिल करने के लिए इटरेटर पारित tf.train.Checkpoint निर्माता।

range_ds = tf.data.Dataset.range(20)

iterator = iter(range_ds)
ckpt = tf.train.Checkpoint(step=tf.Variable(0), iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, '/tmp/my_ckpt', max_to_keep=3)

print([next(iterator).numpy() for _ in range(5)])

save_path = manager.save()

print([next(iterator).numpy() for _ in range(5)])

ckpt.restore(manager.latest_checkpoint)

print([next(iterator).numpy() for _ in range(5)])
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[5, 6, 7, 8, 9]

tf.keras के साथ tf.data का उपयोग करना

tf.keras एपीआई सरल बनाने और शिक्षण मॉडेल को क्रियान्वित करने के कई पहलुओं। इसके .fit() और .evaluate() और .predict() एपीआई इनपुट के रूप में डेटासेट समर्थन करते हैं। यहाँ एक त्वरित डेटासेट और मॉडल सेटअप है:

train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

model = tf.keras.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
              metrics=['accuracy'])

की एक डाटासेट पासिंग (feature, label) जोड़े सब है कि के लिए आवश्यक है Model.fit और Model.evaluate :

model.fit(fmnist_train_ds, epochs=2)
Epoch 1/2
1875/1875 [==============================] - 4s 2ms/step - loss: 0.6053 - accuracy: 0.7952
Epoch 2/2
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4620 - accuracy: 0.8425
<keras.callbacks.History at 0x7f25fc2e4e10>

आपको कॉल करके एक अनंत डाटासेट, उदाहरण के लिए पार कर लेते हैं Dataset.repeat() , तुम बस की जरूरत भी पारित steps_per_epoch तर्क:

model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)
Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4028 - accuracy: 0.8516
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4591 - accuracy: 0.8344
<keras.callbacks.History at 0x7f25fc04cad0>

मूल्यांकन के लिए आप मूल्यांकन चरणों की संख्या पास कर सकते हैं:

loss, accuracy = model.evaluate(fmnist_train_ds)
print("Loss :", loss)
print("Accuracy :", accuracy)
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4340 - accuracy: 0.8518
Loss : 0.43400809168815613
Accuracy : 0.8517833352088928

लंबे डेटासेट के लिए, मूल्यांकन करने के लिए चरणों की संख्या सेट करें:

loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)
print("Loss :", loss)
print("Accuracy :", accuracy)
10/10 [==============================] - 0s 2ms/step - loss: 0.3548 - accuracy: 0.8750
Loss : 0.3548365533351898
Accuracy : 0.875

लेबल जब बुला में की आवश्यकता नहीं है Model.predict

predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)
result = model.predict(predict_ds, steps = 10)
print(result.shape)
(320, 10)

लेकिन यदि आप एक डेटासेट पास करते हैं तो लेबल को अनदेखा कर दिया जाता है:

result = model.predict(fmnist_train_ds, steps = 10)
print(result.shape)
(320, 10)