2024-11-11 14:00:28 +01:00
import torch
import torch . nn as nn
import torch . optim as optim
import torch . nn . functional as F
import tqdm
import os
import numpy as np
import time
from pytorch_mlp_framework . storage_utils import save_statistics
from matplotlib import pyplot as plt
import matplotlib
2024-11-19 10:38:29 +01:00
matplotlib . rcParams . update ( { " font.size " : 8 } )
2024-11-11 14:00:28 +01:00
class ExperimentBuilder ( nn . Module ) :
2024-11-19 10:38:29 +01:00
def __init__ (
self ,
network_model ,
experiment_name ,
num_epochs ,
train_data ,
val_data ,
test_data ,
weight_decay_coefficient ,
use_gpu ,
continue_from_epoch = - 1 ,
) :
2024-11-11 14:00:28 +01:00
"""
Initializes an ExperimentBuilder object . Such an object takes care of running training and evaluation of a deep net
on a given dataset . It also takes care of saving per epoch models and automatically inferring the best val model
to be used for evaluating the test set metrics .
: param network_model : A pytorch nn . Module which implements a network architecture .
: param experiment_name : The name of the experiment . This is used mainly for keeping track of the experiment and creating and directory structure that will be used to save logs , model parameters and other .
: param num_epochs : Total number of epochs to run the experiment
: param train_data : An object of the DataProvider type . Contains the training set .
: param val_data : An object of the DataProvider type . Contains the val set .
: param test_data : An object of the DataProvider type . Contains the test set .
: param weight_decay_coefficient : A float indicating the weight decay to use with the adam optimizer .
: param use_gpu : A boolean indicating whether to use a GPU or not .
: param continue_from_epoch : An int indicating whether we ' ll start from scrach (-1) or whether we ' ll reload a previously saved model of epoch ' continue_from_epoch ' and continue training from there .
"""
super ( ExperimentBuilder , self ) . __init__ ( )
self . experiment_name = experiment_name
self . model = network_model
if torch . cuda . device_count ( ) > = 1 and use_gpu :
2024-11-19 10:38:29 +01:00
self . device = torch . device ( " cuda " )
2024-11-11 14:00:28 +01:00
self . model . to ( self . device ) # sends the model from the cpu to the gpu
2024-11-19 10:38:29 +01:00
print ( " Use GPU " , self . device )
2024-11-11 14:00:28 +01:00
else :
print ( " use CPU " )
2024-11-19 10:38:29 +01:00
self . device = torch . device ( " cpu " ) # sets the device to be CPU
2024-11-11 14:00:28 +01:00
print ( self . device )
2024-11-19 10:38:29 +01:00
print ( " here " )
2024-11-11 14:00:28 +01:00
self . model . reset_parameters ( ) # re-initialize network parameters
self . train_data = train_data
self . val_data = val_data
self . test_data = test_data
2024-11-19 10:38:29 +01:00
print ( " System learnable parameters " )
2024-11-11 14:00:28 +01:00
num_conv_layers = 0
num_linear_layers = 0
total_num_parameters = 0
for name , value in self . named_parameters ( ) :
print ( name , value . shape )
2024-11-19 10:38:29 +01:00
if all ( item in name for item in [ " conv " , " weight " ] ) :
2024-11-11 14:00:28 +01:00
num_conv_layers + = 1
2024-11-19 10:38:29 +01:00
if all ( item in name for item in [ " linear " , " weight " ] ) :
2024-11-11 14:00:28 +01:00
num_linear_layers + = 1
total_num_parameters + = np . prod ( value . shape )
2024-11-19 10:38:29 +01:00
print ( " Total number of parameters " , total_num_parameters )
print ( " Total number of conv layers " , num_conv_layers )
print ( " Total number of linear layers " , num_linear_layers )
2024-11-11 14:00:28 +01:00
2024-11-19 10:38:29 +01:00
self . optimizer = optim . Adam (
self . parameters ( ) , amsgrad = False , weight_decay = weight_decay_coefficient
)
self . learning_rate_scheduler = optim . lr_scheduler . CosineAnnealingLR (
self . optimizer , T_max = num_epochs , eta_min = 0.00002
)
2024-11-11 14:00:28 +01:00
# Generate the directory names
self . experiment_folder = os . path . abspath ( experiment_name )
2024-11-19 10:38:29 +01:00
self . experiment_logs = os . path . abspath (
os . path . join ( self . experiment_folder , " result_outputs " )
)
self . experiment_saved_models = os . path . abspath (
os . path . join ( self . experiment_folder , " saved_models " )
)
2024-11-11 14:00:28 +01:00
# Set best models to be at 0 since we are just starting
self . best_val_model_idx = 0
2024-11-19 10:38:29 +01:00
self . best_val_model_acc = 0.0
2024-11-11 14:00:28 +01:00
2024-11-19 10:38:29 +01:00
if not os . path . exists (
self . experiment_folder
) : # If experiment directory does not exist
2024-11-11 14:00:28 +01:00
os . mkdir ( self . experiment_folder ) # create the experiment directory
os . mkdir ( self . experiment_logs ) # create the experiment log directory
2024-11-19 10:38:29 +01:00
os . mkdir (
self . experiment_saved_models
) # create the experiment saved models directory
2024-11-11 14:00:28 +01:00
self . num_epochs = num_epochs
2024-11-19 10:38:29 +01:00
self . criterion = nn . CrossEntropyLoss ( ) . to (
self . device
) # send the loss computation to the GPU
if (
continue_from_epoch == - 2
) : # if continue from epoch is -2 then continue from latest saved model
self . state , self . best_val_model_idx , self . best_val_model_acc = (
self . load_model (
model_save_dir = self . experiment_saved_models ,
model_save_name = " train_model " ,
model_idx = " latest " ,
)
) # reload existing model from epoch and return best val model index
2024-11-11 14:00:28 +01:00
# and the best val acc of that model
2024-11-19 10:38:29 +01:00
self . starting_epoch = int ( self . state [ " model_epoch " ] )
2024-11-11 14:00:28 +01:00
elif continue_from_epoch > - 1 : # if continue from epoch is greater than -1 then
2024-11-19 10:38:29 +01:00
self . state , self . best_val_model_idx , self . best_val_model_acc = (
self . load_model (
model_save_dir = self . experiment_saved_models ,
model_save_name = " train_model " ,
model_idx = continue_from_epoch ,
)
) # reload existing model from epoch and return best val model index
2024-11-11 14:00:28 +01:00
# and the best val acc of that model
self . starting_epoch = continue_from_epoch
else :
self . state = dict ( )
self . starting_epoch = 0
def get_num_parameters ( self ) :
total_num_params = 0
for param in self . parameters ( ) :
total_num_params + = np . prod ( param . shape )
return total_num_params
2024-11-18 12:26:52 +01:00
def plot_func_def ( self , all_grads , layers ) :
2024-11-11 14:00:28 +01:00
"""
Plot function definition to plot the average gradient with respect to the number of layers in the given model
: param all_grads : Gradients wrt weights for each layer in the model .
: param layers : Layer names corresponding to the model parameters
: return : plot for gradient flow
"""
plt . plot ( all_grads , alpha = 0.3 , color = " b " )
2024-11-19 10:38:29 +01:00
plt . hlines ( 0 , 0 , len ( all_grads ) + 1 , linewidth = 1 , color = " k " )
plt . xticks ( range ( 0 , len ( all_grads ) , 1 ) , layers , rotation = " vertical " )
2024-11-11 14:00:28 +01:00
plt . xlim ( xmin = 0 , xmax = len ( all_grads ) )
plt . xlabel ( " Layers " )
plt . ylabel ( " Average Gradient " )
plt . title ( " Gradient flow " )
plt . grid ( True )
plt . tight_layout ( )
2024-11-19 10:38:29 +01:00
2024-11-11 14:00:28 +01:00
return plt
2024-11-19 10:38:29 +01:00
2024-11-11 14:00:28 +01:00
def plot_grad_flow ( self , named_parameters ) :
"""
2024-11-19 10:38:29 +01:00
The function is being called in Line 298 of this file .
2024-11-11 14:00:28 +01:00
Receives the parameters of the model being trained . Returns plot of gradient flow for the given model parameters .
2024-11-19 10:38:29 +01:00
2024-11-11 14:00:28 +01:00
"""
all_grads = [ ]
layers = [ ]
2024-11-19 10:38:29 +01:00
2024-11-11 14:00:28 +01:00
"""
Complete the code in the block below to collect absolute mean of the gradients for each layer in all_grads with the layer names in layers .
"""
2024-11-18 12:26:52 +01:00
for name , param in named_parameters :
# Check if the parameter requires gradient and has a gradient
2024-11-19 10:38:29 +01:00
if param . requires_grad and param . grad is not None :
2024-11-18 12:26:52 +01:00
try :
_ , a , _ , b , _ = name . split ( " . " , 4 )
except :
b , a = name . split ( " . " , 1 )
layers . append ( f " { a } _ { b } " )
# Collect the mean of the absolute gradients
all_grads . append ( param . grad . abs ( ) . mean ( ) . item ( ) )
2024-11-11 14:00:28 +01:00
plt = self . plot_func_def ( all_grads , layers )
2024-11-18 12:26:52 +01:00
2024-11-11 14:00:28 +01:00
return plt
2024-11-18 12:26:52 +01:00
2024-11-11 14:00:28 +01:00
def run_train_iter ( self , x , y ) :
2024-11-19 10:38:29 +01:00
2024-11-11 14:00:28 +01:00
self . train ( ) # sets model to training mode (in case batch normalization or other methods have different procedures for training and evaluation)
x , y = x . float ( ) . to ( device = self . device ) , y . long ( ) . to (
2024-11-19 10:38:29 +01:00
device = self . device
) # send data to device as torch tensors
2024-11-11 14:00:28 +01:00
out = self . model . forward ( x ) # forward the data in the model
loss = F . cross_entropy ( input = out , target = y ) # compute loss
self . optimizer . zero_grad ( ) # set all weight grads from previous training iters to 0
loss . backward ( ) # backpropagate to compute gradients for current iter loss
2024-11-19 10:38:29 +01:00
2024-11-11 14:00:28 +01:00
self . optimizer . step ( ) # update network parameters
self . learning_rate_scheduler . step ( ) # update learning rate scheduler
2024-11-19 10:38:29 +01:00
2024-11-11 14:00:28 +01:00
_ , predicted = torch . max ( out . data , 1 ) # get argmax of predictions
accuracy = np . mean ( list ( predicted . eq ( y . data ) . cpu ( ) ) ) # compute accuracy
return loss . cpu ( ) . data . numpy ( ) , accuracy
def run_evaluation_iter ( self , x , y ) :
"""
Receives the inputs and targets for the model and runs an evaluation iterations . Returns loss and accuracy metrics .
: param x : The inputs to the model . A numpy array of shape batch_size , channels , height , width
: param y : The targets for the model . A numpy array of shape batch_size , num_classes
: return : the loss and accuracy for this batch
"""
self . eval ( ) # sets the system to validation mode
x , y = x . float ( ) . to ( device = self . device ) , y . long ( ) . to (
2024-11-19 10:38:29 +01:00
device = self . device
) # convert data to pytorch tensors and send to the computation device
2024-11-11 14:00:28 +01:00
out = self . model . forward ( x ) # forward the data in the model
loss = F . cross_entropy ( input = out , target = y ) # compute loss
_ , predicted = torch . max ( out . data , 1 ) # get argmax of predictions
accuracy = np . mean ( list ( predicted . eq ( y . data ) . cpu ( ) ) ) # compute accuracy
return loss . cpu ( ) . data . numpy ( ) , accuracy
2024-11-19 10:38:29 +01:00
def save_model (
self ,
model_save_dir ,
model_save_name ,
model_idx ,
best_validation_model_idx ,
best_validation_model_acc ,
) :
2024-11-11 14:00:28 +01:00
"""
Save the network parameter state and current best val epoch idx and best val accuracy .
: param model_save_name : Name to use to save model without the epoch index
: param model_idx : The index to save the model with .
: param best_validation_model_idx : The index of the best validation model to be stored for future use .
: param best_validation_model_acc : The best validation accuracy to be stored for use at test time .
: param model_save_dir : The directory to store the state at .
: param state : The dictionary containing the system state .
"""
2024-11-19 10:38:29 +01:00
self . state [ " network " ] = (
self . state_dict ( )
) # save network parameter and other variables.
self . state [ " best_val_model_idx " ] = (
best_validation_model_idx # save current best val idx
)
self . state [ " best_val_model_acc " ] = (
best_validation_model_acc # save current best val acc
)
torch . save (
self . state ,
f = os . path . join (
model_save_dir , " {} _ {} " . format ( model_save_name , str ( model_idx ) )
) ,
) # save state at prespecified filepath
2024-11-11 14:00:28 +01:00
def load_model ( self , model_save_dir , model_save_name , model_idx ) :
"""
Load the network parameter state and the best val model idx and best val acc to be compared with the future val accuracies , in order to choose the best val model
: param model_save_dir : The directory to store the state at .
: param model_save_name : Name to use to save model without the epoch index
: param model_idx : The index to save the model with .
: return : best val idx and best val model acc , also it loads the network state into the system state without returning it
"""
2024-11-19 10:38:29 +01:00
state = torch . load (
f = os . path . join (
model_save_dir , " {} _ {} " . format ( model_save_name , str ( model_idx ) )
)
)
self . load_state_dict ( state_dict = state [ " network " ] )
return state , state [ " best_val_model_idx " ] , state [ " best_val_model_acc " ]
2024-11-11 14:00:28 +01:00
def run_experiment ( self ) :
"""
Runs experiment train and evaluation iterations , saving the model and best val model and val model accuracy after each epoch
: return : The summary current_epoch_losses from starting epoch to total_epochs .
"""
2024-11-19 10:38:29 +01:00
total_losses = {
" train_acc " : [ ] ,
" train_loss " : [ ] ,
" val_acc " : [ ] ,
" val_loss " : [ ] ,
} # initialize a dict to keep the per-epoch metrics
2024-11-11 14:00:28 +01:00
for i , epoch_idx in enumerate ( range ( self . starting_epoch , self . num_epochs ) ) :
epoch_start_time = time . time ( )
2024-11-19 10:38:29 +01:00
current_epoch_losses = {
" train_acc " : [ ] ,
" train_loss " : [ ] ,
" val_acc " : [ ] ,
" val_loss " : [ ] ,
}
2024-11-11 14:00:28 +01:00
self . current_epoch = epoch_idx
2024-11-19 10:38:29 +01:00
with tqdm . tqdm (
total = len ( self . train_data )
) as pbar_train : # create a progress bar for training
2024-11-11 14:00:28 +01:00
for idx , ( x , y ) in enumerate ( self . train_data ) : # get data batches
2024-11-19 10:38:29 +01:00
loss , accuracy = self . run_train_iter (
x = x , y = y
) # take a training iter step
current_epoch_losses [ " train_loss " ] . append (
loss
) # add current iter loss to the train loss list
current_epoch_losses [ " train_acc " ] . append (
accuracy
) # add current iter acc to the train acc list
2024-11-11 14:00:28 +01:00
pbar_train . update ( 1 )
2024-11-19 10:38:29 +01:00
pbar_train . set_description (
" loss: {:.4f} , accuracy: {:.4f} " . format ( loss , accuracy )
)
2024-11-11 14:00:28 +01:00
2024-11-19 10:38:29 +01:00
with tqdm . tqdm (
total = len ( self . val_data )
) as pbar_val : # create a progress bar for validation
2024-11-11 14:00:28 +01:00
for x , y in self . val_data : # get data batches
2024-11-19 10:38:29 +01:00
loss , accuracy = self . run_evaluation_iter (
x = x , y = y
) # run a validation iter
current_epoch_losses [ " val_loss " ] . append (
loss
) # add current iter loss to val loss list.
current_epoch_losses [ " val_acc " ] . append (
accuracy
) # add current iter acc to val acc lst.
2024-11-11 14:00:28 +01:00
pbar_val . update ( 1 ) # add 1 step to the progress bar
2024-11-19 10:38:29 +01:00
pbar_val . set_description (
" loss: {:.4f} , accuracy: {:.4f} " . format ( loss , accuracy )
)
val_mean_accuracy = np . mean ( current_epoch_losses [ " val_acc " ] )
if (
val_mean_accuracy > self . best_val_model_acc
) : # if current epoch's mean val acc is greater than the saved best val acc then
2024-11-11 14:00:28 +01:00
self . best_val_model_acc = val_mean_accuracy # set the best val model acc to be current epoch's val accuracy
self . best_val_model_idx = epoch_idx # set the experiment-wise best val idx to be the current epoch's idx
for key , value in current_epoch_losses . items ( ) :
2024-11-19 10:38:29 +01:00
total_losses [ key ] . append (
np . mean ( value )
) # get mean of all metrics of current epoch metrics dict, to get them ready for storage and output on the terminal.
save_statistics (
experiment_log_dir = self . experiment_logs ,
filename = " summary.csv " ,
stats_dict = total_losses ,
current_epoch = i ,
continue_from_mode = (
True if ( self . starting_epoch != 0 or i > 0 ) else False
) ,
) # save statistics to stats file.
2024-11-11 14:00:28 +01:00
# load_statistics(experiment_log_dir=self.experiment_logs, filename='summary.csv') # How to load a csv file if you need to
out_string = " _ " . join (
2024-11-19 10:38:29 +01:00
[
" {} _ {:.4f} " . format ( key , np . mean ( value ) )
for key , value in current_epoch_losses . items ( )
]
)
2024-11-11 14:00:28 +01:00
# create a string to use to report our epoch metrics
2024-11-19 10:38:29 +01:00
epoch_elapsed_time = (
time . time ( ) - epoch_start_time
) # calculate time taken for epoch
2024-11-11 14:00:28 +01:00
epoch_elapsed_time = " {:.4f} " . format ( epoch_elapsed_time )
2024-11-19 10:38:29 +01:00
print (
" Epoch {} : " . format ( epoch_idx ) ,
out_string ,
" epoch time " ,
epoch_elapsed_time ,
" seconds " ,
)
self . state [ " model_epoch " ] = epoch_idx
self . save_model (
model_save_dir = self . experiment_saved_models ,
# save model and best val idx and best val acc, using the model dir, model name and model idx
model_save_name = " train_model " ,
model_idx = epoch_idx ,
best_validation_model_idx = self . best_val_model_idx ,
best_validation_model_acc = self . best_val_model_acc ,
)
self . save_model (
model_save_dir = self . experiment_saved_models ,
# save model and best val idx and best val acc, using the model dir, model name and model idx
model_save_name = " train_model " ,
model_idx = " latest " ,
best_validation_model_idx = self . best_val_model_idx ,
best_validation_model_acc = self . best_val_model_acc ,
)
2024-11-11 14:00:28 +01:00
################################################################
##### Plot Gradient Flow at each Epoch during Training ######
print ( " Generating Gradient Flow Plot at epoch {} " . format ( epoch_idx ) )
plt = self . plot_grad_flow ( self . model . named_parameters ( ) )
2024-11-19 10:38:29 +01:00
if not os . path . exists (
os . path . join ( self . experiment_saved_models , " gradient_flow_plots " )
) :
os . mkdir (
os . path . join ( self . experiment_saved_models , " gradient_flow_plots " )
)
2024-11-11 14:00:28 +01:00
# plt.legend(loc="best")
2024-11-19 10:38:29 +01:00
plt . savefig (
os . path . join (
self . experiment_saved_models ,
" gradient_flow_plots " ,
" epoch {} .pdf " . format ( str ( epoch_idx ) ) ,
)
)
2024-11-11 14:00:28 +01:00
################################################################
2024-11-19 10:38:29 +01:00
2024-11-11 14:00:28 +01:00
print ( " Generating test set evaluation metrics " )
2024-11-19 10:38:29 +01:00
self . load_model (
model_save_dir = self . experiment_saved_models ,
model_idx = self . best_val_model_idx ,
# load best validation model
model_save_name = " train_model " ,
)
current_epoch_losses = {
" test_acc " : [ ] ,
" test_loss " : [ ] ,
} # initialize a statistics dict
2024-11-11 14:00:28 +01:00
with tqdm . tqdm ( total = len ( self . test_data ) ) as pbar_test : # ini a progress bar
for x , y in self . test_data : # sample batch
2024-11-19 10:38:29 +01:00
loss , accuracy = self . run_evaluation_iter (
x = x , y = y
) # compute loss and accuracy by running an evaluation step
2024-11-11 14:00:28 +01:00
current_epoch_losses [ " test_loss " ] . append ( loss ) # save test loss
current_epoch_losses [ " test_acc " ] . append ( accuracy ) # save test accuracy
pbar_test . update ( 1 ) # update progress bar status
pbar_test . set_description (
2024-11-19 10:38:29 +01:00
" loss: {:.4f} , accuracy: {:.4f} " . format ( loss , accuracy )
) # update progress bar string output
test_losses = {
key : [ np . mean ( value ) ] for key , value in current_epoch_losses . items ( )
} # save test set metrics in dict format
save_statistics (
experiment_log_dir = self . experiment_logs ,
filename = " test_summary.csv " ,
# save test set metrics on disk in .csv format
stats_dict = test_losses ,
current_epoch = 0 ,
continue_from_mode = False ,
)
2024-11-11 14:00:28 +01:00
return total_losses , test_losses