ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูแหล่งที่มาบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
บทนำ
TensorFlow ความน่าจะเป็น (TFP) มีจำนวนของ JointDistribution นามธรรมที่ทำให้อนุมานน่าจะง่ายขึ้นโดยให้ผู้ใช้เพื่อให้ง่ายต่อแสดงรูปแบบกราฟิกที่น่าจะเป็นในรูปแบบที่ใกล้เคียงกับคณิตศาสตร์ สิ่งที่เป็นนามธรรมจะสร้างวิธีการสุ่มตัวอย่างจากแบบจำลองและประเมินความน่าจะเป็นของบันทึกของกลุ่มตัวอย่างจากแบบจำลอง ในการกวดวิชานี้เราจะตรวจสอบ "autobatched" สายพันธุ์ซึ่งได้รับการพัฒนาหลังจากที่เดิม JointDistribution นามธรรม เมื่อเทียบกับออโตแบทช์ที่เป็นนามธรรมแบบดั้งเดิม เวอร์ชันอัตโนมัตินั้นใช้งานง่ายกว่าและถูกหลักสรีรศาสตร์มากขึ้น ทำให้สามารถแสดงโมเดลจำนวนมากโดยใช้ต้นแบบที่น้อยลง ใน colab นี้ เราสำรวจโมเดลง่ายๆ ในรายละเอียด (บางทีอาจน่าเบื่อ) ทำให้เห็นปัญหาในการแก้ไขอัตโนมัติและ (หวังว่า) จะสอนผู้อ่านเพิ่มเติมเกี่ยวกับแนวคิดเกี่ยวกับรูปร่าง TFP ไปพร้อมกัน
ก่อนที่จะนำ autobatching ที่มีอยู่เป็นสายพันธุ์ที่แตกต่างกันไม่กี่ JointDistribution ที่สอดคล้องกับรูปแบบประโยคที่แตกต่างกันสำหรับการแสดงแบบจำลองความน่าจะเป็น: JointDistributionSequential , JointDistributionNamed และ JointDistributionCoroutine Auobatching อยู่เป็น mixin ดังนั้นตอนนี้เรามี AutoBatched สายพันธุ์ของสิ่งเหล่านี้ ในการกวดวิชานี้เราจะสำรวจความแตกต่างระหว่าง JointDistributionSequential และ JointDistributionSequentialAutoBatched ; อย่างไรก็ตาม ทุกสิ่งที่เราทำที่นี่ใช้ได้กับรุ่นอื่นๆ โดยไม่มีการเปลี่ยนแปลง
การพึ่งพาและข้อกำหนดเบื้องต้น
นำเข้าและตั้งค่า
import functools
import numpy as np
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
วิชาบังคับก่อน: ปัญหาการถดถอยแบบเบย์
เราจะพิจารณาสถานการณ์ถดถอยแบบเบเซียนอย่างง่าย:
\[ \begin{align*} m & \sim \text{Normal}(0, 1) \\ b & \sim \text{Normal}(0, 1) \\ Y & \sim \text{Normal}(mX + b, 1) \end{align*} \]
ในรูปแบบนี้ m และ b ถูกดึงมาจากภาวะปกติมาตรฐานและข้อสังเกต Y ถูกดึงมาจากการแจกแจงแบบปกติที่มีค่าเฉลี่ยขึ้นอยู่กับตัวแปรสุ่ม m และ b และบางส่วน (nonrandom ที่รู้จักกัน) ตัวแปร X (เพื่อความเรียบง่าย ในตัวอย่างนี้ เราถือว่ามาตราส่วนของตัวแปรสุ่มทั้งหมดเป็นที่รู้จัก)
เพื่อดำเนินการอนุมานในรูปแบบนี้เราจะต้องรู้ว่าตัวแปรทั้ง X และสังเกต Y แต่สำหรับวัตถุประสงค์ของการกวดวิชานี้เราจะต้องการเพียง X ดังนั้นเราจึงกำหนดหุ่นง่าย X :
X = np.arange(7)
X
array([0, 1, 2, 3, 4, 5, 6])
เดสิเดราตา
ในการอนุมานความน่าจะเป็น เรามักจะต้องการดำเนินการพื้นฐานสองอย่าง:
-
sample: ตัวอย่างการวาดภาพจากแบบจำลอง -
log_prob: การคำนวณความน่าจะเป็นบันทึกของตัวอย่างจากแบบจำลอง
ผลงานที่สำคัญของ TFP ของ JointDistribution นามธรรม (เช่นเดียวกับวิธีการอื่น ๆ อีกมากมายกับการเขียนโปรแกรมน่าจะเป็น) คือการอนุญาตให้ผู้ใช้สามารถเขียนรูปแบบครั้งเดียวและมีการเข้าถึงทั้ง sample และ log_prob คำนวณ
สังเกตว่าเรามี 7 คะแนนอยู่ในชุดข้อมูลของเรา ( X.shape = (7,) ) ตอนนี้เราสามารถระบุ Desiderata สำหรับการที่ดีเยี่ยม JointDistribution :
-
sample()ควรผลิตรายการTensorsที่มีรูปร่าง[(), (), (7,)] ซึ่งสอดคล้องกับความลาดชันเกลาอคติเกลาและข้อสังเกตเวกเตอร์ตามลำดับ -
log_prob(sample())ควรผลิตเกลา: ความน่าจะเป็นบันทึกการโดยเฉพาะอย่างยิ่งลาดอคติและข้อสังเกต -
sample([5, 3])ควรผลิตรายการTensorsที่มีรูปร่าง[(5, 3), (5, 3), (5, 3, 7)]คิดเป็น(5, 3)- ชุดของตัวอย่างจาก นางแบบ. -
log_prob(sample([5, 3]))ควรผลิตTensorที่มีรูปร่าง (5, 3)
ตอนนี้เราจะดูที่สืบทอดของ JointDistribution รุ่นดูวิธีการเพื่อให้บรรลุ Desiderata ข้างต้นและหวังว่าจะได้เรียนรู้เล็ก ๆ น้อย ๆ เพิ่มเติมเกี่ยวกับรูปทรง TFP ไปพร้อมกัน
แจ้งเตือนสปอยเลอร์: วิธีการที่ตรงกับ Desiderata ดังกล่าวข้างต้นโดยไม่ต้องเพิ่มสำเร็จรูปเป็น autobatching
ความพยายามครั้งแรก; JointDistributionSequential
jds = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
นี่เป็นการแปลโดยตรงของโมเดลเป็นโค้ดไม่มากก็น้อย ความลาดชัน m และอคติ b มีความตรงไปตรงมา Y ถูกกำหนดให้ใช้ lambda ฟังก์ชั่: รูปแบบทั่วไปคือว่า lambda ฟังก์ชั่ของ \(k\) ข้อโต้แย้งใน JointDistributionSequential (JDS) ใช้ก่อนหน้านี้ \(k\) กระจายในรูปแบบ สังเกตคำสั่ง "ย้อนกลับ"
เราจะเรียก sample_distributions ซึ่งผลตอบแทนทั้งตัวอย่างและต้นแบบ "ย่อยกระจาย" ที่ถูกนำมาใช้ในการสร้างตัวอย่าง (เราอาจจะมีการผลิตเพียงตัวอย่างโดยการเรียก sample ต่อมาในการกวดวิชามันจะสะดวกที่จะมีการแจกแจงได้เป็นอย่างดี.) กลุ่มตัวอย่างที่เราผลิตจะปรับ:
dists, sample = jds.sample_distributions()
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
<tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 ,
-7.308844 , -9.832693 ], dtype=float32)>]
แต่ log_prob ผลิตผลที่มีรูปร่างที่ไม่พึงประสงค์:
jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-4.4777603, -4.6775575, -4.7430477, -4.647725 , -4.5746684,
-4.4368567, -4.480562 ], dtype=float32)>
และการสุ่มตัวอย่างหลายครั้งใช้ไม่ได้:
try:
jds.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
เรามาลองทำความเข้าใจกันว่าจะเกิดอะไรขึ้น
บทวิจารณ์โดยย่อ: แบทช์และรูปร่างของเหตุการณ์
ใน TFP, สามัญ (ไม่ JointDistribution ) กระจายความน่าจะมีรูปร่างเหตุการณ์และรูปร่างชุดและความเข้าใจที่แตกต่างกันเป็นสิ่งสำคัญในการใช้งานที่มีประสิทธิภาพของ TFP:
- รูปร่างเหตุการณ์อธิบายรูปร่างของการวาดครั้งเดียวจากการแจกแจง การจับฉลากอาจขึ้นอยู่กับมิติต่างๆ สำหรับการแจกแจงสเกลาร์ รูปร่างเหตุการณ์คือ [] สำหรับ MultivariateNormal 5 มิติ รูปร่างเหตุการณ์คือ [5]
- รูปร่างแบทช์อธิบายการแจกแจงแบบอิสระ ไม่มีการแจกแจงแบบเดียวกัน หรือที่เรียกว่า "แบทช์" ของการแจกแจง การแสดงชุดของการแจกแจงในอ็อบเจ็กต์ Python เดียวเป็นหนึ่งในวิธีสำคัญที่ TFP บรรลุประสิทธิภาพตามขนาด
สำหรับวัตถุประสงค์ของเราเป็นความจริงสำคัญที่จะเก็บไว้ในใจก็คือว่าถ้าเราเรียก log_prob ในตัวอย่างเดียวจากการกระจายผลที่ได้มักจะมีรูปร่างที่การแข่งขัน (เช่นมีเป็นมิติขวาสุด) รูปร่างชุดที่
สำหรับการอภิปรายเพิ่มเติมในเชิงลึกของรูปทรงเห็น ว่า "ความเข้าใจ TensorFlow กระจายรูปร่าง" กวดวิชา
ทำไมไม่ log_prob(sample()) จัดทำเกลา?
ลองใช้ความรู้เกี่ยวกับชุดและรูปร่างเหตุการณ์ของเราในการสำรวจสิ่งที่เกิดขึ้นกับ log_prob(sample()) นี่คือตัวอย่างของเราอีกครั้ง:
sample
[<tf.Tensor: shape=(), dtype=float32, numpy=-1.668757>,
<tf.Tensor: shape=(), dtype=float32, numpy=0.6585061>,
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([ 0.18573815, -1.79962 , -1.8106272 , -3.5971394 , -6.6625295 ,
-7.308844 , -9.832693 ], dtype=float32)>]
และนี่คือการแจกแจงของเรา:
dists
[<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>]
ความน่าจะเป็นของบันทึกคำนวณโดยการรวมความน่าจะเป็นบันทึกของการแจกแจงย่อยที่องค์ประกอบ (ที่ตรงกัน) ของชิ้นส่วนต่างๆ:
log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]
log_prob_parts
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.3113134>,
<tf.Tensor: shape=(), dtype=float32, numpy=-1.1357536>,
<tf.Tensor: shape=(7,), dtype=float32, numpy=
array([-1.0306933, -1.2304904, -1.2959809, -1.200658 , -1.1276014,
-0.9897899, -1.0334952], dtype=float32)>]
np.sum(log_prob_parts) - jds.log_prob(sample)
<tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)>
ดังนั้นระดับหนึ่งของคำอธิบายคือการคำนวณความน่าจะเข้าสู่ระบบจะกลับมา 7 Tensor เพราะย่อยที่สามของ log_prob_parts เป็น 7 Tensor แต่ทำไม?
ดีเราจะเห็นว่าองค์ประกอบสุดท้ายของ dists ซึ่งสอดคล้องกับการจัดจำหน่ายของเรามากกว่า Y ในการกำหนด mathematial มี batch_shape ของ [7] ในคำอื่น ๆ การจัดจำหน่ายของเรามากกว่า Y เป็นชุดที่ 7 ปกติอิสระ (ด้วยวิธีการที่แตกต่างกันและในกรณีนี้ระดับเดียวกัน)
ตอนนี้เราเข้าใจอะไรผิดปกติใน JDS กระจายมากกว่า Y มี batch_shape=[7] , ตัวอย่างจาก JDS หมายถึงสเกลาสำหรับ m และ b และ "ชุด" 7 ปกติอิสระ และ log_prob คำนวณ 7 เข้าสู่ระบบน่าจะแยกกันซึ่งแสดงให้เห็นถึงความน่าจะเป็นบันทึกของการวาด m และ b และสังเกตเดียว Y[i] ในบาง X[i]
แก้ไข log_prob(sample()) ที่มี Independent
จำได้ว่า dists[2] มี event_shape=[] และ batch_shape=[7] :
dists[2]
<tfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32>
โดยใช้ TFP ของ Independent metadistribution ซึ่งจะแปลงมิติชุดขนาดเหตุการณ์เราสามารถแปลงนี้ในการกระจายกับ event_shape=[7] และ batch_shape=[] (เราจะเปลี่ยนชื่อ y_dist_i เพราะมันเป็นการกระจายบน Y กับ _i ยืน สำหรับเรา Independent ห่อ):
y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)
y_dist_i
<tfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32>
ตอนนี้ log_prob ของ 7-เวกเตอร์เป็นสเกลาร์:
y_dist_i.log_prob(sample[2])
<tf.Tensor: shape=(), dtype=float32, numpy=-7.9087086>
ภายใต้ครอบคลุม Independent จำนวนเงินกว่าชุด:
y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
และแน่นอนเราสามารถใช้เพื่อสร้างใหม่ jds_i (คน i อีกครั้งหมายถึง Independent ) ซึ่ง log_prob ส่งกลับเกลา:
jds_i = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m*X + b, scale=1.),
reinterpreted_batch_ndims=1)
])
jds_i.log_prob(sample)
<tf.Tensor: shape=(), dtype=float32, numpy=-11.355776>
หมายเหตุสองสาม:
-
jds_i.log_prob(s)ไม่ได้เป็นเช่นเดียวกับtf.reduce_sum(jds.log_prob(s))อดีตสร้างความน่าจะเป็นบันทึก "ถูกต้อง" ของการแจกแจงร่วม จำนวนเงินที่หลังกว่า 7 Tensor แต่ละองค์ประกอบซึ่งเป็นผลรวมของความน่าจะเป็นบันทึกของm,bและองค์ประกอบหนึ่งของความน่าจะเข้าสู่ระบบของYจึง overcountsmและb(log_prob(m) + log_prob(b) + log_prob(Y)ผลตอบแทนมากกว่าการขว้างปายกเว้นเพราะ TFP ดังนี้ TF และกฎระเบียบของการออกอากาศ NumPy. เพิ่มเกลาเวกเตอร์ผลิตผลเวกเตอร์ขนาดใหญ่) - ในกรณีนี้โดยเฉพาะอย่างยิ่งที่เราจะได้มีการแก้ไขปัญหาที่เกิดขึ้นและประสบผลเดียวกันโดยใช้
MultivariateNormalDiagแทนIndependent(Normal(...))MultivariateNormalDiagคือการกระจายเวกเตอร์ (คือมันมีอยู่แล้วเวกเตอร์เหตุการณ์รูปร่าง) IndeeedMultivariateNormalDiagอาจจะ ( แต่ไม่) นำมาใช้เป็นองค์ประกอบของIndependentและNormalมันคุ้มค่าที่จะจำไว้ว่าให้เวกเตอร์V, ตัวอย่างจากn1 = Normal(loc=V)และn2 = MultivariateNormalDiag(loc=V)จะแยกไม่ออก; ความแตกต่าง beween กระจายเหล่านี้คือn1.log_prob(n1.sample())เป็นเวกเตอร์และn2.log_prob(n2.sample())เป็นสเกลา
หลายตัวอย่าง?
การวาดตัวอย่างหลายตัวอย่างยังคงใช้งานไม่ได้:
try:
jds_i.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
ลองคิดดูว่าทำไม เมื่อเราเรียก jds_i.sample([5, 3]) เราจะวาดครั้งแรกตัวอย่างสำหรับ m และ b แต่ละคนมีรูปร่าง (5, 3) ต่อไปเราจะพยายามที่จะสร้าง Normal กระจายผ่าน:
tfd.Normal(loc=m*X + b, scale=1.)
แต่ถ้า m มีรูปร่าง (5, 3) และ X มีรูปร่างที่ 7 เราไม่สามารถคูณพวกเขาร่วมกันและแน่นอนนี่เป็นข้อผิดพลาดที่เรากำลังชน:
m = tfd.Normal(0., 1.).sample([5, 3])
try:
m * X
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3] vs. [7] [Op:Mul]
เมื่อต้องการแก้ไขปัญหานี้ให้คิดเกี่ยวกับสิ่งที่คุณสมบัติการกระจายมากกว่า Y ต้องมี ถ้าเราได้เรียก jds_i.sample([5, 3]) แล้วเรารู้ว่า m และ b ทั้งสองจะมีรูปร่าง (5, 3) สิ่งที่รูปร่างควรเรียกร้องให้ sample ใน Y ผลิตการจัดจำหน่าย? คำตอบที่ชัดเจนคือ (5, 3, 7) : จุดชุดแต่ละที่เราต้องการตัวอย่างที่มีขนาดเดียวกับ X เราสามารถทำได้โดยใช้ความสามารถในการออกอากาศของ TensorFlow เพิ่มมิติพิเศษ:
m[..., tf.newaxis].shape
TensorShape([5, 3, 1])
(m[..., tf.newaxis] * X).shape
TensorShape([5, 3, 7])
เพิ่มแกนทั้ง m และ b เราสามารถกำหนด JDS ใหม่ที่รองรับหลายตัวอย่าง:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
shaped_sample = jds_ia.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-1.1133379 , 0.16390413, -0.24177533],
[-1.1312429 , -0.6224666 , -1.8182136 ],
[-0.31343174, -0.32932565, 0.5164407 ],
[-0.0119963 , -0.9079621 , 2.3655841 ],
[-0.26293617, 0.8229698 , 0.31098196]], dtype=float32)>,
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-0.02876974, 1.0872147 , 1.0138507 ],
[ 0.27367726, -1.331534 , -0.09084719],
[ 1.3349475 , -0.68765205, 1.680652 ],
[ 0.75436825, 1.3050154 , -0.9415123 ],
[-1.2502679 , -0.25730947, 0.74611956]], dtype=float32)>,
<tf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=
array([[[-1.8258233e+00, -3.0641669e-01, -2.7595463e+00, -1.6952467e+00,
-4.8197951e+00, -5.2986512e+00, -6.6931367e+00],
[ 3.6438566e-01, 1.0067395e+00, 1.4542470e+00, 8.1155670e-01,
1.8868095e+00, 2.3877139e+00, 1.0195159e+00],
[-8.3624744e-01, 1.2518480e+00, 1.0943471e+00, 1.3052304e+00,
-4.5756745e-01, -1.0668410e-01, -7.0669651e-02]],
[[-3.1788960e-01, 9.2615485e-03, -3.0963073e+00, -2.2846246e+00,
-3.2269263e+00, -6.0213070e+00, -7.4806519e+00],
[-3.9149747e+00, -3.5155020e+00, -1.5669601e+00, -5.0759468e+00,
-4.5065498e+00, -5.6719379e+00, -4.8012795e+00],
[ 1.3053948e-01, -8.0493152e-01, -4.7845001e+00, -4.9721808e+00,
-7.1365709e+00, -9.6198196e+00, -9.7951422e+00]],
[[ 2.0621397e+00, 3.4639853e-01, 7.0252883e-01, -1.4311566e+00,
3.3790007e+00, 1.1619035e+00, -8.9105040e-01],
[-7.8956139e-01, -8.5023916e-01, -9.7148323e-01, -2.6229355e+00,
-2.7150445e+00, -2.4633870e+00, -2.1841538e+00],
[ 7.7627432e-01, 2.2401071e+00, 3.7601702e+00, 2.4245868e+00,
4.0690269e+00, 4.0605016e+00, 5.1753912e+00]],
[[ 1.4275590e+00, 3.3346462e+00, 1.5374103e+00, -2.2849756e-01,
9.1219616e-01, -3.1220305e-01, -3.2643962e-01],
[-3.1910419e-02, -3.8848895e-01, 9.9946201e-02, -2.3619974e+00,
-1.8507402e+00, -3.6830821e+00, -5.4907336e+00],
[-7.1941972e-02, 2.1602919e+00, 4.9575748e+00, 4.2317696e+00,
9.3528280e+00, 1.0526063e+01, 1.5262107e+01]],
[[-2.3257759e+00, -2.5343289e+00, -3.5342445e+00, -4.0423255e+00,
-3.2361765e+00, -3.3434000e+00, -2.6849220e+00],
[ 1.5006512e-02, -1.9866472e-01, 7.6781356e-01, 1.6228745e+00,
1.4191239e+00, 2.6655579e+00, 4.4663467e+00],
[ 2.6599693e+00, 1.2663836e+00, 1.7162113e+00, 1.4839669e+00,
2.0559487e+00, 2.5976877e+00, 2.5977583e+00]]], dtype=float32)>]
jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.483114 , -10.139662 , -11.514159 ],
[-11.656767 , -17.201958 , -12.132455 ],
[-17.838818 , -9.474525 , -11.24898 ],
[-13.95219 , -12.490049 , -17.123957 ],
[-14.487818 , -11.3755455, -10.576363 ]], dtype=float32)>
เพื่อเป็นการตรวจสอบเพิ่มเติม เราจะตรวจสอบว่าความน่าจะเป็นของบันทึกสำหรับจุดชุดเดียวตรงกับสิ่งที่เรามีก่อนหน้านี้:
(jds_ia.log_prob(shaped_sample)[3, 1] -
jds_i.log_prob([shaped_sample[0][3, 1],
shaped_sample[1][3, 1],
shaped_sample[2][3, 1, :]]))
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
แบทช์อัตโนมัติเพื่อชัยชนะ
ยอดเยี่ยม! ตอนนี้เรามีรุ่นของ JointDistribution ที่จับทุก Desiderata ของเรา: log_prob ผลตอบแทนขอบคุณเกลากับการใช้ tfd.Independent และตัวอย่างหลายทำงานในขณะนี้ว่าเราคงออกอากาศโดยการเพิ่มแกนพิเศษ
ถ้าฉันบอกคุณว่ามีวิธีที่ง่ายกว่าและดีกว่านี้ มีและก็เรียกว่า JointDistributionSequentialAutoBatched (JDSAB):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
jds_ab.log_prob(jds.sample())
<tf.Tensor: shape=(), dtype=float32, numpy=-12.954952>
shaped_sample = jds_ab.sample([5, 3])
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-12.191533 , -10.43885 , -16.371655 ],
[-13.292994 , -11.97949 , -16.788685 ],
[-15.987699 , -13.435732 , -10.6029 ],
[-10.184758 , -11.969714 , -14.275676 ],
[-12.740775 , -11.5654125, -12.990162 ]], dtype=float32)>
jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float32)>
มันทำงานอย่างไร? ขณะที่คุณสามารถพยายามที่จะ อ่านรหัส เพื่อความเข้าใจที่ลึกเราจะให้ภาพรวมคร่าวๆซึ่งเพียงพอสำหรับกรณีการใช้งานมากที่สุด:
- จำได้ว่าปัญหาที่เกิดขึ้นครั้งแรกของเราคือการที่จัดจำหน่ายของเราสำหรับ
Yมีbatch_shape=[7]และevent_shape=[]และเราใช้Independentในการแปลงมิติชุดไปยังมิติเหตุการณ์ JDSAB ละเว้นรูปร่างแบทช์ของการแจกแจงส่วนประกอบ แทนที่จะให้การปฏิบัติต่อรูปร่างชุดเป็นคุณสมบัติโดยรวมของรูปแบบซึ่งจะถือว่าเป็น[](ยกเว้นกรณีที่ระบุไว้เป็นอย่างอื่นโดยการตั้งค่าbatch_ndims > 0) ผลที่ได้คือเทียบเท่ากับการใช้ tfd.Independent การแปลงขนาดชุดทั้งหมดของการแจกแจงองค์ประกอบเข้าไปในมิติเหตุการณ์ที่เราทำด้วยตนเองดังกล่าวข้างต้น - ปัญหาที่สองของเราคือความต้องการที่จะนวดรูปร่างของ
mและbเพื่อให้พวกเขาสามารถออกอากาศได้อย่างเหมาะสมกับXเมื่อมีการสร้างหลายตัวอย่าง ด้วย JDSAB คุณเขียนรูปแบบในการสร้างตัวอย่างเดียวและเรา "ยก" รูปแบบทั้งการสร้างตัวอย่างหลายคนโดยใช้ TensorFlow ของ vectorized_map (คุณลักษณะนี้คล้ายคลึงเพื่อ JAX ของ VMAP .)
สำรวจปัญหารูปร่างชุดในรายละเอียดมากขึ้นเราสามารถเปรียบเทียบรูปร่างชุดของเรา "ไม่ดี" เดิมร่วมกันจำหน่าย jds กระจายชุดถาวรของเรา jds_i และ jds_ia และ autobatched เรา jds_ab :
jds.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([7])]
jds_i.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ia.batch_shape
[TensorShape([]), TensorShape([]), TensorShape([])]
jds_ab.batch_shape
TensorShape([])
เราจะเห็นว่าเดิม jds มี subdistributions ที่มีรูปร่างที่แตกต่างกันชุด jds_i และ jds_ia แก้ไขปัญหานี้โดยการสร้าง subdistributions เดียวกับ (ว่าง) รูปร่างชุด jds_ab มีเพียง (ว่าง) รูปร่างชุดเดียว
มันน่าสังเกตว่า JointDistributionSequentialAutoBatched ข้อเสนอบางทั่วไปเพิ่มเติมฟรี สมมติว่าเราทำให้ตัวแปร X (และโดยปริยายสังเกต Y ) สองมิติ:
X = np.arange(14).reshape((2, 7))
X
array([[ 0, 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12, 13]])
เรา JointDistributionSequentialAutoBatched ทำงานโดยไม่มีการเปลี่ยนแปลง (เราจำเป็นที่จะกำหนดรูปแบบเพราะรูปร่างของ X จะถูกเก็บไว้โดย jds_ab.log_prob ):
jds_ab = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y
])
shaped_sample = jds_ab.sample([5, 3])
shaped_sample
[<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[ 0.1813647 , -0.85994506, 0.27593774],
[-0.73323774, 1.1153806 , 0.8841938 ],
[ 0.5127983 , -0.29271227, 0.63733214],
[ 0.2362284 , -0.919168 , 1.6648189 ],
[ 0.26317367, 0.73077047, 2.5395133 ]], dtype=float32)>,
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[ 0.09636458, 2.0138032 , -0.5054413 ],
[ 0.63941646, -1.0785882 , -0.6442188 ],
[ 1.2310615 , -0.3293852 , 0.77637213],
[ 1.2115169 , -0.98906034, -0.07816773],
[-1.1318136 , 0.510014 , 1.036522 ]], dtype=float32)>,
<tf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=
array([[[[-1.9685398e+00, -1.6832136e+00, -6.9127172e-01,
8.5992378e-01, -5.3123581e-01, 3.1584005e+00,
2.9044402e+00],
[-2.5645006e-01, 3.1554163e-01, 3.1186538e+00,
1.4272424e+00, 1.2843871e+00, 1.2266440e+00,
1.2798605e+00]],
[[ 1.5973477e+00, -5.3631151e-01, 6.8143606e-03,
-1.4910895e+00, -2.1568544e+00, -2.0513713e+00,
-3.1663666e+00],
[-4.9448099e+00, -2.8385928e+00, -6.9027486e+00,
-5.6543546e+00, -7.2378774e+00, -8.1577444e+00,
-9.3582869e+00]],
[[-2.1233239e+00, 5.8853775e-02, 1.2024102e+00,
1.6622503e+00, -1.9197327e-01, 1.8647723e+00,
6.4322817e-01],
[ 3.7549341e-01, 1.5853541e+00, 2.4594500e+00,
2.1952972e+00, 1.7517658e+00, 2.9666045e+00,
2.5468128e+00]]],
[[[ 8.9906776e-01, 6.7375046e-01, 7.3354661e-01,
-9.9894643e-01, -3.4606690e+00, -3.4810467e+00,
-4.4315586e+00],
[-3.0670738e+00, -6.3628020e+00, -6.2538433e+00,
-6.8091092e+00, -7.7134805e+00, -8.6319380e+00,
-8.6904278e+00]],
[[-2.2462025e+00, -3.3060855e-01, 1.8974400e-01,
3.1422038e+00, 4.1483402e+00, 3.5642972e+00,
4.8709240e+00],
[ 4.7880130e+00, 5.8790064e+00, 9.6695948e+00,
7.8112822e+00, 1.2022618e+01, 1.2411858e+01,
1.4323385e+01]],
[[-1.0189297e+00, -7.8115642e-01, 1.6466728e+00,
8.2378983e-01, 3.0765080e+00, 3.0170646e+00,
5.1899948e+00],
[ 6.5285158e+00, 7.8038850e+00, 6.4155884e+00,
9.0899811e+00, 1.0040427e+01, 9.1404457e+00,
1.0411951e+01]]],
[[[ 4.5557004e-01, 1.4905317e+00, 1.4904103e+00,
2.9777462e+00, 2.8620450e+00, 3.4745665e+00,
3.8295493e+00],
[ 3.9977460e+00, 5.7173767e+00, 7.8421035e+00,
6.3180594e+00, 6.0838981e+00, 8.2257290e+00,
9.6548376e+00]],
[[-7.0750320e-01, -3.5972297e-01, 4.3136525e-01,
-2.3301599e+00, -5.0374687e-01, -2.8338656e+00,
-3.4453444e+00],
[-3.1258626e+00, -3.4687450e+00, -1.2045374e+00,
-4.0196013e+00, -5.8831010e+00, -4.2965469e+00,
-4.1388311e+00]],
[[ 2.1969774e+00, 2.4614549e+00, 2.2314475e+00,
1.8392437e+00, 2.8367062e+00, 4.8600502e+00,
4.2273531e+00],
[ 6.1879644e+00, 5.1792760e+00, 6.1141996e+00,
5.6517797e+00, 8.9979610e+00, 7.5938139e+00,
9.7918644e+00]]],
[[[ 1.5249090e+00, 1.1388919e+00, 8.6903995e-01,
3.0762129e+00, 1.5128503e+00, 3.5204377e+00,
2.4760864e+00],
[ 3.4166217e+00, 3.5930209e+00, 3.1694956e+00,
4.5797420e+00, 4.5271711e+00, 2.8774328e+00,
4.7288942e+00]],
[[-2.3095846e+00, -2.0595703e+00, -3.0093951e+00,
-3.8594103e+00, -4.9681158e+00, -6.4256043e+00,
-5.5345035e+00],
[-6.4306297e+00, -7.0924540e+00, -8.4075985e+00,
-1.0417805e+01, -1.1727266e+01, -1.1196255e+01,
-1.1333830e+01]],
[[-7.0419472e-01, 1.4568675e+00, 3.7946482e+00,
4.8489718e+00, 6.6498446e+00, 9.0224218e+00,
1.1153137e+01],
[ 1.0060651e+01, 1.1998097e+01, 1.5326431e+01,
1.7957514e+01, 1.8323889e+01, 2.0160881e+01,
2.1269085e+01]]],
[[[-2.2360647e-01, -1.3632748e+00, -7.2704530e-01,
2.3558271e-01, -1.0381399e+00, 1.9387857e+00,
-3.3694571e-01],
[ 1.6015106e-01, 1.5284677e+00, -4.8567140e-01,
-1.7770648e-01, 2.1919653e+00, 1.3015286e+00,
1.3877077e+00]],
[[ 1.3688663e+00, 2.6602898e+00, 6.6657305e-01,
4.6554832e+00, 5.7781887e+00, 4.9115267e+00,
4.8446012e+00],
[ 5.1983776e+00, 6.2297459e+00, 6.3848300e+00,
8.4291229e+00, 7.1309576e+00, 1.0395646e+01,
8.5736713e+00]],
[[ 1.2675294e+00, 5.2844582e+00, 5.1331611e+00,
8.9993315e+00, 1.0794343e+01, 1.4039831e+01,
1.5731170e+01],
[ 1.9084715e+01, 2.2191265e+01, 2.3481146e+01,
2.5803375e+01, 2.8632090e+01, 3.0234968e+01,
3.1886738e+01]]]], dtype=float32)>]
jds_ab.log_prob(shaped_sample)
<tf.Tensor: shape=(5, 3), dtype=float32, numpy=
array([[-28.90071 , -23.052422, -19.851362],
[-19.775568, -25.894997, -20.302256],
[-21.10754 , -23.667885, -20.973007],
[-19.249458, -20.87892 , -20.573763],
[-22.351208, -25.457762, -24.648403]], dtype=float32)>
บนมืออื่น ๆ ที่เราสร้างขึ้นมาอย่างระมัดระวัง JointDistributionSequential ไม่ทำงาน:
jds_ia = tfd.JointDistributionSequential([
tfd.Normal(loc=0., scale=1.), # m
tfd.Normal(loc=0., scale=1.), # b
lambda b, m: tfd.Independent( # Y
tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),
reinterpreted_batch_ndims=1)
])
try:
jds_ia.sample([5, 3])
except tf.errors.InvalidArgumentError as e:
print(e)
Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]
เพื่อแก้ไขปัญหานี้เราจะต้องมีการเพิ่มเป็นครั้งที่สอง tf.newaxis ทั้ง m และ b ตรงกับรูปร่างและเพิ่ม reinterpreted_batch_ndims 2 ในการเรียกร้องให้ Independent ในกรณีนี้ การปล่อยให้เครื่องจักรอัตโนมัติจัดการปัญหารูปร่างนั้นสั้นลง ง่ายขึ้น และถูกหลักสรีรศาสตร์มากขึ้น
อีกครั้งหนึ่งที่เราทราบว่าในขณะนี้โน้ตบุ๊คสำรวจ JointDistributionSequentialAutoBatched , สายพันธุ์อื่น ๆ ของ JointDistribution มีเทียบเท่า AutoBatched (สำหรับผู้ใช้ JointDistributionCoroutine , JointDistributionCoroutineAutoBatched มีประโยชน์เพิ่มเติมที่คุณไม่จำเป็นอีกต่อไปเพื่อระบุ Root โหนดถ้าคุณไม่เคยใช้ JointDistributionCoroutine . คุณสามารถละเว้นคำสั่งนี้)
สรุปความคิด
ในสมุดบันทึกนี้เราแนะนำ JointDistributionSequentialAutoBatched และทำงานผ่านตัวอย่างง่ายๆในรายละเอียด หวังว่าคุณจะได้เรียนรู้บางอย่างเกี่ยวกับรูปร่าง TFP และเกี่ยวกับการทำแบทช์อัตโนมัติ!
ดูบน TensorFlow.org
ทำงานใน Google Colab
ดูแหล่งที่มาบน GitHub
ดาวน์โหลดโน๊ตบุ๊ค