View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
The TFRecord format is a simple format for storing a sequence of binary records.
Protocol buffers are a cross-platform, cross-language library for efficient serialization of structured data.
Protocol messages are defined by .proto
files, these are often the easiest way to understand a message type.
The tf.train.Example
message (or protobuf) is a flexible message type that represents a {"string": value}
mapping. It is designed for use with TensorFlow and is used throughout the higher-level APIs such as TFX.
This notebook demonstrates how to create, parse, and use the tf.train.Example
message, and then serialize, write, and read tf.train.Example
messages to and from .tfrecord
files.
Setup
import tensorflow as tf
import numpy as np
import IPython.display as display
2024-08-16 07:04:00.479365: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-08-16 07:04:00.500427: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-08-16 07:04:00.506819: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
tf.train.Example
Data types for tf.train.Example
Fundamentally, a tf.train.Example
is a {"string": tf.train.Feature}
mapping.
The tf.train.Feature
message type can accept one of the following three types (See the .proto
file for reference). Most other generic types can be coerced into one of these:
tf.train.BytesList
(the following types can be coerced)string
byte
tf.train.FloatList
(the following types can be coerced)float
(float32
)double
(float64
)
tf.train.Int64List
(the following types can be coerced)bool
enum
int32
uint32
int64
uint64
In order to convert a standard TensorFlow type to a tf.train.Example
-compatible tf.train.Feature
, you can use the shortcut functions below. Note that each function takes a scalar input value and returns a tf.train.Feature
containing one of the three list
types above:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
Below are some examples of how these functions work. Note the varying input types and the standardized output types. If the input type for a function does not match one of the coercible types stated above, the function will raise an exception (e.g. _int64_feature(1.0)
will error out because 1.0
is a float—therefore, it should be used with the _float_feature
function instead):
print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))
print(_float_feature(np.exp(1)))
print(_int64_feature(True))
print(_int64_feature(1))
bytes_list { value: "test_string" } bytes_list { value: "test_bytes" } float_list { value: 2.7182817459106445 } int64_list { value: 1 } int64_list { value: 1 } WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1723791843.092175 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.095539 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.099144 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.102704 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.113978 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.117024 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.120462 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.123843 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.126718 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.129679 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.133016 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791843.136422 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.360431 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.362573 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.364572 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.366645 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.368643 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.370639 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.372538 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.374538 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.376473 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.378464 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.380358 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.382352 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.421137 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.423231 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.425187 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.427219 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.429165 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.431166 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.433071 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.435053 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.436984 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.439359 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.441678 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723791844.444034 194894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
All proto messages can be serialized to a binary-string using the .SerializeToString
method:
feature = _float_feature(np.exp(1))
feature.SerializeToString()
b'\x12\x06\n\x04T\xf8-@'
Creating a tf.train.Example
message
Suppose you want to create a tf.train.Example
message from existing data. In practice, the dataset may come from anywhere, but the procedure of creating the tf.train.Example
message from a single observation will be the same:
Within each observation, each value needs to be converted to a
tf.train.Feature
containing one of the 3 compatible types, using one of the functions above.You create a map (dictionary) from the feature name string to the encoded feature value produced in #1.
The map produced in step 2 is converted to a
Features
message.
In this notebook, you will create a dataset using NumPy.
This dataset will have 4 features:
- a boolean feature,
False
orTrue
with equal probability - an integer feature uniformly randomly chosen from
[0, 5]
- a string feature generated from a string table by using the integer feature as an index
- a float feature from a standard normal distribution
Consider a sample consisting of 10,000 independently and identically distributed observations from each of the above distributions:
# The number of observations in the dataset.
n_observations = int(1e4)
# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)
# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)
# String feature.
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]
# Float feature, from a standard normal distribution.
feature3 = np.random.randn(n_observations)
Each of these features can be coerced into a tf.train.Example
-compatible type using one of _bytes_feature
, _float_feature
, _int64_feature
. You can then create a tf.train.Example
message from these encoded features:
@tf.py_function(Tout=tf.string)
def serialize_example(feature0, feature1, feature2, feature3):
"""
Creates a tf.train.Example message ready to be written to a file.
"""
# Create a dictionary mapping the feature name to the tf.train.Example-compatible
# data type.
feature = {
'feature0': _int64_feature(feature0),
'feature1': _int64_feature(feature1),
'feature2': _bytes_feature(feature2),
'feature3': _float_feature(feature3),
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
For example, suppose you have a single observation from the dataset, [False, 4, bytes('goat'), 0.9876]
. You can create and print the tf.train.Example
message for this observation using serialize_example()
. Each single observation will be written as a Features
message as per the above. Note that the tf.train.Example
message is just a wrapper around the Features
message:
# This is an example observation from the dataset.
example_observation = [False, 4, b'goat', 0.9876]
serialized_example = serialize_example(*example_observation)
serialized_example
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'>
To decode the message use the tf.train.Example.FromString
method.
example_proto = tf.train.Example.FromString(serialized_example.numpy())
example_proto
features { feature { key: "feature0" value { int64_list { value: 0 } } } feature { key: "feature1" value { int64_list { value: 4 } } } feature { key: "feature2" value { bytes_list { value: "goat" } } } feature { key: "feature3" value { float_list { value: 0.9876000285148621 } } } }
TFRecords format details
A TFRecord file contains a sequence of records. The file can only be read sequentially.
Each record contains a byte-string, for the data-payload, plus the data-length, and CRC-32C (32-bit CRC using the Castagnoli polynomial) hashes for integrity checking.
Each record is stored in the following formats:
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
The records are concatenated together to produce the file. CRCs are described here, and the mask of a CRC is:
masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul
Reading and writing TFRecord files
The tf.io
module also contains pure-Python functions for reading and writing TFRecord files.
Writing a TFRecord file
Next, write the 10,000 observations to the file test.tfrecord
. Each observation is converted to a tf.train.Example
message, then written to file. You can then verify that the file test.tfrecord
has been created:
filename = 'test.tfrecord'
# Write the `tf.train.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
for i in range(n_observations):
example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
writer.write(example.numpy())
du -sh {filename}
984K test.tfrecord
Reading a TFRecord file in python
These serialized tensors can be easily parsed using tf.train.Example.ParseFromString
:
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>
for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
features { feature { key: "feature0" value { int64_list { value: 0 } } } feature { key: "feature1" value { int64_list { value: 3 } } } feature { key: "feature2" value { bytes_list { value: "horse" } } } feature { key: "feature3" value { float_list { value: -0.6452258229255676 } } } }
That returns a tf.train.Example
proto which is dificult to use as is, but it's fundamentally a representation of a:
Dict[str,
Union[List[float],
List[int],
List[str]]]
The following code manually converts the Example
to a dictionary of NumPy arrays, without using TensorFlow Ops. Refer to the PROTO file for details.
result = {}
# example.features.feature is the dictionary
for key, feature in example.features.feature.items():
# The values are the Feature objects which contain a `kind` which contains:
# one of three fields: bytes_list, float_list, int64_list
kind = feature.WhichOneof('kind')
result[key] = np.array(getattr(feature, kind).value)
result
{'feature2': array([b'horse'], dtype='|S5'), 'feature1': array([3]), 'feature0': array([0]), 'feature3': array([-0.64522582])}
Reading a TFRecord file Using tf.data
You can also read the TFRecord file using the tf.data.TFRecordDataset
class.
More information on consuming TFRecord files using tf.data
can be found in the tf.data: Build TensorFlow input pipelines guide.
Using TFRecordDataset
s can be useful for standardizing input data and optimizing performance.
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>
At this point the dataset contains serialized tf.train.Example
messages. When iterated over it returns these as scalar string tensors.
Use the .take
method to only show the first 10 records.
for raw_record in raw_dataset.take(10):
print(repr(raw_record))
<tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x85-%\xbf'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xc0,\xec\xbe\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xd6O\xb8?\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xf3\xbf\xa0\xbf\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x97\x1aE\xbe\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04~<\xe0\xbe\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x0cW\x18>\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x8d\xdd\xcb\xbf\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x95\xdf=?\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x8dC\xb8\xbd\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00'>
These tensors can be parsed using the function below. Note that the feature_description
is necessary here because tf.data.Dataset
s use graph-execution, and need this description to build their shape and type signature:
# Create a description of the features.
feature_description = {
'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}
def _parse_function(example_proto):
# Parse the input `tf.train.Example` proto using the dictionary above.
return tf.io.parse_single_example(example_proto, feature_description)
Alternatively, use tf.parse_example
to parse the whole batch at once. Apply this function to each item in the dataset using the tf.data.Dataset.map
method:
parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset
<_MapDataset element_spec={'feature0': TensorSpec(shape=(), dtype=tf.int64, name=None), 'feature1': TensorSpec(shape=(), dtype=tf.int64, name=None), 'feature2': TensorSpec(shape=(), dtype=tf.string, name=None), 'feature3': TensorSpec(shape=(), dtype=tf.float32, name=None)}>
Use eager execution to display the observations in the dataset. There are 10,000 observations in this dataset, but you will only display the first 10. The data is displayed as a dictionary of features. Each item is a tf.Tensor
, and the numpy
element of this tensor displays the value of the feature:
for parsed_record in parsed_dataset.take(10):
print(repr(parsed_record))
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.6452258>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.46127892>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=1.4399364>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-1.2558578>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.19248424>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.43796152>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.14876956>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-1.5926987>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.74169284>} {'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.08997259>}
Here, the tf.parse_example
function unpacks the tf.train.Example
fields into standard tensors.
Walkthrough: Reading and writing image data
This is an end-to-end example of how to read and write image data using TFRecords. Using an image as input data, you will write the data as a TFRecord file, then read the file back and display the image.
This can be useful if, for example, you want to use several models on the same input dataset. Instead of storing the image data raw, it can be preprocessed into the TFRecords format, and that can be used in all further processing and modelling.
First, let's download this image of a cat in the snow and this photo of the Williamsburg Bridge, NYC under construction.
Fetch the images
cat_in_snow = tf.keras.utils.get_file(
'320px-Felis_catus-cat_on_snow.jpg',
'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file(
'194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg',
'https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg 17858/17858 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg 15477/15477 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))
display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))
Write the TFRecord file
As before, encode the features as types compatible with tf.train.Example
. This stores the raw image string feature, as well as the height, width, depth, and arbitrary label
feature. The latter is used when you write the file to distinguish between the cat image and the bridge image. Use 0
for the cat image, and 1
for the bridge image:
image_labels = {
cat_in_snow : 0,
williamsburg_bridge : 1,
}
# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()
label = image_labels[cat_in_snow]
# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
image_shape = tf.io.decode_jpeg(image_string).shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(image_string),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
for line in str(image_example(image_string, label)).split('\n')[:15]:
print(line)
print('...')
features { feature { key: "depth" value { int64_list { value: 3 } } } feature { key: "height" value { int64_list { value: 213 } ...
Notice that all of the features are now stored in the tf.train.Example
message. Next, functionalize the code above and write the example messages to a file named images.tfrecords
:
# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.train.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
for filename, label in image_labels.items():
image_string = open(filename, 'rb').read()
tf_example = image_example(image_string, label)
writer.write(tf_example.SerializeToString())
du -sh {record_file}
36K images.tfrecords
Read the TFRecord file
You now have the file—images.tfrecords
—and can now iterate over the records in it to read back what you wrote. Given that in this example you will only reproduce the image, the only feature you will need is the raw image string. Extract it using the getters described above, namely example.features.feature['image_raw'].bytes_list.value[0]
. You can also use the labels to determine which record is the cat and which one is the bridge:
raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')
# Create a dictionary describing the features.
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
def _parse_image_function(example_proto):
# Parse the input tf.train.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, image_feature_description)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
<_MapDataset element_spec={'depth': TensorSpec(shape=(), dtype=tf.int64, name=None), 'height': TensorSpec(shape=(), dtype=tf.int64, name=None), 'image_raw': TensorSpec(shape=(), dtype=tf.string, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 'width': TensorSpec(shape=(), dtype=tf.int64, name=None)}>
Recover the images from the TFRecord file:
for image_features in parsed_image_dataset:
image_raw = image_features['image_raw'].numpy()
display.display(display.Image(data=image_raw))