Menghitung propagasi balik sel GRU untuk 1 langkah waktu.
Args x: Masukan ke sel GRU. h_prev: Menyatakan input dari sel GRU sebelumnya. w_ru: Matriks bobot untuk gerbang reset dan pembaruan. w_c: Matriks bobot untuk gerbang koneksi sel. b_ru: Vektor bias untuk gerbang reset dan update. b_c: Vektor bias untuk gerbang koneksi sel. r: Output dari gerbang reset. u: Output dari gerbang pembaruan. c: Output dari gerbang koneksi sel. d_h: Gradien dari h_new wrt ke fungsi objektif.
Mengembalikan d_x: Gradien dari x wrt ke fungsi tujuan. d_h_prev: Gradien dari h wrt ke fungsi tujuan. d_c_bar Gradien dari c_bar wrt ke fungsi objektif. d_r_bar_u_bar Gradien dari r_bar & u_bar wrt ke fungsi objektif.
Operasi kernel ini mengimplementasikan persamaan matematika berikut:
Catatan tentang notasi variabel:
Penggabungan a dan b diwakili oleh a_b Hasil kali titik elemen a dan b diwakili oleh ab Produk titik elemen-bijaksana diwakili oleh \circ Perkalian matriks diwakili oleh *
Catatan tambahan untuk kejelasan:
`w_ru` dapat disegmentasikan menjadi 4 matriks yang berbeda.
w_ru = [w_r_x w_u_x
w_r_h_prev w_u_h_prev]
Demikian pula, `w_c` dapat dibagi menjadi 2 matriks yang berbeda. w_c = [w_c_x w_c_h_prevr]
Sama berlaku untuk bias. b_ru = [b_ru_x b_ru_h]
b_c = [b_c_x b_c_h]
lain catatan pada notasi: d_x = d_x_component_1 + d_x_component_2
where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
and d_x_component_2 = d_c_bar * w_c_x^T
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
Matematika di balik Gradien bawah: d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
d_r_bar_u_bar = [d_r_bar d_u_bar]
[d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
[d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
d_x = d_x_component_1 + d_x_component_2
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
(. tidak dalam kernel gradien) Berikut perhitungan dilakukan dalam pembungkus python untuk Gradien d_w_ru = x_h_prevr^T * d_c_bar
d_w_c = x_h_prev^T * d_r_bar_u_bar
d_b_ru = sum of d_r_bar_u_bar along axis = 0
d_b_c = sum of d_c_bar along axis = 0
Metode Publik
statis <T meluas Nomor> GRUBlockCellGrad <T> | |
Output <T> | dCBar () |
Output <T> | dHPrev () |
Output <T> | dRBarUBar () |
Output <T> | dX () |
Metode yang Diwarisi
Metode Publik
public static GRUBlockCellGrad <T> membuat ( Lingkup lingkup, Operan <T> x, Operan <T> hPrev, Operan <T> WRU, Operan <T> WC, Operan <T> Bru, Operan <T> bC, Operan <T > r, Operan <T> u, Operan <T> c, Operan <T> dH)
Metode pabrik untuk membuat kelas yang membungkus operasi GRUBlockCellGrad baru.
Parameter
cakupan | lingkup saat ini |
---|
Kembali
- contoh baru GRUBlockCellGrad