tf.distribute.InputContext
    
    
      
    
    
      
      Stay organized with collections
    
    
      
      Save and categorize content based on your preferences.
    
  
  
      
    
  
  
  
  
  
    
  
  
    
    
A class wrapping information needed by an input function.
tf.distribute.InputContext(
    num_input_pipelines=1, input_pipeline_id=0, num_replicas_in_sync=1
)
This is a context class that is passed to the user's input function and
contains information about the compute replicas and input pipelines. The
number of compute replicas (in sync training) helps compute the local batch
size from the desired global batch size for each replica. The input pipeline
information can be used to return a different subset of the input in each
replica (for e.g. shard the input pipeline, use a different input
source etc).
| Args | 
|---|
| num_input_pipelines | the number of input pipelines in a cluster. | 
| input_pipeline_id | the current input pipeline id, should be an int in
[0, num_input_pipelines). | 
| num_replicas_in_sync | the number of replicas that are in sync. | 
| Attributes | 
|---|
| input_pipeline_id | Returns the input pipeline ID. | 
| num_input_pipelines | Returns the number of input pipelines. | 
| num_replicas_in_sync | Returns the number of compute replicas in sync. | 
Methods
get_per_replica_batch_size
View source
get_per_replica_batch_size(
    global_batch_size
)
Returns the per-replica batch size.
| Args | 
|---|
| global_batch_size | the global batch size which should be divisible by num_replicas_in_sync. | 
| Returns | 
|---|
| the per-replica batch size. | 
| Raises | 
|---|
| ValueError | if global_batch_sizenot divisible bynum_replicas_in_sync. | 
  
  
 
  
    
    
      
       
    
    
  
  
  Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
  Last updated 2023-03-17 UTC.
  
  
  
    
      [null,null,["Last updated 2023-03-17 UTC."],[],[]]