Gradient for sigmoid
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@RegisterGradient("sigmoid")
def _sigmoid_gradient(op, grad):
"""Computes the gradients for `sigmoid`.
Args:
op: The `sigmoid` `Operation` that we are differentiating
grad: Gradient with respect to the output of the `sigmoid` op.
Returns:
Gradients with respect to the input of `sigmoid`.
"""
sigmoid = op.output
return grad * sigmoid * (1 - sigmoid)
Enter to Rename, Shift+Enter to Preview