Distribuzioni congiunte auto-batch: un tutorial delicato

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino

introduzione

Tensorflow Probabilità (TFP) offre una serie di JointDistribution astrazioni che rendono inferenza probabilistica più semplice, consentendo ad un utente di facile esprimere un modello grafico probabilistico in una forma quasi matematica; l'astrazione genera metodi per campionare dal modello e valutare la probabilità logaritmica dei campioni dal modello. In questo tutorial, passiamo in rassegna le varianti "autobatched", che sono stati sviluppati dopo l'originale JointDistribution astrazioni. Rispetto alle astrazioni originali non autobatch, le versioni autobatch sono più semplici da usare e più ergonomiche, consentendo a molti modelli di essere espressi con meno boilerplate. In questa collaborazione, esploriamo un modello semplice nei dettagli (forse noiosi), chiarendo i problemi risolti dall'autobatch e (si spera) insegnando al lettore di più sui concetti di forma TFP lungo il percorso.

Prima dell'introduzione di autobatching, ci sono alcune varianti di JointDistribution , corrispondenti a differenti stili sintattiche per esprimere modelli probabilistici: JointDistributionSequential , JointDistributionNamed e JointDistributionCoroutine . Auobatching esiste come mixin, così ora abbiamo AutoBatched varianti di tutti questi. In questo tutorial, esploriamo le differenze tra JointDistributionSequential e JointDistributionSequentialAutoBatched ; tuttavia, tutto ciò che facciamo qui è applicabile alle altre varianti sostanzialmente senza modifiche.

Dipendenze e prerequisiti

Importa e configura

Prerequisito: un problema di regressione bayesiana

Considereremo uno scenario di regressione bayesiana molto semplice:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

In questo modello, m e b sono tratti da normali standard e le osservazioni Y sono tratti da una distribuzione normale la cui media dipende dalle variabili casuali m e b , e alcuni (non casuale, noto) covariate X . (Per semplicità, in questo esempio, assumiamo che la scala di tutte le variabili casuali sia nota.)

Per eseguire l'inferenza in questo modello, avremmo bisogno di conoscere sia le covariate X e le osservazioni Y , ma per gli scopi di questo tutorial, avremo bisogno solo di X , quindi definiamo un semplice manichino X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

Desiderata

Nell'inferenza probabilistica, spesso vogliamo eseguire due operazioni di base:

  • sample : Disegno campioni dal modello.
  • log_prob : Il calcolo della probabilità di registro di un campione dal modello.

Il contributo fondamentale di TFP JointDistribution astrazioni (così come di molti altri approcci alla programmazione probabilistico) è quello di consentire agli utenti di scrivere un modello di una volta e avere accesso a entrambi i sample e log_prob calcoli.

Notando che abbiamo 7 punti nel nostro insieme di dati ( X.shape = (7,) ), possiamo ora affermare i desiderata per un ottimo JointDistribution :

  • sample() dovrebbe produrre un elenco dei Tensors aventi forma [(), (), (7,) ], corrispondente alla pendenza scalare, polarizzazione scalari e vettoriali osservazioni, rispettivamente.
  • log_prob(sample()) dovrebbe produrre uno scalare: la probabilità di log di un particolare pendenza, pregiudizi, e le osservazioni.
  • sample([5, 3]) dovrebbe produrre un elenco dei Tensors aventi forma [(5, 3), (5, 3), (5, 3, 7)] , che rappresenta una (5, 3) - lotto di campioni da il modello.
  • log_prob(sample([5, 3])) dovrebbe produrre un Tensor di forma (5, 3).

Noi ora guardiamo un susseguirsi di JointDistribution modelli, vediamo come realizzare i desiderata sopra, e, auspicabilmente, imparare un po 'di più su TFP forme lungo la strada.

Spoiler alert: L'approccio che soddisfa i desiderata sopra senza boilerplate aggiunto è autobatching .

Primo tentativo; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

Questa è più o meno una traduzione diretta del modello in codice. La pendenza m e polarizzazione b sono semplici. Y viene definito con un lambda -funzione: il modello generale è che un lambda -funzione di \(k\) argomenti in JointDistributionSequential (JDS) utilizza le precedenti \(k\) distribuzioni nel modello. Nota l'ordine "inverso".

Chiameremo sample_distributions , che restituisce sia un campione e il fondo "sub-distribuzioni" che sono stati utilizzati per generare il campione. (Avremmo potuto prodotta solo il campione chiamando sample , in seguito nel tutorial sarà conveniente avere le distribuzioni.) Il campione che produciamo è bene:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Ma log_prob produce un risultato indesiderato di forma:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

E il campionamento multiplo non funziona:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Cerchiamo di capire cosa non va.

Una breve rassegna: forma batch ed evento

