tfl.lattice_lib.evaluate_with_hypercube_interpolation

Evaluates a lattice using hypercube interpolation.

Lattice function is multi-linearly interpolated between the 2^d vertices of a hypercube. This interpolation method is typically slower than simplex interpolation, since each value is interpolated from 2^d hypercube corners, rather than d+1 simplex corners. For details, see e.g. "Dissection of the hypercube into simplices", D.G. Mead, Proceedings of the AMS, 76:2, Sep. 1979.

inputs Tensor representing points to apply lattice interpolation to. If units = 1, tensor should be of shape: (batch_size, ..., len(lattice_sizes)) or list of len(lattice_sizes) tensors of same shape (batch_size, ..., 1). If units > 1, tensor should be of shape: (batch_size, ..., units, len(lattice_sizes)) or list of len(lattice_sizes) tensors of same shape (batch_size, ..., units, 1). A typical shape is (batch_size, len(lattice_sizes)).
kernel Lattice kernel of shape (num_params_per_lattice, units).
units Output dimension of the lattice.
lattice_sizes List or tuple of integers which represents lattice sizes of layer for which interpolation is being computed.
clip_inputs Whether inputs should be clipped to the input range of the lattice.

Tensor of shape: (batch_size, ..., units).