Documenting optimiser module.

This commit is contained in:
Matt Graham 2016-09-21 00:54:36 +01:00
parent dac0729324
commit 0bfe0c1a34

View File

@ -1,5 +1,9 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Model optimisers.""" """Model optimisers.
This module contains objects implementing (batched) stochastic gradient descent
based optimisation of models.
"""
import time import time
import logging import logging
@ -11,9 +15,25 @@ logger = logging.getLogger(__name__)
class Optimiser(object): class Optimiser(object):
"""Basic model optimiser."""
def __init__(self, model, cost, learning_rule, train_dataset, def __init__(self, model, cost, learning_rule, train_dataset,
valid_dataset=None, data_monitors=None): valid_dataset=None, data_monitors=None):
"""Create a new optimiser instance.
Args:
model: The model to optimise.
cost: The scalar cost function to minimise.
learning_rule: Gradient based learning rule to use to minimise
cost.
train_dataset: Data provider for training set data batches.
valid_dataset: Data provider for validation set data batches.
data_monitors: Dictionary of functions evaluated on targets and
model outputs (averaged across both full training and
validation data sets) to monitor during training in addition
to the cost. Keys should correspond to a string label for
the statistic being evaluated.
"""
self.model = model self.model = model
self.cost = cost self.cost = cost
self.learning_rule = learning_rule self.learning_rule = learning_rule
@ -25,6 +45,13 @@ class Optimiser(object):
self.data_monitors.update(data_monitors) self.data_monitors.update(data_monitors)
def do_training_epoch(self): def do_training_epoch(self):
"""Do a single training epoch.
This iterates through all batches in training dataset, for each
calculating the gradient of the estimated loss given the batch with
respect to all the model parameters and then updates the model
parameters according to the learning rule.
"""
for inputs_batch, targets_batch in self.train_dataset: for inputs_batch, targets_batch in self.train_dataset:
activations = self.model.fprop(inputs_batch) activations = self.model.fprop(inputs_batch)
grads_wrt_outputs = self.cost.grad(activations[-1], targets_batch) grads_wrt_outputs = self.cost.grad(activations[-1], targets_batch)
@ -32,7 +59,16 @@ class Optimiser(object):
activations, grads_wrt_outputs) activations, grads_wrt_outputs)
self.learning_rule.update_params(grads_wrt_params) self.learning_rule.update_params(grads_wrt_params)
def monitors(self, dataset, label): def eval_monitors(self, dataset, label):
"""Evaluates the monitors for the given dataset.
Args:
dataset: Dataset to perform evaluation with.
label: Tag to add to end of monitor keys to identify dataset.
Returns:
OrderedDict of monitor values evaluated on dataset.
"""
data_mon_vals = OrderedDict([(key + label, 0.) for key data_mon_vals = OrderedDict([(key + label, 0.) for key
in self.data_monitors.keys()]) in self.data_monitors.keys()])
for inputs_batch, targets_batch in dataset: for inputs_batch, targets_batch in dataset:
@ -45,22 +81,49 @@ class Optimiser(object):
return data_mon_vals return data_mon_vals
def get_epoch_stats(self): def get_epoch_stats(self):
"""Computes training statistics for an epoch.
Returns:
An OrderedDict with keys corresponding to the statistic labels and
values corresponding to the value of the statistic.
"""
epoch_stats = OrderedDict() epoch_stats = OrderedDict()
epoch_stats.update(self.monitors(self.train_dataset, '(train)')) epoch_stats.update(self.eval_monitors(self.train_dataset, '(train)'))
if self.valid_dataset is not None: if self.valid_dataset is not None:
epoch_stats.update(self.monitors(self.valid_dataset, '(valid)')) epoch_stats.update(self.eval_monitors(
self.valid_dataset, '(valid)'))
epoch_stats['cost(param)'] = self.model.params_cost() epoch_stats['cost(param)'] = self.model.params_cost()
return epoch_stats return epoch_stats
def log_stats(self, epoch, epoch_time, stats): def log_stats(self, epoch, epoch_time, stats):
"""Outputs stats for a training epoch to a logger.
Args:
epoch (int): Epoch counter.
epoch_time: Time taken in seconds for the epoch to complete.
stats: Monitored stats for the epoch.
"""
logger.info('Epoch {0}: {1:.1f}s to complete\n {2}'.format( logger.info('Epoch {0}: {1:.1f}s to complete\n {2}'.format(
epoch, epoch_time, epoch, epoch_time,
', '.join(['{0}={1:.2e}'.format(k, v) for (k, v) in stats.items()]) ', '.join(['{0}={1:.2e}'.format(k, v) for (k, v) in stats.items()])
)) ))
def train(self, n_epochs, stats_interval=5): def train(self, num_epochs, stats_interval=5):
"""Trains a model for a set number of epochs.
Args:
num_epochs: Number of epochs (complete passes through trainin
dataset) to train for.
stats_interval: Training statistics will be recorded and logged
every `stats_interval` epochs.
Returns:
Tuple with first value being an array of training run statistics
and the second being a dict mapping the labels for the statistics
recorded to their column index in the array.
"""
run_stats = [] run_stats = []
for epoch in range(1, n_epochs + 1): for epoch in range(1, num_epochs + 1):
start_time = time.clock() start_time = time.clock()
self.do_training_epoch() self.do_training_epoch()
epoch_time = time.clock() - start_time epoch_time = time.clock() - start_time