Wraps f with a tpu.rewrite
or xla.compile
, propagates output structure.
tfp.experimental.auto_batching.xla.compile_nested_output(
f, compile_fn=None
)
xla.compile
insists f
output a flat list of Tensor
s or Op
s, but
tolerates nested input arguments. Here, we capture the output structure in
order to propagate it.
Args |
f
|
Callable to compile, may accept/return nested inputs/outputs.
|
compile_fn
|
The function to use to compile, i.e. xla.compile or
tpu.rewrite . Accepts two args, f and inputs .
|
Returns |
g
|
Callable wrapping f which returns XLA-compiled, nested outputs.
|