![]() |
Variable tracking object which applies a bijector upon convert_to_tensor
.
Inherits From: DeferredTensor
tfp.substrates.numpy.util.TransformedVariable(
initial_value, bijector, dtype=None, name=None, **kwargs
)
Example
from tensorflow_probability.python.internal.backend.numpy.compat import v2 as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.numpy
tfb = tfp.bijectors
tfd = tfp.distributions
trainable_normal = tfd.Normal(
loc=tf.Variable(0.),
scale=tfp.util.TransformedVariable(1., bijector=tfb.Exp()))
trainable_normal.loc
# ==> <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>
trainable_normal.scale
# ==> <TransformedVariable: dtype=float32, shape=[], fn=exp>
tf.convert_to_tensor(trainable_normal.scale)
# ==> 1.
# Operators work with `TransformedVariable`.
trainable_normal.scale + 1.
# ==> 2.
with tf.GradientTape() as tape:
negloglik = -trainable_normal.log_prob(0.5)
g = tape.gradient(negloglik, trainable_normal.trainable_variables)
# ==> (-0.5, 0.75)
Which we could then fit as:
opt = tf.optimizers.Adam(learning_rate=0.05)
loss = tf.function(lambda: -trainable_normal.log_prob(0.5))
for _ in range(int(1e3)):
opt.minimize(loss, trainable_normal.trainable_variables)
trainable_normal.mean()
# ==> 0.5
trainable_normal.stddev()
# ==> (approximately) 0.0075
It is also possible to assign values to a TransformedVariable, e.g.,
d = tfd.Normal(
loc=tf.Variable(0.),
scale=tfp.util.TransformedVariable([1., 2.], bijector=tfb.Softplus()))
d.stddev()
# ==> [1., 2.]
with tf.control_dependencies([x.scale.assign_add([0.5, 1.])]):
d.stddev()
# ==> [1.5, 3.]
Args | |
---|---|
initial_value
|
A Tensor , or Python object convertible to a Tensor ,
which is the initial value for the Variable. Can also be a callable with
no argument that returns the initial value when called. Note: if
initial_value is a TransformedVariable then the instantiated object
does not create a new tf.Variable , but rather points to the underlying
Variable and chains the bijector arg with the underlying bijector as
tfb.Chain([bijector, initial_value.bijector]) .
|
bijector
|
A Bijector -like instance which defines the transformations
applied to the underlying tf.Variable .
|
dtype
|
tf.dtype.DType instance or otherwise valid dtype value to
tf.convert_to_tensor(..., dtype) .
Default value: None (i.e., bijector.dtype ).
|
name
|
Python str representing the underlying tf.Variable 's name.
Default value: None .
|
**kwargs
|
Keyword arguments forward to tf.Variable .
|
Attributes | |
---|---|
also_track
|
Additional variables tracked by tf.Module in self.trainable_variables. |
bijector
|
|
dtype
|
Represents the type of the elements in a Tensor .
|
initializer
|
The initializer operation for the underlying variable. |
name
|
The string name of this object. |
pretransformed_input
|
Input to transform_fn .
|
shape
|
Represents the shape of a Tensor .
|
trainable_variables
|
|
transform_fn
|
Function which characterizes the Tensor ization of this object.
|
variables
|
Methods
assign
assign(
value, **_
)
assign_add
assign_add(
value, **_
)
assign_sub
assign_sub(
value, **_
)
get_shape
get_shape()
Legacy means of getting Tensor shape, for compat with 2.0.0 LinOp.
numpy
numpy()
Returns (copy of) deferred values as a NumPy array or scalar.
set_shape
set_shape(
shape
)
Updates the shape of this pretransformed_input.
This method can be called multiple times, and will merge the given shape
with the current shape of this object. It can be used to provide additional
information about the shape of this object that cannot be inferred from the
graph alone.
Args | |
---|---|
shape
|
A TensorShape representing the shape of this
pretransformed_input , a TensorShapeProto , a list, a tuple, or None.
|
Raises | |
---|---|
ValueError
|
If shape is not compatible with the current shape of this
pretransformed_input .
|
__abs__
__abs__(
*args, **kwargs
)
__add__
__add__(
*args, **kwargs
)
__and__
__and__(
*args, **kwargs
)
__bool__
__bool__()
self != 0
__floordiv__
__floordiv__(
*args, **kwargs
)
__ge__
__ge__(
*args, **kwargs
)
__getitem__
__getitem__(
*args, **kwargs
)
__gt__
__gt__(
*args, **kwargs
)
__invert__
__invert__(
*args, **kwargs
)
__iter__
__iter__(
*args, **kwargs
)
__le__
__le__(
*args, **kwargs
)
__lt__
__lt__(
*args, **kwargs
)
__matmul__
__matmul__(
*args, **kwargs
)
__mod__
__mod__(
*args, **kwargs
)
__mul__
__mul__(
*args, **kwargs
)
__neg__
__neg__(
*args, **kwargs
)
__or__
__or__(
*args, **kwargs
)
__pow__
__pow__(
*args, **kwargs
)
__radd__
__radd__(
*args, **kwargs
)
__rand__
__rand__(
*args, **kwargs
)
__rfloordiv__
__rfloordiv__(
*args, **kwargs
)
__rmatmul__
__rmatmul__(
*args, **kwargs
)
__rmod__
__rmod__(
*args, **kwargs
)
__rmul__
__rmul__(
*args, **kwargs
)
__ror__
__ror__(
*args, **kwargs
)
__rpow__
__rpow__(
*args, **kwargs
)
__rsub__
__rsub__(
*args, **kwargs
)
__rtruediv__
__rtruediv__(
*args, **kwargs
)
__rxor__
__rxor__(
*args, **kwargs
)
__sub__
__sub__(
*args, **kwargs
)
__truediv__
__truediv__(
*args, **kwargs
)
__xor__
__xor__(
*args, **kwargs
)