Add support for retrieving best epoch and val score when reloading a model
This commit is contained in:
parent
f707b10ef0
commit
8ac193d906
@ -5,7 +5,7 @@ import tqdm
|
|||||||
from data_providers import CIFAR100DataProvider
|
from data_providers import CIFAR100DataProvider
|
||||||
from network_builder import ClassifierNetworkGraph
|
from network_builder import ClassifierNetworkGraph
|
||||||
from utils.parser_utils import ParserClass
|
from utils.parser_utils import ParserClass
|
||||||
from utils.storage import build_experiment_folder, save_statistics
|
from utils.storage import build_experiment_folder, save_statistics, get_best_validation_model_statistics
|
||||||
|
|
||||||
tf.reset_default_graph() # resets any previous graphs to clear memory
|
tf.reset_default_graph() # resets any previous graphs to clear memory
|
||||||
parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser
|
parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser
|
||||||
@ -64,8 +64,6 @@ total_train_batches = train_data.num_batches
|
|||||||
total_val_batches = val_data.num_batches
|
total_val_batches = val_data.num_batches
|
||||||
total_test_batches = test_data.num_batches
|
total_test_batches = test_data.num_batches
|
||||||
|
|
||||||
best_epoch = 0
|
|
||||||
|
|
||||||
if tensorboard_enable:
|
if tensorboard_enable:
|
||||||
print("saved tensorboard file at", logs_filepath)
|
print("saved tensorboard file at", logs_filepath)
|
||||||
writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph())
|
writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph())
|
||||||
@ -76,13 +74,16 @@ with tf.Session() as sess:
|
|||||||
sess.run(init) # actually running the initialization op
|
sess.run(init) # actually running the initialization op
|
||||||
train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of
|
train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of
|
||||||
val_saver = tf.train.Saver()
|
val_saver = tf.train.Saver()
|
||||||
|
best_val_accuracy = 0.
|
||||||
|
best_epoch = 0
|
||||||
# training or inference
|
# training or inference
|
||||||
|
|
||||||
if continue_from_epoch != -1:
|
if continue_from_epoch != -1:
|
||||||
train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name,
|
train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name,
|
||||||
continue_from_epoch)) # restore previous graph to continue operations
|
continue_from_epoch)) # restore previous graph to continue operations
|
||||||
|
best_val_accuracy, best_epoch = get_best_validation_model_statistics(logs_filepath, "result_summary_statistics")
|
||||||
|
print(best_val_accuracy, best_epoch)
|
||||||
|
|
||||||
best_val_accuracy = 0.
|
|
||||||
with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar:
|
with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar:
|
||||||
for e in range(start_epoch, epochs):
|
for e in range(start_epoch, epochs):
|
||||||
total_c_loss = 0.
|
total_c_loss = 0.
|
||||||
|
@ -5,7 +5,7 @@ import tqdm
|
|||||||
from data_providers import CIFAR10DataProvider
|
from data_providers import CIFAR10DataProvider
|
||||||
from network_builder import ClassifierNetworkGraph
|
from network_builder import ClassifierNetworkGraph
|
||||||
from utils.parser_utils import ParserClass
|
from utils.parser_utils import ParserClass
|
||||||
from utils.storage import build_experiment_folder, save_statistics
|
from utils.storage import build_experiment_folder, save_statistics, get_best_validation_model_statistics
|
||||||
|
|
||||||
tf.reset_default_graph() # resets any previous graphs to clear memory
|
tf.reset_default_graph() # resets any previous graphs to clear memory
|
||||||
parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser
|
parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser
|
||||||
@ -64,7 +64,7 @@ total_train_batches = train_data.num_batches
|
|||||||
total_val_batches = val_data.num_batches
|
total_val_batches = val_data.num_batches
|
||||||
total_test_batches = test_data.num_batches
|
total_test_batches = test_data.num_batches
|
||||||
|
|
||||||
best_epoch = 0
|
|
||||||
|
|
||||||
if tensorboard_enable:
|
if tensorboard_enable:
|
||||||
print("saved tensorboard file at", logs_filepath)
|
print("saved tensorboard file at", logs_filepath)
|
||||||
@ -76,13 +76,16 @@ with tf.Session() as sess:
|
|||||||
sess.run(init) # actually running the initialization op
|
sess.run(init) # actually running the initialization op
|
||||||
train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of
|
train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of
|
||||||
val_saver = tf.train.Saver()
|
val_saver = tf.train.Saver()
|
||||||
|
best_val_accuracy = 0.
|
||||||
|
best_epoch = 0
|
||||||
# training or inference
|
# training or inference
|
||||||
|
|
||||||
if continue_from_epoch != -1:
|
if continue_from_epoch != -1:
|
||||||
train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name,
|
train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name,
|
||||||
continue_from_epoch)) # restore previous graph to continue operations
|
continue_from_epoch)) # restore previous graph to continue operations
|
||||||
|
best_val_accuracy, best_epoch = get_best_validation_model_statistics(logs_filepath, "result_summary_statistics")
|
||||||
|
print(best_val_accuracy, best_epoch)
|
||||||
|
|
||||||
best_val_accuracy = 0.
|
|
||||||
with tqdm.tqdm(total=epochs-start_epoch) as epoch_pbar:
|
with tqdm.tqdm(total=epochs-start_epoch) as epoch_pbar:
|
||||||
for e in range(start_epoch, epochs):
|
for e in range(start_epoch, epochs):
|
||||||
total_c_loss = 0.
|
total_c_loss = 0.
|
||||||
|
@ -5,7 +5,7 @@ import tqdm
|
|||||||
from data_providers import EMNISTDataProvider
|
from data_providers import EMNISTDataProvider
|
||||||
from network_builder import ClassifierNetworkGraph
|
from network_builder import ClassifierNetworkGraph
|
||||||
from utils.parser_utils import ParserClass
|
from utils.parser_utils import ParserClass
|
||||||
from utils.storage import build_experiment_folder, save_statistics
|
from utils.storage import build_experiment_folder, save_statistics, get_best_validation_model_statistics
|
||||||
|
|
||||||
tf.reset_default_graph() # resets any previous graphs to clear memory
|
tf.reset_default_graph() # resets any previous graphs to clear memory
|
||||||
parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser
|
parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser
|
||||||
@ -64,8 +64,6 @@ total_train_batches = train_data.num_batches
|
|||||||
total_val_batches = val_data.num_batches
|
total_val_batches = val_data.num_batches
|
||||||
total_test_batches = test_data.num_batches
|
total_test_batches = test_data.num_batches
|
||||||
|
|
||||||
best_epoch = 0
|
|
||||||
|
|
||||||
if tensorboard_enable:
|
if tensorboard_enable:
|
||||||
print("saved tensorboard file at", logs_filepath)
|
print("saved tensorboard file at", logs_filepath)
|
||||||
writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph())
|
writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph())
|
||||||
@ -76,13 +74,16 @@ with tf.Session() as sess:
|
|||||||
sess.run(init) # actually running the initialization op
|
sess.run(init) # actually running the initialization op
|
||||||
train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of
|
train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of
|
||||||
val_saver = tf.train.Saver()
|
val_saver = tf.train.Saver()
|
||||||
|
best_val_accuracy = 0.
|
||||||
|
best_epoch = 0
|
||||||
# training or inference
|
# training or inference
|
||||||
|
|
||||||
if continue_from_epoch != -1:
|
if continue_from_epoch != -1:
|
||||||
train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name,
|
train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name,
|
||||||
continue_from_epoch)) # restore previous graph to continue operations
|
continue_from_epoch)) # restore previous graph to continue operations
|
||||||
|
best_val_accuracy, best_epoch = get_best_validation_model_statistics(logs_filepath, "result_summary_statistics")
|
||||||
|
print(best_val_accuracy, best_epoch)
|
||||||
|
|
||||||
best_val_accuracy = 0.
|
|
||||||
with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar:
|
with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar:
|
||||||
for e in range(start_epoch, epochs):
|
for e in range(start_epoch, epochs):
|
||||||
total_c_loss = 0.
|
total_c_loss = 0.
|
||||||
|
@ -5,7 +5,7 @@ import tqdm
|
|||||||
from data_providers import MSD10GenreDataProvider
|
from data_providers import MSD10GenreDataProvider
|
||||||
from network_builder import ClassifierNetworkGraph
|
from network_builder import ClassifierNetworkGraph
|
||||||
from utils.parser_utils import ParserClass
|
from utils.parser_utils import ParserClass
|
||||||
from utils.storage import build_experiment_folder, save_statistics
|
from utils.storage import build_experiment_folder, save_statistics, get_best_validation_model_statistics
|
||||||
|
|
||||||
tf.reset_default_graph() # resets any previous graphs to clear memory
|
tf.reset_default_graph() # resets any previous graphs to clear memory
|
||||||
parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser
|
parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser
|
||||||
@ -64,8 +64,6 @@ total_train_batches = train_data.num_batches
|
|||||||
total_val_batches = val_data.num_batches
|
total_val_batches = val_data.num_batches
|
||||||
total_test_batches = test_data.num_batches
|
total_test_batches = test_data.num_batches
|
||||||
|
|
||||||
best_epoch = 0
|
|
||||||
|
|
||||||
if tensorboard_enable:
|
if tensorboard_enable:
|
||||||
print("saved tensorboard file at", logs_filepath)
|
print("saved tensorboard file at", logs_filepath)
|
||||||
writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph())
|
writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph())
|
||||||
@ -76,14 +74,16 @@ with tf.Session() as sess:
|
|||||||
sess.run(init) # actually running the initialization op
|
sess.run(init) # actually running the initialization op
|
||||||
train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of
|
train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of
|
||||||
val_saver = tf.train.Saver()
|
val_saver = tf.train.Saver()
|
||||||
|
best_val_accuracy = 0.
|
||||||
|
best_epoch = 0
|
||||||
# training or inference
|
# training or inference
|
||||||
|
|
||||||
|
|
||||||
if continue_from_epoch != -1:
|
if continue_from_epoch != -1:
|
||||||
train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name,
|
train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name,
|
||||||
continue_from_epoch)) # restore previous graph to continue operations
|
continue_from_epoch)) # restore previous graph to continue operations
|
||||||
|
best_val_accuracy, best_epoch = get_best_validation_model_statistics(logs_filepath, "result_summary_statistics")
|
||||||
|
print(best_val_accuracy, best_epoch)
|
||||||
|
|
||||||
best_val_accuracy = 0.
|
|
||||||
with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar:
|
with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar:
|
||||||
for e in range(start_epoch, epochs):
|
for e in range(start_epoch, epochs):
|
||||||
total_c_loss = 0.
|
total_c_loss = 0.
|
||||||
|
@ -5,7 +5,7 @@ import tqdm
|
|||||||
from data_providers import MSD10GenreDataProvider
|
from data_providers import MSD10GenreDataProvider
|
||||||
from network_builder import ClassifierNetworkGraph
|
from network_builder import ClassifierNetworkGraph
|
||||||
from utils.parser_utils import ParserClass
|
from utils.parser_utils import ParserClass
|
||||||
from utils.storage import build_experiment_folder, save_statistics
|
from utils.storage import build_experiment_folder, save_statistics, get_best_validation_model_statistics
|
||||||
|
|
||||||
tf.reset_default_graph() # resets any previous graphs to clear memory
|
tf.reset_default_graph() # resets any previous graphs to clear memory
|
||||||
parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser
|
parser = argparse.ArgumentParser(description='Welcome to CNN experiments script') # generates an argument parser
|
||||||
@ -64,8 +64,6 @@ total_train_batches = train_data.num_batches
|
|||||||
total_val_batches = val_data.num_batches
|
total_val_batches = val_data.num_batches
|
||||||
total_test_batches = test_data.num_batches
|
total_test_batches = test_data.num_batches
|
||||||
|
|
||||||
best_epoch = 0
|
|
||||||
|
|
||||||
if tensorboard_enable:
|
if tensorboard_enable:
|
||||||
print("saved tensorboard file at", logs_filepath)
|
print("saved tensorboard file at", logs_filepath)
|
||||||
writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph())
|
writer = tf.summary.FileWriter(logs_filepath, graph=tf.get_default_graph())
|
||||||
@ -76,13 +74,16 @@ with tf.Session() as sess:
|
|||||||
sess.run(init) # actually running the initialization op
|
sess.run(init) # actually running the initialization op
|
||||||
train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of
|
train_saver = tf.train.Saver() # saver object that will save our graph so we can reload it later for continuation of
|
||||||
val_saver = tf.train.Saver()
|
val_saver = tf.train.Saver()
|
||||||
|
best_val_accuracy = 0.
|
||||||
|
best_epoch = 0
|
||||||
# training or inference
|
# training or inference
|
||||||
|
|
||||||
if continue_from_epoch != -1:
|
if continue_from_epoch != -1:
|
||||||
train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name,
|
train_saver.restore(sess, "{}/{}_{}.ckpt".format(saved_models_filepath, experiment_name,
|
||||||
continue_from_epoch)) # restore previous graph to continue operations
|
continue_from_epoch)) # restore previous graph to continue operations
|
||||||
|
best_val_accuracy, best_epoch = get_best_validation_model_statistics(logs_filepath, "result_summary_statistics")
|
||||||
|
print(best_val_accuracy, best_epoch)
|
||||||
|
|
||||||
best_val_accuracy = 0.
|
|
||||||
with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar:
|
with tqdm.tqdm(total=epochs - start_epoch) as epoch_pbar:
|
||||||
for e in range(start_epoch, epochs):
|
for e in range(start_epoch, epochs):
|
||||||
total_c_loss = 0.
|
total_c_loss = 0.
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import csv
|
import csv
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def save_statistics(log_dir, statistics_file_name, list_of_statistics, create=False):
|
def save_statistics(log_dir, statistics_file_name, list_of_statistics, create=False):
|
||||||
"""
|
"""
|
||||||
@ -42,6 +42,20 @@ def load_statistics(log_dir, statistics_file_name):
|
|||||||
data_dict[key].append(item)
|
data_dict[key].append(item)
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
|
def get_best_validation_model_statistics(log_dir, statistics_file_name):
|
||||||
|
"""
|
||||||
|
Returns the best val epoch and val accuracy from a log csv file
|
||||||
|
:param log_dir: The log directory the file is saved in
|
||||||
|
:param statistics_file_name: The log file name
|
||||||
|
:return: The best validation accuracy and the epoch at which it is produced
|
||||||
|
"""
|
||||||
|
log_file_dict = load_statistics(statistics_file_name=statistics_file_name, log_dir=log_dir)
|
||||||
|
val_acc = np.array(log_file_dict['val_c_accuracy'], dtype=np.float32)
|
||||||
|
best_val_acc = np.max(val_acc)
|
||||||
|
best_val_epoch = np.argmax(val_acc)
|
||||||
|
|
||||||
|
return best_val_acc, best_val_epoch
|
||||||
|
|
||||||
|
|
||||||
def build_experiment_folder(experiment_name, log_path):
|
def build_experiment_folder(experiment_name, log_path):
|
||||||
saved_models_filepath = "{}/{}/{}".format(log_path, experiment_name.replace("%.%", "/"), "saved_models")
|
saved_models_filepath = "{}/{}/{}".format(log_path, experiment_name.replace("%.%", "/"), "saved_models")
|
||||||
|
Loading…
Reference in New Issue
Block a user