Gradient for reduce_sum
@RegisterGradient("reduce_sum")
def _reduce_sum_gradient(op, grad):
"""Computes the gradients for `reduce_sum`.
Args:
op: The `reduce_sum` `Operation` that we are differentiating
grad: Gradient with respect to the output of the `reduce_sum` op.
Returns:
Gradients with respect to the input of `reduce_sum`.
"""
A = op.inputs[0]
output_shape = np.array(A.shape)
output_shape[op.axis] = 1
tile_scaling = A.shape // output_shape
grad = np.reshape(grad, output_shape)
return np.tile(grad, tile_scaling)
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX