Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand

tfl.pwl_calibration_sonnet_module.PWLCalibration

Piecewise linear calibration layer.

Module takes input of shape (batch_size, units) or (batch_size, 1) and transforms it using units number of piecewise linear functions following monotonicity, convexity and bounds constraints if specified. If multi dimensional input is provides, each output will be for the corresponding input, otherwise all PWL functions will act on the same input. All units share the same configuration, but each has their separate set of trained parameters.

Input shape:

Single input should be a rank-2 tensor with shape: (batch_size, units) or (batch_size, 1). The input can also be a list of two tensors of the same shape where the first tensor is the regular input tensor and the second is the is_missing tensor. In the is_missing tensor, 1.0 represents missing input and 0.0 represents available input.

Output shape:

Rank-2 tensor with shape: (batch_size, units).

Example:

calibrator = tfl.sonnet_modules.PWLCalibration(
    # Key-points of piecewise-linear function.
    input_keypoints=np.linspace(1., 4., num=4),
    # Output can be bounded, e.g. when this layer feeds into a lattice.
    output_min=0.0,
    output_max=2.0,
    # You can specify monotonicity and other shape constraints for the layer.
    monotonicity='increasing',
)

input_keypoints Ordered list of keypoints of piecewise linear function. Can be anything accepted by tf.convert_to_tensor().
units Output dimension of the layer. See class comments for details.
output_min Minimum output of calibrator.
output_max Maximum output of calibrator.
clamp_min For monotonic calibrators ensures that output_min is reached.
clamp_max For monotonic calibrators ensures that output_max is reached.
monotonicity Constraints piecewise linear function to be monotonic using 'increasing' or 1 to indicate increasing monotonicity, 'decreasing' or -1 to indicate decreasing monotonicity and 'none' or 0 to indicate no monotonicity constraints.
convexity Constraints piecewise linear function to be convex or concave. Convexity is indicated by 'convex' or 1, concavity is indicated by 'concave' or -1, 'none' or 0 indicates no convexity/concavity constraints. Concavity together with increasing monotonicity as well as convexity together with decreasing monotonicity results in diminishing return constraints. Consider increasing the value of num_projection_iterations if convexity is specified, especially with larger number of keypoints.
is_cyclic Whether the output for last keypoint should be identical to output for first keypoint. This is useful for features such as "time of day" or "degree of turn". If inputs are discrete and exactly match keypoints then is_cyclic will have an effect only if TFL regularizers are being used.
kernel_init None or one of:

  • String "equal_heights": For pieces of pwl function to have equal heights.
  • String "equal_slopes": For pieces of pwl function to have equal slopes.
  • Any Sonnet initializer object. If you are passing such object make sure that you know how this module uses the variables.
impute_missing Whether to learn an output for cases where input data is missing. If set to True, either missing_input_value should be initialized, or the call() method should get pair of tensors. See class input shape description for more details.
missing_input_value If set, all inputs which are equal to this value will be considered as missing. Can not be set if impute_missing is False.
missing_output_value If set, instead of learning output for missing inputs, simply maps them into this value. Can not be set if impute_missing is False.
num_projection_iterations Number of iterations of the Dykstra's projection algorithm. Constraints are strictly satisfied at the end of each update, but the update will be closer to a true L2 projection with higher number of iterations. See tfl.pwl_calibration_lib.project_all_constraints for more details.
**kwargs Other args passed to snt.Module initializer.

ValueError If layer hyperparameters are invalid.

  • All __init__ arguments.
kernel TF variable which stores weights of piecewise linear function.
missing_output TF variable which stores output learned for missing input. Or TF Constant which stores missing_output_value if one is provided. Available only if impute_missing is True.
name Returns the name of this module as passed or determined in the ctor.
name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

trainable_variables Sequence of :tf:Variable\ s owned by this module and it's submodules.

See :tf:Module.trainable_variables for implementation details.

variables Sequence of :tf:Variable\ s owned by this module and it's submodules.

See :tf:Module.variables for implementation details.

Methods

with_name_scope

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

Args
method The method to wrap.

Returns
The original method wrapped such that it enters the module's name scope.

__call__