1 समय चरण के लिए GRU सेल बैक-प्रचार की गणना करता है।
आर्ग्स x: GRU सेल में इनपुट। h_prev: पिछले GRU सेल से इनपुट बताएं। w_ru: रीसेट और अपडेट गेट के लिए वजन मैट्रिक्स। w_c: सेल कनेक्शन गेट के लिए वजन मैट्रिक्स। b_ru: रीसेट और अपडेट गेट के लिए बायस वेक्टर। b_c: सेल कनेक्शन गेट के लिए बायस वेक्टर। आर: रीसेट गेट का आउटपुट। यू: अद्यतन गेट का आउटपुट। सी: सेल कनेक्शन गेट का आउटपुट। d_h: h_new wrt का ऑब्जेक्टिव फ़ंक्शन में ग्रेडिएंट।
रिटर्न d_x: ऑब्जेक्टिव फ़ंक्शन के लिए x wrt का ग्रेडिएंट। d_h_prev: उद्देश्य फ़ंक्शन के लिए h wrt का ग्रेडिएंट। d_c_bar ऑब्जेक्टिव फ़ंक्शन के लिए c_bar के ग्रेडिएंट्स। d_r_bar_u_bar ऑब्जेक्टिव फ़ंक्शन के लिए r_bar और u_bar के ग्रेडियेंट।
यह कर्नेल ऑप निम्नलिखित गणितीय समीकरण लागू करता है:
चरों के अंकन पर ध्यान दें:
ए और बी के संयोजन को a_b द्वारा दर्शाया गया है ए और बी के तत्व-वार डॉट उत्पाद को एबी द्वारा दर्शाया गया है तत्व-वार डॉट उत्पाद को \circ द्वारा दर्शाया गया है मैट्रिक्स गुणन को * द्वारा दर्शाया गया है
स्पष्टता के लिए अतिरिक्त नोट्स:
`w_ru` को 4 अलग-अलग आव्यूहों में विभाजित किया जा सकता है।
w_ru = [w_r_x w_u_x
w_r_h_prev w_u_h_prev]
इसी तरह, `w_c` को 2 अलग-अलग आव्यूहों में विभाजित किया जा सकता है। w_c = [w_c_x w_c_h_prevr]
यही बात पक्षपात के लिए भी लागू होती है। b_ru = [b_ru_x b_ru_h]
b_c = [b_c_x b_c_h]
अंकन पर एक और नोट: d_x = d_x_component_1 + d_x_component_2
where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
and d_x_component_2 = d_c_bar * w_c_x^T
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
नीचे दिए गए स्नातकों के पीछे का गणित: d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
d_r_bar_u_bar = [d_r_bar d_u_bar]
[d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
[d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
d_x = d_x_component_1 + d_x_component_2
d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
नीचे दी गई गणना ग्रैडिएंट्स के लिए पायथन रैपर में की जाती है (ग्रेडिएंट कर्नेल में नहीं।) d_w_ru = x_h_prevr^T * d_c_bar
d_w_c = x_h_prev^T * d_r_bar_u_bar
d_b_ru = sum of d_r_bar_u_bar along axis = 0
d_b_c = sum of d_c_bar along axis = 0
सार्वजनिक तरीके
स्थिर <T संख्या बढ़ाता है> GRUBlockCellGrad <T> | |
आउटपुट <T> | डीसीबार () |
आउटपुट <T> | dHPrev () |
आउटपुट <T> | डीआरबारयूबार () |
आउटपुट <T> | डीएक्स () |
विरासत में मिले तरीके
सार्वजनिक तरीके
सार्वजनिक स्थैतिक GRUBlockCellGrad <T> बनाएं ( स्कोप स्कोप, ऑपरेंड <T> x, ऑपरेंड <T> hPrev, ऑपरेंड <T> wRu, ऑपरेंड <T> wC, ऑपरेंड <T> bRu, ऑपरेंड <T> bC, ऑपरेंड <T > आर, ऑपरेंड <टी> यू, ऑपरेंड <टी> सी, ऑपरेंड <टी> डीएच)
एक नया GRUBlockCellGrad ऑपरेशन लपेटकर क्लास बनाने की फ़ैक्टरी विधि।
पैरामीटर
दायरा | वर्तमान दायरा |
---|
रिटर्न
- GRUBlockCellGrad का एक नया उदाहरण