Adding schedulers to optimiser.

This commit is contained in:
Matt Graham 2016-10-10 09:25:33 +01:00
parent 18f893d864
commit 259b000cba

View File

@ -18,7 +18,7 @@ class Optimiser(object):
"""Basic model optimiser.""" """Basic model optimiser."""
def __init__(self, model, error, learning_rule, train_dataset, def __init__(self, model, error, learning_rule, train_dataset,
valid_dataset=None, data_monitors=None): valid_dataset=None, data_monitors=None, schedulers=[]):
"""Create a new optimiser instance. """Create a new optimiser instance.
Args: Args:
@ -43,6 +43,7 @@ class Optimiser(object):
self.data_monitors = OrderedDict([('error', error)]) self.data_monitors = OrderedDict([('error', error)])
if data_monitors is not None: if data_monitors is not None:
self.data_monitors.update(data_monitors) self.data_monitors.update(data_monitors)
self.schedulers = schedulers
def do_training_epoch(self): def do_training_epoch(self):
"""Do a single training epoch. """Do a single training epoch.
@ -103,7 +104,7 @@ class Optimiser(object):
epoch_time: Time taken in seconds for the epoch to complete. epoch_time: Time taken in seconds for the epoch to complete.
stats: Monitored stats for the epoch. stats: Monitored stats for the epoch.
""" """
logger.info('Epoch {0}: {1:.1f}s to complete\n {2}'.format( logger.info('Epoch {0}: {1:.2f}s to complete\n {2}'.format(
epoch, epoch_time, epoch, epoch_time,
', '.join(['{0}={1:.2e}'.format(k, v) for (k, v) in stats.items()]) ', '.join(['{0}={1:.2e}'.format(k, v) for (k, v) in stats.items()])
)) ))
@ -118,17 +119,26 @@ class Optimiser(object):
every `stats_interval` epochs. every `stats_interval` epochs.
Returns: Returns:
Tuple with first value being an array of training run statistics Tuple with first value being an array of training run statistics,
and the second being a dict mapping the labels for the statistics the second being a dict mapping the labels for the statistics
recorded to their column index in the array. recorded to their column index in the array and the final value
being the total time elapsed in seconds during the training run.
""" """
run_stats = [self.get_epoch_stats().values()] run_stats = [self.get_epoch_stats().values()]
run_start_time = time.time()
for epoch in range(1, num_epochs + 1): for epoch in range(1, num_epochs + 1):
start_time = time.clock() for scheduler in self.schedulers:
scheduler.update_learning_rule(self.learning_rule, epoch - 1)
start_time = time.time()
self.do_training_epoch() self.do_training_epoch()
epoch_time = time.clock() - start_time epoch_time = time.time() - start_time
if epoch % stats_interval == 0: if epoch % stats_interval == 0:
stats = self.get_epoch_stats() stats = self.get_epoch_stats()
self.log_stats(epoch, epoch_time, stats) self.log_stats(epoch, epoch_time, stats)
run_stats.append(stats.values()) run_stats.append(stats.values())
return np.array(run_stats), {k: i for i, k in enumerate(stats.keys())} run_time = time.time() - run_start_time
return (
np.array(run_stats),
{k: i for i, k in enumerate(stats.keys())},
run_time
)