Adding reshape layer.

This commit is contained in:
Matt Graham 2016-11-02 01:13:22 +00:00
parent a6f775b1ff
commit da9ee58609

View File

@ -261,6 +261,62 @@ class AffineLayer(LayerWithParameters):
self.input_dim, self.output_dim)
class ReshapeLayer(layers.Layer):
"""Layer which reshapes dimensions of inputs."""
def __init__(self, output_shape=None):
"""Create a new reshape layer object.
Args:
output_shape: Tuple specifying shape each input in batch should
be reshaped to in outputs. This **excludes** the batch size
so the shape of the final output array will be
(batch_size, ) + output_shape
Similarly to numpy.reshape, one shape dimension can be -1. In
this case, the value is inferred from the size of the input
array and remaining dimensions. The shape specified must be
compatible with the input array shape - i.e. the total number
of values in the array cannot be changed. If set to `None` the
output shape will be set to
(batch_size, -1)
which will flatten all the inputs to vectors.
"""
self.output_shape = (-1,) if output_shape is None else output_shape
def fprop(self, inputs):
"""Forward propagates activations through the layer transformation.
Args:
inputs: Array of layer inputs of shape (batch_size, input_dim).
Returns:
outputs: Array of layer outputs of shape (batch_size, output_dim).
"""
return inputs.reshape((inputs.shape[0],) + self.output_shape)
def bprop(self, inputs, outputs, grads_wrt_outputs):
"""Back propagates gradients through a layer.
Given gradients with respect to the outputs of the layer calculates the
gradients with respect to the layer inputs.
Args:
inputs: Array of layer inputs of shape (batch_size, input_dim).
outputs: Array of layer outputs calculated in forward pass of
shape (batch_size, output_dim).
grads_wrt_outputs: Array of gradients with respect to the layer
outputs of shape (batch_size, output_dim).
Returns:
Array of gradients with respect to the layer inputs of shape
(batch_size, input_dim).
"""
return grads_wrt_outputs.reshape(inputs.shape)
def __repr__(self):
return 'ReshapeLayer(output_shape={0})'.format(self.output_shape)
class SigmoidLayer(Layer):
"""Layer implementing an element-wise logistic sigmoid transformation."""