diff --git a/mlp/schedulers.py b/mlp/schedulers.py index 914ea29..6ae9597 100644 --- a/mlp/schedulers.py +++ b/mlp/schedulers.py @@ -15,7 +15,7 @@ class LearningRateScheduler(object): def get_rate(self): raise NotImplementedError() - def get_next_rate(self, current_error=None): + def get_next_rate(self, current_accuracy=None): self.epoch += 1 @@ -35,8 +35,8 @@ class LearningRateList(LearningRateScheduler): return self.lr_list[self.epoch] return 0.0 - def get_next_rate(self, current_error=None): - super(LearningRateList, self).get_next_rate(current_error=None) + def get_next_rate(self, current_accuracy=None): + super(LearningRateList, self).get_next_rate(current_accuracy=None) return self.get_rate() @@ -53,18 +53,21 @@ class LearningRateFixed(LearningRateList): return self.lr_list[0] return 0.0 - def get_next_rate(self, current_error=None): - super(LearningRateFixed, self).get_next_rate(current_error=None) + def get_next_rate(self, current_accuracy=None): + super(LearningRateFixed, self).get_next_rate(current_accuracy=None) return self.get_rate() class LearningRateNewBob(LearningRateScheduler): """ - Exponential learning rate schema + newbob learning rate schedule. + + Fixed learning rate until validation set stops improving then exponential + decay. """ - def __init__(self, start_rate, scale_by=.5, max_epochs=99, \ - min_derror_ramp_start=.5, min_derror_stop=.5, init_error=100.0, \ + def __init__(self, start_rate, scale_by=.5, max_epochs=99, + min_derror_ramp_start=.5, min_derror_stop=.5, init_error=100.0, patience=0, zero_rate=None, ramping=False): """ :type start_rate: float @@ -84,8 +87,6 @@ class LearningRateNewBob(LearningRateScheduler): :type init_error: float :param init_error: - # deltas2 below are just deltas returned by linear Linear,bprop transform - # and are exactly the same as """ self.start_rate = start_rate self.init_error = init_error @@ -115,13 +116,14 @@ class LearningRateNewBob(LearningRateScheduler): return self.zero_rate return self.rate - def get_next_rate(self, current_error): + def get_next_rate(self, current_accuracy): """ - :type current_error: float - :param current_error: percentage error + :type current_accuracy: float + :param current_accuracy: current proportion correctly classified """ + current_error = 1. - current_accuracy diff_error = 0.0 if ( (self.max_epochs > 10000) or (self.epoch >= self.max_epochs) ): @@ -166,5 +168,5 @@ class DropoutFixed(LearningRateList): def get_rate(self): return self.lr_list[0] - def get_next_rate(self, current_error=None): - return self.get_rate() \ No newline at end of file + def get_next_rate(self, current_accuracy=None): + return self.get_rate()