Phân phối khớp tự động theo lô: Hướng dẫn nhẹ nhàng

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Giới thiệu

TensorFlow Xác suất (TFP) cung cấp một số JointDistribution trừu tượng mà làm cho suy luận xác suất dễ dàng hơn bằng cách cho phép người dùng dễ dàng thể hiện một mô hình đồ họa theo xác suất trong một hình thức gần như toán học; sự trừu tượng tạo ra các phương pháp lấy mẫu từ mô hình và đánh giá xác suất log của các mẫu từ mô hình. Trong hướng dẫn này, chúng tôi xem xét "autobatched" biến thể, được phát triển sau khi bản gốc JointDistribution trừu tượng. Liên quan đến các bản tóm tắt gốc, không được tự động khớp, các phiên bản được tự động khớp đơn giản hơn để sử dụng và tiện dụng hơn, cho phép nhiều mô hình được thể hiện với ít bảng soạn sẵn hơn. Trong chuyên mục này, chúng tôi khám phá một mô hình đơn giản đến từng chi tiết (có lẽ tẻ nhạt), làm rõ các vấn đề giải quyết vấn đề tự động so khớp và (hy vọng) dạy người đọc nhiều hơn về khái niệm hình dạng TFP trong quá trình thực hiện.

Trước sự ra đời của autobatching, đã có một vài biến thể khác nhau của JointDistribution , tương ứng với kiểu cú pháp khác nhau để thể hiện mô hình xác suất: JointDistributionSequential , JointDistributionNamed , và JointDistributionCoroutine . Auobatching tồn tại như một mixin, vì vậy bây giờ chúng ta có AutoBatched biến thể của tất cả các. Trong hướng dẫn này, chúng tôi khám phá sự khác biệt giữa JointDistributionSequentialJointDistributionSequentialAutoBatched ; tuy nhiên, mọi thứ chúng tôi làm ở đây đều có thể áp dụng cho các biến thể khác mà về cơ bản không có thay đổi.

Phụ thuộc & Điều kiện tiên quyết

Nhập và thiết lập

Điều kiện tiên quyết: Một vấn đề hồi quy Bayes

Chúng ta sẽ xem xét một kịch bản hồi quy Bayes rất đơn giản:

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

Trong mô hình này, mb được rút ra từ normals tiêu chuẩn, và các quan sát Y được rút ra từ một phân phối chuẩn có nghĩa là phụ thuộc vào các biến ngẫu nhiên mb , và một số (không ngẫu nhiên, được biết đến) đồng biến X . (Để đơn giản, trong ví dụ này, chúng tôi giả sử quy mô của tất cả các biến ngẫu nhiên đã biết.)

Thực hiện kết luận trong mô hình này, chúng tôi cần phải biết cả hai biến số X và các quan sát Y , nhưng đối với mục đích của hướng dẫn này, chúng tôi sẽ chỉ cần X , vì vậy chúng ta định nghĩa một hình nộm đơn giản X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

Desiderata

Trong suy luận xác suất, chúng ta thường muốn thực hiện hai phép toán cơ bản:

  • sample : Vẽ mẫu từ mô hình.
  • log_prob : Tính toán xác suất log của một mẫu từ mô hình.

Sự đóng góp quan trọng của TFP của JointDistribution trừu tượng (cũng như của nhiều cách tiếp cận khác để lập trình xác suất) là cho phép người dùng viết một mô hình một lần và được tiếp cận với cả hai samplelog_prob tính toán.

Lưu ý rằng chúng tôi có 7 điểm trong tập dữ liệu của chúng tôi ( X.shape = (7,) ), bây giờ chúng ta có thể nêu rõ ước nguyện cho một tuyệt vời JointDistribution :

  • sample() nên tạo ra một danh sách các Tensors có hình dạng [(), (), (7,) ], tương ứng với độ dốc vô hướng, thiên vị vô hướng, và quan sát véc tơ, tương ứng.
  • log_prob(sample()) nên sản xuất một vô hướng: xác suất log của một đặc biệt dốc, thiên vị, và quan sát.
  • sample([5, 3]) nên tạo ra một danh sách các Tensors có hình dạng [(5, 3), (5, 3), (5, 3, 7)] , đại diện cho một (5, 3) - hàng loạt mẫu từ ngươi mâu.
  • log_prob(sample([5, 3])) nên sản xuất một Tensor với hình dạng (5, 3).

