View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Overview
The objective of Avro Dataset API is to load Avro formatted data natively into TensorFlow as TensorFlow dataset. Avro is a data serialization system similiar to Protocol Buffers. It's widely used in Apache Hadoop where it can provide both a serialization format for persistent data, and a wire format for communication between Hadoop nodes. Avro data is a row-oriented, compacted binary data format. It relies on schema which is stored as a separate JSON file. For the spec of Avro format and schema declaration, please refer to the official manual.
Setup package
Install the required tensorflow-io package
pip install tensorflow-io
Import packages
import tensorflow as tf
import tensorflow_io as tfio
Validate tf and tfio imports
print("tensorflow-io version: {}".format(tfio.__version__))
print("tensorflow version: {}".format(tf.__version__))
tensorflow-io version: 0.18.0 tensorflow version: 2.5.0
Usage
Explore the dataset
For the purpose of this tutorial, let's download the sample Avro dataset.
Download a sample Avro file:
curl -OL https://github.com/tensorflow/io/raw/master/docs/tutorials/avro/train.avro
ls -l train.avro
% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 151 100 151 0 0 1268 0 --:--:-- --:--:-- --:--:-- 1268 100 369 100 369 0 0 1255 0 --:--:-- --:--:-- --:--:-- 1255 -rw-rw-r-- 1 kbuilder kokoro 369 May 25 22:23 train.avro
Download the corresponding schema file of the sample Avro file:
curl -OL https://github.com/tensorflow/io/raw/master/docs/tutorials/avro/train.avsc
ls -l train.avsc
% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 151 100 151 0 0 1247 0 --:--:-- --:--:-- --:--:-- 1247 100 271 100 271 0 0 780 0 --:--:-- --:--:-- --:--:-- 780 -rw-rw-r-- 1 kbuilder kokoro 271 May 25 22:23 train.avsc
In the above example, a testing Avro dataset were created based on mnist dataset. The original mnist dataset in TFRecord format is generated from TF named dataset. However, the mnist dataset is too large as a demo dataset. For simplicity purpose, most of it were trimmed and first few records only were kept. Moreover, additional trimming was done for image
field in original mnist dataset and mapped it to features
field in Avro. So the avro file train.avro
has 4 records, each of which has 3 fields: features
, which is an array of int, label
, an int or null, and dataType
, an enum. To view the decoded train.avro
(Note the original avro data file is not human readable as avro is a compacted format):
Install the required package to read Avro file:
pip install avro
To read and print an Avro file in a human-readable format:
from avro.io import DatumReader
from avro.datafile import DataFileReader
import json
def print_avro(avro_file, max_record_num=None):
if max_record_num is not None and max_record_num <= 0:
return
with open(avro_file, 'rb') as avro_handler:
reader = DataFileReader(avro_handler, DatumReader())
record_count = 0
for record in reader:
record_count = record_count+1
print(record)
if max_record_num is not None and record_count == max_record_num:
break
print_avro(avro_file='train.avro')
{'features': [0, 0, 0, 1, 4], 'label': None, 'dataType': 'TRAINING'} {'features': [0, 0], 'label': 2, 'dataType': 'TRAINING'} {'features': [0], 'label': 3, 'dataType': 'VALIDATION'} {'features': [1], 'label': 4, 'dataType': 'VALIDATION'}
And the schema of train.avro
which is represented by train.avsc
is a JSON-formatted file.
To view the train.avsc
:
def print_schema(avro_schema_file):
with open(avro_schema_file, 'r') as handle:
parsed = json.load(handle)
print(json.dumps(parsed, indent=4, sort_keys=True))
print_schema('train.avsc')
{ "fields": [ { "name": "features", "type": { "items": "int", "type": "array" } }, { "name": "label", "type": [ "int", "null" ] }, { "name": "dataType", "type": { "name": "dataTypes", "symbols": [ "TRAINING", "VALIDATION" ], "type": "enum" } } ], "name": "ImageDataset", "type": "record" }
Prepare the dataset
Load train.avro
as TensorFlow dataset with Avro dataset API:
features = {
'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32),
'label': tf.io.FixedLenFeature(shape=[], dtype=tf.int32, default_value=-100),
'dataType': tf.io.FixedLenFeature(shape=[], dtype=tf.string)
}
schema = tf.io.gfile.GFile('train.avsc').read()
dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],
reader_schema=schema,
features=features,
shuffle=False,
batch_size=3,
num_epochs=1)
for record in dataset:
print(record['features[*]'])
print(record['label'])
print(record['dataType'])
print("--------------------")
SparseTensor(indices=tf.Tensor( [[0 0] [0 1] [0 2] [0 3] [0 4] [1 0] [1 1] [2 0]], shape=(8, 2), dtype=int64), values=tf.Tensor([0 0 0 1 4 0 0 0], shape=(8,), dtype=int32), dense_shape=tf.Tensor([3 5], shape=(2,), dtype=int64)) tf.Tensor([-100 2 3], shape=(3,), dtype=int32) tf.Tensor([b'TRAINING' b'TRAINING' b'VALIDATION'], shape=(3,), dtype=string) -------------------- SparseTensor(indices=tf.Tensor([[0 0]], shape=(1, 2), dtype=int64), values=tf.Tensor([1], shape=(1,), dtype=int32), dense_shape=tf.Tensor([1 1], shape=(2,), dtype=int64)) tf.Tensor([4], shape=(1,), dtype=int32) tf.Tensor([b'VALIDATION'], shape=(1,), dtype=string) --------------------
The above example converts train.avro
into tensorflow dataset. Each element of the dataset is a dictionary whose key is the feature name, value is the converted sparse or dense tensor.
E.g, it converts features
, label
, dataType
field to a VarLenFeature(SparseTensor), FixedLenFeature(DenseTensor), and FixedLenFeature(DenseTensor) respectively. Since batch_size is 3, it coerce 3 records from train.avro
into one element in the result dataset.
For the first record in train.avro
whose label is null, avro reader replaces it with the specified default value(-100).
In this example, there're 4 records in total in train.avro
. Since batch size is 3, the result dataset contains 3 elements, last of which's batch size is 1. However user is also able to drop the last batch if the size is smaller than batch size by enabling drop_final_batch
. E.g:
dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],
reader_schema=schema,
features=features,
shuffle=False,
batch_size=3,
drop_final_batch=True,
num_epochs=1)
for record in dataset:
print(record)
{'features[*]': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f97656423d0>, 'dataType': <tf.Tensor: shape=(3,), dtype=string, numpy=array([b'TRAINING', b'TRAINING', b'VALIDATION'], dtype=object)>, 'label': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([-100, 2, 3], dtype=int32)>}
One can also increase num_parallel_reads to expediate Avro data processing by increasing avro parse/read parallelism.
dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],
reader_schema=schema,
features=features,
shuffle=False,
num_parallel_reads=16,
batch_size=3,
drop_final_batch=True,
num_epochs=1)
for record in dataset:
print(record)
{'features[*]': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f9765693990>, 'dataType': <tf.Tensor: shape=(3,), dtype=string, numpy=array([b'TRAINING', b'TRAINING', b'VALIDATION'], dtype=object)>, 'label': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([-100, 2, 3], dtype=int32)>}
For detailed usage of make_avro_record_dataset
, please refer to API doc.
Train tf.keras models with Avro dataset
Now let's walk through an end-to-end example of tf.keras model training with Avro dataset based on mnist dataset.
Load train.avro
as TensorFlow dataset with Avro dataset API:
features = {
'features[*]': tfio.experimental.columnar.VarLenFeatureWithRank(dtype=tf.int32)
}
schema = tf.io.gfile.GFile('train.avsc').read()
dataset = tfio.experimental.columnar.make_avro_record_dataset(file_pattern=['train.avro'],
reader_schema=schema,
features=features,
shuffle=False,
batch_size=1,
num_epochs=1)
Define a simple keras model:
def build_and_compile_cnn_model():
model = tf.keras.Sequential()
model.compile(optimizer='sgd', loss='mse')
return model
model = build_and_compile_cnn_model()
Train the keras model with Avro dataset:
model.fit(x=dataset, epochs=1, steps_per_epoch=1, verbose=1)
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'dict'> input: {'features[*]': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f94b00645d0>} Consider rewriting this model with the Functional API. WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'dict'> input: {'features[*]': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f976476ca90>} Consider rewriting this model with the Functional API. 1/1 [==============================] - 0s 60ms/step - loss: 0.0000e+00 <tensorflow.python.keras.callbacks.History at 0x7f94ec08c6d0>
The avro dataset can parse and coerce any avro data into TensorFlow tensors, including records in records, maps, arrays, branches, and enumerations. The parsing information is passed into the avro dataset implementation as a map where keys encode how to parse the data values encode on how to coerce the data into TensorFlow tensors – deciding the primitive type (e.g. bool, int, long, float, double, string) as well as the tensor type (e.g. sparse or dense). A listing of TensorFlow's parser types (see Table 1) and the coercion of primitive types (Table 2) is provided.
Table 1 the supported TensorFlow parser types:
TensorFlow Parser Types | TensorFlow Tensors | Explanation |
---|---|---|
tf.FixedLenFeature([], tf.int32) | dense tensor | Parse a fixed length feature; that is all rows have the same constant number of elements, e.g. just one element or an array that has always the same number of elements for each row |
tf.SparseFeature(index_key=['key_1st_index', 'key_2nd_index'], value_key='key_value', dtype=tf.int64, size=[20, 50]) | sparse tensor | Parse a sparse feature where each row has a variable length list of indices and values. The 'index_key' identifies the indices. The 'value_key' identifies the value. The 'dtype' is the data type. The 'size' is the expected maximum index value for each index entry |
tfio.experimental.columnar.VarLenFeatureWithRank([],tf.int64) | sparse tensor | Parse a variable length feature; that means each data row can have a variable number of elements, e.g. the 1st row has 5 elements, the 2nd row has 7 elements |
Table 2 the supported conversion from Avro types to TensorFlow's types:
Avro Primitive Type | TensorFlow Primitive Type |
---|---|
boolean: a binary value | tf.bool |
bytes: a sequence of 8-bit unsigned bytes | tf.string |
double: double precision 64-bit IEEE floating point number | tf.float64 |
enum: enumeration type | tf.string using the symbol name |
float: single precision 32-bit IEEE floating point number | tf.float32 |
int: 32-bit signed integer | tf.int32 |
long: 64-bit signed integer | tf.int64 |
null: no value | uses default value |
string: unicode character sequence | tf.string |
A comprehensive set of examples of Avro dataset API is provided within the tests.