Google I/O returns May 18-20! Reserve space and build your schedule Register now

tfp.experimental.vi.util.build_trainable_linear_operator_block

Builds a trainable blockwise tf.linalg.LinearOperator.

This function returns a trainable blockwise LinearOperator. If operators is a flat list, it is interpreted as blocks along the diagonal of the structure and an instance of tf.linalg.LinearOperatorBlockDiag is returned. If operators is a doubly nested list, then a tf.linalg.LinearOperatorBlockLowerTriangular instance is returned, with the block in row i column j (i >= j) given by operators[i][j]. The operators list may contain LinearOperator instances, LinearOperator subclasses, or callables that return LinearOperator instances. The dimensions of the blocks are given by block_dims; this argument may be omitted if operators contains only LinearOperator instances.

Examples

# Build a 5x5 trainable `LinearOperatorBlockDiag` given `LinearOperator`
# subclasses and `block_dims`.
op = build_trainable_linear_operator_block(
  operators=(tf.linalg.LinearOperatorDiag,
             tf.linalg.LinearOperatorLowerTriangular),
  block_dims=[3, 2],
  dtype=tf.float32)

# Build an 8x8 `LinearOperatorBlockLowerTriangular`, with a callable that
# returns a `LinearOperator` in the upper left block, and `LinearOperator`
# subclasses in the lower two blocks.
op = build_trainable_linear_operator_block(
  operators=(
    (lambda shape, dtype: tf.linalg.LinearOperatorScaledIdentity(
       num_rows=shape[-1], multiplier=tf.Variable(1., dtype=dtype))),
    (tf.linalg.LinearOperatorFullMatrix,
    tf.linalg.LinearOperatorLowerTriangular))
  block_dims=[4, 4],
  dtype=tf.float64)

# Build a 6x6 `LinearOperatorBlockDiag` with batch shape `(4,)`. Since
# `operators` contains only `LinearOperator` instances, the `block_dims`
# argument is not necessary.
op = build_trainable_linear_operator_block(
  operators=(tf.linalg.LinearOperatorDiag(tf.Variable(tf.ones((4, 3)))),
             tf.linalg.LinearOperatorFullMatrix([4.]),
             tf.linalg.LinearOperatorIdentity(2)))

operators A list or tuple containing LinearOperator subclasses, LinearOperator instances, or callables returning LinearOperator instances. If the list is flat, a tf.linalg.LinearOperatorBlockDiag instance is returned. Otherwise, the list must be singly nested, with the first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a tf.linalg.LinearOperatorBlockLowerTriangular instance is returned. Callables contained in the lists must take three arguments -- shape, the shape of the tf.Variable instantiating the LinearOperator, dtype, the tf.dtype of the LinearOperator, and seed, a seed for generating random values.
block_dims List or tuple of integers, representing the sizes of the blocks along one dimension of the (square) blockwise LinearOperator. If operators contains only LinearOperator instances, block_dims may be None and the dimensions are inferred.
batch_shape Batch shape of the LinearOperator.
dtype tf.dtype of the LinearOperator.
seed Python integer to seed the random number generator.
name str, name for tf.name_scope.

Trainable instance of tf.linalg.LinearOperatorBlockDiag or tf.linalg.LinearOperatorBlockLowerTriangular.