Probabilità TensorFlow su JAX

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza l'origine su GitHub Scarica quaderno

Tensorflow Probabilità (TFP) è una libreria per ragionamento probabilistico e l'analisi statistica che ora funziona anche su JAX ! Per chi non lo conoscesse, JAX è una libreria per il calcolo numerico accelerato basata su trasformazioni di funzioni componibili.

TFP su JAX supporta molte delle funzionalità più utili della normale TFP preservando le astrazioni e le API con cui molti utenti di TFP si trovano ora a proprio agio.

Impostare

TFP su JAX non dipende tensorflow; disinstalliamo completamente TensorFlow da questo Colab.

pip uninstall tensorflow -y -q

Possiamo installare TFP su JAX con le ultime build notturne di TFP.

pip install -Uq tfp-nightly[jax] > /dev/null

Importiamo alcune utili librerie Python.

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import datasets
sns.set(style='white')
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Importiamo anche alcune funzionalità JAX di base.

import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap

Importare TFP su JAX

Per utilizzare TFP su JAX, è sufficiente importare il jax "substrato" e usarlo come la normale procedura tfp :

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

Demo: regressione logistica bayesiana

Per dimostrare cosa possiamo fare con il backend JAX, implementeremo la regressione logistica bayesiana applicata al classico set di dati Iris.

Per prima cosa, importiamo il set di dati Iris ed estraiamo alcuni metadati.

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

Possiamo definire il modello utilizzando tfd.JointDistributionCoroutine . Metteremo priori normali standard su entrambi i pesi e il termine bias di poi scrivere la target_log_prob funzione che i perni delle etichette campionati ai dati.

Root = tfd.JointDistributionCoroutine.Root
def model():
  w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
                            sample_shape=(num_features, num_classes)))
  b = yield Root(
      tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
  logits = jnp.dot(features, w) + b
  yield tfd.Independent(tfd.Categorical(logits=logits),
                        reinterpreted_batch_ndims=1)


dist = tfd.JointDistributionCoroutine(model)
def target_log_prob(*params):
  return dist.log_prob(params + (labels,))

Noi campione dal dist per produrre uno stato iniziale per MCMC. Possiamo quindi definire una funzione che prende in una chiave casuale e uno stato iniziale e produce 500 campioni da un campionatore No-U-Turn (NUTS). Si noti che possiamo usare trasformazioni JAX come jit per compilare il nostro campionatore NUTS utilizzando XLA.

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=init_key)[:-1])

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)

states, log_probs = run_chain(sample_key, init_params)
plt.figure()
plt.plot(log_probs)
plt.ylabel('Target Log Prob')
plt.xlabel('Iterations of NUTS')
plt.show()

png

Usiamo i nostri campioni per eseguire la media del modello bayesiano (BMA) calcolando la media delle probabilità previste di ciascun insieme di pesi.

Per prima cosa scriviamo una funzione che per un dato insieme di parametri produrrà le probabilità su ogni classe. Possiamo usare dist.sample_distributions per ottenere la distribuzione finale nel modello.

def classifier_probs(params):
  dists, _ = dist.sample_distributions(seed=random.PRNGKey(0),
                                       value=params + (None,))
  return dists[-1].distribution.probs_parameter()

Siamo in grado di vmap(classifier_probs) il gruppo di campioni per ottenere le probabilità di classe previsti per ciascuno dei nostri campioni. Quindi calcoliamo l'accuratezza media su ciascun campione e l'accuratezza dalla media del modello bayesiano.

all_probs = jit(vmap(classifier_probs))(states)
print('Average accuracy:', jnp.mean(all_probs.argmax(axis=-1) == labels))
print('BMA accuracy:', jnp.mean(all_probs.mean(axis=0).argmax(axis=-1) == labels))
Average accuracy: 0.96952
BMA accuracy: 0.97999996

Sembra che BMA riduca il nostro tasso di errore di quasi un terzo!

Fondamenti

TFP su JAX ha un'API identica a TF dove invece di accettare oggetti TF come tf.Tensor s accetta l'analogico JAX. Ad esempio, laddove un tf.Tensor stato precedentemente utilizzato come input, l'API ora aspetta un JAX DeviceArray . Invece di restituire un tf.Tensor , metodi TFP torneranno DeviceArray s. TFP su JAX funziona anche con strutture annidate di oggetti JAX, come una lista o un dizionario di DeviceArray s.

distribuzioni

La maggior parte delle distribuzioni di TFP sono supportate in JAX con una semantica molto simile alle loro controparti TF. Essi sono anche registrati come JAX Pytrees , in modo che possano essere ingressi e le uscite delle funzioni JAX-trasformate.

distribuzioni di base

Il log_prob metodo per distribuzioni funziona allo stesso modo.

dist = tfd.Normal(0., 1.)
print(dist.log_prob(0.))
-0.9189385

Campionamento da una distribuzione richiede esplicitamente che passa in un PRNGKey (o un elenco di numeri interi) come il seed argomento parola chiave. Non riuscire a passare esplicitamente un seme genererà un errore.

tfd.Normal(0., 1.).sample(seed=random.PRNGKey(0))
DeviceArray(-0.20584226, dtype=float32)

La semantica forma per distribuzioni rimangono invariati in JAX, dove le distribuzioni hanno ciascuno un event_shape e batch_shape e disegno molti campioni aggiungerà ulteriori sample_shape dimensioni.

Ad esempio, un tfd.MultivariateNormalDiag con parametri vettore avrà una forma evento vettoriale e forma batch vuoto.

dist = tfd.MultivariateNormalDiag(
    loc=jnp.zeros(5),
    scale_diag=jnp.ones(5)
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: (5,)
Batch shape: ()

D'altra parte, un tfd.Normal parametrizzato con vettori avrà una forma forma evento e vettoriale lotto scalare.

dist = tfd.Normal(
    loc=jnp.ones(5),
    scale=jnp.ones(5),
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: ()
Batch shape: (5,)

La semantica di prendere log_prob di campioni funziona lo stesso in JAX troppo.

dist =  tfd.Normal(jnp.zeros(5), jnp.ones(5))
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)

dist =  tfd.Independent(tfd.Normal(jnp.zeros(5), jnp.ones(5)), 1)
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)
(10, 2, 5)
(10, 2)

Perché JAX DeviceArray s sono compatibili con le librerie come NumPy e Matplotlib, siamo in grado di nutrire i campioni direttamente in una funzione di stampa.

sns.distplot(tfd.Normal(0., 1.).sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Distribution metodi sono compatibili con le trasformazioni JAX.

sns.distplot(jit(vmap(lambda key: tfd.Normal(0., 1.).sample(seed=key)))(
    random.split(random.PRNGKey(0), 2000)))
plt.show()

png

x = jnp.linspace(-5., 5., 100)
plt.plot(x, jit(vmap(grad(tfd.Normal(0., 1.).prob)))(x))
plt.show()

png

Poiché distribuzioni TFP sono registrati come JAX nodi pytree, possiamo scrivere funzioni con distribuzioni come ingressi o uscite e trasformarli utilizzando jit , ma non sono ancora supportato come argomenti vmap funzioni -ed.

@jit
def random_distribution(key):
  loc_key, scale_key = random.split(key)
  loc, log_scale = random.normal(loc_key), random.normal(scale_key)
  return tfd.Normal(loc, jnp.exp(log_scale))
random_dist = random_distribution(random.PRNGKey(0))
print(random_dist.mean(), random_dist.variance())
0.14389051 0.081832744

Distribuzioni trasformate

Distribuzioni trasformati cioè distribuzioni cui campioni sono passati attraverso un Bijector opera anche fuori dalla scatola (bijectors lavorano troppo! Vedi sotto).

dist = tfd.TransformedDistribution(
    tfd.Normal(0., 1.),
    tfb.Sigmoid()
)
sns.distplot(dist.sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Distribuzioni congiunte

TFP offre JointDistribution s per consentire combinando distribuzioni componenti in un'unica distribuzione su più variabili casuali. Attualmente, TFP offerte tre varianti principali ( JointDistributionSequential , JointDistributionNamed e JointDistributionCoroutine ) i quali sono supportati in JAX. I AutoBatched varianti sono tutti supportati.

dist = tfd.JointDistributionSequential([
  tfd.Normal(0., 1.),
  lambda x: tfd.Normal(x, 1e-1)
])
plt.scatter(*dist.sample(1000, seed=random.PRNGKey(0)), alpha=0.5)
plt.show()

png

joint = tfd.JointDistributionNamed(dict(
    e=             tfd.Exponential(rate=1.),
    n=             tfd.Normal(loc=0., scale=2.),
    m=lambda n, e: tfd.Normal(loc=n, scale=e),
    x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
))
joint.sample(seed=random.PRNGKey(0))
{'e': DeviceArray(3.376818, dtype=float32),
 'm': DeviceArray(2.5449684, dtype=float32),
 'n': DeviceArray(-0.6027825, dtype=float32),
 'x': DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Root = tfd.JointDistributionCoroutine.Root
def model():
  e = yield Root(tfd.Exponential(rate=1.))
  n = yield Root(tfd.Normal(loc=0, scale=2.))
  m = yield tfd.Normal(loc=n, scale=e)
  x = yield tfd.Sample(tfd.Bernoulli(logits=m), 12)

joint = tfd.JointDistributionCoroutine(model)

joint.sample(seed=random.PRNGKey(0))
StructTuple(var0=DeviceArray(0.17315261, dtype=float32), var1=DeviceArray(-3.290489, dtype=float32), var2=DeviceArray(-3.1949058, dtype=float32), var3=DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))

Altre distribuzioni

I processi gaussiani funzionano anche in modalità JAX!

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
observation_noise_variance = 0.01
f = lambda x: jnp.sin(10*x[..., 0]) * jnp.exp(-x[..., 0]**2)
observation_index_points = random.uniform(
    k1, [50], minval=-1.,maxval= 1.)[..., jnp.newaxis]
observations = f(observation_index_points) + tfd.Normal(
    loc=0., scale=jnp.sqrt(observation_noise_variance)).sample(seed=k2)

index_points = jnp.linspace(-1., 1., 100)[..., jnp.newaxis]

kernel = tfpk.ExponentiatedQuadratic(length_scale=0.1)

gprm = tfd.GaussianProcessRegressionModel(
    kernel=kernel,
    index_points=index_points,
    observation_index_points=observation_index_points,
    observations=observations,
    observation_noise_variance=observation_noise_variance)

samples = gprm.sample(10, seed=k3)
for i in range(10):
  plt.plot(index_points, samples[i], alpha=0.5)
plt.plot(observation_index_points, observations, marker='o', linestyle='')
plt.show()

png

Sono supportati anche i modelli Markov nascosti.

initial_distribution = tfd.Categorical(probs=[0.8, 0.2])
transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                 [0.2, 0.8]])

observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

model = tfd.HiddenMarkovModel(
    initial_distribution=initial_distribution,
    transition_distribution=transition_distribution,
    observation_distribution=observation_distribution,
    num_steps=7)

print(model.mean())
print(model.log_prob(jnp.zeros(7)))
print(model.sample(seed=random.PRNGKey(0)))
[3.       6.       7.5      8.249999 8.625001 8.812501 8.90625 ]
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/substrates/jax/distributions/hidden_markov_model.py:483: UserWarning: HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug in which the transition model was applied prior to the initial step. This bug has been fixed. You may observe a slight change in behavior.
  'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
-19.855635
[ 1.3641367  0.505798   1.3626463  3.6541772  2.272286  15.10309
 22.794212 ]

Poche distribuzioni come PixelCNN non sono ancora supportati a causa di severe dipendenze tensorflow o XLA incompatibilità.

Biiettori

La maggior parte dei biiettori di TFP sono supportati in JAX oggi!

tfb.Exp().inverse(1.)
DeviceArray(0., dtype=float32)
bij = tfb.Shift(1.)(tfb.Scale(3.))
print(bij.forward(jnp.ones(5)))
print(bij.inverse(jnp.ones(5)))
[4. 4. 4. 4. 4.]
[0. 0. 0. 0. 0.]
b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
print(b.forward(x=[0., 0., 0.]))
print(b.inverse(y=[[1., 0], [.5, 2]]))
[[1. 0.]
 [0. 1.]]
[0.6931472 0.5       0.       ]
b = tfb.Chain([tfb.Exp(), tfb.Softplus()])
# or:
# b = tfb.Exp()(tfb.Softplus())
print(b.forward(-jnp.ones(5)))
[1.3678794 1.3678794 1.3678794 1.3678794 1.3678794]

Bijectors sono compatibili con le trasformazioni JAX come jit , grad e vmap .

jit(vmap(tfb.Exp().inverse))(jnp.arange(4.))
DeviceArray([     -inf, 0.       , 0.6931472, 1.0986123], dtype=float32)
x = jnp.linspace(0., 1., 100)
plt.plot(x, jit(grad(lambda x: vmap(tfb.Sigmoid().inverse)(x).sum()))(x))
plt.show()

png

Alcuni bijectors, come RealNVP e FFJORD non sono ancora supportati.

MCMC

Abbiamo porting tfp.mcmc a JAX pure, in modo da poter eseguire algoritmi come Hamiltoniana Monte Carlo (HMC) e il No-U-Turn-Sampler (NUTS) in JAX.

target_log_prob = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)).log_prob

A differenza di TFP il TF, ci viene richiesto di passare un PRNGKey in sample_chain utilizzando il seed argomento chiave.

def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-1)
  return tfp.mcmc.sample_chain(1000,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      seed=key)
states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros(2))
plt.figure()
plt.scatter(*states.T, alpha=0.5)
plt.figure()
plt.plot(log_probs)
plt.show()

png

png

Per eseguire più catene, possiamo sia passare una serie di Stati in sample_chain o l'uso vmap (anche se non abbiamo ancora esplorato differenze di prestazioni tra i due approcci).

states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros([10, 2]))
plt.figure()
for i in range(10):
  plt.scatter(*states[:, i].T, alpha=0.5)
