74 lines
3.1 KiB
Python
74 lines
3.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Training schedulers.
|
|
|
|
This module contains classes implementing schedulers which control the
|
|
evolution of learning rule hyperparameters (such as learning rate) over a
|
|
training run.
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
|
|
class ConstantLearningRateScheduler(object):
|
|
"""Example of scheduler interface which sets a constant learning rate."""
|
|
|
|
def __init__(self, learning_rate):
|
|
"""Construct a new constant learning rate scheduler object.
|
|
|
|
Args:
|
|
learning_rate: Learning rate to use in learning rule.
|
|
"""
|
|
self.learning_rate = learning_rate
|
|
|
|
def update_learning_rule(self, learning_rule, epoch_number):
|
|
"""Update the hyperparameters of the learning rule.
|
|
|
|
Run at the beginning of each epoch.
|
|
|
|
Args:
|
|
learning_rule: Learning rule object being used in training run,
|
|
any scheduled hyperparameters to be altered should be
|
|
attributes of this object.
|
|
epoch_number: Integer index of training epoch about to be run.
|
|
"""
|
|
learning_rule.learning_rate = self.learning_rate
|
|
|
|
class CosineAnnealingWithWarmRestarts(object):
|
|
"""Cosine annealing scheduler, implemented as in https://arxiv.org/pdf/1608.03983.pdf"""
|
|
|
|
def __init__(self, min_learning_rate, max_learning_rate, total_iters_per_period, max_learning_rate_discount_factor,
|
|
period_iteration_expansion_factor):
|
|
"""
|
|
Instantiates a new cosine annealing with warm restarts learning rate scheduler
|
|
:param min_learning_rate: The minimum learning rate the scheduler can assign
|
|
:param max_learning_rate: The maximum learning rate the scheduler can assign
|
|
:param total_epochs_per_period: The number of epochs in a period
|
|
:param max_learning_rate_discount_factor: The rate of discount for the maximum learning rate after each restart i.e. how many times smaller the max learning rate will be after a restart compared to the previous one
|
|
:param period_iteration_expansion_factor: The rate of expansion of the period epochs. e.g. if it's set to 1 then all periods have the same number of epochs, if it's larger than 1 then each subsequent period will have more epochs and vice versa.
|
|
"""
|
|
self.min_learning_rate = min_learning_rate
|
|
self.max_learning_rate = max_learning_rate
|
|
self.total_epochs_per_period = total_iters_per_period
|
|
|
|
self.max_learning_rate_discount_factor = max_learning_rate_discount_factor
|
|
self.period_iteration_expansion_factor = period_iteration_expansion_factor
|
|
|
|
|
|
def update_learning_rule(self, learning_rule, epoch_number):
|
|
"""Update the hyperparameters of the learning rule.
|
|
|
|
Run at the beginning of each epoch.
|
|
|
|
Args:
|
|
learning_rule: Learning rule object being used in training run,
|
|
any scheduled hyperparameters to be altered should be
|
|
attributes of this object.
|
|
epoch_number: Integer index of training epoch about to be run.
|
|
Returns:
|
|
effective_learning_rate at step 'epoch_number'
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
|