Returns a subgradient of the MaximumSpanningTree op.
text.max_spanning_tree_gradient(
mst_op, d_loss_d_max_scores, *_
)
Note that MaximumSpanningTree is only differentiable w.r.t. its |scores| input
and its |max_scores| output.
Args |
mst_op
|
The MaximumSpanningTree op being differentiated.
|
d_loss_d_max_scores
|
[B] vector where entry b is the gradient of the network
loss w.r.t. entry b of the |max_scores| output of the |mstop|.
|
*_ <a id="*">
|
The gradients w.r.t. the other outputs; ignored.
|
Returns |
- None, since the op is not differentiable w.r.t. its |num_nodes| input.
- [B,M,M] tensor where entry b,t,s is a subgradient of the network loss
w.r.t. entry b,t,s of the |scores| input, with the same dtype as
|d_loss_d_max_scores|.
|