Source code for neograd.nn.checkpoint

import os
import json
import numpy as np
import datetime as dt
import secrets
from hashlib import sha256


[docs]class Checkpoint: '''Creates and initializes files for checkpoints A JSON file checkpoints.json is created which contains the dict which has the tracked values and also the params file name at the time of adding a checkpoint, operates in append mode A new params file is created at each checkpoint and the JSON file is updated Parameters: session (str): Current session that is in use dirpath (str): Directory in which checkpoints must be saved model (Model): Model to be checkpointed hash_length (int): Character length of session identifiers. Defaults to 16 ''' def __init__(self, model, dirpath, hash_length=16): ''' Raises: AssertionError: if hash_length<0 and hash_length>64 ''' self.session = None self.dirpath = None assert hash_length>0 and hash_length<=64, 'Hash length must be between 1 and 64' self.hash_length = hash_length self._init_files(dirpath) self.model = model
[docs] def new_session(self): '''Creates a new session in checkpoints.json Also creates a new directory for the session Returns: self ''' self.session = self._generate_hash() os.mkdir(f'{self.dirpath}/{self.session}') return self
[docs] def specify_session(self, session): '''Used to specify a particular session to use for checkpoints.json Args: session (str): The session to be used Raises: ValueError: If session is not already in checkpoints.json ''' with open(f'{self.dirpath}/checkpoints.json', 'r') as checkpoints_fp: existing_checkpoints = json.load(checkpoints_fp) if session in existing_checkpoints: self.session = session else: raise ValueError(f"Invalid session {session} specified!")
[docs] def add(self, **tracked): '''Adds a new checkpoint Args: **tracked: All the data that needs to be tracked in checkpoints.json Raises: ValueError: If forbidden_attrs ('datetime') are used as keys in tracked, because the same key is used by neograd to add key of the same value, which might get overwritten ValueError: If values in tracked aren't serializable and don't belong to builtin classes ''' forbidden_attrs = ('datetime') allowed_types = (int, float, str, dict, list) for attr, val in tracked.items(): if not(isinstance(val, allowed_types)): raise ValueError(f'Only {allowed_types} can be tracked, if using a Tensor, use data attribute and convert it into native python objects or strings!') if attr in forbidden_attrs: raise ValueError(f'Attribute {attr} must not be present in forbidden_attrs {forbidden_attrs}, as neograd uses them internally!') updated_checkpoint, params_fname_hash = self._update(tracked) self._save(updated_checkpoint, params_fname_hash)
[docs] def _update(self, new_checkpoint): '''Updates the new_checkpoint Updates the checkpoint by including the datetime of adding new checkpoint Generates the hash that'll be used as the filename for the params file that'll be saved Args: new_checkpoint (Checkpoint): New Checkpoint to be updated Returns: New Checkpoint, hash that is used as fname ''' curr_time = str(dt.datetime.now()) fname_hash = self._generate_hash() new_checkpoint['datetime'] = curr_time return new_checkpoint, fname_hash
[docs] def _save(self, updated_checkpoint, params_fname_hash): '''Saves the checkpoint Saves the checkpoint details onto checkpoints.json and creates a new file with the params of the model if self.session isn't already in existing checkpoints, then it creates a new dict and adds the checkpoint there. Args: updated_checkpoint (Checkpoint): Checkpoint that is updated params_fname_hash (str): Hash that is generated to be the name of filename ''' with open(f'{self.dirpath}/checkpoints.json', 'r') as checkpoints_fp: existing_checkpoints = json.load(checkpoints_fp) with open(f'{self.dirpath}/checkpoints.json', 'w') as checkpoints_fp: if existing_checkpoints.get(self.session) is None: existing_checkpoints[self.session] = {} existing_checkpoints[self.session][params_fname_hash] = updated_checkpoint json.dump(existing_checkpoints, checkpoints_fp, indent=4) self.model.save(f'{self.dirpath}/{self.session}/{params_fname_hash}.hkl')
[docs] def load(self, params_fname, load_params=True): '''Retrieves the Checkpoint Returns the checkpoint based on the params_fname and loads the params onto the model if load_params is True Args: params_fname (str): Filename to load params from load_params (bool): Whether params should be loaded from the file onto the model Returns: Checkpoint desired Raises: ValueError: If the current session is not present in checkpoints.py ''' params_fname_hash = params_fname.rstrip('.hkl') with open(f'{self.dirpath}/checkpoints.json') as checkpoints_fp: session = json.load(checkpoints_fp).get(self.session) if session is None: raise ValueError(f"Invalid session {self.session}") if params_fname_hash not in session.keys(): raise ValueError(f"File {params_fname} not in current session {self.session} directory! Please specify the session using Checkpoint.specify_session") checkpoint = session[params_fname_hash] if load_params: self.model.load(f'{self.dirpath}/{self.session}/{params_fname}') return checkpoint
[docs] def _init_files(self, dirpath): '''Initializes files required for Checkpoint Creates a new folder at dirpath, if it doesn't exist A checkpoint.json file is created, if it is empty, then a new session is created if self.session is None, then automatically the last session is initialized as self.session Args: dirpath (str): Directory in which checkpoints must be saved ''' dirpath = dirpath.rstrip("/'\'") try: os.mkdir(dirpath) except FileExistsError: pass self.dirpath = dirpath # All validation passed, can be assigned to self try: with open(f'{dirpath}/checkpoints.json', 'x') as checkpoints_fp: # Create checkpoints.json or do nothing if already exists pass except FileExistsError: pass with open(f'{dirpath}/checkpoints.json') as checkpoints_fp: contents = checkpoints_fp.read() if contents.strip()!='': sessions = json.loads(contents) #json.JSONDecodeError is raised if JSON file is invalid if self.session is None: if len(sessions)==0: self.new_session() else: self.session = list(sessions.keys())[-1] else: # checkpoints.json is empty sessions = {} self.new_session() with open(f'{dirpath}/checkpoints.json', 'w') as checkpoints_fp: json.dump(sessions, checkpoints_fp, indent=4)
[docs] def _generate_hash(self): '''Generates 64 hex digit sha256 hash of a random number Returns: sha256 hash ''' return sha256(secrets.token_hex(32).encode('utf-8')).hexdigest()[:self.hash_length]