Gradient for multiply
@RegisterGradient("multiply")
def _multiply_gradient(op, grad):
"""Computes the gradients for `multiply`.
Args:
op: The `multiply` `Operation` that we are differentiating
grad: Gradient with respect to the output of the `multiply` op.
Returns:
Gradients with respect to the input of `multiply`.
"""
A = op.inputs[0]
B = op.inputs[1]
return [grad * B, grad * A]
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX