mlpractical/notebooks/07_Autoencoders.ipynb

994 lines
583 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Autoencoders\n",
"\n",
"In this notebook we will explore autoencoder models. These are models in which the inputs are *encoded* to some intermediate representation before this representation is then *decoded* to try to reconstruct the inputs. They are example of a model which uses an unsupervised training method and are both interesting as a model in their own right and as a method for pre-training useful representations to use in supervised tasks such as classification. Autoencoders were covered as a pre-training method in the [sixth lecture slides](http://www.inf.ed.ac.uk/teaching/courses/mlp/2016/mlp06-enc.pdf).\n",
"\n",
"__*Correction: The original version of this notebook used the term 'contractive autoencoder' to refer to an autoencoder where the encoder 'contracts' the input to a smaller dimension hidden representation. This is non-standard usage - 'Contractive Autoencoder' is used more commonly for an autoencoder variant with a specific form of regularisation described in the paper [Contractive Autoencoders: Explicit Feature Invariance During Feature Extraction](http://www.icml-2011.org/papers/455_icmlpaper.pdf). Apologies for any confusion and thanks to Iain Murray for pointing out this error.*__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 1: Linear <s>contractive</s> autoencoders\n",
"\n",
"For the first exercise we will consider training a simple autoencoder where the hidden representation is smaller in dimension than the input and the objective is to minimise the mean squared error between the original inputs and reconstructed inputs. To begin with we will consider models in which the encoder and decoder are both simple affine transformations.\n",
"\n",
"When training an autoencoder the target outputs for the model are the original inputs. A simple way to integrate this in to our `mlp` framework is to define a new data provider inheriting from a base data provider (e.g. `MNISTDataProvider`) which overrides the `next` method to return the inputs batch as both inputs and targets to the model. A data provider of this form has been provided for you in `mlp.data_providers` as `MNISTAutoencoderDataProvider`.\n",
"\n",
"Use this data provider to train an autoencoder model with a 50 dimensional hidden representation and both encoder and decoder defined by affine transformations. You should use a sum of squared differences error and a basic gradient descent learning rule with learning rate 0.01. Initialise the biases to zero and use a uniform Glorot initialisation for both layers weights. Train the model for 25 epochs with a batch size of 50."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy as np\n",
"import logging\n",
"import mlp.layers as layers\n",
"import mlp.models as models\n",
"import mlp.optimisers as optimisers\n",
"import mlp.errors as errors\n",
"import mlp.learning_rules as learning_rules\n",
"import mlp.data_providers as data_providers\n",
"import mlp.initialisers as initialisers\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Seed a random number generator\n",
"seed = 10102016 \n",
"rng = np.random.RandomState(seed)\n",
"\n",
"# Set up a logger object to print info about the training run to stdout\n",
"logger = logging.getLogger()\n",
"logger.setLevel(logging.INFO)\n",
"logger.handlers = [logging.StreamHandler()]\n",
"\n",
"# Create data provider objects for the MNIST data set\n",
"train_data = data_providers.MNISTAutoencoderDataProvider('train', batch_size=50, rng=rng)\n",
"valid_data = data_providers.MNISTAutoencoderDataProvider('valid', batch_size=50, rng=rng)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"input_dim, output_dim, hidden_dim = 784, 784, 50\n",
"\n",
"weights_init = initialisers.GlorotUniformInit(rng=rng)\n",
"biases_init = initialisers.ConstantInit(0.)\n",
"\n",
"model = models.MultipleLayerModel([\n",
" layers.AffineLayer(input_dim, hidden_dim, weights_init, biases_init), \n",
" layers.AffineLayer(hidden_dim, output_dim, weights_init, biases_init),\n",
"])\n",
"\n",
"error = errors.SumOfSquaredDiffsError()\n",
"\n",
"learning_rule = learning_rules.GradientDescentLearningRule(0.01)\n",
"\n",
"num_epochs = 25\n",
"stats_interval = 1\n",
"optimiser = optimisers.Optimiser(\n",
" model, error, learning_rule, train_data, valid_data)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0:\n",
" error(train)=5.30e+01, error(valid)=5.20e+01, params_penalty=0.00e+00\n",
"Epoch 1: 2.22s to complete\n",
" error(train)=7.66e+00, error(valid)=7.63e+00, params_penalty=0.00e+00\n",
"Epoch 2: 2.19s to complete\n",
" error(train)=5.76e+00, error(valid)=5.74e+00, params_penalty=0.00e+00\n",
"Epoch 3: 2.24s to complete\n",
" error(train)=5.11e+00, error(valid)=5.09e+00, params_penalty=0.00e+00\n",
"Epoch 4: 2.13s to complete\n",
" error(train)=4.87e+00, error(valid)=4.85e+00, params_penalty=0.00e+00\n",
"Epoch 5: 2.14s to complete\n",
" error(train)=4.77e+00, error(valid)=4.75e+00, params_penalty=0.00e+00\n",
"Epoch 6: 2.17s to complete\n",
" error(train)=4.72e+00, error(valid)=4.71e+00, params_penalty=0.00e+00\n",
"Epoch 7: 2.32s to complete\n",
" error(train)=4.69e+00, error(valid)=4.68e+00, params_penalty=0.00e+00\n",
"Epoch 8: 2.15s to complete\n",
" error(train)=4.68e+00, error(valid)=4.66e+00, params_penalty=0.00e+00\n",
"Epoch 9: 2.15s to complete\n",
" error(train)=4.66e+00, error(valid)=4.65e+00, params_penalty=0.00e+00\n",
"Epoch 10: 2.15s to complete\n",
" error(train)=4.65e+00, error(valid)=4.64e+00, params_penalty=0.00e+00\n",
"Epoch 11: 2.18s to complete\n",
" error(train)=4.65e+00, error(valid)=4.63e+00, params_penalty=0.00e+00\n",
"Epoch 12: 2.15s to complete\n",
" error(train)=4.64e+00, error(valid)=4.63e+00, params_penalty=0.00e+00\n",
"Epoch 13: 2.15s to complete\n",
" error(train)=4.64e+00, error(valid)=4.62e+00, params_penalty=0.00e+00\n",
"Epoch 14: 2.44s to complete\n",
" error(train)=4.63e+00, error(valid)=4.62e+00, params_penalty=0.00e+00\n",
"Epoch 15: 2.39s to complete\n",
" error(train)=4.63e+00, error(valid)=4.61e+00, params_penalty=0.00e+00\n",
"Epoch 16: 2.34s to complete\n",
" error(train)=4.63e+00, error(valid)=4.61e+00, params_penalty=0.00e+00\n",
"Epoch 17: 2.52s to complete\n",
" error(train)=4.62e+00, error(valid)=4.61e+00, params_penalty=0.00e+00\n",
"Epoch 18: 1.59s to complete\n",
" error(train)=4.62e+00, error(valid)=4.61e+00, params_penalty=0.00e+00\n",
"Epoch 19: 1.83s to complete\n",
" error(train)=4.62e+00, error(valid)=4.61e+00, params_penalty=0.00e+00\n",
"Epoch 20: 2.29s to complete\n",
" error(train)=4.63e+00, error(valid)=4.62e+00, params_penalty=0.00e+00\n",
"Epoch 21: 2.39s to complete\n",
" error(train)=4.62e+00, error(valid)=4.61e+00, params_penalty=0.00e+00\n",
"Epoch 22: 2.19s to complete\n",
" error(train)=4.62e+00, error(valid)=4.60e+00, params_penalty=0.00e+00\n",
"Epoch 23: 2.19s to complete\n",
" error(train)=4.62e+00, error(valid)=4.60e+00, params_penalty=0.00e+00\n",
"Epoch 24: 2.21s to complete\n",
" error(train)=4.61e+00, error(valid)=4.60e+00, params_penalty=0.00e+00\n",
"Epoch 25: 2.13s to complete\n",
" error(train)=4.61e+00, error(valid)=4.60e+00, params_penalty=0.00e+00\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.text.Text at 0x7f8f8728fb10>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqQAAAF5CAYAAABJDjjFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzs3Xl8VPW9//HXd2aSTBYSQkIIa1gHgoCaoIhUShWxWvWq\neNVQrYiillYstnJtq9VWqxZbqVjb4r7dRqTUn7WtV8WKWDdq4lIRCMgusoQ1kHVmvr8/JonZk0km\nOVnezz7O48x8z/ec85kM9/o+23eMtRYREREREae4nC5ARERERHo2BVIRERERcZQCqYiIiIg4SoFU\nRERERBylQCoiIiIijlIgFRERERFHKZCKiIiIiKMUSEVERETEUQqkIiIiIuIoBVIRERERcVRYgdQY\n4zLG3GmM2WyMKTbGbDLG3NqC9aYZY/KMMaXGmAJjzJWtL1lEREREuhNPmP1vAa4DvgN8BkwEnjTG\nHLLW/q6hFYwxQ4G/Ab8HZgHTgUeNMbusta+1sm4RERER6SaMtbblnY15CdhtrZ1bo+3PQLG19juN\nrPMr4Gxr7YQabblAkrX2nFZXLiIiIiLdQrj3kL4DnGGMGQVgjDkemAL8o4l1TgFW1ml7BZgc5r5F\nREREpBsK95L9vUAisN4YEyAUaH9qrX2uiXXSgT112vYAicaYGGttWZg1iIiIiEg3Em4gvZTQfaCX\nEbqH9ATggcr7QZ+JVFHGmBTgLGArUBqp7YqIiIhIxHiBocAr1tr9bdlQuIF0EXCPtXZ55fu1lQ8t\n/RhoLJDuBvrVaesHHGni7OhZwP+GWZuIiIiIdLxvA39qywbCDaRxQKBOW5Cm70V9Fzi7TtuMyvbG\nbAV49tlnyczMDLNE6YoWLFjA4sWLnS5DOoi+755F33fPou+751i3bh2XX345VOa2tgg3kL4E3GqM\n2QmsBbKABcCjVR2MMXcDA621VWON/hH4XuXT9o8DZwAXA009YV8KkJmZSVZWVpglSleUlJSk77oH\n0ffds+j77ln0ffdIbb69MtxA+n3gTuAhIA3YBfyhsq1Kf2Bw1Rtr7VZjzLeAxcB8YCdwtbW27pP3\nIiIiItIDhRVIrbXHgJsqp8b6XNVA22ogO+zqRERERKTb02/Zi4iIiIijFEilU8jJyXG6BOlA+r57\nFn3fPYu+b2mNsH46tKMYY7KAvLy8PN0YLSIi0oDt27dTWFjodBnSjaWmpjJkyJBGl+fn55OdnQ2Q\nba3Nb8u+wn2oSURERBy2fft2MjMzKS4udroU6cbi4uJYt25dk6E0UhRIRUREupjCwkKKi4s1Xre0\nm6oxRgsLCxVIRUREpHEar1u6Cz3UJCIiIiKOUiAVEREREUcpkIqIiIiIoxRIRURERMRRCqQiIiIi\nrbBo0SLGjh3bYfu74447cLlaF92WLl1KRkYGFRUVEa4qMhRIRURERMJUVFTEokWLuOWWW6rbSkpK\n+PnPf87q1avbZZ/GmFYH0tmzZ1NeXs7SpUsjXFVkKJCKiIiIhOmxxx4jEAhw2WWXVbcVFxfz85//\nnFWrVrXLPm+77bZW/xhCTEwMV155Jffff3+Eq4oMBVIRERHpFsrKymjsJ9Ej8atWNbfx5JNPcv75\n5xMdHV3dFu7PsYdbk8vlqrW/cF1yySVs3bq13QJzWyiQioiISKeza9cu5syZQ3p6Ol6vl3HjxvHE\nE09UL3/zzTdxuVwsW7aMW2+9lUGDBhEfH09RURFPPvkkLpeL1atXM2/ePPr168fgwYOr1/3www85\n++yzSUpKolevXkyfPp3333+/1v6feuqpRrexZcsWPvnkE6ZPn17df9u2baSlpWGMqb7X0+Vy8Ytf\n/AIIXTLv1asXmzdv5pxzziExMZHLL78cgH/9619ccsklZGRk4PV6GTJkCDfddBOlpaW1amroHlKX\ny8X8+fN58cUXGT9+fPXf6pVXXqn3N83KyqJPnz68+OKLrflK2pV+qUlEREQ6lb179zJp0iTcbjfz\n588nNTWVl19+mauvvpqioiLmz59f3ffOO+8kJiaGm2++mbKyMqKjozHGADBv3jzS0tK4/fbbOXbs\nGABr165l6tSpJCUlccstt+DxeFi6dCnTpk1j9erVnHTSSbVqaWgb77zzDsaYWr+S1bdvX/74xz9y\n/fXXc9FFF3HRRRcBMGHCBCB0/6ff7+ess87itNNO4ze/+Q1xcXEALF++nJKSEubNm0dKSgpr1qzh\nwQcf5IsvvmDZsmXV+zDGVH+2mt566y3+8pe/MG/ePHr16sWSJUu4+OKL2b59O8nJybX6ZmVl8fbb\nb7fui2lHCqQiIiLSqfzkJz/BWstHH31E7969Abj22muZNWsWd9xxB9ddd11137KyMvLz8xu8lJ2a\nmsrrr79eK8Tdeuut+P1+3n77bTIyMgC44oorGD16NAsXLuSNN95odhsbNmwAYNiwYdVtcXFxzJw5\nk+uvv54JEyYwa9asevWUl5dz6aWXctddd9VqX7RoETExMdXvr7nmGkaMGMFPf/pTdu7cyaBBg5r8\ne61fv55169YxdOhQAKZNm8bxxx9Pbm4u8+bNq9V3+PDhPPvss01uzwkKpCIiIt1YcTGsX9/++xkz\nBipP+LXZX/7yFy699FICgQD79++vbp8xYwbLli0jPz+/um327NkNhlFjDHPnzq0VJIPBIK+99hoX\nXnhhdRgFSE9PZ9asWTz66KMcPXqUhISERrcBsH//fjweT/UZznBcf/319dpqhtHi4mJKSkqYPHky\nwWCQDz/8sNlAeuaZZ1aHUYDx48eTmJjI5s2b6/VNTk6mpKSE0tJSvF5v2PW3FwVSERGRbmz9esjO\nbv/95OVBjSvYrbZv3z4OHTrEww8/3OAQRcYY9u7dW33mtGYQq6vusn379lFcXIzP56vXNzMzk2Aw\nyI4dO8jMzGx0G23h8XgaDJc7duzgtttu46WXXuLgwYPV7cYYDh8+3Ox2a94fWyU5ObnWtqpUPXjV\n0KV/JymQioiIdGNjxoTCYkfsJxKCwSAAl19+OVdeeWWDfSZMmMDatWsBiI2NbXRbTS1rqYa2kZKS\ngt/v59ixY8THx7d4WzXPhFYJBoNMnz6dQ4cO8eMf/5jRo0cTHx/PF198wZVXXln992iK2+1usL2h\np/4PHjxIXFxcg7U4SYFURESkG4uLi8yZy47St29fevXqRSAQ4PTTT4/4tuPi4qrvAa1p3bp1uFyu\nBs821jWmMn1v2bKFcePGVbe35qzjf/7zHzZu3MgzzzzDt7/97er2lStXhr2tltiyZUutM8CdhYZ9\nEhERkU7D5XIxc+ZMVqxYUX0WtKbCwsI2bXvGjBm8+OKLbN++vbp9z5495Obmctppp1XfP9qUyZMn\nY63lgw8+qNVedU/poUOHWlxT1dnNumdCf/vb37bLZfX8/HxOPfXUiG+3rXSGVERERDqVe++9l1Wr\nVjFp0iTmzp3L2LFjOXDgAHl5efzzn/9sUShtbJD6u+66i5UrVzJlyhTmzZuH2+3m4Ycfpry8nEWL\nFrVoG8OGDWPcuHGsXLmS2bNnV7d7vV7Gjh3LsmXLGDVqFH369GHcuHEcd9xxjdY5ZswYRowYwQ9/\n+EN27txJYmIiK1asCCvUtlReXh4HDhzgggsuiPi226rHniEtKYEW3JYhIiIiHSwtLY01a9YwZ84c\nXnjhBW644QaWLFnCoUOHaoXGps4gNrZs7NixvPXWW4wfP557772XO++8k2HDhrFq1SomTpzYom0A\nzJkzh5deeomysrJa7Y899hgDBw7kpptuYtasWaxYsaLJ7Xk8Hv72t79x4okncu+99/KLX/yC0aNH\n8/TTT7foczU2NmlD7cuXLycjI4Np06Y1+rmcYsL9mauOYIzJAvLy8vJqDTobKe+8A1/7GqxbB6NH\nR3zzIiIi7So/P5/s7Gza67+T0rwjR44wYsQIFi1axFVXXeV0Oc0qLy9n6NCh/OQnP+H73/9+s/1b\n8m+sqg+Qba3Nb7BTC/X
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8fa854bf50>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"stats, keys, run_time = optimiser.train(num_epochs=num_epochs, stats_interval=stats_interval)\n",
"\n",
"# Plot the change in the validation and training set error over training.\n",
"fig_1 = plt.figure(figsize=(8, 4))\n",
"ax_1 = fig_1.add_subplot(111)\n",
"for k in ['error(train)', 'error(valid)']:\n",
" ax_1.plot(np.arange(1, stats.shape[0]) * stats_interval, \n",
" stats[1:, keys[k]], label=k)\n",
"ax_1.legend(loc=0)\n",
"ax_1.set_xlabel('Epoch number')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using the function defined in the cell below (from the first lab notebook), plot a batch of the original images and the autoencoder reconstructions."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def show_batch_of_images(img_batch, fig_size=(3, 3), num_rows=None):\n",
" fig = plt.figure(figsize=fig_size)\n",
" batch_size, im_height, im_width = img_batch.shape\n",
" if num_rows is None:\n",
" # calculate grid dimensions to give square(ish) grid\n",
" num_rows = int(batch_size**0.5)\n",
" num_cols = int(batch_size * 1. / num_rows)\n",
" if num_rows * num_cols < batch_size:\n",
" num_cols += 1\n",
" # intialise empty array to tile image grid into\n",
" tiled = np.zeros((im_height * num_rows, im_width * num_cols))\n",
" # iterate over images in batch + indexes within batch\n",
" for i, img in enumerate(img_batch):\n",
" # calculate grid row and column indices\n",
" r, c = i % num_rows, i // num_rows\n",
" tiled[r * im_height:(r + 1) * im_height, \n",
" c * im_height:(c + 1) * im_height] = img\n",
" ax = fig.add_subplot(111)\n",
" ax.imshow(tiled, cmap='Greys', vmin=0., vmax=1.)\n",
" ax.axis('off')\n",
" fig.tight_layout()\n",
" plt.show()\n",
" return fig, ax"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAC+CAYAAABqOvflAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnXd4VOeZt+8pGnUJVSTUCyCEJFRBVCGQ6b2YZmOMcTdx\nNtlks1/aZst17W6ycRw7m9gG2+CCkenIogkBQgXUC5JQL6iXUZ9evj+4NBcy2AbNCMfs3H9hzfF5\nzpk57++87/M+RaDX6zFjxowZM8Yj/L4vwIwZM2aeFMyCasaMGTMmwiyoZsyYMWMizIJqxowZMybC\nLKhmzJgxYyLMgmrGjBkzJsIsqGbMmDFjIsyCasaMGTMmwiyoZsyYMWMizIJqxowZMybCLKhmzJgx\nYyLMgmrGjBkzJsIsqGbMmDFjIsyCasaMGTMmwiyoZsyYMWMixN/3BTwC5sKtZsyY+b4RfNuHPyRB\nNWPGzBOKQqGgra0NjUZDYGAgYvEPU5rMS34zZh4BnU6HSqVCqVSiVCpRq9WYu16MH71eT3V1Nb/5\nzW9YtmwZv/71r6mqqvq+L2vcmAXVzLjQaDQUFhbyxhtvsGDBAn7+85+TkZFBd3f3931pE4ZWqyUn\nJ4e9e/cSFRVFVFQUTz/9NIWFhd/3pY0btVrNoUOHWLp0KaGhoWzatInjx48zNDQ04bZVKhVpaWn8\n8z//MwcOHKC5uRmtVsukSZMm3PZE8cOcV5v5Xunv7+fEiRMcPnyYO3fuoNPpyM3NBaCuro7nn3/+\ne7kunU7HyMgIHR0duLq64uTkZLJzK5VKvvrqK/7nf/6HsrIyRkZGAOju7ubQoUOEhIRga2trMnuP\ni5GREW7evElmZiYqlYqmpibKysqoqKjgpz/9KTY2NhNit6SkhM8++4zz58/T3t6Oj48Pa9asYdOm\nTbi7u0+IzcfB/zlBzc3N5dNPP2XmzJns27cPkUhk0vPrdDq0Wi29vb3cuHGDkpISw2fW1tZERETw\n1FNPmdzuRDC6nBWLxQiFdxczCoWCgwcPcvr0aeLj4/nNb36DlZUV58+fJzk5GZ1O99iuraysjJs3\nbxIUFERsbCy5ubmcO3eOwsJCVq9ezS9+8QuT2NLpdKSmpvKHP/yB1tZWEhISmDdvHnK5nJMnT3Lq\n1ClsbW3593//9x/E73ovtra27Nixg/DwcGQyGdeuXSMtLY0TJ04QHBzMjh07TGpPo9Fw5coV3n//\nfTIzM3FycmLfvn2sXr2a4OBgXFxcsLCwMKnNx8n/GUHV6XQ0NTVx8uRJjh49ysqVK9m7d69JB4BU\nKqWgoICzZ89SWFhIb28v/f39hs9FIhHR0dFMmTKFiIgIk9jUarUMDw9TUlJCdXU1VVVVFBUVYWVl\nxYIFC9i3bx+urq7jOvfw8DBCoRBbW1uEQiF6vZ68vDyuXr3KzJkz2bNnD1OnTkUoFFJbW8vw8DBS\nqRSVSoVEIjHJ/T0IrVbLpUuXeOutt6itrcXW1hZnZ2ekUikdHR0MDAyg0+lMJqh6vZ6+vj5mzpzJ\nL3/5SwICAnB1dUWlUhEUFMTvf/97Ll68yM6dOwkPD3/k8+t0OgYHByktLaWoqIjCwkJaW1uxt7cn\nIiKC0NBQ/P39mTNnjknu517EYjGxsbGEh4ej0WhwcnKipqaGhoYGCgoK2Lp1q8k2iFQqFfn5+Xzw\nwQekpaWRkJDA3r17iYmJwd3dfUKEVKPRkJuby2effUZ3dzcxMTEkJSURHByMhYUFFhYWaDQadDod\nGo0GkUiEtbX1uHXhiRdUjUZDS0sL169f59y5c2RmZuLm5sayZctMJqZqtZpLly6RkpJCUVER1dXV\n9PX1PfBYlUrFJ598wu9+9zujl1MymYzMzExOnDhBQUEBw8PDhkGhVqv5+OOPEQgE/NM//dO4zm9t\nbY1AIDDMTrVaLefOncPW1pYtW7YYHkoAb29vAgICGBgYQC6XT6ig5uXlcfDgQTIzM1EqlQ885pu+\n/0dFqVRSUVFBaWkpUqkUnU5HcHAwEokEnU5HYmIi5eXlfPnll1RXVz+yoPb19ZGWlkZKSgoVFRX0\n9vbS3d2NTCbDwsKC3NxcnJycWLt27YQIqkAgwNraGmtrawBmz55NbGwsR44cobW1lYGBAVxcXIy2\no9fryc/P59133yU9PZ21a9fy+uuvEx4ePmFuBYArV67wt7/9jStXrqBQKLhx4wanTp3C0dERT09P\nXF1d6enpQS6XMzAwgLe3N/v27Rv3d/1ECqper6e9vZ38/HxycnIoKyujvr6elpYWRkZGCA0NJT4+\n3iAUxqBSqUhJSeG9994jPz+fwcFBtFrtNx7f19fHyZMncXFx4bXXXsPBwWFcdtvb20lOTub48ePU\n19cTGRnJc889h5eXF25ubqjVak6dOkVGRgarV68mLCzskW1YWlqO+W+5XM7g4CAxMTHExsYaRFOv\n1zM4OEhfXx9+fn4TtmRTq9WGAXL16lWDmAqFQoNNOzs7FixYwJ49e4yypVQqKS8v5+LFi1y7do3q\n6mrc3Nxob29nZGQEiURimL17eXkhlUpJSUlh8+bND23jzp07vP3221y6dImGhgZkMhl+fn44Ojoi\nk8nQaDR0dnbS2dlJbGysUffzsIhEIkQikeE7NdVvWVBQwP/+7/+SkZHBmjVreOONN5g1a9aELu/V\najXXr18nJyeHgYEBAFpbW2ltbQXAxsYGa2tr5HI5Go0GlUpFYGAgixYtMgvqKHK5nJSUFM6ePUtL\nSwvu7u7ExcUxbdo0Ll68SHd3N5GRkXh6ehptS6VScebMGf76179y8+ZNZDIZEokEd3d3g18vKCiI\n8vJyjh07Rk9PD1qt1uB6WL9+/bgEVaFQcPz4cT744AOEQiG7d+9m3bp1hIaGYmdnh1AoRK1WI5VK\nycvLIy8vb1yCKhCMjWFWq9VIJBImT56MnZ2d4e91dXWkp6czMjJCcHAwVlZWj2zru1CpVNy4cYMD\nBw5w6dIlJk+ezLp16/D09MTZ2RkvLy/g7ksgICDAKJdKVVUVZ8+e5cqVK9y6dQutVsucOXNYt24d\nc+fONczm4K4PMiAgAIFAQGlp6SPZaW5u5tq1a9y6dQuBQEBcXBybN28mJyeHU6dOjfv6jaGpqYma\nmhp8fX2Ji4szyexRr9eTnJzMxYsXWbhwIS+++CIRERET7iuVy+XU1NQglUof+LlMJkMmk435m06n\nQ6FQjNvmEyWoGo2GY8eO8eGHH2JpaUlSUhLz5s1jxowZpKWlcebMGTw9PYmJiRkzKMaDWq0mJSVl\njJgCLFq0iM2bNzN9+nQCAgJwd3enpKSE1tZWzp49C9xdOre0tJCRkcGMGTMe2XZJSQkpKSnIZDL2\n7t3Ls88+i7e3t8GFodfr6ezspLy8HKVSiUajMepeR7G0tEShUFBfX09PTw+TJ09GoVCQmZlJSkoK\nQUFBrF271iQz/69TUlLCRx99RHp6OiqVivj4eH7yk5/g4uKCra2tSUJtVCoVJSUlHD58mK+++orW\n1lZWrFjBxo0bCQ0NZcaMGdjb24/5f0QiEXZ2dggEgkceiHV1dQwODhr+OyYmxvDCfZx0dXWRmZnJ\nrVu3KC4upqamhvj4eBISEoz2n2q1Wq5cuUJ6ejr+/v7s2rWLyMjICXUJjaLRaBgeHkatVgMwb948\n5s6di0QiobGxkebmZiwtLZFKpRQXF5vE5hMlqEqlksuXLzMwMMCrr77Khg0bcHZ2Rq1W09HRwfDw\nMElJScyePdtoO+3t7XzxxRdjZqaLFi3itddeIykpacwMbsaMGTz//PMIBALy8/Npa2tDKpVy9uxZ\nXn755Ue2X1lZyZ07d/D29iY+Ph5fX98xs0mpVMq5c+c4fvw4Xl5ezJ8/36j7HcXKyoq5c+dy+fJl\ngyuhuLiYkydPolarSUxMNNlm29e5ceMGV65cQSqVIhKJaGtro6amBk9PTxwdHY0+v0ql4uTJk3z5\n5ZdkZmai1+tZtWoVr73
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8f8753bd90>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAC+CAYAAABqOvflAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsvfdzW1d+//0iQXSiEQR7700sEtUsyZZkyWUTx9lNm2Qz\nyT+WyUySmTQnm/F3veu15aJGVYqkSLGKvReAANE7nh/0nOMLiAWk6E3k5XtGI0oA77mnfXrJSaVS\nnOAEJzjBCd4cuf/bL3CCE5zgBD8VnBDUE5zgBCc4JpwQ1BOc4AQnOCacENQTnOAEJzgmnBDUE5zg\nBCc4JpwQ1BOc4AQnOCacENQTnOAEJzgmnBDUE5zgBCc4JpwQ1BOc4AQnOCacENQTnOAEJzgmnBDU\nE5zgBCc4JpwQ1BOc4AQnOCacENQTnOAEJzgmnBDUE5zgBCc4JpwQ1BOc4AQnOCbk/W+/wCFwUrj1\nBP9nIOoI5+Tk/C+/yduJbOow/x9d231f6m0iqIdCKpWSfwRycnJ23aSjbpzy+eLnZDIpx8nNzSU3\nN1d+nkwm097jOA5MNvP8sQ5mKpUikUjIsVUqlZzvjzXebmu82/yVfx92jMzxcnNzycnJkWMnEgkS\niQTwas7Keb/JWmfupfKcZP6/wHHv7Y/JKMSzk8mkXMd4PE44HCYSiZCTk0NeXh5arRatVkteXp5c\n++N6nx+bEf6kCarYOHi1gCqV6lifL8bIHEtcMOWmKb8rPnuTS6gcXzw3lUrJeSrH343YvClisRh+\nvx+fz0csFkOr1WI2mzEYDOTlvdmx2o1JCSIWj8eBV2us0WjkWIKwi3UVROhNCHzmRU4mk4RCITwe\nD6FQCLVajclkknMW5+u4L6tyf5XvJeZ42PGU66skbIlEQhI1tVr92j4e1zlKJpNEo1F8Ph+bm5ss\nLCzgdDpRq9WUlJRQVVVFUVERRqPxRxUGlDiucX6SBDWZTL52QFQqFalUing8TjKZJDc3F5VKdWRi\nlklIE4mEfK6Ss4rDE41GSaVS5OXlvTHnzSQ0gpArx/6xkEql8Pv9zM7OMjAwwOjoKDs7O5SWlnL+\n/Hl6e3txOBxv9A5K4iHmGIlECIfDJBIJtFotOp0OtVqd9nuxWAx4XWo8DPbbD7/fz7Nnz3j06BE+\nn4+Kigo6Ojqora3FZrNJqeqolzPz9zKJeTKZlGc2k1lnM6byzMRiMYLBIC6Xi6WlJWZnZ1lbWyOV\nSlFbW8vZs2epr69Hp9OlPUN51g47TyUDiMViuFwuRkZGePz4MQsLC9hsNt555x3KysrSmOVxElXx\nDoBch2QyKZnIm96dnxRBFYdFELCcnBz0er08hLFYjJ2dHYLBIHq9HovFglarPfQ4mdKhUgVUSg7h\ncBi/308wGCQej6NWqzEajej1+iMdSOXYYqxMorIXkT6u3mGRSITZ2Vnu3LnDwMAAq6urBAIBFhYW\nCIVC6HQ6zp8/T35+/pHHUM5BrFMikZA/ZxKuZDJJJBLB7XYTjUYxGAzk5+ej0WjSNIE3QSAQ4OnT\np/zbv/0bAwMDaLVaampq2NraYnNzk1OnTlFWVoZer38jFXUvdX63n48iZQkzjZKgTk1Ncf/+fQYH\nB/F6vTQ0NLCzs0N+fj7V1dWv/b7y78POM5FI4Ha7GR8fp7+/n/7+fsbHx0kmk1gsFiwWCzabTe7d\nXqaOw0IIU4CkB0KjEHf3OPCTIahCKo1GoyQSCakSqtVquRHBYJDp6Wm2traoqKhAp9MdiqBmqvlA\nmo00kUgQDodxOp2srq4yMzPD0tISgUAAk8lERUUFjY2NVFVVYbVas74AyrHhdRusIDjBYJCdnR2i\n0ShqtRqdTofBYECn0x2LuSORSLCxscH4+Dhut5uOjg4++ugj4vE44+Pj8nI2NTUdmqBmzk0JlUqF\nVquVe5lIJHA6naytreF2u0kkEgQCAVZXVwmFQlRUVHDq1ClKS0sP1ATEXmZKLkKbUKlURKNRFhcX\nGR0dJRgM0traSkVFBWq1ms3NTfx+P8lkEqPRmCZZHRZCEg+FQpIBK5mHsMkrJT1BDLJl0EJyV6vV\nkmgJyd9mszExMcHq6ipfffUVZWVlOBwODAaD3KOjmBgEotEoCwsL3Lt3j7t377KysoJWq+Xs2bN0\ndHRw7tw5GhoasFgsUqNU2ssPGlu5l8o9DwaDbG1t4ff7MZlMOBwO9Hp92nocF34yBFUcOK1Wmybd\n5Obmkkgk8Hg8PHv2jK+//hqPx8Ply5cpLi7GaDSmLWg2hyVzw4SEtL6+ztTUFM+fP6e/v5+pqSl8\nPp9UU8vLy7l8+TLXr1/n1KlTFBQUyEO9G5ROCqWNUCmdRSIRnE4n09PTjI6OsrS0RCwWw2w2U1JS\nQnNzMx0dHdjt9qzWMZNpKFVMn8/Hy5cvpXp24cIFmpqaiEaj5Ofn89vf/paVlRX8fr98llJV3evg\nZjpjMomDkCbEu3g8Hh48eMDnn3/O8+fPgVcXw+12k5eXR2dnJ7FYjMuXL6PVavc07SjHzbQPKgmX\n1+tlaWmJSCTC2bNn6erqor6+Hp/Px+3bt+nr60Or1VJfX58mXWWDZDJJIBBgfX2d2dlZ5ufnpU0x\nLy+P4uJiWlpa6O7upqqqiry8vDQmK8xa2RAcJQMWJq+ioiIMBgNNTU14vV6ePHnCv/zLvzA0NMT3\n339Pd3c3jY2NacTtKOaqeDzO0tISX331FV988QXb29t0dHRw/fp1uru7KS4uxmw2y7UTAhKQFYPK\n9GPk5eWRTCbZ2dlhcHCQZ8+esbOzQ1lZGS0tLZSWlmIymSRhDQQC8hzb7XY0Go18bub67YefFEFV\nOgUikQibm5u43W68Xi+zs7N8/fXX3L59G7vdTldX15HsUJkHKScnh0AgwPDwMF999RX37t1jbm4O\nt9tNIBBI25D19XVcLhexWAy1Wk1bWxs2my1NilYi0yGTSqWkih+NRlleXubZs2cMDQ0xPT3N5uYm\niUQCq9VKYWEhbrebZDJJSUkJNpstK04sxhS2ZiGpJRIJtre3WV5eJhAIUFNTQ2VlJRaLhXA4jEql\nYmtrK20MceGFbfmgMZVMQ0gogqiLS7y9vc2TJ0/41a9+xa1bt9jc3AReEVQh1anVai5evEg0Gj1w\nviJKIBPiUjudTgYGBnjy5AmhUIgzZ87Q0dFBUVERoVCI9fV17t27x/z8PCsrKzQ0NOxpXslUXePx\nOGtrazx48IB79+4xNTWF0+nE6XRKpqTX6zl16hR/93d/h8PhkDZNsa5KIndYCOYs1OzS0lKCwSCl\npaWMjo4SCASkjXE3BnsYeL1e+vv7uXv3Ll6vl9OnT/PJJ59w9uxZCgoK0ohmPB4nEomQSqWkCWU/\nKImpIKjiGRMTE3zxxRd8++23eDwe7HY7DocDjUaDXq+ntLQUm82G1+slGo3S3t7OlStXqK+vT2OM\n2a7vT4KgZqo9TqeTZ8+eMTo6SiQSIS8vj6WlJQYHB1lfX6eyspKGhgapdmdrX9xtY71eLw8fPuSf\n//mfuXfvHpubm3LMiooKampq8Pl8TE5OEgqFmJ+f586dOzgcDqxWK/n5+bs6MpSSk1KVyc3NJZlM\nMjs7y69+9Su+//57otEoJSUldHZ2Ul5eTklJCclkkpWVFXw+n1SLsz0USq+ycm3i8TipVIrCwkKq\nq6vl+oXDYVZXV1lfX6eiogKz2Zw2j/2wmz0afiAYAFqtltzcXLa3t7l79y6fffYZDx8+JBaLUV1d\nTUlJCWazmUAgQCgUoqamRkpe+zkelRJb5vzdbjeTk5M8evSIBw8e4HQ6OXXqFBaLRV5yjUZDUVER\nBQUFrK6usr29nbU9TtgSb9++zT//8z/z/PlzdDqdJC6hUIhwOIzb7Uav17O8vCyl0cyojv0Izm7n\nSJwDJcMUDCkajRKPxykuLqa1tZWioqK0c3cUwh0Oh3nx4gUPHz6UpqIPP/yQ3t5eCgsLpSQZDAbT\nNBmhPWZDvJUEVaxRIBB
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8f8753b910>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"inputs, targets = valid_data.next()\n",
"recons = model.fprop(inputs)[-1]\n",
"_ = show_batch_of_images(inputs.reshape((-1, 28, 28)), (4, 2), 5)\n",
"_ = show_batch_of_images(recons.reshape((-1, 28, 28)), (4, 2), 5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Optional extension: principle components analysis\n",
"\n",
"*This section is provided for the interest of those also sitting MLPR or otherwise already familiar with eigendecompositions and PCA. Feel free to skip over if this doesn't apply to you (or even if it does).*\n",
"\n",
"For a linear (affine) autoencoder model trained with a sum of squared differences error function there is an analytic solution for the optimal model parameters corresponding to [principle components analysis](https://en.wikipedia.org/wiki/Principal_component_analysis).\n",
"\n",
"If we have a training dataset of $N$ $D$-dimensional vectors $\\left\\lbrace \\boldsymbol{x}^{(n)} \\right\\rbrace_{n=1}^N$, then we can calculate the empiricial mean and covariance of the training data using\n",
"\n",
"\\begin{equation}\n",
" \\boldsymbol{\\mu} = \\frac{1}{N} \\sum_{n=1}^N \\left[ \\boldsymbol{x}^{(n)} \\right]\n",
" \\qquad\n",
" \\text{and}\n",
" \\qquad\n",
" \\mathbf{\\Sigma} = \\frac{1}{N} \n",
" \\sum_{n=1}^N \\left[ \n",
" \\left(\\boldsymbol{x}^{(n)} - \\boldsymbol{\\mu} \\right)\n",
" \\left(\\boldsymbol{x}^{(n)} - \\boldsymbol{\\mu} \\right)^{\\rm T}\n",
" \\right].\n",
"\\end{equation}\n",
"\n",
"We can then calculate an [eigendecomposition](https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix) of the covariance matrix \n",
"\\begin{equation}\n",
" \\mathbf{\\Sigma} = \\mathbf{Q} \\mathbf{\\Lambda} \\mathbf{Q}^{\\rm T}\n",
" \\qquad\n",
" \\mathbf{Q} = \\left[ \n",
" \\begin{array}{cccc}\n",
" \\uparrow & \\uparrow & \\cdots & \\uparrow \\\\\n",
" \\boldsymbol{q}_1 & \\boldsymbol{q}_2 & \\cdots & \\boldsymbol{q}_D \\\\\n",
" \\downarrow & \\downarrow & \\cdots & \\downarrow \\\\\n",
" \\end{array}\n",
" \\right]\n",
" \\qquad\n",
" \\mathbf{\\Lambda} = \\left[ \n",
" \\begin{array}{cccc} \n",
" \\lambda_1 & 0 & \\cdots & 0 \\\\\n",
" 0 & \\lambda_2 & \\cdots & \\vdots \\\\\n",
" \\vdots & \\vdots & \\ddots & 0 \\\\ \n",
" 0 & 0 & \\cdots & \\lambda_D \\\\ \n",
" \\end{array} \\right]\n",
"\\end{equation}\n",
"\n",
"with $\\mathbf{Q}$ an orthogonal matrix, $\\mathbf{Q}\\mathbf{Q}^{\\rm T} = \\mathbf{I}$, with columns $\\left\\lbrace \\boldsymbol{q}_d \\right\\rbrace_{d=1}^D$ corresponding to the eigenvectors of $\\mathbf{\\Sigma}$ and $\\mathbf{\\Lambda}$ a diagonal matrix with diagonal elements $\\left\\lbrace \\lambda_d \\right\\rbrace_{d=1}^D$ the corresponding eigenvalues of $\\mathbf{\\Sigma}$. \n",
"\n",
"Assuming the eigenvalues are ordered such that $\\lambda_1 < \\lambda_2 < \\dots < \\lambda_D$ then the top $K$ principle components of the inputs (eigenvectors with largest eigenvalues) correspond to $\\left\\lbrace \\boldsymbol{q}_d \\right\\rbrace_{d=D + 1 - K}^D$. If we define a $D \\times K$ matrix $\\mathbf{V} = \\left[ \\boldsymbol{q}_{D + 1 - K} ~ \\boldsymbol{q}_{D + 2 - K} ~\\cdots~ \\boldsymbol{q}_D \\right]$ then we can find the projections of a (mean normalised) input vector on to the selected $K$ principle components as $\\boldsymbol{h} = \\mathbf{V}^{\\rm T}\\left( \\boldsymbol{x} - \\boldsymbol{\\mu}\\right)$. We can then use these principle component projections to form a reconstruction of the original input just in terms of the $K$ top principle components using $\\boldsymbol{r} = \\mathbf{V} \\boldsymbol{h} + \\boldsymbol{\\mu}$. We can see that this is just a sequence of two affine transformations and so is directly analagous to a model with two affine layers and with $K$ dimensional outputs of the first layer / inputs to second.\n",
"\n",
"The function defined in the cell below will calculate the PCA solution for a set of input vectors and a defined number of components $K$. Use it to calculate the top 50 principle components of the MNIST training data. Use the returned matrix and mean vector to calculate the PCA based reconstructions of a batch of 50 MNIST images and use the `show_batch_of_images` function to plot both the original and reconstructed inputs alongside each other. Also calculate the sum of squared differences error for the PCA solution on the MNIST training set and compare to the figure you got by gradient descent based training above. Will the gradient based training produce the same hidden representations as the PCA solution if it is trained to convergence?"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def get_pca_parameters(inputs, num_components=50):\n",
" mean = inputs.mean(0)\n",
" inputs_zm = inputs - mean[None, :]\n",
" covar = np.einsum('ij,ik', inputs_zm, inputs_zm)\n",
" eigvals, eigvecs = np.linalg.eigh(covar)\n",
" return eigvecs[:, -num_components:], mean"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAC+CAYAAABqOvflAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnXd4VOeZt+8pGnUJVSTUCyCEJFRBVCGQ6b2YZmOMcTdx\nNtlks1/aZst17W6ycRw7m9gG2+CCkenIogkBQgXUC5JQL6iXUZ9evj+4NBcy2AbNCMfs3H9hzfF5\nzpk57++87/M+RaDX6zFjxowZM8Yj/L4vwIwZM2aeFMyCasaMGTMmwiyoZsyYMWMizIJqxowZMybC\nLKhmzJgxYyLMgmrGjBkzJsIsqGbMmDFjIsyCasaMGTMmwiyoZsyYMWMizIJqxowZMybCLKhmzJgx\nYyLMgmrGjBkzJsIsqGbMmDFjIsyCasaMGTMmwiyoZsyYMWMixN/3BTwC5sKtZsyY+b4RfNuHPyRB\nNWPGzBOKQqGgra0NjUZDYGAgYvEPU5rMS34zZh4BnU6HSqVCqVSiVCpRq9WYu16MH71eT3V1Nb/5\nzW9YtmwZv/71r6mqqvq+L2vcmAXVzLjQaDQUFhbyxhtvsGDBAn7+85+TkZFBd3f3931pE4ZWqyUn\nJ4e9e/cSFRVFVFQUTz/9NIWFhd/3pY0btVrNoUOHWLp0KaGhoWzatInjx48zNDQ04bZVKhVpaWn8\n8z//MwcOHKC5uRmtVsukSZMm3PZE8cOcV5v5Xunv7+fEiRMcPnyYO3fuoNPpyM3NBaCuro7nn3/+\ne7kunU7HyMgIHR0duLq64uTkZLJzK5VKvvrqK/7nf/6HsrIyRkZGAOju7ubQoUOEhIRga2trMnuP\ni5GREW7evElmZiYqlYqmpibKysqoqKjgpz/9KTY2NhNit6SkhM8++4zz58/T3t6Oj48Pa9asYdOm\nTbi7u0+IzcfB/zlBzc3N5dNPP2XmzJns27cPkUhk0vPrdDq0Wi29vb3cuHGDkpISw2fW1tZERETw\n1FNPmdzuRDC6nBWLxQiFdxczCoWCgwcPcvr0aeLj4/nNb36DlZUV58+fJzk5GZ1O99iuraysjJs3\nbxIUFERsbCy5ubmcO3eOwsJCVq9ezS9+8QuT2NLpdKSmpvKHP/yB1tZWEhISmDdvHnK5nJMnT3Lq\n1ClsbW3593//9x/E73ovtra27Nixg/DwcGQyGdeuXSMtLY0TJ04QHBzMjh07TGpPo9Fw5coV3n//\nfTIzM3FycmLfvn2sXr2a4OBgXFxcsLCwMKnNx8n/GUHV6XQ0NTVx8uRJjh49ysqVK9m7d69JB4BU\nKqWgoICzZ89SWFhIb28v/f39hs9FIhHR0dFMmTKFiIgIk9jUarUMDw9TUlJCdXU1VVVVFBUVYWVl\nxYIFC9i3bx+urq7jOvfw8DBCoRBbW1uEQiF6vZ68vDyuXr3KzJkz2bNnD1OnTkUoFFJbW8vw8DBS\nqRSVSoVEIjHJ/T0IrVbLpUuXeOutt6itrcXW1hZnZ2ekUikdHR0MDAyg0+lMJqh6vZ6+vj5mzpzJ\nL3/5SwICAnB1dUWlUhEUFMTvf/97Ll68yM6dOwkPD3/k8+t0OgYHByktLaWoqIjCwkJaW1uxt7cn\nIiKC0NBQ/P39mTNnjknu517EYjGxsbGEh4ej0WhwcnKipqaGhoYGCgoK2Lp1q8k2iFQqFfn5+Xzw\nwQekpaWRkJDA3r17iYmJwd3dfUKEVKPRkJuby2effUZ3dzcxMTEkJSURHByMhYUFFhYWaDQadDod\nGo0GkUiEtbX1uHXhiRdUjUZDS0sL169f59y5c2RmZuLm5sayZctMJqZqtZpLly6RkpJCUVER1dXV\n9PX1PfBYlUrFJ598wu9+9zujl1MymYzMzExOnDhBQUEBw8PDhkGhVqv5+OOPEQgE/NM//dO4zm9t\nbY1AIDDMTrVaLefOncPW1pYtW7YYHkoAb29vAgICGBgYQC6XT6ig5uXlcfDgQTIzM1EqlQ885pu+\n/0dFqVRSUVFBaWkpUqkUnU5HcHAwEokEnU5HYmIi5eXlfPnll1RXVz+yoPb19ZGWlkZKSgoVFRX0\n9vbS3d2NTCbDwsKC3NxcnJycWLt27YQIqkAgwNraGmtrawBmz55NbGwsR44cobW1lYGBAVxcXIy2\no9fryc/P59133yU9PZ21a9fy+uuvEx4ePmFuBYArV67wt7/9jStXrqBQKLhx4wanTp3C0dERT09P\nXF1d6enpQS6XMzAwgLe3N/v27Rv3d/1ECqper6e9vZ38/HxycnIoKyujvr6elpYWRkZGCA0NJT4+\n3iAUxqBSqUhJSeG9994jPz+fwcFBtFrtNx7f19fHyZMncXFx4bXXXsPBwWFcdtvb20lOTub48ePU\n19cTGRnJc889h5eXF25ubqjVak6dOkVGRgarV68mLCzskW1YWlqO+W+5XM7g4CAxMTHExsYaRFOv\n1zM4OEhfXx9+fn4TtmRTq9WGAXL16lWDmAqFQoNNOzs7FixYwJ49e4yypVQqKS8v5+LFi1y7do3q\n6mrc3Nxob29nZGQEiURimL17eXkhlUpJSUlh8+bND23jzp07vP3221y6dImGhgZkMhl+fn44Ojoi\nk8nQaDR0dnbS2dlJbGysUffzsIhEIkQikeE7NdVvWVBQwP/+7/+SkZHBmjVreOONN5g1a9aELu/V\najXXr18nJyeHgYEBAFpbW2ltbQXAxsYGa2tr5HI5Go0GlUpFYGAgixYtMgvqKHK5nJSUFM6ePUtL\nSwvu7u7ExcUxbdo0Ll68SHd3N5GRkXh6ehptS6VScebMGf76179y8+ZNZDIZEokEd3d3g18vKCiI\n8vJyjh07Rk9PD1qt1uB6WL9+/bgEVaFQcPz4cT744AOEQiG7d+9m3bp1hIaGYmdnh1AoRK1WI5VK\nycvLIy8vb1yCKhCMjWFWq9VIJBImT56MnZ2d4e91dXWkp6czMjJCcHAwVlZWj2zru1CpVNy4cYMD\nBw5w6dIlJk+ezLp16/D09MTZ2RkvLy/g7ksgICDAKJdKVVUVZ8+e5cqVK9y6dQutVsucOXNYt24d\nc+fONczm4K4PMiAgAIFAQGlp6SPZaW5u5tq1a9y6dQuBQEBcXBybN28mJyeHU6dOjfv6jaGpqYma\nmhp8fX2Ji4szyexRr9eTnJzMxYsXWbhwIS+++CIRERET7iuVy+XU1NQglUof+LlMJkMmk435m06n\nQ6FQjNvmEyWoGo2GY8eO8eGHH2JpaUlSUhLz5s1jxowZpKWlcebMGTw9PYmJiRkzKMaDWq0mJSVl\njJgCLFq0iM2bNzN9+nQCAgJwd3enpKSE1tZWzp49C9xdOre0tJCRkcGMGTMe2XZJSQkpKSnIZDL2\n7t3Ls88+i7e3t8GFodfr6ezspLy8HKVSiUajMepeR7G0tEShUFBfX09PTw+TJ09GoVCQmZlJSkoK\nQUFBrF271iQz/69TUlLCRx99RHp6OiqVivj4eH7yk5/g4uKCra2tSUJtVCoVJSUlHD58mK+++orW\n1lZWrFjBxo0bCQ0NZcaMGdjb24/5f0QiEXZ2dggEgkceiHV1dQwODhr+OyYmxvDCfZx0dXWRmZnJ\nrVu3KC4upqamhvj4eBISEoz2n2q1Wq5cuUJ6ejr+/v7s2rWLyMjICXUJjaLRaBgeHkatVgMwb948\n5s6di0QiobGxkebmZiwtLZFKpRQXF5vE5hMlqEqlksuXLzMwMMCrr77Khg0bcHZ2Rq1W09HRwfDw\nMElJScyePdtoO+3t7XzxxRdjZqaLFi3itddeIykpacwMbsaMGTz//PMIBALy8/Npa2tDKpVy9uxZ\nXn755Ue2X1lZyZ07d/D29iY+Ph5fX98xs0mpVMq5c+c4fvw4Xl5ezJ8/36j7HcXKyoq5c+dy+fJl\ngyuhuLiYkydPolarSUxMNNlm29e5ceMGV65cQSqVIhKJaGtro6amBk9PTxwdHY0+v0ql4uTJk3z5\n5ZdkZmai1+tZtWoVr73
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8f86fc7150>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAC+CAYAAABqOvflAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsffdT22ee/0tdQg0hECB6x3QwYBOXJHZI2Umym1ybu90f\n7s+6H27mbm727uZudyf7dbKpthNXMMXYmN6rJFDvXfr+4Hs/efSxAAnI3jmr14wHA+LzfJ727kWU\nTqdRQAEFFFDA2SH+336BAgoooICfCwoEtYACCijgnFAgqAUUUEAB54QCQS2ggAIKOCcUCGoBBRRQ\nwDmhQFALKKCAAs4JBYJaQAEFFHBOKBDUAgoooIBzQoGgFlBAAQWcEwoEtYACCijgnFAgqAUUUEAB\n54QCQS2ggAIKOCcUCGoBBRRQwDmhQFALKKCAAs4JBYJaQAEFFHBOkP5vv0AeKBRuLaCAnwlOqsMs\nEon+TG+SN459sdeJoOaNdDqNVCrFvheJRFk36rSbl06n2cGg/9P3YrEYYrEYIpEo4+f8eGc9NMKx\n6ZlHzfO8QeubTqfZmGLxT6f00DxpT2l96XeEs6xvtovOj5FKpZBIJNg7SCQSSCQSNu+zrHu2c8KP\nze/xWcc6anx+7PMGf0eSySRby1gshng8jlQqBYlEAoVCAYVCAalUyvb4/zCBzcBfBEGlC89fQOHn\n8tkwOnj0bH4c4OUlEz7vKEJwFmJOYwrnKJyn8CKeBxKJBEKhEAKBABKJBJRKJTQaDZRK5ZmJKn+x\n+QuYSCQQj8eRTCYhkUggl8shk8kyPick7Oc552QyiUAgALfbjWg0CqVSCa1Wi6KiIkilUkgkkjOP\nedT55M+akHGd5uxmI2zJZBIikQhSqRRyuRxS6U9DHmi8QCAAp9MJq9WKw8NDpNNplJWVoba2FiaT\nCUVFRT8pg/4p8LMlqMlkkh0WumTpdJpxQpFIxCSL0xJT/h8dUP6ZIpEIyWQS0WgUsVgMACCTySCT\nybIS3VzHF14GAOyS0aX+KREIBLCxsYGnT59ibm4Ofr8f1dXVGBkZwcDAAAwGw7kRMto3IuDRaBRS\nqZRJMTROMpnMkHJIusl3f4/7rMfjwfj4OJ48eYJQKITKykp0dXWhubkZZWVlUKlUkEqlp567UNrm\nvyepmCRiIcPMZUw6q7RWwWAQbrcbu7u72NzchM1mQyqVQk1NDQYGBtDa2gq1Wp3xjLNKsXTvotEo\nfD4fFhcXMTExgY2NDahUKly6dAnl5eWQSqVnWstcIBSGeE3jtPhZEVQiMvF4nBEwhUIBmUwGkUiE\neDwOn8/HpIuioiJG3PJBNkKajUjHYjH4fD74fD4kk0koFArodDrI5fJTb5xQ7SXCAZy/SSMbgsEg\nFhcXcf/+fTx79gxWqxWBQACrq6vw+/1QKBQYHByESqU681j03kQY6ZLJ5XLI5fIMRplIJODxeBAK\nhSCXy6HVaqFSqdjnzgqfz4enT5/iiy++wOzsLEQiESoqKmCz2WCz2XDx4kVUV1dDrVaz83bWefPf\nSySSDAmc5k04iajykikR1VgsBofDgeXlZUxMTGBhYQEejwdVVVVwu93Q6/VobGx85Vm8kJLvvJLJ\nJNxuN5aXlzE9PY2JiQksLi4ikUigp6cHRqMRJSUlGWt4HvtHZySdTh95Z86jv97PhqDyKmEikYBY\nLGYXjxYvGAxifX0dHo8HNTU1qK6uhlwuB5Ab5z1KtScinkgkEI1GEQqFYLVasbGxgf39fYRCISgU\nCtTU1ODChQuor6+HRqPJm5AL7bP8BUulUggGg/B6vYhGo5BIJFCpVCgqKoJKpToXyTUej8NqtWJx\ncRFutxu9vb346KOPEAgEMDMzA5vNhoWFBTQ3NzOCyjObXJlINnOJRCKBWq1m8/Z6vdje3obb7YZE\nIkEsFsP29jY8Hg/Ky8vR09OD2tpatkfHqdL0fwLtManx0WgUm5ubmJubQzKZxMWLF2Eymdh6TE1N\nQSKRMJPHaSWrRCKBSCSCUCiEeDwOmUyGoqIiKJVKpooL50LvmoskzjN8mUwGuVyOVCqFSCQCsVgM\nk8mElZUV7O7u4v79+6irq0N5eTmTUnnt6DTzoz26d+8e7ty5g+3tbahUKvT09KCrqwsjIyNoa2uD\nTqdj51VI5E4zbigUgsvlQiAQgFKpRGlpKdRqdQbDPq9mpT8LgsoTOplMxlRBIjypVApOpxPj4+O4\nffs2QqEQ3n77bRQXF6OoqIg9A8htw4T2uXQ6jUgkAqvVivX1dTx//hwTExNYXV1FJBJhNr/Kykrc\nuHEDN2/eRFdXFwwGQ9ZLws+LvmZzSIhEIkSjUTidTqytrWFhYQH7+/uIx+PQaDQoLy9Ha2srurq6\nYDQa81pPocMtlUrB6/ViY2MDu7u7MBqNuHbtGlpbW+H3+yGRSPDll19ie3sbPp8PlZWV7O9SqdSx\nxJQfj5f0SQqjvyei4XQ6ce/ePfzud7/D3NwcFAoFlEols222t7cjkUhAr9czNfyoMUla45kTgAz1\nLxgMwmKxIBwOo6urCxcvXkR9fT28Xi/u37+Phw8fYn5+Hi0tLSgtLYVcLs+L6NDz19fXsbW1hd3d\nXRweHgIAzGYzenp60NfXB7PZDKVSmbEniUSCPec4qTGbtCeRSJipoq2tDYFAABMTE/j3f/93zM3N\n4dGjR+jr60NTUxMAsHU6jXaVTCaxv7+Pb7/9Fp999hkODg7Q3d2N999/H/39/TCZTNDr9Wx+tC+5\nSsLCM0uCjsfjwdzcHF68eAGfz4eysjI0NzfDbDZDrVZDpVJBLBYjFAohFApBo9GgrKwsL0GLx8+G\noBIH59Vth8MBj8cDn8+H5eVlfPHFF7h//z4qKyvR39+fIfrnimzOgGAwiNnZWdy7dw+PHj3C8vIy\nnE4nQqFQxt/abDa43W54PB5Eo1H09vbCaDQeaXbgL73QqxyLxbC7u4uJiQnMzMxga2sLTqcT6XQa\nxcXFKCkpgdvtRjweR1lZGfR6fU5OBn483gadSqXg9/thsVjgdrtRVVUFs9kMnU4HsVgMlUrFDiV/\n8HKJOuAJWyKRgEgkYipfMpl8hZg+ePAA//mf/4k7d+4wCVUmkyEej0MkEkEul8PpdCIej584XwJ/\ndvj3DgQCmJ2dxfj4OAKBAJqbm9Hc3Izy8nIYjUa4XC5MTk5if38fNpsNbW1tx86Tng28lEgdDgfG\nxsZw7949rK6uwu124/DwEC6XC+l0GiqVCt3d3fjNb36DDz74AEqlMmNviOic1oQkkUhQXFwMiUTC\nmKbRaMTi4iKCwSBbQyLep/E7AC/t7s+ePcP9+/fh8/kwMjKCTz75BENDQzAYDBkSKfkdRCIRVCrV\nK+YNIfjzw0f10N59/vnnePDgAVwuF0pLS2E2mxkhLSkpgdFoRCAQgM/nQ1dXF9577z00NjZCJpOx\nO5jr+r72BJU2gJdIHQ4HpqenmW1GLBZja2sL8/PzcLlcaGtrQ2NjI/R6fYbqDBzPhbL9LhAIYHx8\nHP/xH/+Bx48fw2KxIBQKQSKRoKamBg0NDfB4PFhaWkIsFsPGxgbS6TR0Oh30ej20Wi3buGyeed5e\nSwc5mUxiY2MDf/jDH/DNN98gGo2itrYWvb29qKysRFlZGQDAarXC5/PB4XCgsbHxRIKazbsOgDl5\nRCIRIpEIZDIZysvLodfrAbxUqQ4ODhAOh1FcXAyNRpOxXrlcQOE8+fch4uZyufDDDz/gv/7rvzA1\nNQWFQoELFy6gtLQUer2emVvMZjOTQI4ydQglUV7rSCaT8Hq92Nvbw9OnT/HNN99gd3cXg4ODMJvN\n0Gg0EIlEUCgUqKiogMlkwtraGtxud4aTMNva8kzS4/Hg3r17+Nd//VfMzMxAp9OhvLwccrmczcXj\n8UAmk2Fzc5NJo3Tmc9GqskluRKCIWdEapFIpxGIxpNNplJeXM4mb/oZ39uWDaDSKxcVFPHr0CE6n\nE11dXfjggw8wODgIo9G
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8fae837690>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"V, mu = get_pca_parameters(train_data.inputs)\n",
"hiddens = (inputs - mu[None, :]).dot(V)\n",
"recons = hiddens.dot(V.T) + mu[None, :]\n",
"_ = show_batch_of_images(inputs.reshape((-1, 28, 28)), (4, 2), 5)\n",
"_ = show_batch_of_images(recons.reshape((-1, 28, 28)), (4, 2), 5)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.59056138992\n"
]
}
],
"source": [
"hiddens = (train_data.inputs - mu[None, :]).dot(V)\n",
"recons = hiddens.dot(V.T) + mu[None, :]\n",
"print(error(train_data.inputs, recons))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 2: Non-linear <s>contractive</s> autoencoders\n",
"\n",
"Those who did the extension in the previous exercise will have just seen that for an autoencoder with both linear / affine encoder and decoders, there is an analytic solution for the parameters which minimise a sum of squared differences error.\n",
"\n",
"In general the advantage of using gradient-based training methods is that it allows us to use non-linear models for which there is no analytic solution for the optimal parameters. The hope is the use of non-linear transformations between the affine transformation layers will increase the representational power of the model (a sequence of affine transformations applied without any interleaving non-linear operations can always be represented by a single affine transformation).\n",
"\n",
"Train a autoencoder with an initial affine layer (output dimension again 50) followed by a rectified linear layer, then an affine transformation projecting to outputs of same dimension as the original inputs, and finally a logistic sigmoid layer at the output. As the only layers with parameters are the two affine layers which have the same dimensions as in the fully affine model above, the overall model here has the same number of parameters as previously.\n",
"\n",
"Again train for 25 epochs with 50 training examples per batch and use a uniform Glorot initialisation for the weights, and zero biases initialisation. Use our implementation of the 'Adam' adaptive moments learning rule (available in `mlp.learning_rules` as `AdamLearningRule`) rather than basic gradient descent here (the adaptivity helps deal with the varying appropriate scale of updates induced by the non-linear transformations in this model)."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"input_dim, output_dim, hidden_dim = 784, 784, 50\n",
"\n",
"weights_init = initialisers.GlorotUniformInit(rng=rng)\n",
"biases_init = initialisers.ConstantInit(0.)\n",
"\n",
"model = models.MultipleLayerModel([\n",
" layers.AffineLayer(input_dim, hidden_dim, weights_init, biases_init),\n",
" layers.ReluLayer(),\n",
" layers.AffineLayer(hidden_dim, output_dim, weights_init, biases_init),\n",
" layers.SigmoidLayer()\n",
"])\n",
"\n",
"error = errors.SumOfSquaredDiffsError()\n",
"\n",
"learning_rule = learning_rules.AdamLearningRule()\n",
"\n",
"num_epochs = 25\n",
"stats_interval = 1\n",
"optimiser = optimisers.Optimiser(\n",
" model, error, learning_rule, train_data, valid_data)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0:\n",
" error(train)=9.11e+01, error(valid)=9.06e+01, params_penalty=0.00e+00\n",
"Epoch 1: 5.30s to complete\n",
" error(train)=6.90e+00, error(valid)=6.79e+00, params_penalty=0.00e+00\n",
"Epoch 2: 4.22s to complete\n",
" error(train)=3.79e+00, error(valid)=3.79e+00, params_penalty=0.00e+00\n",
"Epoch 3: 4.58s to complete\n",
" error(train)=2.82e+00, error(valid)=2.84e+00, params_penalty=0.00e+00\n",
"Epoch 4: 4.22s to complete\n",
" error(train)=2.50e+00, error(valid)=2.52e+00, params_penalty=0.00e+00\n",
"Epoch 5: 4.22s to complete\n",
" error(train)=2.39e+00, error(valid)=2.41e+00, params_penalty=0.00e+00\n",
"Epoch 6: 4.32s to complete\n",
" error(train)=2.32e+00, error(valid)=2.35e+00, params_penalty=0.00e+00\n",
"Epoch 7: 4.38s to complete\n",
" error(train)=2.28e+00, error(valid)=2.31e+00, params_penalty=0.00e+00\n",
"Epoch 8: 4.31s to complete\n",
" error(train)=2.23e+00, error(valid)=2.27e+00, params_penalty=0.00e+00\n",
"Epoch 9: 4.26s to complete\n",
" error(train)=2.22e+00, error(valid)=2.26e+00, params_penalty=0.00e+00\n",
"Epoch 10: 4.25s to complete\n",
" error(train)=2.21e+00, error(valid)=2.25e+00, params_penalty=0.00e+00\n",
"Epoch 11: 4.51s to complete\n",
" error(train)=2.19e+00, error(valid)=2.23e+00, params_penalty=0.00e+00\n",
"Epoch 12: 4.59s to complete\n",
" error(train)=2.15e+00, error(valid)=2.20e+00, params_penalty=0.00e+00\n",
"Epoch 13: 4.38s to complete\n",
" error(train)=2.15e+00, error(valid)=2.20e+00, params_penalty=0.00e+00\n",
"Epoch 14: 4.22s to complete\n",
" error(train)=2.14e+00, error(valid)=2.19e+00, params_penalty=0.00e+00\n",
"Epoch 15: 4.40s to complete\n",
" error(train)=2.11e+00, error(valid)=2.16e+00, params_penalty=0.00e+00\n",
"Epoch 16: 4.35s to complete\n",
" error(train)=2.13e+00, error(valid)=2.17e+00, params_penalty=0.00e+00\n",
"Epoch 17: 4.22s to complete\n",
" error(train)=2.12e+00, error(valid)=2.17e+00, params_penalty=0.00e+00\n",
"Epoch 18: 4.36s to complete\n",
" error(train)=2.11e+00, error(valid)=2.16e+00, params_penalty=0.00e+00\n",
"Epoch 19: 4.31s to complete\n",
" error(train)=2.12e+00, error(valid)=2.18e+00, params_penalty=0.00e+00\n",
"Epoch 20: 4.38s to complete\n",
" error(train)=2.10e+00, error(valid)=2.16e+00, params_penalty=0.00e+00\n",
"Epoch 21: 4.37s to complete\n",
" error(train)=2.10e+00, error(valid)=2.16e+00, params_penalty=0.00e+00\n",
"Epoch 22: 4.26s to complete\n",
" error(train)=2.07e+00, error(valid)=2.13e+00, params_penalty=0.00e+00\n",
"Epoch 23: 4.31s to complete\n",
" error(train)=2.09e+00, error(valid)=2.15e+00, params_penalty=0.00e+00\n",
"Epoch 24: 4.37s to complete\n",
" error(train)=2.08e+00, error(valid)=2.14e+00, params_penalty=0.00e+00\n",
"Epoch 25: 4.42s to complete\n",
" error(train)=2.07e+00, error(valid)=2.13e+00, params_penalty=0.00e+00\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.text.Text at 0x7f8f8728fa90>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAApcAAAF5CAYAAADKygH+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzs3Xl8VPW9//HXd2ayb2RlDzABAghYEy1XLdZaqtWfWpde\nFUsr4lpcWuuVq1brWrXYaq/26kXrrk2ppdbauoEb7pTElU0hYd8SIPs+c35/TBIyWUgmmclJZt7P\nx+M8ZuacM+d8JicD73zP93yPsSwLEREREZFgcNhdgIiIiIiED4VLEREREQkahUsRERERCRqFSxER\nEREJGoVLEREREQkahUsRERERCRqFSxEREREJGoVLEREREQkahUsRERERCRqFSxEREREJmoDCpTGm\nxBjj7WJ6MFQFioiIiMjQ4Qpw/SMBZ7vXM4DXgb8ErSIRERERGbICCpeWZe1r/9oYcxqwybKsd4Na\nlYiIiIgMSX3uc2mMiQJ+BDwWvHJEREREZCjrzwU9ZwIpwFNBqkVEREREhjhjWVbf3mjMq0CDZVk/\nOMQ66cBJwGagvk87EhEREZFQigXGA6917ALZF4Fe0AOAMSYbmAOc0cOqJwHP9WUfIiIiIjKgfgT8\nqb8b6VO4BBYAe4CXe1hvM8Czzz7L1KlT+7grGUquueYa7r//frvLkAGi4x1ZdLwji4535Fi3bh3z\n5s2DltzWXwGHS2OMAeYDT1qW5e1h9XqAqVOnkpeXF3h1MuSkpKToWEcQHe/IouMdWXS8I1JQujD2\n5YKeOcBY4IlgFCAiIiIi4SPglkvLspbjP5C6iIiIiAgQBvcWX7YMfN0ERERERMRuQz5c7t0Lf/4z\nNDfbXYkAzJ071+4SZADpeEcWHe/IouMtfdXncS57tXFj8oDCwsLCkHUKfvVVOPlkKCmB8eNDsgsR\nEZGQ2bp1K2VlZXaXIWEuIyOD7OzsLpcVFRWRn58PkG9ZVlF/99XXoYgGjQkTfI8KlyIiMtRs3bqV\nqVOnUltba3cpEubi4+NZt25dtwEzmIZ8uBw3zvdYUgLf+Y69tYiIiASirKyM2tpajQctIdU6jmVZ\nWZnCZW/ExsKoUb5wKSIiMhRpPGgJJ0P+gp4DdQfImvGFwqWIiIjIIDDkw+XDqx9m7Te/rXApIiIi\nMggM+XDpTnXT6DxA8c4DdpciIiIiEvGGfLjMSc0BYHdDCfVBuSOmiIiIiPTVkA+X7lS370lqMVu2\n2FuLiIiIDE2LFy9m2rRpA7a/W2+9FYejbzFsyZIljBs3jqampiBXFRxDPlymxaWRFJUMqcXqdyki\nIiIBq6qqYvHixVx//fVt8+rq6rjttttYuXJlSPZpjOlzuJw/fz6NjY0sWbIkyFUFx5APl8YY3Glu\nTJrCpYiIiATusccew+PxcN5557XNq62t5bbbbuPtt98OyT5vvvnmPg+eHxMTwwUXXMB9990X5KqC\nY8iHS4CcNDcxIzcpXIqIiISJhoYGurtFdTDuaNR+G08++SSnn3460dHRbfMCvT12oDU5HA6//QXq\nnHPOYfPmzSELv/0RFuHSPcyt0+IiIiKD1M6dO1mwYAEjRowgNjaW6dOn88QTT7Qtf+edd3A4HCxd\nupSbbrqJMWPGkJCQQFVVFU8++SQOh4OVK1eycOFChg8fztixY9ve+8knn3DyySeTkpJCUlISc+bM\n4eOPP/bb/1NPPdXtNkpKSvj888+ZM2dO2/pbtmwhKysLY0xb30iHw8Htt98O+E5LJyUlUVxczCmn\nnEJycjLz5s0D4L333uOcc85h3LhxxMbGkp2dzS9+8QvqO1x13FWfS4fDwdVXX82LL77IjBkz2n5W\nr732WqefaV5eHmlpabz44ot9OSQhNeTv0AOQk5ZDQ+wWijc3EyYfSUREJCzs3buXWbNm4XQ6ufrq\nq8nIyOCVV17hoosuoqqqiquvvrpt3TvuuIOYmBiuu+46GhoaiI6OxhgDwMKFC8nKyuKWW26hpqYG\ngDVr1nDccceRkpLC9ddfj8vlYsmSJRx//PGsXLmSo446yq+WrrbxwQcfYIzxu0NSZmYm//d//8fl\nl1/OWWedxVlnnQXAzJkzAV+XvObmZk466SRmz57N7373O+Lj4wF4/vnnqaurY+HChaSnp7Nq1Soe\nfPBBduzYwdKlS9v2YYxp+2ztvfvuu/ztb39j4cKFJCUl8cADD/DDH/6QrVu3kpqa6rduXl4e77//\nft8OTAiFRRJzp7qxjIdNZduACXaXIyIiIi1uvPFGLMvi008/ZdiwYQBceumlnH/++dx6661cdtll\nbes2NDRQVFTU5enijIwM3njjDb9AdtNNN9Hc3Mz777/PuHHjAPjxj39Mbm4uixYt4q233upxGxs2\nbABgwoSD+SE+Pp6zzz6byy+/nJkzZ3L++ed3qqexsZFzzz2XO++802/+4sWLiYmJaXt98cUXk5OT\nwy9/+Uu2b9/OmDFjDvnzWr9+PevWrWP8+PEAHH/88Rx++OEUFBSwcOFCv3XdbjfPPvvsIbdnh7AJ\nlwAVppjKygkkJ9tckIiISAjU1sL69aHdx5Qp0NIIFxR/+9vfOPfcc/F4POzbt69t/oknnsjSpUsp\nKipqmzd//vwug6UxhksuucQvFHq9XpYvX86ZZ57ZFiwBRowYwfnnn88f//hHqqurSUxM7HYbAPv2\n7cPlcrW1PAbi8ssv7zSvfbCsra2lrq6Oo48+Gq/XyyeffNJjuPze977XFiwBZsyYQXJyMsXFxZ3W\nTU1Npa6ujvr6emJjYwOuP1TCIlxmp2TjwIE3bRMlJd/l8MPtrkhERCT41q+H/PzQ7qOwENqdIe6X\n0tJSysvLeeSRR7ocNscYw969e9taNNuHqo46ListLaW2tpbJkyd3Wnfq1Kl4vV62bdvG1KlTu91G\nf7hcri6D4rZt27j55pt56aWXOHDg4N0DjTFUVFT0uN32/Ulbpaam+m2rVetFR12dXrdTWITLaGc0\noxLHsr3loh6FSxERCUdTpvjCX6j3ESxerxeAefPmccEFF3S5zsyZM1mzZg0AcXFx3W7rUMt6q6tt\npKen09zcTE1NDQkJCb3eVvsWylZer5c5c+ZQXl7ODTfcQG5uLgkJCezYsYMLLrig7edxKE6ns8v5\nXV29fuDAAeLj47usxU5hES4BJmW42ZWhK8ZFRCR8xccHr1VxIGRmZpKUlITH4+GEE04I+rbj4+Pb\n+ky2t27dOhwOR5etgB1NaUnTJSUlTJ8+vW1+X1oDv/jiC77++mueeeYZfvSjH7XNX7FiRcDb6o2S\nkhK/ltnBIiyGIgLfPcajshQuRUREBguHw8HZZ5/NsmXL2lon2ysrK+vXtk888URefPFFtm7d2jZ/\nz549FBQUMHv27Lb+lody9NFHY1kWq1ev9pvf2gezvLy81zW1tjp2bKH8/e9/H5JT10VFRRxzzDFB\n325/hU3LpTvVjSd5GSWf2l2JiIiItLrnnnt4++23mTVrFpdccgnTpk1j//79FBYW8uabb/YqYHY3\noPmdd97JihUrOPbYY1m4cCFOp5NHHnmExsZGFi9e3KttTJgwgenTp7NixQrmz5/fNj82NpZp06ax\ndOlSJk2aRFpaGtOnT+ewww7rts4pU6aQk5PDtddey/bt20lOTmbZsmUBBdTeKiwsZP/+/ZxxxhlB\n33Z/hU3LpTvVTZPrABu3d+7wKiIiIvbIyspi1apVLFiwgBdeeIGrrrqKBx54gPLycr8AeKiWve6W\nTZs2jXfffZcZM2Zwzz33cMcddzBhwgTefvttjjzyyF5tA2DBggW89NJLNDQ0+M1/7LHHGD16NL/4\nxS84//zzWbZs2SG353K5+Oc//8kRRxzBPffcw+23305ubi5PP/10rz5Xd2NfdjX/+eefZ9y4cRx/\n/PHdfi67mEBvbxTQxo3JAwoLCwv9BicNhX/v+Dff/OM3iX1qNbWb8hlkF06JiIh0UlRURH5+PgPx\n/6R0r7KykpycHBYvXsy
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8f86fc7b50>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"stats, keys, run_time = optimiser.train(num_epochs=num_epochs, stats_interval=stats_interval)\n",
"\n",
"# Plot the change in the validation and training set error over training.\n",
"fig_1 = plt.figure(figsize=(8, 4))\n",
"ax_1 = fig_1.add_subplot(111)\n",
"for k in ['error(train)', 'error(valid)']:\n",
" ax_1.plot(np.arange(1, stats.shape[0]) * stats_interval, \n",
" stats[1:, keys[k]], label=k)\n",
"ax_1.legend(loc=0)\n",
"ax_1.set_xlabel('Epoch number')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot batches of the inputs and reconstructed inputs for this non-linear autoencoder model and compare to the corresponding plots for the linear models above."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAC+CAYAAABqOvflAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnXdU1Ge+/18z1KE3pQpSBCQoVcGKXWJBg71Ho0ncuDfZ\nZO+Wu7u/PXtu7m5McjfZxKgxxt5FUVERbEiXJkgTEARFYOidgWm/PzzM1Y0aGcbNrjuvc3JyjgPP\n832G7/f9/Tyf51MESqUSLVq0aNEyeIQ/9QVo0aJFy6uCVlC1aNGiRUNoBVWLFi1aNIRWULVo0aJF\nQ2gFVYsWLVo0hFZQtWjRokVDaAVVixYtWjSEVlC1aNGiRUNoBVWLFi1aNIRWULVo0aJFQ2gFVYsW\nLVo0hFZQtWjRokVDaAVVixYtWjSEVlC1aNGiRUNoBVWLFi1aNITuT30BA0BbuFWLFi0/NYLnfai1\nULVo0aJFQ2gFVcu/FA0NDfzpT3/C1dWV9957j9ra2p/6krT8i9Pd3c2VK1dYtGgR8+bNo7CwUO2x\ntIL6Eunr6yM6OpoZM2awZcsWysrKfupL0hhSqZSysjKuXbvGJ598wrhx45g1axb79+9HJpO9tHnv\n379Pfn4+TU1N6OjoIBKJXtpcAL29vVRUVHDx4kXS0tJ+sLbe3l6uX7+On58f06dPJy8vT6PzK5VK\nysrK+Mtf/sLWrVuRSqUaHf/fGYVCQWNjIxcvXuSvf/0rcXFx1NXV0dfXp/aY/0o+1AHR19dHWloa\nBw4coKCgAEdHR/T09Lh79y4ARkZGTJw4kQ8++ABbW9uXcg1isZi0tDSSk5Pp6uqiqqqKESNGvJS5\nnodCoaCiooJvvvkGAwMDPvnkk0GNd+fOHaKiooiLi6O2tpbOzk6am5vR0dGho6MDfX19Fi1ahL6+\nvoZW8AiJREJaWhqZmZlMmTKFN998E1NTU43O0Y9CoaCoqIiDBw9y7do1mpubiYyMxMfHB3Nzc9XP\n6erqoqOjw4MHD6isrOTbb79l+/btGruO7u5uLl++zFdffYWfnx9z5sxh1KhRGhv/caRSKc3NzbS1\ntfHw4UPOnz9PQkICAAKBAEtLS+bPn09ERATDhw9/7lg9PT3s3r0bsVj81M9ramrQ09NjyJAhANja\n2uLu7o6trS2urq4YGxtjYGCgyeU9gUKhIDc3l127dhEfH09DQwPm5ubMnDnzR9f2PF45QZXL5VRX\nV3Pq1ClOnTpFQUEBnZ2dFBQUIBAI6O7uBkBHR4f79++jo6PD7373u5di6VRWVpKbm4tAIMDZ2ZnX\nXntN43NUVVURGxuLSCQiPDz8By8HmUxGQUEBn376KTdv3mT58uVqzdN/A966dYv4+HiSk5NpbGxE\nIBAQGhqKt7c3ycnJ5OXl8cUXX1BfX8/ChQtxcXHRxDIBiIuL4+jRo9TW1hIWFoa9vT06OjoaG/9x\nbty4wY4dO0hISKClpQWFQqH6/+MIhULs7OwYP348SUlJaLrp5b1790hPT6ehoQGhUIiFhYVGx398\nnuTkZOLj4ykqKkIikdDQ0EBTUxPwSFD19PQoLS0lMTGRqKio546nq6tLcXExsbGxdHV1/eDz3t5e\nzM3NEYlEdHR0qHYbBgYGmJqaMmzYMKZPn87rr7/OsGHDNL7ejo4Orl27xsmTJ5HJZISEhLB69Wpm\nzpyJmZmZ2uO+coJ669Ytdu7cyeXLlxGLxfT19SEQCJBIJHh7e+Pg4EBSUhIikQhbW1vq6+tpaGjA\n2dlZo9chl8upqamhsrISFxcXpk2bxtChQzU6h0wmIyoqim+//ZbQ0FD8/f2fEFSFQkFpaSlfffUV\nFy9exMHBgcjISLXmSkxMZPv27WRnZ6Ojo0NgYCDOzs54e3sTFhaGiYkJ06dPZ//+/aSkpLBv3z4U\nCgWRkZEaEdXe3l6Ki4u5e/cutra2+Pr6vjRx6enpIT4+nmvXrmFtbY2bmxsCgYBRo0Y9YXUrlUqa\nmpqIj4/n1q1bGBsbM2HCBI1dR01NDTExMSor0cDAQONr7r9H9u7dy/nz56mpqaGzs1P1uUAgQKlU\nIhAIkMlk1NTUUFJS8qPj6urqsmXLFsaNG0dHRwfW1taYmZk98QLU19dHIBAglUpRKBT09PRQX1/P\nvXv3yMnJ4dNPPyU6OprIyEg2bdqksTV3dnZy4sQJvvvuO6RSKWFhYWzevJnx48cP+vt9pQS1vb2d\n48ePk56ejpWVFTKZTHVoYWJiwpQpU/Dx8SErKwt3d3d+//vf4+rqirW1tcavpaenh6qqKh4+fMjk\nyZOZNGmSxq2pgoICEhMTqaqqYsqUKT94szY2NhIXF0dMTAzGxsasXLlSLSv5zp07fP/991y+fBl9\nfX1mzpzJihUrGDlyJBYWFlhZWaFUKhkyZAh2dnbs37+fw4cPc+DAAYRCIZs3bx709r+4uJicnByk\nUilz5sxh7ty5GBoaDmrMZ9HS0kJlZSWtra0sXryYFStWIBKJcHBwUM0pl8upqKhg7969REVF0djY\niL+/P+PHjx/0/AqFgvb2ds6dO8f+/fupqalRiZqm76HS0lJ27NjBqVOnEIvFKBQKDAwM8PLyws/P\nD6VSSWtrKwUFBdy/f/+FxxUIBIwcOZJhw4Yhl8sxMDBQCejjPwOorHq5XE5PTw9tbW3ExcXx61//\nmoyMDLy9vTW23ra2No4ePcpXX31FRUUFU6dO5aOPPiI4OBgTE5NBj/9KCWpqairJycnY29uzZMkS\niouLOXHiBI2NjcycOZPQ0FAuX76MSCRi/PjxTJo0CXNz85fi6C8qKiItLQ2hUIibm5tGt77wyDq9\ndOkSWVlZBAYGMnv2bOzs7FSfS6VS8vLyOHLkCD09PUyePJm1a9cOWITy8vI4duwYSUlJyGQy1q9f\nz8qVK/H09HxCwAUCAaampvj6+uLv78/Bgwe5c+cOMTExLF++fFB+6oaGBk6dOkVycjIeHh68/vrr\njBgx4omHU5NIpVL6+vpQKpWUlpaqrBilUolMJqO0tJQLFy6QmJhIdnY23d3dLFy4kDVr1uDk5DSo\nuRUKBXl5eRw/fpyYmBjKy8tRKBQvZa0VFRVUVFRQVlamElMvLy/mzJlDWFgYHh4ewCOL7tq1a3z6\n6ae0t7e/8Pj998SL/Bw8cp/o6upSW1tLdnY2EokET09Pli1bpt4C/w65XM7ly5fZt28fd+/eZfTo\n0axdu5bQ0FCNvZxfKUGtra2lra2NYcOG4e/vj0gk4vr169ja2uLo6EhsbCzp6elMnDhRdaDR29tL\nY2OjRv00MpmM/Px8cnJy8PT0JCwsTKM+WqVSSUpKCpcvX8bCwoJly5YxYcKEJ+ZoamoiPT2dO3fu\nMHr0aDZv3qyWWyM3N5f4+Hjkcjnz589nyZIlBAQEPNNSMjExYdy4cURGRnLixAnV9u31119Xe72J\niYnExcXR0dHB+PHjGTdunMYPvB7H3NwcS0tL9PT0yMvLIyUlBTc3Nx4+fEhiYiLJycncuXOHuro6\nbGxsWL9+PWvWrMHHx2fQ19V/UHLu3Dnq6+sxNjZGT0+P1tZWjYlqe3s7u3bt4saNG3R2dnL37l0U\nCgVOTk4sW7aMtWvXYmdnp7qf5HI53d3dmJqa0traqnE/8ePU1tYSHR1NdHQ09vb2vP/++wQHBw96\n3N7eXi5fvszu3bspKChQPROzZ8/W6E7nlRJUXV1dhEIhMpkMmUyGubk5c+fORSKRkJ+fj1gsVomC\nn58fOjo6KJVKjZ8Ut7a2UlZWRn19vUoAhELNRajl5OSwe/du6uvrWbx4MfPnz1edlsKjk+GMjAwu\nXryIo6Mj77zzDhMnTlRrHQkJCZSXl7N27VpWr17NyJEjn7vt1NPTw8fHh7feeova2lry8vI4e/Ys\nYWFhGBkZDfgaysvLiY2NpbS0lLCwMGbPno1EIuHq1asA+Pn5YW5ujp6e3oDHfhampqaEhISQkpLC\n3bt3iY6OpqCggKamJkp
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8f86fd2350>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAC+CAYAAABqOvflAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnXdw1Oed/1+7q131LqGGekWIIkQTIAnRERiBMbYxMeAW\n2+fLJJNMcplM7o+bu5v743e5mbtLguM4uIBtTK/C9CaEQKigApJQQ72v2kqrbd/fH8z3m91FAq2Q\n785k3zOMkLT6PuX7PO/n0x+ZIAjYYYcddtjx/JD/b3fADjvssONFgZ1Q7bDDDjumCHZCtcMOO+yY\nItgJ1Q477LBjimAnVDvssMOOKYKdUO2www47pgh2QrXDDjvsmCLYCdUOO+ywY4pgJ1Q77LDDjimC\nnVDtsMMOO6YIdkK1ww477Jgi2AnVDjvssGOKYCdUO+yww44pgp1Q7bDDDjumCHZCtcMOO+yYIjj8\nb3fABtgLt9phhx3/25A97Zc/JEL9QUMs5C2TPfV9/I8954cM66Lof8tz8UOHIAjS+5TL//cUZkEQ\nMJlM0vcKhWJSz7Gr/P8DEBeN+eKZ7HP+r0EQBAwGAzqdDqPR+L21If6rqanhn//5n1m7di0/+9nP\nyM3NnbJ5mez7ed73+qxnGwwGtFotOp0Ok8n0vbb1fTxbEASMRuO4fZfJZN/roSiOa6z2TSYT/f39\n7N+/n5SUFN566y16e3vR6XSTmou/KQnVfFJlMhlyuXzKXqT55MtkMgRBYHBwkBMnTvDFF18QFhbG\nzp07SUtLm9TpZzKZMJlMk+6z2L+pkAbEDVJfX8+RI0c4d+4cg4ODBAUFkZ2dzbZt23B3d59SiUMQ\nBLq6uvj666/585//jF6vJyEhgYiIiDE/C1MjuRqNRjo7O+nq6sLT05PAwEBUKpX0bJPJREtLC6dP\nn0apVPL666/j4uJi89jF92u+LgVBQKfTcfXqVfbt20dAQABvvfUWM2bMmLQEJeJZc2S9V8T+2Lr+\nDAYDw8PDDAwMIJfLcXd3x9nZ+YnniKQ31eQqrlWj0YhMJsPBwUF6vsFg4Pr16/zXf/0Xd+/epaur\nC7lcTlFREatWrbLo10Tf5wtFqOYnkV6v59GjR9y9e5eioiKam5sxmUzodDp6enqQy+UkJyfzm9/8\nhsDAwClpFywXaGVlJSdPniQ/Px+1Wk1bW9uULxbrts0XpPi70dFRmpubKS8vRxAElixZgr+/v82b\nUhAERkZGyMvL4+jRo9K8dnd3YzQaqayspKqqipqaGn7961/j5eU1qTEZjUYEQUChUFhsPK1WS29v\nL8PDw4SHh5OcnMy0adNsbmM8mL8bk8nExYsX2bdvH+Xl5ej1eqKjo9m9ezfZ2dkWhNrR0cHBgwdp\nbW3lxo0b7NmzBxcXlwm3azKZMBgM6PV6HBwcUKlU0u+am5u5cOEC586dIyQkhJkzZzJjxoxJj1Fs\ny2AwoFAoUCgUFoSp0WgoKSnhypUrPHz4kNbWVvR6PYIg4OHhQUZGBj/96U9xcnJ6ZltarZZ9+/ZR\nXFxMR0cHIyMjUjsODg4EBASQkJCAh4cHLS0tqNVqgoKCSE9PZ8GCBTg4PKYn8W9E2LqHxiJDk8nE\nrVu3+O///m+uXLnCyMgIvr6+pKSkEB8fb3Go2dLeC0eoBoOB+/fvc+jQIa5evUprayv9/f3odDpp\ngkZHRxEEgfb2dmbMmMEHH3zwvagc/f39tLe3IwgC4eHhLF68eNLS5Vjqh8lkoq2tjcLCQpycnJg3\nbx4+Pj4WRKlWqzl79iz79u1Dq9WSlJRESkqKTRKUSHK1tbXs3buXc+fO0dzcjE6nw8XFhdjYWJRK\nJR0dHXR1dXH06FGcnZ355S9/aROxwGOpoa2tDaVSiY+Pj4U0WFNTQ3V1NQBJSUksXLhQkjjGMquI\nG9JWCIJAWVkZ+/bt4+TJk2g0Gtzc3EhOTiYyMvKJ/nZ3d1NXV0d3dzfl5eV0d3cTFhY24bbE+QXL\nzS8IAvX19RQUFDA4OIi/vz9z5861ODBFTHRdmUuaYlsmk4m+vj6uX7/OiRMnKCkpobOzE41Gw+jo\nqCSpKpVKysrKqK+v549//ONT15AovLS2tlJSUkJLS4v0DJHQXVxcaGpqws/PT1o77u7uVFdXc+rU\nKfR6PaGhobzyyiv4+flJ5G+LICDOi9hXcfzt7e1cunSJGzduoNfrmTNnDm+++SZr1qwhODjY5nkV\n8cIR6v379/n88885ffo0zc3NjI6OAuDi4kJoaCiOjo5UV1dLNimdTmfx9/D0SRTVMut2rU8yk8nE\nwMAAg4ODhIWFkZ6ezvTp0ye9GcaSPnt7e/nmm284deoUM2fOxMvLC09PT+l0NZlMVFVV8fHHH3Pv\n3j3Cw8NZsWIF/v7+E14ogiCg1+s5deoUX3/9NXl5efT19REREUFWVhbp6ekEBATg4OBAZ2cn586d\n4+uvv+bzzz9HrVbzm9/8xiYp0mg0otfrJbVQhEajITc3l7t376JUKomKiiIsLAy5XG5BpAaDAaPR\naCF92QLxwM3JyeHatWsMDw+TnJzM66+/zpo1a4iOjraQTh89esQ333xDX18fLi4uhIeHP1UyF9eP\n9VqRyWQolUqJLARBoKenh4KCAioqKgDw9/e3aF/8nK1jVCqVFiYgtVrN/v372bdvH7W1tWg0midU\nfZPJxOjoqHSAG43GpxKqTCZDpVKRlpZGQEAAXl5eeHt74+DgIO09pVKJp6cnTk5OjI6OotFo0Gg0\nlJWVceLECVpbW4mKiiIhIYFVq1ZNiTlAlMLF9Tw8PExqaio///nPWbBggSSQTLadF4ZQxYm6ceMG\nt2/fZnBwEBcXF1QqFXq9nunTp5OamorBYKCzs5ORkRGSk5NZt26d9PfmX59mWxLxtM8MDg5SX19P\nV1cXqamprFu3TpKYrKXNZ20K69+JZHnjxg1OnDhBdXU1ixYtIigoSFrkgiAwNDTE/fv3qaiowGg0\nEhgYSEZGBs7OzuO2ZQ2xnc8++4wbN25gNBpZs2YNr732GosXLyYoKAhHR0dkMhmjo6M4Ojpy7949\nbt26xYEDBzAYDPzTP/0Tvr6+E2rPwcEBb29vnJycJHIxGo0cO3aMo0ePMjAwwMaNG3nppZeekH5F\nG9lkiVT82tPTQ2VlJV1dXUybNo3s7Gx27NiBn5+fNL8Gg4Hq6mr27NnDxYsXkcvlJCUl8f777+Pq\n6jrh9gDpADQ/CB89esSRI0f45ptvUKvV0rjEuZ6s88iaJPv7+/nyyy/585//TG1tLUajEZVKhZ+f\nH7NmzSIlJYWBgQGOHz8umc1EE8Cz2lEqlcybN4+ZM2fi6uoq9V10UJnbNEUpXdR6zp8/L9mtn8d3\nYA29Xk9OTg5/+ctfaGhoYO7cubz33ntkZGRI7+152nlhCFWn01FYWMi1a9cYGRkhPT2d2NhYmpub\nKSkpAaCurg5XV1fmz59PbGwsr7zyinTiT8TDaU581l/NYTKZqKysJD8/H41Gg7+/PxEREc+9EczR\n19fHtWvXePDgAQsXLmTt2rUEBgZakFB5eTmHDh1CEATmzp3Lz3/+c7y9vW2STnt6ejhw4AC3b99G\nEAR27NjBzp07SUpKwt3d3aJvjo6OJCUlkZ2dTUVFBWq1mhs3blBcXMzKlSsn1K7ouDCXRrq7u7l1\n6xYNDQ0sW7aMt956i5kzZ1pISBN5LxNFT08P/f39CIJAfHw8CxYskKT60dFRqqurOXv2LNeuXeP+\n/fv4+fnxzjvvMHv2bJKTk58puY01ZhEGg4H8/Hy+/fZbrly5Qn19vRTOYx3aIz5vMoeH0WikrKyM\nw4cPc+TIEWprazGZTCxdupTNmzcTHh5OSEgI06dPR6vVEhISwr/+678yPDyMyWRieHjYwtY7FuRy\nOR4eHk/008HB4QnBRbSZK5VK3NzcpL3m7e3N7NmzJzRGc/u7uUlDhF6v5/z583z66adUVFQwffp0\nNm/ezKpVqywOwclI/SJ
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8f87149c90>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"inputs, targets = valid_data.next()\n",
"recons = model.fprop(inputs)[-1]\n",
"_ = show_batch_of_images(inputs.reshape((-1, 28, 28)), (4, 2), 5)\n",
"_ = show_batch_of_images(recons.reshape((-1, 28, 28)), (4, 2), 5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 3: Denoising autoencoders\n",
"\n",
"So far we have just considered autoencoders that try to reconstruct the input vector fed into them via some intermediate lower-dimensional 'contracted' representation. The contraction is important as if we were to mantain the input dimensionality in all layers of the model a trivial optima for the model to learn would be to apply an identity transformation at each layer.\n",
"\n",
"It can be desirable for the intermediate hidden representation to be robust to noise in the input. The intuition is that this will force the model to learn to maintain the 'important structure' in the input in the hidden representation (that needed to reconstruct the input). This also removes the requirement to have a contracted hidden representation (as the model can no longer simply learn to apply an identity transformation) though in practice we will still often use a lower-dimensional hidden representation as we believe there is a certain level of redundancy in the input data and so the important structure can be represented with a lower dimensional representation.\n",
"\n",
"Create a new data provider object which adds to noise to the inputs to an autoencoder in each batch it returns. There are various different ways you could introduce noise. The three suggested in the lecture slides are\n",
"\n",
" * *Gaussian*: add independent, zero-mean Gaussian noise of a fixed standard-deviation to each dimension of the input vectors.\n",
" * *Masking*: generate a random binary mask and perform an elementwise multiplication with each input (forcing some subset of the values to zero).\n",
" * *Salt and pepper*: select a random subset of values in each input and randomly assign either zero or one to them.\n",
" \n",
"You should choose one of these noising schemes to implement. It may help to know that the base `DataProvider` object already has access to a random number generator object as its `self.rng` attribute."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class MNISTDenoisingAutoencoderDataProvider(data_providers.MNISTDataProvider):\n",
" \"\"\"Simple wrapper data provider for training a denoising autoencoder on MNIST.\"\"\"\n",
"\n",
" def next(self):\n",
" \"\"\"Returns next data batch or raises `StopIteration` if at end.\"\"\"\n",
" inputs, targets = super(\n",
" MNISTDenoisingAutoencoderDataProvider, self).next()\n",
" noised_inputs = (self.rng.uniform(size=inputs.shape) < 0.75) * inputs\n",
" return noised_inputs, inputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once you have implemented your chosen scheme, use the new data provider object to train a denoising autoencoder with the same model architecture as in exercise 2."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Create data provider objects for the MNIST data set\n",
"train_data = MNISTDenoisingAutoencoderDataProvider('train', batch_size=50, rng=rng)\n",
"valid_data = MNISTDenoisingAutoencoderDataProvider('valid', batch_size=50, rng=rng)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"input_dim, output_dim, hidden_dim = 784, 784, 50\n",
"\n",
"weights_init = initialisers.GlorotUniformInit(rng=rng)\n",
"biases_init = initialisers.ConstantInit(0.)\n",
"\n",
"model = models.MultipleLayerModel([\n",
" layers.AffineLayer(input_dim, hidden_dim, weights_init, biases_init),\n",
" layers.ReluLayer(),\n",
" layers.AffineLayer(hidden_dim, output_dim, weights_init, biases_init),\n",
" layers.SigmoidLayer()\n",
"])\n",
"\n",
"error = errors.SumOfSquaredDiffsError()\n",
"\n",
"learning_rule = learning_rules.AdamLearningRule()\n",
"\n",
"num_epochs = 25\n",
"stats_interval = 1\n",
"optimiser = optimisers.Optimiser(\n",
" model, error, learning_rule, train_data, valid_data)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0:\n",
" error(train)=9.08e+01, error(valid)=9.08e+01, params_penalty=0.00e+00\n",
"Epoch 1: 6.40s to complete\n",
" error(train)=8.95e+00, error(valid)=8.80e+00, params_penalty=0.00e+00\n",
"Epoch 2: 6.40s to complete\n",
" error(train)=6.16e+00, error(valid)=6.13e+00, params_penalty=0.00e+00\n",
"Epoch 3: 5.04s to complete\n",
" error(train)=5.23e+00, error(valid)=5.23e+00, params_penalty=0.00e+00\n",
"Epoch 4: 5.03s to complete\n",
" error(train)=4.93e+00, error(valid)=4.95e+00, params_penalty=0.00e+00\n",
"Epoch 5: 5.16s to complete\n",
" error(train)=4.84e+00, error(valid)=4.86e+00, params_penalty=0.00e+00\n",
"Epoch 6: 6.21s to complete\n",
" error(train)=4.78e+00, error(valid)=4.80e+00, params_penalty=0.00e+00\n",
"Epoch 7: 5.03s to complete\n",
" error(train)=4.72e+00, error(valid)=4.74e+00, params_penalty=0.00e+00\n",
"Epoch 8: 5.02s to complete\n",
" error(train)=4.71e+00, error(valid)=4.71e+00, params_penalty=0.00e+00\n",
"Epoch 9: 5.01s to complete\n",
" error(train)=4.68e+00, error(valid)=4.71e+00, params_penalty=0.00e+00\n",
"Epoch 10: 5.05s to complete\n",
" error(train)=4.69e+00, error(valid)=4.70e+00, params_penalty=0.00e+00\n",
"Epoch 11: 5.12s to complete\n",
" error(train)=4.64e+00, error(valid)=4.68e+00, params_penalty=0.00e+00\n",
"Epoch 12: 5.08s to complete\n",
" error(train)=4.64e+00, error(valid)=4.68e+00, params_penalty=0.00e+00\n",
"Epoch 13: 5.61s to complete\n",
" error(train)=4.64e+00, error(valid)=4.67e+00, params_penalty=0.00e+00\n",
"Epoch 14: 5.17s to complete\n",
" error(train)=4.63e+00, error(valid)=4.66e+00, params_penalty=0.00e+00\n",
"Epoch 15: 5.74s to complete\n",
" error(train)=4.60e+00, error(valid)=4.62e+00, params_penalty=0.00e+00\n",
"Epoch 16: 5.21s to complete\n",
" error(train)=4.60e+00, error(valid)=4.63e+00, params_penalty=0.00e+00\n",
"Epoch 17: 5.03s to complete\n",
" error(train)=4.58e+00, error(valid)=4.61e+00, params_penalty=0.00e+00\n",
"Epoch 18: 5.07s to complete\n",
" error(train)=4.57e+00, error(valid)=4.61e+00, params_penalty=0.00e+00\n",
"Epoch 19: 5.17s to complete\n",
" error(train)=4.57e+00, error(valid)=4.60e+00, params_penalty=0.00e+00\n",
"Epoch 20: 5.03s to complete\n",
" error(train)=4.57e+00, error(valid)=4.62e+00, params_penalty=0.00e+00\n",
"Epoch 21: 5.01s to complete\n",
" error(train)=4.57e+00, error(valid)=4.62e+00, params_penalty=0.00e+00\n",
"Epoch 22: 5.12s to complete\n",
" error(train)=4.56e+00, error(valid)=4.62e+00, params_penalty=0.00e+00\n",
"Epoch 23: 5.01s to complete\n",
" error(train)=4.55e+00, error(valid)=4.59e+00, params_penalty=0.00e+00\n",
"Epoch 24: 5.05s to complete\n",
" error(train)=4.56e+00, error(valid)=4.61e+00, params_penalty=0.00e+00\n",
"Epoch 25: 5.06s to complete\n",
" error(train)=4.55e+00, error(valid)=4.59e+00, params_penalty=0.00e+00\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.text.Text at 0x7f8f870f31d0>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqQAAAF5CAYAAABJDjjFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzs3Xl4VdW9//H39yRkHkhIQlAC4RyU2SHRKg6oiFiHeqtY\nNVQroqjlXvXSXvlZZ6tVi620anuFah3biBa9Xm2tioqIE9dERBltEkAcQpjCkJCQZP3+OElMYqaT\nnORk+LyeZz/JWXvtvb9JfPTj2nutbc45RERERERCxRPqAkRERESkf1MgFREREZGQUiAVERERkZBS\nIBURERGRkFIgFREREZGQUiAVERERkZBSIBURERGRkFIgFREREZGQUiAVERERkZBSIBURERGRkAo4\nkJpZnJn9zsw2mlmZmS03s6PaOOZkM8szs/1mtsHMLu14ySIiIiLSl3RkhPRR4FTgx8B44HVgiZkN\naa6zmWUCLwNvAIcDvwceMbPTOnBtEREREeljzDnX/s5mUcAe4AfOuX82aP8I+Idz7tZmjvk1cIZz\n7rAGbblAonPuzM4ULyIiIiK9X6AjpOFAGFDRpL0cOKGFY44FljRpexWYGOC1RURERKQPCiiQOuf2\nAu8Dt5jZEDPzmNnF+MNls7fsgXSguElbMZBgZpGBFiwiIiIifUt4B465GPgz8CVQBeQDfwWyg1WU\nmQ0CTgc2AvuDdV4RERERCZooIBN41Tm3vTMnCjiQOueKgFPMLBpIcM4Vm9kzQGELh3wDDG7SNhjY\n7Zxreuu/zunAXwKtTURERES63Y/xD052WEdGSAFwzpUD5WaWhD9A/lcLXd8HzmjSNrW2vSUbAZ5+\n+mnGjBnT0RKlF5kzZw7z588PdRnSTfT37l/09+5f9PfuP9auXcvFF18MtbmtMwIOpGY2FTBgPXAI\nMA9YAzxeu/9u4GDnXN1aow8D/1472/7P+JeMOh9obYb9foAxY8aQlZUVaIlt2rQJFi6En/8ckpOD\nfnrpgMTExC75W0vPpL93/6K/d/+iv3e/1OnHKzuyDmki8AdgLf4Qugz4vnOuunb/ECCjrrNzbiNw\nFjAFWAnMAS53zjWded9tdu2Cu++G9etDVYGIiIiI1OnIM6TPAc+1sv+yZtqWEcRJT501YoT/a1ER\nTNTiUyIiIiIh1S/fZZ+Q4L9VX1QU6kpEREREpF8GUvCPkiqQ9hw5OTmhLkG6kf7e/Yv+3v2L/t7S\nEQG9OrS7mFkWkJeXl9dlD0b/6EewYwe88UaXnF5ERKRLbd68mW3btoW6DOnDUlJSGDZsWIv78/Pz\nyc7OBsh2zuV35lodXvaptxsxAvLyQl2FiIhI4DZv3syYMWMoKysLdSnSh8XExLB27dpWQ2mw9OtA\nunkzVFVBeL/9LYiISG+0bds2ysrKtF63dJm6NUa3bdumQNqVRoyA6mrYsgUyM0NdjYiISOC6ar1u\nke7Wryc1gSY2iYiIiIRavwykJftKWL7nSYjYq0AqIiIiEmL9MpBuKt3EFX+/lNQx6xVIRUREREKs\nXwZSX5IPgGRfgQKpiIiISIj1y0CaFJ1EUlQSMQcXKpCKiIhIh8ybN4+xY8d22/Vuv/12PJ6ORbcF\nCxYwfPhwDhw4EOSqgqNfBlIAb5IXkjVCKiIiIoHbs2cP8+bN44YbbqhvKy8v54477mDZsmVdck0z\n63AgnTFjBpWVlSxYsCDIVQVHvw2kvmQf5VEFfP01lJeHuhoRERHpTR599FGqq6u56KKL6tvKysq4\n4447WLp0aZdc85ZbbunwyxAiIyO59NJLuf/++4NcVXD020DqHehlJ4UAbNoU4mJERESk0yoqKmjp\nlejBeKtVw3M8/vjjnHPOOURERNS3Bfo69kBr8ng8ja4XqAsuuICNGzd2WWDujH4bSH3JPkoqv4Cw\nSt22FxER6WG++uorZs6cSXp6OlFRUYwfP57HHnusfv/bb7+Nx+Nh0aJF3HzzzQwdOpTY2Fj27NnD\n448/jsfjYdmyZcyePZvBgweTkZFRf+zHH3/MGWecQWJiIvHx8UyZMoUPP/yw0fWfeOKJFs9RVFTE\nqlWrmDJlSn3/TZs2kZaWhpnVP+vp8Xj45S9/CfhvmcfHx1NYWMiZZ55JQkICF198MQDLly/nggsu\nYPjw4URFRTFs2DB+9rOfsX///kY1NfcMqcfj4dprr+XFF19kwoQJ9b+rV1999Tu/06ysLJKTk3nx\nxRc78ifpUv32TU3eJC81robwQZsoKjok1OWIiIhIra1bt3LMMccQFhbGtddeS0pKCq+88gqXX345\ne/bs4dprr63ve+eddxIZGcn1119PRUUFERERmBkAs2fPJi0tjdtuu419+/YBsHr1aiZNmkRiYiI3\n3HAD4eHhLFiwgJNPPplly5Zx9NFHN6qluXO89957mFmjt2Slpqby8MMPc/XVV3Peeedx3nnnAXDY\nYYcB/uc/q6qqOP300znxxBP57W9/S0xMDADPPfcc5eXlzJ49m0GDBrFixQoefPBBvvzySxYtWlR/\nDTOr/9kaeuedd3j++eeZPXs28fHxPPDAA5x//vls3ryZpKSkRn2zsrJ49913O/aH6UL9NpDWLf2U\nOqpAgVRERKQHufHGG3HOsXLlSgYOHAjAlVdeyfTp07n99tu56qqr6vtWVFSQn5/f7K3slJQU3njj\njUYh7uabb6aqqop3332X4cOHA3DJJZcwatQo5s6dy1tvvdXmOdavXw/AiLrXPgIxMTFMmzaNq6++\nmsMOO4zp06d/p57KykouvPBC7rrrrkbt8+bNIzIysv7zFVdcgc/n46abbmLLli0MHTq01d/XunXr\nWLt2LZm170I/+eSTOfzww8nNzWX27NmN+nq9Xp5++ulWzxcK/TaQDk0YygDPABIyNdNeRET6rrIy\nWLeu668zejTUDvh12vPPP8+FF15IdXU127dvr2+fOnUqixYtIj8/v75txowZzYZRM2PWrFmNgmRN\nTQ2vv/465557bn0YBUhPT2f69Ok88sgj7N27l7i4uBbPAbB9+3bCw8PrRzgDcfXVV3+nrWEYLSsr\no7y8nIkTJ1JTU8PHH3/cZiA97bTT6sMowIQJE0hISKCwsPA7fZOSkigvL2f//v1ERUUFXH9X6beB\nNMwTRubATCLSCil6q+3+IiIivdG6dZCd3fXXycuDBnewO6ykpIRdu3axcOHCZpcoMjO2bt1aP3La\nMIg11XRfSUkJZWVlHHrood/pO2bMGGpqavjiiy8YM2ZMi+fojPDw8GbD5RdffMEtt9zCSy+9xM6d\nO+vbzYzS0tI2z9vw+dg6SUlJjc5Vp27iVXO3/kOp3wZS8E9s2rSzgC+++z8QIiIifcLo0f6w2B3X\nCYaamhoALr74Yi699NJm+xx22GGsXr0agOjo6BbP1dq+9mruHIMGDaKqqop9+/YRGxvb7nM1HAmt\nU1NTw5QpU9i1axe/+MUvGDVqFLGxsXz55Zdceuml9b+P1oSFhTXb3tys/507dxITE9NsLaHUrwOp\nd6CX1QPeYdcu2LULav9nS0REpM+IiQnOyGV3SU1NJT4+nurqaiZPnhz0c8fExNQ/A9rQ2rVr8Xg8\nzY42NjW6Nn0XFRUxfvz4+vaOjDp++umnfP755zz11FP8+Mc/rm9fsmRJwOdqj6KiokYjwD1Fv132\nCfwjpNuqCwGn50hFRER6AI/Hw7Rp01i8eHH9KGhD27Zt69S5p06dyosvvsjmzZvr24uLi8nNzeXE\nE0+sf360NRMnTsQ5x0cffdSove6Z0l27drW7prrRzaYjob/73e+65LZ6fn4+xx13XNDP21n9eoTU\nl+SjvHofxG6lqGgwRx4Z6opERETk3nvvZenSpRxzzDHMmjWLsWPHsmPHDvLy8njzzTfbFUpbWqT+\nrrvuYsmSJRx//PHMnj2bsLAwFi5cSGVlJfPmzWvXOUaMGMH48eNZsmQJM2bMqG+Piopi7NixLFq0\niEMOOYTk5GTGjx/PuHH
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8f870f3850>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"stats, keys, run_time = optimiser.train(num_epochs=num_epochs, stats_interval=stats_interval)\n",
"\n",
"# Plot the change in the validation and training set error over training.\n",
"fig_1 = plt.figure(figsize=(8, 4))\n",
"ax_1 = fig_1.add_subplot(111)\n",
"for k in ['error(train)', 'error(valid)']:\n",
" ax_1.plot(np.arange(1, stats.shape[0]) * stats_interval, \n",
" stats[1:, keys[k]], label=k)\n",
"ax_1.legend(loc=0)\n",
"ax_1.set_xlabel('Epoch number')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the `show_batch_of_images` function from above to visualise a batch of noisy inputs from your data provider implementation and the denoised reconstructions from your trained denoising autoencoder."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAC+CAYAAABqOvflAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnXdcVNfWv58Z6tCG3rsCoiKKBcGCWKKiokGNMZZYYhJj\nkptebqo39yYx5d7EvInRGI0RS+xRUZGIIigggoLSBEWkSC9Dm4Epvz/8Ma8mmkQ4mFfvPH/lkxn3\nOmc457vXXnuvtUQajQYdOnTo0NF9xH/1BejQoUPHg4JOUHXo0KFDIHSCqkOHDh0CoRNUHTp06BAI\nnaDq0KFDh0DoBFWHDh06BEInqDp06NAhEDpB1aFDhw6B0AmqDh06dAiETlB16NChQyB0gqpDhw4d\nAqETVB06dOgQCJ2g6tChQ4dA6ARVhw4dOgRCJ6g6dOjQIRD6f/UF3AW6wq06dOj4qxH93ocPvIea\nkZHB448/TlBQEGvXrv1LriEnJ4eXXnqJl19++S+xr0OHjv9FoVDw2Wef8frrryOXywUd+4EW1Kqq\nKrZv346dnR0LFizgypUrtLa23tNraG5uJikpicuXLzN37lzBx29ra0OhUPDNN9+wYsUKsrKyaG9v\np62tTXBbOnQ8CBgaGtLR0YFMJkPojiX305L/rikuLmbChAmkpaURHR3NnDlzMDY2vmf21Wo1xsbG\n9OnTB4VCQf/+/QUd//jx43z77bdMmDCB6upqjI2N8fHx4dChQyQnJ7NixQqkUilSqVRQuzruHSqV\nitjYWC5dusRzzz2HWq3GwMDgr76s+5aGhgZEIhHLli2jo6MDIyMjQcd/oAXV39+f7777jtbWVn74\n4Qfs7e0Ri++dU56Xl8eGDRsAePPNNwX747W2tlJQUICjoyPm5ubU19fz5JNPsnHjRlauXImpqSmb\nN28mJSWFmTNnsmLFCvT09ASxfS/45z//SVxcHC0tLQC8/fbbFBQU8NNPP2k9irlz57JkyRJaWlpw\nc3P7Ky+3R8nPz2fPnj0cOnSI69ev89FHHwluIycnh02bNnHy5EkiIiJYvnw5tra2t3xHJpNx7do1\nPD09MTMz65Y9pVJJSUkJubm5HDx4kMrKSmbNmsVDDz2EVCpFT08Pkeh3Q5VdxszMjLVr1yKXy3n2\n2WcF1wO9999/X9ABe5D37/Yf7Ny5k6+++orW1lZCQkLo06dPD1zWnTlw4ADbt29nxIgRjBs3TrCH\n5Nq1a3z++ed8++235ObmEhoaSnh4OM7OzvTq1QsrKyt8fX0ZPXo0ra2tDB8+XDDbH3/8MW+88QZK\npRJ/f/87ekvt7e00NzdjYGBAa2srhoaGf9qGi4sLP//8M2fOnGHZsmVcunQJCwsLVCoVbm5ueHh4\nsGnTJg4ePEhmZiY2NjY4ODigr98z/kF6ejpffPEFn3zyCWfPnsXNzQ17e/sesXUzSqUSIyMjAgMD\nmTZtGiNGjMDOzk6w8WtqalCpVFhaWhIXF4eNjQ1WVlbk5+cTHBxMfHw8zz//PL/88gtisRgvLy8y\nMzPp1atXl+wVFxeza9cuVq1ahUQiITs7m/Xr15Ofn09qaipqtRqpVEpsbCzOzs6YmpoKdq+dxMTE\nIJfLmTJlCtbW1l15L1b+3ocPrIcaExMDQHBwMAYGBvj4+Nwz22q1mpMnT7Jv3z4GDhzIo48+Kuj4\nCoWC0tJS2traePLJJ4mMjNQ+8BqNhtOnT2NlZcXAgQPJzMwUTEyPHTuGn58fQ4cOxc3N7Y5eb3V1\nNXv37iUzM5P33nuPo0ePMn/+/D9t58KFC9TV1SESidizZw/19fW8/vrruLi4sGvXLsRiMRUVFVRU\nVJCfn095eTmLFy/mscceE+Q+ATo6Ojh48CDW1tacO3eOtrY23nrrLZKSkvj+++95++23sbGxEcze\n7SgrK6O9vZ2CggIyMjJ44YUXBBu7sbERAwMDjI2NWbt2Lfb29sydO5ddu3bxxRdfcODAAaqqqsjP\nz8fIyIiUlBTs7e2RSqX06tULb2/vu7JXXV2tHff69etcuHCB9vZ2JBIJixcvxszMjJiYGA4fPoxY\nLMbc3JyZM2cKdr+dDB48mOzsbBobG3tktfpACmp7ezt9+vRh586d1NXVERUVhbW19T2xrVar0Wg0\nZGVloaenx5IlS3BxcRHUhpOTEy+++CL6+vr0798fOzs7RCKR1kOrqanh8uXLzJo1S7D7rq2tZdeu\nXYSHhzNu3DjUavVtv1dVVcXOnTv58ssvcXBwwNzcnOHDh9+Vrf79+/Ovf/2L5uZm8vPzUavVhIaG\nIhKJGDBgALm5uajVamQyGWvWrEEsFhMcHCzEbWpJTEzk+++/p6KiAqVSSVRUFG1tbfz888+oVCpG\njhzZIy98J0qlkoSEBKKjo5HJZAwaNOiuvPzfQ6VScejQIY4cOUJ9fT0XL15EIpGQkpLCwIEDWb16\nNbGxscTHxwM3Jpf8/HwKCgoYPHjwb8IBfwZDQ0OuX7+OUqlkwIABTJkyBX9/f0xMTOjbty8GBgYM\nGjSInTt3cuTIERITE3vk942NjaWxsZE5c+YIPjY8oIIaExODo6Mjcrkcb29vQZfbf0RxcTF79uyh\nsbGRN954g4CAAMFtWFhYEBYWhr6+/i3L3Pb2do4dO8aaNWswMjIiJyeHYcOGCWLTxMSEhoYGoqOj\ncXd3v6NImpqaMnLkSKqrq4mOjmbfvn13fbrB29sbd3d3NBoNtbW1aDQarK2tMTAwoG/fvly9epWC\nggJiY2MBsLS07PIy9Nd0xmh/+eUXfH196dOnDwYGBkRFRSGRSBg6dChHjhyhtLRUEHt3Qq1WM3jw\nYMzNzbl27RpSqVRQQb18+TLx8fGUl5cD4OrqSu/evRk1ahQeHh64ubkxaNAgKisraW1tpaOjgyFD\nhmBjY4OFhcVd2zQzM2PFihUsWLAAS0tLevfurfXwO5/hgIAAfvnlF0xNTXsspJKbm4tEIumx8R84\nQS0vLycmJoYrV64glUqJiorCycnpnthua2sjNjaW6OhoIiMjGThwoOC7iAB6enq/WW53dHRQWVmJ\nkZERUVFRiMViHB0dBbO5YcMGgoKCkMlk2s2iX1NWVkZMTAzJyclYWlryj3/8A2dn57u2JRaLteLx\n63/f0NBAQkICMplM6yldunSJPXv2EBUVdde2bsfatWvZtWsXHh4eLFiwgDFjxuDq6kp9fT329va0\ntLRQXV2NSqXqsc2+2tpaTp06hampKbNnzxZ0iaqnp8fgwYNZuXIlmzdv5uTJk9jb2zN69GiCg4NR\nKBR4eHjg5+dHU1MTHR0d2vh1d2yGhIQgEokQi8W3dXDy8vI4ceIEZmZmXXpu/ojNmzeTkZHBrFmz\neuzv9sAJanFxMSNGjKCoqAgHBwdGjhx5z2zn5eVx/Phx+vbtq41r3gvq6ur45Zdf0Gg0GBgYoKen\nx6hRo5DJZCiVym5t1rS1tWFgYMCJEyeoq6tj0aJF9O3bF6VSqQ1vaDQaLl26RHR0NPv378fY2Ji3\n336b4ODgLi0Pb0en52hkZERJSQkHDhzA0tKS4OBg5s2bJ9gKpHMDpjNc4e/vj7u7O3DjdIVMJsPf\n35+IiIgeXfUUFhbS2tpKY2MjJ0+eJDIyUrCx9fT0GD58ODk5OdqJq7W1lebmZkxMTLSbQZ277XK5\nHD09Paqqqrrl2d3uOdRoNNTX1xMfH8+6desoLCxk9OjRDBkyBLhxjru0tLTbG8pVVVXExcUhFovx\n9/fv1li/xwMnqOnp6Vy4cIHw8HAmT56Ml5fXPbFbU1PDzz//TGlpKU888QR9+/altLRU60kJHUe9\nmdTUVK5du4ZKpeLIkSPU1dXR0dHB1KlTAbrsSTU1NREXF0dGRgZZWVkUFBQgFosxMjLCwsKC8+fP\nU15ejkqlorS0VBvrrKmpYdq0aTQ3NyORSLp1b3K5nKysLEpLS+nXrx9XrlzRxqaTk5ORy+WMHz8e\nKyurbtmBG3FLKysrMjI
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8f870f3750>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAC+CAYAAABqOvflAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzsnVdwVGea938d1epWauWcAEVAoECOJmNM8jisAzUz652t\nWVdtzeXebNVe7vXs3OxMfTM7Ho9tHLCNwQQLMCCCEFYgSCAJtSLKqVutzqe/C+qcOd1qQXdLHi9s\n/6soidbp87zxye/zKrxeLxFEEEEEESwcyp+6ARFEEEEELwoiDDWCCCKIYJEQYagRRBBBBIuECEON\nIIIIIlgkRBhqBBFEEMEiIcJQI4gggggWCRGGGkEEEUSwSIgw1AgiiCCCRUKEoUYQQQQRLBIiDDWC\nCCKIYJEQYagRRBBBBIuECEONIIIIIlgkRBhqBBFEEMEiIcJQI4gggggWCRGGGkEEEUSwSFD/1A0I\nAQsq3Cqv+6pQKBbcmGfREWkIgoAgCHi9XtxuN9HR0T8a7QgiiODp8K//LAgCAErlE90yCN7w1Ade\nKA3V6/VKAyYysOvXr/PP//zPHDhwgD/84Q9MT08HfH6hdN1uN1arFbvdLjFQr9eL1Wrls88+Y9eu\nXfzbv/0bY2Nji0LzfxPEvr5o/fopEWhM5Z9Fxjo0yMdtZmaGlpYWmpqasNlsAZ8LF8+ThjovAg2C\nIAj09PTwySef8Omnn+L1elmyZAmvv/46Xq/XRxL5/z8c+oIgoFAoUCqVPtppa2sr33zzDaOjoyQl\nJRETEyP9Df4mGcOBIAg4nU6ampq4desWKSkp1NTUkJycjF6vR6lUolKpFkQjgr8//BUD+U9xjUUQ\nHrxeL2NjYwwODhIbG4vH4wn4jDjeoY71C8FQ58Po6Cijo6MolUry8/OpqKiY1+T2N9VDhUKhQKPR\n+DDUmZkZBgYGiIuL41e/+hVHjhxBq9X6fE9kxM+iK2+fqBHfuXOHzz//nKtXr2K328nIyODUqVNY\nrVYAqqureeuttygoKAhrE4biJlno+AXThmDb8mNAvsmCma9Q3x0IgiAwPT1NS0sLDx48ID8/nzVr\n1mA0GqX5XIxxlzPtvwfjlmuL4lj+PebU6/Xi8XjQ6/UsW7YMvV5PVFSURFvcWyLCadMLw1D9O+/x\neBgaGmJ8fJysrCwOHz7M3r17UaufdHmxTCZROxUEwYeZWq1WLl26xKeffkpqaiqbN28mPT193skL\nFh6Ph5mZGb788ks++ugj7t27h8PhoKKigpqaGux2O7W1tXR2dmIymdDr9bz//vs+CycYiG4L+cad\nj6nIPxcFhPzzYDenOJYul4vGxkYuXLhAW1sbU1NTKJVKqQ8qlYqUlBRKS0uprq6mpKQEg8GASqWS\n3hVKX5/GlMQ+uN1uXC4XXq8XtVqNVqsNaOWEugkDWVfiOzweD4ODgzQ0NPDo0SO0Wq00L/J2hwuv\n14vNZqO1tZWmpibsdjtLliyhtLRUYtperxeXywWATqdDq9WiUql8xjoUek6nk4sXL3LmzBlMJhMe\nj4ecnBw2btxIVVUVSUlJxMbGEhUV5UNDvhZDpSn+FAQBu92O0+kkOjqauLg4iR/A36zGheCFYKj+\nC1sQBIaHh7l27Rqtra1kZGSQnZ1NUlLSHEYm/z1ciSxf5AAul4va2lp++9vf0tXVxYEDB8jIyJiz\nCIPdfP5+4QsXLvDhhx9y69Yt3G43GzZs4B//8R9Zt24djx8/Znh4mIcPH2Kz2Xx8uqFsdjlzUCgU\n0mJsb2/HbDaTk5NDeno6Go1mjsvD5XIxMjJCb28viYmJ5OTkzGF4gfooCAIOh4P/+q//4uTJk3R1\ndTEzM4Pb7ZbeLY6FVqslNjaWvLw8XnvtNX75y1+i1+vD1nTkzF8+1h0dHdy6dYvu7m7MZjMOh4Pl\ny5dz5MgRUlNTfdq0EPhbICL98fFxHj9+jNvtJiEhQerjfEw4FHoej4f6+nr++te/cvv2bdxuN1lZ\nWaSkpDA9Pc3k5CRut1tSFtLS0ti/fz/79u0jMzPThxk9i5ZoUX3wwQdcvHiRgYEBZmdn8Xq9REdH\nc+HCBRITE4mNjaW8vJwjR46wYcMGNBqNz3tCFZT+e8dsNjM9PU18fDxqtXqOpu/PG+RKUjB4IRiq\n/+LyeDzU1tZSW1vLxMQEy5cvJzs7O+CGli/gcCFqTeLkdHZ2cv78ee7du0d2djYVFRXEx8fPmZhw\nzDVBEBgfH8dsNqPX69m0aRPHjh1j48aNxMXFodFoKCsro7S0lNLSUnbt2iUtylDMmUBt/eGHHzh5\n8iQul4sDBw6QnJyMQqGQ/FBqtVrSLn/3u9/R39/P0qVLefPNN9m+fXtQi3NiYoLGxkbu37/PzMwM\nKpUKo9FISkoKSqWS0dFRxsfHsdvtTE9PMzY2RnJyMjt27KC0tNSnvcGMq7+FIWJ6eprjx4/z1Vdf\n0dfXh91ux+1243Q6qa+vx+Vy8d5770kuJPE9KpUqLEHpPzeCIGA2m+nu7mZwcJDMzEzy8vJ8NGO5\nEAimv/7rzWKxcOnSJWpraxkZGUGpVDIyMoJKpcJsNkvCWPyOwWBgenqahIQEDh06FFRfRUbW2trK\nn/70J77++msGBwd93ut2u7FYLExNTeF0Ouns7ESlUrFy5UqMRmNQY+kPfyVH/MxqtTI7O0t8fPwz\n+YHcjxrsnD73DFXcDHKNanp6mqamJkwmE2lpaWzatInly5cH1EDlWli4kPuB3G439+7do62tjfLy\ncl5//XX27Nkzr8kttv1p2ptcI1GpVFRWVnLgwAGio6PZtGkTZWVlxMbGSotIr9dTWlrKtm3bKC0t\nDWmTzwe3283du3fp7OxkzZo15ObmSn0S/ymVSsbGxrhy5Qrnz5/H4XCgUChITk4OioZCoUCv17N7\n927p/6tXr6akpISEhAQAzGYzDQ0NnDx5ktbWVhwOByaTiTt37lBSUuKzIZ7lahA3tf/6cTqdnDx5\nkj/+8Y+0traSmZlJcXExExMTtLW10dnZye3btzl69ChZWVnSO+Q05xtv/00utlXeBpEJdXV10dDQ\nwOjoKCtWrCArK0taJ/4R/2e5G0SNVHzO4/Hw9ddfc/r0aQYHB3G73RgMBmJiYjAYDOh0OkZHR7HZ\nbJKwMRgM5OTksGrVqqC1U3FsxsbGePz4MXa7ndjYWBISEqioqGDLli0sWbIElUpFXV0dH3zwAaOj\no/T09IQdLJanKsrXgN1ul7T9nJwcH8EEvrzAn6kGi+eaoYoL0e12o1QqUavVeL1empqaaG1tBWDj\nxo3s3r1b0qbk8Gem4UyefDF7PB7q6ur4+uuvSUhI4NVXX2XHjh2SaSj/jrwNwdAVn1GpVCxdupTX\nXnsNg8FAamoqWq1W8k8NDAwwMDBAeno65eXlPmZwuH4+QRB4/Pgx3d3d5OXlsWXLFnJyciRzX2Qg\ngiBgMpmora1lZmYGo9HI8uXLWbp0aVA+N4VCQUxMDLt27WLFihVoNBoyMjIwGo3SBna73ZSVlSEI\nAgMDA0xMTOBwOJiZmfHZBP4a59Mgtl/cSN3d3Zw7d47W1lZSUlJ488032b17N83Nzfzxj3/k4cOH\nOJ3OgH0K19Lx90lbrVZMJhMmkwmDwcCSJUtISEiYM3fB9lFkqGL7+vr6+O6776S+qNVqioqKOHr0\nKBUVFXg8HkZGRuju7qanpweXy0VFRQVr1qyhoKAgJCGtUCjIz8/nwIED1NTUkJqaSlpaGhkZGeTl\n5ZGQkCDRczqdUjvl5n6opr6/kBQzbmpra8nJyaG0tDTgO+W8QKlU/t+L8ouDL3bcZrNx8+ZNurq6\nWL16NYcOHZK0NDkWI6ool2yCIHDjxg3+3//7fzQ2NnLkyBFeeukl0tLS5jwrb0Oo7RC1uPz8fCk4\nIAYXHjx4wMWLF7l37x7FxcWSZig3LcPxtY2Pj/PVV1/R2NjIzp07Je1UzsA8Hg8PHz7kk08+4c6d\nOyQlJbFv3z7effddDAb
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8f86ccc3d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"inputs, targets = valid_data.next()\n",
"recons = model.fprop(inputs)[-1]\n",
"_ = show_batch_of_images(inputs.reshape((-1, 28, 28)), (4, 2), 5)\n",
"_ = show_batch_of_images(recons.reshape((-1, 28, 28)), (4, 2), 5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 4: Using an autoencoder as an initialisation for supervised training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As a final exercise we will use the first layer of an autoencoder for MNIST digit images as a layer within a multiple layer model trained to do digit classification. The intuition behind pretraining methods like this is that the hidden representations learnt by an autoencoder should be a more useful representation for training a classifier than the raw pixel values themselves. We could fix the parameters in the layers taken from the autoencoder but generally we can get better performance by letting the whole model be trained end-to-end on the supervised training task, with the learnt autoencoder parameters in this case acting as a potentially more intelligent initialisation than randomly sampling the parameters which can help ease some of the optimisation issues encountered due to poor initialisation of a model.\n",
"\n",
"You can either use one of the autoencoder models you trained in the previous exercises, or train a new autoencoder model for specifically for this exercise. Create a new model object (instance of `mlp.models.MultipleLayerModel`) in which the first layer(s) of the list of layer passed to the model constructor are the trained first layer(s) from your autoencoder model (these can be accessed via the `layers` attribute which is a list of all the layers in a model). Add any additional layers you wish to the pretrained layers - at the very least you will need to add an output layer with output dimension 10 to allow the model to be used to predict class labels. Train this new model on the original MNIST image, digit labels pairs with a cross entropy error. "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"ae_model = model\n",
"train_data = data_providers.MNISTDataProvider('train', batch_size=50, rng=rng)\n",
"valid_data = data_providers.MNISTDataProvider('valid', batch_size=50, rng=rng)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"input_dim, output_dim, hidden_dim = 784, 10, 50\n",
"\n",
"weights_init = initialisers.GlorotUniformInit(rng=rng)\n",
"biases_init = initialisers.ConstantInit(0.)\n",
"\n",
"model = models.MultipleLayerModel([\n",
" ae_model.layers[0],\n",
" layers.ReluLayer(),\n",
" layers.AffineLayer(hidden_dim, output_dim, weights_init, biases_init)\n",
"])\n",
"\n",
"error = errors.CrossEntropySoftmaxError()\n",
"\n",
"learning_rule = learning_rules.AdamLearningRule()\n",
"\n",
"num_epochs = 25\n",
"stats_interval = 1\n",
"data_monitors={'acc': lambda y, t: (y.argmax(-1) == t.argmax(-1)).mean()}\n",
"optimiser = optimisers.Optimiser(\n",
" model, error, learning_rule, train_data, valid_data, data_monitors)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0:\n",
" error(train)=1.09e+01, acc(train)=1.03e-01, error(valid)=1.07e+01, acc(valid)=1.05e-01, params_penalty=0.00e+00\n",
"Epoch 1: 1.38s to complete\n",
" error(train)=3.31e-01, acc(train)=9.05e-01, error(valid)=3.11e-01, acc(valid)=9.12e-01, params_penalty=0.00e+00\n",
"Epoch 2: 1.34s to complete\n",
" error(train)=2.61e-01, acc(train)=9.25e-01, error(valid)=2.51e-01, acc(valid)=9.30e-01, params_penalty=0.00e+00\n",
"Epoch 3: 1.34s to complete\n",
" error(train)=2.17e-01, acc(train)=9.37e-01, error(valid)=2.12e-01, acc(valid)=9.41e-01, params_penalty=0.00e+00\n",
"Epoch 4: 1.36s to complete\n",
" error(train)=1.88e-01, acc(train)=9.46e-01, error(valid)=1.94e-01, acc(valid)=9.46e-01, params_penalty=0.00e+00\n",
"Epoch 5: 1.40s to complete\n",
" error(train)=1.73e-01, acc(train)=9.47e-01, error(valid)=1.88e-01, acc(valid)=9.45e-01, params_penalty=0.00e+00\n",
"Epoch 6: 1.40s to complete\n",
" error(train)=1.41e-01, acc(train)=9.59e-01, error(valid)=1.64e-01, acc(valid)=9.54e-01, params_penalty=0.00e+00\n",
"Epoch 7: 1.35s to complete\n",
" error(train)=1.39e-01, acc(train)=9.58e-01, error(valid)=1.66e-01, acc(valid)=9.53e-01, params_penalty=0.00e+00\n",
"Epoch 8: 1.41s to complete\n",
" error(train)=1.12e-01, acc(train)=9.66e-01, error(valid)=1.45e-01, acc(valid)=9.59e-01, params_penalty=0.00e+00\n",
"Epoch 9: 1.46s to complete\n",
" error(train)=1.09e-01, acc(train)=9.67e-01, error(valid)=1.48e-01, acc(valid)=9.58e-01, params_penalty=0.00e+00\n",
"Epoch 10: 1.40s to complete\n",
" error(train)=8.54e-02, acc(train)=9.75e-01, error(valid)=1.29e-01, acc(valid)=9.64e-01, params_penalty=0.00e+00\n",
"Epoch 11: 1.58s to complete\n",
" error(train)=8.34e-02, acc(train)=9.75e-01, error(valid)=1.34e-01, acc(valid)=9.62e-01, params_penalty=0.00e+00\n",
"Epoch 12: 1.44s to complete\n",
" error(train)=7.08e-02, acc(train)=9.79e-01, error(valid)=1.18e-01, acc(valid)=9.66e-01, params_penalty=0.00e+00\n",
"Epoch 13: 1.50s to complete\n",
" error(train)=6.54e-02, acc(train)=9.81e-01, error(valid)=1.25e-01, acc(valid)=9.64e-01, params_penalty=0.00e+00\n",
"Epoch 14: 1.63s to complete\n",
" error(train)=6.31e-02, acc(train)=9.82e-01, error(valid)=1.22e-01, acc(valid)=9.65e-01, params_penalty=0.00e+00\n",
"Epoch 15: 1.63s to complete\n",
" error(train)=5.89e-02, acc(train)=9.83e-01, error(valid)=1.19e-01, acc(valid)=9.67e-01, params_penalty=0.00e+00\n",
"Epoch 16: 1.39s to complete\n",
" error(train)=5.33e-02, acc(train)=9.85e-01, error(valid)=1.23e-01, acc(valid)=9.66e-01, params_penalty=0.00e+00\n",
"Epoch 17: 1.43s to complete\n",
" error(train)=4.60e-02, acc(train)=9.88e-01, error(valid)=1.14e-01, acc(valid)=9.70e-01, params_penalty=0.00e+00\n",
"Epoch 18: 1.39s to complete\n",
" error(train)=4.64e-02, acc(train)=9.86e-01, error(valid)=1.26e-01, acc(valid)=9.65e-01, params_penalty=0.00e+00\n",
"Epoch 19: 1.54s to complete\n",
" error(train)=4.28e-02, acc(train)=9.88e-01, error(valid)=1.30e-01, acc(valid)=9.64e-01, params_penalty=0.00e+00\n",
"Epoch 20: 1.40s to complete\n",
" error(train)=3.49e-02, acc(train)=9.91e-01, error(valid)=1.17e-01, acc(valid)=9.67e-01, params_penalty=0.00e+00\n",
"Epoch 21: 1.44s to complete\n",
" error(train)=3.74e-02, acc(train)=9.89e-01, error(valid)=1.26e-01, acc(valid)=9.67e-01, params_penalty=0.00e+00\n",
"Epoch 22: 1.52s to complete\n",
" error(train)=2.74e-02, acc(train)=9.93e-01, error(valid)=1.18e-01, acc(valid)=9.69e-01, params_penalty=0.00e+00\n",
"Epoch 23: 1.40s to complete\n",
" error(train)=2.71e-02, acc(train)=9.93e-01, error(valid)=1.19e-01, acc(valid)=9.69e-01, params_penalty=0.00e+00\n",
"Epoch 24: 1.41s to complete\n",
" error(train)=2.36e-02, acc(train)=9.94e-01, error(valid)=1.20e-01, acc(valid)=9.69e-01, params_penalty=0.00e+00\n",
"Epoch 25: 1.53s to complete\n",
" error(train)=2.44e-02, acc(train)=9.93e-01, error(valid)=1.29e-01, acc(valid)=9.68e-01, params_penalty=0.00e+00\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.text.Text at 0x7f8f86c4f210>"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAq0AAAF5CAYAAAC1GxMPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzs3XlYVdX+x/H3OiAiKjiAYipqas4TaIZT5lRamlMq5k2l\nHK6VZaWpaaZWTs3TVbMcyrhmTnnLLMuhHLJAsxzrl7OpYM44wvr9cYBARTlMB/Dzep7zAGuvvfb3\niD19XGfttY21FhERERGRnMzh7gJERERERG5EoVVEREREcjyFVhERERHJ8RRaRURERCTHU2gVERER\nkRxPoVVEREREcjyFVhERERHJ8RRaRURERCTHU2gVERERkRxPoVVEREREcrx0hVZjzKPGmN3GmHPG\nmA3GmAbX6dvYGPODMSbGGBNrjNlujHnyij69jTHxxpi4hK/xxpjY9NQmIiIiInmPp6snGGO6A68C\n/YGNwBBguTHmNmttzDVOOQu8DWxJ+L4JMN0Yc8ZaOyNZv5PAbYBJ+Nm6WpuIiIiI5E3GWteyoTFm\nA/CjtfaJhJ8NsB94y1o7OY1jLADOWGt7J/zcG3jdWlvMpWJERERE5Kbg0vIAY0w+IAT4NrHNOlPv\nCiA0jWPUS+i76opDhYwxe4wx+4wxi40x1V2pTURERETyLlfXtPoDHsCRK9qPAIHXO9EYs98Ycx7n\nkoJ3rbUzkx3eCYQDHYAHE+paZ4y5xcX6RERERCQPcnlNawY0AQoBdwCTjDF/WGvnAVhrNwAbEjsa\nY9YD24EBwJhrDWaMKQ7cDewBzmdp5SIiIiKSHt5AeWC5tfZYRgZyNbTGAHFAySvaSwKHr3eitXZv\nwrdbjTGBwAvAvFT6XjbGbAIqXWfIu4G5aahZRERERNzrQeCTjAzgUmi11l4yxkQCLYHPIelGrJbA\nWy4M5QHkT+2gMcYB1AK+uM4YewA+/vhjqlWr5sKlJbcaMmQIr7/+urvLkGyi3/fNRb/vm4t+3zeP\n7du306tXL0jIbRmRnuUBrwGzEsJr4pZXPsAsAGPMBOCWZDsDDAL2ATsSzr8TeBp4I3FAY8xonMsD\n/gCKAMOAICD5llhXOg9QrVo1goOD0/E2JLfx8/PT7/omot/3zUW/75uLft83pQwv5XQ5tFprPzXG\n+APjcC4L2Azcba2NTugSCJRNdooDmIBzPcNl4P+Aodba6cn6FAWmJ5x7HIgEQq21OxARERGRm166\nbsSy1r4HvJfKsb5X/PwO8M4NxnsKeCo9tYiIiIhI3peux7iKiIiIiGQnhVbJNcLCwtxdgmQj/b5v\nLvp931z0+5b0cPkxrjmFMSYYiIyMjNRibhERkWvYt28fMTEx7i5D8jB/f3+CgoJSPR4VFUVISAhA\niLU2KiPXys6HC4iIiEg22bdvH9WqVSM2NtbdpUge5uPjw/bt268bXDOLQquIiEgeFBMTQ2xsrPYz\nlyyTuAdrTEyMQquIiIhkjPYzl7xCN2KJiIiISI6n0CoiIiIiOZ5Cq4iIiIjkeAqtIiIiIpLjKbSm\n4rffYNgwyKXb2IqIiEgOMHnyZKpXr55t13vhhRdwONIX76ZNm0a5cuW4dOlSJleVORRaU3HkCEyZ\nAmvXursSERERyY1Onz7N5MmTGT58eFLbuXPnGDt2LGvWrMmSaxpj0h1a+/Tpw8WLF5k2bVomV5U5\nFFpTcdddULEi5NDfm4iIiORwH3zwAXFxcfTo0SOpLTY2lrFjx7Jq1aosuebo0aPT/UCJ/Pnz07t3\nb1577bVMripzKLSmwuGA/v1h/nz4+293VyMiIiKZ4cKFC6T2CPvMeHpY8jFmzZpFhw4d8PLySmpL\n7dppGS8tHA5Hiuu5qlu3buzZsyfLQnVGKLReR58+EB8Pc+a4uxIRERG50qFDhwgPDycwMBBvb29q\n1qzJzJkzk46vXr0ah8PBvHnzGDVqFGXKlKFgwYKcPn2aWbNm4XA4WLNmDYMGDaJkyZKULVs26dxN\nmzbRtm1b/Pz8KFy4MK1ateLHH39Mcf3Zs2enOsbu3bvZsmULrVq1Suq/d+9eSpQogTEmae2pw+Fg\n3LhxgPPj+cKFC/Pnn3/Srl07fH196dWrFwA//PAD3bp1o1y5cnh7exMUFMRTTz3F+fPnU9R0rTWt\nDoeDwYMHs2TJEmrVqpX0Z7V8+fKr/kyDg4MpVqwYS5YsSc+vJEvpiVjXUaIEdOrkXCLwxBNgjLsr\nEhEREYCjR4/SsGFDPDw8GDx4MP7+/ixbtoyHH36Y06dPM3jw4KS+48ePJ3/+/AwdOpQLFy7g5eWF\nSfif+qBBgyhRogRjxozh7NmzAGzdupVmzZrh5+fH8OHD8fT0ZNq0aTRv3pw1a9bQoEGDFLVca4x1\n69ZhjEnxNLKAgACmTp3KwIED6dy5M507dwagdu3agHM96uXLl7n77rtp2rQpr776Kj4+PgDMnz+f\nc+fOMWjQIIoXL87GjRt5++23OXjwIPPmzUu6hjEm6b0l9/3337Nw4UIGDRpE4cKFeeutt+jatSv7\n9u2jaNGiKfoGBwezNgfe1KPQegMDBkDLlvDDD9C0qburEREREYCRI0dirWXz5s0UKVIEgP79+9Oz\nZ09eeOEFBgwYkNT3woULREVFXfNjc39/f7799tsUQW/UqFFcvnyZtWvXUq5cOQD+9a9/UaVKFYYN\nG8bKlStvOMbOnTsBqFChQlKbj48PXbp0YeDAgdSuXZuePXteVc/Fixfp3r07L774Yor2yZMnkz9/\n/qSfH3nkESpWrMhzzz3HgQMHKFOmzHX/vHbs2MH27dspX748AM2bN6dOnTpEREQwaNCgFH1vvfVW\nPv744+uO5w4KrTfQvDlUquScbVVoFRGRvCg2FnbsyPrrVK0KCROHGbZw4UK6d+9OXFwcx44dS2pv\n06YN8+bNIyoqKqmtT58+1wysxhj69euXImzGx8fzzTff0KlTp6TAChAYGEjPnj2ZMWMGZ86coVCh\nQqmOAXDs2DE8PT2TZkpdMXDgwKvakgfW2NhYzp07R2hoKPHx8WzatOmGobV169ZJgRWgVq1a+Pr6\n8ueff17Vt2jRopw7d47z58/j7e3tcv1ZRaH1BhJvyBo9Gt58E4oXd3dFIiIimWvHDggJyfrrREZC\nsk/L0y06OpoTJ04wffr0a27PZIzh6NGjSTOwycPala48Fh0dTWxsLLfddttVfatVq0Z8fDz79++n\nWrVqqY6REZ6entcMoPv372f06NEsXbqU48ePJ7UbYzh58uQNx02+XjdR0aJFU4yVKPFmsWstM3An\nhdY06NMHnnvOeUPWkCHurkZERCRzVa3qDJTZcZ3MEB8fD0CvXr3o3bv3NfvUrl2brVu3AlCgQIFU\nx7resbS61hjFixfn8uXLnD17loIFC6Z5rOQzqoni4+Np1aoVJ06cYMSIEVSpUoWCBQty8OBBevfu\nnfTncT0eHh7XbL/WbgbHjx/Hx8fnmrW4k0JrGgQEQOfOziUCTz6pG7JERCRv8fHJnBnQ7BIQEEDh\nwoWJi4ujRYsWmT62j49P0prU5LZv347D4bjmrOWVqiYk9N27d1OzZs2k9vTMXv7666/8/vvvfPTR\nRzz44INJ7StWrHB5rLTYvXt3ipnknEJbXqXRgAGwcydk0QMsREREJI0cDgddunRhwYIFSbOpycXE\nxGRo7DZt2rBkyRL27duX1H7kyBEiIiJo2rRp0nrW6wkNDcVay88//5yiPXGN64kTJ9JcU+Is6ZUz\nqm+88UaWfIQfFRVFo0aNMn3cjNJMaxo1bw6VK8P06XDnne6uRkRE5OY2ceJEVq1aRcOGDenXrx/V\nq1fn77//JjIyku+++y5NwTW1jf5ffPFFVqxYQePGjRk0aBAeHh5Mnz6dixcvMnny5DSNUaFCBWrW\nrMmKFSvo06dPUru3tzfVq1dn3rx5VK5cmWLFilGzZk1q1KiRap1Vq1alYsWKPP300xw4cABfX18W\nLFjgUvBNq8jISP7++28
"text/plain": [
"<matplotlib.figure.Figure at 0x7f8f870c25d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"stats, keys, run_time = optimiser.train(num_epochs=num_epochs, stats_interval=stats_interval)\n",
"\n",
"# Plot the change in the validation and training set error over training.\n",
"fig_1 = plt.figure(figsize=(8, 4))\n",
"ax_1 = fig_1.add_subplot(111)\n",
"for k in ['error(train)', 'error(valid)']:\n",
" ax_1.plot(np.arange(1, stats.shape[0]) * stats_interval, \n",
" stats[1:, keys[k]], label=k)\n",
"ax_1.legend(loc=0)\n",
"ax_1.set_xlabel('Epoch number')"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [conda env:mlp]",
"language": "python",
"name": "conda-env-mlp-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}