Source code for spear.Implyloss.checkpoints

# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import os

# from .my_checkmate import BestCheckpointSaver, get_best_checkpoint
# from .my_data_types import train_modes
from .checkmate import BestCheckpointSaver, get_best_checkpoint
from .data_types import train_modes
checkpoint_dir = "./checkpoint"
if not os.path.exists(checkpoint_dir):
	os.makedirs(checkpoint_dir)

num_checkpoints = 1 # Number of checkpoints to keep around

# Keeps only the most recently saved checkpoint
#
# max_to_keep is deliberatly set to 1 in order to provide for the case when more recent checkpoint
# has a smaller global_step. tf.train.Saver() orders by global_step.

class MRUCheckpoint():
	def __init__(self, path, session, variables):
		'''
		Func Desc:
		Initializes the class variables

		Input:
		self
		path - file path
		session 
		variables

		Output:

		'''
		self.ckpt_path = path
		self.ckpt_file = os.path.join(path, 'checkpoint')
		self.checkpoint_prefix = os.path.join(self.ckpt_path, 'hls-model')
		self.sess = session
		# max_to_keep
		self.saver = tf.train.Saver(variables, max_to_keep=1)
		# self.saver = tf.train.Saver()

	def save(self, global_step=None):
		'''
		Func Desc:
		saves the obtained checkpoint

		Input:
		self
		global step (Default - none)

		Output:

		'''
		path = self.saver.save(self.sess, self.checkpoint_prefix, global_step)
		print('Saved MRU checkpoint to path: ', path)
		
	def restore(self):
		'''
		Func Desc:
		Restores the last checkpoint

		Input:
		self

		Output:

		'''
		last_checkpoint = tf.train.latest_checkpoint(self.ckpt_path, 'checkpoint')
		#if self.saver.last_checkpoints:
		#    last_checkpoint = self.saver.last_checkpoints[0]
		#    print('All saved checkpoints: ', self.saver.last_checkpoints)
		#else:
		if not last_checkpoint:
			last_checkpoint = self.checkpoint_prefix

		print('Restoring checkpoint from path: ', last_checkpoint)
		self.saver.restore(self.sess, last_checkpoint)

	def restore_if_checkpoint_exists(self):
		'''
		Func Desc:
		checks if there exists any checkpoint for the file 

		Input:
		self

		Output:
		Boolean (True or False)
		'''
		if os.path.exists(self.ckpt_file):
			self.restore()
			return True
		return False

