71 lines
2.7 KiB
Python
71 lines
2.7 KiB
Python
import csv
|
|
import numpy as np
|
|
|
|
def save_statistics(log_dir, statistics_file_name, list_of_statistics, create=False):
|
|
"""
|
|
Saves a statistics .csv file with the statistics
|
|
:param log_dir: Directory of log
|
|
:param statistics_file_name: Name of .csv file
|
|
:param list_of_statistics: A list of statistics to add in the file
|
|
:param create: If True creates a new file, if False adds list to existing
|
|
"""
|
|
if create:
|
|
with open("{}/{}.csv".format(log_dir, statistics_file_name), 'w+') as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow(list_of_statistics)
|
|
else:
|
|
with open("{}/{}.csv".format(log_dir, statistics_file_name), 'a') as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow(list_of_statistics)
|
|
|
|
|
|
def load_statistics(log_dir, statistics_file_name):
|
|
"""
|
|
Loads the statistics in a dictionary.
|
|
:param log_dir: The directory in which the log is saved
|
|
:param statistics_file_name: The name of the statistics file
|
|
:return: A dict with the statistics
|
|
"""
|
|
data_dict = dict()
|
|
with open("{}/{}.csv".format(log_dir, statistics_file_name), 'r') as f:
|
|
lines = f.readlines()
|
|
data_labels = lines[0].replace("\n", "").replace("\r", "").split(",")
|
|
del lines[0]
|
|
|
|
for label in data_labels:
|
|
data_dict[label] = []
|
|
|
|
for line in lines:
|
|
data = line.replace("\n", "").replace("\r", "").split(",")
|
|
for key, item in zip(data_labels, data):
|
|
if item not in data_labels:
|
|
data_dict[key].append(item)
|
|
return data_dict
|
|
|
|
def get_best_validation_model_statistics(log_dir, statistics_file_name):
|
|
"""
|
|
Returns the best val epoch and val accuracy from a log csv file
|
|
:param log_dir: The log directory the file is saved in
|
|
:param statistics_file_name: The log file name
|
|
:return: The best validation accuracy and the epoch at which it is produced
|
|
"""
|
|
log_file_dict = load_statistics(statistics_file_name=statistics_file_name, log_dir=log_dir)
|
|
val_acc = np.array(log_file_dict['val_c_accuracy'], dtype=np.float32)
|
|
best_val_acc = np.max(val_acc)
|
|
best_val_epoch = np.argmax(val_acc)
|
|
|
|
return best_val_acc, best_val_epoch
|
|
|
|
|
|
def build_experiment_folder(experiment_name, log_path):
|
|
saved_models_filepath = "{}/{}/{}".format(log_path, experiment_name.replace("%.%", "/"), "saved_models")
|
|
logs_filepath = "{}/{}/{}".format(log_path, experiment_name.replace("%.%", "/"), "summary_logs")
|
|
import os
|
|
|
|
if not os.path.exists(logs_filepath):
|
|
os.makedirs(logs_filepath)
|
|
if not os.path.exists(saved_models_filepath):
|
|
os.makedirs(saved_models_filepath)
|
|
|
|
return saved_models_filepath, logs_filepath
|