Source code for neograd.nn.utils
import dill
from .model import Model
[docs]def save_model(fpath, model):
'''Saves the model
Saves the model by pickling it onto a file
Args:
fpath (str): Path in which to save the model
model (Model): Model to be saved
Raises:
TypeError: if model isn't an instance of Model
'''
if not isinstance(model, Model):
raise TypeError(f'Expected Model object, instead got {type(model)}')
with open(fpath,'wb') as fp:
dill.dump(model, fp)
print(f'MODEL SAVED at {fpath}')
[docs]def load_model(fpath):
'''Loads the model
Args:
fpath (str): Path from which to load the model
Returns:
Model object that is loaded
'''
with open(fpath,'rb') as fp:
model = dill.load(fp)
print(f'MODEL LOADED from {fpath}')
return model
[docs]def get_batches(inputs, targets=None, batch_size=None):
'''Returns batches of inputs and targets
Split the inputs and their corresponding targets into batches for efficient
training
Args:
inputs (Tensor): Inputs to be batched
targets (Tensor): Targets to be batched. Defaults to None
batch_size (int): Size of the batches. Defaults to None meaning batch_size
will be same as number of examples
Yields:
Batches of inputs and their corresponding targets
Raises:
AssertionError: If first dimensions of inputs and targets don't match
ValueError: If batch_size is greater than number of examples
ValueError: If batch_size is negative
ValueError: If batch_size is 0
'''
if targets is not None:
assert inputs.shape[0]==targets.shape[0], '0th dim should be number of examples and should match'
num_examples = inputs.shape[0]
if batch_size is not None:
if batch_size > num_examples:
raise ValueError("Batch size cannot be greater than the number of examples")
elif batch_size<0:
raise ValueError("Batch size cannot be negative")
elif batch_size==0:
raise ValueError("Batch size cannot be zero")
else:
batch_size = num_examples
start = 0
while start<num_examples:
end = start+batch_size if start+batch_size<num_examples else num_examples
if targets is not None:
yield inputs[start:end], targets[start:end]
else:
yield inputs[start:end]
start = end