tfp.experimental.auto_batching.Context

Context object for auto-batching multiple Python functions together.

Usage:

ctx = frontend.Context()

@ctx.batch(type_inference=lambda ...)
def my_single_example_function_1(args):
  ...

@ctx.batch(type_inference=lambda ...)
def my_single_example_function_2(args):
  ...

# etc

Then calling any of the decorated functions will execute a batch computation. The decorated functions may call each other, including mutually recursively. See also the batch method.

Limitations:

  • You must explicitly decorate every function to be auto-batched.
  • All calls to them must call them by name (no higher-order auto-batching).
  • Auto-batched functions must be defined with def, not lambda.

Methods

batch

View source

Decorates one function to auto-batch.

The decorated function will run in batch. It accepts all the same arguments, except:

  • All arguments must have an additional leading dimension for the batch. (By special dispensation, scalar inputs are promoted to shape [1], which then leads to broadcasting.)
  • All the arguments' sizes in the batch dimension must be the same, or 1. The latter are broadcast.
  • The returned value will also have a leading batch dimension, and will have the same size.
  • The batched function accepts an additional bool keyword argument dry_run. If present and True, just calls the unbatched version, circumventing the auto-batching system. This can be useful for debugging the program subject to auto-batching.
  • The batched function accepts an additional bool keyword argument stackless. If present and True, invokes the stackless version of the auto-batching system. This can be useful for avoiding stack maintenance overhead; but in general, it will recover less batching, and not work in graph-mode TensorFlow.
  • The batched function accepts an additional int keyword argument max_stack_depth specifying the maximum stack depth (default 15). Ignored in stackless execution.
  • The batched function accepts an additional keyword argument backend specifying the backend to use. Must be an instance of auto_batching.TensorFlowBackend (default) or auto_batching.NumpyBackend.
  • The batched function accepts an additional keyword argument block_code_cache, a dict which allows the caching of basic block rewrites (i.e. tf.function + XLA) to live across calls to the autobatched function. The default value of None results in caching only within a given call to the batched function. Currently, stackless autobatching ignores the cache completely.

Args
type_inference A Python callable giving the type signature of the function being auto-batched. The callable will be invoked with a single argument giving the list of instructions.Type objects describing the arguments at a particular call site, and must return a list of instructions.Type objects describing the values that call site will return.

Returns
dec A decorator that may be applied to a function with the given type signature to auto-batch it.

Raises
ValueError If the decorated function predictably cannot be auto-batched, e.g., name-clashing with another function already decorated in this Context.

batch_uncurried

View source

A non-decorator version of batch, which see.

function_names

View source

lowered_for_args

View source

Helper for calling program_lowered that computes the type signature.

module

View source

Constructs an instructions.Module for this Context.

Returns
module An instructions.Module representing the batched computation defined by all the functions decorated with batch in this Context so far.

program

View source

Constructs an instructions.Program for this Context.

This is a helper method, equivalent to self.module().program(main).

Args
main Python string name of the function that should be the entry point.

Returns
prog An instructions.Program representing the batched computation defined by all the functions decorated with batch in this Context so far. Suitable for downstream compilation with other passes in auto_batching.

Raises
ValueError If the intended main function was not decorated with batch.

program_compiled

View source

Constructs a compiled instructions.Program for this Context.

This constructs the program with self.program(main), and the performs type inference and optimization, to emit a result that can be executed by the stackless auto-batching VM.

The point of having this as a method in its own right is that it caches the compilation on the types of the arguments.

If either sig or backend are omitted or None, type inference is skipped. The result is not executable, but it can be enlightening to inspect.

Args
main Python string name of the function that should be the entry point.
sig A list of (patterns of) instructions.TensorType aligned with the formal parameters to main.
backend Backend implementation.

Returns
prog An instructions.Program representing the batched computation defined by all the functions decorated with batch in this Context so far. Suitable for execution or staging on real data by the auto-batching VM.

program_lowered

View source

Constructs a lowered instructions.Program for this Context.

This constructs the program with self.program(main), and the performs type inference, optimization, and lowering, to emit a result that can be executed (or staged) by the auto-batching VM.

The point of having this as a method in its own right is that it caches the compilation on the types of the arguments.

If either sig or backend are omitted or None, type inference is skipped. The result is not executable, but it can be enlightening to inspect.

Args
main Python string name of the function that should be the entry point.
sig A list of (patterns of) instructions.TensorType aligned with the formal parameters to main.
backend Backend implementation.

Returns
prog An instructions.Program representing the batched computation defined by all the functions decorated with batch in this Context so far. Suitable for execution or staging on real data by the auto-batching VM.