Generalization of tf.where for nested structures.
tf_agents.utils.nest_utils.where(
condition, true_outputs, false_outputs
)
This generalization handles applying where across nested structures and the
special case where the rank of the condition is smaller than the rank of the
true and false cases.
Args |
condition
|
A boolean Tensor of shape [B, ...]. The shape of condition must
be equal to or a prefix of the shape of true_outputs and false_outputs. If
condition's rank is smaller than the rank of true_outputs and
false_outputs, dimensions of size 1 are added to condition to make its
rank match that of true_outputs and false_outputs in order to satisfy the
requirements of tf.where.
|
true_outputs
|
Tensor or nested tuple of Tensors of any dtype, each with
shape [B, ...], to be split based on condition .
|
false_outputs
|
Tensor or nested tuple of Tensors of any dtype, each with
shape [B, ...], to be split based on condition .
|
Returns |
Interleaved output from true_outputs and false_outputs based on
condition .
|