Bây giờ chúng ta sẽ xem xét một loạt các JointDistribution mô hình, xem làm thế nào để đạt được ước nguyện trên, và hy vọng tìm hiểu một chút thêm về TFP hình dạng trên đường đi.

Spoiler cảnh báo: Cách tiếp cận đó thỏa mãn các ước nguyện trên mà không soạn thêm được autobatching .

Nỗ lực đầu tiên; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

Đây ít nhiều là một bản dịch trực tiếp của mô hình thành mã. Độ dốc m và bias b là đơn giản. Y được xác định bằng cách sử dụng lambda -function: mô hình chung là một lambda -function của \(k\) đối số trong một JointDistributionSequential (JDS) sử dụng trước \(k\) phân phối trong mô hình. Lưu ý thứ tự "đảo ngược".

Chúng tôi sẽ gọi sample_distributions , mà lợi nhuận cả một mẫu tiềm ẩn "tiểu phân phối" đã được sử dụng để tạo ra mẫu. (Chúng tôi có thể sản xuất chỉ là mẫu bằng cách gọi sample ; sau trong hướng dẫn nó sẽ được thuận tiện để có các bản phân phối là tốt.) Các mẫu chúng tôi sản xuất là tốt:

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Nhưng log_prob tạo ra một kết quả với một hình dạng không mong muốn:

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

Và lấy mẫu nhiều lần không hoạt động:

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Hãy cố gắng hiểu những gì đang xảy ra.

Đánh giá ngắn gọn: Hàng loạt và hình dạng sự kiện

Trong TFP, một (không phải là một bình thường JointDistribution ) phân bố xác suất có một hình dạng sự kiện và một hình hàng loạt, và hiểu được sự khác biệt là rất quan trọng để sử dụng hiệu quả TFP:

  • Hình dạng sự kiện mô tả hình dạng của một lần vẽ duy nhất từ ​​phân phối; việc rút thăm có thể phụ thuộc vào các kích thước. Đối với phân phối vô hướng, hình dạng sự kiện là []. Đối với Đa biến 5 chiều Bình thường, hình dạng sự kiện là [5].
  • Hình dạng lô mô tả các lần rút được phân phối độc lập, không giống nhau, còn được gọi là "lô" các bản phân phối. Trình bày một loạt các bản phân phối trong một đối tượng Python duy nhất là một trong những cách quan trọng để TFP đạt được hiệu quả trên quy mô lớn.

Đối với mục đích của chúng tôi, một thực tế quan trọng cần lưu ý là nếu chúng ta gọi là log_prob trên một mẫu duy nhất từ một phân phối, kết quả sẽ luôn luôn có một hình dạng mà các trận đấu (ví dụ, có như kích thước bìa phải) hình dạng hàng loạt.

Đối với một sâu sắc hơn cuộc thảo luận về hình dạng, xem các "Tìm hiểu TensorFlow phân phối Shapes" hướng dẫn .

Tại sao không log_prob(sample()) Sản xuất một vô hướng?

Hãy sử dụng kiến thức của chúng ta về hàng loạt và sự kiện hình dạng để khám phá những gì đang xảy ra với log_prob(sample()) . Đây là mẫu của chúng tôi một lần nữa:

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

Và đây là các bản phân phối của chúng tôi:

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

Xác suất nhật ký được tính bằng cách tính tổng các xác suất nhật ký của các phân phối con tại các phần tử (phù hợp) của các bộ phận:

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

Vì vậy, một trong những mức độ giải thích là các tính log khả năng được trả lại một 7-tensor vì tiểu hợp phần thứ ba của log_prob_parts là một 7-tensor. Nhưng tại sao?

Vâng, chúng ta thấy rằng yếu tố cuối cùng của dists , tương ứng với phân phối của chúng tôi qua Y trong việc xây dựng mathematial, có batch_shape của [7] . Nói cách khác, phân phối của chúng tôi qua Y là một loạt 7 normals độc lập (với các phương tiện khác nhau, và trong trường hợp này, quy mô như nhau).

Bây giờ chúng ta hiểu những gì là sai: trong JDS, sự phân bố trên Ybatch_shape=[7] , một mẫu từ JDS đại diện cho vô hướng cho mb và một "mẻ" của 7 normals độc lập. và log_prob tính 7 riêng log-xác suất, mỗi trong số đó thể hiện khả năng ghi vẽ mb và một quan sát đơn Y[i] tại một số X[i] .

