Merge remote branch 'upstream/master'

This commit is contained in:
pswietojanski 2015-12-13 20:08:12 +00:00
commit 57b7496083

View File

@ -15,7 +15,7 @@ class LearningRateScheduler(object):
def get_rate(self):
raise NotImplementedError()
def get_next_rate(self, current_error=None):
def get_next_rate(self, current_accuracy=None):
self.epoch += 1
@ -35,8 +35,8 @@ class LearningRateList(LearningRateScheduler):
return self.lr_list[self.epoch]
return 0.0
def get_next_rate(self, current_error=None):
super(LearningRateList, self).get_next_rate(current_error=None)
def get_next_rate(self, current_accuracy=None):
super(LearningRateList, self).get_next_rate(current_accuracy=None)
return self.get_rate()
@ -53,18 +53,21 @@ class LearningRateFixed(LearningRateList):
return self.lr_list[0]
return 0.0
def get_next_rate(self, current_error=None):
super(LearningRateFixed, self).get_next_rate(current_error=None)
def get_next_rate(self, current_accuracy=None):
super(LearningRateFixed, self).get_next_rate(current_accuracy=None)
return self.get_rate()
class LearningRateNewBob(LearningRateScheduler):
"""
Exponential learning rate schema
newbob learning rate schedule.
Fixed learning rate until validation set stops improving then exponential
decay.
"""
def __init__(self, start_rate, scale_by=.5, max_epochs=99, \
min_derror_ramp_start=.5, min_derror_stop=.5, init_error=100.0, \
def __init__(self, start_rate, scale_by=.5, max_epochs=99,
min_derror_ramp_start=.5, min_derror_stop=.5, init_error=100.0,
patience=0, zero_rate=None, ramping=False):
"""
:type start_rate: float
@ -84,8 +87,6 @@ class LearningRateNewBob(LearningRateScheduler):
:type init_error: float
:param init_error:
# deltas2 below are just deltas returned by linear Linear,bprop transform
# and are exactly the same as
"""
self.start_rate = start_rate
self.init_error = init_error
@ -115,13 +116,14 @@ class LearningRateNewBob(LearningRateScheduler):
return self.zero_rate
return self.rate
def get_next_rate(self, current_error):
def get_next_rate(self, current_accuracy):
"""
:type current_error: float
:param current_error: percentage error
:type current_accuracy: float
:param current_accuracy: current proportion correctly classified
"""
current_error = 1. - current_accuracy
diff_error = 0.0
if ( (self.max_epochs > 10000) or (self.epoch >= self.max_epochs) ):
@ -166,5 +168,5 @@ class DropoutFixed(LearningRateList):
def get_rate(self):
return self.lr_list[0]
def get_next_rate(self, current_error=None):
def get_next_rate(self, current_accuracy=None):
return self.get_rate()