From 03f27cab75ade118167b307f6cefcf89b80fc9f0 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Mon, 19 Sep 2016 12:18:44 +0100 Subject: [PATCH] Added ability to add additional monitoring channels during training. --- mlp/trainers.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/mlp/trainers.py b/mlp/trainers.py index 68e5d79..7cfbf5e 100644 --- a/mlp/trainers.py +++ b/mlp/trainers.py @@ -13,13 +13,16 @@ logger = logging.getLogger(__name__) class Trainer(object): def __init__(self, model, cost, learning_rule, train_dataset, - valid_dataset=None): + valid_dataset=None, data_monitors=None): self.model = model self.cost = cost self.learning_rule = learning_rule self.learning_rule.initialise(self.model.params) self.train_dataset = train_dataset self.valid_dataset = valid_dataset + self.data_monitors = OrderedDict([('cost', cost)]) + if data_monitors is not None: + self.data_monitors.update(data_monitors) def do_training_epoch(self): for inputs_batch, targets_batch in self.train_dataset: @@ -29,19 +32,23 @@ class Trainer(object): activations, grads_wrt_outputs) self.learning_rule.update_params(grads_wrt_params) - def data_cost(self, dataset): - cost = 0. + def monitors(self, dataset, label): + data_mon_vals = OrderedDict([(key + label, 0.) for key + in self.data_monitors.keys()]) for inputs_batch, targets_batch in dataset: activations = self.model.fprop(inputs_batch) - cost += self.cost(activations[-1], targets_batch) - cost /= dataset.num_batches - return cost + for key, data_monitor in self.data_monitors.items(): + data_mon_vals[key + label] += data_monitor( + activations[-1], targets_batch) + for key, data_monitor in self.data_monitors.items(): + data_mon_vals[key + label] /= dataset.num_batches + return data_mon_vals def get_epoch_stats(self): epoch_stats = OrderedDict() - epoch_stats['cost(train)'] = self.data_cost(self.train_dataset) + epoch_stats.update(self.monitors(self.train_dataset, '(train)')) if self.valid_dataset is not None: - epoch_stats['cost(valid)'] = self.data_cost(self.valid_dataset) + epoch_stats.update(self.monitors(self.valid_dataset, '(valid)')) epoch_stats['cost(param)'] = self.model.params_cost() return epoch_stats @@ -61,4 +68,4 @@ class Trainer(object): stats = self.get_epoch_stats() self.log_stats(epoch, epoch_time, stats) run_stats.append(stats.values()) - return np.array(run_stats), stats.keys() + return np.array(run_stats), {k: i for i, k in enumerate(stats.keys())}