diff --git a/mlp/optimisers.py b/mlp/optimisers.py index dcd4199..4ce9e4d 100644 --- a/mlp/optimisers.py +++ b/mlp/optimisers.py @@ -1,5 +1,9 @@ # -*- coding: utf-8 -*- -"""Model optimisers.""" +"""Model optimisers. + +This module contains objects implementing (batched) stochastic gradient descent +based optimisation of models. +""" import time import logging @@ -11,9 +15,25 @@ logger = logging.getLogger(__name__) class Optimiser(object): + """Basic model optimiser.""" def __init__(self, model, cost, learning_rule, train_dataset, valid_dataset=None, data_monitors=None): + """Create a new optimiser instance. + + Args: + model: The model to optimise. + cost: The scalar cost function to minimise. + learning_rule: Gradient based learning rule to use to minimise + cost. + train_dataset: Data provider for training set data batches. + valid_dataset: Data provider for validation set data batches. + data_monitors: Dictionary of functions evaluated on targets and + model outputs (averaged across both full training and + validation data sets) to monitor during training in addition + to the cost. Keys should correspond to a string label for + the statistic being evaluated. + """ self.model = model self.cost = cost self.learning_rule = learning_rule @@ -25,6 +45,13 @@ class Optimiser(object): self.data_monitors.update(data_monitors) def do_training_epoch(self): + """Do a single training epoch. + + This iterates through all batches in training dataset, for each + calculating the gradient of the estimated loss given the batch with + respect to all the model parameters and then updates the model + parameters according to the learning rule. + """ for inputs_batch, targets_batch in self.train_dataset: activations = self.model.fprop(inputs_batch) grads_wrt_outputs = self.cost.grad(activations[-1], targets_batch) @@ -32,7 +59,16 @@ class Optimiser(object): activations, grads_wrt_outputs) self.learning_rule.update_params(grads_wrt_params) - def monitors(self, dataset, label): + def eval_monitors(self, dataset, label): + """Evaluates the monitors for the given dataset. + + Args: + dataset: Dataset to perform evaluation with. + label: Tag to add to end of monitor keys to identify dataset. + + Returns: + OrderedDict of monitor values evaluated on dataset. + """ data_mon_vals = OrderedDict([(key + label, 0.) for key in self.data_monitors.keys()]) for inputs_batch, targets_batch in dataset: @@ -45,22 +81,49 @@ class Optimiser(object): return data_mon_vals def get_epoch_stats(self): + """Computes training statistics for an epoch. + + Returns: + An OrderedDict with keys corresponding to the statistic labels and + values corresponding to the value of the statistic. + """ epoch_stats = OrderedDict() - epoch_stats.update(self.monitors(self.train_dataset, '(train)')) + epoch_stats.update(self.eval_monitors(self.train_dataset, '(train)')) if self.valid_dataset is not None: - epoch_stats.update(self.monitors(self.valid_dataset, '(valid)')) + epoch_stats.update(self.eval_monitors( + self.valid_dataset, '(valid)')) epoch_stats['cost(param)'] = self.model.params_cost() return epoch_stats def log_stats(self, epoch, epoch_time, stats): + """Outputs stats for a training epoch to a logger. + + Args: + epoch (int): Epoch counter. + epoch_time: Time taken in seconds for the epoch to complete. + stats: Monitored stats for the epoch. + """ logger.info('Epoch {0}: {1:.1f}s to complete\n {2}'.format( epoch, epoch_time, ', '.join(['{0}={1:.2e}'.format(k, v) for (k, v) in stats.items()]) )) - def train(self, n_epochs, stats_interval=5): + def train(self, num_epochs, stats_interval=5): + """Trains a model for a set number of epochs. + + Args: + num_epochs: Number of epochs (complete passes through trainin + dataset) to train for. + stats_interval: Training statistics will be recorded and logged + every `stats_interval` epochs. + + Returns: + Tuple with first value being an array of training run statistics + and the second being a dict mapping the labels for the statistics + recorded to their column index in the array. + """ run_stats = [] - for epoch in range(1, n_epochs + 1): + for epoch in range(1, num_epochs + 1): start_time = time.clock() self.do_training_epoch() epoch_time = time.clock() - start_time