From 8ac193d9065e957fc6103dc7eca314b6c7844696 Mon Sep 17 00:00:00 2001 From: AntreasAntoniou Date: Fri, 16 Feb 2018 23:44:04 +0000 Subject: [PATCH] Add support for retrieving best epoch and val score when reloading a model --- cifar100_network_trainer.py | 11 ++++++----- cifar10_network_trainer.py | 9 ++++++--- emnist_network_trainer.py | 11 ++++++----- msd10_network_trainer.py | 12 ++++++------ msd25_network_trainer.py | 11 ++++++----- utils/storage.py | 16 +++++++++++++++- 6 files changed, 45 insertions(+), 25 deletions(-) diff --git a/cifar100_network_trainer.py b/cifar100_network_trainer.py index 9d86344..ab6a227 100644 --- a/cifar100_network_trainer.py +++ b/cifar100_network_trainer.py @@ -5,7 +5,7 @@ import tqdm from data_providers import CIFAR100DataProvider from network_builder import ClassifierNetworkGraph from utils.parser_utils import ParserClass -from utils.storage import build_experiment_folder, save_statistics +from utils.storage import build_experiment_folder, save_statistics, get_best_validation_model_statistics tf.reset_default_graph() # resets any previous graphs to clear memory parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser @@ -64,8 +64,6 @@ total_train_batches = train_data.num_batches total_val_batches = val_data.num_batches total_test_batches = test_data.num_batches -best_epoch = 0 - if tensorboard_enable: print("saved tensorboard file at", logs_filepath) writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph()) @@ -76,13 +74,16 @@ with tf.Session() as sess: sess.run(init) # actually running the initialization op train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of val_saver = tf.train.Saver() + best_val_accuracy = 0. + best_epoch = 0 # training or inference if continue_from_epoch != -1: train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name, - continue_from_epoch)) # restore previous graph to continue operations + continue_from_epoch)) # restore previous graph to continue operations + best_val_accuracy, best_epoch = get_best_validation_model_statistics(logs_filepath, "result_summary_statistics") + print(best_val_accuracy, best_epoch) - best_val_accuracy = 0. with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar: for e in range(start_epoch, epochs): total_c_loss = 0. diff --git a/cifar10_network_trainer.py b/cifar10_network_trainer.py index 54f10c1..b6470b1 100644 --- a/cifar10_network_trainer.py +++ b/cifar10_network_trainer.py @@ -5,7 +5,7 @@ import tqdm from data_providers import CIFAR10DataProvider from network_builder import ClassifierNetworkGraph from utils.parser_utils import ParserClass -from utils.storage import build_experiment_folder, save_statistics +from utils.storage import build_experiment_folder, save_statistics, get_best_validation_model_statistics tf.reset_default_graph() # resets any previous graphs to clear memory parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser @@ -64,7 +64,7 @@ total_train_batches = train_data.num_batches total_val_batches = val_data.num_batches total_test_batches = test_data.num_batches -best_epoch = 0 + if tensorboard_enable: print("saved tensorboard file at", logs_filepath) @@ -76,13 +76,16 @@ with tf.Session() as sess: sess.run(init) # actually running the initialization op train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of val_saver = tf.train.Saver() + best_val_accuracy = 0. + best_epoch = 0 # training or inference if continue_from_epoch != -1: train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name, continue_from_epoch)) # restore previous graph to continue operations + best_val_accuracy, best_epoch = get_best_validation_model_statistics(logs_filepath, "result_summary_statistics") + print(best_val_accuracy, best_epoch) - best_val_accuracy = 0. with tqdm.tqdm(total=epochs-start_epoch) as epoch_pbar: for e in range(start_epoch, epochs): total_c_loss = 0. diff --git a/emnist_network_trainer.py b/emnist_network_trainer.py index c175470..b848c29 100644 --- a/emnist_network_trainer.py +++ b/emnist_network_trainer.py @@ -5,7 +5,7 @@ import tqdm from data_providers import EMNISTDataProvider from network_builder import ClassifierNetworkGraph from utils.parser_utils import ParserClass -from utils.storage import build_experiment_folder, save_statistics +from utils.storage import build_experiment_folder, save_statistics, get_best_validation_model_statistics tf.reset_default_graph() # resets any previous graphs to clear memory parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser @@ -64,8 +64,6 @@ total_train_batches = train_data.num_batches total_val_batches = val_data.num_batches total_test_batches = test_data.num_batches -best_epoch = 0 - if tensorboard_enable: print("saved tensorboard file at", logs_filepath) writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph()) @@ -76,13 +74,16 @@ with tf.Session() as sess: sess.run(init) # actually running the initialization op train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of val_saver = tf.train.Saver() + best_val_accuracy = 0. + best_epoch = 0 # training or inference if continue_from_epoch != -1: train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name, - continue_from_epoch)) # restore previous graph to continue operations + continue_from_epoch)) # restore previous graph to continue operations + best_val_accuracy, best_epoch = get_best_validation_model_statistics(logs_filepath, "result_summary_statistics") + print(best_val_accuracy, best_epoch) - best_val_accuracy = 0. with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar: for e in range(start_epoch, epochs): total_c_loss = 0. diff --git a/msd10_network_trainer.py b/msd10_network_trainer.py index 57afdae..3e43939 100644 --- a/msd10_network_trainer.py +++ b/msd10_network_trainer.py @@ -5,7 +5,7 @@ import tqdm from data_providers import MSD10GenreDataProvider from network_builder import ClassifierNetworkGraph from utils.parser_utils import ParserClass -from utils.storage import build_experiment_folder, save_statistics +from utils.storage import build_experiment_folder, save_statistics, get_best_validation_model_statistics tf.reset_default_graph() # resets any previous graphs to clear memory parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser @@ -64,8 +64,6 @@ total_train_batches = train_data.num_batches total_val_batches = val_data.num_batches total_test_batches = test_data.num_batches -best_epoch = 0 - if tensorboard_enable: print("saved tensorboard file at", logs_filepath) writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph()) @@ -76,14 +74,16 @@ with tf.Session() as sess: sess.run(init) # actually running the initialization op train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of val_saver = tf.train.Saver() + best_val_accuracy = 0. + best_epoch = 0 # training or inference - if continue_from_epoch != -1: train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name, - continue_from_epoch)) # restore previous graph to continue operations + continue_from_epoch)) # restore previous graph to continue operations + best_val_accuracy, best_epoch = get_best_validation_model_statistics(logs_filepath, "result_summary_statistics") + print(best_val_accuracy, best_epoch) - best_val_accuracy = 0. with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar: for e in range(start_epoch, epochs): total_c_loss = 0. diff --git a/msd25_network_trainer.py b/msd25_network_trainer.py index 81247c6..3e43939 100644 --- a/msd25_network_trainer.py +++ b/msd25_network_trainer.py @@ -5,7 +5,7 @@ import tqdm from data_providers import MSD10GenreDataProvider from network_builder import ClassifierNetworkGraph from utils.parser_utils import ParserClass -from utils.storage import build_experiment_folder, save_statistics +from utils.storage import build_experiment_folder, save_statistics, get_best_validation_model_statistics tf.reset_default_graph() # resets any previous graphs to clear memory parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser @@ -64,8 +64,6 @@ total_train_batches = train_data.num_batches total_val_batches = val_data.num_batches total_test_batches = test_data.num_batches -best_epoch = 0 - if tensorboard_enable: print("saved tensorboard file at", logs_filepath) writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph()) @@ -76,13 +74,16 @@ with tf.Session() as sess: sess.run(init) # actually running the initialization op train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of val_saver = tf.train.Saver() + best_val_accuracy = 0. + best_epoch = 0 # training or inference if continue_from_epoch != -1: train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name, - continue_from_epoch)) # restore previous graph to continue operations + continue_from_epoch)) # restore previous graph to continue operations + best_val_accuracy, best_epoch = get_best_validation_model_statistics(logs_filepath, "result_summary_statistics") + print(best_val_accuracy, best_epoch) - best_val_accuracy = 0. with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar: for e in range(start_epoch, epochs): total_c_loss = 0. diff --git a/utils/storage.py b/utils/storage.py index 4d63efc..ed06814 100644 --- a/utils/storage.py +++ b/utils/storage.py @@ -1,5 +1,5 @@ import csv - +import numpy as np def save_statistics(log_dir, statistics_file_name, list_of_statistics, create=False): """ @@ -42,6 +42,20 @@ def load_statistics(log_dir, statistics_file_name): data_dict[key].append(item) return data_dict +def get_best_validation_model_statistics(log_dir, statistics_file_name): + """ + Returns the best val epoch and val accuracy from a log csv file + :param log_dir: The log directory the file is saved in + :param statistics_file_name: The log file name + :return: The best validation accuracy and the epoch at which it is produced + """ + log_file_dict = load_statistics(statistics_file_name=statistics_file_name, log_dir=log_dir) + val_acc = np.array(log_file_dict['val_c_accuracy'], dtype=np.float32) + best_val_acc = np.max(val_acc) + best_val_epoch = np.argmax(val_acc) + + return best_val_acc, best_val_epoch + def build_experiment_folder(experiment_name, log_path): saved_models_filepath = "{}/{}/{}".format(log_path, experiment_name.replace("%.%", "/"), "saved_models")