Attend the Women in ML Symposium on December 7 Register now


Multi-linear interpolation on a rectilinear grid.

Given [a batch of] reference values, this function computes a multi-linear interpolant and evaluates it on [a batch of] new x values. This is a multi-dimensional generalization of Bilinear Interpolation.

The interpolant is built from reference values indexed by nd dimensions of y_ref, starting at axis.

The x grid is defined by 1-D points along each dimension. These points must be sorted, but may have unequal spacing.

For example, take the case of a 2-D scalar valued function and no leading batch dimensions. In this case, y_ref.shape = [C1, C2] and y_ref[i, j] is the reference value corresponding to grid point

[x_grid_points[0][i], x_grid_points[1][j]]

In the general case, dimensions to the left of axis in y_ref are broadcast with leading dimensions in x, and x_grid_points[k], k = 0, ..., nd - 1.

x Numeric Tensor The x-coordinates of the interpolated output values for each batch. Shape [..., D, nd], designating [a batch of] D coordinates in nd space. D must be >= 1 and is not a batch dim.
x_grid_points Tuple of dimension points. x_grid_points[k] are a shape [..., Ck] Tensor of the same dtype as x that must be sorted along the innermost (-1) axis. These represent [a batch of] points defining the kth dimension values.
y_ref Tensor of same dtype as x. The reference output values. Shape [..., C1, ..., Cnd, B1,...,BM], designating [a batch of] reference values indexed by nd dimensions, of a shape [B1,...,BM] valued function (for M >= 0).
axis Scalar integer Tensor. Dimensions [axis, axis + nd) of y_ref index the interpolation table. E.g. 3-D interpolation of a scalar valued function requires axis=-3 and a 3-D matrix valued function requires axis=-5.
fill_value Determines what values output should take for x values that are below/above the min/max values in x_grid_points. 'constant_extension' ==> Extend as constant function. Default value: 'constant_extension'
name A name to prepend to created ops. Default value: 'batch_interp_rectilinear_nd_grid'.

y_interp Interpolation between members of y_ref, at points x. Tensor of same dtype as x, and shape [..., D, B1, ..., BM].

Exceptions will be raised if shapes are statically determined to be wrong.

ValueError If rank(x) < 2
ValueError If axis is not a scalar.
ValueError If axis + nd > rank(y_ref).
ValueError If x_grid_points[k].shape[-1] != y_ref.shape[axis + k].


Interpolate a function of one variable.

x_grid = tf.linspace(0., 1., 20)**2   # Nonlinearly spaced
y_ref = tf.exp(x_grid)

    # x.shape = [3, 1], with the trailing `1` for `1-D`.
    x=[[6.0], [0.5], [3.3]], x_grid_points=(x_grid,), y_ref=y_ref, axis=0)
==> approx [exp(6.0), exp(0.5), exp(3.3)]

Interpolate a scalar function of two variables.

x0_grid = tf.linspace(0., 2 * np.pi, num=100),
x1_grid = tf.linspace(0., 2 * np.pi, num=100),

# Build y_ref.
x0s, x1s = tf.meshgrid(x0_grid, x1_grid, indexing='ij')

def func(x0, x1):
  return tf.sin(x0) * tf.cos(x1)

y_ref = func(x0s, x1s)

x = np.pi * tf.random.uniform(shape=(10, 2))

tfp.math.batch_interp_regular_nd_grid(x, x_grid_points=(x0_grid, x1_grid),
                                      y_ref, axis=-2)
==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])