TPUPartitionedOutputV2

public final class TPUPartitionedOutputV2

An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned

outputs outside the XLA computation. Supports ND sharding.

Public Methods

static <T> TPUPartitionedOutputV2<T>
create(Scope scope, Operand<T> inputs, Long numSplits, List<Long> partitionDims)
Factory method to create a class wrapping a new TPUPartitionedOutputV2 operation.
Iterator<Operand<T>>
List<Output<T>>
output()
A list of partitioned outputs which have the same shape.

Inherited Methods

Public Methods

public static TPUPartitionedOutputV2<T> create (Scope scope, Operand<T> inputs, Long numSplits, List<Long> partitionDims)

Factory method to create a class wrapping a new TPUPartitionedOutputV2 operation.

Parameters
scope current scope
inputs A tensor which represents the full shape of partitioned tensors.
partitionDims A list of integers describing how each dimension is partitioned. Emptiness indicates the inputs are replicated.
Returns
  • a new instance of TPUPartitionedOutputV2

public Iterator<Operand<T>> iterator ()

public List<Output<T>> output ()

A list of partitioned outputs which have the same shape.