As tf.where, but only calls x_fn/y_fn when condition not statically known.
tfp.experimental.distributions.marginal_fns.ps.smart_where(
condition, x_fn, y_fn
)
Args |
condition
|
A bool Tensor.
|
x_fn
|
A callable returning a Tensor , for locations where condition is
True .
|
y_fn
|
A callable returning a Tensor , for locations where condition is
False .
|
Returns |
A Tensor equivalent to tf.where(condition, x_fn(), y_fn()) .
|