Gradient for softmax
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@RegisterGradient("softmax")
def _softmax_gradient(op, grad):
"""Computes the gradients for `softmax`.
Args:
op: The `softmax` `Operation` that we are differentiating
grad: Gradient with respect to the output of the `softmax` op.
Returns:
Gradients with respect to the input of `softmax`.
"""
softmax = op.output
return (grad - np.reshape(
np.sum(grad * softmax, 1),
[-1, 1]
)) * softmax
Enter to Rename, Shift+Enter to Preview