Calcola la retro-propagazione delle celle GRU per 1 fase temporale.
Args x: input per la cella GRU. h_prev: input di stato dalla cella GRU precedente. w_ru: matrice di peso per il gate di ripristino e aggiornamento. w_c: matrice di peso per il gate di connessione della cella. b_ru: vettore di polarizzazione per il gate di ripristino e aggiornamento. b_c: vettore di polarizzazione per la porta di connessione della cella. r: Uscita del gate di reset. u: output del gate di aggiornamento. c: Uscita del gate di connessione della cella. d_h: gradienti della funzione h_new rispetto all'obiettivo.
Restituisce d_x: gradienti della x rispetto alla funzione obiettivo. d_h_prev: gradienti della h rispetto alla funzione obiettivo. d_c_bar Gradienti della c_bar rispetto alla funzione obiettivo. d_r_bar_u_bar Gradienti della r_bar & u_bar rispetto alla funzione obiettivo.
Questo kernel operativo implementa le seguenti equazioni matematiche:
Nota sulla notazione delle variabili:
La concatenazione di aeb è rappresentata da a_b Il prodotto scalare in termini di elementi di a e b è rappresentato da ab Il prodotto scalare in termini di elementi è rappresentato da \ circ La moltiplicazione della matrice è rappresentata da *
Note aggiuntive per chiarezza:
"w_ru" può essere segmentato in 4 matrici differenti.
w_ru = [w_r_x w_u_x
w_r_h_prev w_u_h_prev]
Allo stesso modo, `w_c` può essere segmentato in 2 matrici differenti. w_c = [w_c_x w_c_h_prevr]
Lo stesso vale per i pregiudizi. b_ru = [b_ru_x b_ru_h]
b_c = [b_c_x b_c_h]
Un'altra nota sulla notazione: d_x = d_x_component_1 + d_x_component_2
where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
and d_x_component_2 = d_c_bar * w_c_x^T
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
Matematica dietro i gradienti di seguito: d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
d_r_bar_u_bar = [d_r_bar d_u_bar]
[d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
[d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
d_x = d_x_component_1 + d_x_component_2
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
Il calcolo sotto viene eseguito nel wrapper python per i gradienti (non nel kernel dei gradienti). d_w_ru = x_h_prevr^T * d_c_bar
d_w_c = x_h_prev^T * d_r_bar_u_bar
d_b_ru = sum of d_r_bar_u_bar along axis = 0
d_b_c = sum of d_c_bar along axis = 0
Metodi pubblici
static <T estende Number> GRUBlockCellGrad <T> | |
Uscita <T> | dCBar () |
Uscita <T> | dHPrev () |
Uscita <T> | dRBarUBar () |
Uscita <T> | dX () |
Metodi ereditati
Metodi pubblici
public static GRUBlockCellGrad <T> create ( Scope scope, Operand <T> x, Operand <T> hPrev, Operand <T> wRu, Operand <T> wC, Operand <T> bRu, Operand <T> bC, Operand <T > r, Operando <T> u, Operando <T> c, Operando <T> dH)
Metodo Factory per creare una classe che racchiude una nuova operazione GRUBlockCellGrad.
Parametri
scopo | ambito attuale |
---|
ritorna
- una nuova istanza di GRUBlockCellGrad