Adding schedulers to optimiser.
This commit is contained in:
parent
18f893d864
commit
259b000cba
@ -18,7 +18,7 @@ class Optimiser(object):
|
||||
"""Basic model optimiser."""
|
||||
|
||||
def __init__(self, model, error, learning_rule, train_dataset,
|
||||
valid_dataset=None, data_monitors=None):
|
||||
valid_dataset=None, data_monitors=None, schedulers=[]):
|
||||
"""Create a new optimiser instance.
|
||||
|
||||
Args:
|
||||
@ -43,6 +43,7 @@ class Optimiser(object):
|
||||
self.data_monitors = OrderedDict([('error', error)])
|
||||
if data_monitors is not None:
|
||||
self.data_monitors.update(data_monitors)
|
||||
self.schedulers = schedulers
|
||||
|
||||
def do_training_epoch(self):
|
||||
"""Do a single training epoch.
|
||||
@ -103,7 +104,7 @@ class Optimiser(object):
|
||||
epoch_time: Time taken in seconds for the epoch to complete.
|
||||
stats: Monitored stats for the epoch.
|
||||
"""
|
||||
logger.info('Epoch {0}: {1:.1f}s to complete\n {2}'.format(
|
||||
logger.info('Epoch {0}: {1:.2f}s to complete\n {2}'.format(
|
||||
epoch, epoch_time,
|
||||
', '.join(['{0}={1:.2e}'.format(k, v) for (k, v) in stats.items()])
|
||||
))
|
||||
@ -118,17 +119,26 @@ class Optimiser(object):
|
||||
every `stats_interval` epochs.
|
||||
|
||||
Returns:
|
||||
Tuple with first value being an array of training run statistics
|
||||
and the second being a dict mapping the labels for the statistics
|
||||
recorded to their column index in the array.
|
||||
Tuple with first value being an array of training run statistics,
|
||||
the second being a dict mapping the labels for the statistics
|
||||
recorded to their column index in the array and the final value
|
||||
being the total time elapsed in seconds during the training run.
|
||||
"""
|
||||
run_stats = [self.get_epoch_stats().values()]
|
||||
run_start_time = time.time()
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
start_time = time.clock()
|
||||
for scheduler in self.schedulers:
|
||||
scheduler.update_learning_rule(self.learning_rule, epoch - 1)
|
||||
start_time = time.time()
|
||||
self.do_training_epoch()
|
||||
epoch_time = time.clock() - start_time
|
||||
epoch_time = time.time() - start_time
|
||||
if epoch % stats_interval == 0:
|
||||
stats = self.get_epoch_stats()
|
||||
self.log_stats(epoch, epoch_time, stats)
|
||||
run_stats.append(stats.values())
|
||||
return np.array(run_stats), {k: i for i, k in enumerate(stats.keys())}
|
||||
run_time = time.time() - run_start_time
|
||||
return (
|
||||
np.array(run_stats),
|
||||
{k: i for i, k in enumerate(stats.keys())},
|
||||
run_time
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user