[docs]def test_mru_checkpoints(num_to_keep): ''' Func Desc: Runs different sessions while changing the checkpoint number that is currently being worked with and tests the same Input: num_to_keep(int) - a limit on the size of the global step for checkpoint traversal Output: ''' global_step = tf.get_variable(name='mru_global_step_%d' % num_to_keep, initializer=10, dtype=tf.int32) inc = tf.assign_add(global_step, 1) sess = tf.Session() sess.run(tf.global_variables_initializer()) assert sess.run(global_step) == 10 sess.run(inc) assert sess.run(global_step) == 11 ckpt_path = '/tmp/checkpoints_%.6f' % np.random.rand() ckpt = MRUCheckpoint(ckpt_path, sess, tf.global_variables()) ckpt.save(global_step) sess.run(inc) assert sess.run(global_step) == 12 ckpt.restore_if_checkpoint_exists() assert sess.run(global_step) == 11 assgn_op = tf.assign(global_step, 5) sess.run(assgn_op) assert sess.run(global_step) == 5 ckpt.save(global_step) sess.run(inc) sess.run(inc) assert sess.run(global_step) == 7 ckpt.restore_if_checkpoint_exists() assert sess.run(global_step) == 5
[docs]def test_checkpoint(): ''' Func Desc: tests whether the checkpoints stored are as expected Input: Output: ''' v = tf.get_variable(name='v', initializer=12, dtype=tf.int32) v1 = tf.assign_add(v, 1) sess = tf.Session() sess.run(tf.global_variables_initializer()) assert sess.run(v) == 12 sess.run(v1) assert sess.run(v) == 13 ckpt = MRUCheckpoint('/tmp/checkpoints', sess, tf.global_variables()) ckpt.save() sess1 = tf.Session() sess1.run(tf.global_variables_initializer()) assert sess1.run(v) == 12 ckpt1 = MRUCheckpoint('/tmp/checkpoints', sess1, tf.global_variables()) ckpt1.restore() assert sess1.run(v) == 13 sess1.run(v1) assert sess1.run(v) == 14 ckpt2 = MRUCheckpoint('/tmp/bad-ckpt-path', sess1, tf.global_variables()) restored = ckpt2.restore_if_checkpoint_exists() assert restored == False restored = ckpt1.restore_if_checkpoint_exists() assert restored assert sess1.run(v) == 13
class BestCheckpoint(): def __init__(self, path, prefix, session, num_checkpoints, variables, global_step): ''' Func Desc: Initializes the class member variables to find the Best checkpoint so far Input: self path prefix session num_checkpoints variables global_step Output: ''' self.ckpt_path = os.path.join(path, prefix) #self.ckpt_file = os.path.join(self.ckpt_path, 'checkpoint') #self.checkpoint_prefix = os.path.join(self.ckpt_path, prefix) self.sess = session # max_to_keep is None. Number of checkpoints is handled separately by BestCheckpointSaver self.saver = tf.train.Saver(variables, max_to_keep=None, save_relative_paths=True) # self.saver = tf.train.Saver() self.best_ckpt_saver = BestCheckpointSaver( save_dir=self.ckpt_path, num_to_keep=num_checkpoints, maximize=True, saver=self.saver ) self.global_step = global_step def save_if_best(self, metric): ''' Func Desc: save if the current checkpoint is the best so far Input: self metric Output: ''' saved = self.best_ckpt_saver.handle(metric, self.sess, self.global_step) path = tf.train.latest_checkpoint(self.ckpt_path, 'checkpoint') if saved: print('Saved new best checkpoint to path: ', path) else: print('No new best checkpoint. Did not save a new best checkpoint. Last checkpointed file: ', path) def restore_best_checkpoint(self): ''' Func Desc: Restore the best checkpoint so far Input: self Output: ''' best_ckpt_file = get_best_checkpoint(self.ckpt_path, select_maximum_value=True) print('Restoring best checkpoint from path: ', best_ckpt_file) self.saver.restore(self.sess, best_ckpt_file) def restore_best_checkpoint_if_exists(self): ''' Func Desc: Restore the best checkpoint so far only if it exists Input: self Output: ''' try: self.restore_best_checkpoint() return True except ValueError as e: print(str(e)) return False
[docs]def test_best_ckpt(): ''' Func Desc: test for the best checkpoint so far Input: Output: ''' global_step = tf.get_variable(name='global_step', initializer=50, dtype=tf.int32) inc_global_step = tf.assign_add(global_step, 1) sess1 = tf.Session() sess2 = tf.Session() sess1.run(tf.global_variables_initializer()) sess2.run(tf.global_variables_initializer()) # We'll save using sess1 and restore in sess2 best_checkpoint_dir = '/tmp/best_ckpt_%.6f' % np.random.rand() best1 = BestCheckpoint(best_checkpoint_dir, 'foo-bar', sess1, 3, tf.trainable_variables(), global_step) best2 = BestCheckpoint(best_checkpoint_dir, 'foo-bar', sess2, 3, tf.trainable_variables(), global_step) restored = best2.restore_best_checkpoint_if_exists() assert not restored sess1.run(inc_global_step) ## 51 best1.save_if_best(0.1) assert sess2.run(global_step) == 50 restored = best2.restore_best_checkpoint_if_exists() assert restored assert sess2.run(global_step) == 51 sess1.run(inc_global_step) ## 52 best1.save_if_best(0.05) sess2.run(inc_global_step) # 52 sess2.run(inc_global_step) # 53 sess2.run(inc_global_step) # 54 assert sess2.run(global_step) == 54 restored = best2.restore_best_checkpoint_if_exists() assert restored assert sess2.run(global_step) == 51 sess1.run(inc_global_step) ## 53 best1.save_if_best(0.2) sess1.run(inc_global_step) ## 54 best1.save_if_best(0.15) sess2.run(inc_global_step) # 52 sess2.run(inc_global_step) # 53 sess2.run(inc_global_step) # 54 sess2.run(inc_global_step) # 55 assert sess2.run(global_step) == 55 restored = best2.restore_best_checkpoint_if_exists() assert restored assert sess2.run(global_step) == 53
[docs]def test_checkmate(): ''' Func Desc: test whether the checkmate model is working fine Input: Output: ''' global_step = tf.get_variable(name='checkmate_global_step', initializer=12, dtype=tf.int32) inc_global_step_op = tf.assign_add(global_step, 1) sess = tf.Session() sess1 = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=None) # saver = tf.train.Saver() best_checkpoint_dir = '/tmp/best_ckpt_%.6f' % np.random.rand() best_ckpt_saver = BestCheckpointSaver( save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True, saver=saver ) accuracy = 0.1 # 12 best_ckpt_saver.handle(accuracy, sess, global_step) accuracy = 0.2 sess.run(inc_global_step_op) # 13 best_ckpt_saver.handle(accuracy, sess, global_step) accuracy = 0.05 sess.run(inc_global_step_op) # 14 best_ckpt_saver.handle(accuracy, sess, global_step) ckpt_path = get_best_checkpoint(best_checkpoint_dir, select_maximum_value=True) print('Best ckpt path: ', ckpt_path) saver.restore(sess1, ckpt_path) assert sess1.run(global_step) == 13 accuracy = 0.12 sess.run(inc_global_step_op) # 15 best_ckpt_saver.handle(accuracy, sess, global_step) ckpt_path = get_best_checkpoint(best_checkpoint_dir, select_maximum_value=True) print('Best ckpt path: ', ckpt_path) saver.restore(sess1, ckpt_path) assert sess1.run(global_step) == 13 accuracy = 0.45 sess.run(inc_global_step_op) # 16 best_ckpt_saver.handle(accuracy, sess, global_step) ckpt_path = get_best_checkpoint(best_checkpoint_dir, select_maximum_value=True) print('Best ckpt path: ', ckpt_path) saver.restore(sess1, ckpt_path) assert sess1.run(global_step) == 16 # Now select lowest value ckpt_path = get_best_checkpoint(best_checkpoint_dir, select_maximum_value=False) print('Best ckpt path: ', ckpt_path) saver.restore(sess1, ckpt_path) assert sess1.run(global_step) == 15
# Loading of checkpoints happens only once - at the end of HLSModel initialization. # # Saving of checkpoints happens during training. We have only one MRU checkpoint saver # We have one best checkpoint saver per train mode type class CheckpointsFactory: def __init__(self, sess, global_steps): ''' Func Desc: Initializes the class with the arguments Input: self sess global_steps Output: ''' self.best_savers = {} self.initialize_savers(sess, global_steps) def get_best_saver(self, train_mode): ''' Func Desc: get the best saved checkpoints Input: self Train_mode - the mode of training Output: ''' return self.best_savers[train_mode] def initialize_savers(self, sess, global_steps): ''' Func Desc: Initialize the required savers Input: self sess global_steps Output: ''' for mode in train_modes: self.init_saver(sess, mode, global_steps) def init_saver(self, sess, mode, global_steps): ''' Func Desc: Initialize the required savers with the given mode Input: self sess mode global_steps Output: ''' ckpt_dir = checkpoint_dir self.best_savers[mode] = BestCheckpoint(ckpt_dir, mode, sess, num_checkpoints, tf.global_variables(), global_steps[mode])