GRUBlockCellGrad

שיעור הגמר הציבורי GRUBlockCellGrad

מחשב את ההפצה לאחור של תא GRU עבור שלב אחד.

Args x: קלט לתא GRU. h_prev: קלט מצב מתא GRU הקודם. w_ru: מטריצת משקל עבור שער האיפוס והעדכון. w_c: מטריצת משקל עבור שער חיבור התא. b_ru: וקטור הטיה עבור שער האיפוס והעדכון. b_c: וקטור הטיה עבור שער חיבור התא. r: פלט של שער האיפוס. u: פלט של שער העדכון. ג: פלט של שער חיבור התא. d_h: הדרגות של הפונקציה h_new wrt לאובייקטיבית.

מחזירה d_x: מעברי צבע של x wrt לפונקציה אובייקטיבית. d_h_prev: הדרגות של ה-h wrt לפונקציה אובייקטיבית. d_c_bar שיפועים של c_bar wrt לפונקציה אובייקטיבית. d_r_bar_u_bar שיפועים של r_bar & u_bar עם פונקציה אובייקטיבית.

קרנל אופ זה מיישם את המשוואות המתמטיות הבאות:

הערה לגבי סימון המשתנים:

שרשור של a ו-b מיוצג על ידי a_b מכפלת נקודה מבחינת היסודות של a ו-b מיוצגת על ידי ab מכפלת נקודה מבחינה אלמנט מיוצגת על ידי \circ כפל מטריצה ​​מיוצג על ידי *

הערות נוספות לבהירות:

ניתן לפלח את `w_ru` ל-4 מטריצות שונות.

w_ru = [w_r_x w_u_x
         w_r_h_prev w_u_h_prev]
 
באופן דומה, 'w_c' ניתן לפלח ל-2 מטריצות שונות.
w_c = [w_c_x w_c_h_prevr]
 
כנ"ל לגבי הטיות.
b_ru = [b_ru_x b_ru_h]
 b_c = [b_c_x b_c_h]
 
הערה נוספת על סימון:
d_x = d_x_component_1 + d_x_component_2
 
 where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
 and d_x_component_2 = d_c_bar * w_c_x^T
 
 d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
 where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
 
מתמטיקה מאחורי הדרגתיים למטה:
d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
 d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
 
 d_r_bar_u_bar = [d_r_bar d_u_bar]
 
 [d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
 
 [d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
 
 d_x = d_x_component_1 + d_x_component_2
 
 d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
 
החישוב להלן מתבצע במעטפת הפיתון עבור ה- Gradients (לא בליבת ה-gradient.)
d_w_ru = x_h_prevr^T * d_c_bar
 
 d_w_c = x_h_prev^T * d_r_bar_u_bar
 
 d_b_ru = sum of d_r_bar_u_bar along axis = 0
 
 d_b_c = sum of d_c_bar along axis = 0
 

שיטות ציבוריות

סטטי <T מרחיב מספר> GRUBlockCellGrad <T>
צור ( Scope scope, Operand <T> x, Operand <T> hPrev, Operand <T> wRu, Operand <T> wC, Operand <T> bRu, Operand <T> bC, Operand <T> r, Operand <T > u, Operand <T> c, Operand <T> dH)
שיטת מפעל ליצירת מחלקה העוטפת פעולת GRUBlockCellGrad חדשה.
פלט <T>
dCBar ()
פלט <T>
dHPrev ()
פלט <T>
פלט <T>
dX ()

שיטות בירושה

שיטות ציבוריות

ציבורי סטטי GRUBlockCellGrad <T> create ( Scope scope, Operand <T> x, Operand <T> hPrev, Operand <T> wRu, Operand <T> wC, Operand <T> bRu, Operand <T> bC, Operand <T > r, Operand <T> u, Operand <T> c, Operand <T> dH)

שיטת מפעל ליצירת מחלקה העוטפת פעולת GRUBlockCellGrad חדשה.

פרמטרים
תְחוּם ההיקף הנוכחי
החזרות
  • מופע חדש של GRUBlockCellGrad

פלט ציבורי <T> dCBar ()

פלט ציבורי <T> dHPrev ()

פלט ציבורי <T> dRBarUBar ()

פלט ציבורי <T> dX ()