Sửa log_prob(sample()) với Independent

Nhớ lại rằng dists[2]event_shape=[]batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

Bằng cách sử dụng TFP của Independent metadistribution, chuyển đổi kích thước hàng loạt để kích thước sự kiện, chúng ta có thể chuyển đổi này vào một phân phối với event_shape=[7]batch_shape=[] (chúng tôi sẽ đổi tên nó y_dist_i bởi vì nó là một bản phân phối trên Y , với _i đứng in cho chúng tôi Independent gói):

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

Bây giờ, log_prob của một 7-vector là một vô hướng:

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

Dưới sự bao trùm, Independent khoản tiền so với hàng loạt:

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

Và quả thực, chúng ta có thể sử dụng để xây dựng một mới jds_i (các i lại tượng trưng cho Independent ), nơi log_prob trả về một đại lượng vô hướng:

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

Một vài lưu ý:

  • jds_i.log_prob(s)không giống như tf.reduce_sum(jds.log_prob(s)) . Trước đây tạo ra xác suất nhật ký "đúng" của phân phối chung. Các khoản tiền sau hơn 7 tensor, mỗi phần tử trong số đó là tổng xác suất log của m , b , và một yếu tố duy nhất của xác suất log của Y , vì vậy nó overcounts mb . ( log_prob(m) + log_prob(b) + log_prob(Y) trả về một kết quả chứ không phải ném một ngoại lệ vì TFP sau TF và các quy tắc phát sóng NumPy của;. Thêm một vô hướng tới một vector tạo ra một kết quả vector có kích thước)
  • Trong trường hợp đặc biệt này, chúng ta có thể giải quyết vấn đề và đạt được kết quả tương tự sử dụng MultivariateNormalDiag thay vì Independent(Normal(...)) . MultivariateNormalDiag là một phân phối vector có giá trị (ví dụ, nó đã có vector sự kiện-shape). Indeeed MultivariateNormalDiag có thể (nhưng không được) thực hiện như một phần của IndependentNormal . Đó là đáng giá để nhớ mà đưa ra một vector V , mẫu từ n1 = Normal(loc=V) , và n2 = MultivariateNormalDiag(loc=V) không thể phân biệt; sự khác biệt beween các bản phân phối là n1.log_prob(n1.sample()) là một vector và n2.log_prob(n2.sample()) là một đại lượng vô hướng.

Nhiều mẫu?

Vẽ nhiều mẫu vẫn không hoạt động:

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Hãy suy nghĩ về lý do tại sao. Khi chúng ta gọi là jds_i.sample([5, 3]) , chúng tôi sẽ đầu tiên vẽ mẫu cho mb , mỗi hình dạng (5, 3) . Tiếp theo, chúng ta sẽ cố gắng xây dựng một Normal phân phối thông qua:

tfd.Normal(loc=m*X + b, scale=1.)

Nhưng nếu m có hình dạng (5, 3)X có hình 7 , chúng ta không thể nhân chúng lại với nhau, và thực sự đây là lỗi chúng tôi đang đánh:

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

Để giải quyết vấn đề này, chúng ta hãy suy nghĩ về những gì thuộc tính phân phối qua Y phải có. Nếu chúng ta đã gọi jds_i.sample([5, 3]) , sau đó chúng ta biết mb cả hai sẽ có hình dạng (5, 3) . Hình dạng gì nên một cuộc gọi đến sample trên Y phân phối sản phẩm? Câu trả lời rõ ràng là (5, 3, 7) : cho mỗi điểm hàng loạt, chúng tôi muốn có một mẫu với kích thước tương tự như X . Chúng tôi có thể đạt được điều này bằng cách sử dụng khả năng phát sóng của TensorFlow, thêm các thứ nguyên bổ sung:

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

Thêm một trục cho cả mb , chúng ta có thể định nghĩa một JDS mới hỗ trợ nhiều mẫu:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

Để kiểm tra thêm, chúng tôi sẽ xác minh rằng xác suất nhật ký cho một điểm lô phù hợp với những gì chúng tôi đã có trước đây:

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

AutoBatching để giành chiến thắng

