Dataset

public abstract class Dataset
Known Direct Subclasses

Represents a potentially large list of independent elements (samples), and allows iteration and transformations to be performed across these elements.

Public Constructors

Dataset(Ops tf, Operand<?> variant, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes)

Public Methods

final Dataset
batch(long batchSize, boolean dropLastBatch)
Groups elements of this dataset into batches.
final Dataset
batch(long batchSize)
Groups elements of this dataset into batches.
static Dataset
fromTensorSlices(Ops tf, List<Operand<?>> tensors, List<Class<? extends TType>> outputTypes)
Creates an in-memory `Dataset` whose elements are slices of the given tensors.
Ops
List<Shape>
getOutputShapes()
Get a list of shapes for each component of this dataset.
List<Class<? extends TType>>
getOutputTypes()
Get a list of output types for each component of this dataset.
Operand<?>
getVariant()
Get the variant tensor representing this dataset.
Iterator<List<Operand<?>>>
iterator()
Creates an iterator which iterates through all batches of this Dataset in an eager fashion.
DatasetIterator
makeInitializeableIterator()
Creates a `DatasetIterator` that can be used to iterate over elements of this dataset.
DatasetIterator
makeOneShotIterator()
Creates a `DatasetIterator` that can be used to iterate over elements of this dataset.
Dataset
map(Function<List<Operand<?>>, List<Operand<?>>> mapper)
Returns a new Dataset which maps a function over all elements returned by this dataset.
Dataset
mapAllComponents(Function<Operand<?>, Operand<?>> mapper)
Returns a new Dataset which maps a function across all elements from this dataset, on all components of each element.
Dataset
mapOneComponent(int index, Function<Operand<?>, Operand<?>> mapper)
Returns a new Dataset which maps a function across all elements from this dataset, on a single component of each element.
final Dataset
skip(long count)
Returns a new `Dataset` which skips `count` initial elements from this dataset
final Dataset
take(long count)
Returns a new `Dataset` with only the first `count` elements from this dataset.
static Dataset
textLineDataset(Ops tf, String filename, String compressionType, long bufferSize)
static Dataset
tfRecordDataset(Ops tf, String filename, String compressionType, long bufferSize)
String

Inherited Methods

Public Constructors

public Dataset (Ops tf, Operand<?> variant, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes)

Public Methods

public final Dataset batch (long batchSize, boolean dropLastBatch)

Groups elements of this dataset into batches.

Parameters
batchSize The number of desired elements per batch
dropLastBatch Whether to leave out the final batch if it has fewer than `batchSize` elements.
Returns
  • A batched Dataset

public final Dataset batch (long batchSize)

Groups elements of this dataset into batches. Includes the last batch, even if it has fewer than `batchSize` elements.

Parameters
batchSize The number of desired elements per batch
Returns
  • A batched Dataset

public static Dataset fromTensorSlices (Ops tf, List<Operand<?>> tensors, List<Class<? extends TType>> outputTypes)

Creates an in-memory `Dataset` whose elements are slices of the given tensors. Each element of this dataset will be a List<Operand<?>>, representing slices (e.g. batches) of the provided tensors.

Parameters
tf Ops Accessor
tensors A list of Operand<?> representing components of this dataset (e.g. features, labels)
outputTypes A list of tensor type classes representing the data type of each component of this dataset.
Returns
  • A new `Dataset`

public Ops getOpsInstance ()

public List<Shape> getOutputShapes ()

Get a list of shapes for each component of this dataset.

public List<Class<? extends TType>> getOutputTypes ()

Get a list of output types for each component of this dataset.

public Operand<?> getVariant ()

Get the variant tensor representing this dataset.

public Iterator<List<Operand<?>>> iterator ()

Creates an iterator which iterates through all batches of this Dataset in an eager fashion. Each batch is a list of components, returned as `Output` objects.

This method enables for-each iteration through batches when running in eager mode. For Graph mode batch iteration, see `makeOneShotIterator`.

Returns
  • an Iterator through batches of this dataset.

public DatasetIterator makeInitializeableIterator ()

Creates a `DatasetIterator` that can be used to iterate over elements of this dataset.

This iterator will have to be initialized with a call to `iterator.makeInitializer(Dataset)` before elements can be retreived in a loop.

Returns
  • A new `DatasetIterator` based on this dataset's structure.

public DatasetIterator makeOneShotIterator ()

Creates a `DatasetIterator` that can be used to iterate over elements of this dataset. Using `makeOneShotIterator` ensures that the iterator is automatically initialized on this dataset. skips In graph mode, the initializer op will be added to the Graph's intitializer list, which must be run via `tf.init()`:

Ex:

     try (Session session = new Session(graph) {
         // Immediately run initializers
         session.run(tf.init());
     }
 

In eager mode, the initializer will be run automatically as a result of this call.

Returns
  • A new `DatasetIterator` based on this dataset's structure.

public Dataset map (Function<List<Operand<?>>, List<Operand<?>>> mapper)

Returns a new Dataset which maps a function over all elements returned by this dataset.

For example, suppose each element is a List<Operand<?>> with 2 components: (features, labels).

Calling

dataset.map(components -> {
      Operand<?> features = components.get(0);
      Operand<?> labels   = components.get(1);

      return Arrays.asList(
        tf.math.mul(features, tf.constant(2)),
        tf.math.mul(labels, tf.constant(5))
      );
 );
 }
will map the function over the `features` and `labels` components, multiplying features by 2, and multiplying the labels by 5.

Parameters
mapper The function to apply to each element of this iterator.
Returns
  • A new Dataset applying `mapper` to each element of this iterator.

public Dataset mapAllComponents (Function<Operand<?>, Operand<?>> mapper)

Returns a new Dataset which maps a function across all elements from this dataset, on all components of each element.

For example, suppose each element is a List<Operand<?>> with 2 components: (features, labels).

Calling dataset.mapAllComponents(component -> tf.math.mul(component, tf.constant(2))) will map the function over the both the `features` and `labels` components of each element, multiplying them all by 2

Parameters
mapper The function to apply to each component
Returns
  • A new Dataset applying `mapper` to all components of each element.

public Dataset mapOneComponent (int index, Function<Operand<?>, Operand<?>> mapper)

Returns a new Dataset which maps a function across all elements from this dataset, on a single component of each element.

For example, suppose each element is a List<Operand<?>> with 2 components: (features, labels).

Calling dataset.mapOneComponent(0, features -> tf.math.mul(features, tf.constant(2))) will map the function over the `features` component of each element, multiplying each by 2.

Parameters
index The index of the component to transform.
mapper The function to apply to the target component.
Returns
  • A new Dataset applying `mapper` to the component at the chosen index.

public final Dataset skip (long count)

Returns a new `Dataset` which skips `count` initial elements from this dataset

Parameters
count The number of elements to `skip` to form the new dataset.
Returns
  • A new Dataset with `count` elements removed.

public final Dataset take (long count)

Returns a new `Dataset` with only the first `count` elements from this dataset.

Parameters
count The number of elements to "take" from this dataset.
Returns
  • A new Dataset containing the first `count` elements from this dataset.

public static Dataset textLineDataset (Ops tf, String filename, String compressionType, long bufferSize)

public static Dataset tfRecordDataset (Ops tf, String filename, String compressionType, long bufferSize)

public String toString ()