tfp.substrates.jax.distributions.mvn_conjugate_linear_update
Stay organized with collections
Save and categorize content based on your preferences.
Computes a conjugate normal posterior for a Bayesian linear regression.
tfp.substrates.jax.distributions.mvn_conjugate_linear_update(
prior_scale,
linear_transformation,
likelihood_scale,
observation,
prior_mean=None,
name=None
)
We assume the following model:
latent ~ MVN(loc=prior_mean, scale=prior_scale)
observation ~ MVN(loc=linear_transformation.matvec(latent),
scale=likelihood_scale)
For Bayesian linear regression, the latent
represents the weights, and the
provided linear_transformation
is the design matrix.
This method computes the multivariate normal
posterior p(latent | observation)
, using LinearOperator
s to perform
perform computations efficiently when the matrices involved have special
structure.
Args |
prior_scale
|
Instance of tf.linalg.LinearOperator of shape
[..., num_features, num_features] , specifying a
scale matrix (any matrix L such that LL' = Q where Q is the
covariance) for the prior on regression weights. May optionally be a
float Tensor .
|
linear_transformation
|
Instance of tf.linalg.LinearOperator of shape
[..., num_outputs, num_features]) , specifying a transformation of the
latent values. May optionally be a float Tensor .
|
likelihood_scale
|
Instance of tf.linalg.LinearOperator of shape
[..., num_outputs, num_outputs] specifying a scale matrix (any matrix
L such that LL' = Q where Q is the covariance) for the likelihood
of observed targets. May optionally be a float Tensor .
|
observation
|
Float Tensor of shape [..., num_outputs]]), specifying the
observed values or regression targets.
</td>
</tr><tr>
<td> prior_mean<a id="prior_mean"></a>
</td>
<td>
Optional float Tensorof shape [..., num_features],
specifying the prior mean. If None, the prior mean is assumed to be
zero and some computation is avoided.
Default value: None.
</td>
</tr><tr>
<td> name<a id="name"></a>
</td>
<td>
Option Python str` name given to ops created by this function.
Default value: 'mvn_conjugate_linear_update'.
|
Returns |
posterior_mean
|
Float Tensor of shape [..., num_features] , giving the
mean of the multivariate normal posterior on the latent value.
|
posterior_prec
|
Instance of tf.linalg.LinearOperator of shape
shape [..., num_features, num_features] , giving the
posterior precision (inverse covariance) matrix.
|
Mathematical details
Let the prior precision be denoted by
prior_prec = prior_scale.matmul(prior_scale, adjoint_arg=True).inverse()
and the likelihood precision by likelihood_prec = likelihood_scale.matmul(
likelihood_scale, adjoint_arg=True).inverse()
. Then the posterior
p(latent | observation)
is multivariate normal with precision
posterior_prec = (
linear_transformation.matmul(
likelihood_prec.matmul(linear_transformation), adjoint=True) +
prior_prec)
and mean
posterior_mean = posterior_prec.solvevec(
linear_transformation.matvec(
likelihood_prec.matvec(observation) +
prior_prec.matvec(prior_mean)))
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[null,null,["Last updated 2023-11-21 UTC."],[],[],null,["# tfp.substrates.jax.distributions.mvn_conjugate_linear_update\n\n\u003cbr /\u003e\n\n|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/substrates/jax/distributions/normal_conjugate_posteriors.py#L23-L146) |\n\nComputes a conjugate normal posterior for a Bayesian linear regression.\n\n#### View aliases\n\n\n**Main aliases**\n\n[`tfp.experimental.substrates.jax.distributions.mvn_conjugate_linear_update`](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/distributions/mvn_conjugate_linear_update)\n\n\u003cbr /\u003e\n\n tfp.substrates.jax.distributions.mvn_conjugate_linear_update(\n prior_scale,\n linear_transformation,\n likelihood_scale,\n observation,\n prior_mean=None,\n name=None\n )\n\nWe assume the following model: \n\n latent ~ MVN(loc=prior_mean, scale=prior_scale)\n observation ~ MVN(loc=linear_transformation.matvec(latent),\n scale=likelihood_scale)\n\nFor Bayesian linear regression, the `latent` represents the weights, and the\nprovided `linear_transformation` is the design matrix.\n\nThis method computes the multivariate normal\nposterior `p(latent | observation)`, using `LinearOperator`s to perform\nperform computations efficiently when the matrices involved have special\nstructure.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `prior_scale` | Instance of [`tf.linalg.LinearOperator`](https://www.tensorflow.org/api_docs/python/tf/linalg/LinearOperator) of shape `[..., num_features, num_features]`, specifying a scale matrix (any matrix `L` such that `LL' = Q` where `Q` is the covariance) for the prior on regression weights. May optionally be a float `Tensor`. |\n| `linear_transformation` | Instance of [`tf.linalg.LinearOperator`](https://www.tensorflow.org/api_docs/python/tf/linalg/LinearOperator) of shape `[..., num_outputs, num_features])`, specifying a transformation of the latent values. May optionally be a float `Tensor`. |\n| `likelihood_scale` | Instance of [`tf.linalg.LinearOperator`](https://www.tensorflow.org/api_docs/python/tf/linalg/LinearOperator) of shape `[..., num_outputs, num_outputs]` specifying a scale matrix (any matrix `L` such that `LL' = Q` where `Q` is the covariance) for the likelihood of observed targets. May optionally be a float `Tensor`. |\n| `observation` | Float `Tensor` of shape `[..., num_outputs]]), specifying the observed values or regression targets. \u003c/td\u003e \u003c/tr\u003e\u003ctr\u003e \u003ctd\u003e`prior_mean`\u003ca id=\"prior_mean\"\u003e\u003c/a\u003e \u003c/td\u003e \u003ctd\u003e Optional float`Tensor`of shape`\\[..., num_features\\]`, specifying the prior mean. If`None`, the prior mean is assumed to be zero and some computation is avoided. Default value:`None`. \u003c/td\u003e \u003c/tr\u003e\u003ctr\u003e \u003ctd\u003e`name`\u003ca id=\"name\"\u003e\u003c/a\u003e \u003c/td\u003e \u003ctd\u003e Option Python`str\\` name given to ops created by this function. Default value: 'mvn_conjugate_linear_update'. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `posterior_mean` | Float `Tensor` of shape `[..., num_features]`, giving the mean of the multivariate normal posterior on the latent value. |\n| `posterior_prec` | Instance of [`tf.linalg.LinearOperator`](https://www.tensorflow.org/api_docs/python/tf/linalg/LinearOperator) of shape shape `[..., num_features, num_features]`, giving the posterior precision (inverse covariance) matrix. |\n\n\u003cbr /\u003e\n\n#### Mathematical details\n\nLet the prior precision be denoted by\n`prior_prec = prior_scale.matmul(prior_scale, adjoint_arg=True).inverse()`\nand the likelihood precision by `likelihood_prec = likelihood_scale.matmul(\nlikelihood_scale, adjoint_arg=True).inverse()`. Then the posterior\n`p(latent | observation)` is multivariate normal with precision \n\n posterior_prec = (\n linear_transformation.matmul(\n likelihood_prec.matmul(linear_transformation), adjoint=True) +\n prior_prec)\n\nand mean \n\n posterior_mean = posterior_prec.solvevec(\n linear_transformation.matvec(\n likelihood_prec.matvec(observation) +\n prior_prec.matvec(prior_mean)))"]]