ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf_agents.distributions.utils.Params

The (recursive) parameters of objects exposing the parameters property.

This includes TFP Distribution, Bijector, and TF LinearOperator.

Params objects are created with tf_agents.distributions.utils.get_parameters; Params can be converted back to original objects via tf_agents.distributions.utils.make_from_parameters.

In-place edits of fields are allowed, and will not modify the original objects (with the exception of, e.g., reference objects like tf.Variable being modified in-place).

The components of a Params object are: type_ and params.

  • type_ is the type of object.
  • params is a dict of the (non-default) non-tensor arguments passed to the object's __init__; and includes nests of Python objects, as well as other Params values representing "Param-representable" objects passed to init.

A non-trivial example:

scale_matrix = tf.Variable([[1.0, 2.0], [-1.0, 0.0]])
d = tfp.distributions.MultivariateNormalDiag(
    loc=[1.0, 1.0], scale_diag=[2.0, 3.0], validate_args=True)
b = tfp.bijectors.ScaleMatvecLinearOperator(
    scale=tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix),
    adjoint=True)
b_d = b(d)
p = utils.get_parameters(b_d)

Then p is:

Params(
    tfp.distributions.TransformedDistribution,
    params={
        "bijector": Params(
            tfp.bijectors.ScaleMatvecLinearOperator,
            params={"adjoint": True,
                    "scale": Params(
                        tf.linalg.LinearOperatorFullMatrix,
                        params={"matrix": scale_matrix})}),
        "distribution": Params(
            tfp.distributions.MultivariateNormalDiag,
            params={"validate_args": True,
                    "scale_diag": [2.0, 3.0],
                    "loc": [1.0, 1.0]})})

This structure can be manipulated and/or converted back to a Distribution instance via make_from_parameters:

p.params["distribution"].params["loc"] = [0.0, 0.0]

# The distribution `new_b_d` will be a MVN centered on `(0, 0)` passed through
# the `ScaleMatvecLinearOperator` bijector.
new_b_d = utils.make_from_parameters(p)

Methods

__eq__

View source

Return self==value.