बायेसियन स्विचपॉइंट विश्लेषण

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें

इस नोटबुक reimplements और से बायेसियन "परिवर्तन बिंदु विश्लेषण" उदाहरण फैली pymc3 प्रलेखन

आवश्यक शर्तें

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (15,8)
%config InlineBackend.figure_format = 'retina'
import numpy as np
import pandas as pd

डेटासेट

डाटासेट से है यहाँ । ध्यान दें, वहाँ इस उदाहरण का एक और संस्करण है चारों ओर चल , लेकिन यह है डेटा "लापता" - जिस स्थिति में आप लापता मूल्यों लिए आरोपित करने की आवश्यकता होगी। (अन्यथा आपका मॉडल अपने प्रारंभिक मापदंडों को कभी नहीं छोड़ेगा क्योंकि संभावना समारोह अपरिभाषित होगा।)

disaster_data = np.array([ 4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
                           3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
                           2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0,
                           1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
                           0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
                           3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
                           0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
years = np.arange(1851, 1962)
plt.plot(years, disaster_data, 'o', markersize=8);
plt.ylabel('Disaster count')
plt.xlabel('Year')
plt.title('Mining disaster data set')
plt.show()

पीएनजी

संभाव्य मॉडल

मॉडल एक "स्विच पॉइंट" (उदाहरण के लिए एक वर्ष जिसके दौरान सुरक्षा नियमों को बदल दिया गया है), और पॉइसन-वितरित आपदा दर उस स्विच पॉइंट से पहले और बाद में स्थिर (लेकिन संभावित रूप से भिन्न) दरों के साथ मानता है।

वास्तविक आपदा गणना निश्चित (मनाया गया) है; इस मॉडल के किसी भी नमूने को स्विचपॉइंट और आपदाओं की "प्रारंभिक" और "देर से" दर दोनों को निर्दिष्ट करने की आवश्यकता होगी।

से मूल मॉडल pymc3 प्रलेखन उदाहरण :

\[ \begin{align*} (D_t|s,e,l)&\sim \text{Poisson}(r_t), \\ & \,\quad\text{with}\; r_t = \begin{cases}e & \text{if}\; t < s\\l &\text{if}\; t \ge s\end{cases} \\ s&\sim\text{Discrete Uniform}(t_l,\,t_h) \\ e&\sim\text{Exponential}(r_e)\\ l&\sim\text{Exponential}(r_l) \end{align*} \]

हालांकि, इसका मतलब आपदा दर \(r_t\) switchpoint में एक अंतराल है \(s\)है, जो यह नहीं विभेदक बनाता है। इस प्रकार यह Hamiltonian मोंटे कार्लो (एचएमसी) कलन विधि के लिए कोई ढाल संकेत प्रदान करता है - लेकिन क्योंकि \(s\) पहले निरंतर है, एक यादृच्छिक टहलने के लिए एच एम सी के वापस आने काफी अच्छा इस उदाहरण में उच्च संभावना जन के क्षेत्रों मिल रहा है।

एक दूसरे मॉडल के रूप में, हम एक का उपयोग कर मूल मॉडल संशोधित अवग्रह "स्विच" और एल के बीच संक्रमण विभेदक बनाने के लिए, और switchpoint के लिए एक सतत समान वितरण का उपयोग करने के \(s\)। (कोई यह तर्क दे सकता है कि यह मॉडल वास्तविकता के लिए अधिक सत्य है, क्योंकि माध्य दर में "स्विच" कई वर्षों तक खिंच सकता है।) नया मॉडल इस प्रकार है:

\[ \begin{align*} (D_t|s,e,l)&\sim\text{Poisson}(r_t), \\ & \,\quad \text{with}\; r_t = e + \frac{1}{1+\exp(s-t)}(l-e) \\ s&\sim\text{Uniform}(t_l,\,t_h) \\ e&\sim\text{Exponential}(r_e)\\ l&\sim\text{Exponential}(r_l) \end{align*} \]

अधिक जानकारी के अभाव में हम यह मान \(r_e = r_l = 1\) महंतों के लिए मानकों के रूप में। हम दोनों मॉडल चलाएंगे और उनके अनुमान परिणामों की तुलना करेंगे।

def disaster_count_model(disaster_rate_fn):
  disaster_count = tfd.JointDistributionNamed(dict(
    e=tfd.Exponential(rate=1.),
    l=tfd.Exponential(rate=1.),
    s=tfd.Uniform(0., high=len(years)),
    d_t=lambda s, l, e: tfd.Independent(
        tfd.Poisson(rate=disaster_rate_fn(np.arange(len(years)), s, l, e)),
        reinterpreted_batch_ndims=1)
  ))
  return disaster_count

def disaster_rate_switch(ys, s, l, e):
  return tf.where(ys < s, e, l)

def disaster_rate_sigmoid(ys, s, l, e):
  return e + tf.sigmoid(ys - s) * (l - e)

model_switch = disaster_count_model(disaster_rate_switch)
model_sigmoid = disaster_count_model(disaster_rate_sigmoid)

उपरोक्त कोड संयुक्त वितरण अनुक्रमिक वितरण के माध्यम से मॉडल को परिभाषित करता है। disaster_rate कार्यों की एक सरणी के साथ कहा जाता है [0, ..., len(years)-1] का एक वेक्टर का निर्माण करने के len(years) यादृच्छिक परिवर्तनीय - से पहले के वर्षों switchpoint हैं early_disaster_rate , के बाद लोगों late_disaster_rate (सापेक्ष सिग्मॉइड संक्रमण)।

यहां एक विवेक-जांच है कि लक्ष्य लॉग प्रोब फ़ंक्शन समझदार है:

def target_log_prob_fn(model, s, e, l):
  return model.log_prob(s=s, e=e, l=l, d_t=disaster_data)

models = [model_switch, model_sigmoid]
print([target_log_prob_fn(m, 40., 3., .9).numpy() for m in models])  # Somewhat likely result
print([target_log_prob_fn(m, 60., 1., 5.).numpy() for m in models])  # Rather unlikely result
print([target_log_prob_fn(m, -10., 1., 1.).numpy() for m in models]) # Impossible result
[-176.94559, -176.28717]
[-371.3125, -366.8816]
[-inf, -inf]

एचएमसी बायेसियन इंट्रेंस करेगी

हम आवश्यक परिणामों की संख्या और बर्न-इन चरणों को परिभाषित करते हैं; कोड ज्यादातर के बाद मॉडलिंग की है tfp.mcmc.HamiltonianMonteCarlo के प्रलेखन । यह एक अनुकूली चरण आकार का उपयोग करता है (अन्यथा परिणाम चुने गए चरण आकार मान के प्रति बहुत संवेदनशील होता है)। हम श्रृंखला की प्रारंभिक स्थिति के रूप में एक के मूल्यों का उपयोग करते हैं।

हालांकि यह पूरी कहानी नहीं है। यदि आप उपरोक्त मॉडल परिभाषा पर वापस जाते हैं, तो आप देखेंगे कि कुछ संभाव्यता वितरण पूरी वास्तविक संख्या रेखा पर अच्छी तरह से परिभाषित नहीं हैं। इसलिए हम अंतरिक्ष कि एच एम सी एक साथ एच एम सी गिरी लपेटकर द्वारा जांच करेगा विवश TransformedTransitionKernel कि निर्दिष्ट करता है आगे bijectors डोमेन कि संभावना वितरण पर परिभाषित किया गया है (नीचे दिए गए कोड में टिप्पणी देखें) पर वास्तविक संख्या को बदलने के लिए।

num_results = 10000
num_burnin_steps = 3000

@tf.function(autograph=False, jit_compile=True)
def make_chain(target_log_prob_fn):
   kernel = tfp.mcmc.TransformedTransitionKernel(
       inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
          target_log_prob_fn=target_log_prob_fn,
          step_size=0.05,
          num_leapfrog_steps=3),
       bijector=[
          # The switchpoint is constrained between zero and len(years).
          # Hence we supply a bijector that maps the real numbers (in a
          # differentiable way) to the interval (0;len(yers))
          tfb.Sigmoid(low=0., high=tf.cast(len(years), dtype=tf.float32)),
          # Early and late disaster rate: The exponential distribution is
          # defined on the positive real numbers
          tfb.Softplus(),
          tfb.Softplus(),
      ])
   kernel = tfp.mcmc.SimpleStepSizeAdaptation(
        inner_kernel=kernel,
        num_adaptation_steps=int(0.8*num_burnin_steps))

   states = tfp.mcmc.sample_chain(
      num_results=num_results,
      num_burnin_steps=num_burnin_steps,
      current_state=[
          # The three latent variables
          tf.ones([], name='init_switchpoint'),
          tf.ones([], name='init_early_disaster_rate'),
          tf.ones([], name='init_late_disaster_rate'),
      ],
      trace_fn=None,
      kernel=kernel)
   return states

switch_samples = [s.numpy() for s in make_chain(
    lambda *args: target_log_prob_fn(model_switch, *args))]
sigmoid_samples = [s.numpy() for s in make_chain(
    lambda *args: target_log_prob_fn(model_sigmoid, *args))]

switchpoint, early_disaster_rate, late_disaster_rate = zip(
    switch_samples, sigmoid_samples)

दोनों मॉडलों को समानांतर में चलाएँ:

परिणाम की कल्पना करें

हम प्रारंभिक और देर से आपदा दर के साथ-साथ स्विचपॉइंट के लिए पश्च वितरण के नमूनों के हिस्टोग्राम के रूप में परिणाम की कल्पना करते हैं। हिस्टोग्राम को नमूना माध्यिका का प्रतिनिधित्व करने वाली एक ठोस रेखा के साथ मढ़ा जाता है, साथ ही धराशायी रेखाओं के रूप में 95% ile विश्वसनीय अंतराल सीमा।

def _desc(v):
  return '(median: {}; 95%ile CI: $[{}, {}]$)'.format(
      *np.round(np.percentile(v, [50, 2.5, 97.5]), 2))

for t, v in [
    ('Early disaster rate ($e$) posterior samples', early_disaster_rate),
    ('Late disaster rate ($l$) posterior samples', late_disaster_rate),
    ('Switch point ($s$) posterior samples', years[0] + switchpoint),
]:
  fig, ax = plt.subplots(nrows=1, ncols=2, sharex=True)
  for (m, i) in (('Switch', 0), ('Sigmoid', 1)):
    a = ax[i]
    a.hist(v[i], bins=50)
    a.axvline(x=np.percentile(v[i], 50), color='k')
    a.axvline(x=np.percentile(v[i], 2.5), color='k', ls='dashed', alpha=.5)
    a.axvline(x=np.percentile(v[i], 97.5), color='k', ls='dashed', alpha=.5)
    a.set_title(m + ' model ' + _desc(v[i]))
  fig.suptitle(t)
  plt.show()

पीएनजी

पीएनजी

पीएनजी