Gradient for matmul
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@RegisterGradient("matmul")
def _matmul_gradient(op, grad):
"""Computes the gradients for `matmul`.
Args:
op: The `matmul` `Operation` that we are differentiating
grad: Gradient with respect to the output of the `matmul` op.
Returns:
Gradients with respect to the input of `matmul`.
"""
A = op.inputs[0]
B = op.inputs[1]
return [grad.dot(B.T), A.T.dot(grad)]
Enter to Rename, Shift+Enter to Preview