Gradient for add
@RegisterGradient("add")
def _add_gradient(op, grad):
"""Computes the gradients for `add`.
Args:
op: The `add` `Operation` that we are differentiating
grad: Gradient with respect to the output of the `add` op.
Returns:
Gradients with respect to the input of `add`.
"""
a = op.inputs[0]
b = op.inputs[1]
grad_wrt_a = grad
while np.ndim(grad_wrt_a) > len(a.shape):
grad_wrt_a = np.sum(grad_wrt_a, axis=0)
for axis, size in enumerate(a.shape):
if size == 1:
grad_wrt_a = np.sum(grad_wrt_a, axis=axis, keepdims=True)
grad_wrt_b = grad
while np.ndim(grad_wrt_b) > len(b.shape):
grad_wrt_b = np.sum(grad_wrt_b, axis=0)
for axis, size in enumerate(b.shape):
if size == 1:
grad_wrt_b = np.sum(grad_wrt_b, axis=axis, keepdims=True)
return [grad_wrt_a, grad_wrt_b]
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX