Source code for spear.Implyloss.checkmate

import os
import glob
import json
import numpy as np
# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


[docs]class BestCheckpointSaver(object): '''Maintains a directory containing only the best n checkpoints Inside the directory is a best_checkpoints JSON file containing a dictionary mapping of the best checkpoint filepaths to the values by which the checkpoints are compared. Only the best n checkpoints are contained in the directory and JSON file. This is a light-weight wrapper class only intended to work in simple, non-distributed settings. It is not intended to work with the tf.Estimator framework. ''' def __init__(self, save_dir, num_to_keep=1, maximize=True, saver=None): ''' Func Desc: Creates a `BestCheckpointSaver` `BestCheckpointSaver` acts as a wrapper class around a `tf.train.Saver` Input: save_dir: The directory in which the checkpoint files will be saved num_to_keep: The number of best checkpoint files to retain maximize: Define 'best' values to be the highest values. For example, set this to True if selecting for the checkpoints with the highest given accuracy. Or set to False to select for checkpoints with the lowest given error rate. saver: A `tf.train.Saver` to use for saving checkpoints. A default `tf.train.Saver` will be created if none is provided. Output: ''' self._num_to_keep = num_to_keep self._save_dir = save_dir self._save_path = os.path.join(save_dir, 'best.ckpt') self._maximize = maximize self._saver = saver if saver else tf.train.Saver( max_to_keep=None, save_relative_paths=True ) if not os.path.exists(save_dir): os.makedirs(save_dir) self.best_checkpoints_file = os.path.join(save_dir, 'best_checkpoints')
[docs] def handle(self, value, sess, global_step_tensor): ''' Func Desc: Updates the set of best checkpoints based on the given result. Input: value: The value by which to rank the checkpoint. sess: A tf.Session to use to save the checkpoint global_step_tensor: A `tf.Tensor` represent the global step Output: True or False ''' global_step = sess.run(global_step_tensor) current_ckpt = 'best.ckpt-{}'.format(global_step) value = float(value) if not os.path.exists(self.best_checkpoints_file): self._save_best_checkpoints_file({current_ckpt: value}) self._saver.save(sess, self._save_path, global_step_tensor) return True best_checkpoints = self._load_best_checkpoints_file() if len(best_checkpoints) < self._num_to_keep: best_checkpoints[current_ckpt] = value self._save_best_checkpoints_file(best_checkpoints) self._saver.save(sess, self._save_path, global_step_tensor) return True if self._maximize: should_save = not all(current_best >= value for current_best in best_checkpoints.values()) else: should_save = not all(current_best <= value for current_best in best_checkpoints.values()) if should_save: best_checkpoint_list = self._sort(best_checkpoints) worst_checkpoint = os.path.join(self._save_dir, best_checkpoint_list.pop(-1)[0]) self._remove_outdated_checkpoint_files(worst_checkpoint) self._update_internal_saver_state(best_checkpoint_list) best_checkpoints = dict(best_checkpoint_list) best_checkpoints[current_ckpt] = value self._save_best_checkpoints_file(best_checkpoints) self._saver.save(sess, self._save_path, global_step_tensor) return True return False
def _save_best_checkpoints_file(self, updated_best_checkpoints): ''' Func Desc: Save the best checkpoints Input: self updated_best_checkpoints Output: ''' with open(self.best_checkpoints_file, 'w') as f: json.dump(updated_best_checkpoints, f, indent=3) def _remove_outdated_checkpoint_files(self, worst_checkpoint): ''' Func Desc: To remove the outdated checkpoint files Input: self worst_checkpoint Output: ''' os.remove(os.path.join(self._save_dir, 'checkpoint')) for ckpt_file in glob.glob(worst_checkpoint + '.*'): os.remove(ckpt_file) def _update_internal_saver_state(self, best_checkpoint_list): ''' Func Desc: TO update the internal saver state with best_checkpoint_list Input: self best_checkpoint_list Output: ''' best_checkpoint_files = [ (ckpt[0], np.inf) # TODO: Try to use actual file timestamp for ckpt in best_checkpoint_list ] self._saver.set_last_checkpoints_with_time(best_checkpoint_files) def _load_best_checkpoints_file(self): ''' Func Desc: load rhe best checkpoints file Input: self Output: the best checkpoints ''' with open(self.best_checkpoints_file, 'r') as f: best_checkpoints = json.load(f) return best_checkpoints def _sort(self, best_checkpoints): ''' Func Desc: Sort the best_checkpoints list in the decreasing order of their goodness Input: self best_checkpoints - list of checkpoints Output: ''' best_checkpoints = [ (ckpt, best_checkpoints[ckpt]) for ckpt in sorted(best_checkpoints, key=best_checkpoints.get, reverse=self._maximize) ] return best_checkpoints
[docs]def get_best_checkpoint(best_checkpoint_dir, select_maximum_value=True): ''' Func Desc: Reads the best_checkpoints file in the best_checkpoint_dir directory. Returns the filepath in the best_checkpoints file associated with the highest value if select_maximum_value is True, or the filepath associated with the lowest value if select_maximum_value is False. Input: best_checkpoint_dir: Directory containing best_checkpoints JSON file select_maximum_value: If True, select the filepath associated with the highest value. Otherwise, select the filepath associated with the lowest value. Output: The full path to the best checkpoint file ''' best_checkpoints_file = os.path.join(best_checkpoint_dir, 'best_checkpoints') if not os.path.exists(best_checkpoints_file): raise ValueError('Checkpoint file does not exist') with open(best_checkpoints_file, 'r') as f: best_checkpoints = json.load(f) best_checkpoints = [ ckpt for ckpt in sorted(best_checkpoints, key=best_checkpoints.get, reverse=select_maximum_value) ] return os.path.join(best_checkpoint_dir, best_checkpoints[0])