Thông minh! Bây giờ chúng ta có một phiên bản của JointDistribution mà xử lý tất cả chúng tôi ước nguyện: log_prob lợi nhuận một nhờ vô hướng đến việc sử dụng tfd.Independent , và nhiều mẫu làm việc bây giờ mà chúng tôi cố định phát sóng bằng cách thêm các trục phụ.

Điều gì sẽ xảy ra nếu tôi nói với bạn rằng có một cách dễ dàng hơn, tốt hơn? Có, và nó được gọi là JointDistributionSequentialAutoBatched (JDSAB):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

Cái này hoạt động ra sao? Trong khi bạn có thể cố gắng để đọc mã cho một sự hiểu biết sâu sắc, chúng tôi sẽ cung cấp cho một cái nhìn tổng quan ngắn gọn đó là đủ cho hầu hết các trường hợp sử dụng:

  • Nhớ lại rằng vấn đề đầu tiên của chúng tôi là phân phối của chúng tôi cho Ybatch_shape=[7]event_shape=[] , và chúng tôi sử dụng Independent để chuyển đổi kích thước hàng loạt đến một khía cạnh sự kiện. JDSAB bỏ qua các hình dạng lô của các bản phân phối thành phần; thay vào đó nó đối xử với hình dạng thực thi như một tài sản chung của các mô hình, mà được giả định là [] (trừ khi có quy định khác bằng cách thiết lập batch_ndims > 0 ). Hiệu quả tương đương với sử dụng tfd.Independent để chuyển đổi tất cả các kích thước hàng loạt các bản phân phối phần vào kích thước sự kiện, như chúng ta đã làm bằng tay trên.
  • Vấn đề thứ hai của chúng tôi là một nhu cầu để xoa bóp các hình dạng của mb để họ có thể phát sóng phù hợp với X khi tạo nhiều mẫu. Với JDSAB, bạn viết một mô hình để tạo ra một mẫu duy nhất, và chúng ta "nâng đỡ" toàn bộ mô hình để tạo ra nhiều mẫu sử dụng TensorFlow của vectorized_map . (Tính năng này analagous để JAX của VMAP .)

Khám phá các vấn đề hình hàng loạt chi tiết hơn, chúng ta có thể so sánh các hình dạng lô gốc "xấu" của chúng tôi phân phối chung jds , phân phối hàng loạt cố định của chúng tôi jds_ijds_ia , và autobatched của chúng tôi jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

Chúng ta thấy rằng bản gốc jds có subdistributions với hình dạng batch khác nhau. jds_ijds_ia khắc phục điều này bằng cách tạo ra subdistributions với cùng (trống) hình hàng loạt. jds_ab có duy nhất một (trống) hình hàng loạt.

Nó đáng chú ý là JointDistributionSequentialAutoBatched cung cấp một số tính tổng quát bổ sung miễn phí. Giả sử chúng ta làm cho biến số X (và, mặc nhiên, các quan sát Y ) hai chiều:

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

Chúng tôi JointDistributionSequentialAutoBatched làm việc không có thay đổi (chúng ta cần phải xác định lại mô hình vì hình dạng của X là cache của jds_ab.log_prob ):

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

Mặt khác, chúng tôi cẩn thận crafted JointDistributionSequential không còn hoạt động:

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

Để khắc phục điều này, chúng ta sẽ phải thêm một giây tf.newaxis cho cả mb phù hợp với hình dạng, và tăng reinterpreted_batch_ndims đến 2 trong cuộc gọi đến Independent . Trong trường hợp này, việc để máy trộn tự động xử lý các vấn đề về hình dạng sẽ ngắn hơn, dễ dàng hơn và tiện dụng hơn.

Một lần nữa, chúng tôi lưu ý rằng trong khi máy tính xách tay này đã tìm hiểu JointDistributionSequentialAutoBatched , các biến thể khác của JointDistribution có tương đương AutoBatched . (Đối với người dùng của JointDistributionCoroutine , JointDistributionCoroutineAutoBatched có lợi ích bổ sung mà bạn không còn cần phải xác định Root node, nếu bạn chưa từng sử dụng JointDistributionCoroutine . Bạn có thể yên tâm bỏ qua tuyên bố này)

Suy nghĩ kết luận

Trong máy tính xách tay này, chúng tôi giới thiệu JointDistributionSequentialAutoBatched và làm việc thông qua một ví dụ đơn giản một cách chi tiết. Hy vọng rằng bạn đã học được điều gì đó về hình dạng TFP và về tính năng tự động so khớp!