44 lines
1.1 KiB
Python
44 lines
1.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Learning rules."""
|
|
|
|
import numpy as np
|
|
|
|
|
|
class GradientDescentLearningRule(object):
|
|
|
|
def __init__(self, learning_rate=1e-3):
|
|
self.learning_rate = learning_rate
|
|
|
|
def initialise(self, params):
|
|
self.params = params
|
|
|
|
def reset(self):
|
|
pass
|
|
|
|
def update_params(self, grads_wrt_params):
|
|
for param, grad in zip(self.params, grads_wrt_params):
|
|
param -= self.learning_rate * grad
|
|
|
|
|
|
class MomentumLearningRule(object):
|
|
|
|
def __init__(self, learning_rate=1e-3, mom_coeff=0.9):
|
|
self.learning_rate = learning_rate
|
|
self.mom_coeff = mom_coeff
|
|
|
|
def initialise(self, params):
|
|
self.params = params
|
|
self.moms = []
|
|
for param in self.params:
|
|
self.moms.append(np.zeros_like(param))
|
|
|
|
def reset(self):
|
|
for mom in zip(self.moms):
|
|
mom *= 0.
|
|
|
|
def update_params(self, grads_wrt_params):
|
|
for param, mom, grad in zip(self.params, self.moms, grads_wrt_params):
|
|
mom *= self.mom_coeff
|
|
mom -= self.learning_rate * grad
|
|
param += mom
|