mlpractical/mlp/schedulers.py

35 lines
1.1 KiB
Python
Raw Normal View History

2016-10-10 10:23:53 +02:00
# -*- coding: utf-8 -*-
"""Training schedulers.
2015-10-12 02:50:05 +02:00
2016-10-10 10:23:53 +02:00
This module contains classes implementing schedulers which control the
evolution of learning rule hyperparameters (such as learning rate) over a
training run.
"""
2015-10-12 02:50:05 +02:00
2016-10-10 10:23:53 +02:00
import numpy as np
2015-10-12 02:50:05 +02:00
2016-10-10 10:23:53 +02:00
class ConstantLearningRateScheduler(object):
"""Example of scheduler interface which sets a constant learning rate."""
2015-10-12 02:50:05 +02:00
2016-10-10 10:23:53 +02:00
def __init__(self, learning_rate):
"""Construct a new constant learning rate scheduler object.
2015-10-12 02:50:05 +02:00
2016-10-10 10:23:53 +02:00
Args:
learning_rate: Learning rate to use in learning rule.
2015-10-12 02:50:05 +02:00
"""
2016-10-10 10:23:53 +02:00
self.learning_rate = learning_rate
2015-11-14 18:06:12 +01:00
2016-10-10 10:23:53 +02:00
def update_learning_rule(self, learning_rule, epoch_number):
"""Update the hyperparameters of the learning rule.
2015-11-14 18:06:12 +01:00
2016-10-10 10:23:53 +02:00
Run at the beginning of each epoch.
2015-11-14 18:06:12 +01:00
2016-10-10 10:23:53 +02:00
Args:
learning_rule: Learning rule object being used in training run,
any scheduled hyperparameters to be altered should be
attributes of this object.
epoch_number: Integer index of training epoch about to be run.
"""
learning_rule.learning_rate = self.learning_rate