tfp.substrates.jax.math.fill_triangular

Creates a (batch of) triangular matrix from a vector of inputs.

Created matrix can be lower- or upper-triangular. (It is more efficient to create the matrix as upper or lower, rather than transpose.)

Triangular matrix elements are filled in a clockwise spiral. See example, below.

If x.shape is [b1, b2, ..., bB, d] then the output shape is [b1, b2, ..., bB, n, n] where n is such that d = n(n+1)/2, i.e., n = int(np.sqrt(0.25 + 2. * m) - 0.5).

Example:

fill_triangular([1, 2, 3, 4, 5, 6])
# ==> [[4, 0, 0],
#      [6, 5, 0],
#      [3, 2, 1]]

fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
# ==> [[1, 2, 3],
#      [0, 5, 6],
#      [0, 0, 4]]

The key trick is to create an upper triangular matrix by concatenating x and a tail of itself, then reshaping.

Suppose that we are filling the upper triangle of an n-by-n matrix M from a vector x. The matrix M contains n**2 entries total. The vector x contains n * (n+1) / 2 entries. For concreteness, we'll consider n = 5 (so x has 15 entries and M has 25). We'll concatenate x and x with the first (n = 5) elements removed and reversed:

x = np.arange(15) + 1
xc = np.concatenate([x, x[5:][::-1]])
# ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
#            12, 11, 10, 9, 8, 7, 6])

# (We add one to the arange result to disambiguate the zeros below the
# diagonal of our upper-triangular matrix from the first entry in `x`.)

# Now, when reshapedlay this out as a matrix:
y = np.reshape(xc, [5, 5])
# ==> array([[ 1,  2,  3,  4,  5],
#            [ 6,  7,  8,  9, 10],
#            [11, 12, 13, 14, 15],
#            [15, 14, 13, 12, 11],
#            [10,  9,  8,  7,  6]])

# Finally, zero the elements below the diagonal:
y = np.triu(y, k=0)
# ==> array([[ 1,  2,  3,  4,  5],
#            [ 0,  7,  8,  9, 10],
#            [ 0,  0, 13, 14, 15],
#            [ 0,  0,  0, 12, 11],
#            [ 0,  0,  0,  0,  6]])

From this example we see that the resuting matrix is upper-triangular, and contains all the entries of x, as desired. The rest is details:

  • If n is even, x doesn't exactly fill an even number of rows (it fills n / 2 rows and half of an additional row), but the whole scheme still works.
  • If we want a lower triangular matrix instead of an upper triangular, we remove the first n elements from x rather than from the reversed x.

For additional comparisons, a pure numpy version of this function can be found in distribution_util_test.py, function _fill_triangular.

x Tensor representing lower (or upper) triangular elements.
upper Python bool representing whether output matrix should be upper triangular (True) or lower triangular (False, default).
name Python str. The name to give this op.

tril Tensor with lower (or upper) triangular elements filled from x.

ValueError if x cannot be mapped to a triangular matrix.