View source on GitHub |
Normalizes Tensor ranks for use in if
conditions.
tfp.experimental.auto_batching.truthy(
x
)
This enables dry-runs of programs with control flow. Usage: Program the
conditions of if
statements and while
loops to have a batch dimension, and
then wrap them with this function. Example:
ctx = frontend.Context
truthy = frontend.truthy
@ctx.batch(type_inference=...)
def my_abs(x):
if truthy(x > 0):
return x
else:
return -x
my_abs([-5], dry_run=True)
# returns [5] in Eager mode
This is necessary because auto-batched programs still have a leading batch
dimension (of size 1) even in dry-run mode, and a Tensor of shape [1] is not
acceptable as the condition to an if
or while
. However, the leading
dimension is critical during batched execution; so conditions of ifs need to
have rank 1 if running batched and rank 0 if running unbatched (i.e.,
dry-run). The truthy
function arranges for this be happen (by detecting
whether it is in dry-run mode or not).
If you missed a spot where you should have used truthy
, the error message
will say Non-scalar tensor <Tensor ...> cannot be converted to boolean.
Args | |
---|---|
x
|
A Tensor. |