import pickle import os import csv def save_to_stats_pkl_file(experiment_log_filepath, filename, stats_dict): summary_filename = os.path.join(experiment_log_filepath, filename) with open("{}.pkl".format(summary_filename), "wb") as file_writer: pickle.dump(stats_dict, file_writer) def load_from_stats_pkl_file(experiment_log_filepath, filename): summary_filename = os.path.join(experiment_log_filepath, filename) with open("{}.pkl".format(summary_filename), "rb") as file_reader: stats = pickle.load(file_reader) return stats def save_statistics(experiment_log_dir, filename, stats_dict, current_epoch, continue_from_mode=False, save_full_dict=False): """ Saves the statistics in stats dict into a csv file. Using the keys as the header entries and the values as the columns of a particular header entry :param experiment_log_dir: the log folder dir filepath :param filename: the name of the csv file :param stats_dict: the stats dict containing the data to be saved :param current_epoch: the number of epochs since commencement of the current training session (i.e. if the experiment continued from 100 and this is epoch 105, then pass relative distance of 5.) :param save_full_dict: whether to save the full dict as is overriding any previous entries (might be useful if we want to overwrite a file) :return: The filepath to the summary file """ summary_filename = os.path.join(experiment_log_dir, filename) mode = 'a' if continue_from_mode else 'w' with open(summary_filename, mode) as f: writer = csv.writer(f) if not continue_from_mode: writer.writerow(list(stats_dict.keys())) if save_full_dict: total_rows = len(list(stats_dict.values())[0]) for idx in range(total_rows): row_to_add = [value[idx] for value in list(stats_dict.values())] writer.writerow(row_to_add) else: row_to_add = [value[current_epoch] for value in list(stats_dict.values())] writer.writerow(row_to_add) return summary_filename def load_statistics(experiment_log_dir, filename): """ Loads a statistics csv file into a dictionary :param experiment_log_dir: the log folder dir filepath :param filename: the name of the csv file to load :return: A dictionary containing the stats in the csv file. Header entries are converted into keys and columns of a particular header are converted into values of a key in a list format. """ summary_filename = os.path.join(experiment_log_dir, filename) with open(summary_filename, 'r+') as f: lines = f.readlines() keys = lines[0].split(",") stats = {key: [] for key in keys} for line in lines[1:]: values = line.split(",") for idx, value in enumerate(values): stats[keys[idx]].append(value) return stats