diff --git a/mlp/layers.py b/mlp/layers.py index 60fabca..e050b25 100644 --- a/mlp/layers.py +++ b/mlp/layers.py @@ -366,3 +366,64 @@ class SoftmaxLayer(Layer): def __repr__(self): return 'SoftmaxLayer' + + +class RadialBasisFunctionLayer(Layer): + """Layer implementing projection to a grid of radial basis functions.""" + + def __init__(self, grid_dim, intervals=[[0., 1.]]): + """Creates a radial basis function layer object. + + Args: + grid_dim: Integer specifying how many basis function to use in + grid across input space per dimension (so total number of + basis functions will be grid_dim**input_dim) + intervals: List of intervals (two element lists or tuples) + specifying extents of axis-aligned region in input-space to + tile basis functions in grid across. For example for a 2D input + space spanning [0, 1] x [0, 1] use intervals=[[0, 1], [0, 1]]. + """ + num_basis = grid_dim**len(intervals) + self.centres = np.array(np.meshgrid(*[ + np.linspace(low, high, grid_dim) for (low, high) in intervals]) + ).reshape((len(intervals), -1)) + self.scales = np.array([ + [(high - low) * 1. / grid_dim] for (low, high) in intervals]) + + def fprop(self, inputs): + """Forward propagates activations through the layer transformation. + + Args: + inputs: Array of layer inputs of shape (batch_size, input_dim). + + Returns: + outputs: Array of layer outputs of shape (batch_size, output_dim). + """ + return np.exp(-(inputs[..., None] - self.centres[None, ...])**2 / + self.scales**2).reshape((inputs.shape[0], -1)) + + def bprop(self, inputs, outputs, grads_wrt_outputs): + """Back propagates gradients through a layer. + + Given gradients with respect to the outputs of the layer calculates the + gradients with respect to the layer inputs. + + Args: + inputs: Array of layer inputs of shape (batch_size, input_dim). + outputs: Array of layer outputs calculated in forward pass of + shape (batch_size, output_dim). + grads_wrt_outputs: Array of gradients with respect to the layer + outputs of shape (batch_size, output_dim). + + Returns: + Array of gradients with respect to the layer inputs of shape + (batch_size, input_dim). + """ + num_basis = self.centres.shape[1] + return -2 * ( + ((inputs[..., None] - self.centres[None, ...]) / self.scales**2) * + grads_wrt_outputs.reshape((inputs.shape[0], -1, num_basis)) + ).sum(-1) + + def __repr__(self): + return 'RadialBasisFunctionLayer(grid_dim={0})'.format(grid_dim)