คำนวณการแพร่กระจายกลับเซลล์ GRU สำหรับขั้นตอน 1 ครั้ง
Args x: อินพุตไปยังเซลล์ GRU h_prev: ระบุอินพุตจากเซลล์ GRU ก่อนหน้า w_ru: เมทริกซ์น้ำหนักสำหรับเกทรีเซ็ตและอัปเดต w_c: เมทริกซ์น้ำหนักสำหรับเกตการเชื่อมต่อเซลล์ b_ru: เวกเตอร์อคติสำหรับเกตรีเซ็ตและอัปเดต b_c: เวกเตอร์อคติสำหรับเกตการเชื่อมต่อเซลล์ r: เอาต์พุตของเกทรีเซ็ต u: เอาต์พุตของเกทอัพเดต c: เอาต์พุตของเกตการเชื่อมต่อเซลล์ d_h: การไล่ระดับของ h_new wrt เป็นฟังก์ชันวัตถุประสงค์
ส่งกลับ d_x: การไล่ระดับของ x wrt เป็นฟังก์ชันวัตถุประสงค์ d_h_prev: การไล่ระดับของ h wrt เป็นฟังก์ชันวัตถุประสงค์ d_c_bar การไล่ระดับสีของ c_bar wrt เป็นฟังก์ชันวัตถุประสงค์ d_r_bar_u_bar การไล่ระดับสีของ r_bar & u_bar wrt เป็นฟังก์ชันวัตถุประสงค์
เคอร์เนล op นี้ใช้สมการทางคณิตศาสตร์ต่อไปนี้:
หมายเหตุเกี่ยวกับสัญกรณ์ของตัวแปร:
การต่อกันของ a และ b ถูกแทนด้วย a_b Element-wise dot product ของ a และ b แทนด้วย ab Element-wise dot product ถูกแทนด้วย \circ การคูณเมทริกซ์ถูกแทนด้วย *
หมายเหตุเพิ่มเติมเพื่อความชัดเจน:
`w_ru` สามารถแบ่งได้เป็น 4 เมทริกซ์ที่แตกต่างกัน
w_ru = [w_r_x w_u_x
w_r_h_prev w_u_h_prev]
ในทำนองเดียวกัน `w_c` สามารถแบ่งออกเป็น 2 เมทริกซ์ที่แตกต่างกัน w_c = [w_c_x w_c_h_prevr]
กันไปสำหรับอคติ b_ru = [b_ru_x b_ru_h]
b_c = [b_c_x b_c_h]
หมายเหตุเกี่ยวกับสัญกรณ์อื่น ๆ : d_x = d_x_component_1 + d_x_component_2
where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
and d_x_component_2 = d_c_bar * w_c_x^T
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
คณิตศาสตร์ที่อยู่เบื้องหลังการไล่ระดับสีดังต่อไปนี้: d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
d_r_bar_u_bar = [d_r_bar d_u_bar]
[d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
[d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
d_x = d_x_component_1 + d_x_component_2
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
(. ไม่ได้อยู่ในเคอร์เนลการไล่ระดับสี) ด้านล่างคำนวณจะดำเนินการในกระดาษห่อหลามสำหรับการไล่ระดับสี d_w_ru = x_h_prevr^T * d_c_bar
d_w_c = x_h_prev^T * d_r_bar_u_bar
d_b_ru = sum of d_r_bar_u_bar along axis = 0
d_b_c = sum of d_c_bar along axis = 0
วิธีการสาธารณะ
คง <T ขยายจำนวน> GRUBlockCellGrad <T> | |
เอาท์พุท <T> | dCBar () |
เอาท์พุท <T> | dHPrev () |
เอาท์พุท <T> | dRBarUBar () |
เอาท์พุท <T> | dX () |
วิธีการสืบทอด
วิธีการสาธารณะ
สาธารณะคง GRUBlockCellGrad <T> สร้าง ( ขอบเขต ขอบเขต Operand <T> x, Operand <T> hPrev, Operand <T> WRU, Operand <T> สุขา Operand <T> Bru, Operand <T> คริสตศักราช Operand <T > R, Operand <T> มึง Operand <T> C, Operand <T> DH)
วิธีการจากโรงงานเพื่อสร้างคลาสที่ปิดการดำเนินการ GRUBlockCellGrad ใหม่
พารามิเตอร์
ขอบเขต | ขอบเขตปัจจุบัน |
---|
คืนสินค้า
- อินสแตนซ์ใหม่ของ GRUBlockCellGrad