plt.figure()
for i in range(10):
  plt.plot(log_probs[:, i], alpha=0.5)
plt.show()

png

png

ottimizzatori

TFP su JAX supporta alcuni importanti ottimizzatori come BFGS e L-BFGS. Impostiamo una semplice funzione di perdita quadratica in scala.

minimum = jnp.array([1.0, 1.0])  # The center of the quadratic bowl.
scales = jnp.array([2.0, 3.0])  # The scales along the two axes.

# The objective function and the gradient.
def quadratic_loss(x):
  return jnp.sum(scales * jnp.square(x - minimum))

start = jnp.array([0.6, 0.8])  # Starting point for the search.

BFGS può trovare il minimo di questa perdita.

optim_results = tfp.optimizer.bfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

Così può L-BFGS.

optim_results = tfp.optimizer.lbfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

Per vmap L-BFGS, facciamo impostare una funzione che ottimizza la perdita di un singolo punto di partenza.

def optimize_single(start):
  return tfp.optimizer.lbfgs_minimize(
      value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

all_results = jit(vmap(optimize_single))(
    random.normal(random.PRNGKey(0), (10, 2)))
assert all(all_results.converged)
for i in range(10):
  np.testing.assert_allclose(optim_results.position[i], minimum)
print("Function evaluations: %s" % all_results.num_objective_evaluations)
Function evaluations: [6 6 9 6 6 8 6 8 5 9]

Avvertenze

Esistono alcune differenze fondamentali tra TF e JAX, alcuni comportamenti TFP saranno diversi tra i due substrati e non tutte le funzionalità sono supportate. Per esempio,

  • TFP su JAX non supporta nulla di simile tf.Variable poiché nulla di simile esiste in JAX. Questo significa anche utilità come tfp.util.TransformedVariable non sono supportati neanche.
  • tfp.layers non è ancora supportato nel backend, a causa della sua dipendenza Keras e tf.Variable s.
  • tfp.math.minimize non funziona in TFP su JAX a causa della sua dipendenza da tf.Variable .
  • Con TFP su JAX, le forme tensoriali sono sempre valori interi concreti e non sono mai sconosciute/dinamiche come in TFP su TF.
  • La pseudocasualità è gestita in modo diverso in TF e JAX (vedi appendice).
  • Biblioteche in tfp.experimental non sono garantiti esistere nel substrato JAX.
  • Le regole di promozione Dtype sono diverse tra TF e JAX. TFP su JAX cerca di rispettare la semantica dtype di TF internamente, per coerenza.
  • I biiettori non sono ancora stati registrati come pytree JAX.

Per visualizzare l'elenco completo di ciò che è supportato in TFP su JAX, si prega di fare riferimento alla documentazione API .

Conclusione

Abbiamo portato molte delle funzionalità di TFP su JAX e non vediamo l'ora di vedere cosa costruiranno tutti. Alcune funzionalità non sono ancora supportate; se abbiamo perso qualcosa di importante per voi (o se si trova un bug!) rivolgiti a noi - si può e-mail tfprobability@tensorflow.org o file un problema sulla nostra repo Github .

Appendice: pseudocasualità in JAX

Il modello di JAX generazione di numeri pseudo (PRNG) è senza stato. A differenza di un modello stateful, non esiste uno stato globale mutevole che si evolve dopo ogni estrazione casuale. Nel modello di JAX, iniziamo con una chiave PRNG, che agisce come un paio di interi a 32 bit. Siamo in grado di costruire questi chiavi utilizzando jax.random.PRNGKey .

key = random.PRNGKey(0)  # Creates a key with value [0, 0]
print(key)
[0 0]

Funzioni casuali in JAX consumano una chiave per deterministicamente produrre una variata a caso, nel senso che non devono essere utilizzati di nuovo. Per esempio, possiamo usare key per campionare un valore distribuito normalmente, ma non dobbiamo usare key di nuovo altrove. Inoltre, superato lo stesso valore in random.normal produrrà lo stesso valore.

print(random.normal(key))
-0.20584226

Quindi, come possiamo disegnare più campioni da una singola chiave? La risposta è la divisione chiave. L'idea di base è che possiamo dividere un PRNGKey in multiplo, e ciascuna delle nuove chiavi può essere trattata come fonte indipendente di casualità.

key1, key2 = random.split(key, num=2)
print(key1, key2)
[4146024105  967050713] [2718843009 1272950319]

La suddivisione delle chiavi è deterministica ma caotica, quindi ogni nuova chiave può ora essere utilizzata per disegnare un campione casuale distinto.

print(random.normal(key1), random.normal(key2))
0.14389051 -1.2515389

Per maggiori dettagli su deterministica del modello chiave di scissione di JAX, consultare questa guida .