Build a trainable LinearOperatorDiag
instance.
tfp.experimental.vi.util.build_trainable_linear_operator_diag(
shape, scale_initializer=0.01, diag_bijector=None, dtype=None, name=None
)
Args |
shape
|
Shape of the LinearOperator , equal to [b0, ..., bn, d] , where
b0...bn are batch dimensions and d is the length of the diagonal.
|
scale_initializer
|
Variables are initialized with samples from
Normal(0, scale_initializer) .
|
diag_bijector
|
Bijector to apply to the diagonal of the operator.
|
dtype
|
tf.dtype of the LinearOperator .
|
name
|
str, name for tf.name_scope .
|