Register Gradient Decorator
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# A dictionary that will map operations to gradient functions
_gradient_registry = {}
class RegisterGradient:
"""A decorator for registering the gradient function for an op type.
"""
def __init__(self, op_type):
"""Creates a new decorator with `op_type` as the Operation type.
Args:
op_type: The name of an operation
"""
self._op_type = eval(op_type)
def __call__(self, f):
"""Registers the function `f` as gradient function for `op_type`."""
_gradient_registry[self._op_type] = f
return f
Enter to Rename, Shift+Enter to Preview