자동 배치 공동 분포: 부드러운 자습서

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

소개

TensorFlow 확률 (TFP)을 다수 제공 JointDistribution 쉽게 쉽게에 사용자가 거의 수학적 형태의 확률 그래픽 모델을 표현할 수 있도록하여 확률 적 추론을 추상화를; 추상화는 모델에서 샘플링하고 모델에서 샘플의 로그 확률을 평가하는 방법을 생성합니다. 이 튜토리얼에서, 우리는 원래 후에 개발 된 "autobatched"변종 검토 JointDistribution 추상화를. 자동 일괄 처리되지 않은 원래의 추상화에 비해 자동 일괄 처리 버전은 사용이 더 간편하고 인체공학적이어서 더 적은 상용구로 많은 모델을 표현할 수 있습니다. 이 colab에서 우리는 (아마도 지루할 수도 있는) 세부적인 간단한 모델을 탐구하여 자동 배치가 해결하는 문제를 명확히 하고 (바라건대) 독자들에게 그 과정에서 TFP 모양 개념에 대해 더 많이 가르칩니다.

autobatching의 도입에 앞서, 몇 가지 변종이 있었다 JointDistribution : 확률 모델을 표현하는 다른 구문 스타일에 해당하는 JointDistributionSequential , JointDistributionNamedJointDistributionCoroutine . 우리가 지금 그래서 Auobatching는 믹스 인으로 존재 AutoBatched 이 모든 변종. 이 튜토리얼에서, 우리는 차이점 탐구 JointDistributionSequentialJointDistributionSequentialAutoBatched ; 그러나 여기서 우리가 하는 모든 것은 본질적으로 변경 없이 다른 변형에 적용할 수 있습니다.

종속성 및 전제 조건

가져오기 및 설정

전제 조건: 베이지안 회귀 문제

매우 간단한 베이지안 회귀 시나리오를 고려할 것입니다.

\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]

이 모델에서, mb 표준 법선으로부터 인출되고, 관측 Y 그 평균 확률 변수에 따라 정규 분포로부터 그려 mb , 일부 (랜덤하지 공지)는 공변량 X . (간단함을 위해 이 예에서는 모든 확률 변수의 척도를 알고 있다고 가정합니다.)

이 모델에서 추론을 수행하기 위해, 우리는 모두 공변량 알아야 할 것 X 및 관찰 Y 하지만,이 튜토리얼의 목적을 위해, 우리는해야합니다 X 우리가 간단한 더미 정의 할 수 있도록, X :

X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])

데데라타

확률 추론에서 우리는 종종 두 가지 기본 작업을 수행하기를 원합니다.

  • sample : 모델에서 샘플을 그리기.
  • log_prob : 모델의 샘플의 로그 확률을 계산.

TFP의 주요 기여 JointDistribution 추상화 (뿐만 아니라의 확률 프로그램에 많은 다른 접근 방법) 사용자가 한번 모델을 작성하고 모두에 액세스하도록 허용하는 것입니다 samplelog_prob 계산을.

우리는 우리의 데이터 세트 (7 점을 가지고 주목 X.shape = (7,) ), 우리는 지금 훌륭한에 대한 desiderata 명시 할 수 JointDistribution :

  • sample() 의 목록을 생성한다 Tensors 형상을 갖는 [(), (), (7,) 를 각각 스칼라 기울기 바이어스 스칼라 및 벡터 관측에 대응].
  • log_prob(sample()) 는 특정 기울기 바이어스 및 관측 로그 확률 : 스칼라를 생성한다.
  • sample([5, 3]) 의 목록을 생성한다 Tensors 형상을 갖는 [(5, 3), (5, 3), (5, 3, 7)] 하는 표현 (5, 3) - 샘플의 배치로부터이 모델.
  • log_prob(sample([5, 3])) 농산물한다 Tensor 형상 (5, 3)과.

우리는 지금의 연속 살펴 보겠습니다 JointDistribution 모델 위의 desiderata을 달성하는 방법을보고, 총요소 생산성이 길을 따라 모양에 대한 희망이 조금 더 배우게됩니다.

스포일러 경고 : 추가 상용구없이 만족 위의 desiderata이되는 접근 방식 autobatching .

첫번째 시도; JointDistributionSequential

jds = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

이것은 모델을 코드로 직접 번역한 것입니다. 기울기 m 바이어스 b 간단하다. Y 하여 정의된다 lambda α- 함수를 생성한다 : 패턴이 있다는 것이다 lambda 의 α- 함수 \(k\) A의 인자 JointDistributionSequential (JDS)가 이전의 사용 \(k\) 모델에 분포. "역" 순서에 유의하십시오.

우리는 전화 할게 sample_distributions , 반환 샘플 샘플을 생성하는 데 사용 된 기본 "하위 분포"둘 다. (우리는 호출하여 바로 샘플을 생산 할 수 sample , 나중에 튜토리얼는 분포를 가지고 편리 할 것입니다뿐만 아니라.) 우리가 생산 샘플은 괜찮 :

dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

그러나 log_prob 바람직하지 않은 형상의 결과를 생성합니다 :

jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
       -4.4368567, -4.480562 ], dtype=float32)>

다중 샘플링이 작동하지 않습니다.

try:
  jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

무엇이 잘못되었는지 이해하려고 노력합시다.

간략한 검토: 배치 및 이벤트 형태

TFP에 통상 (하지 JointDistribution ) 확률 분포는 일정 형상배치 형태와의 차이를 이해하는 TFP의 효과적인 사용에 매우 중요 있습니다

  • 이벤트 모양은 분포에서 단일 무승부의 모양을 설명합니다. 무승부는 치수에 따라 다를 수 있습니다. 스칼라 분포의 경우 이벤트 모양은 []입니다. 5차원 MultivariateNormal의 경우 이벤트 모양은 [5]입니다.
  • 배치 형태는 분포의 "배치"라고도 하는 동일하게 분포되지 않은 독립적인 드로우를 나타냅니다. 단일 Python 개체에서 배포 배치를 나타내는 것은 TFP가 규모에 따라 효율성을 달성하는 주요 방법 중 하나입니다.

우리의 목적을 위해, 기억해야 할 중요한 사실은 우리가 호출하는 경우이다 log_prob 배포에서 단일 샘플, 결과가 항상 일치 (즉, 맨 오른쪽 차원으로이)있는 모양 배치 형태를 가질 것이다.

모양에 대한 자세한 설명은 다음을 참조 은 "이해 TensorFlow 분포 모양"자습서를 .

하지 않는 이유는 log_prob(sample()) 스칼라를 생산?

하자가 일어나는 일을 탐험 배치 및 이벤트 모양의 우리의 지식을 사용 log_prob(sample()) . 다음은 다시 샘플입니다.

sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([ 0.18573815, -1.79962   , -1.8106272 , -3.5971394 , -6.6625295 ,
        -7.308844  , -9.832693  ], dtype=float32)>]

배포본은 다음과 같습니다.

dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
 <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]

로그 확률은 부품의 (일치된) 요소에서 하위 분포의 로그 확률을 합산하여 계산됩니다.

log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
 <tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
 <tf.Tensor: shape=(7,), dtype=float32, numpy=
 array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
        -0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>

따라서, 설명의 하나의 레벨 3 부성분 때문에 대수 확률 계산 7 텐서 반환된다는 것이다 log_prob_parts 7 텐서이다. 하지만 왜?

그래서, 우리는의 마지막 요소 볼 dists 위에 우리의 분포에 대응하고, Y mathematial 제형하기, 보유 batch_shape[7] . 즉, 우리 위에 분포 Y (이 경우, 같은 크기의 다른 수단 등) 7 독립적 법선의 배치이다.

우리는 지금 잘못 이해 : JDS에 걸쳐 분포 Y 있다 batch_shape=[7] 의 JDS에서 샘플에 대한 스칼라 대표 mb 7 독립적 인 법선의 "배치"를. 및 log_prob 도면의 로그 확률을 나타내고, 각각의 7 개별 로그 확률, 계산 mb 번의 관찰 Y[i] 일부에서 X[i] .

고정 log_prob(sample())Independent

리콜 것을 dists[2] 갖는 event_shape=[]batch_shape=[7] :

dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>

TFP의 사용에 의해 Independent 이벤트 차원에 배치 크기를 변환 metadistribution을, 우리와 배포에이를 변환 할 수 있습니다 event_shape=[7]batch_shape=[] (우리는 그것을 이름을 바꿀 수 있습니다 y_dist_i 가에 배포이기 때문에 Y_i 서 우리 대한의 Independent ) 포장 :

y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>

이제, log_prob 7 벡터의 스칼라입니다 :

y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>

커버, 아래에서 Independent 배치를 통해 금액 :

y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

그리고 실제로, 우리는 새로운 구성이 사용할 수 jds_i 합니다 ( i 다시 의미 Independent ) 여기서 log_prob 스칼라를 반환

jds_i = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m*X + b, scale=1.),
        reinterpreted_batch_ndims=1)
])

jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>

몇 가지 참고 사항:

  • jds_i.log_prob(s)동일하지 tf.reduce_sum(jds.log_prob(s)) . 전자는 공동 분포의 "정확한" 로그 확률을 생성합니다. 7 텐서 통해 후자의 합은 각각의 요소는 그중의 로그 확률의 합 m , b , 및 로그 확률의 단일 원소 Y 가 overcounts 있도록, mb . ( log_prob(m) + log_prob(b) + log_prob(Y) TFP는 TF와 NumPy와의 방송 규칙 따르기 때문에 예외를 발생 아닌 결과를 반환. 벡터에 스칼라를 추가하여 벡터 크기의 결과를 생성)을
  • 이 특정한 경우에, 우리는 문제를 해결하고 사용하여 동일한 결과를 얻을 수 있었다 MultivariateNormalDiag 대신 Independent(Normal(...)) . MultivariateNormalDiag 벡터 값 분포 (즉, 이미 벡터 이벤트 형태를 갖는다). Indeeed MultivariateNormalDiag 될 수있다 (그러나되지 않음)의 구성으로 구현 IndependentNormal . 그것은 그 벡터 주어진 기억하는 보람 V ,에서 샘플 n1 = Normal(loc=V)n2 = MultivariateNormalDiag(loc=V) 구별을; 이러한 분포 beween 차이이다 n1.log_prob(n1.sample()) 벡터이고 n2.log_prob(n2.sample()) 스칼라이다.

여러 샘플?

여러 샘플을 그리는 것이 여전히 작동하지 않습니다.

try:
  jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

그 이유를 생각해 봅시다. 우리가 호출되면 jds_i.sample([5, 3]) , 우리는 먼저 위해 샘플을 그리는 것이다 mb , 각 형상 (5, 3) . 다음으로, 우리가 건설하려고하는거야 Normal 를 통해 배포 :

tfd.Normal(loc=m*X + b, scale=1.)

하지만 m 모양이있다 (5, 3)X 모양이 7 , 우리는 그들을 함께 곱셈 할 수 없으며, 실제로이 오류 우린 타격이다 :

m = tfd.Normal(0., 1.).sample([5, 3])
try:
  m * X
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]

이 문제를 해결하기 위해,이 이상 분포 특성 것에 대해 생각해 봅시다 Y 가지고 가지고 있습니다. 우리가 호출 한 경우 jds_i.sample([5, 3]) , 그 다음 우리가 알고있는 mb 모두 모양이됩니다 (5, 3) . 호출은 어떤 모양을 할 수 있어야 sample 상의 Y 분포를 생산? 확실한 대답은 (5, 3, 7) : 각 배치 지점에 대해, 우리는 같은 크기의 샘플을 원하는 X . TensorFlow의 브로드캐스팅 기능을 사용하여 추가 차원을 추가하여 이를 달성할 수 있습니다.

m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])

