View source on GitHub |
Creates and initializes the requested tf.distribute strategy.
tfr.keras.strategy_utils.get_strategy(
strategy: str,
cluster_resolver: Optional[tf.distribute.cluster_resolver.ClusterResolver] = None,
variable_partitioner: Optional[tf.distribute.experimental.partitioners.Partitioner] = _USE_DEFAULT_VARIABLE_PARTITIONER,
tpu: Optional[str] = ''
) -> Union[None, tf.distribute.MirroredStrategy, tf.distribute.
MultiWorkerMirroredStrategy, tf.distribute.experimental.
ParameterServerStrategy, tf.distribute.experimental.TPUStrategy]
Example usage:
strategy = get_strategy("MirroredStrategy")
Args | |
---|---|
strategy
|
Key for a tf.distribute strategy to be used to train the model.
Choose from ["MirroredStrategy", "MultiWorkerMirroredStrategy",
"ParameterServerStrategy", "TPUStrategy"]. If None, no distributed
strategy will be used.
|
cluster_resolver
|
A cluster_resolver to build strategy. |
variable_partitioner
|
Variable partitioner to be used in
ParameterServerStrategy. If the argument is not specified, a recommended
tf.distribute.experimental.partitioners.MinSizePartitioner is used. If
the argument is explicitly specified as None , no partitioner is used and
that variables are not partitioned. This arg is used only when the
strategy is tf.distribute.experimental.ParameterServerStrategy .
See tf.distribute.experimental.ParameterServerStrategy class doc for
more information.
|
tpu
|
TPU address for TPUStrategy. Not used for other strategy. |
Returns | |
---|---|
A strategy will be used for distributed training. |
Raises | |
---|---|
ValueError if strategy is not supported.
|