Merge remote branch 'upstream/master'
This commit is contained in:
commit
57b7496083
@ -15,7 +15,7 @@ class LearningRateScheduler(object):
|
||||
def get_rate(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_next_rate(self, current_error=None):
|
||||
def get_next_rate(self, current_accuracy=None):
|
||||
self.epoch += 1
|
||||
|
||||
|
||||
@ -35,8 +35,8 @@ class LearningRateList(LearningRateScheduler):
|
||||
return self.lr_list[self.epoch]
|
||||
return 0.0
|
||||
|
||||
def get_next_rate(self, current_error=None):
|
||||
super(LearningRateList, self).get_next_rate(current_error=None)
|
||||
def get_next_rate(self, current_accuracy=None):
|
||||
super(LearningRateList, self).get_next_rate(current_accuracy=None)
|
||||
return self.get_rate()
|
||||
|
||||
|
||||
@ -53,18 +53,21 @@ class LearningRateFixed(LearningRateList):
|
||||
return self.lr_list[0]
|
||||
return 0.0
|
||||
|
||||
def get_next_rate(self, current_error=None):
|
||||
super(LearningRateFixed, self).get_next_rate(current_error=None)
|
||||
def get_next_rate(self, current_accuracy=None):
|
||||
super(LearningRateFixed, self).get_next_rate(current_accuracy=None)
|
||||
return self.get_rate()
|
||||
|
||||
|
||||
class LearningRateNewBob(LearningRateScheduler):
|
||||
"""
|
||||
Exponential learning rate schema
|
||||
newbob learning rate schedule.
|
||||
|
||||
Fixed learning rate until validation set stops improving then exponential
|
||||
decay.
|
||||
"""
|
||||
|
||||
def __init__(self, start_rate, scale_by=.5, max_epochs=99, \
|
||||
min_derror_ramp_start=.5, min_derror_stop=.5, init_error=100.0, \
|
||||
def __init__(self, start_rate, scale_by=.5, max_epochs=99,
|
||||
min_derror_ramp_start=.5, min_derror_stop=.5, init_error=100.0,
|
||||
patience=0, zero_rate=None, ramping=False):
|
||||
"""
|
||||
:type start_rate: float
|
||||
@ -84,8 +87,6 @@ class LearningRateNewBob(LearningRateScheduler):
|
||||
|
||||
:type init_error: float
|
||||
:param init_error:
|
||||
# deltas2 below are just deltas returned by linear Linear,bprop transform
|
||||
# and are exactly the same as
|
||||
"""
|
||||
self.start_rate = start_rate
|
||||
self.init_error = init_error
|
||||
@ -115,13 +116,14 @@ class LearningRateNewBob(LearningRateScheduler):
|
||||
return self.zero_rate
|
||||
return self.rate
|
||||
|
||||
def get_next_rate(self, current_error):
|
||||
def get_next_rate(self, current_accuracy):
|
||||
"""
|
||||
:type current_error: float
|
||||
:param current_error: percentage error
|
||||
:type current_accuracy: float
|
||||
:param current_accuracy: current proportion correctly classified
|
||||
|
||||
"""
|
||||
|
||||
current_error = 1. - current_accuracy
|
||||
diff_error = 0.0
|
||||
|
||||
if ( (self.max_epochs > 10000) or (self.epoch >= self.max_epochs) ):
|
||||
@ -166,5 +168,5 @@ class DropoutFixed(LearningRateList):
|
||||
def get_rate(self):
|
||||
return self.lr_list[0]
|
||||
|
||||
def get_next_rate(self, current_error=None):
|
||||
return self.get_rate()
|
||||
def get_next_rate(self, current_accuracy=None):
|
||||
return self.get_rate()
|
||||
|
Loading…
Reference in New Issue
Block a user