양쪽 축을 추가 mb , 우리는 새로운 JDS를 정의 할 수 지지체 여러 샘플이 :

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-1.1133379 ,  0.16390413, -0.24177533],
        [-1.1312429 , -0.6224666 , -1.8182136 ],
        [-0.31343174, -0.32932565,  0.5164407 ],
        [-0.0119963 , -0.9079621 ,  2.3655841 ],
        [-0.26293617,  0.8229698 ,  0.31098196]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[-0.02876974,  1.0872147 ,  1.0138507 ],
        [ 0.27367726, -1.331534  , -0.09084719],
        [ 1.3349475 , -0.68765205,  1.680652  ],
        [ 0.75436825,  1.3050154 , -0.9415123 ],
        [-1.2502679 , -0.25730947,  0.74611956]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
 array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
          -4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
         [ 3.6438566e-01,  1.0067395e+00,  1.4542470e+00,  8.1155670e-01,
           1.8868095e+00,  2.3877139e+00,  1.0195159e+00],
         [-8.3624744e-01,  1.2518480e+00,  1.0943471e+00,  1.3052304e+00,
          -4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
 
        [[-3.1788960e-01,  9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
          -3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
         [-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
          -4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
         [ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
          -7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
 
        [[ 2.0621397e+00,  3.4639853e-01,  7.0252883e-01, -1.4311566e+00,
           3.3790007e+00,  1.1619035e+00, -8.9105040e-01],
         [-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
          -2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
         [ 7.7627432e-01,  2.2401071e+00,  3.7601702e+00,  2.4245868e+00,
           4.0690269e+00,  4.0605016e+00,  5.1753912e+00]],
 
        [[ 1.4275590e+00,  3.3346462e+00,  1.5374103e+00, -2.2849756e-01,
           9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
         [-3.1910419e-02, -3.8848895e-01,  9.9946201e-02, -2.3619974e+00,
          -1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
         [-7.1941972e-02,  2.1602919e+00,  4.9575748e+00,  4.2317696e+00,
           9.3528280e+00,  1.0526063e+01,  1.5262107e+01]],
 
        [[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
          -3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
         [ 1.5006512e-02, -1.9866472e-01,  7.6781356e-01,  1.6228745e+00,
           1.4191239e+00,  2.6655579e+00,  4.4663467e+00],
         [ 2.6599693e+00,  1.2663836e+00,  1.7162113e+00,  1.4839669e+00,
           2.0559487e+00,  2.5976877e+00,  2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
       [-11.656767 , -17.201958 , -12.132455 ],
       [-17.838818 ,  -9.474525 , -11.24898  ],
       [-13.95219  , -12.490049 , -17.123957 ],
       [-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>

추가 검사로 단일 배치 포인트에 대한 로그 확률이 ​​이전과 일치하는지 확인할 것입니다.

(jds_ia.log_prob(shaped_sample)[3, 1] -
 jds_i.log_prob([shaped_sample[0][3, 1],
                 shaped_sample[1][3, 1],
                 shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

승리를 위한 자동 배치

훌륭한! 우리는 지금 JointDistribution의 버전이 그 핸들 우리의 모든 desiderata : log_prob 회수 방법에 스칼라 감사 tfd.Independent , 여러 샘플 우리는 여분의 축을 추가하여 방송 고정 이니까 작동합니다.

더 쉽고 좋은 방법이 있다고 말하면 어떨까요? 있다, 그것은이라고 JointDistributionSequentialAutoBatched (JDSAB) :

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885  , -16.371655 ],
       [-13.292994 , -11.97949  , -16.788685 ],
       [-15.987699 , -13.435732 , -10.6029   ],
       [-10.184758 , -11.969714 , -14.275676 ],
       [-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>

어떻게 작동합니까? 당신이 시도 할 수 있지만 코드 읽기 에 대한 깊은 이해를 들어, 우리는 대부분의 사용 사례에 대한 충분한 간략한 개요를 줄 것이다 :

  • 리콜 우리의 첫 번째 문제는 우리의 유통 있다는 것을 Y 있었다 batch_shape=[7]event_shape=[] , 우리가 사용하는 Independent 이벤트 차원에 배치 차원을 변환 할 수 있습니다. JDSAB는 구성 요소 분포의 배치 모양을 무시합니다. 대신는 것으로 가정되는 모델의 전체적인 특성으로 배치 형상을 취급 [] (설정에 의해 달리 명시되지 않는 batch_ndims > 0 ). 효과는 우리가 수동으로 위처럼, 이벤트 차원으로 구성 분포의 모든 배치 치수를 변환하는 tfd.Independent을 사용하는 것과 같습니다.
  • 우리의 두 번째 문제는의 모양 마사지 할 필요했다 mb 그들이 적절하게 방송 할 수 있도록 X 여러 샘플을 만들 때. JDSAB를 사용하면 모델이 하나의 샘플을 생성하기 위해 쓰기, 그리고 우리 "리프트"전체 모델은 TensorFlow의 사용하여 여러 샘플을 생성하는 vectorized_map을 . (이 기능은 JAX의에와 똑같이입니다 VMAP .)

더 상세하게 배치 형태의 문제를 탐구, 우리는 우리의 원래 "나쁜"공동 분배의 배치 모양 비교할 수 jds , 우리의 배치 고정 분포 jds_ijds_ia , 우리의 autobatched jds_ab :

jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])

우리는 원래 볼 jds 다른 배치 모양 subdistributions 있습니다. jds_ijds_ia 동일한 (빈) 배치 형태로 subdistributions을 생성하여이 문제를 해결. jds_ab 단 하나의 (빈) 배치 형태를 갖는다.

있다는 지적이의 가치 JointDistributionSequentialAutoBatched 무료로 몇 가지 추가 일반성을 제공합니다. 우리는 공변량을 만드는 가정 X (그리고 암시 적으로, 관찰 Y ) 두 차원 :

X = np.arange(14).reshape((2, 7))
X
array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])

우리 JointDistributionSequentialAutoBatched 변경없이 (우리의 모양 때문에 모델을 재정의 할 필요가 작동 X 캐시됩니다 jds_ab.log_prob ) :

jds_ab = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])

shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.1813647 , -0.85994506,  0.27593774],
        [-0.73323774,  1.1153806 ,  0.8841938 ],
        [ 0.5127983 , -0.29271227,  0.63733214],
        [ 0.2362284 , -0.919168  ,  1.6648189 ],
        [ 0.26317367,  0.73077047,  2.5395133 ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3), dtype=float32, numpy=
 array([[ 0.09636458,  2.0138032 , -0.5054413 ],
        [ 0.63941646, -1.0785882 , -0.6442188 ],
        [ 1.2310615 , -0.3293852 ,  0.77637213],
        [ 1.2115169 , -0.98906034, -0.07816773],
        [-1.1318136 ,  0.510014  ,  1.036522  ]], dtype=float32)>,
 <tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
 array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
            8.5992378e-01, -5.3123581e-01,  3.1584005e+00,
            2.9044402e+00],
          [-2.5645006e-01,  3.1554163e-01,  3.1186538e+00,
            1.4272424e+00,  1.2843871e+00,  1.2266440e+00,
            1.2798605e+00]],
 
         [[ 1.5973477e+00, -5.3631151e-01,  6.8143606e-03,
           -1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
           -3.1663666e+00],
          [-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
           -5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
           -9.3582869e+00]],
 
         [[-2.1233239e+00,  5.8853775e-02,  1.2024102e+00,
            1.6622503e+00, -1.9197327e-01,  1.8647723e+00,
            6.4322817e-01],
          [ 3.7549341e-01,  1.5853541e+00,  2.4594500e+00,
            2.1952972e+00,  1.7517658e+00,  2.9666045e+00,
            2.5468128e+00]]],
 
 
        [[[ 8.9906776e-01,  6.7375046e-01,  7.3354661e-01,
           -9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
           -4.4315586e+00],
          [-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
           -6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
           -8.6904278e+00]],
 
         [[-2.2462025e+00, -3.3060855e-01,  1.8974400e-01,
            3.1422038e+00,  4.1483402e+00,  3.5642972e+00,
            4.8709240e+00],
          [ 4.7880130e+00,  5.8790064e+00,  9.6695948e+00,
            7.8112822e+00,  1.2022618e+01,  1.2411858e+01,
            1.4323385e+01]],
 
         [[-1.0189297e+00, -7.8115642e-01,  1.6466728e+00,
            8.2378983e-01,  3.0765080e+00,  3.0170646e+00,
            5.1899948e+00],
          [ 6.5285158e+00,  7.8038850e+00,  6.4155884e+00,
            9.0899811e+00,  1.0040427e+01,  9.1404457e+00,
            1.0411951e+01]]],
 
 
        [[[ 4.5557004e-01,  1.4905317e+00,  1.4904103e+00,
            2.9777462e+00,  2.8620450e+00,  3.4745665e+00,
            3.8295493e+00],
          [ 3.9977460e+00,  5.7173767e+00,  7.8421035e+00,
            6.3180594e+00,  6.0838981e+00,  8.2257290e+00,
            9.6548376e+00]],
 
         [[-7.0750320e-01, -3.5972297e-01,  4.3136525e-01,
           -2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
           -3.4453444e+00],
          [-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
           -4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
           -4.1388311e+00]],
 
         [[ 2.1969774e+00,  2.4614549e+00,  2.2314475e+00,
            1.8392437e+00,  2.8367062e+00,  4.8600502e+00,
            4.2273531e+00],
          [ 6.1879644e+00,  5.1792760e+00,  6.1141996e+00,
            5.6517797e+00,  8.9979610e+00,  7.5938139e+00,
            9.7918644e+00]]],
 
 
        [[[ 1.5249090e+00,  1.1388919e+00,  8.6903995e-01,
            3.0762129e+00,  1.5128503e+00,  3.5204377e+00,
            2.4760864e+00],
          [ 3.4166217e+00,  3.5930209e+00,  3.1694956e+00,
            4.5797420e+00,  4.5271711e+00,  2.8774328e+00,
            4.7288942e+00]],
 
         [[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
           -3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
           -5.5345035e+00],
          [-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
           -1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
           -1.1333830e+01]],
 
         [[-7.0419472e-01,  1.4568675e+00,  3.7946482e+00,
            4.8489718e+00,  6.6498446e+00,  9.0224218e+00,
            1.1153137e+01],
          [ 1.0060651e+01,  1.1998097e+01,  1.5326431e+01,
            1.7957514e+01,  1.8323889e+01,  2.0160881e+01,
            2.1269085e+01]]],
 
 
        [[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
            2.3558271e-01, -1.0381399e+00,  1.9387857e+00,
           -3.3694571e-01],
          [ 1.6015106e-01,  1.5284677e+00, -4.8567140e-01,
           -1.7770648e-01,  2.1919653e+00,  1.3015286e+00,
            1.3877077e+00]],
 
         [[ 1.3688663e+00,  2.6602898e+00,  6.6657305e-01,
            4.6554832e+00,  5.7781887e+00,  4.9115267e+00,
            4.8446012e+00],
          [ 5.1983776e+00,  6.2297459e+00,  6.3848300e+00,
            8.4291229e+00,  7.1309576e+00,  1.0395646e+01,
            8.5736713e+00]],
 
         [[ 1.2675294e+00,  5.2844582e+00,  5.1331611e+00,
            8.9993315e+00,  1.0794343e+01,  1.4039831e+01,
            1.5731170e+01],
          [ 1.9084715e+01,  2.2191265e+01,  2.3481146e+01,
            2.5803375e+01,  2.8632090e+01,  3.0234968e+01,
            3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
       [-19.775568, -25.894997, -20.302256],
       [-21.10754 , -23.667885, -20.973007],
       [-19.249458, -20.87892 , -20.573763],
       [-22.351208, -25.457762, -24.648403]], dtype=float32)>

다른 한편으로, 우리는 신중하게 만들어진하지 JointDistributionSequential 더 이상 작동 :

jds_ia = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1.),   # m
    tfd.Normal(loc=0., scale=1.),   # b
    lambda b, m: tfd.Independent(   # Y
        tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
        reinterpreted_batch_ndims=1)
])

try:
  jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
  print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]

이 문제를 해결하려면, 우리는 두 번째 추가해야 할 것 tf.newaxis 모두 mb 모양 및 증가 일치 reinterpreted_batch_ndims 에 대한 호출에 2를 Independent . 이 경우 자동 배치 기계가 모양 문제를 처리하도록 하는 것이 더 짧고 쉽고 인체공학적입니다.

다시 한번, 우리는이 노트북을 탐험하는 동안주의 JointDistributionSequentialAutoBatched 의 다른 변종 JointDistribution 상당이 AutoBatched . (사용자의 경우 JointDistributionCoroutine , JointDistributionCoroutineAutoBatched 추가적인 이점이 당신 지정할 더 이상 필요하지 Root 노드를, 당신은 사용한 적이있는 경우 JointDistributionCoroutine . 당신이 안전하게 문을 무시할 수 있습니다)

결론

이 노트북에서는 도입 JointDistributionSequentialAutoBatched 상세하게 간단한 예제를했다. TFP 모양과 자동 배치에 대해 배웠기를 바랍니다!