A class wrapping information needed by an input function.
View aliases
Compat aliases for migration
See Migration guide for more details.
tf.distribute.InputContext(
num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=1
)
This is a context class that is passed to the user's input function and contains information about the compute replicas and input pipelines. The number of compute replicas (in sync training) helps compute the local batch size from the desired global batch size for each replica. The input pipeline information can be used to return a different subset of the input in each replica (for e.g. shard the input pipeline, use a different input source etc).
Attributes | |
---|---|
input_pipeline_id
|
Returns the input pipeline ID. |
num_input_pipelines
|
Returns the number of input pipelines. |
num_replicas_in_sync
|
Returns the number of compute replicas in sync. |
Methods
get_per_replica_batch_size
get_per_replica_batch_size(
global_batch_size
)
Returns the per-replica batch size.
Args | |
---|---|
global_batch_size
|
the global batch size which should be divisible by
num_replicas_in_sync .
|
Returns | |
---|---|
the per-replica batch size. |
Raises | |
---|---|
ValueError
|
if global_batch_size not divisible by
num_replicas_in_sync .
|