Register Gradient Decorator
# A dictionary that will map operations to gradient functions
_gradient_registry = {}
class RegisterGradient:
"""A decorator for registering the gradient function for an op type.
"""
def __init__(self, op_type):
"""Creates a new decorator with `op_type` as the Operation type.
Args:
op_type: The name of an operation
"""
self._op_type = eval(op_type)
def __call__(self, f):
"""Registers the function `f` as gradient function for `op_type`."""
_gradient_registry[self._op_type] = f
return f
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX