tfp.experimental.auto_composite_tensor

Automagically create a CompositeTensor class for cls.

CompositeTensor objects are able to pass in and out of tf.function, tf.while_loop and serve as part of the signature of a TF saved model.

The basic contract is that all args must have public attributes (or properties) or private attributes corresponding to each argument to __init__. Each of these is inspected to determine whether it is a Tensor or non-Tensor metadata. Lists and tuples of objects are supported provided all items therein are all either Tensor/CompositeTensor, or all are not.

Example

@tfp.experimental.auto_composite_tensor(omit_kwargs=('name',))
class Adder(object):
  def __init__(self, x, y, name=None):
    with tf.name_scope(name or 'Adder') as name:
      self._x = tf.convert_to_tensor(x)
      self._y = tf.convert_to_tensor(y)
      self._name = name

  def xpy(self):
    return self._x + self._y

def body(obj):
  return Adder(obj.xpy(), 1.),

result, = tf.while_loop(
    cond=lambda _: True,
    body=body,
    loop_vars=(Adder(1., 1.),),
    maximum_iterations=3)

result.xpy()  # => 5.

cls The class for which to create a CompositeTensor subclass.
omit_kwargs Optional sequence of kwarg names to be omitted from the spec.

ctcls A subclass of cls and TF CompositeTensor.