Merge remote branch 'upstream/master'

This commit is contained in:
pswietojanski 2015-12-13 20:08:12 +00:00
commit 57b7496083

View File

@ -15,7 +15,7 @@ class LearningRateScheduler(object):
def get_rate(self): def get_rate(self):
raise NotImplementedError() raise NotImplementedError()
def get_next_rate(self, current_error=None): def get_next_rate(self, current_accuracy=None):
self.epoch += 1 self.epoch += 1
@ -35,8 +35,8 @@ class LearningRateList(LearningRateScheduler):
return self.lr_list[self.epoch] return self.lr_list[self.epoch]
return 0.0 return 0.0
def get_next_rate(self, current_error=None): def get_next_rate(self, current_accuracy=None):
super(LearningRateList, self).get_next_rate(current_error=None) super(LearningRateList, self).get_next_rate(current_accuracy=None)
return self.get_rate() return self.get_rate()
@ -53,18 +53,21 @@ class LearningRateFixed(LearningRateList):
return self.lr_list[0] return self.lr_list[0]
return 0.0 return 0.0
def get_next_rate(self, current_error=None): def get_next_rate(self, current_accuracy=None):
super(LearningRateFixed, self).get_next_rate(current_error=None) super(LearningRateFixed, self).get_next_rate(current_accuracy=None)
return self.get_rate() return self.get_rate()
class LearningRateNewBob(LearningRateScheduler): class LearningRateNewBob(LearningRateScheduler):
""" """
Exponential learning rate schema newbob learning rate schedule.
Fixed learning rate until validation set stops improving then exponential
decay.
""" """
def __init__(self, start_rate, scale_by=.5, max_epochs=99, \ def __init__(self, start_rate, scale_by=.5, max_epochs=99,
min_derror_ramp_start=.5, min_derror_stop=.5, init_error=100.0, \ min_derror_ramp_start=.5, min_derror_stop=.5, init_error=100.0,
patience=0, zero_rate=None, ramping=False): patience=0, zero_rate=None, ramping=False):
""" """
:type start_rate: float :type start_rate: float
@ -84,8 +87,6 @@ class LearningRateNewBob(LearningRateScheduler):
:type init_error: float :type init_error: float
:param init_error: :param init_error:
# deltas2 below are just deltas returned by linear Linear,bprop transform
# and are exactly the same as
""" """
self.start_rate = start_rate self.start_rate = start_rate
self.init_error = init_error self.init_error = init_error
@ -115,13 +116,14 @@ class LearningRateNewBob(LearningRateScheduler):
return self.zero_rate return self.zero_rate
return self.rate return self.rate
def get_next_rate(self, current_error): def get_next_rate(self, current_accuracy):
""" """
:type current_error: float :type current_accuracy: float
:param current_error: percentage error :param current_accuracy: current proportion correctly classified
""" """
current_error = 1. - current_accuracy
diff_error = 0.0 diff_error = 0.0
if ( (self.max_epochs > 10000) or (self.epoch >= self.max_epochs) ): if ( (self.max_epochs > 10000) or (self.epoch >= self.max_epochs) ):
@ -166,5 +168,5 @@ class DropoutFixed(LearningRateList):
def get_rate(self): def get_rate(self):
return self.lr_list[0] return self.lr_list[0]
def get_next_rate(self, current_error=None): def get_next_rate(self, current_accuracy=None):
return self.get_rate() return self.get_rate()