Ver no TensorFlow.org | Executar no Google Colab | Ver fonte no GitHub | Baixar caderno |
Introdução
TensorFlow Probabilidade (TFP) oferece uma série de JointDistribution
abstrações que fazem inferência probabilística mais fácil, permitindo que um usuário facilmente expressar um modelo gráfico probabilística de uma forma quase matemática; a abstração gera métodos para amostragem do modelo e avaliação da probabilidade de log de amostras do modelo. Neste tutorial, revisamos variantes "autobatched", que foram desenvolvidos após os originais JointDistribution
abstrações. Em relação às abstrações originais não autobatched, as versões autobatched são mais simples de usar e mais ergonômicas, permitindo que muitos modelos sejam expressos com menos clichês. Neste colab, exploramos um modelo simples em detalhes (talvez tediosos), deixando claro os problemas que o autobatching resolve e (espero) ensinando o leitor mais sobre os conceitos de forma da TFP ao longo do caminho.
Antes da introdução de autobatching, houve algumas variantes diferentes de JointDistribution
, correspondendo a diferentes estilos sintáticas para expressar modelos probabilísticos: JointDistributionSequential
, JointDistributionNamed
e JointDistributionCoroutine
. Auobatching existe como um mixin, então agora temos AutoBatched
variantes de todos estes. Neste tutorial, vamos explorar as diferenças entre JointDistributionSequential
e JointDistributionSequentialAutoBatched
; no entanto, tudo o que fazemos aqui é aplicável às outras variantes, essencialmente sem alterações.
Dependências e pré-requisitos
Importar e configurar
import functools
import numpy as np
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
Pré-requisito: Um Problema de Regressão Bayesiana
Vamos considerar um cenário de regressão bayesiana muito simples:
\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]
Neste modelo, m
e b
são retirados de indivíduos normais padrão, e as observações Y
são desenhados a partir de uma distribuição normal com média depende das variáveis aleatórias m
e b
, e alguns (não aleatória, conhecido) covariáveis X
. (Para simplificar, neste exemplo, assumimos que a escala de todas as variáveis aleatórias é conhecida.)
Para realizar inferência neste modelo, precisaríamos saber ambas as covariáveis X
e as observações Y
, mas para os fins deste tutorial, nós só precisa de X
, então definimos um simples manequim X
:
X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])
Desiderata
Na inferência probabilística, geralmente queremos realizar duas operações básicas:
-
sample
: a tiragem de amostras a partir do modelo. -
log_prob
: Calculando a probabilidade log de uma amostra a partir do modelo.
A contribuição fundamental de da TFP JointDistribution
abstrações (bem como de muitas outras abordagens para a programação probabilística) é permitir aos usuários escrever um modelo de uma vez e ter acesso a ambas as sample
e log_prob
cálculos.
Notando que temos 7 pontos em nosso conjunto de dados ( X.shape = (7,)
), podemos agora afirmar os desideratos para um excelente JointDistribution
:
-
sample()
deve produzir uma lista dosTensors
que têm forma[(), (), (7,)
], correspondente à inclinação escalar, viés escalar, e observações vector, respectivamente. -
log_prob(sample())
deverá produzir um escalar: a probabilidade de log de uma determinada inclinação, de polarização, e observações. -
sample([5, 3])
deve produzir uma lista dosTensors
que têm forma[(5, 3), (5, 3), (5, 3, 7)]
, que representa um(5, 3)
- em lotes de amostras a partir de o modelo. -
log_prob(sample([5, 3]))
deve produzir umaTensor
com forma (5, 3).
Vamos agora olhar para uma sucessão de JointDistribution
modelos, ver como alcançar os desideratos acima, e espero aprender um pouco mais sobre TFP molda ao longo do caminho.
Alerta de spoiler: A abordagem que satisfaz os desideratos acima sem clichê adicionado é autobatching .
Primeira tentativa; JointDistributionSequential
jds = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
Esta é mais ou menos uma tradução direta do modelo em código. O declive m
e viés b
são simples. Y
é definida usando um lambda
-função: o padrão geral é que um lambda
-função de \(k\) argumentos em um JointDistributionSequential
(JDS) utiliza os anteriores \(k\) distribuições no modelo. Observe a ordem "reversa".
Vamos chamar sample_distributions
, que retorna tanto uma amostra e as subjacentes "sub-distribuições" que foram usados para gerar a amostra. (Poderíamos ter produzido apenas a amostra chamando sample
, mais tarde no tutorial, será conveniente ter as distribuições também.) A amostra que produzimos é muito bem:
dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
Mas log_prob
produz um resultado com uma forma indesejada:
jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy= array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684, -4.4368567, -4.480562 ], dtype=float32)>
E a amostragem múltipla não funciona:
try:
jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Vamos tentar entender o que está errado.
Uma breve revisão: forma de lote e evento
Em TFP, uma (não um comum JointDistribution
) distribuição de probabilidade tem uma forma evento e uma forma de lote, e compreender a diferença é crucial para o uso eficaz da TFP:
- A forma do evento descreve a forma de um único desenho da distribuição; o desenho pode ser dependente das dimensões. Para distribuições escalares, a forma do evento é []. Para um MultivariateNormal 5-dimensional, a forma do evento é [5].
- A forma de lote descreve desenhos independentes, não distribuídos de forma idêntica, também conhecido como "lote" de distribuições. Representar um lote de distribuições em um único objeto Python é uma das principais maneiras pelas quais o TFP alcança eficiência em escala.
Para nossos propósitos, um fato fundamental para manter em mente é que se nós chamamos log_prob
em uma única amostra de uma distribuição, o resultado terá sempre uma forma que jogos (ou seja, tem como dimensões mais à direita) a forma de lote.
Para uma discussão mais aprofundada das formas, consulte "Compreender TensorFlow Distribuições Shapes" tutorial .
Por que não log_prob(sample())
Produzir um escalar?
Vamos usar o nosso conhecimento da forma de lote e evento para explorar o que está acontecendo com log_prob(sample())
. Aqui está nosso exemplo novamente:
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>, <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 , -7.308844 , -9.832693 ], dtype=float32)>]
E aqui estão nossas distribuições:
dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]
A probabilidade de log é calculada somando as probabilidades de log das sub-distribuições nos elementos (combinados) das partes:
log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>, <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>, <tf.Tensor: shape=(7,), dtype=float32, numpy= array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014, -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>
Assim, um nível de explicação é que o cálculo de probabilidade log está retornando a 7 Tensor porque a terceira subcomponente do log_prob_parts
fica a 7 Tensor. Mas por que?
Bem, vemos que o último elemento de dists
, o que corresponde a nossa distribuição de mais de Y
na formulação mathematial, tem uma batch_shape
de [7]
. Em outras palavras, a distribuição ao longo Y
é um lote de 7 normais independentes (com diferentes meios e, neste caso, a mesma escala).
Compreendemos agora o que há de errado: na JDS, a distribuição ao longo do Y
tem batch_shape=[7]
, uma amostra do JDS representa escalares para m
e b
e um "lote" de 7 normais independentes. e log_prob
calcula 7 log-probabilidades separadas, cada uma das quais representa a probabilidade de registo de tiragem m
e b
e uma única observação Y[i]
em algum X[i]
.
Fixação log_prob(sample())
com Independent
Recorde-se que dists[2]
tem event_shape=[]
e batch_shape=[7]
:
dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>
Através da utilização de PTF Independent
metadistribuição, que converte as dimensões de lote para dimensões de eventos, que pode converter este em uma distribuição com event_shape=[7]
e batch_shape=[]
(que vai mudar o nome y_dist_i
porque é uma distribuição em Y
, com o _i
permanente in para o nosso Independent
embalagem):
y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>
Agora, a log_prob
de um 7-vector é um escalar:
y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>
Debaixo das cobertas, Independent
somas sobre o lote:
y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
E, de fato, podemos usar isso para construir um novo jds_i
(o i
novamente significa Independent
), onde log_prob
retorna um escalar:
jds_i = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m*X + b, scale=1.),
reinterpreted_batch_ndims=1)
])
jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>
Algumas notas:
-
jds_i.log_prob(s)
não é o mesmo quetf.reduce_sum(jds.log_prob(s))
. O primeiro produz o log de probabilidade "correto" da distribuição conjunta. Os últimos somas sobre um 7-Tensor, cada elemento de que é a soma da probabilidade de log dem
,b
, e um único elemento da probabilidade log deY
, de modo que overcountsm
eb
. (log_prob(m) + log_prob(b) + log_prob(Y)
retorna um resultado em vez de gerar uma excepção porque TFP segue TF e regras de radiodifusão de Numpy;. Adicionando um escalar para um vector produz um resultado do tamanho do vetor-) - Neste caso particular, poderíamos ter resolvido o problema e obteve o mesmo resultado usando
MultivariateNormalDiag
vez deIndependent(Normal(...))
.MultivariateNormalDiag
é uma distribuição vectorial (isto é, que já tem vector de evento-forma). IndeeedMultivariateNormalDiag
poderia ser (mas não é) implementado como uma composiçãoIndependent
eNormal
. Vale a pena lembrar que dado um vector deV
, a partir de amostrasn1 = Normal(loc=V)
, en2 = MultivariateNormalDiag(loc=V)
são indistinguíveis; a diferença beween estas distribuições é quen1.log_prob(n1.sample())
é um vector en2.log_prob(n2.sample())
é um escalar.
Amostras múltiplas?
Desenhar várias amostras ainda não funciona:
try:
jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Vamos pensar no porquê. Quando chamamos jds_i.sample([5, 3])
, que vai primeiro extrair amostras para m
e b
, cada um com forma (5, 3)
. Em seguida, vamos tentar construir uma Normal
distribuição via:
tfd.Normal(loc=m*X + b, scale=1.)
Mas se m
tem forma (5, 3)
e X
tem forma 7
, não podemos multiplicá-los juntos, e na verdade este é o erro que está batendo:
m = tfd.Normal(0., 1.).sample([5, 3])
try:
m * X
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
Para resolver esse problema, vamos pensar sobre o que propriedades a distribuição ao longo Y
tem que ter. Se nós chamamos jds_i.sample([5, 3])
, então sabemos m
e b
vai ambos têm forma (5, 3)
. Que forma deve uma chamada para sample
na Y
produtos distribuição? A resposta óbvia é (5, 3, 7)
: para cada ponto de lote, queremos uma amostra com o mesmo tamanho que X
. Podemos conseguir isso usando os recursos de transmissão do TensorFlow, adicionando dimensões extras:
m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])
Adicionando um eixo de ambos m
e b
, pode-se definir uma nova JDS que suporta múltiplas amostras:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-1.1133379 , 0.16390413, -0.24177533], [-1.1312429 , -0.6224666 , -1.8182136 ], [-0.31343174, -0.32932565, 0.5164407 ], [-0.0119963 , -0.9079621 , 2.3655841 ], [-0.26293617, 0.8229698 , 0.31098196]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-0.02876974, 1.0872147 , 1.0138507 ], [ 0.27367726, -1.331534 , -0.09084719], [ 1.3349475 , -0.68765205, 1.680652 ], [ 0.75436825, 1.3050154 , -0.9415123 ], [-1.2502679 , -0.25730947, 0.74611956]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy= array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00, -4.8197951e+00, -5.2986512e+00, -6.6931367e+00], [ 3.6438566e-01, 1.0067395e+00, 1.4542470e+00, 8.1155670e-01, 1.8868095e+00, 2.3877139e+00, 1.0195159e+00], [-8.3624744e-01, 1.2518480e+00, 1.0943471e+00, 1.3052304e+00, -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]], [[-3.1788960e-01, 9.2615485e-03, -3.0963073e+00, -2.2846246e+00, -3.2269263e+00, -6.0213070e+00, -7.4806519e+00], [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00, -4.5065498e+00, -5.6719379e+00, -4.8012795e+00], [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00, -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]], [[ 2.0621397e+00, 3.4639853e-01, 7.0252883e-01, -1.4311566e+00, 3.3790007e+00, 1.1619035e+00, -8.9105040e-01], [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00, -2.7150445e+00, -2.4633870e+00, -2.1841538e+00], [ 7.7627432e-01, 2.2401071e+00, 3.7601702e+00, 2.4245868e+00, 4.0690269e+00, 4.0605016e+00, 5.1753912e+00]], [[ 1.4275590e+00, 3.3346462e+00, 1.5374103e+00, -2.2849756e-01, 9.1219616e-01, -3.1220305e-01, -3.2643962e-01], [-3.1910419e-02, -3.8848895e-01, 9.9946201e-02, -2.3619974e+00, -1.8507402e+00, -3.6830821e+00, -5.4907336e+00], [-7.1941972e-02, 2.1602919e+00, 4.9575748e+00, 4.2317696e+00, 9.3528280e+00, 1.0526063e+01, 1.5262107e+01]], [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00, -3.2361765e+00, -3.3434000e+00, -2.6849220e+00], [ 1.5006512e-02, -1.9866472e-01, 7.6781356e-01, 1.6228745e+00, 1.4191239e+00, 2.6655579e+00, 4.4663467e+00], [ 2.6599693e+00, 1.2663836e+00, 1.7162113e+00, 1.4839669e+00, 2.0559487e+00, 2.5976877e+00, 2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.483114 , -10.139662 , -11.514159 ], [-11.656767 , -17.201958 , -12.132455 ], [-17.838818 , -9.474525 , -11.24898 ], [-13.95219 , -12.490049 , -17.123957 ], [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>
Como uma verificação extra, verificaremos se a probabilidade de log para um único ponto de lote corresponde ao que tínhamos antes:
(jds_ia.log_prob(shaped_sample)[3, 1] -
jds_i.log_prob([shaped_sample[0][3, 1],
shaped_sample[1][3, 1],
shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
AutoBatching For The Win
Excelente! Temos agora uma versão do JointDistribution que lida com toda a nossa desideratos: log_prob
retorna um escalares graças ao uso de tfd.Independent
e várias amostras de trabalhar agora que fixa transmitindo adicionando eixos extras.
E se eu dissesse que existe uma maneira mais fácil e melhor? Há, e é chamado JointDistributionSequentialAutoBatched
(JDSAB):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-12.191533 , -10.43885 , -16.371655 ], [-13.292994 , -11.97949 , -16.788685 ], [-15.987699 , -13.435732 , -10.6029 ], [-10.184758 , -11.969714 , -14.275676 ], [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=float32)>
Como é que isso funciona? Enquanto você poderia tentar ler o código para uma compreensão profunda, vamos dar uma visão breve o que é suficiente para a maioria dos casos de uso:
- Recorde-se que o nosso primeiro problema é que a nossa distribuição para
Y
tinhabatch_shape=[7]
eevent_shape=[]
, e utilizou-seIndependent
para converter a dimensão do lote para um evento de dimensão. O JDSAB ignora as formas de lote das distribuições de componentes; em vez disso, trata forma de lote como uma propriedade geral do modelo, o qual é assumido como sendo[]
(a menos que especificado de outra forma pela definiçãobatch_ndims > 0
). O efeito é equivalente a usar tfd.Independent para converter todas as dimensões de lote de distribuições de componentes em dimensões de eventos, como fizemos manualmente acima. - A segunda foi um problema a necessidade de massagem as formas de
m
eb
, para que pudessem difundir adequadamente comX
quando a criação de várias amostras. Com JDSAB, você escreve um modelo para gerar uma única amostra, e nós "levantar" todo o modelo para gerar várias amostras usando de TensorFlow vectorized_map . (Este recurso é análogo ao do JAX VMAP .)
Explorando a questão de forma lote com mais detalhes, podemos comparar as formas de lote do nosso "mau" original de distribuição conjuntas jds
, nossas distribuições fixo-lote jds_i
e jds_ia
, e nossa autobatched jds_ab
:
jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])
Vemos que os originais jds
tem subdistributions com diferentes formas de lote. jds_i
e jds_ia
corrigir isto criando subdistributions com a mesma forma de lote (vazio). jds_ab
tem apenas uma única forma de lote (vazio).
É importante notar que JointDistributionSequentialAutoBatched
oferece alguma generalidade adicional de graça. Suponha-se que fazer o co-variáveis X
(e, implicitamente, as observações Y
) bidimensional:
X = np.arange(14).reshape((2, 7))
X
array([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]])
Nossa JointDistributionSequentialAutoBatched
funciona sem alterações (precisamos redefinir o modelo porque a forma de X
é armazenada em cache pelo jds_ab.log_prob
):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.1813647 , -0.85994506, 0.27593774], [-0.73323774, 1.1153806 , 0.8841938 ], [ 0.5127983 , -0.29271227, 0.63733214], [ 0.2362284 , -0.919168 , 1.6648189 ], [ 0.26317367, 0.73077047, 2.5395133 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[ 0.09636458, 2.0138032 , -0.5054413 ], [ 0.63941646, -1.0785882 , -0.6442188 ], [ 1.2310615 , -0.3293852 , 0.77637213], [ 1.2115169 , -0.98906034, -0.07816773], [-1.1318136 , 0.510014 , 1.036522 ]], dtype=float32)>, <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy= array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01, 8.5992378e-01, -5.3123581e-01, 3.1584005e+00, 2.9044402e+00], [-2.5645006e-01, 3.1554163e-01, 3.1186538e+00, 1.4272424e+00, 1.2843871e+00, 1.2266440e+00, 1.2798605e+00]], [[ 1.5973477e+00, -5.3631151e-01, 6.8143606e-03, -1.4910895e+00, -2.1568544e+00, -2.0513713e+00, -3.1663666e+00], [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00, -5.6543546e+00, -7.2378774e+00, -8.1577444e+00, -9.3582869e+00]], [[-2.1233239e+00, 5.8853775e-02, 1.2024102e+00, 1.6622503e+00, -1.9197327e-01, 1.8647723e+00, 6.4322817e-01], [ 3.7549341e-01, 1.5853541e+00, 2.4594500e+00, 2.1952972e+00, 1.7517658e+00, 2.9666045e+00, 2.5468128e+00]]], [[[ 8.9906776e-01, 6.7375046e-01, 7.3354661e-01, -9.9894643e-01, -3.4606690e+00, -3.4810467e+00, -4.4315586e+00], [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00, -6.8091092e+00, -7.7134805e+00, -8.6319380e+00, -8.6904278e+00]], [[-2.2462025e+00, -3.3060855e-01, 1.8974400e-01, 3.1422038e+00, 4.1483402e+00, 3.5642972e+00, 4.8709240e+00], [ 4.7880130e+00, 5.8790064e+00, 9.6695948e+00, 7.8112822e+00, 1.2022618e+01, 1.2411858e+01, 1.4323385e+01]], [[-1.0189297e+00, -7.8115642e-01, 1.6466728e+00, 8.2378983e-01, 3.0765080e+00, 3.0170646e+00, 5.1899948e+00], [ 6.5285158e+00, 7.8038850e+00, 6.4155884e+00, 9.0899811e+00, 1.0040427e+01, 9.1404457e+00, 1.0411951e+01]]], [[[ 4.5557004e-01, 1.4905317e+00, 1.4904103e+00, 2.9777462e+00, 2.8620450e+00, 3.4745665e+00, 3.8295493e+00], [ 3.9977460e+00, 5.7173767e+00, 7.8421035e+00, 6.3180594e+00, 6.0838981e+00, 8.2257290e+00, 9.6548376e+00]], [[-7.0750320e-01, -3.5972297e-01, 4.3136525e-01, -2.3301599e+00, -5.0374687e-01, -2.8338656e+00, -3.4453444e+00], [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00, -4.0196013e+00, -5.8831010e+00, -4.2965469e+00, -4.1388311e+00]], [[ 2.1969774e+00, 2.4614549e+00, 2.2314475e+00, 1.8392437e+00, 2.8367062e+00, 4.8600502e+00, 4.2273531e+00], [ 6.1879644e+00, 5.1792760e+00, 6.1141996e+00, 5.6517797e+00, 8.9979610e+00, 7.5938139e+00, 9.7918644e+00]]], [[[ 1.5249090e+00, 1.1388919e+00, 8.6903995e-01, 3.0762129e+00, 1.5128503e+00, 3.5204377e+00, 2.4760864e+00], [ 3.4166217e+00, 3.5930209e+00, 3.1694956e+00, 4.5797420e+00, 4.5271711e+00, 2.8774328e+00, 4.7288942e+00]], [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00, -3.8594103e+00, -4.9681158e+00, -6.4256043e+00, -5.5345035e+00], [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00, -1.0417805e+01, -1.1727266e+01, -1.1196255e+01, -1.1333830e+01]], [[-7.0419472e-01, 1.4568675e+00, 3.7946482e+00, 4.8489718e+00, 6.6498446e+00, 9.0224218e+00, 1.1153137e+01], [ 1.0060651e+01, 1.1998097e+01, 1.5326431e+01, 1.7957514e+01, 1.8323889e+01, 2.0160881e+01, 2.1269085e+01]]], [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01, 2.3558271e-01, -1.0381399e+00, 1.9387857e+00, -3.3694571e-01], [ 1.6015106e-01, 1.5284677e+00, -4.8567140e-01, -1.7770648e-01, 2.1919653e+00, 1.3015286e+00, 1.3877077e+00]], [[ 1.3688663e+00, 2.6602898e+00, 6.6657305e-01, 4.6554832e+00, 5.7781887e+00, 4.9115267e+00, 4.8446012e+00], [ 5.1983776e+00, 6.2297459e+00, 6.3848300e+00, 8.4291229e+00, 7.1309576e+00, 1.0395646e+01, 8.5736713e+00]], [[ 1.2675294e+00, 5.2844582e+00, 5.1331611e+00, 8.9993315e+00, 1.0794343e+01, 1.4039831e+01, 1.5731170e+01], [ 1.9084715e+01, 2.2191265e+01, 2.3481146e+01, 2.5803375e+01, 2.8632090e+01, 3.0234968e+01, 3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy= array([[-28.90071 , -23.052422, -19.851362], [-19.775568, -25.894997, -20.302256], [-21.10754 , -23.667885, -20.973007], [-19.249458, -20.87892 , -20.573763], [-22.351208, -25.457762, -24.648403]], dtype=float32)>
Por outro lado, a nossa cuidadosamente elaborado JointDistributionSequential
não funciona mais:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
try:
jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]
Para corrigir isso, teríamos que adicionar um segundo tf.newaxis
a ambos m
e b
combinar a forma e aumentar reinterpreted_batch_ndims
a 2 na chamada para Independent
. Nesse caso, deixar a máquina de lote automático lidar com os problemas de formato é mais curto, mais fácil e mais ergonômico.
Mais uma vez, notamos que, enquanto este notebook explorado JointDistributionSequentialAutoBatched
, as outras variantes de JointDistribution
tem equivalente AutoBatched
. (Para usuários de JointDistributionCoroutine
, JointDistributionCoroutineAutoBatched
tem a vantagem adicional de que você não precisa mais especificar Root
nós, se você nunca usou JointDistributionCoroutine
. Você pode seguramente ignorar esta declaração)
Pensamentos Finais
Neste caderno, introduzimos JointDistributionSequentialAutoBatched
e trabalhou através de um exemplo simples em detalhe. Espero que você tenha aprendido algo sobre formas TFP e autobatching!