Source code for neograd.autograd.node
[docs]class Node:
'''Used as an abstraction to connect the tensors together and hold relationships
Each Tensor is assigned a Node and this Node monitors all the incoming edges(parents)
and the outgoing edges(children)
Parameters:
children (list of Node): List of all Nodes which uses the current Node
as an operand in an Operation
parents (list of Node): List of all Nodes(operands) that has resulted in the creation
of current Node
parent_broadcast_shape (tuple or None): If the parent needs to be broadcasted from one shape to
another, then the final broadcasted shape of the parent is stored here.
If they cannot be broadcasted, then it is None
backward_fn (Operation.backward): Sets the grad_fn of Tensor(operand) involved in the Operation
visited (bool) - If Node is visited or not
'''
def __init__(self, tens):
'''
Args:
tens (Tensor) - The Tensor corresponding to the Node
'''
self.tens = tens
self.children = []
self.parents = []
self.parent_broadcast_shape = None
self.backward_fn = None
self.visited = False
[docs] def top_sort(self):
'''Performs topological sort of all Nodes starting from current Node
Sorts the graph topologically, to perform backward pass efficiently, so that all the children's
is calculated before the current node's gradient is calculated.
Sorting is done by first checking if all the children are visited, if they are, then the current
node is added to sorted_tensors if not, then topological sort is performed on children
'''
sorted_tensors = []
if self.are_children_visited():
self.visited = True
sorted_tensors.append(self.tens)
for parent in self.parents:
if not(parent.visited):
sorted_tensors+=parent.top_sort()
else:
for child in self.children:
if not(child.visited):
sorted_tensors+=child.top_sort()
return sorted_tensors
[docs] def backward(self, retain_graph):
'''Initiates backward pass starting from current Node
This first visits all the children to make sure that they aren't included in
sorted_tensors as they aren't required as backward pass is being initiated from the current
node.
Then it pops its corresponding Tensor from sorted_tensors (it is the first tensor) so that
_backward can be called on it with calculate_grads=False, so that grads arent calculated for
it, but allows flushing of all Tensors
Next it topologically sorts all Tensors starting from current Node then the Node
corresponding to the Tensor is retreived, which is marked as visited and the Tensor's
backward pass is initiated.
Args:
retain_graph (bool): If the graph should be retained after backward pass or flushed
after backward calculation
'''
from .utils import get_graph
graph = get_graph()
graph.reset_visited()
self.visit_all_children() # this allows for gradient calculation from any intermediate node in the graph
sorted_tensors = self.top_sort()
graph.reset_visited()
sorted_tensors.pop(0) # Remove the Tensor corresponding to the current node
self.visited = True
self.tens._backward(self, retain_graph, calculate_grads=False)
for tens in sorted_tensors:
node = graph.get_node(tens)
node.visited = True
tens._backward(node, retain_graph)
[docs] def visit_all_children(self):
'''Marks all children as visited
'''
for child in self.children:
child.visited = True
[docs] def are_children_visited(self):
'''Checks if all children are visited
Returns:
True if all children are visited else False
'''
for child in self.children:
if not(child.visited):
return False
return True
[docs] def are_parents_visited(self):
'''Checks if all parents are visited
Returns:
True if all parents are visited else False
'''
for parent in self.parents:
if not(parent.visited):
return False
return True
[docs] def add_child(self, other):
'''Adds a child to the Node
Args:
other (Node): The child Node
'''
self.children.append(other)
[docs] def add_parent(self, other):
'''Adds a parent to the Node
Args:
other (Node): The parent Node
'''
self.parents.append(other)
def __repr__(self):
return f'Node({self.tens})'
def __str__(self):
return f'Node( \n{self.tens}\nbackward_fn: {self.backward_fn}\nvisited: {self.visited}\n )'