In TFP, un normale (non un JointDistribution ) distribuzione di probabilità ha una forma evento e una forma batch, e comprendere la differenza è fondamentale per un uso efficace della TFP:

  • La forma dell'evento descrive la forma di una singola estrazione dalla distribuzione; il sorteggio può dipendere dalle dimensioni. Per le distribuzioni scalari, la forma dell'evento è []. Per un Multivariato Normale a 5 dimensioni, la forma dell'evento è [5].
  • La forma batch descrive i sorteggi indipendenti, non distribuiti in modo identico, ovvero un "batch" di distribuzioni. Rappresentare un batch di distribuzioni in un singolo oggetto Python è uno dei modi principali in cui TFP raggiunge l'efficienza su larga scala.

Per i nostri scopi, un fatto importante da tenere a mente è che se chiamiamo log_prob su uno stesso campione da una distribuzione, il risultato sarà sempre una forma che partite (cioè, ha come dimensioni più a destra) la forma batch.

Per una discussione più approfondita di forme, consultare il tutorial "Comprendere tensorflow Distribuzioni Forme" .

Perché non log_prob(sample()) Produrre uno scalare?

Usiamo la nostra conoscenza della forma batch e manifestazione per esplorare ciò che sta accadendo con log_prob(sample()) . Ecco di nuovo il nostro campione:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Ed ecco le nostre distribuzioni:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

La probabilità logaritmica viene calcolata sommando le probabilità logaritmiche delle sotto-distribuzioni agli elementi (accoppiati) delle parti:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

Così, un livello di spiegazione è che il calcolo delle probabilità registro sta tornando a 7 Tensor perché il terzo sottocomponente di log_prob_parts è un 7-Tensor. Ma perché?

Ebbene, vediamo che l'ultimo elemento di dists , che corrisponde alla nostra distribuzione su Y nella formulazione mathematial, ha una batch_shape di [7] . In altre parole, la nostra distribuzione su Y è un gruppo di 7 normali indipendenti (con mezzi diversi e, in questo caso, nella stessa scala).

Ora capire cosa c'è che non va: nel JDS, la distribuzione su Y ha batch_shape=[7] , un campione dal JDS rappresenta scalari per m e b e un "batch" di 7 normali indipendenti. e log_prob calcola 7 log-probabilità separati, ciascuno dei quali rappresenta la probabilità registro di disegno m e b ed una singola osservazione Y[i] ad un certo X[i] .

Fissaggio log_prob(sample()) con Independent

Ricordiamo che dists[2] ha event_shape=[] e batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

Utilizzando del TFP Independent metadistribuzione, che converte le dimensioni dei lotti alle dimensioni degli eventi, siamo in grado di convertire questo in una distribuzione con event_shape=[7] e batch_shape=[] (ci rinominiamo y_dist_i perché è una distribuzione su Y , con la _i in piedi in per il nostro Independent wrapping):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

Ora, la log_prob di un 7-vettore è uno scalare:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

Sotto le coperte, Independent somme sopra il lotto:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

E in effetti, possiamo usare questo per costruire una nuova jds_i (l' i si distingue ancora una volta per Independent ), dove log_prob ritorna uno scalare:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

Un paio di note:

  • jds_i.log_prob(s) non è la stessa tf.reduce_sum(jds.log_prob(s)) . Il primo produce la probabilità logaritmica "corretta" della distribuzione congiunta. Quest'ultimo importo superiore a 7 Tensor, ciascun elemento del quale è la somma della probabilità di log m , b , e un singolo elemento della probabilità registro di Y , quindi overcounts m e b . ( log_prob(m) + log_prob(b) + log_prob(Y) restituisce un risultato piuttosto che un'eccezione perché TFP segue TF e regole di radiodiffusione NumPy,. Aggiungendo uno scalare per un vettore produce un risultato vettore dimensioni)
  • In questo caso particolare, abbiamo potuto risolvere il problema e ottenere lo stesso risultato utilizzando MultivariateNormalDiag anziché Independent(Normal(...)) . MultivariateNormalDiag è una distribuzione vettoriale valori (cioè non ha già event-forma vettoriale). Indeeed MultivariateNormalDiag potrebbe essere (ma non è) implementato come una composizione di Independent e Normal . Vale la pena ricordare che, dato un vettore V , campioni da n1 = Normal(loc=V) , e n2 = MultivariateNormalDiag(loc=V) sono indistinguibili; la differenza beween queste distribuzioni è che n1.log_prob(n1.sample()) è un vettore e n2.log_prob(n2.sample()) è uno scalare.

Campioni multipli?

Il disegno di più campioni continua a non funzionare:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Pensiamo al perché. Quando chiamiamo jds_i.sample([5, 3]) , faremo prima disegnare i campioni per m e b , ciascuno con la forma (5, 3) . Avanti, andiamo a cercare di costruire una Normal distribuzione via:

tfd.Normal(loc=m*X + b, scale=1.)

Ma se m ha forma (5, 3) e X ha forma 7 , non possiamo moltiplicare insieme, e in effetti questo è l'errore che sta colpendo:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Per risolvere questo problema, pensiamo a quali proprietà, distribuiti su Y deve avere. Se abbiamo chiamato jds_i.sample([5, 3]) , allora sappiamo m e b avranno entrambi di forma (5, 3) . Che forma dovrebbe una chiamata al sample sulla Y prodotti distribuzione? La risposta è ovvia (5, 3, 7) : per ogni punto lotto, vogliamo un campione con la stessa dimensione X . Possiamo raggiungere questo obiettivo utilizzando le capacità di trasmissione di TensorFlow, aggiungendo dimensioni extra:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

Aggiunta di un asse sia m e b , si può definire un nuovo JDS che supporta più campioni:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

Come ulteriore controllo, verificheremo che la probabilità di log per un singolo punto batch corrisponda a quella che avevamo prima:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

Autobatching per la vittoria

Eccellente! Ora abbiamo una versione di JointDistribution che gestisce tutti i nostri desiderata: log_prob restituisce uno scalare grazie all'utilizzo di tfd.Independent , e campioni più lavoro ora che abbiamo fissato a trasmettere con l'aggiunta di assi in più.

E se ti dicessi che esiste un modo migliore e più semplice? C'è, e si chiama JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

Come funziona? Mentre si potrebbe tentare di leggere il codice per una comprensione profonda, daremo una panoramica breve che è sufficiente per la maggior parte dei casi di utilizzo:

  • Ricordiamo che il nostro primo problema era che la nostra distribuzione di Y aveva batch_shape=[7] e event_shape=[] , e abbiamo usato Independent per convertire la dimensione batch per una dimensione evento. JDSAB ignora le forme batch delle distribuzioni dei componenti; invece si tratta di forma batch come una proprietà complessiva del modello, che si presume essere [] (se non diversamente specificato dalla regolazione batch_ndims > 0 ). L'effetto è equivalente all'utilizzo tfd.Independent per convertire tutte le dimensioni dei lotti di distribuzioni componenti in dimensioni di evento, come abbiamo fatto manualmente sopra.
  • Il nostro secondo problema era la necessità di massaggiare le forme di m e b in modo che potessero trasmettere adeguatamente con X durante la creazione di campioni multipli. Con JDSAB, si scrive un modello per generare un singolo campione, e noi "lift" l'intero modello di generare campioni multipli usando del tensorflow vectorized_map . (Questa funzione è analoga a di JAX vmap .)

Esplorando la questione forma lotto più in dettaglio, possiamo confrontare le forme lotti del nostro "cattivo" originale distribuzione congiunta jds , le nostre distribuzioni batch fisso jds_i e jds_ia , e la nostra autobatched jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

Vediamo che le originali jds ha subdistributions con diverse forme di batch. jds_i e jds_ia risolvere questo creando subdistributions con la stessa forma batch (vuoto). jds_ab ha soltanto una singola forma batch (vuoto).

Vale la pena di notare che JointDistributionSequentialAutoBatched offre alcune generalità ulteriore gratuitamente. Supponiamo facciamo la covariate X (e, implicitamente, le osservazioni Y ) bidimensionale:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

Il nostro JointDistributionSequentialAutoBatched funziona senza modifiche (abbiamo bisogno di ridefinire il modello, perché la forma di X viene memorizzato nella cache da jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

D'altra parte, la nostra cura artigianale JointDistributionSequential non funziona più:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

Per risolvere questo problema, avremmo dovuto aggiungere un secondo tf.newaxis sia m e b abbinare la forma, e aumentare reinterpreted_batch_ndims a 2 nella chiamata a Independent . In questo caso, lasciare che il macchinario di dosaggio automatico gestisca i problemi di forma è più breve, più facile e più ergonomico.

Ancora una volta, notiamo che, mentre questo notebook esplorato JointDistributionSequentialAutoBatched , le altre varianti di JointDistribution hanno equivalente AutoBatched . (Per gli utenti di JointDistributionCoroutine , JointDistributionCoroutineAutoBatched ha l'ulteriore vantaggio che non hanno più bisogno di specificare Root nodi, se non hai mai usato JointDistributionCoroutine è possibile ignorare questa informativa.)

Pensieri conclusivi

In questo notebook, abbiamo introdotto JointDistributionSequentialAutoBatched e lavorato attraverso un semplice esempio in dettaglio. Spero che tu abbia imparato qualcosa sulle forme TFP e sull'autobatch!