Source code for neograd.autograd.graph
from .node import Node
[docs]class Graph:
'''Used to keep track of nodes and tensors
The graph is constructed during the forward pass, and used by the backward
pass to calculate gradients through automatic differentiation
Parameters:
graph (Graph or None): Graph object that's currently in use. If None, then the global
_NG_GRAPH is used, else a specific graph object is used. Defaults to None
nodes_dict (dict): Stores key-value pairs of tensors and their corresponding
nodes in the graph
track (bool): Whether the graph must track the tensor operations or not, ie if True, when any
operation happens and a new result tensor is created, then the operands of the operation
are added as parents to the result tensor and the result tensor is added as child to the
operands, if False, none of these happens. Defaults to True
'''
graph = None
def __init__(self):
'''Initializes the nodes_dict to empty dict, track to True
'''
self.nodes_dict = {}
self.track = True
[docs] def add_edge(self, result_node, operands):
'''Creates an edge between two nodes
Adds edges between the result_node, which is created during an Operation, and the
operands that produced the result. This means the result_node is added as a child of
each of the operands and the result_node adds all operands as its parents
Args:
result_node (Node): node that is created in Operation.get_result_tensor
operands (list of Tensor): All the operands for an Operation
'''
self.add_node(result_node)
for operand in operands:
if self.get_node(operand) is None:
self.add_tensor(operand)
operand_node = self.get_node(operand)
result_node.add_parent(operand_node)
operand_node.add_child(result_node)
[docs] def add_node(self, node):
'''Adds a Node to the graph
Creates an key-value pair in nodes_dict with the specified node as the value
and its tens attribute as the key
Args:
node (Node): Node to be added to the graph
'''
self.nodes_dict[node.tens] = node
[docs] def get_node(self, tens):
'''Returns the Node corresponding to the Tensor
Args:
tens (Tensor): Tensor whose node is to be fetched
Returns:
Node if found, else None
'''
return self.nodes_dict.get(tens)
[docs] def add_tensor(self, tens):
'''Adds a Tensor to the graph
A new node is created for the Tensor and corresponding entry is made
in nodes_dict
Args:
tens (Tensor): Tensor to be added
'''
self.nodes_dict[tens] = Node(tens)
[docs] def remove_tensor(self, tens):
'''Removes a Tensor from the graph
Pops the Tensor from nodes_dict
Args:
tens (Tensor): Tensor to be removed
'''
self.nodes_dict.pop(tens)
[docs] def reset_visited(self):
'''Sets visited=False for each Node in the graph
'''
for node in self.nodes_dict.values():
node.visited = False
[docs] def reset_graph(self):
'''Resets the whole graph
This is accomplished by setting nodes_dict to an empty dictionary
Doing so, removes all the Tensors and their Nodes from the graph
'''
self.nodes_dict = {}
[docs] def zero_grad(self):
'''Performs zero_grad on all the tensors in the graph
Iterates through nodes_dict and performs zero_grad on the tensors
'''
for tensor in self.nodes_dict.keys():
tensor.zero_grad()
def __repr__(self):
return 'Graph()'
def __str__(self):
return f'Graph( {self.nodes_dict} )'