tfp.experimental.as_composite

Returns a CompositeTensor equivalent to the given object.

Note that the returned object will have any Variable, tfp.util.DeferredTensor, or tfp.util.TransformedVariable references it closes over converted to tensors at the time this function is called. The type of the returned object will be a subclass of both CompositeTensor and type(obj). For this reason, one should be careful about using as_composite(), especially for tf.Module objects.

For example, when the composite tensor is created even as part of a tf.Module, it "fixes" the values of the DeferredTensor and tf.Variable objects it uses:

class M(tf.Module):
  def __init__(self):
    self._v = tf.Variable(1.)
    self._d = tfp.distributions.Normal(
      tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10)
    self._dct = tfp.experimental.as_composite(self._d)

  @tf.function
  def mean(self):
    return self._dct.mean()

m = M()
m.mean()
>>> <tf.Tensor: numpy=2.0>
m._v.assign(2.)  # Doesn't update the CompositeTensor distribution.
m.mean()
>>> <tf.Tensor: numpy=2.0>

If, however, the creation of the composite is deferred to a method call, then the Variable and DeferredTensor will be properly captured and respected by the Module and its SavedModel (if it is serialized).

class M(tf.Module):
  def __init__(self):
    self._v = tf.Variable(1.)
    self._d = tfp.distributions.Normal(
      tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10)

  @tf.function
  def d(self):
    return tfp.experimental.as_composite(self._d)

m = M()
m.d().mean()
>>> <tf.Tensor: numpy=2.0>
m._v.assign(2.)
m.d().mean()
>>> <tf.Tensor: numpy=3.0>

obj A tfp.distributions.Distribution.

obj A tfp.distributions.Distribution that extends CompositeTensor.