2024-10-10 15:52:23 +02:00
# -*- coding: utf-8 -*-
""" Training schedulers.
This module contains classes implementing schedulers which control the
evolution of learning rule hyperparameters ( such as learning rate ) over a
training run .
"""
import numpy as np
class ConstantLearningRateScheduler ( object ) :
""" Example of scheduler interface which sets a constant learning rate. """
def __init__ ( self , learning_rate ) :
""" Construct a new constant learning rate scheduler object.
Args :
learning_rate : Learning rate to use in learning rule .
"""
self . learning_rate = learning_rate
def update_learning_rule ( self , learning_rule , epoch_number ) :
""" Update the hyperparameters of the learning rule.
Run at the beginning of each epoch .
Args :
learning_rule : Learning rule object being used in training run ,
any scheduled hyperparameters to be altered should be
attributes of this object .
epoch_number : Integer index of training epoch about to be run .
"""
learning_rule . learning_rate = self . learning_rate
2024-10-14 11:51:43 +02:00
class CosineAnnealingWithWarmRestarts ( object ) :
""" Cosine annealing scheduler, implemented as in https://arxiv.org/pdf/1608.03983.pdf """
def __init__ ( self , min_learning_rate , max_learning_rate , total_iters_per_period , max_learning_rate_discount_factor ,
period_iteration_expansion_factor ) :
"""
Instantiates a new cosine annealing with warm restarts learning rate scheduler
: param min_learning_rate : The minimum learning rate the scheduler can assign
: param max_learning_rate : The maximum learning rate the scheduler can assign
: param total_epochs_per_period : The number of epochs in a period
: param max_learning_rate_discount_factor : The rate of discount for the maximum learning rate after each restart i . e . how many times smaller the max learning rate will be after a restart compared to the previous one
: param period_iteration_expansion_factor : The rate of expansion of the period epochs . e . g . if it ' s set to 1 then all periods have the same number of epochs, if it ' s larger than 1 then each subsequent period will have more epochs and vice versa .
"""
self . min_learning_rate = min_learning_rate
self . max_learning_rate = max_learning_rate
self . total_epochs_per_period = total_iters_per_period
self . max_learning_rate_discount_factor = max_learning_rate_discount_factor
self . period_iteration_expansion_factor = period_iteration_expansion_factor
def update_learning_rule ( self , learning_rule , epoch_number ) :
""" Update the hyperparameters of the learning rule.
Run at the beginning of each epoch .
Args :
learning_rule : Learning rule object being used in training run ,
any scheduled hyperparameters to be altered should be
attributes of this object .
epoch_number : Integer index of training epoch about to be run .
Returns :
effective_learning_rate at step ' epoch_number '
"""
raise NotImplementedError