diff --git a/mlp/learning_rules.py b/mlp/learning_rules.py index c2cb013..1f7ca8c 100644 --- a/mlp/learning_rules.py +++ b/mlp/learning_rules.py @@ -160,3 +160,113 @@ class MomentumLearningRule(GradientDescentLearningRule): mom *= self.mom_coeff mom -= self.learning_rate * grad param += mom + + +class NesterovMomentumLearningRule(GradientDescentLearningRule): + """Gradient descent with Nesterov accelerated gradient learning rule. + + This again extends the basic gradient learning rule by introducing extra + momentum state variables for each parameter. These can help the learning + dynamic help overcome shallow local minima and speed convergence when + making multiple successive steps in a similar direction in parameter space. + + Compared to 'classical' momentum, Nesterov momentum [1] uses a slightly + different update rule where the momentum is effectively decremented by the + gradient evaluated at the parameters plus the momentum coefficient times + the current previous momentum. This corresponds to 'looking ahead' to + where the previous momentum would move the parameters to and using the + gradient evaluated at this look ahead point. This can give more responsive + and stable momentum updates in some cases [1]. + + To fit in with the learning rule framework used here we use a variant of + Nesterov momentum described in [2] where the updates are reparameterised + in terms of the 'look ahead' parameters, so as to allow the learning rule + to be passed the gradients evaluated at the current parameters as with the + other learning rules. + + For parameter p[i] and corresponding momentum m[i] the updates for a + scalar loss function `L` are of the form + + m_ := m[i] + m[i] := mom_coeff * m[i] - learning_rate * dL/dp[i] + p[i] := p[i] - mom_coeff * m_ + (1 + mom_coeff) * m[i] + + with `learning_rate` a positive scaling parameter for the gradient updates + and `mom_coeff` a value in [0, 1] that determines how much 'friction' there + is the system and so how quickly previous momentum contributions decay. + + References: + [1]: On the importance of initialization and momentum in deep learning + Sutskever, Martens, Dahl and Hinton (2013) + [2]: http://cs231n.github.io/neural-networks-3/#anneal + """ + + def __init__(self, learning_rate=1e-3, mom_coeff=0.9): + """Creates a new learning rule object. + + Args: + learning_rate: A postive scalar to scale gradient updates to the + parameters by. This needs to be carefully set - if too large + the learning dynamic will be unstable and may diverge, while + if set too small learning will proceed very slowly. + mom_coeff: A scalar in the range [0, 1] inclusive. This determines + the contribution of the previous momentum value to the value + after each update. If equal to 0 the momentum is set to exactly + the negative scaled gradient each update and so this rule + collapses to standard gradient descent. If equal to 1 the + momentum will just be decremented by the scaled gradient at + each update. This is equivalent to simulating the dynamic in + a frictionless system. Due to energy conservation the loss + of 'potential energy' as the dynamics moves down the loss + function surface will lead to an increasingly large 'kinetic + energy' and so speed, meaning the updates will become + increasingly large, potentially unstably so. Typically a value + less than but close to 1 will avoid these issues and cause the + dynamic to converge to a local minima where the gradients are + by definition zero. + """ + super(NesterovMomentumLearningRule, self).__init__(learning_rate) + assert mom_coeff >= 0. and mom_coeff <= 1., ( + 'mom_coeff should be in the range [0, 1].' + ) + self.mom_coeff = mom_coeff + + def initialise(self, params): + """Initialises the state of the learning rule for a set or parameters. + + This must be called before `update_params` is first called. + + Args: + params: A list of the parameters to be optimised. Note these will + be updated *in-place* to avoid reallocating arrays on each + update. + """ + super(NesterovMomentumLearningRule, self).initialise(params) + self.moms = [] + for param in self.params: + self.moms.append(np.zeros_like(param)) + + def reset(self): + """Resets any additional state variables to their initial values. + + For this learning rule this corresponds to zeroing all the momenta. + """ + for mom in zip(self.moms): + mom *= 0. + + def update_params(self, grads_wrt_params): + """Applies a single update to all parameters. + + All parameter updates are performed using in-place operations and so + nothing is returned. + + Args: + grads_wrt_params: A list of gradients of the scalar loss function + with respect to each of the parameters passed to `initialise` + previously, with this list expected to be in the same order. + """ + for param, mom, grad in zip(self.params, self.moms, grads_wrt_params): + mom_prev = mom.copy() + mom *= self.mom_coeff + mom -= self.learning_rate * grad + param += (1. + self.mom_coeff) * mom - self.mom_coeff * mom_prev