71 lines
2.9 KiB
Python
71 lines
2.9 KiB
Python
|
import pickle
|
||
|
import os
|
||
|
import csv
|
||
|
|
||
|
|
||
|
def save_to_stats_pkl_file(experiment_log_filepath, filename, stats_dict):
|
||
|
summary_filename = os.path.join(experiment_log_filepath, filename)
|
||
|
with open("{}.pkl".format(summary_filename), "wb") as file_writer:
|
||
|
pickle.dump(stats_dict, file_writer)
|
||
|
|
||
|
|
||
|
def load_from_stats_pkl_file(experiment_log_filepath, filename):
|
||
|
summary_filename = os.path.join(experiment_log_filepath, filename)
|
||
|
with open("{}.pkl".format(summary_filename), "rb") as file_reader:
|
||
|
stats = pickle.load(file_reader)
|
||
|
|
||
|
return stats
|
||
|
|
||
|
|
||
|
def save_statistics(experiment_log_dir, filename, stats_dict, current_epoch, continue_from_mode=False, save_full_dict=False):
|
||
|
"""
|
||
|
Saves the statistics in stats dict into a csv file. Using the keys as the header entries and the values as the
|
||
|
columns of a particular header entry
|
||
|
:param experiment_log_dir: the log folder dir filepath
|
||
|
:param filename: the name of the csv file
|
||
|
:param stats_dict: the stats dict containing the data to be saved
|
||
|
:param current_epoch: the number of epochs since commencement of the current training session (i.e. if the experiment continued from 100 and this is epoch 105, then pass relative distance of 5.)
|
||
|
:param save_full_dict: whether to save the full dict as is overriding any previous entries (might be useful if we want to overwrite a file)
|
||
|
:return: The filepath to the summary file
|
||
|
"""
|
||
|
summary_filename = os.path.join(experiment_log_dir, filename)
|
||
|
mode = 'a' if continue_from_mode else 'w'
|
||
|
with open(summary_filename, mode) as f:
|
||
|
writer = csv.writer(f)
|
||
|
if not continue_from_mode:
|
||
|
writer.writerow(list(stats_dict.keys()))
|
||
|
|
||
|
if save_full_dict:
|
||
|
total_rows = len(list(stats_dict.values())[0])
|
||
|
for idx in range(total_rows):
|
||
|
row_to_add = [value[idx] for value in list(stats_dict.values())]
|
||
|
writer.writerow(row_to_add)
|
||
|
else:
|
||
|
row_to_add = [value[current_epoch] for value in list(stats_dict.values())]
|
||
|
writer.writerow(row_to_add)
|
||
|
|
||
|
return summary_filename
|
||
|
|
||
|
|
||
|
def load_statistics(experiment_log_dir, filename):
|
||
|
"""
|
||
|
Loads a statistics csv file into a dictionary
|
||
|
:param experiment_log_dir: the log folder dir filepath
|
||
|
:param filename: the name of the csv file to load
|
||
|
:return: A dictionary containing the stats in the csv file. Header entries are converted into keys and columns of a
|
||
|
particular header are converted into values of a key in a list format.
|
||
|
"""
|
||
|
summary_filename = os.path.join(experiment_log_dir, filename)
|
||
|
|
||
|
with open(summary_filename, 'r+') as f:
|
||
|
lines = f.readlines()
|
||
|
|
||
|
keys = lines[0].split(",")
|
||
|
stats = {key: [] for key in keys}
|
||
|
for line in lines[1:]:
|
||
|
values = line.split(",")
|
||
|
for idx, value in enumerate(values):
|
||
|
stats[keys[idx]].append(value)
|
||
|
|
||
|
return stats
|