2017-10-06 15:46:19 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
2017-10-15 14:13:09 +02:00
"$\\newcommand{\\vct}[1]{\\boldsymbol{#1}}\n",
"\\newcommand{\\mtx}[1]{\\mathbf{#1}}\n",
"\\newcommand{\\tr}{^\\mathrm{T}}\n",
"\\newcommand{\\reals}{\\mathbb{R}}\n",
"\\newcommand{\\lpa}{\\left(}\n",
"\\newcommand{\\rpa}{\\right)}\n",
"\\newcommand{\\lsb}{\\left[}\n",
"\\newcommand{\\rsb}{\\right]}\n",
"\\newcommand{\\lbr}{\\left\\lbrace}\n",
"\\newcommand{\\rbr}{\\right\\rbrace}\n",
"\\newcommand{\\fset}[1]{\\lbr #1 \\rbr}\n",
"\\newcommand{\\pd}[2]{\\frac{\\partial #1}{\\partial #2}}$\n",
"\n",
"# Multiple layer models\n",
"\n",
"In this notebook we will explore network models with multiple layers of transformations. This will build upon the single-layer affine model we looked at in the previous notebook and use material covered in the [second](http://www.inf.ed.ac.uk/teaching/courses/mlp/2017-18/mlp02-sln.pdf) and [third](http://www.inf.ed.ac.uk/teaching/courses/mlp/2017-18/mlp03-mlp.pdf) lectures.\n",
2017-10-06 15:46:19 +02:00
"\n",
"You will need to use these models for the experiments you will be running in the first coursework so part of the aim of this lab will be to get you familiar with how to construct multiple layer models in our framework and how to train them.\n",
"\n",
"## What is a layer?\n",
"\n",
"Often when discussing (neural) network models, a network layer is taken to mean an input to output transformation of the form\n",
"\n",
"\\begin{equation}\n",
2017-10-15 14:13:09 +02:00
" \\vct{y} = \\vct{f}(\\mtx{W} \\vct{x} + \\vct{b})\n",
2017-10-06 15:46:19 +02:00
" \\qquad\n",
" \\Leftrightarrow\n",
" \\qquad\n",
2017-10-15 14:13:09 +02:00
" y_k = f\\lpa\\sum_{d=1}^D \\lpa W_{kd} x_d \\rpa + b_k \\rpa\n",
2017-10-06 15:46:19 +02:00
"\\end{equation}\n",
"\n",
2017-10-15 14:13:09 +02:00
"where $\\mtx{W}$ and $\\vct{b}$ parameterise an affine transformation as discussed in the previous notebook, and $f$ is a function applied elementwise to the result of the affine transformation. For example a common choice for $f$ is the logistic sigmoid function \n",
2017-10-06 15:46:19 +02:00
"\\begin{equation}\n",
" f(u) = \\frac{1}{1 + \\exp(-u)}.\n",
"\\end{equation}\n",
"\n",
"In the second lecture slides you were shown how to train a model consisting of an affine transformation followed by the elementwise logistic sigmoid using gradient descent. This was referred to as a 'sigmoid single-layer network'.\n",
"\n",
"In the previous notebook we also referred to single-layer models, where in that case the layer was an affine transformation, with you implementing the various necessary methods for the `AffineLayer` class before using an instance of that class within a `SingleLayerModel` on a regression problem. We could in that case consider the function $f$ to simply be the identity function $f(u) = u$. In the code for the labs we will however use a slightly different convention. Here we will consider the affine transformations and subsequent elementwise function $f$ to each be a separate transformation layer. \n",
"\n",
"This will mean we can combine our already implemented `AffineLayer` class with any non-linear function applied to the outputs just by implementing a layer object for the relevant non-linearity and then stacking the two layers together. An alternative would be to have our new layer objects inherit from `AffineLayer` and then call the relevant parent class methods in the child class however this would mean we need to include a lot of the same boilerplate code in every new class.\n",
"\n",
"To give a concrete example, in the `mlp.layers` module there is a definition for a `SigmoidLayer` equivalent to the following (documentation strings have been removed here for brevity)\n",
"\n",
"```python\n",
"class SigmoidLayer(Layer):\n",
"\n",
" def fprop(self, inputs):\n",
" return 1. / (1. + np.exp(-inputs))\n",
"\n",
" def bprop(self, inputs, outputs, grads_wrt_outputs):\n",
" return grads_wrt_outputs * outputs * (1. - outputs)\n",
"```\n",
"\n",
"As you can see this `SigmoidLayer` class has a very lightweight definition, defining just two key methods:\n",
"\n",
" * `fprop` which takes a batch of activations at the input to the layer and forward propagates them to produce activates at the outputs (directly equivalently to the `fprop` method you implemented for then `AffineLayer` in the previous notebook),\n",
2017-10-15 14:13:09 +02:00
" * `brop` which takes a batch of gradients with respect to the outputs of the layer and back-propagates them to calculate gradients with respect to the inputs of the layer (explained in more detail below).\n",
2017-10-06 15:46:19 +02:00
" \n",
"This `SigmoidLayer` class only implements the logistic sigmoid non-linearity transformation and so does not have any parameters. Therefore unlike `AffineLayer` it is derived directly from the base `Layer` class rather than `LayerWithParameters` and does not need to implement `grads_wrt_params` or `params` methods. \n",
"\n",
"To create a model consisting of an affine transformation followed by applying an elementwise logistic sigmoid transformation we first create a list of the two layer objects (in the order they are applied from inputs to outputs) and then use this to instantiate a new `MultipleLayerModel` object:\n",
"\n",
"```python\n",
"from mlp.layers import AffineLayer, SigmoidLayer\n",
"from mlp.models import MultipleLayerModel\n",
"\n",
"layers = [AffineLayer(input_dim, output_dim), SigmoidLayer()]\n",
"model = MultipleLayerModel(layers)\n",
"```\n",
"\n",
"Because of the modular way in which the layers are defined we can also stack an arbitrarily long sequence of layers together to produce deeper models. For instance the following would define a model consisting of three pairs of affine and logistic sigmoid transformations.\n",
"\n",
"```python\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, hidden_dim), SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim), SigmoidLayer(),\n",
" AffineLayer(hidden_dim, output_dim), SigmoidLayer(),\n",
"])\n",
"```\n",
"\n",
"## Back-propagation of gradients\n",
" \n",
"To allow training models consisting of a stack of multiple layers, all layers need to implement a `bprop` method in addition to the `fprop` we encountered in the previous week. \n",
"\n",
"The `bprop` method takes gradients of an error function with respect to the *outputs* of a layer and uses these gradients to calculate gradients of the error function with respect to the *inputs* of a layer. As the inputs to a non-input layer in a multiple-layer model consist of the outputs of the previous layer, this means we can calculate the gradients of the error function with respect to the outputs of every layer in the model by iteratively propagating the gradients backwards through the layers of the model (i.e. from the last to first layer), hence the term 'back-propagation' or 'bprop' for short. A block diagram illustrating this is shown for a three layer model below.\n",
"\n",
"<img src='res/fprop-bprop-block-diagram.png' />\n",
"\n",
2017-10-15 14:13:09 +02:00
"For a layer with parameters, the gradients with respect to the layer outputs are required to calculate gradients with respect to the layer parameters. Therefore by combining back-propagation of gradients through the model with computing the gradients with respect to parameters in the relevant layers we can calculate gradients of the error function with respect to all of the parameters of a multiple-layer model in a very efficient manner (in fact the computational cost of computing gradients with respect to all of the parameters of the model using this method will only be a constant factor times the cost of calculating the model outputs in the forward pass).\n",
2017-10-06 15:46:19 +02:00
"\n",
"We so far have abstractly talked about calculating gradients with respect to the inputs of a layer using gradients with respect to the layer outputs. More concretely we will be using the chain rule for derivatives to do this, similarly to how we used the chain rule in exercise 4 of the previous notebook to calculate gradients with respect to the parameters of an affine layer given gradients with respect to the outputs of the layer.\n",
"\n",
2017-10-15 14:13:09 +02:00
"In particular if our layer has a batch of $B$ vector inputs each of dimension $D$, $\\fset{\\vct{x}^{(b)}}_{b=1}^B$, and produces a batch of $B$ vector outputs each of dimension $K$, $\\fset{\\vct{y}^{(b)}}_{b=1}^B$, then we can calculate the gradient with respect to the $d^\\textrm{th}$ dimension of the $b^{\\textrm{th}}$ input given the gradients with respect to the $b^{\\textrm{th}}$ output using\n",
2017-10-06 15:46:19 +02:00
"\n",
"\\begin{equation}\n",
2017-10-15 14:13:09 +02:00
" \\pd{E}{x^{(b)}_d} = \\sum_{k=1}^K \\lpa \\pd{E}{y^{(b)}_k} \\pd{y^{(b)}_k}{x^{(b)}_d} \\rpa.\n",
2017-10-06 15:46:19 +02:00
"\\end{equation}\n",
"\n",
2017-10-15 14:13:09 +02:00
"Mathematically therefore the `bprop` method takes an array of gradients with respect to the outputs $\\pd{E}{y^{(b)}_k}$ and applies a sum-product operation with the partial derivatives of each output with respect to each input $\\pd{y^{(b)}_k}{x^{(b)}_d}$ to produce gradients with respect to the inputs of the layer $\\pd{E}{x^{(b)}_d}$.\n",
2017-10-06 15:46:19 +02:00
"\n",
2017-10-15 14:13:09 +02:00
"For the affine transformation used in the `AffineLayer` implemented last week, i.e a forward propagation corresponding to \n",
2017-10-06 15:46:19 +02:00
"\n",
"\\begin{equation}\n",
2017-10-15 14:13:09 +02:00
" y^{(b)}_k = \\sum_{d=1}^D \\lpa W_{kd} x^{(b)}_d \\rpa + b_k\n",
2017-10-06 15:46:19 +02:00
"\\end{equation}\n",
"\n",
"then the corresponding partial derivatives of layer outputs with respect to inputs are\n",
"\n",
"\\begin{equation}\n",
2017-10-15 14:13:09 +02:00
" \\pd{y^{(b)}_k}{x^{(b)}_d} = W_{kd}\n",
2017-10-06 15:46:19 +02:00
"\\end{equation}\n",
"\n",
2017-10-15 14:13:09 +02:00
"and so the back-propagation method for the `AffineLayer` takes the following form\n",
2017-10-06 15:46:19 +02:00
"\n",
"\\begin{equation}\n",
2017-10-15 14:13:09 +02:00
" \\pd{E}{x^{(b)}_d} = \\sum_{k=1}^K \\lpa \\pd{E}{y^{(b)}_k} W_{kd} \\rpa.\n",
2017-10-06 15:46:19 +02:00
"\\end{equation}\n",
"\n",
"This can be efficiently implemented in NumPy using the `dot` function\n",
"\n",
"```python\n",
"class AffineLayer(LayerWithParameters):\n",
"\n",
" # ... [implementation of remaining methods from previous week] ...\n",
" \n",
" def bprop(self, inputs, outputs, grads_wrt_outputs):\n",
" return grads_wrt_outputs.dot(self.weights)\n",
"```\n",
"\n",
2017-10-15 14:13:09 +02:00
"An important special case applies when the outputs of a layer are an elementwise function of the inputs such that $y^{(b)}_k$ only depends on $x^{(b)}_d$ when $d = k$. In this case the partial derivatives $\\pd{y^{(b)}_k}{x^{(b)}_d}$ will be zero for $k \\neq d$ and so the above summation collapses to a single term, giving\n",
2017-10-06 15:46:19 +02:00
"\n",
"\\begin{equation}\n",
2017-10-15 14:13:09 +02:00
" \\pd{E}{x^{(b)}_d} = \\pd{E}{y^{(b)}_d} \\pd{y^{(b)}_d}{x^{(b)}_d}\n",
2017-10-06 15:46:19 +02:00
"\\end{equation}\n",
"\n",
"i.e. to calculate the gradient with respect to the $b^{\\textrm{th}}$ input vector we just perform an elementwise multiplication of the gradient with respect to the $b^{\\textrm{th}}$ output vector with the vector of derivatives of the outputs with respect to the inputs. This case applies to the `SigmoidLayer` and to all other layers applying an elementwise function to their inputs.\n",
"\n",
"For the logistic sigmoid layer we have that\n",
"\n",
"\\begin{equation}\n",
" y^{(b)}_d = \\frac{1}{1 + \\exp(-x^{(b)}_d)}\n",
" \\qquad\n",
" \\Rightarrow\n",
" \\qquad\n",
2017-10-15 14:13:09 +02:00
" \\pd{y^{(b)}_d}{x^{(b)}_d} = \n",
" \\frac{\\exp(-x^{(b)}_d)}{\\lsb 1 + \\exp(-x^{(b)}_d) \\rsb^2} =\n",
" y^{(b)}_d \\lsb 1 - y^{(b)}_d \\rsb\n",
2017-10-06 15:46:19 +02:00
"\\end{equation}\n",
"\n",
"which you should now be able relate to the implementation of `SigmoidLayer.bprop` given earlier:\n",
"\n",
"```python\n",
"class SigmoidLayer(Layer):\n",
"\n",
" def fprop(self, inputs):\n",
" return 1. / (1. + np.exp(-inputs))\n",
"\n",
" def bprop(self, inputs, outputs, grads_wrt_outputs):\n",
" return grads_wrt_outputs * outputs * (1. - outputs)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 1: training a softmax model on MNIST\n",
"\n",
"For this first exercise we will train a model consisting of an affine transformation plus softmax on a multiclass classification task: classifying the digit labels for handwritten digit images from the MNIST data set introduced in the first notebook.\n",
"\n",
"First run the cell below to import the necessary modules and classes and to load the MNIST data provider objects. As it takes a little while to load the MNIST data from disk into memory it is worth loading the data providers just once in a separate cell like this rather than recreating the objects for every training run.\n",
"\n",
"We are loading two data provider objects here - one corresponding to the training data set and a second to use as a *validation* data set. This is data we do not train the model on but measure the performance of the trained model on to assess its ability to *generalise* to unseen data. \n",
"\n",
"If you are in the Monday or Tuesday lab sessions you will not yet have had the lecture introducing the concepts of generalisation and validation data sets (though those doing MLPR alongside this course should already be familiar with these ideas). As you will need to report both training and validation set performances in your experiments for the first coursework assignment we are providing code here to give an example of how to do this."
]
},
{
"cell_type": "code",
2017-10-09 18:30:53 +02:00
"execution_count": 1,
2017-10-15 14:13:09 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import logging\n",
"from mlp.layers import AffineLayer, SoftmaxLayer, SigmoidLayer\n",
"from mlp.errors import CrossEntropyError, CrossEntropySoftmaxError\n",
"from mlp.models import SingleLayerModel, MultipleLayerModel\n",
"from mlp.initialisers import UniformInit\n",
"from mlp.learning_rules import GradientDescentLearningRule\n",
"from mlp.data_providers import MNISTDataProvider\n",
"from mlp.optimisers import Optimiser\n",
"%matplotlib inline\n",
"plt.style.use('ggplot')\n",
"\n",
"# Seed a random number generator\n",
"seed = 6102016 \n",
"rng = np.random.RandomState(seed)\n",
"\n",
"# Set up a logger object to print info about the training run to stdout\n",
"logger = logging.getLogger()\n",
"logger.setLevel(logging.INFO)\n",
"logger.handlers = [logging.StreamHandler()]\n",
"\n",
"# Create data provider objects for the MNIST data set\n",
"train_data = MNISTDataProvider('train', rng=rng)\n",
"valid_data = MNISTDataProvider('valid', rng=rng)\n",
"input_dim, output_dim = 784, 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To minimise replication of code and allow you to run experiments more quickly a helper function is provided below which trains a model and plots the evolution of the error and classification accuracy of the model (on both training and validation sets) over training."
]
},
{
"cell_type": "code",
2017-10-15 14:13:09 +02:00
"execution_count": 2,
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [],
"source": [
"def train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval):\n",
"\n",
" # As well as monitoring the error over training also monitor classification\n",
" # accuracy i.e. proportion of most-probable predicted classes being equal to targets\n",
" data_monitors={'acc': lambda y, t: (y.argmax(-1) == t.argmax(-1)).mean()}\n",
"\n",
" # Use the created objects to initialise a new Optimiser instance.\n",
" optimiser = Optimiser(\n",
" model, error, learning_rule, train_data, valid_data, data_monitors)\n",
"\n",
" # Run the optimiser for 5 epochs (full passes through the training set)\n",
" # printing statistics every epoch.\n",
2017-10-15 14:13:09 +02:00
" stats, keys, run_time = optimiser.train(num_epochs=num_epochs, stats_interval=stats_interval)\n",
2017-10-06 15:46:19 +02:00
"\n",
" # Plot the change in the validation and training set error over training.\n",
" fig_1 = plt.figure(figsize=(8, 4))\n",
" ax_1 = fig_1.add_subplot(111)\n",
" for k in ['error(train)', 'error(valid)']:\n",
" ax_1.plot(np.arange(1, stats.shape[0]) * stats_interval, \n",
" stats[1:, keys[k]], label=k)\n",
" ax_1.legend(loc=0)\n",
" ax_1.set_xlabel('Epoch number')\n",
"\n",
" # Plot the change in the validation and training set accuracy over training.\n",
" fig_2 = plt.figure(figsize=(8, 4))\n",
" ax_2 = fig_2.add_subplot(111)\n",
" for k in ['acc(train)', 'acc(valid)']:\n",
" ax_2.plot(np.arange(1, stats.shape[0]) * stats_interval, \n",
" stats[1:, keys[k]], label=k)\n",
" ax_2.legend(loc=0)\n",
" ax_2.set_xlabel('Epoch number')\n",
" \n",
2017-10-15 14:13:09 +02:00
" return stats, keys, run_time, fig_1, ax_1, fig_2, ax_2"
2017-10-06 15:46:19 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Running the cell below will create a model consisting of an affine layer follower by a softmax transformation and train it on the MNIST data set by minimising the multi-class cross entropy error function using a basic gradient descent learning rule. By using the helper function defined above, at the end of training curves of the evolution of the error function and also classification accuracy of the model over the training epochs will be plotted.\n",
"\n",
"You should try running the code for various settings of the training hyperparameters defined at the beginning of the cell to get a feel for how these affect how training proceeds. You may wish to create multiple copies of the cell below to allow you to keep track of and compare the results across different hyperparameter settings."
]
},
{
2017-10-15 14:13:09 +02:00
"cell_type": "markdown",
"metadata": {},
"source": [
"### Varying initialisation scale\n",
"\n",
"<span style=\"color:red\">First try a few different parameter initialisation scales</span>"
]
},
{
"cell_type": "markdown",
2017-10-06 15:46:19 +02:00
"metadata": {},
2017-10-15 14:13:09 +02:00
"source": [
"#### `init_scale = 0.01`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
2017-10-09 18:30:53 +02:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2017-10-15 14:13:09 +02:00
"Epoch 5: 6.4s to complete\n",
" error(train)=3.10e-01, acc(train)=9.14e-01, error(valid)=2.91e-01, acc(valid)=9.18e-01\n",
"Epoch 10: 5.4s to complete\n",
" error(train)=2.88e-01, acc(train)=9.20e-01, error(valid)=2.76e-01, acc(valid)=9.23e-01\n",
"Epoch 15: 3.5s to complete\n",
" error(train)=2.78e-01, acc(train)=9.23e-01, error(valid)=2.69e-01, acc(valid)=9.24e-01\n",
"Epoch 20: 4.2s to complete\n",
" error(train)=2.71e-01, acc(train)=9.25e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 25: 4.7s to complete\n",
" error(train)=2.68e-01, acc(train)=9.25e-01, error(valid)=2.65e-01, acc(valid)=9.26e-01\n",
"Epoch 30: 3.7s to complete\n",
" error(train)=2.63e-01, acc(train)=9.27e-01, error(valid)=2.62e-01, acc(valid)=9.26e-01\n",
"Epoch 35: 3.8s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 40: 4.2s to complete\n",
" error(train)=2.59e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.29e-01\n",
"Epoch 45: 4.3s to complete\n",
" error(train)=2.55e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 50: 4.0s to complete\n",
" error(train)=2.54e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 55: 3.6s to complete\n",
" error(train)=2.52e-01, acc(train)=9.30e-01, error(valid)=2.58e-01, acc(valid)=9.29e-01\n",
"Epoch 60: 4.4s to complete\n",
" error(train)=2.52e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 65: 3.8s to complete\n",
" error(train)=2.50e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 70: 4.3s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 75: 3.5s to complete\n",
" error(train)=2.47e-01, acc(train)=9.31e-01, error(valid)=2.57e-01, acc(valid)=9.30e-01\n",
"Epoch 80: 4.0s to complete\n",
" error(train)=2.46e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 85: 3.9s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 90: 3.6s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.57e-01, acc(valid)=9.29e-01\n",
"Epoch 95: 4.6s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.29e-01\n",
"Epoch 100: 3.8s to complete\n",
" error(train)=2.43e-01, acc(train)=9.33e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n"
2017-10-09 18:30:53 +02:00
]
2017-10-15 14:13:09 +02:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VOXd///XdWYmkwWyzExIgIQlIZE9gClicINE616q\nUtva9lbstxbrUmyLgrSl9UeLFtHerd62FKktXagLVlEBA6UIcYksssgWElkDIZkQSMJMMnPO748T\nhgQCCXCGbJ/n45HHbGfOXPPJJO8517nOdZRhGAZCCCGEaPe0tm6AEEIIIVpHQlsIIYToICS0hRBC\niA5CQlsIIYToICS0hRBCiA5CQlsIIYToICS0hRBCiA5CQlsIIYToICS0hRBCiA5CQlsIIYToIOxt\n3YDmHDx4sK2b0Kl4PB7Ky8vbuhmditQ0PKSu1pOahofVde3Vq1erlpMtbSGEEKKDkNAWQgghOggJ\nbSGEEKKDaJf7tIUQQlwahmHg8/nQdR2lVFs3p8M4fPgwfr//vJ5jGAaaphEZGXnBtZbQFkKILszn\n8+FwOLDbJQ7Oh91ux2aznffzAoEAPp+PqKioC3vd1iy0ceNGFixYgK7r5ObmMmHChCaPL1++nGXL\nloW+QTzwwAOkpKRw/Phx5s6dS1FREddddx3333//BTVSCCFEeOi6LoF9Cdnt9vPeQm/y/JYW0HWd\n+fPnM2PGDNxuN9OmTSM7O5uUlJTQMldddRU33HADAJ9++imvvPIKTz75JA6Hg7vvvpu9e/eyb9++\nC26kEEKI8JAu8UvvYmre4kC0oqIikpOTSUpKwm63k5OTQ2FhYZNloqOjQ9d9Pl+oQZGRkQwcOJCI\niIgLbuDFMHy16O+9hlH0eZu8vhBCCGGlFkPb6/XidrtDt91uN16v94zlli5dysMPP8zf/vY37rvv\nPmtbeaFsDoz3XsNY835bt0QIIUQbmjdvHq+++ioAixYt4tChQ+e9jr/85S+hdZzNtm3b+OEPf3hB\nbWwNy3Zk3Hjjjdx4442sWbOG119/nYceeqjVz83Pzyc/Px+A2bNn4/F4rGoWVaOvxr+uAHd8PKqL\n7rex2+2W1lRITcNF6mq9lmp6+PDhDrFPOxgMNhn4dfrtswkEAoAZ1Pn5+djtdl577TWGDBnSZDdv\na9Y7adKkJrebq9uwYcM4dOgQhw4danb9AE6n84I/5y3+plwuFxUVFaHbFRUVuFyusy6fk5PDvHnz\nzqsReXl55OXlhW5bOTWcMTQb47/LKF+zEjV0lGXr7UhkGkPrSU3DQ+pqvZZq6vf7L2gUtNVef/11\nXn75Zerq6hg5ciS//vWvGThwIN/61rf44IMP+NWvfsXDDz/M7bffzurVq3nwwQdJT0/niSeewOfz\n0bdvX5599lni4+O56667GDx4MIWFhXzlK19h0KBBDB06FIA333yTjRs3MnnyZCIjI3nrrbe47rrr\nmqy3urqav/3tb9TV1dG/f3/+93//l6ioKJ599lliYmL4/ve/z8SJExkxYgQFBQVUVVXx7LPPcsUV\nVwBmpr3xxhs8+OCDzb5Xv99/xu+ktdOYthja6enplJaWUlZWhsvloqCggEceeaTJMqWlpfTs2ROA\n9evXh663C0NGQmQUxrq1XTa0hRCiNfR/zsPYV2LpOlVqf7Sv/79zLrNr1y7eeust3nzzTRwOB9Om\nTeONN96gtraWkSNH8vOf/zy0bEJCAsuWLQPMcHzqqae48sor+c1vfsPcuXP55S9/CUB9fT3vvfce\nAHPmzGH48OEA3Hrrrfz5z3/mpz/9KVlZWc2u1+v1cs899wDw9NNP849//OOMrWwwt+LfeecdVqxY\nwdy5c1m0aBEAWVlZ/P73vz9raF+MFkPbZrMxadIkZs2aha7rjBs3jtTUVBYtWkR6ejrZ2dksXbqU\nzZs3Y7PZ6NatGz/4wQ9Cz//BD35AbW0tgUCAwsJCZsyYcdYug3BQjgjU8NEYGz7CuGdyl+0iF0KI\n9mrNmjVs3ryZm2++GTAHNHs8Hmw2G7fcckuTZW+//XYAjh07RlVVFVdeeSUAEydO5IEHHjhjOYCy\nsjIyMjLO2YbGy+/YsYNnnnmGY8eOUVNTw7XXXtvsc062d/jw4ezfvz90v9vt5vDhwy2+7wvRqgQb\nNWoUo0Y13Uq9++67Q9fPNfDshRdeuMCmWUdlj8X45L+wczMMHtnWzRFCiHappS3icDEMg4kTJzJt\n2rQm97/00ktndN03PlrpXBovFxkZic/na/XyU6ZMYf78+QwZMoRFixbx4YcfNvuck0dG2Wy20L5z\nMLu/IyMjW9XO89U15h4fMhKcURifrm3rlgghhDjNVVddxZIlS0L7eSsrK5tsuTYnNjaWuLg4Pv74\nY8DcJz5mzJhmlx0wYABffPFF6HZMTAzV1dVnXXd1dTVJSUnU19ezePHi83w3UFxczGWXXXbez2uN\nLtFXrCKcqOHZp7rI28GgCyGEEKbMzEymTp3KN77xDQzDwG63M2vWrBaf9/zzz4cGovXp04e5c+c2\nu9z48eObjMX62te+xhNPPBEaiHa6n/zkJ9x666243W5Gjhx5zoBvTkFBAbm5uef1nNZShmEYYVnz\nRTh48KDl6zTWF6D/32y0Kb9EDR5h+frbMxmRaz2paXhIXa3XUk1ra2tb3eXckd1///08+eSTpKWl\nWbI+u93epEv8JL/fz5133smbb7551kPpmqt5a0ePd43ucYChl4MzEmNdQVu3RAghxCU2bdo0ysrK\nwv46Bw4cYPr06WE79r3LhLbZRf4ljA0fYgSDbd0cIYQQl9CAAQPOus/bSmlpaeTk5IRt/V0mtAHU\n5TlwvAp2bmnrpgghhBDnrUuFNkOzIcKJsV66yIUQQnQ8XSq0ldOJGpaNsf5DDF26yIUQQnQsXSq0\nAbh8LBw7CrvkdJ1CCCE6li4X2mp4NkREyEQrQgjRhTQ+Nef5+uEPf8iSJUsA+PGPf8zOnTvPWGbR\nokU8+eSTACxYsIB//vOfF97Yc+h6oe2MhKHZ5ihy6SIXQogOIXjaUT+n3z6bQCBAIBBg0aJFfPWr\nX73odsyZM4fMzMxzLvP1r3+dl19++aJfqzldLrTBnIucqkoo2tbWTRFCCIE5Dektt9zC9ddfz9Sp\nUwkGg2RkZPCLX/yCvLw81q1bxxVXXMGsWbP48pe/zJIlS9iyZQu33noreXl53H///Rw9ehSAu+66\ni5/97GfcdNNN/OlPf2Lt2rUMHToUu91OUVFRk5OQ7Nu3LzR72XPPPcfNN9/M+PHjmTp1Ks3NPXbX\nXXfx2WefAebW9VVXXcUtt9zCp59+GlomKiqK1NRUNmzYYHmdusQ0pqdTw7IxHGYXucoc2tbNEUKI\nduFPnx6mpPLcJ9Y4X/0TIvludtI5l7mUp+YcMGAAdXV17N27lz59+vDWW29x2223AXDvvfcyZcoU\nAB5++GHef/99brjhhmbbfPjwYebMmcPSpUvp3r07EydODJ2zG8wzf3388ceMHGntSaq65pZ2ZBQM\nHdUwilxv6+YIIUSX1vjUnNdffz1r1qxh7969531qzpMnD2m8HJin5nS73aHbt912W2jO8bfeeiu0\nbEFBAbfeeiu5ubkUFBQ0u+/6pHXr1nHllVfidruJiIho8npgTh8bjtNzdsktbQB1+ViMDR/B7u2Q\nMbitmyOEEG2upS3icLnUp+a8/fbbeeCBB7jppptQSpGWlobP52P69Om8++679O7dm2effRa/33/B\n7ylcp+fsklvaACrrS2B3YKyTUeRCCNGWLvWpOfv164fNZuP5558PbSGfDGiXy0VNTQ3vvPPOOV//\n8ssv56OPPsLr9VJfXx8aXX5ScXExAwcOPOc6LkTX3dKOjDa7yNetxfja/Sity35/EUKINnWpT80J\n5tb2U089xUcffQRAXFwc3/zmN8nNzSUxMZGsrKxzvnZSUhI/+tGPuP3224mLi2PIkCFNHi8sLOSx\nxx5r8T2cry5zas7m6B+
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc438001828>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAENCAYAAADngqfoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNX9+P/XnZkkkz2ZCdlIgBBEIwgSolC0gZC4gVhq\n6eKCVvBj/YpBq+IHrD+1trSp0GqlQvtBQMVS6SLWpVaNLCoREkBUdsKihIRsM8lkmZlk5p7fHwMD\nETABJmTh/Xw88pjcmXPvPfcYfM/ZNaWUQgghhBC9kqGrMyCEEEKIziOBXgghhOjFJNALIYQQvZgE\neiGEEKIXk0AvhBBC9GIS6IUQQoheTAK9EEII0YtJoBdCCCF6MQn0QgghRC8mgV4IIYToxUxdnYFA\nKS8v7+os9CpxcXHU1NR0dTZ6FSnTwJMy7RxSroHXGWWanJzcoXRSoxdCCCF6MQn0QgghRC8mgV4I\nIYToxXpNH/03KaVwuVzouo6maV2dnR6nsrISt9sN+MrSYDBgNpulLIUQoofptYHe5XIRFBSEydRr\nH7FTmUwmjEaj/9jj8eByuQgNDe3CXAkhhDhTvbbpXtd1CfIBZDKZ0HW9q7MhhBDiDPXaQC9NzIEn\nZSqEED1Prw30QgghRFdTuhe1YysNr7yAUqpL8iCBvptTSvHDH/6QhoYG6uvreemll87qOlOnTqW+\nvv5b0zz99NN88sknZ3V9IYQQx6myA+j/XIb+v9PRn30C53tvgL1rFiGSTuxu7sMPP+TSSy8lMjKS\nQ4cO8corr/DTn/70pHQej+dbxyQsX7683XtNmzaNWbNmcfXVV59LloUQ4oKk6mpRGz9CbVgDZQfB\naIShIzH8eBxx42+g1tHQJfmSQN/Jpk2bRnl5OW63m+nTp3P77bezZs0aCgoK8Hq9WCwW/v73v9PU\n1MTjjz/OF198gaZp/PznP2fixImsWrWK2267DYDf/OY3fPXVV1xzzTVkZ2eTm5vLvHnziI6OprS0\nlE8++eSU9wMYNWoU7777Lk1NTdx+++1ceeWVbNq0icTERJYuXUpoaCgpKSnY7Xaqqqo6vLSiEEJc\nyJTLifpsgy+47/wClA5pg9FuuQftiu+iRUYDoAWHAN040G/dupVly5ah6zq5ublMnjy5zefV1dUs\nWrQIh8NBREQE+fn5WK1WqqurmT9/Prqu4/V6uf7667n22mtxu9384Q9/oLKyEoPBwMiRI/3BbO3a\ntSxfvhyLxQLA9ddfT25u7jk9pP7aYtShA+d0jW/SUtMw/OR/2k33+9//ntjYWJxOJxMnTuS6665j\n1qxZvP766/Tr1w+73Q7Ac889R2RkJB9++CEAdXV1AJSUlPC73/0OgMcee4zdu3fzwQcfAFBUVMSX\nX37J6tWr6dev3ynvN2HCBH9ZHnPgwAFeeOEF5s2bx89+9jP+85//8IMf/ACAyy67jJKSEr73ve8F\noJSEEKL3UboXdn6B2rAG9dkGcLvAGo82YQra6HFoiSldncU22g30uq6zZMkSHn/8caxWK3PmzCEr\nK4uUlOMPsnz5crKzsxk3bhzbtm1jxYoV5OfnExsby69//WuCgoJwuVw8/PDDZGVlER4ezqRJkxg6\ndCgej4enn36azz77jBEjRgAwZswYpk+f3nlPfR4tXbqUd999F/BtvPPqq68yevRof2COjY0F4OOP\nP2bhwoX+82JiYgBfwI+IiDjt9S+//HL/tU51vwMHDpwU6FNTUxk6dCgAw4YN49ChQ/7PrFYrlZWV\nZ/28QgjRW6lDB3zBfeNHUG+D0HC0K7PRRufAoAw0Q/cc9tZuoC8tLSUxMZGEhATAF4RLSkraBPqy\nsjLuuOMOAIYMGcK8efN8Fz+hz7i1tdU/DzskJMQfaEwmE2lpadTW1gbokU7WkZp3ZygqKuLjjz/m\nrbfeIjQ0lClTpjBkyBD27dvX4Wscm79uOM0fUFhY2Lfe79jqdicKCQnx/240GnG5XP5jt9uN2Wzu\ncP6EEKI3U/ZaVPE61Kdr4PBXYDTBZSMxjM6BYVloQcFdncV2tRvobTYbVqvVf2y1Wtm7d2+bNP37\n96e4uJgJEyZQXFyM0+mkoaGByMhIampqKCgo4MiRI9x+++0n1S6bmprYvHkzEyZM8L+3ceNGdu7c\nSVJSEnfeeSdxcXEn5auwsJDCwkIACgoKTkpTWVnZ5QvmNDU1ERMTQ2RkJHv37mXLli14PB42btzI\n4cOH6d+/P3a7ndjYWMaOHcsrr7zCr3/9a8BXk4+JiSE9PZ3Dhw+TlpZGdHQ0TU1N/ucyGo1omuY/\nPtX9jEYjJpMJTdMwGo3+1e6OnWMwGDAYDP7jAwcO+Jvtv1l+ISEhp/xvITrGZDJJ+QWYlGnnuNDL\nVXc24f50Ha51/6Xly82gFEGDh2C+52HMV+VhiIo+42t2ZZkGJBJOnTqVpUuXsnbtWjIyMrBYLP4a\naFxcHPPnz8dmszFv3jxGjx7tb5b2er388Y9/5IYbbvC3GIwcOZKrrrqKoKAgPvjgA1544QWefPLJ\nk+6Zl5dHXl6e//ib+/y63e42S7h2hezsbF5++WWuuuoq0tPTyczMJCYmht/97nfcdddd6LpOXFwc\nr732GjNnzuSxxx4jOzsbg8HAQw89xIQJExg/fjwff/wxqampREVFkZWVRXZ2Njk5OeTm5qKUwuPx\nnPZ+Xq8Xj8eDUgqv14vX6wXwn6PrOrqu4/F4aG1t5cCBA/7WlmNpjnG73bJH9TmQPb4DT8q0cwS6\nXJXX62vqjozuNjVgpRQ01ENtFaqmCmxVUFOFqjkCe7ZDixv6JKJN/DHa6HHoCck0A80trXAWZdOV\n+9Frqp0Z/Hv27OEf//gHv/jFLwBYtWoVAN///vdPmd7lcvHggw/y5z//+aTPFi5cSGZmJqNHj/Yf\nm81mpk2bdspr6brOXXfdxcsvv9zug5SXl7c5bm5ubtOs3VNVVlbywAMP8Nprr3X6vd59912+/PJL\nHn30UUwm00mBvreUaVeRoBR4UqadI1DlqirKUOsLfSPS630DjwkLh2gLRMWgRVsgOsZ3HB2DFhXr\n/53wyHNajdMXyOt8wbvWF8SprfT9XlsNtZXQ0tL2pLAIiItHSxuMNnocpGcEbEXQrgz07dbo09PT\nqaiooKqqCovFQlFRETNnzmyT5thoe4PBwKpVq8jJyQGgtraWyMhIgoODaWxsZPfu3dx4440AvPba\nazQ3N3Pvvfe2udaxpmyATZs2tRkLcCFKSEjg1ltv9XeFdCaPx8PPfvazTr2HEKJ3U85m1KZPUOsL\nYd8uMBhg2BVol46A5kaot6Hq63yvB3b7avpHA26bWqfRdPxLwDe+FGjH3g+PhIZ6VE0l2KqPBvVK\nOBbMW78RyCMiwZoASaloQzPBmoAWFw9W348W2jsrMu0GeqPRyLRp05g7dy66rpOTk0NqaiorV64k\nPT2drKwsduzYwYoVK9A0jYyMDP+I+cOHD/PKK6+gaRpKKSZNmkS/fv2ora3l9ddfp2/fvvzv//4v\ncHwa3bvvvsumTZswGo1ERERw3333dW4J9AA33XTTebnPpEmTzst9hBC9i1IK9mxHrf8AtbnI1+yd\nlIo25S7fdLPo2G8/1+X01fjr7SiH3Rf8j30ZqK+DmkrU/t2+pna+8YXgRJHRvqDdtz/asCt9tXNr\nvC+4W/ugmS/M3TfbbbrvKXpr031Xkab7wJNm5sCTMu0cHS1XZatBfbraV3uvPgLmUN90s6vyfIvG\nBHgjLOXx+IK9ww51dlSjAy0q+niNPKT7zhjq1k33QgghxDGqtRW1dSNq/QewYysoBRdfhjbpFrTM\nMWgnTN8NNM1kglir76c/yH6aHSOBXgghRLvU1/t9A+s2roOmBrDEoU38EdqYXLQ+iV2dPfEtJNAL\nIYQ4JdXo8G3Ssv4DOHQATCa0Ed/xNc1nDEMzdO0UZtEx3XO9PuF34ja1Z+Oiiy4C4MiRI/zP/5x6\nhcApU6bw+eefA/DjH//
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435f2a5f8>"
]
},
"metadata": {},
"output_type": "display_data"
2017-10-09 18:30:53 +02:00
}
],
2017-10-06 15:46:19 +02:00
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
2017-10-15 14:13:09 +02:00
"init_scale = 0.01 # scale for random parameter initialisation\n",
2017-10-06 15:46:19 +02:00
"learning_rate = 0.1 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2017-10-15 14:13:09 +02:00
"#### `init_scale = 0.1`"
2017-10-06 15:46:19 +02:00
]
},
{
"cell_type": "code",
2017-10-15 14:13:09 +02:00
"execution_count": 4,
2017-10-09 18:30:53 +02:00
"metadata": {
2017-10-15 14:13:09 +02:00
"scrolled": true
2017-10-09 18:30:53 +02:00
},
2017-10-15 14:13:09 +02:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 3.8s to complete\n",
" error(train)=3.11e-01, acc(train)=9.13e-01, error(valid)=2.92e-01, acc(valid)=9.18e-01\n",
"Epoch 10: 4.0s to complete\n",
" error(train)=2.89e-01, acc(train)=9.20e-01, error(valid)=2.77e-01, acc(valid)=9.23e-01\n",
"Epoch 15: 3.7s to complete\n",
" error(train)=2.79e-01, acc(train)=9.22e-01, error(valid)=2.70e-01, acc(valid)=9.24e-01\n",
"Epoch 20: 5.4s to complete\n",
" error(train)=2.72e-01, acc(train)=9.24e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 25: 4.7s to complete\n",
" error(train)=2.68e-01, acc(train)=9.25e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 30: 4.2s to complete\n",
" error(train)=2.63e-01, acc(train)=9.27e-01, error(valid)=2.62e-01, acc(valid)=9.26e-01\n",
"Epoch 35: 4.0s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 40: 4.3s to complete\n",
" error(train)=2.59e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 45: 4.5s to complete\n",
" error(train)=2.55e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 50: 3.7s to complete\n",
" error(train)=2.54e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 55: 3.7s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 60: 4.6s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 65: 4.3s to complete\n",
" error(train)=2.50e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 70: 4.9s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.59e-01, acc(valid)=9.31e-01\n",
"Epoch 75: 4.7s to complete\n",
" error(train)=2.47e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 80: 4.7s to complete\n",
" error(train)=2.46e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 85: 4.2s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 90: 4.4s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 95: 4.1s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 100: 4.2s to complete\n",
" error(train)=2.43e-01, acc(train)=9.33e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNX9+P/XmSWTBbLMTEgCJEBCWAMBjYiICxAVFRX3\nWtt+FPupP6xL9dMPCFpL64eWVkT7betHpZRq+bRNK+ICCjRgRYnYKCD7EhL2QMhMFrJNMnPv748b\nBgKBBLhDtvfz8chj5s6cuffcN0PeOeeee47SdV1HCCGEEO2epa0rIIQQQojWkaQthBBCdBCStIUQ\nQogOQpK2EEII0UFI0hZCCCE6CEnaQgghRAchSVsIIYToICRpCyGEEB2EJG0hhBCig5CkLYQQQnQQ\ntrauQHMOHz7c1lXoVNxuN6WlpW1djU5FYhoaElfzSUxDw+y49uzZs1XlpKUthBBCdBCStIUQQogO\nQpK2EEII0UG0y2vaQgghLg1d16mrq0PTNJRSbV2dDuPo0aP4fL7z+oyu61gsFsLDwy841pK0hRCi\nC6urq8Nut2OzSTo4HzabDavVet6f8/v91NXVERERcUHHle5xIYTowjRNk4R9CdlsNjRNu+DPS9IW\nQoguTLrEL72LiXmnTtp6XS3ax++gF2xv66oIIYQQF61VfSIbN25k4cKFaJrGhAkTmDx5cpP3V65c\nyYoVK4IX2B999FF69+7N8ePHmTdvHgUFBVx//fU88sgjITmJs7La0D9+B44eRvUffGmPLYQQot2Y\nP38+sbGx3HvvveTk5HDdddeRmJh4Xvt4++23iYiI4N577z1rme3bt/PGG2/w6quvXmyVm9Vi0tY0\njQULFvD888/jcrmYMWMGWVlZ9O7dO1hm7Nix3HjjjQB89dVXvPXWWzz33HPY7Xbuv/9+9u/fz4ED\nB0JyAuei7HZU5ij0jV+i+/0ouW4jhBAdUiAQaDLw6/Tts/H7/QDk5OSwfPlyAP7xj38waNCgZpP2\nufb7ve99r8XjDR48mOLiYg4dOkSvXr1aLH++WuweLygoIDExkYSEBGw2G2PGjCE/P79JmcjIyODz\nurq6YH99eHg4gwYNIiwszORqt566/GqoPg47NrVZHYQQQpzb4sWLufXWW7nhhhuYNm0agUCA9PR0\nfvazn5Gdnc3XX3/NlVdeyezZs7nppptYunQpW7ZsYdKkSWRnZ/PII49QXl4OwD333MMLL7zAzTff\nzB/+8AfWrl1LRkYGNpuNpUuX8s033/D4449zww03UFtbe8Z+/+///o9bbrmF7Oxs/vM//5Pa2loA\nXn75ZV5//XUA7rzzTmbPns2tt97K2LFj+fLLL4PncsMNN/D++++HJE4tNj29Xi8ulyu47XK52L17\n9xnlli9fzrJly/D7/bzwwgvm1vJiDB0J4RHoX69FZVzW1rURQoh2S/vbfPQDRabuUyX3w/Kt/zxn\nmd27d/PBBx/w3nvvYbfbmTFjBu+++y41NTWMHDmSn/70p8GycXFxrFixAoDs7GxefPFFrrrqKl56\n6SXmzZvHz3/+cwAaGhr4+OOPAZg7dy7Dhw8HYNKkSfzpT3/iJz/5CZmZmc3u1+v18uCDDwLwq1/9\nir/+9a9MmTLljHr7/X6WLVvGqlWrmDdvHjk5OQBkZmbyu9/9jscee+yCYnYupvUXT5w4kYkTJ/L5\n55+zePFiHn/88VZ/Njc3l9zcXADmzJmD2+02q1oAVIy6Bt/6dbie+kmX7CK32Wymx7Srk5iGhsTV\nfC3F9OjRo8FbvvwWC5rJo8ktFkuLt5Tl5eWxefNmbr31VsDose3RowdWq5U77rgj2F2tlOLOO+/E\nZrNRWVlJZWUl11xzDQAPPPAA3//+97HZbE3KARw7doyBAwcGt5VSWK3WJtunli8oKGDOnDlUVFRQ\nXV3NuHHjsNlsWCyWJudz2223YbPZGDlyJAcPHgy+npCQQElJyVnP2+FwXPD3vMUM5nQ68Xg8wW2P\nx4PT6Txr+TFjxjB//vzzqkR2djbZ2dnBbbNXpNGHZaGvWUnp56u7ZGtbVvkxn8Q0NCSu5msppj6f\n7+Q13PseCcktRSeuK59NIBDg3nvvZcaMGU1ef+2119B1Pfh5XddxOBz4/X78fn+T907dPrUcGEmy\npqamyX4CgUCz+wV48sknWbBgAUOHDiUnJ4cvvvgCv9+PpmlomhYsZ7Vamxz/xPPq6uom+zudz+c7\n49/EtFW+0tLSKC4upqSkBL/fT15eHllZWU3KFBcXB5+vX7+epKSkVh38khl6WbCLXAghRPsyduxY\nli5dGkxkZWVlHDx48JyfiY6OJiYmJngtefHixYwePbrZsv3792fv3r3B7aioKKqqqs6676qqKhIS\nEmhoaGDJkiXneTZQWFjIwIEDz/tzrdFiS9tqtTJlyhRmz56NpmmMGzeO5ORkcnJySEtLIysri+XL\nl7N582asVivdunXjhz/8YfDzP/zhD4N/4eTn5/P88883GXl+KSh7GGr4KPQN69AfnNolu8iFEKK9\nGjBgANOmTeOBBx5A13VsNhuzZ89u8XOvvvoqzz77LHV1daSkpDBv3rxmy40fP54nn3wyuH3ffffx\n7LPPEh4ezgcffHBG+f/+7/9m0qRJuFwuRo4cec4E35y8vDwmTJhwXp9pLaXruh6SPV+Ew4cPm75P\nfeM6tN//AsuPfoYaOtL0/bdn0uVoPolpaEhczddSTGtqaprcAdRZPfLIIzz33HOkpqaasj+bzdZs\n97fP5+Puu+/mvffeO+s17eZiblr3eKchXeRCCNFlzZgxg5KSkpAf59ChQ8ycOTNk87l3maR9sov8\nC/QWBkUIIYToXPr373/Wa95mSk1NZcyYMSHbf5dJ2gAq62qoOg67Nrd1VYQQQojz1qWSNkNHgiMC\n/SvpIhdCCNHxdKmkrcIcqMwrjFHkgUBbV0cIIYQ4L10qaUPjXORVlbBTusiFEEJ0LF0uaZNxmdFF\nLqPIhRCiy5g/fz7/+Mc/LuizP/rRj1i6dCkAP/7xj9m1a9cZZXJycnjuuecAWLhwIX/7298uvLLn\n0OWStgpzoIZnoa//QrrIhRCigwic9vv69O2zOTG9aE5ODnfeeedF12Pu3LkMGDDgnGW+9a1v8cc/\n/vGij9WcLpe04cQo8krYtaWtqyKEEIJLtzRnQUFBcGESgAMHDgRnL3vllVe45ZZbGD9+PNOmTaO5\nucfuuecevvnmG8BoXY8dO5Zbb72Vr776KlgmIiKC5ORkNmzYYHqcuuZ8nhmXgyMc/au1qMGZLZcX\nQogu4A9fHaWorM7UffaLC+f7WQnnLHMpl+bs378/9fX17N+/n5SUFD744ANuu+02AB566CGefvpp\nAJ544gn++c9/cuONNzZb56NHjzJ37lyWL19O9+7duffee8nIyAi+P3z4cL788ktGjjR3Bs6u2dIO\nc6CGX2FMtCJd5EII0aY+//xzNm/ezC233MINN9zA559/zv79+7FarU1axQC33347AJWVlVRUVHDV\nVVcBcO+99wYXDzm1HEBJSQkulyu4fdtttwXnHP/ggw+CZfPy8pg0aRITJkwgLy+v2WvXJ3z99ddc\nddVVuFwuwsLCmhwPjOljjx49eiHhOKeu2dLGGEWu539mdJFLa1sIIVpsEYeKruvNLs35+uuvn1w2\ntFFr50k/tVx4eDh1dSd7EG6//XYeffRRbr75ZpRSpKamUldXx8yZM/noo4/o1asXL7/8Mj6f74LP\nyefzER4efsGfP5su2dIGjC7yMIdMtCKEEG3sUi/N2bdvX6xWK6+++mqwhXwiQTudTqqrq1m2bNk5\nj3/55Zezbt06vF4vDQ0NwdHlJxQWFjJo0KBz7uNCdN2WtuOULvJvP4o67a85IYQQl8alXpoTjNb2\niy++yLp16wCIiYnh29/+NhMmTCA+Pp7MzHP3wCYkJPBf//Vf3H777cTExDB06NAm7+fn5/PMM8+0\neA7nq8sszdkc/eu1aK/
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc438001780>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAENCAYAAADngqfoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNX9//HXnZns+0xCFhJCCIsBBIlRKGokJG6gliLa\nuuACrfUnotalVYtLXVpa+LpRt68FFyxfrQvWjaoBsWiEhE2FBEhYZElClkkyk2Rmkpl7fn8MDETA\nQJiQ7fN8PPJIZubMveceEt5z7j33HE0ppRBCCCFEr2To6goIIYQQovNI0AshhBC9mAS9EEII0YtJ\n0AshhBC9mAS9EEII0YtJ0AshhBC9mAS9EEII0YtJ0AshhBC9mAS9EEII0YtJ0AshhBC9mKmrK+Av\n5eXlXV2FXiU2NpaampqurkavIm3qf9KmnUPa1f86o02TkpKOq5z06IUQQoheTIJeCCGE6MUk6IUQ\nQoherNdco/8xpRROpxNd19E0raur0+Ps378fl8sFeNvSYDAQHBwsbSmEED1Mrw16p9NJQEAAJlOv\nPcROZTKZMBqNvsdutxun00lISEgX1koIIcSJ6rWn7nVdl5D3I5PJhK7rXV0NIYQQJ6jXBr2cYvY/\naVMhhOh5em3QCyGEEN2B2lWK/fXnUEp1yf4l6Ls5pRRXXnkldrudhoYGXn311Q5tZ/r06TQ0NPxk\nmUcffZSvvvqqQ9sXQgjRltq2Cc9TD6M/cTeOzz8Aa3WX1EOCvptbvnw5w4cPJyIiApvNxuuvv37U\ncm63+ye3s3jxYqKion6yzIwZM3juuec6XFchhOjrlFKoTevw/PU+9HkPwJ4daFfcQOz/vodm6dcl\ndTqu0WobN27klVdeQdd1cnNzmTJlSpvXq6ureeGFF7DZbISHhzN79mwsFgvV1dXMnz8fXdfxeDxc\nfPHFXHjhhbhcLp588kn279+PwWDgzDPP5NprrwVg5cqVLF68GLPZDMDFF19Mbm6unw/71JkxYwbl\n5eW4XC5mzpzJddddxxdffMHcuXPxeDyYzWb+9a9/0dTUxJw5c/juu+/QNI3f/e53TJ48maVLl/ra\n5s9//jM//PADF1xwAdnZ2eTm5jJv3jyioqIoKyvjq6++Our+AMaOHcuyZctoamriuuuu4+yzz2bt\n2rUkJCSwaNEiQkJCSE5Opq6ujqqqquOeWlEIIQQoXYeNa9A/eRt+KIOYWLRf3Yx23gVogUEYQsKg\nydEldWs36HVdZ+HChcyZMweLxcL9999PVlYWycnJvjKLFy8mOzubCRMmsGnTJpYsWcLs2bOJiYnh\n8ccfJyAgAKfTyd13301WVhZhYWFcdtlljBw5ErfbzaOPPsqGDRsYM2YMAOPHj2fmzJl+O0j9zZdR\ne3b6bXsAWkoahl/9pt1y//M//0NMTAwOh4PJkydz0UUXce+99/Lee+8xYMAA6urqAHj66aeJiIhg\n+fLlANTX1wNQVFTEX//6VwAeeOABtm7dyueffw5AQUEB33//PStWrGDAgAFH3d+kSZN8H5oO2rlz\nJ8899xzz5s3jt7/9LZ988glXXHEFAKeffjpFRUX8/Oc/90MrCSFE76Y8HlTRKtSyd6B8N/RLRLv+\nNrSf5aCZArq6esBxBH1ZWRkJCQnEx8cD3hAuKipqE/R79+7l+uuvB2DEiBHMmzfPu/HDbm9rbW31\n3Z4VFBTEyJEjfWXS0tKora310yF1L4sWLWLZsmWAd+GdN954g3HjxvmCOSYmBoBVq1bx/PPP+94X\nHR0NeAM/PDz8mNs/44wzfNs62v527tx5RNCnpKT42n/UqFHs2bPH95rFYmH//v0dPl4hhOgLVGsr\n6psVqP+8C9WVkDQA7dd3o2Wdi3bYHCTdQbtBb7VasVgsvscWi4XS0tI2ZVJTUyksLGTSpEkUFhbi\ncDiw2+1ERERQU1PD3Llzqays5LrrrjsidJqamli3bh2TJk3yPbdmzRpKSkpITEzkhhtuIDY29qQO\n8nh63p2hoKCAVatW8eGHHxISEsK0adMYMWIE27dvP+5tHLx/3WA4+nCK0NDQn9zfwdntDhcUFOT7\n2Wg04nQ6fY9dLhfBwcHHXT8hhOhLlMuF+uoz1KdLoa4GUgdjuPUBGH022jH+n+5qfplRZvr06Sxa\ntIiVK1eSkZGB2Wz2BVNsbCzz58/HarUyb948xo0b5+utejwennnmGS655BLfGYMzzzyTc845h4CA\nAD7//HOee+45Hn744SP2mZ+fT35+PgBz58494sPA/v37u3zCnKamJqKjo4mIiKC0tJT169fjdrtZ\ns2YN+/btIzU1lbq6OmJiYjj//PN5/fXXefzxxwFvTz46Opr09HT27dtHWloaUVFRNDU1+Y7LaDSi\naZrv8dH2ZzQaMZlMaJqG0Wj0zXZ38D0GgwGDweB7vHPnTt9p+x+3X1BQ0El/6OrLTCaTtJ+fSZt2\nDn+2q6feStM7r9G65XtMKQMxDRxCwKChmAYOwRAR6Zd9nAp6cxOOZe/S9MGbKFs9AcPPIOz2PxI4\n+uzjmmOkK39X201Cs9nc5rR6bW3tEb1ys9nMPffcA3innl2zZg1hYWFHlElJSWHLli2MGzcOgJde\neomEhAQmT57sKxcREeH7OTc3lzfeeOOo9crLyyMvL8/3+Mfr/LpcrjZTuHaF7OxsXnvtNc455xzS\n09PJzMwkOjqav/71r9x0003ouk5sbCxvvvkmt99+Ow888ADZ2dkYDAbuuusuJk2axMSJE1m1ahUp\nKSlERkaSlZVFdnY2OTk55ObmopTyjbg/2v48Hg9utxulFB6PB4/HAxwapa/rOrqu43a7aW1tZefO\nnb7T+j8eye9yuWSN6pMga3z7n7Rp5/BHu6rmJtSnS1HLP4DWFhicgXtjIaz8z6FC5jgYMAgtJQ0t\nZRCkpIGlX7eanEs12lDLP0Kt+BCam2DEGAyTrkIfOgI7wHFedu7K9ejbDfr09HQqKiqoqqrCbDZT\nUFDA7bff3qbMwdH2BoOBpUuXkpOTA3g/FERERBAYGEhjYyNbt27l0ksvBeDNN9+kubmZW265pc22\nDvZwAdauXdtmLEBPExQUdMwPKhMnTmzzOCwsjGeeeeaIctdccw133HEH11xzDcARt7+NHz/+uPa3\nZs0awPuBa8WKFb7nD2///Px8Jk+e3OVnQoQQPZdyuVBffIRa9i40N6KddR7a5degJfT3vm6rhz07\nUXt2eL/v3oH6tvDQZDKhYZDiDX9SBqENSIOEFLRT/P+SaqhDffY+6stl4HLCmHEYJl2JNnDIKa2H\nP7TbckajkRkzZvDEE0+g6zo5OTmkpKTw1ltvkZ6eTlZWFsXFxSxZsgRN08jIyPCNmN+3bx+vv/46\nmqahlOKyyy5jwIAB1NbW8t5779G/f3/+8Ic/AIduo1u2bBlr167FaDQSHh7Orbfe2rkt0M3Fx8dz\nzTXX+MY8dCa3281vf/vbTt2HEKJ3Um436ut81EdvQr0VRmZi+MV0tAHpbcppkdEwYgzaiDGH3uty\nwr4fULt3+D4EqP/+B1paUAAmk3ewW8qgwz4EpKGFhNIepRToOnjc4PEc5bun7ePWFtTaVahVn4PH\ng3b2eWiXXInWf0C7++quNNVVc/L5WXl5eZvHzc3NbQaqiRNjMpmOOHUvbXpy5DSz/0mbdo4TaVel\n697byz5YAlUVkH4ahqnXow0deVJ1ULoH9pe3CX9274BG26FCln5gCjgU0voxgvxEGU1o4yeiXTwV\nrZ9/5hTp1qfuhRBCiB9TSsGmdejvLYa9O6F/KobbHoRRWX65xq4ZjJCYgpaYAmPPP7TPBqvvlD/l\nu0EpMBoPfJkOfTcc5Tmj4UePD5XVDn+ufypajKWdGvYcEvRCCNEDqZ2lqLWroF8S2pDhkJB8ym7v\nUqXF6O+9DmXFEJeANvMutLOzO33/mqZBtAWiLWinZ3XqvnoTCXohhOhB1LZN6B//C4o3gsEAuu69\njh0WAYMz0IYMRxs8HFL
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435dd92e8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2017-10-06 15:46:19 +02:00
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.1 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
2017-10-15 14:13:09 +02:00
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
2017-10-06 15:46:19 +02:00
"\n",
2017-10-15 14:13:09 +02:00
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
2017-10-06 15:46:19 +02:00
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2017-10-15 14:13:09 +02:00
"#### `init_scale = 0.5`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 3.7s to complete\n",
" error(train)=3.38e-01, acc(train)=9.03e-01, error(valid)=3.17e-01, acc(valid)=9.11e-01\n",
"Epoch 10: 4.3s to complete\n",
" error(train)=3.06e-01, acc(train)=9.13e-01, error(valid)=2.94e-01, acc(valid)=9.17e-01\n",
"Epoch 15: 3.3s to complete\n",
" error(train)=2.92e-01, acc(train)=9.17e-01, error(valid)=2.83e-01, acc(valid)=9.20e-01\n",
"Epoch 20: 5.5s to complete\n",
" error(train)=2.82e-01, acc(train)=9.20e-01, error(valid)=2.77e-01, acc(valid)=9.22e-01\n",
"Epoch 25: 3.8s to complete\n",
" error(train)=2.77e-01, acc(train)=9.22e-01, error(valid)=2.75e-01, acc(valid)=9.22e-01\n",
"Epoch 30: 4.3s to complete\n",
" error(train)=2.71e-01, acc(train)=9.24e-01, error(valid)=2.71e-01, acc(valid)=9.25e-01\n",
"Epoch 35: 3.9s to complete\n",
" error(train)=2.67e-01, acc(train)=9.25e-01, error(valid)=2.69e-01, acc(valid)=9.26e-01\n",
"Epoch 40: 4.4s to complete\n",
" error(train)=2.65e-01, acc(train)=9.27e-01, error(valid)=2.68e-01, acc(valid)=9.26e-01\n",
"Epoch 45: 4.2s to complete\n",
" error(train)=2.61e-01, acc(train)=9.27e-01, error(valid)=2.66e-01, acc(valid)=9.27e-01\n",
"Epoch 50: 4.2s to complete\n",
" error(train)=2.59e-01, acc(train)=9.28e-01, error(valid)=2.65e-01, acc(valid)=9.27e-01\n",
"Epoch 55: 4.2s to complete\n",
" error(train)=2.57e-01, acc(train)=9.29e-01, error(valid)=2.64e-01, acc(valid)=9.29e-01\n",
"Epoch 60: 3.6s to complete\n",
" error(train)=2.56e-01, acc(train)=9.28e-01, error(valid)=2.65e-01, acc(valid)=9.28e-01\n",
"Epoch 65: 4.6s to complete\n",
" error(train)=2.54e-01, acc(train)=9.30e-01, error(valid)=2.63e-01, acc(valid)=9.28e-01\n",
"Epoch 70: 3.7s to complete\n",
" error(train)=2.52e-01, acc(train)=9.30e-01, error(valid)=2.64e-01, acc(valid)=9.28e-01\n",
"Epoch 75: 5.0s to complete\n",
" error(train)=2.50e-01, acc(train)=9.31e-01, error(valid)=2.62e-01, acc(valid)=9.29e-01\n",
"Epoch 80: 3.7s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.63e-01, acc(valid)=9.28e-01\n",
"Epoch 85: 3.8s to complete\n",
" error(train)=2.48e-01, acc(train)=9.31e-01, error(valid)=2.62e-01, acc(valid)=9.30e-01\n",
"Epoch 90: 4.1s to complete\n",
" error(train)=2.47e-01, acc(train)=9.31e-01, error(valid)=2.62e-01, acc(valid)=9.28e-01\n",
"Epoch 95: 4.3s to complete\n",
" error(train)=2.47e-01, acc(train)=9.31e-01, error(valid)=2.62e-01, acc(valid)=9.29e-01\n",
"Epoch 100: 3.7s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.62e-01, acc(valid)=9.28e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNX9//HXmTV7yEwgISSYFZAdGhUQsUBUZLFqhWr9\n1qp8W76u1bYuiD9rtbRYceliv1CKSL/SFpSKigoUULbIDgLKvgcCIZmE7Mtkzu+PmwyJLFmYYbJ8\nno/HPHJn5t47Zw6XvHPPPfccpbXWCCGEEKLFMwW6AEIIIYRoHAltIYQQopWQ0BZCCCFaCQltIYQQ\nopWQ0BZCCCFaCQltIYQQopWQ0BZCCCFaCQltIYQQopWQ0BZCCCFaCQltIYQQopWwNGal7du3M2fO\nHDweDyNHjuT222+v9/6yZctYunQpJpOJoKAgJk2aRHx8vPf93NxcnnzyScaPH89tt93W4OedPHmy\niV9DXEp0dDS5ubmBLkabInXqH1Kvvid16h++rte4uLhGrddgaHs8HmbPns3zzz+P0+lk8uTJpKen\n1wvloUOHcvPNNwOwefNm5s6dy5QpU7zvz507lwEDBjT1OwghhBCijgabxw8cOEBsbCwxMTFYLBaG\nDBnCpk2b6q0TEhLiXS4vL0cp5X2+ceNGOnXqVC/khRBCCNF0DYa2y+XC6XR6nzudTlwu13nrLVmy\nhMcee4x58+bxwAMPAEaAf/jhh4wfP96HRRZCCCHap0Zd026MUaNGMWrUKNauXcvChQt59NFHWbBg\nAWPGjCEoKOiS2y5fvpzly5cDMG3aNKKjo31VLAFYLBapUx+TOvUPqVffa6hOtda4XC7cbvcVLFXr\nl5OTQ3NmtrZYLDgcjnot0k3avqEVHA4HeXl53ud5eXk4HI6Lrj9kyBBmzZoFGE3rGzZsYN68eZSU\nlKCUwmazMWrUqHrbZGRkkJGR4X0unSZ8Szqi+J7UqX9IvfpeQ3VaVlaG1WrFYvHZOVy7YLFYmvWH\nTlVVFVlZWQQHB9d73Wcd0VJSUsjOziYnJweHw0FmZiaPP/54vXWys7Pp3LkzAFu3bvUuv/TSS951\nFixYQFBQ0HmBLYQQInA8Ho8E9hVksVioqKho/vYNrWA2m3nwwQeZOnUqHo+H4cOHk5CQwPz580lJ\nSSE9PZ0lS5awc+dOzGYzYWFhPPLII80ukBBCiCunuc20ovkup86Vbk6jvJ/56j5tXV6K/vxTVFpP\nVGpPn+yzNZImR9+TOvUPqVffa6hOS0tL690BJBqnuc3jcOE6b2zzeNseEc1sRX/6HnrdikCXRAgh\nRADNmjWL9957D4D58+dz6tSpJu/j73//u3cfF7N7926eeOKJZpWxMdp0aCurFdX3GvT2DWhPdaCL\nI4QQopmqq6sv+fxi3G43breb+fPnc8cddwDw3nvvcfr06UZ9Tl333Xdfg7cwX3311WRnZ3PixIlG\nla+p2nRoA6iBg6G4EPbvDnRRhBBCXMTChQsZM2YMN910E08//TTV1dWkpaXx61//moyMDLZs2cJ1\n113H1KlTueWWW1i8eDG7du1i7NixZGRkMHHiRAoKCgC46667eOGFF7j11lv529/+xrp16+jduzcW\ni4XFixfz1Vdf8eijj3LTTTdRVlZ23n7nzZvH6NGjycjI4Cc/+QllZWUAvPbaa8yYMQOAO+64g6lT\npzJmzBiGDh3Khg0bvN/lpptu4sMPP/RLPbX9LoO9BoLFit72Jap770CXRgghWizPv2ahjx/26T5V\nQhKmu39yyXX279/PRx99xKJFi7BarUyePJl///vflJaWMmDAAH71q195142KimLp0qWAcbvwyy+/\nzODBg3n11Vd5/fXXvXctVVVV8dlnnwEwffp0+vbtC8DYsWN55513+H//7//Rr1+/C+7X5XJx7733\nAvDKK6/wz3/+kwcffPC8crvdbj755BNWrFjB66+/zvz58wHo168ff/7zn3n44YebVWeX0uZDWwUF\nQ68B6K1fon/w39JTUgghWpi1a9eyc+dORo8eDRijaUZHR2M2mxkzZky9dWsnnSosLOTs2bMMHjwY\ngPHjxzNp0qTz1gNjIJS0tLRLlqHu+nv37uX3v/89hYWFlJSUcOONN15wm9ry9u3bl6ysLO/rTqfz\nos3vl6vNhzYYTeT6q41w5AAkXfofTggh2quGzoj9RWvN+PHjmTx5cr3XZ8yYgdlsrvdaY3u6110v\nKCiI8vLyRq//5JNPMnv2bHr16sX8+fP58ssvL7iNzWYDjFuj6/Ykr6ioaHAk0OZq89e0AVS/a8Fk\nQm+7cMULIYQInKFDh7J48WLvrWn5+fn1zlwvJCIigsjISO+15IULFzJo0KALrpuamsqRI0e8z0ND\nQykuLr7ovouLi4mJiaGqqooPPvigid8GDh06RPfu3Zu8XWO0jzPt0HDo3sdoIr/jR9JELoQQLUi3\nbt14+umnueeee9BaY7FYmDp1aoPbvfnmmzz77LOUl5fTtWtXXn/99QuuN2LEiHojeU6YMIFnn32W\noKAgPvroo/PWf+qppxg7dixOp5MBAwZcMuAvJDMzk5EjRzZpm8Zq04Or1OX5/FP0P2Zg+vWfUXFd\nfb7/lkwGrPA9qVP/kHr1PRlcxTBx4kSmTJlCcnKyT/Z3scFVKioq+P73v8+iRYsuOjysDK7SCGrA\ndQDordJELoQQ7c3kyZPJycnx++ecOHGC5557zm/jubef0O7ghJQe6G3rA10UIYQQV1hqaupFr3n7\nUnJyMkOGDPHb/ttNaAOoAYPh2EF0rn+64gshhBD+1M5C2/grS862hRBCtEbtK7Q7dYb4RLmuLYQQ\nolVqV6ENNU3kB3ejC/MDXRQhhBCiSdpfaA8cDFqjt29oeGUhhBBtQt2pOZvqiSeeYPHixQD88pe/\nZN++feetM3/+fKZMmQLAnDlz+Ne//tX8wl5CuwttulwFnTpLE7kQQrQivpya83JMnz6dbt26XXKd\nu+++m7fffvuyP+tC2l1oK6WMJvI9O9ClTRvlRgghhH9cqak5Dxw4UG8SkuPHj3tHL3vjjTcYPXo0\nI0aM4Omnn+ZCY4/dddddfPXVV4Bxdj106FDGjBnD5s2bvesEBweTkJDAtm3bfF5P7WIY029TAwah\nl/4bvWMTatDwQBdHCCFahL9tPs3h/EtPrNFUSVFB/Hd6zCXXuZJTc6amplJZWcmxY8fo2rUrH330\nEePGjQPg/vvv58knnwTgscce4z//+Q8333zzBct8+vRppk+fzpIlSwgPD2f8+PH07n1u+ue+ffuy\nYcMGBgwY0Jxqu6h2d6YNQFI36OCQJnIhhGgB6k7NedNNN7F27VqOHTvW5Kk5aycPqbseGFNzOp1O\n7/Nx48Z5xxz/6KOPvOtmZmYyduxYRo4cSWZm5gWvXdfasmULgwcPxul0YrPZ6n0eGMPH+mN6zvZ5\npm0yGWfb65ajKypQdnugiySEEAHX0Bmxv1zpqTlvu+02Jk2axK233opSiuTkZMrLy3nuuef49NNP\n6dKlC6+99hoVFRXN/k7+mp6zfZ5pU3PrV2UlfL010EURQoh27UpPzZmYmIjZbObNN9/0niHXBrTD\n4aCkpIRPPvnkkp//ne98h/Xr1+NyuaiqqvL2Lq916NAhevToccl9NEe7PNMGoFtvCA1Hb/vSuA1M\nCCFEQFzpqTnBONt++eWXWb/eGCEzMjKSH/7wh4wcOZKOHTvSr1+/S352TEwMv/jFL7jtttuIjIyk\nV69e9d7ftGkTP//5zxv8Dk3VbqbmvBDPnD+gt63H9PrfURbrFfnMQJDpDn1P6tQ/pF59T6bmNFyp\nqTkBdu3axcyZM/nTn/50wfdlas5mUgMHQ1kJ7N0V6KIIIYTwoys1NSeAy+Xi6aef9su+22/zOEDP\n/mAPQm/9EtXLt93yhRBCtBypqamkpqZekc8aNmyY3/bdvs+0rTZUn3T09vVoT+NG1xFCiLakBV4h\nbfMup87bdWgDMHAwFBb
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc438001160>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAENCAYAAAAbl4wiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VPW9//HXmSXbTEgyE0gICQECSCCExSiIGAnEpaCW\nIl1c0Ba91aqx1xb7w5Zau9Byr1q9WuWqF2hFcamVWmvVGhFREQKyEwSCLAnZ92Qyk8zM+f7+mBCI\nLAnJhMkkn+fjwSNM5syZz/kS8p7v95zz/WpKKYUQQgghgpYh0AUIIYQQonskzIUQQoggJ2EuhBBC\nBDkJcyGEECLISZgLIYQQQU7CXAghhAhyEuZCCCFEkJMwF0IIIYKchLkQQggR5CTMhRBCiCBnCnQB\n56u4uDjQJfQZsbGxVFZWBrqMPkfa1f+kTf1P2rRn+LtdExISOrWd9MyFEEKIICdhLoQQQgQ5CXMh\nhBAiyAXdOfOvU0rhcrnQdR1N0wJdTlApKyujubm57bFSCoPBQFhYmLSlEEIEkaAPc5fLhdlsxmQK\n+kO54EwmE0ajsd33PB4PLpeL8PDwAFUlhBDifAX9MLuu6xLkfmQymdB1PdBlCCGEOA9BH+YyHOx/\n0qZCCBFcpEsrhBBCdJFSCmqqoPgYqvgYzsShMHbyBa9DwrwXUErxne98h5UrV6LrOmvXruX73//+\nee9nwYIF/OlPfyIqKuqs2/zmN79h5syZTJ8+vRsVCyFE/6KUgvrattDm+FHf1+JCcDratmuecqWE\neX/14YcfMnbsWCIjIyksLOTFF188Y5h7PJ5zXh+wevXqDt9r4cKFPPjggxLmQghxFqqxHo63hnbx\nMVTxUSg+Bo0NJzeyREJCEtqUTEgYipaQDAlDiR4+IiAz60mY+8HChQspLi6mubmZO+64g1tvvZWP\nPvqIZcuW4fV6sdlsvP766zgcDpYsWcKuXbvQNI0HHniAOXPmsHbtWm655RYAfv/733P06FGuuuoq\nMjMzmTVrFo8++ihRUVEUFBTw6aefnvH9AKZMmcK7776Lw+Hg1ltv5dJLL2Xr1q3Ex8ezcuVKwsPD\nSUxMpKamhvLy8k5PEyiEEGejdB0qS1FHv4LCQ+DxQLgFwiMg3IIWHuH7e8TJ7xEegWYyB7p0VJPj\nlLAuRB1vDe362pMbhUf4wnrSZa2hPRSGJMOA6F51fVGfCnP91RdQhYf9uk8taTiG7/3HObd5/PHH\niYmJwel0MmfOHK655hoefPBB3nzzTYYOHUpNTQ0ATz75JJGRkXz44YcA1Nb6fmC2bNnCf/3XfwHw\n85//nP379/PBBx8AsHHjRnbv3s26desYOnToGd9v9uzZ2Gy2djUdPnyYZ555hkcffZS77rqLf/3r\nX9x4440AjB8/ni1btvDNb37TT60khOgPlO6F0uOoY4fg6Fe+r4VfgbPJt4HRCEYTtJwyf8XZdmYO\naRfu7cO/9XsRrX83h/g+JHjcrX86+rsHdabn3G7wtn6v2dU+tEPDYHAS2viLT4Z2QjLE2HtVaJ9N\nnwrzQFm5ciXvvvsu4FsI5qWXXmLq1Klt4RsTEwPAJ598wrPPPtv2uujoaMAX6lar9az7nzhxYtu+\nzvR+hw8fPi3Mk5KSSEtLAyA9PZ3CwsK25+x2O2VlZV0+XiFE36c8HigpbA3uQ63BffhkUJtDIHEY\n2pQrYWgK2tAUXwiazb7Xupp8Ie90tH1VTad/D2cT6sTj2uqTf292da5QoxFMZt+HCLP59L+bTL6v\nEaFtjzWTGUJCYeDg1p72ULANRDME7w1efSrMO+pB94SNGzfyySef8PbbbxMeHs78+fMZN24chw4d\n6vQ+TtzbbTjLD1JERMQ53+/UWdxOCA0Nbfu70WjE5Tr5H6O5uZmwsLBO1yeE6NuUuwWKjvoC+9gh\n1NFDcPyIrzcLEBoOQ4ejXXG1L7iTUyA+Ee1rk06doJlMYB3g+3Pq98+nJq/X94GgyQHultYgPiWc\nzb7QDuYA9qc+FeaB0NDQQFRUFOHh4RQUFLBt2zaam5vZtGkTx44daxtmj4mJITMzkz//+c/85je/\nAXw98ujoaEaMGMHRo0cZPnw4FouFxsbG83q/8/XVV19x3XXXdfmYhRAXllIKvN7WP54uf1WnPHYY\nDej79/oCvPgYnJgsKsIKySlos64/2eMeNPiCh6ZmNPouMrNEXtD3DVYS5t00Y8YMVq9ezZVXXklK\nSgqTJ0/Gbrfz3//939x5553ouk5sbCyvvvoqP/7xj/n5z3/OzJkzMRgM/OQnP2H27NnMmjWLzz//\nnOHDh2Oz2bjkkkuYOXMmWVlZzJo1q8P3Ox9ut5sjR44wYcIEfzaDEMKPlK7D4QOonXmonXm+sPWz\nRoDIKF9wp1/iC+2hIyA2LijOEYv2NKXUWa9P6I2Ki4vbPW5qamo3DB2MysrK+PGPf8yrr77a4+/1\n7rvvsnv3bn72s59hMpnwnBhGO0VfaNNAio2NDcitKX1Zf2hT1dIM+3aeDPD6WjAYYHQa2ogxbcPK\nbReZGY3t/q6Zvv7cub/a4wdT1eyW4PYzf/+sdvauo071zHfs2MGqVavQdZ1Zs2Yxd+7cds9XVFSw\nfPly6uvrsVqt5OTkYLfbqaio4LHHHkPXdbxeL9deey1XX3014BvqfeaZZ2hpaWHSpEn84Ac/6Lc/\nVHFxcdx88800NDQQGdmzQ0oej4e77rqrR99DCNE5qr4WtXsrasdmyN8OLS2+27bSLoYJl6KlXYxm\nOfvFsd1hGBCN1sc/IPUnHYa5ruusWLGCJUuWYLfbeeihh8jIyCAxMbFtm9WrV5OZmcmMGTPYs2cP\na9asIScnh5iYGH73u99hNptxuVz89Kc/JSMjA5vNxgsvvMBdd93FqFGj+MMf/sCOHTuYNGlSjx5s\nb3bDDTdckPe5/vrrL8j7CCHOTJUUoXZu9gX4V/tBKbDFol2ejTZxiq8n3gvuwRbBpcMwLygoID4+\nnri4OACmTZvGli1b2oV5UVERt912GwDjxo3j0Ucf9e38lNnK3G5322pcNTU1OJ1ORo8eDUBmZiZb\ntmzp12EuhOiblO6FQ/tROzb7hs/LjvueGJqCdt33fAGeNLzfjkwK/+gwzKurq7Hb7W2P7XY7Bw8e\nbLdNcnIyeXl5zJ49m7y8PJxOZ9uQcWVlJcuWLaO0tJRbb70Vm83GoUOHTttndXX1Gd8/NzeX3Nxc\nAJYtW0ZsbGy758vKymQJ1G44U9uFhoae1s6i80wmk7SfnwVbmyqXk+YdeTTnfULzFxtR9bVgMhEy\n/mJCv/k9Qi+ZjjE2LqA1BlubBotAtatfUnDBggWsXLmS9evXk5qais1ma7tnOjY2lscee4zq6moe\nffRRpk6del77zs7OJjs7u+3x1y8saG5uxniWex3FuZ3tArjm5uY+f7FRT+oPF2tdaL2tTZXu9U1s\n0tgAjgZwNKIc9VBfh/pyF+zb6ZtlLMKCNj4Dw8QpMG4y3vAImoAmgAAfT29r076i114AZ7PZqKqq\nantcVVV12mxjNpuNRYsWAeByudi8eTMWi+W0bZKSkvjyyy+56KKLOtynEEL0NKXrvlnIWgOZxgbU\nib876lu/nvq9Bl+AOx2+c91nEhuHNuMbaBMuhZFjfVeZ9wPNHp3NRY3ERpgYO0juhrnQOvwpS0lJ\noaSkhPLycmw2Gxs3buT+++9vt82Jq9gNBgNr164lKysL8IV0ZGQkISEhNDY2sn//fq677jpiYmII\nDw/nwIEDjBo1ig0bNnDttdf2zBEGgVOXQO3K1eyjRo3i4MGDlJaW8stf/pIXXnjhtG3mz5/PL3/5\nSyZMmMB3v/tdnnvuORliE/2SampEvfc31Ke50Fh/9lAG3+IgJyYusVjRBsa3e4w1Es0S6ZtoxTrA\n9z1LZL86/13W2MK/DtSSe6iWxhbfdVHjBoXznbRYJsRH9Ku2CKQOw9xoNLJw4UKWLl2KrutkZWWR\nlJTEa6+9RkpKChkZGeT
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435c4db70>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.5 # scale for random parameter initialisation\n",
"learning_rate = 0.1 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
2017-10-06 15:46:19 +02:00
"\n",
2017-10-15 14:13:09 +02:00
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
2017-10-06 15:46:19 +02:00
"\n",
2017-10-15 14:13:09 +02:00
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
2017-10-06 15:46:19 +02:00
"\n",
2017-10-15 14:13:09 +02:00
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
2017-10-06 15:46:19 +02:00
"\n",
2017-10-15 14:13:09 +02:00
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
2017-10-06 15:46:19 +02:00
"\n",
2017-10-15 14:13:09 +02:00
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
2017-10-06 15:46:19 +02:00
"\n",
2017-10-15 14:13:09 +02:00
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
2017-10-06 15:46:19 +02:00
"\n",
2017-10-15 14:13:09 +02:00
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)"
2017-10-06 15:46:19 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2017-10-15 14:13:09 +02:00
"|`init_scale`| Final `error(train)` | Final `error(valid)` |\n",
"|------------|----------------------|----------------------|\n",
"| 0.01 | 2.43e-01 | 2.58e-01 |\n",
"| 0.1 | 2.43e-01 | 2.59e-01 |\n",
"| 0.5 | 2.45e-01 | 2.62e-01 |\n",
"\n",
"<span style=\"color:red\">\n",
"Larger initialisation scale of 0.5 seems to give slightly slower initial learning than smaller scales of 0.1 and 0.01 however difference is only slight suggesting for this shallow architecure training performance is not particularly sensitive to initialisation scale.\n",
"</span>"
2017-10-06 15:46:19 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2017-10-15 14:13:09 +02:00
"### Varying learning rate\n",
"\n",
"<span style=\"color:red\">Now let's try some different values for learning rate.</span>"
2017-10-06 15:46:19 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2017-10-15 14:13:09 +02:00
"#### `learning_rate = 0.05`"
2017-10-06 15:46:19 +02:00
]
},
{
"cell_type": "code",
2017-10-15 14:13:09 +02:00
"execution_count": 6,
2017-10-09 18:30:53 +02:00
"metadata": {
2017-10-15 14:13:09 +02:00
"scrolled": true
2017-10-09 18:30:53 +02:00
},
2017-10-15 14:13:09 +02:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 4.0s to complete\n",
" error(train)=3.41e-01, acc(train)=9.05e-01, error(valid)=3.16e-01, acc(valid)=9.12e-01\n",
"Epoch 10: 4.0s to complete\n",
" error(train)=3.10e-01, acc(train)=9.14e-01, error(valid)=2.92e-01, acc(valid)=9.18e-01\n",
"Epoch 15: 5.1s to complete\n",
" error(train)=2.97e-01, acc(train)=9.18e-01, error(valid)=2.82e-01, acc(valid)=9.21e-01\n",
"Epoch 20: 4.1s to complete\n",
" error(train)=2.88e-01, acc(train)=9.20e-01, error(valid)=2.76e-01, acc(valid)=9.23e-01\n",
"Epoch 25: 4.4s to complete\n",
" error(train)=2.83e-01, acc(train)=9.21e-01, error(valid)=2.73e-01, acc(valid)=9.24e-01\n",
"Epoch 30: 3.9s to complete\n",
" error(train)=2.77e-01, acc(train)=9.22e-01, error(valid)=2.69e-01, acc(valid)=9.24e-01\n",
"Epoch 35: 3.7s to complete\n",
" error(train)=2.74e-01, acc(train)=9.24e-01, error(valid)=2.67e-01, acc(valid)=9.25e-01\n",
"Epoch 40: 4.0s to complete\n",
" error(train)=2.72e-01, acc(train)=9.24e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 45: 3.7s to complete\n",
" error(train)=2.68e-01, acc(train)=9.26e-01, error(valid)=2.64e-01, acc(valid)=9.27e-01\n",
"Epoch 50: 4.7s to complete\n",
" error(train)=2.66e-01, acc(train)=9.26e-01, error(valid)=2.63e-01, acc(valid)=9.28e-01\n",
"Epoch 55: 3.7s to complete\n",
" error(train)=2.64e-01, acc(train)=9.26e-01, error(valid)=2.62e-01, acc(valid)=9.29e-01\n",
"Epoch 60: 4.8s to complete\n",
" error(train)=2.63e-01, acc(train)=9.26e-01, error(valid)=2.62e-01, acc(valid)=9.28e-01\n",
"Epoch 65: 3.8s to complete\n",
" error(train)=2.61e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.27e-01\n",
"Epoch 70: 4.2s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 75: 4.3s to complete\n",
" error(train)=2.58e-01, acc(train)=9.29e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 80: 4.5s to complete\n",
" error(train)=2.57e-01, acc(train)=9.29e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 85: 4.2s to complete\n",
" error(train)=2.56e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 90: 4.5s to complete\n",
" error(train)=2.55e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 95: 4.0s to complete\n",
" error(train)=2.54e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 100: 3.4s to complete\n",
" error(train)=2.53e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNX9//HXubNlX2YmC1lYsiDKImiUpbggQRC0LlW/\najerba1rtf26INa6FKUK6ret/YKKlD7qr0XlKyoqICgqohBAFhEhYYckhOz7MnPv748bhgQCCTCT\nyfJ5Ph55zHbnzpnDDe+cc889RxmGYSCEEEKILk8LdgGEEEII0TES2kIIIUQ3IaEthBBCdBMS2kII\nIUQ3IaEthBBCdBMS2kIIIUQ3IaEthBBCdBMS2kIIIUQ3IaEthBBCdBMS2kIIIUQ3YQ12AdqSn58f\n7CL0KG63m+Li4mAXo0eROg0MqVf/kzoNDH/Xa1JSUoe2k5a2EEII0U1IaAshhBDdhIS2EEII0U10\nyXPaQgghOodhGNTX16PrOkqpYBen2zh06BANDQ2n9B7DMNA0jZCQkNOuawltIYToxerr67HZbFit\nEgenwmq1YrFYTvl9Ho+H+vp6QkNDT+tzpXtcCCF6MV3XJbA7kdVqRdf1036/hLYQQvRi0iXe+c6k\nznt0aBv1degfvoWR+12wiyKEEEKcsQ6F9saNG/ntb3/Lvffey6JFi457fdmyZfz+97/nwQcf5A9/\n+AMHDhxo9XpxcTE//elPee+99/xT6o6yWDE+ehvjq08693OFEEJ0Ka+++ipvvfUWAAsWLKCwsPCU\n9/HPf/7Tt48T2bZtG/fff/9plbEj2j2Roes6c+fO5bHHHsPlcjF16lSysrJISUnxbTN27Fguv/xy\nANatW8f8+fOZNm2a7/X58+czYsSIABT/5JTNhhpyPsamtRi6jtJ6dMeCEEL0WF6vt9XAr2Mfn4jH\n4wHMoF6yZAkAb731FoMGDSIxMbHdz2npZz/7Wbufd/bZZ1NQUMDBgwdJTk5ud/tT1W6K5eXlkZiY\nSEJCAlarlTFjxpCTk9Nqm7CwMN/9+vr6Vv31a9euJT4+vlXId6rhI6GyHHbvCM7nCyGEaNfChQuZ\nMmUKEyZM4KGHHsLr9ZKZmcmTTz5JdnY269evZ+TIkUyfPp2JEyeyePFivv32W6688kqys7O5/fbb\nKS8vB+D666/n8ccf54orruC1117jyy+/ZMiQIVitVhYvXsymTZu45557mDBhAnV1dcft94033mDy\n5MlkZ2fzq1/9irq6OgBmzZrF7NmzAbj22muZPn06U6ZMYezYsaxZs8b3XSZMmMC7774bkHpqt6Vd\nWlqKy+XyPXa5XOTm5h633ZIlS/jggw/weDw8/vjjgBng7777Ln/4wx9O2jW+fPlyli9fDsCMGTNw\nu92n/EVORL9kAofnvUTIji1Ejhzrt/12J1ar1a91KqROA0Xq1f/aq9NDhw75Ro97/t8c9H27/Pr5\nWt80rLfccdJtduzYwfvvv8/ixYux2Ww8/PDDvPvuu9TW1pKVlcXTTz8NmAO4XC4XK1asAODSSy/l\nmWeeYcyYMfz5z3/mpZde4k9/+hNKKbxeLx9//DEAzz33HMOHD8dqtXLNNdcwf/58/vjHPzJ8+PA2\n91taWsrPf/5zAJ599lkWLFjAL3/5SzRNQ9M0X33pus7SpUtZvnw5L774Im+//TYA5513Hn/5y1+4\n77772vy+DofjtI9zv43znzRpEpMmTWLVqlUsXLiQe+65hzfffJMpU6YQEhJy0vdmZ2eTnZ3te+z3\nye0zB1P71UoarrjBv/vtJmTBAP+TOg0MqVf/a69OGxoafN3Buq5jGIZfP1/XdV8X9Yl89tlnbN68\n2Xeatb6+HqfTicViYdKkSb73G4bBlVdeicfjobKykoqKCi688EI8Hg8/+tGPuOOOO/B4PK22Aygs\nLCQ9Pb3Vfrxeb5v7Bdi6dSvPPfcclZWV1NTUcMkll+DxeNB1vdX3OVK2wYMHs3//ft/zMTExFBYW\nnvB7NzQ0HPdv0tEFQ9oNbafTSUlJie9xSUkJTqfzhNuPGTOGV199FTC71tesWcMbb7xBTU0NSins\ndjuTJk3qUOH8RQ0fifGfVzEO5aMSOlYxQgjR22g3/Soon2sYBjfccANTp05t9fzs2bOPO7/c8nTs\nybTcLiQkhPr6+g5v/8ADDzB37lwGDx7MggUL+Oqrr9p8j91uB8BisbQK6IaGhnYbq6er3XPa6enp\nFBQUUFRUhMfjYfXq1WRlZbXapqCgwHd/w4YN9OnTB4CnnnqKl19+mZdffpnJkydz7bXXdnpgA6hz\nLwTA2LSmnS2FEEJ0trFjx7J48WJf67OsrOy4q5COFRUVRXR0tO9c8sKFCxk1alSb22ZkZLBnzx7f\n4/DwcKqrq0+47+rqahISEmhqauKdd945xW8Du3bt4qyzzjrl93VEuy1ti8XCbbfdxvTp09F1nXHj\nxpGamsqCBQtIT08nKyuLJUuWsGXLFiwWCxEREdx9990BKezpUu4ESOmPsXENXH5tsIsjhBCihYED\nB/LQQw9x8803YxgGVquV6dOnt/u+l156iUceeYT6+nr69u3LCy+80OZ2l112WavzyzfeeCOPPPII\nISEhbY63evDBB7nyyitxuVyMGDHipAHfltWrVzN+/PhTek9HKcPfJzD8ID8/3+/71N99A+ODt9Bm\n/RMVGeX3/Xdlcp7Q/6ROA0Pq1f/aq9Pa2toOdzl3Z7fffjvTpk0jLS3NL/uzWq1tnrNuaGjgRz/6\nEYsWLTrh9LBt1XlHz2n3mguX1fCRYOgYW3La31gIIUSPMnXqVIqKigL+OQcPHuTRRx8N2HzuvWeW\n+L7pEOMyu8jHBKbbQgghRNeUkZFBRkZGwD8nLS3Nb635tvSelrZSqOEXwtZvMJoag10cIYQQ4pT1\nmtAGUOeOhMYG2LYp2EURQgghTlmvCm3OGgohoWYXuRBCCNHN9KrQVjYbavB5GJtzMM5gEXIhhBAi\nGHpVaAPmAiIVZbDn+PnThRBC9Ewtl+Y8Vffffz+LFy8G4L//+7/ZseP4BagWLFjgW91y3rx5/Oc/\n/zn9wp5ErwttNTQLNE26yIUQohvxer0nfXwiHo8Hj8fDggULuPbaM59ca+bMmQwcOPCk29x00028\n/vrrZ/xZbel9oR0eAZmDJbSFEKIL6aylOfPy8pgyZYrvc/fv3++bvezFF19k8uTJXHbZZTz00ENt\nLp5y/fXXs2mTOZh5wYIFjB07lilTprBu3TrfNqGhoaSmpvLNN9/4vZ56z3XaLajhIzEWvIZRlI+K\nlwVEhBAC4LV1h9hddvKFNU7VgNgQfpmVcNJtcnNzee+991i0aBE2m42pU6fyf//3f9TW1jJixAj+\n+Mc/+raNjY1l6dKlgLlC5NNPP83o0aN5/vnneeGFF3jqqacAaGpq4qOPPgLM1vGwYcMA83rtxsZG\n9u3bR9++fXnvvfe46qqrALj11lt54IEHALj33nv5+OOPfSuPHevQoUPMnDmTJUuWEBkZyQ033MCQ\nIUN8rw8bNow1a9YwYsSI06m2E+p1LW1osYDIxrVBLokQQohVq1axZcsWJk+ezIQJE1i1ahX79u3D\nYrG0ahUD/PCHPwTwLc05evRoAG644Qbf4iEttwMoKirC5XL5Hl911VW+Ocffe+8937arV6/myiuv\nZPz48axevbrNc9dHrF+/ntGjR+NyubDb7a0+D8zpYw8dOnQ61XFSvbOlHZcIyf3MVb8uvybYxRFC\niC6hvRZxoHT20pw//OEPueOOO7jiiitQSpGWlkZ9fT2PPvooH374IcnJycyaNYuGhobT/k6BWp6z\nV7a0oXku8txtGNWVwS6KEEL0ap29NGf//v2xWCy89NJLvhbykYB2Op3U1NTwwQcfnPTzzz//fL7+\n+mtKS0tpamryjS4/YteuXQwaNOik+zgdvbKlDc3ntT94E2PzOtSYy4JdHCGE6LU6e2lOMFvbTz/9\nNF9//TUA0dHR3HLLLYw
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435d77940>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAENCAYAAAAbl4wiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xt81NWd//HXd2Zyv88MJISEEMIthDtRKWogEsX1QlHp\nxQu1glV/VWxr1dUu2m63dGnFraurrlrAio26a0Vrd6lrlKCCJgQMFwNKIJA7uUzITC4zycz3/P4Y\nCESQXJhhcvk8H488MpM58/2e7yHkPed8v99zNKWUQgghhBCDliHQFRBCCCHE+ZEwF0IIIQY5CXMh\nhBBikJMwF0IIIQY5CXMhhBBikJMwF0IIIQY5CXMhhBBikJMwF0IIIQY5CXMhhBBikJMwF0IIIQY5\nU6Ar0FfV1dWBrsKQYbVaaWhoCHQ1hhxpV9+TNvU9aVP/8HW7JiYm9qqc9MyFEEKIQU7CXAghhBjk\nJMyFEEKIQW7QnTP/OqUUTqcTXdfRNC3Q1RlUjh07hsvl6nqulMJgMBAaGiptKYQQg8igD3On00lQ\nUBAm06A/lAvOZDJhNBq7/cztduN0OgkLCwtQrYQQQvTVoB9m13VdgtyHTCYTuq4HuhpCCCH6YNCH\nuQwH+560qRBCDC696tIWFxezYcMGdF1n4cKFLFmypNvr9fX1PP/889jtdiIjI1m5ciUWi4X6+nrW\nrl2Lrut4PB6uvvpqrrrqKgAOHz7Ms88+S0dHB7NmzeKOO+6QEBFCiEHKoys+rXBgDjMxZWR4oKsz\n7PTYM9d1nXXr1vGLX/yCP/zhD2zbto3KyspuZTZu3EhWVhZr165l6dKl5ObmAhAXF8dvfvMbnnji\nCX7729/yzjvvYLPZAHjppZe4++67efrpp6mtraW4uNgPhzc4KKX4zne+g8PhoLm5mZdffrlf21m2\nbBnNzc3nLPPrX/+aTz75pF/bF0KIs9l7rJWfbT7CE59U8+j75fz+4yrqWjoDXa1hpccwLy0tJSEh\ngfj4eEwmE/PmzWPHjh3dylRWVjJ16lQAMjIyKCoqArznX4OCggDo7OzsOhfb1NREe3s7EydORNM0\nsrKyztjmcPLBBx8wZcoUoqKisNvtvPLKK2ct53a7z7mdjRs3EhMTc84yy5cv59lnn+13XYUQ4qT6\n1k5+/3EVq/IqaO/08NBlidwy3cqOqhbu/dthXt/TgMst1+BcCD0Os9tsNiwWS9dzi8XCwYMHu5VJ\nSUmhsLCQa665hsLCQtrb23E4HERFRdHQ0MCaNWuora3ltttuw2w2c+jQoTO2ebLHPhgtX76c6upq\nXC4XK1as4LbbbmPLli2sWbMGj8eD2Wzmv/7rv2htbWXVqlXs2bMHTdP42c9+xrXXXsumTZu49dZb\nAfjtb3/L0aNHufLKK8nKymLhwoU88cQTxMTEUFpayieffHLW/QFccsklbN68mdbWVm677TYuvvhi\nioqKSEhIYP369YSFhZGUlERTUxN1dXW9niZQCCFO53LrbNpv4y9fNAJwy3QrS9LNhJi8/cMrxsXw\n8ud1vLa3gQ8OH+eO2SP5VnKUnEr1I59cBr5s2TLWr19Pfn4+6enpmM1mDAbvP6rVamXt2rXYbDae\neOIJ5s6d26dt5+XlkZeXB8CaNWuwWq3dXj927FjX1ezu3BfQyw/74IhOMYwZh+mWu89Z5t///d+J\ni4ujvb2dRYsWcc011/Dwww/z9ttvk5KSQlNTEyaTiaeffpqYmBi2bt0KwPHjxzGZTBQVFfHkk09i\nMpl47LHH+PLLL9myZQsA27ZtY+/evWzdupWUlJSz7m/x4sWYzWY0TcNoNGI0GikrK+OFF17gD3/4\nAz/60Y947733WLp0KQDTp09n165dJCYmnvVOgJCQkDPaWfSeyWSS9vMxaVPf60+bKqXIL23kPz4+\nSq3DxRUTrNx72VgSokO7lbNa4Xcpo/i8spmnth7idx9XMycphp/OH8c4a4QvDyPglMeDp6aCzrKv\ncB8+iDNhNNZFS3p+o4/1GOZms5nGxsau542NjZjN5jPKPPjgg4D3vu+CggIiIiLOKJOcnMyBAweY\nNGlSj9s8KScnh5ycnK7nX5/A3uVydd0rres6SqmeDqlPdF3vcXj7xRdfZPPmzYB3IZg//elPXHLJ\nJYwePRq3201UVBRut5utW7fy3HPPdW0vMjISt9tNU1MToaGhuN1uPB4PcGpI3ePxMHPmzK5tnW1/\nBw8eZM6cOSil8Hg8eDwekpOTmTx5Mm63m6lTp3LkyJGu95vN5q4Fa852bC6XSxZgOA+ygIXvSZv6\nXl/b9OhxF38sOsaeY22kxIbwm5xkpsVHQEcLDQ0tZ31Pcij8/spk/q/0OH/eXc/tuZ/zDxPjuGWa\nlcgQ41nfM5CpDhdUlaMqDkFFGar8MFQegY4Tk2+ZTITOv5qWOZf5bJ+9HUHtMczT0tKoqamhrq4O\ns9nM9u3buf/++7uVOXkVu8FgYNOmTWRnZwPekI6KiiI4OJiWlha+/PJLrrvuOuLi4ggLC+Orr75i\nwoQJfPTRR1x99dX9OMzuDN//0Xlvo6+2b9/Oxx9/zLvvvktYWBhLly4lIyODQ4cO9XobJ+/tPjma\n8XXh4aeuDD3b/k6fxe2kkJCQrsdGoxGn09n13OVyERoaesZ7hBDi61pcHnL3NrD5qybCgwzcfVE8\ni8bHYjT0bsjcaND4h4lxXJoSTe7uejZ/1cRHR+zcNsPKlWm9386FplodUH4YVVEGFYe9wV1bCSfn\n4QgLh+RUtMuvgjHj0MaMg4QkYhJGBeSDZ49hbjQaWb58OatXr0bXdbKzs0lOTuaNN94gLS2NzMxM\nSkpKyM3NRdM00tPTWbFiBQBVVVW88soraJqGUorrr7+eMWPGAHDnnXfy3HPP0dHRwcyZM5k1a5Z/\nj9RPHA4HMTExhIWFUVpayq5du3C5XHz22WeUl5czZswYmpqaiIuLIysri5dffplf//rXgHeYPTY2\nlnHjxnH06FFSU1OJiIigpeXsn3K/aX99dfjwYa677rp+H7MQYujz6Ir3Dx3n1d0NtHZ4WDQ+lltm\njCC6nz3q6BAj91ycwKIJsbxUdIznC4/x94PHuSszPqC3simlwNYAFYdQ5WWoisNQfhhs9acKxZoh\neRzarLloyeNgzDiwxg+oawB6dc589uzZzJ49u9vPvve973U9njt37lnPhU+fPp21a9eedZtpaWk8\n+eSTfanrgLRgwQI2btzI/PnzSUtLY/bs2VgsFn7/+99z5513ous6VquV119/nZ/85Cf84he/4Ior\nrsBgMPDAAw9wzTXXsHDhQj799FNSU1Mxm81cdNFFXHHFFWRnZ7Nw4cIe99cXnZ2dHDlyhBkzZviy\nGYQQQ0hJXRsvFh2jrMlFxsgwfpQZT2qcb0bzUuNCWZ0zhm3lDjbsquPR98vJSonm9tkjsIZ7735S\nSkGLwxuotnpUix2UDroCpbyPFSe+q1PfdfW1n/HNr3V2oqrLoaIMWh3eymkaxCeipU2GBdd4e9vJ\nqWjRsT45dn/SlK9PMvvZyXO9J7W1tXUbhh6Mjh07xk9+8hNef/11v+9r8+bN7N27l4cffhiTyXTW\nc+ZDoU0DSc7v+p60qe+drU0b2jr50656PjpqxxJu4o5ZI7ksxfdXoavODmhqwFlXz1+OuHi7ORKD\nUtzUuo/FVR8T3FgLHR2+26GmnfgynHpsMELC6BOBfWKYfHQKWmjf1qXw6AqXR6fD7f0+0mpBczp8\nVnWfnTMX/hcfH88tt9zSdTufP7ndbu6++9xX5wshhpcOj847+238975GdAXfnWrhpgwLoaa+z/it\nlALHcWhs8PaqT/aubfXQ6H2Mwzu5VTBwM3BFaBx/mnwDubEz+GDCeO6YUs7FFiMGywiwjIDIGDAa\nuofx6eFsOD2oDaDRrezJDyMdHp3GNjeNbW5s7W6cbh2XW8flUXS06nSUOHB5munwqK5wdnkUHSfL\neHRcbu/3Do/C5dbxfK0
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435aaef60>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.05 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)"
]
2017-10-06 15:46:19 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2017-10-15 14:13:09 +02:00
"#### `learning_rate = 0.1`"
2017-10-06 15:46:19 +02:00
]
},
{
"cell_type": "code",
2017-10-15 14:13:09 +02:00
"execution_count": 7,
2017-10-09 18:30:53 +02:00
"metadata": {
2017-10-15 14:13:09 +02:00
"scrolled": true
2017-10-09 18:30:53 +02:00
},
2017-10-15 14:13:09 +02:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 5.1s to complete\n",
" error(train)=3.11e-01, acc(train)=9.13e-01, error(valid)=2.92e-01, acc(valid)=9.18e-01\n",
"Epoch 10: 3.4s to complete\n",
" error(train)=2.89e-01, acc(train)=9.20e-01, error(valid)=2.77e-01, acc(valid)=9.23e-01\n",
"Epoch 15: 3.8s to complete\n",
" error(train)=2.79e-01, acc(train)=9.22e-01, error(valid)=2.70e-01, acc(valid)=9.24e-01\n",
"Epoch 20: 3.4s to complete\n",
" error(train)=2.72e-01, acc(train)=9.24e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 25: 4.3s to complete\n",
" error(train)=2.68e-01, acc(train)=9.25e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 30: 3.9s to complete\n",
" error(train)=2.63e-01, acc(train)=9.27e-01, error(valid)=2.62e-01, acc(valid)=9.26e-01\n",
"Epoch 35: 4.5s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 40: 5.9s to complete\n",
" error(train)=2.59e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 45: 4.0s to complete\n",
" error(train)=2.55e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 50: 4.1s to complete\n",
" error(train)=2.54e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 55: 5.4s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 60: 4.4s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 65: 3.4s to complete\n",
" error(train)=2.50e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 70: 5.4s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.59e-01, acc(valid)=9.31e-01\n",
"Epoch 75: 3.7s to complete\n",
" error(train)=2.47e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 80: 4.4s to complete\n",
" error(train)=2.46e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 85: 4.0s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 90: 5.1s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 95: 3.6s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 100: 3.9s to complete\n",
" error(train)=2.43e-01, acc(train)=9.33e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNX9+P/XmSWTBbLMTEgCJEBCWAMBjYiICxAVFRX3\nWtt+FPupP6xL9dMPCFpL64eWVkT7betHpZRq+bRNK+ICCjRgRYnYKCD7EhL2QMhMFrJNMnPv748b\nBgKBBLhDtvfz8chj5s6cuffcN0PeOeeee47SdV1HCCGEEO2epa0rIIQQQojWkaQthBBCdBCStIUQ\nQogOQpK2EEII0UFI0hZCCCE6CEnaQgghRAchSVsIIYToICRpCyGEEB2EJG0hhBCig5CkLYQQQnQQ\ntrauQHMOHz7c1lXoVNxuN6WlpW1djU5FYhoaElfzSUxDw+y49uzZs1XlpKUthBBCdBCStIUQQogO\nQpK2EEII0UG0y2vaQgghLg1d16mrq0PTNJRSbV2dDuPo0aP4fL7z+oyu61gsFsLDwy841pK0hRCi\nC6urq8Nut2OzSTo4HzabDavVet6f8/v91NXVERERcUHHle5xIYTowjRNk4R9CdlsNjRNu+DPS9IW\nQoguTLrEL72LiXmnTtp6XS3ax++gF2xv66oIIYQQF61VfSIbN25k4cKFaJrGhAkTmDx5cpP3V65c\nyYoVK4IX2B999FF69+7N8ePHmTdvHgUFBVx//fU88sgjITmJs7La0D9+B44eRvUffGmPLYQQot2Y\nP38+sbGx3HvvveTk5HDdddeRmJh4Xvt4++23iYiI4N577z1rme3bt/PGG2/w6quvXmyVm9Vi0tY0\njQULFvD888/jcrmYMWMGWVlZ9O7dO1hm7Nix3HjjjQB89dVXvPXWWzz33HPY7Xbuv/9+9u/fz4ED\nB0JyAuei7HZU5ij0jV+i+/0ouW4jhBAdUiAQaDLw6/Tts/H7/QDk5OSwfPlyAP7xj38waNCgZpP2\nufb7ve99r8XjDR48mOLiYg4dOkSvXr1aLH++WuweLygoIDExkYSEBGw2G2PGjCE/P79JmcjIyODz\nurq6YH99eHg4gwYNIiwszORqt566/GqoPg47NrVZHYQQQpzb4sWLufXWW7nhhhuYNm0agUCA9PR0\nfvazn5Gdnc3XX3/NlVdeyezZs7nppptYunQpW7ZsYdKkSWRnZ/PII49QXl4OwD333MMLL7zAzTff\nzB/+8AfWrl1LRkYGNpuNpUuX8s033/D4449zww03UFtbe8Z+/+///o9bbrmF7Oxs/vM//5Pa2loA\nXn75ZV5//XUA7rzzTmbPns2tt97K2LFj+fLLL4PncsMNN/D++++HJE4tNj29Xi8ulyu47XK52L17\n9xnlli9fzrJly/D7/bzwwgvm1vJiDB0J4RHoX69FZVzW1rURQoh2S/vbfPQDRabuUyX3w/Kt/zxn\nmd27d/PBBx/w3nvvYbfbmTFjBu+++y41NTWMHDmSn/70p8GycXFxrFixAoDs7GxefPFFrrrqKl56\n6SXmzZvHz3/+cwAaGhr4+OOPAZg7dy7Dhw8HYNKkSfzpT3/iJz/5CZmZmc3u1+v18uCDDwLwq1/9\nir/+9a9MmTLljHr7/X6WLVvGqlWrmDdvHjk5OQBkZmbyu9/9jscee+yCYnYupvUXT5w4kYkTJ/L5\n55+zePFiHn/88VZ/Njc3l9zcXADmzJmD2+02q1oAVIy6Bt/6dbie+kmX7CK32Wymx7Srk5iGhsTV\nfC3F9OjRo8FbvvwWC5rJo8ktFkuLt5Tl5eWxefNmbr31VsDose3RowdWq5U77rgj2F2tlOLOO+/E\nZrNRWVlJZWUl11xzDQAPPPAA3//+97HZbE3KARw7doyBAwcGt5VSWK3WJtunli8oKGDOnDlUVFRQ\nXV3NuHHjsNlsWCyWJudz2223YbPZGDlyJAcPHgy+npCQQElJyVnP2+FwXPD3vMUM5nQ68Xg8wW2P\nx4PT6Txr+TFjxjB//vzzqkR2djbZ2dnBbbNXpNGHZaGvWUnp56u7ZGtbVvkxn8Q0NCSu5msppj6f\n7+Q13PseCcktRSeuK59NIBDg3nvvZcaMGU1ef+2119B1Pfh5XddxOBz4/X78fn+T907dPrUcGEmy\npqamyX4CgUCz+wV48sknWbBgAUOHDiUnJ4cvvvgCv9+PpmlomhYsZ7Vamxz/xPPq6uom+zudz+c7\n49/EtFW+0tLSKC4upqSkBL/fT15eHllZWU3KFBcXB5+vX7+epKSkVh38khl6WbCLXAghRPsyduxY\nli5dGkxkZWVlHDx48JyfiY6OJiYmJngtefHixYwePbrZsv3792fv3r3B7aioKKqqqs6676qqKhIS\nEmhoaGDJkiXneTZQWFjIwIEDz/tzrdFiS9tqtTJlyhRmz56NpmmMGzeO5ORkcnJySEtLIysri+XL\nl7N582asVivdunXjhz/8YfDzP/zhD4N/4eTn5/P88883GXl+KSh7GGr4KPQN69AfnNolu8iFEKK9\nGjBgANOmTeOBBx5A13VsNhuzZ89u8XOvvvoqzz77LHV1daSkpDBv3rxmy40fP54nn3wyuH3ffffx\n7LPPEh4ezgcffHBG+f/+7/9m0qRJuFwuRo4cec4E35y8vDwmTJhwXp9pLaXruh6SPV+Ew4cPm75P\nfeM6tN//AsuPfoYaOtL0/bdn0uVoPolpaEhczddSTGtqaprcAdRZPfLIIzz33HOkpqaasj+bzdZs\n97fP5+Puu+/mvffeO+s17eZiblr3eKchXeRCCNFlzZgxg5KSkpAf59ChQ8ycOTNk87l3maR9sov8\nC/QWBkUIIYToXPr373/Wa95mSk1NZcyYMSHbf5dJ2gAq62qoOg67Nrd1VYQQQojz1qWSNkNHgiMC\n/SvpIhdCCNHxdKmkrcIcqMwrjFHkgUBbV0cIIYQ4L10qaUPjXORVlbBTusiFEEJ0LF0uaZNxmdFF\nLqPIhRCiy5g/fz7/+Mc/LuizP/rRj1i6dCkAP/7xj9m1a9cZZXJycnjuuecAWLhwIX/7298uvLLn\n0OWStgpzoIZnoa//QrrIhRCigwic9vv69O2zOTG9aE5ODnfeeedF12Pu3LkMGDDgnGW+9a1v8cc/\n/vGij9WcLpe04cQo8krYtaWtqyKEEIJLtzRnQUFBcGESgAMHDgRnL3vllVe45ZZbGD9+PNOmTaO5\nucfuuecevvnmG8BoXY8dO5Zbb72Vr776KlgmIiKC5ORkNmzYYHqcuuZ8nhmXgyMc/au1qMGZLZcX\nQogu4A9fHaWorM7UffaLC+f7WQnnLHMpl+bs378/9fX17N+/n5SUFD744ANuu+02AB566CGefvpp\nAJ544gn++c9/cuONNzZb56NHjzJ37lyWL19O9+7duffee8nIyAi+P3z4cL788ktGjjR3Bs6u2dIO\nc6CGX2FMtCJd5EII0aY+//xzNm/ezC233MINN9zA559/zv79+7FarU1axQC33347AJWVlVRUVHDV\nVVcBcO+99wYXDzm1HEBJSQkulyu4fdtttwXnHP/ggw+CZfPy8pg0aRITJkwgLy+v2WvXJ3z99ddc\nddVVuFwuwsLCmhwPjOljjx49eiHhOKeu2dLGGEWu539mdJFLa1sIIVpsEYeKruvNLs35+uuvn1w2\ntFFr50k/tVx4eDh1dSd7EG6//XYeffRRbr75ZpRSpKamUldXx8yZM/noo4/o1asXL7/8Mj6f74LP\nyefzER4efsGfP5su2dIGjC7yMIdMtCKEEG3sUi/N2bdvX6xWK6+++mqwhXwiQTudTqqrq1m2bNk5\nj3/55Zezbt06vF4vDQ0NwdHlJxQWFjJo0KBz7uNCdN2WtuOULvJvP4o67a85IYQQl8alXpoTjNb2\niy++yLp16wCIiYnh29/+NhMmTCA+Pp7MzHP3wCYkJPBf//Vf3H777cTExDB06NAm7+fn5/PMM8+0\neA7nq8sszdkc/eu1aK/
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc4685a60f0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAENCAYAAADngqfoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNX9//HXnZns+0xCFhJCCIsBBIlRKGokJG6gliLa\nuuACrfUnotalVYtLXVpa+LpRt68FFyxfrQvWjaoBsWiEhE2FBEhYZElClkkyk2Rmkpl7fn8MDETA\nQJiQ7fN8PPJIZubMveceEt5z7j33HE0ppRBCCCFEr2To6goIIYQQovNI0AshhBC9mAS9EEII0YtJ\n0AshhBC9mAS9EEII0YtJ0AshhBC9mAS9EEII0YtJ0AshhBC9mAS9EEII0YtJ0AshhBC9mKmrK+Av\n5eXlXV2FXiU2NpaampqurkavIm3qf9KmnUPa1f86o02TkpKOq5z06IUQQoheTIJeCCGE6MUk6IUQ\nQoherNdco/8xpRROpxNd19E0raur0+Ps378fl8sFeNvSYDAQHBwsbSmEED1Mrw16p9NJQEAAJlOv\nPcROZTKZMBqNvsdutxun00lISEgX1koIIcSJ6rWn7nVdl5D3I5PJhK7rXV0NIYQQJ6jXBr2cYvY/\naVMhhOh5em3QCyGEEN2B2lWK/fXnUEp1yf4l6Ls5pRRXXnkldrudhoYGXn311Q5tZ/r06TQ0NPxk\nmUcffZSvvvqqQ9sXQgjRltq2Cc9TD6M/cTeOzz8Aa3WX1EOCvptbvnw5w4cPJyIiApvNxuuvv37U\ncm63+ye3s3jxYqKion6yzIwZM3juuec6XFchhOjrlFKoTevw/PU+9HkPwJ4daFfcQOz/vodm6dcl\ndTqu0WobN27klVdeQdd1cnNzmTJlSpvXq6ureeGFF7DZbISHhzN79mwsFgvV1dXMnz8fXdfxeDxc\nfPHFXHjhhbhcLp588kn279+PwWDgzDPP5NprrwVg5cqVLF68GLPZDMDFF19Mbm6unw/71JkxYwbl\n5eW4XC5mzpzJddddxxdffMHcuXPxeDyYzWb+9a9/0dTUxJw5c/juu+/QNI3f/e53TJ48maVLl/ra\n5s9//jM//PADF1xwAdnZ2eTm5jJv3jyioqIoKyvjq6++Our+AMaOHcuyZctoamriuuuu4+yzz2bt\n2rUkJCSwaNEiQkJCSE5Opq6ujqqqquOeWlEIIQQoXYeNa9A/eRt+KIOYWLRf3Yx23gVogUEYQsKg\nydEldWs36HVdZ+HChcyZMweLxcL9999PVlYWycnJvjKLFy8mOzubCRMmsGnTJpYsWcLs2bOJiYnh\n8ccfJyAgAKfTyd13301WVhZhYWFcdtlljBw5ErfbzaOPPsqGDRsYM2YMAOPHj2fmzJl+O0j9zZdR\ne3b6bXsAWkoahl/9pt1y//M//0NMTAwOh4PJkydz0UUXce+99/Lee+8xYMAA6urqAHj66aeJiIhg\n+fLlANTX1wNQVFTEX//6VwAeeOABtm7dyueffw5AQUEB33//PStWrGDAgAFH3d+kSZN8H5oO2rlz\nJ8899xzz5s3jt7/9LZ988glXXHEFAKeffjpFRUX8/Oc/90MrCSFE76Y8HlTRKtSyd6B8N/RLRLv+\nNrSf5aCZArq6esBxBH1ZWRkJCQnEx8cD3hAuKipqE/R79+7l+uuvB2DEiBHMmzfPu/HDbm9rbW31\n3Z4VFBTEyJEjfWXS0tKora310yF1L4sWLWLZsmWAd+GdN954g3HjxvmCOSYmBoBVq1bx/PPP+94X\nHR0NeAM/PDz8mNs/44wzfNs62v527tx5RNCnpKT42n/UqFHs2bPH95rFYmH//v0dPl4hhOgLVGsr\n6psVqP+8C9WVkDQA7dd3o2Wdi3bYHCTdQbtBb7VasVgsvscWi4XS0tI2ZVJTUyksLGTSpEkUFhbi\ncDiw2+1ERERQU1PD3Llzqays5LrrrjsidJqamli3bh2TJk3yPbdmzRpKSkpITEzkhhtuIDY29qQO\n8nh63p2hoKCAVatW8eGHHxISEsK0adMYMWIE27dvP+5tHLx/3WA4+nCK0NDQn9zfwdntDhcUFOT7\n2Wg04nQ6fY9dLhfBwcHHXT8hhOhLlMuF+uoz1KdLoa4GUgdjuPUBGH022jH+n+5qfplRZvr06Sxa\ntIiVK1eSkZGB2Wz2BVNsbCzz58/HarUyb948xo0b5+utejwennnmGS655BLfGYMzzzyTc845h4CA\nAD7//HOee+45Hn744SP2mZ+fT35+PgBz58494sPA/v37u3zCnKamJqKjo4mIiKC0tJT169fjdrtZ\ns2YN+/btIzU1lbq6OmJiYjj//PN5/fXXefzxxwFvTz46Opr09HT27dtHWloaUVFRNDU1+Y7LaDSi\naZrv8dH2ZzQaMZlMaJqG0Wj0zXZ38D0GgwGDweB7vHPnTt9p+x+3X1BQ0El/6OrLTCaTtJ+fSZt2\nDn+2q6feStM7r9G65XtMKQMxDRxCwKChmAYOwRAR6Zd9nAp6cxOOZe/S9MGbKFs9AcPPIOz2PxI4\n+uzjmmOkK39X201Cs9nc5rR6bW3tEb1ys9nMPffcA3innl2zZg1hYWFHlElJSWHLli2MGzcOgJde\neomEhAQmT57sKxcREeH7OTc3lzfeeOOo9crLyyMvL8/3+Mfr/LpcrjZTuHaF7OxsXnvtNc455xzS\n09PJzMwkOjqav/71r9x0003ouk5sbCxvvvkmt99+Ow888ADZ2dkYDAbuuusuJk2axMSJE1m1ahUp\nKSlERkaSlZVFdnY2OTk55ObmopTyjbg/2v48Hg9utxulFB6PB4/HAxwapa/rOrqu43a7aW1tZefO\nnb7T+j8eye9yuWSN6pMga3z7n7Rp5/BHu6rmJtSnS1HLP4DWFhicgXtjIaz8z6FC5jgYMAgtJQ0t\nZRCkpIGlX7eanEs12lDLP0Kt+BCam2DEGAyTrkIfOgI7wHFedu7K9ejbDfr09HQqKiqoqqrCbDZT\nUFDA7bff3qbMwdH2BoOBpUuXkpOTA3g/FERERBAYGEhjYyNbt27l0ksvBeDNN9+kubmZW265pc22\nDvZwAdauXdtmLEBPExQUdMwPKhMnTmzzOCwsjGeeeeaIctdccw133HEH11xzDcARt7+NHz/+uPa3\nZs0awPuBa8WKFb7nD2///Px8Jk+e3OVnQoQQPZdyuVBffIRa9i40N6KddR7a5degJfT3vm6rhz07\nUXt2eL/v3oH6tvDQZDKhYZDiDX9SBqENSIOEFLRT/P+SaqhDffY+6stl4HLCmHEYJl2JNnDIKa2H\nP7TbckajkRkzZvDEE0+g6zo5OTmkpKTw1ltvkZ6eTlZWFsXFxSxZsgRN08jIyPCNmN+3bx+vv/46\nmqahlOKyyy5jwIAB1NbW8t5779G/f3/+8Ic/AIduo1u2bBlr167FaDQSHh7Orbfe2rkt0M3Fx8dz\nzTXX+MY8dCa3281vf/vbTt2HEKJ3Um436ut81EdvQr0VRmZi+MV0tAHpbcppkdEwYgzaiDGH3uty\nwr4fULt3+D4EqP/+B1paUAAmk3ewW8qgwz4EpKGFhNIepRToOnjc4PEc5bun7ePWFtTaVahVn4PH\ng3b2eWiXXInWf0C7++quNNVVc/L5WXl5eZvHzc3NbQaqiRNjMpmOOHUvbXpy5DSz/0mbdo4TaVel\n697byz5YAlUVkH4ahqnXow0deVJ1ULoH9pe3CX9274BG26FCln5gCjgU0voxgvxEGU1o4yeiXTwV\nrZ9/5hTp1qfuhRBCiB9TSsGmdejvLYa9O6F/KobbHoRRWX65xq4ZjJCYgpaYAmPPP7TPBqvvlD/l\nu0EpMBoPfJkOfTcc5Tmj4UePD5XVDn+ufypajKWdGvYcEvRCCNEDqZ2lqLWroF8S2pDhkJB8ym7v\nUqXF6O+9DmXFEJeANvMutLOzO33/mqZBtAWiLWinZ3XqvnoTCXohhOhB1LZN6B//C4o3gsEAuu69\njh0WAYMz0IYMRxs8HFL
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435b63588>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2017-10-06 15:46:19 +02:00
"source": [
2017-10-15 14:13:09 +02:00
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.1 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `learning_rate = 0.2`"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 3.5s to complete\n",
" error(train)=2.90e-01, acc(train)=9.19e-01, error(valid)=2.77e-01, acc(valid)=9.22e-01\n",
"Epoch 10: 3.8s to complete\n",
" error(train)=2.75e-01, acc(train)=9.23e-01, error(valid)=2.69e-01, acc(valid)=9.25e-01\n",
"Epoch 15: 5.3s to complete\n",
" error(train)=2.66e-01, acc(train)=9.26e-01, error(valid)=2.64e-01, acc(valid)=9.26e-01\n",
"Epoch 20: 4.3s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 25: 4.9s to complete\n",
" error(train)=2.57e-01, acc(train)=9.28e-01, error(valid)=2.64e-01, acc(valid)=9.27e-01\n",
"Epoch 30: 5.0s to complete\n",
" error(train)=2.53e-01, acc(train)=9.29e-01, error(valid)=2.61e-01, acc(valid)=9.30e-01\n",
"Epoch 35: 4.2s to complete\n",
" error(train)=2.50e-01, acc(train)=9.30e-01, error(valid)=2.60e-01, acc(valid)=9.30e-01\n",
"Epoch 40: 4.0s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 45: 4.4s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 50: 3.8s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.60e-01, acc(valid)=9.31e-01\n",
"Epoch 55: 3.9s to complete\n",
" error(train)=2.43e-01, acc(train)=9.32e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 60: 3.9s to complete\n",
" error(train)=2.44e-01, acc(train)=9.31e-01, error(valid)=2.63e-01, acc(valid)=9.29e-01\n",
"Epoch 65: 3.8s to complete\n",
" error(train)=2.41e-01, acc(train)=9.34e-01, error(valid)=2.60e-01, acc(valid)=9.30e-01\n",
"Epoch 70: 4.2s to complete\n",
" error(train)=2.40e-01, acc(train)=9.34e-01, error(valid)=2.62e-01, acc(valid)=9.29e-01\n",
"Epoch 75: 3.7s to complete\n",
" error(train)=2.38e-01, acc(train)=9.34e-01, error(valid)=2.60e-01, acc(valid)=9.30e-01\n",
"Epoch 80: 4.3s to complete\n",
" error(train)=2.38e-01, acc(train)=9.33e-01, error(valid)=2.62e-01, acc(valid)=9.29e-01\n",
"Epoch 85: 3.2s to complete\n",
" error(train)=2.36e-01, acc(train)=9.35e-01, error(valid)=2.61e-01, acc(valid)=9.30e-01\n",
"Epoch 90: 4.1s to complete\n",
" error(train)=2.36e-01, acc(train)=9.34e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 95: 3.0s to complete\n",
" error(train)=2.37e-01, acc(train)=9.34e-01, error(valid)=2.63e-01, acc(valid)=9.29e-01\n",
"Epoch 100: 3.5s to complete\n",
" error(train)=2.35e-01, acc(train)=9.35e-01, error(valid)=2.63e-01, acc(valid)=9.29e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNX9//HXmSWTfZmZkLCELew7GhAQrUAQRUBQ3Kpt\nrUst/iqKtShoK9ZiaUW09ttWpUhrYy2KgogoGFBBAohC2ETWsC8hC2SdJDP3/P6YEIgEsjCTyfJ5\nPh48Zrt35jOHybznnHvuvUprrRFCCCFEg2cKdAFCCCGEqBkJbSGEEKKRkNAWQgghGgkJbSGEEKKR\nkNAWQgghGgkJbSGEEKKRkNAWQgghGgkJbSGEEKKRkNAWQgghGgkJbSGEEKKRsAS6gKocO3Ys0CU0\nKU6nk6ysrECX0aRIm/qHtKvvSZv6h6/btVWrVjVaTnraQgghRCMhoS2EEEI0EhLaQgghRCPRILdp\nCyGEqB9aa1wuF4ZhoJQKdDmNxsmTJykpKanVOlprTCYTwcHBdW5rCW0hhGjGXC4XVqsVi0XioDYs\nFgtms7nW67ndblwuFyEhIXV6XRkeF0KIZswwDAnsemSxWDAMo+7r12Sh9PR05s+fj2EYjBgxgvHj\nx1d6fOnSpaxcuRKz2UxkZCSTJk0iNjYWgJSUFDZv3gzArbfeypAhQ+pcrBBCCN+SIfH6dzltXm1P\n2zAM5s2bx/Tp03n55ZdZu3YtR44cqbRM+/btmTVrFrNnz2bQoEGkpKQAsGnTJjIyMvjzn//MzJkz\n+eijjygqKqpzsbWlXcUYnyxE7/u+3l5TCCGE8JdqQ3vv3r3Ex8cTFxeHxWJhyJAhbNy4sdIyvXr1\nwmazAdC5c2dycnIAOHLkCN27d8dsNhMcHEzbtm1JT0/3w9u4CJMJ/en76FUf199rCiGEaHDmzp3L\ne++9B8CCBQs4ceJErZ/jrbfeqniOi9m5cyePPfZYnWqsiWqHx3NycnA4HBW3HQ4He/bsuejyq1at\nol+/fgC0a9eOhQsXMnbsWEpKStixYwdt2rS5YJ3U1FRSU1MBmDVrFk6ns9Zv5GLyrruR4s+WYLcF\nYYqI9NnzNiYWi8WnbSqkTf1F2tX3qmvTkydPNopt2h6Pp9LErx/evhi32w14gzo1NRWLxcLChQvp\n2bNnlXl0qee97777Kt2uqt169+7NiRMnOHHiRJXPD2Cz2er8Offp/9Tq1avZv38/M2bMAKBv377s\n27ePZ555hsjISLp06YLJdGHnPjk5meTk5Irbvjw0nE66BpYtJGvZB5hGjPHZ8zYmchhD35M29Q9p\nV9+rrk1LSkrqNAva195//33efPNNSktL6d+/P3/84x/p1q0b99xzD2vWrOGFF17gkUceYdy4caxe\nvZqHH36YxMREnnrqKVwuF+3ateOll14iOjqaiRMn0qNHDzZu3MjNN99M9+7d6dWrFwCLFy8mPT2d\nSZMmERwczJIlS7juuusqPW9BQQFvv/02paWldOjQgVdffZWQkBBeeuklwsLC+OUvf8ltt91Gv379\nSEtL48yZM7z00ktcddVVgDfTPvjgAx5++OEq32tJSckF/yc1PYxptaFtt9vJzs6uuJ2dnY3dbr9g\nua1bt7Jo0SJmzJiB1WqtuP+WW27hlltuAeAvf/kLLVu2rFFhvqISOkC7Tug1y9HDb5JJF0IIcRHG\n/+aiD2f49DlVQgdMdz54yWX27NnDkiVLWLx4MVarlWnTpvHBBx9QVFRE//79efbZZyuWjYmJYfny\n5YA3HJ9//nkGDx7Miy++yJw5c/j9738PQFlZGZ988gkAs2fPpk+fPgCMGTOGf/3rX/z2t7+lb9++\nVT5vTk4Od999NwB/+tOfeOeddy7oZYO3F//xxx+zcuVK5syZw4IFCwBvh/X//u//Lhral6PabdqJ\niYkcP36czMxM3G43aWlpJCUlVVomIyODuXPnMnXqVKKioiruNwyD/Px8AA4ePMihQ4cqNVJ9Udde\nD0cPwoGLD+sLIYQIjK+++opt27YxevRoRo4cyVdffcWhQ4cwm83cdNNNlZYdN24cAHl5eZw5c4bB\ngwcDcNttt7Fhw4YLlgPIzMystJm3Kucvv2vXLiZMmMCIESNYtGgRu3btqnKd0aNHA9CnT59KE7Qd\nDgcnT56syVuvtWp72mazmfvuu4+ZM2diGAbDhg0jISGBBQsWkJiYSFJSEikpKbhcLubMmQN4h2Oe\nfPJJ3G43v/vd7wAIDQ3lkUceCcgwjBpwLXrBPPSaFagOXer99YUQojGorkfsL1prbrvtNqZNm1bp\n/tdee+2CzAgNDa3Rc56/XHBwMC6Xq8bLT5kyhXnz5tGzZ08WLFjAunXrqlwnKCgI8Obk2W3n4B3+\nDg4OrlGdtVWjbdpXXHEFV1xxRaX77rjjjorrv/3tb6tcLygoiJdffvkyyvMNFRKKGjAU/fUa9O33\no4LrdiQaIYQQvjd06FB+/vOf8+CDD+J0OsnNzaWwsPCS60RGRhIVFcWGDRu46qqreP/99xk0aFCV\ny3bq1IkDBw5U3A4LC6OgoOCiz11QUEBcXBxlZWUsWrSI+Pj4Wr2f/fv307Vr11qtU1MNf8qgj6hr\nRqHXrkRvXIO65vpAlyOEEKJcly5dmDp1KnfddRdaaywWCzNnzqx2vVdeeaViIlrbtm0rRnt/aPjw\n4UyePLni9u23385TTz1VMRHth37zm98wZswYHA4H/fv3v2TAVyUtLY0RI0bUap2aUlpr7ZdnvgzH\njh3z+XNqrTGe/RWEhGKe9qLPn78hkxm5vidt6h/Srr5XXZsWFRXVeMi5Mbv//vt5+umn6dixo0+e\nz2KxVBoSP6ukpIRbb72VxYsXX3RXuqravKazx5vNsceVUt4e9v5d6KMHA12OEEKIejRt2jQyMzP9\n/jpHjx5l+vTpftv3vdmENoAaNAwsFvSaFYEuRQghRD3q1KnTRbd5+1LHjh39eo6N5hXaEZGo/oPR\n679Al5UGuhwhhBCiVppVaAOooSOhMB+9eX2gSxFCCCFqpdmFNt36gDMO/dVnga5ECCGEqJVmF9rK\nZPL2tnduQWceD3Q5QgghRI01u9AGUENGgDKh16YGuhQhhBD14PxTc9bWY489xtKlSwF44okn2L17\n9wXLLFiwgKeffhqA+fPn87///a/uxV5C8wztGAf0vtJ7sBWPJ9DlCCGEqIbnB9/VP7x9MW63G7fb\nzYIFC5gwYcJl1zF79my6dLn04bDvvPNO3nzzzct+rao0y9AGMF0zEs7kwLZvAl2KEEI0e++//z43\n3XQTI0eOZOrUqXg8Hjp37sxzzz1HcnIy3377LVdddRUzZ85k1KhRLF26lO3btzNmzBiSk5O5//77\nOX36NAATJ07kd7/7HTfeeCP//Oc/Wbt2Lb169cJisbB3795KJyE5fPhwxdHLXn75ZUaPHs3w4cOZ\nOnUqVR17bOLEiWzZsgXw9q6HDh3KTTfdxDffnMuSkJAQEhIS2Lx5s8/bqdkcxvQCvQdAlB3jq88w\n97sq0NUIIUTA/fObk2TkXvrEGrXVISaYB5LiLrlMfZ6as1OnTpSWlnLo0CHatm3LkiVLGDt2LAD3\n3nsvU6ZMAeCRRx7hs88+4/rrqz7s9cmTJ5k9ezaffvopERER3HbbbRXn7Abvmb82bNhA//7969Js\nF9Vse9rKbEYNGQ5bv0HnZle/ghBCCL+o71Nzjh07tuKY40uWLKlYNi0tjTFjxjBixAjS0tKq3HZ9\n1rfffsvgwYNxOBwEBQVVej3wHj7WH6fnbL49bbz7bOtPFqLTVqJuuj3Q5QghREBV1yP2l/o+Nee4\nceN46KGHuPHGG1FK0bFjR1wuF9OnT2fZsmW0bt2al156iZKSkjq/J3+dnrPZ9rQBVIuW0K0P+qvP\n0IYR6HKEEKJZGjp0KEuXLq04sUlubi5Hjhy55Drnn5oTqNWpOdu3b4/ZbOaVV16p6CGfDWi73U5h\nYSEff/zxJV//yiuvZP3
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435de5dd8>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAENCAYAAAAbl4wiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xlc1HX+wPHXdxjuS2a4BVHxCPEWlaxIBM08yuzOo9Ld\ntTI6bTf7ubttZdFqtW1b1pZH2rLaVnbZSUlpKOCBF154I/cMN8zAzPfz+2OUJA8Oh9PP8/HwAeN8\n5/v9fN8MvOdzK0IIgSRJkiRJnZamvQsgSZIkSdLlkclckiRJkjo5mcwlSZIkqZOTyVySJEmSOjmZ\nzCVJkiSpk5PJXJIkSZI6OZnMJUmSJKmT0zbloMzMTFauXImqqsTFxTFt2rQGzxcVFbFs2TLKy8vx\n8PAgISEBvV5PUVERS5cuRVVVrFYrEydOZMKECQBYLBaWL19OVlYWiqJw1113ER0dbf87lCRJkqQu\nTmls0RhVVXn00UdZtGgRer2ehQsX8uijjxISElJ/zKuvvsrw4cMZO3Yse/fuZePGjSQkJGCxWBBC\n4OjoiMlk4sknn+T5559Hp9Px4Ycfoqoqd911F6qqUllZiZeXV6vfsCRJkiR1NY02s2dnZxMYGEhA\nQABarZYxY8aQkZHR4JicnBwGDhwIQGRkJNu2bQNAq9Xi6OgIQF1dHaqq1r9m48aN9TV8jUYjE7kk\nSZIktVCjzexGoxG9Xl//WK/Xc/jw4QbHhIWFkZ6ezqRJk0hPT6empoaKigo8PT0pLi4mMTGR/Px8\nZs6ciU6no6qqCoB169aRlZVFQEAAc+bMoVu3buddPzk5meTkZAASExOpra29rBuWfqXVarFYLO1d\njC5HxtX+ZEztT8a0ddg7rk5OTk27rj0uNmvWLFasWEFKSgoRERHodDo0Glul39fXl6VLl2I0Glmy\nZAnR0dFoNBoMBgP9+/fn3nvv5csvv2TNmjUkJCScd+74+Hji4+PrHxcXF9ujyBK2n42Mp/3JuNqf\njKn9yZi2DnvHNTg4uEnHNZrMdTodBoOh/rHBYECn0513zIIFCwAwmUykpaXh7u5+3jGhoaEcOHCA\n0aNH4+zszKhRowCIjo7mxx9/bFKBJUmSJElqqNE+8/DwcPLy8igsLMRisZCamkpUVFSDY8rLy+v7\nw9evX09sbCxgS/xnm8UrKys5ePAgwcHBKIrCiBEjyMrKAmDv3r0NBtRJkiRJktR0jdbMHRwcmDNn\nDosXL0ZVVWJjYwkNDWXdunWEh4cTFRVFVlYWSUlJKIpCREQEc+fOBeD06dOsXr0aRVEQQjB16lR6\n9OgBwIwZM/jXv/7FqlWr8PLy4qGHHmrRDQghMJlMqKqKoigtOseVqqCgALPZXP9YCIFGo8HFxUXG\nUpIkqRNpdGpaR5Obm9vgcU1NDY6Ojmi1dun+v6JcaKCGxWKhrq4OV1fXdipV5yf7Iu1PxtT+ZExb\nR3v1mXf6FeBUVZWJ3I60Wm2DKYSSJElSx9fpk7lsDrY/GVNJkqTORVZpJUmSJOkyCasVseVHTH4B\n0H9wm1+/09fMuwIhBLfffjsVFRWUlZWxatWqFp1n1qxZlJWVXfKY5557js2bN7fo/JIkSVJDQrWi\nbk1B/ct8xPtvYNr0XbuUQybzDuCHH35gwIABeHp6Ul5ezurVqy94XGOrCq1ZswZvb+9LHjNnzhze\nfPPNFpdVkiRJAqGqiO2/oD77CGL5q+DoiGb+M3j/6aV2KY9sZreDOXPmkJubi9lsZu7cucycOZON\nGzeSmJiI1Wqt31imqqqKRYsWsXv3bhRF4fHHH2fy5MmsX7+eGTNmAPDiiy9y4sQJxo8fT0xMDHFx\ncSxZsgRvb2+ys7PZvHnzBa8HMHr0aL7++muqqqqYOXMmo0aNYtu2bQQGBrJixQpcXV0JCQmhpKSE\nwsLCJo+SlCRJkmyEELArHfWzJMg5BoEhKH/4I8qIMSgaTbuNOepSyVxd+y7i1DG7nlMJ7YXmrt9f\n8phXXnkFHx8fampqmDx5MjfccANPPfUUn3zyCT169KCkpASAf/zjH3h6evLDDz8AUFpaCkBGRgYv\nv/wyAM888wwHDx7k+++/ByA1NZU9e/bw448/1s/R/+31Jk2adN6qfMeOHePNN99kyZIlzJs3j6++\n+opbb70VgEGDBpGRkcHNN99spyhJkiR1bUII2LcT9bP/wPHD4BeIMudxlNExKBqH9i5e10rm7WXF\nihV8/fXXgG0e/AcffEB0dHR98vXx8QFg06ZNvPXWW/WvO7uxTGlpKR4eHhc9/9ChQ+vPdaHrHTt2\n7LxkHhoaWr+T3eDBgzl16lT9c3q9noKCghbfryRJ0pVEHNyD+ukHkL0fdH4osx9GuXocSgeaFt1x\nSmIHjdWgW0NqaiqbNm3iiy++wNXVldtuu43IyEiOHDnS5HOcndt9dnOa33Jzc7vk9c5dxe0sZ2fn\n+u8dHBwwmUz1j81mMy4uLk0unyRJUlOIg3ugugp6hIPOt9NPcxXZ+2018QO7oZsO5Z4HUK4dj3Jm\na++OpEsl8/ZQUVGBt7c3rq6uZGdns2PHDsxmM1u3buXkyZP1zew+Pj7ExMSwatUqnnvuOcBWI+/W\nrRu9e/fmxIkT9OrVC3d3dyorK5t1veY6evQoU6ZMafE9S5IknUvU1SLWvYf46Ztf/9PdE3r0RunR\nG3qEo/QIB/8glItUWjoScfywLYnv3QGe3ih3zkWJmYji5Nz4i9uJTOaXaezYsaxZs4brr7+e8PBw\nhg8fjl6v5+9//zu/+93vUFUVX19f1q5dy6OPPsozzzzDuHHj0Gg0PPHEE0yaNIm4uDi2bNlCr169\n0Ol0jBw5knHjxhEbG0tcXFyj12uOuro6jh8/zpAhQ+wZBkmSrlCiIBf17Zch5xjKDdNRhkUjTh2F\nk0cRJ48ifvgCLBYEgLMrhPb6NcGH9YbA0A7TXC1yjtkGtmWmgbsnyvR7UcZNRnHu+C2ZnX5t9urq\n6gbN0J1RQUEBjz76KGvXrm31a3399dfs2bOHP/7xjxdcmx26Rkzbk1zz2v5kTO3PHjFVMzYjVr8B\nDlo09z+GMmTkeccISx3k5SBOHoETR2yJ/tQxMJ/p+tM6Qvewc2rwvSGkZ5vWgkXeKcRnSYjtv4Cr\nG8r4aSjxN6G4Nv/vYIfdz1xqfQEBAdxzzz1UVFTg6enZqteyWCzMmzevVa8hSVLXJupqER+uQKR8\nBeFXofn9Uyh6vwseq2gdbbXx0F5wTbzt9aoVCvMQJ47YavCnjiK2p8Km72w1eI0GgkJRQntDWG+U\ngBBwdgZHJ9s/JydwdAZHx/qvLWm+F4W5iC/WItJ+BicnlEl3oEyYhuJ+8QHJHZWsmV/BZM28dcha\npP3JmNpfS2MqCvNQ33kZTh5FmXALyi2z7NJMLoQAY5Gt9n7yCOKkrameMmPTTqB1PJPkz/nndDbh\n275XHJ1+/QBQU2WriWu1KGMno0ycjuJ56UW3mkLWzCVJkqR6Qgg4vA+x+XtETQ2a8Teh9BvYvmXa\n/gvq+2+AokHz8CKUIaPsdm5FUUDvD3p/lOFX/3rNshIoyoO6OqithTozorYWLLVnHp/z9ey/WjOi\nrq7+e8wmqCxHnHusUFFiJ6PceBuKt4/d7qO9NCmZZ2ZmsnLlSlRVJS4ujmnTpjV4vqioiGXLllFe\nXo6HhwcJCQno9XqKiopYunQpqqpitVqZOHEiEyZMaPDal19+mcLCQl555RX73ZUkSVInJUqNiC0/\nIjYnQ2EuuLqB1hE1cyv0i0Qz5S64anCbTvsSdXWI/61AbNwAvfqhmfdHFL1/m1xb8faB3yTbzj3h\nrXU0msxVVWX58uUsWrQIvV7PwoULiYqKIiQkpP6YNWvWEBMTw9ixY9m7dy9JSUkkJCTg4+PDCy+8\ngKOjIyaTiSeffJKoqKj6BU7S0tLkfGdJkq54wmKBvdtQNyfDnm2gqtAvEmXyHSgjrgEFxKbvEN98\njPrqn2391FPuhMjhrZ7
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc43599eba8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.2 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `learning_rate = 0.5`"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 3.5s to complete\n",
" error(train)=2.79e-01, acc(train)=9.20e-01, error(valid)=2.74e-01, acc(valid)=9.23e-01\n",
"Epoch 10: 3.7s to complete\n",
" error(train)=2.68e-01, acc(train)=9.24e-01, error(valid)=2.72e-01, acc(valid)=9.26e-01\n",
"Epoch 15: 3.9s to complete\n",
" error(train)=2.55e-01, acc(train)=9.28e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 20: 3.5s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.61e-01, acc(valid)=9.29e-01\n",
"Epoch 25: 4.3s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.73e-01, acc(valid)=9.25e-01\n",
"Epoch 30: 3.5s to complete\n",
" error(train)=2.47e-01, acc(train)=9.31e-01, error(valid)=2.70e-01, acc(valid)=9.27e-01\n",
"Epoch 35: 3.2s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.69e-01, acc(valid)=9.27e-01\n",
"Epoch 40: 4.1s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.72e-01, acc(valid)=9.26e-01\n",
"Epoch 45: 4.3s to complete\n",
" error(train)=2.36e-01, acc(train)=9.35e-01, error(valid)=2.66e-01, acc(valid)=9.29e-01\n",
"Epoch 50: 3.7s to complete\n",
" error(train)=2.38e-01, acc(train)=9.33e-01, error(valid)=2.69e-01, acc(valid)=9.28e-01\n",
"Epoch 55: 3.9s to complete\n",
" error(train)=2.36e-01, acc(train)=9.34e-01, error(valid)=2.68e-01, acc(valid)=9.26e-01\n",
"Epoch 60: 4.0s to complete\n",
" error(train)=2.46e-01, acc(train)=9.29e-01, error(valid)=2.81e-01, acc(valid)=9.22e-01\n",
"Epoch 65: 4.1s to complete\n",
" error(train)=2.33e-01, acc(train)=9.35e-01, error(valid)=2.70e-01, acc(valid)=9.28e-01\n",
"Epoch 70: 3.6s to complete\n",
" error(train)=2.35e-01, acc(train)=9.36e-01, error(valid)=2.75e-01, acc(valid)=9.27e-01\n",
"Epoch 75: 4.4s to complete\n",
" error(train)=2.31e-01, acc(train)=9.36e-01, error(valid)=2.70e-01, acc(valid)=9.26e-01\n",
"Epoch 80: 3.6s to complete\n",
" error(train)=2.35e-01, acc(train)=9.34e-01, error(valid)=2.76e-01, acc(valid)=9.25e-01\n",
"Epoch 85: 4.0s to complete\n",
" error(train)=2.32e-01, acc(train)=9.35e-01, error(valid)=2.75e-01, acc(valid)=9.26e-01\n",
"Epoch 90: 3.6s to complete\n",
" error(train)=2.29e-01, acc(train)=9.37e-01, error(valid)=2.74e-01, acc(valid)=9.26e-01\n",
"Epoch 95: 4.2s to complete\n",
" error(train)=2.31e-01, acc(train)=9.35e-01, error(valid)=2.76e-01, acc(valid)=9.27e-01\n",
"Epoch 100: 3.4s to complete\n",
" error(train)=2.31e-01, acc(train)=9.36e-01, error(valid)=2.77e-01, acc(valid)=9.27e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl4VOXZx/Hvc2ay78lkIRtL2Hcw7MgaUFRU1KqtS6u2\nVq1LbV2xVtxxw9pa9dWiVVGLFlE2FcIuoKIIiEDYwhKSELKTPZnzvH8MpkSWbDOZGXJ/rquXJHPm\nnHueTuY355xnUVprjRBCCCE8nuHuAoQQQgjRNBLaQgghhJeQ0BZCCCG8hIS2EEII4SUktIUQQggv\nIaEthBBCeAkJbSGEEMJLSGgLIYQQXkJCWwghhPASEtpCCCGEl7C6u4BTyc7OdncJZxWbzUZ+fr67\nyzirSJu6hrSr80mbuoaz2zU+Pr5J28mZthBCCOElJLSFEEIILyGhLYQQQngJCW0hhBDCS0hoCyGE\nEF5CQlsIIYTwEhLaQgghhJeQ0BZCuJ3WmuV7iykor3F3KUJ4NAltIYTbbc+r5O9f5fLSmn3uLkUI\nj9akGdE2b97MW2+9hWmaTJw4kUsvvbTB44sWLWL58uVYLBZCQ0O59dZbiY6OBmDOnDls2rQJrTX9\n+vXjhhtuQCnl/FcihPBaCzOKAFi+K5+LuwbTKcLfzRUJ4ZkaPdM2TZPZs2czffp0XnzxRdatW0dW\nVlaDbTp16sTMmTN5/vnnGT58OHPmzAEgIyODjIwMnn/+eV544QX27t3L9u3bXfNKhBBe6Wh5LV9n\nHWNy1zCCfC385weZclOI02k0tPfs2UNcXByxsbFYrVZGjhzJxo0bG2zTt29f/Pz8AOjWrRuFhYUA\nKKWoqamhrq6O2tpa7HY7YWFhLngZQghvtWSX4yz7yr42rh6UwIZDZewrrHJzVUJ4pkZDu7CwkKio\nqPqfo6Ki6kP5VFasWMHAgQMB6N69O3369OHmm2/m5ptvZsCAASQmJjqhbCHE2aC6zmTpnmKGJYYQ\nHeTDlYPiCfI1+EDOtoU4Jaeu8rVmzRr27dvHjBkzAMjNzeXw4cO89tprADz++OPs2LGDXr16NXhe\neno66enpAMycORObzebMsto9q9Uqbepk0qbOsWBbLmU1JtcO64TNFobVauWac5J4fcMBjtr96BUb\n4u4SvZ68V13DXe3aaGhHRkZSUFBQ/3NBQQGRkZEnbbd161bmz5/PjBkz8PHxAeCbb76hW7du+Ps7\nOpUMGjSIXbt2nRTaaWlppKWl1f8sy8g5lyzN53zSpq2nteY/3x6ic4QfCb415OfnY7PZGJ/kywff\nGby6Zg9/HZ/k7jK9nrxXXcNjl+ZMSUkhJyeHvLw86urqWL9+PampqQ22yczM5I033uC+++5rcM/a\nZrOxY8cO7HY7dXV1bN++nYSEhGa+FCHE2eiHIxUcKKnmoh4RDUaUBPpYmNY7iu+yy9l5tNKNFQrh\neRo907ZYLNx44408+eSTmKbJ+PHjSUpKYu7cuaSkpJCamsqcOXOoqqpi1qxZgCOs77//foYPH862\nbdu45557ABg4cOBJge9KuroavWw+ql8qqmPXNjuuEKJxizKKCPWzMKZT6EmPXdA9gk93FPLB1qM8\nOjHZDdUJ4ZmU1lq7u4ify87Odsp+dEU55kM3Q1IXjLsfa7fjw+XymPNJm7bOkbIafv/pPi7vE8V1\nA6Prf39iu36yo4C3Nh3lqUnJ9IkJdFepXk/eq67hsZfHvZkKDEJdeBXs2ALbN7u7HCHEcUt2FaMU\nXNA9/LTbTOkWQYS/hQ+2SuAI8ZOzOrQB1NgpEBWDOe/faNN0dzlCtHuVtSbL9hQzMjmEqECf027n\nZzW4vE8UPxypYGtueRtWKITnOvtD28cHNe06OJSJ/maNu8sRot1blVlCea3JRT0iGt32vG7hRAZY\n+WBrPh54J0+INnfWhzaAGnIuJHdBfzIHXVvr7nKEaLe01izKKKJrpD89bQGNbu9rMfhF3yi2H61k\nS25FG1QoROPspsZuuudLZPsIbcPAuPzXUJCHXr3E3eUI0W5tya0gq7TmpGFeZzIpJQxboJX3tx6V\ns23hNoWVdSzfW8xzXx7m+nm72Zpd6pY6nDojmidTvQdB74HoxR+iR6ahAoPcXZIQ7c7CnYWE+VsY\n3bHpM535WAyu7GvjlW9y2ZRdzjkJwS6sUAiHWrtmZ34Fm7LL+T6nnMyiagAi/C0MTQwh2M8CtP2V\n23YT2gDGZb/GfOJu9BcfO+5zCyHaTHZpDd9ml3NVvyh8LM27yDehSxj//bGA97fmMzg+qN0O3xSu\ndaSspj6kt+RWUFVnYlHQKyaQ6wdGMzg+iE7hfiilsNmCyc9v+4Vt2lVoq44pqKFj0emfosddgIqI\navxJQginWLKrCKsB53drvAPaz/lYFFf1i+IfX+Wy8XAZQxNlTnLRetV1Jj/mOc6mN+WUc7i0BoCY\nICvjOocyuEMQ/eICCfSxuLnS/2lXoQ2gLr0G/d069MIPUNff7u5yhGgXKmrtpO8tYVRyKJEBLfvY\nGd85jI+2Oc62hyQEy9m2aDatNYdLa9iUU86m7HJ+zKugxq7xtSj6xgQypVs4g+KDSAjx9dj3V/sL\n7eg41Lgp6BWL0ZMuQXWQBQmEcLUV+0qorGvaMK/TsRiKq/vZ+NuGHL7KKmNEkpxtN0WdvX3PT1FR\na2dr7k/3psvIK68DIDHUl/O6hTO4QxB9YgLxs3pHv+x2F9oA6sIr0evSMT9+F8sfpru7HCHOaqbW\nLM4oonuUP92bMMzrTMZ0CuWjHwv4YEs+wxKDMTz0bMgT1NhNnllzmO+ydxIT7ENiqC9JYX4khvqS\nGOZLUqjf8c5UZ5/qOpOvs8pYlVnC5pxy7Br8rQYD4gK5vE8QgzoEERvs6+4yW6R9hnZIGOr8yx3j\ntvdsR3Xt7e6ShDhrfZ9dTvaxWv48KrrxjRvx09n2C+uyWX/wGKM7nrzYiIBau8nMNYfZlF3OJf3i\nyC8tJ6ukhq25FdSeML44zN9CUqgvicfDPCnMj8QwX6ICrB57efh0TK3ZdqSClZmlbDh4jMo6E1ug\nlYt7RpKaEEwPWwA+Fu96TafSLkMbQKVdjF65BHPe2xj3zfS6N6gQ3mJRRhERAVanXc4elRzCh9t8\n+WBrPiOSQrAY8rd7olq75pm12XyXXc4fhsXxq+Fd6xe2sJuavPJaskpqOFRaTVZJDVmlNaw9UEp5\nzf8uo/tbjQZn5Ilhjn/HBfti9bD2PlhSzap9JazeX0p+RR0BVoORySGM6xxK39jAs+5qTPsNbT9/\n1MVXo999BbZ8DQOHu7skIc46WSXVbMop51f9bU47y7EYil/2s/Hsl9l8eaCUsZ3DnLLf5tBae+QX\n/TpT8/y6w2w8XMYtQ2KZ3LXhgiwWQ9EhxJcOIb4M4X/j3bXWFFfZOVRSTVZpDVnH//tDbgWrMv83\niYjVgLhgxxl51yh/ukf50zXKv817VxdX1rHmQCmrMkvYW1iNoWBQhyB+PSiGYYnBXnN/uiXabWgD\nqFGT0Ms+xZz3Dka/ISjL2Xl/Rwh3WbyrCKuhOK/r6VfzaokRySF0CvfjPz8UMLpjaJudbZta8+G2\nAhZnFHHLkFhGedDlebupmbUum68OlfG71BimdG96pz+lFBEBViICrPSPazjxVEWtvf6M/FBJNYdL\na8gsqmLDoWOO5wKJYb50iwqo77fQMdzP6WfkJ96n/j6nHFNDSqQfN50Tw5iOoYS3cFSCtznrX+WZ\nvhEriwVj2vWYrz6NXr8cde7kNq5OiLNXeY2dFftKGNMpxOkfqIZS/LK/jafXHGb1/lImdHH92XZV\nnclLG3JYf/AYEQFWnv0ym1+W1HBlvyi3X4K1m5q/bchh3cFj3Dg4hot6RDpt34E+FrrbAk7qRFha\nbWdPQSW7CqrYnV/Jt4fLWLGvBABfi6JzhH99iHeL8icu2KfZVyd+uk+9KrOU9cfvU0cFWpnWK5Jx\nXcJIDvNz2uv0Fmd1aNt
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc4685a61d0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAENCAYAAAAbl4wiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3WdgVMXawPH/nE0lfTeQUEILCIGAlAgRIRIIRZqAXDsW\n0NfCRUWwoNgBUQL2ckVAilywgAVUNAKCIIQWihGkQ4CQ3vueeT+sRHMJpO1mN8n8vmjIKc9ONnn2\nTHlGSCkliqIoiqLUWZq9A1AURVEUpWZUMlcURVGUOk4lc0VRFEWp41QyVxRFUZQ6TiVzRVEURanj\nVDJXFEVRlDpOJXNFURRFqeNUMlcURVGUOk4lc0VRFEWp41QyVxRFUZQ6zsneAVTVuXPn7B1CveHv\n709KSoq9w6h3VLtan2pT61NtahvWbtdmzZpV6jj1ZK4oiqIodZxK5oqiKIpSx6lkriiKoih1XJ0b\nM/9fUkoKCgrQdR0hhL3DqVMuXLhAYWFh6ddSSjRNw83NTbWloihKHVLnk3lBQQHOzs44OdX5l1Lr\nnJycMBgMZf6tpKSEgoIC3N3d7RSVoiiKUlV1vptd13WVyK3IyckJXdftHYaiKIpSBXU+mavuYOtT\nbaooilK31PlkriiKotRfMvEs+rYNSNVjeEUqmTsAKSX/+te/yM7OJjMzk08++aRa1xk/fjyZmZlX\nPObll1/m119/rdb1FUVRapNMTkSfOx25+E30d2cic3PsHZLDUsncAfz888906tQJLy8vsrKyWLp0\nabnHlZSUXPE6y5Ytw8fH54rHTJgwgffee6/asSqKotQGmZmO/sbzYDYjRt4K8XHorzyGPHXM3qE5\nJJXMrWDChAkMHTqUyMhIli9fDsDGjRsZMmQIUVFR3HzzzQDk5uYyZcoUBg4cSFRUFOvWrQNgzZo1\nDBkyBIDZs2dz6tQpBg0axCuvvMK2bdsYM2YM99xzD/3797/s/QB69+5NWloaZ86c4frrr+eJJ54g\nMjKS2267jfz8fABatGhBeno6SUlJtdU8iqIoVSLzctHfehEy09EmP4c26na0J18FXUef8yT6lh/t\nHaLDqdQ08Li4OBYvXoyu6wwcOJDRo0eX+X5ycjIffPABWVlZeHp6MnnyZEwmE8nJyURHR6PrOmaz\nmaFDhzJ48GDA8pS5cOFC4uPjEUJw6623Eh4eXqMXo69cgDxzokbX+F8iqA3arfdf8Zh58+bh5+dH\nfn4+w4cPZ8iQITzxxBOsXr2ali1bkp6eDsCbb76Jl5cXP//8MwAZGRkA7Ny5k9deew2AZ555hsOH\nD/PTTz8BsG3bNg4cOMCGDRto2bJlufcbNmwYRqOxTEwnTpzgvffeY+7cuTzwwAN899133HTTTQB0\n6dKFnTt3cuONN1qplRRFsRVZUmJ5OnV1tXcotUIWF6G/NwvOnUb79wxEcEcARNsOaM+9gb4gGrn0\nXfTjhxG3/R/CpWG0S0UqTOa6rrNw4UJmzJiByWRi+vTphIWF0aJFi9Jjli1bRkREBP379+fgwYOs\nWLGCyZMn4+fnx8yZM3F2dqagoICpU6cSFhaG0Whk9erV+Pj48NZbb6HrOjk5dXcsZNGiRXz//feA\nZSOY5cuXEx4eXpp8/fz8ANiyZQvvv/9+6Xm+vr6AJal7enpe9vrdunUrvVZ59ztx4sQlyTwoKIjQ\n0FAAunbtypkzZ0q/ZzKZuHDhQrVfr6IotifPn0H++hPyt43g7Iz23JsIT297h2VT0mxG/yga/jyI\nuG8qIrRnme8LLx+0x15Efv1f5HefIU8fQ3vwaUTjQDtF7DgqTOZHjx4lMDCQgIAAAPr06cPOnTvL\nJPOEhATuuusuADp37szcuXMtF//H+u/i4uIy65c3btzIG2+8AYCmaXh71/xNWtETtC1s27aNLVu2\n8O233+Lu7s64cePo3Lkzx45Vflzn4tpuTSt/1KNRo0ZXvN8/q7hd5PqPT/EGg4GCgoLSrwsLC3Fz\nc6t0fIqi1A5ZkI/cuQW5NQaOHQKDAUJ7wsE96EveRXt4er1dOiqlRH76AcRtR9x6P1rv68s9TmgG\nxJg7kW07oC+ajz5zCtrExxFdr6nliMuSUsL+XRT6+UHLdrV+/wrHzNPS0jCZTKVfm0wm0tLSyhzT\nqlUrYmNjAYiNjSU/P5/s7GwAUlJSmDZtGg899BA33ngjRqOR3NxcAFatWsVTTz3F/PnzS7uc65rs\n7Gx8fHxwd3fn6NGj7Nmzh8LCQrZv387p06cBSrvZIyIiysxUv/ia27Zty6lTpwDw8PC4Yi9Fefer\nquPHj9OhQ4cqn6coivVJKZHHDqEveQd92j3Ipe9Cbg7iX/eivb4Yw79nIMaOh7jtyHo8Viy/Wo7c\n8iNi2M1oA0dWeLy4+hq0GW+AqQn6O6+gf/0pUjfXQqRlSV1H7tmGPnMK+ruvkLd2Va3HAFYq5zp+\n/HgWLVrEpk2bCAkJwWg0lj5l+vv7Ex0dTVpaGnPnziU8PBxN00hNTaVDhw7cfffdrF27lmXLljF5\n8uRLrh0TE0NMTAwAc+bMwd/fv8z3L1y4YNcKcFFRUSxfvpz+/fsTHBxMz549adKkCfPmzeP+++9H\n13X8/f35/PPPmTp1Kk8//TQDBgzAYDAwbdo0hg8fzuDBg9mxYwft27enSZMm9OrViwEDBpROlBNC\nlL7G8u5nMBhwcnJCCIHBYCgt0XrxHE3T0DQNJycniouLOXXqFD179ixzzD+5urpe0s5K5Tk5Oan2\ns7L62KZ6Zjr5m34gP+ZbzAknEW7uuPUdiHvUSJw7hJZ5Ape3TiTjz4MUfbYQ3959cWreqsb3d6Q2\nzf12FTnffY77oFF43fdo5Xsf/P2RcxeR9VE0BWtX4XzmOD6Pv4Tm7WvbgLEMCRT+tpHczz+h5PRx\nDE1b4DH5WTwHDKf2P1KAkFLKKx3w559/8vnnn/Pss88ClpnXAGPGjCn3+IKCAh577DE+/PDDS773\n/vvv06NHD3r37s1dd93FkiVL0DSNlJQUZs+ezfz58ysM+Ny5c2W+zsvLK9MNXRdduHCBRx99lJUr\nV9r8Xt9//z0HDhzgySefxMnJqdzlbvWhTe3J39+flJQUe4dRr9SXNpW6GX6PQ//1J9gXC+YSaNsB\n0XcQ4pq+CLfL/97JjFT0lx4BYxO06a8jnJxrFIujtKm+fRNy4XzocS3aA08iNEPFJ5V3nS0/Ilf8\nB7x9LOPoba6ycqQW0my2DIWs+wwSEyCwBWL4zYhr+iEMBqu3a7NmzSp1XIWPtMHBwZw/f56kpCSM\nRiPbtm3jkUceKXPMxVnsmqaxZs0aIiMjAUhNTcXLywsXFxdycnI4fPgwI0aMQAhBz549iY+PJzQ0\nlIMHD5YZg29oAgICuP3228nOzsbLy8um9yopKeGBBx6w6T0URSlLJicit/2M3PozpKeApzdiwHBL\nEm/WsuILAMLXhHb3ZPT3ZiO/+hQx7h7bBl0L5IHdyE/egg5d0O6bWu1EDqD1G4xs2Rb9gznorz2N\nuPU+xPU3WG2OgSwpQe7YhPzuc0g6D81bIf7vSUTPa2sUt7VU+GQOsGfPHpYsWYKu60RGRjJ27FhW\nrVpFcHAwYWFhbN++nRUrViCEICQkhIkTJ+Ls7Mz+/ftZunQpQgiklAwdOpSoqCjAspzt3XffJTc3\nF29vbx5++OFKdfnUxydze1FP5rbhKE889UldbFNZXITc85tlMtsf+0AI6Nwdre8guLpXtZ+s9WXv\nI7esR5vyMiLk6mrHZ+82lccOoc+fAYEt0KbNRrhb52+OzM1GX/gGHNiFCO+PuHNSjZb1yZJiywex\n776A1CRo2RZt+C3QrTeinEnL9noyr1QydyQqmVuPSua2Ye8/kvVRXWpTeeaEZUnZ9k2QlwOmJoi+\nUYg+AxHGxjW/fmEB+swpUJCP9sLb1V6uZs82lWdPo7/+NHh4oj39GsLbz7rX13XL0rVv/gvNWqI9\nNB0RULmkWHqN4iLLz/G
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435c37438>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.5 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"|`learning_rate`| Final `error(train)` | Final `error(valid)` |\n",
"|---------------|----------------------|----------------------|\n",
"| 0.05 | $2.53\\times 10^{-1}$ | $2.59\\times 10^{-1}$|\n",
"| 0.1 | $2.43\\times 10^{-1}$ | $2.59\\times 10^{-1}$|\n",
"| 0.2 | $2.35\\times 10^{-1}$ | $2.63\\times 10^{-1}$|\n",
"| 0.5 | $2.31\\times 10^{-1}$ | $2.77\\times 10^{-1}$|\n",
"\n",
"<span style=\"color:red\">\n",
"Increasing the learning rate, as would be expected, increase the speed of learning, with the final training error reached monotonically decreasing over the learning rates tested as the learning rate was increased. Note however the validation set error increases for larger learning rates - this suggests the model is overfitting to the data, with the larger learning rates causing the model to begin overfitting sooner - we could have afforded to halt learning earlier in these cases when there was no further improvement in the validation set error. Notice also the error curves for the largest learning rate value are much more noisy suggesting learning is becoming quite unstable with this large a step size, with a lot of the gradient descent steps overshooting and causing the error function value to increase.\n",
"</span>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Optional extra: more efficient softmax gradient evaluation\n",
"\n",
"In the lectures you were shown that for certain combinations of error function and final output layers, that the expressions for the gradients take particularly simple forms. \n",
"\n",
"In particular it can be shown that the combinations of \n",
"\n",
" * logistic sigmoid output layer and binary cross entropy error function\n",
" * softmax output layer and cross entropy error function\n",
" \n",
"lead to particularly simple forms for the gradients of the error function with respect to the inputs to the final layer. In particular for the latter softmax and cross entropy error function case we have that\n",
"\n",
"\\begin{equation}\n",
" y^{(b)}_k = \\textrm{Softmax}_k\\lpa\\vct{x}^{(b)}\\rpa = \\frac{\\exp(x^{(b)}_k)}{\\sum_{d=1}^D \\lbr \\exp(x^{(b)}_d) \\rbr}\n",
" \\qquad\n",
" E^{(b)} = \\textrm{CrossEntropy}\\lpa\\vct{y}^{(b)},\\,\\vct{t}^{(b)}\\rpa = -\\sum_{d=1}^D \\lbr t^{(b)}_d \\log(y^{(b)}_d) \\rbr\n",
"\\end{equation}\n",
"\n",
"and it can be shown (this is an instructive mathematical exercise if you want a challenge!) that\n",
"\n",
"\\begin{equation}\n",
" \\pd{E^{(b)}}{x^{(b)}_d} = y^{(b)}_d - t^{(b)}_d.\n",
"\\end{equation}\n",
"\n",
"The combination of `CrossEntropyError` and `SoftmaxLayer` used to train the model above calculate this gradient less directly by first calculating the gradient of the error with respect to the model outputs in `CrossEntropyError.grad` and then back-propagating this gradient to the inputs of the softmax layer using `SoftmaxLayer.bprop`.\n",
"\n",
"Rather than computing the gradient in two steps like this we can instead wrap the softmax transformation in to the definition of the error function and make use of the simpler gradient expression above. More explicitly we define an error function as follows\n",
"\n",
"\\begin{equation}\n",
" E^{(b)} = \\textrm{CrossEntropySoftmax}\\lpa\\vct{y}^{(b)},\\,\\vct{t}^{(b)}\\rpa = -\\sum_{d=1}^D \\lbr t^{(b)}_d \\log\\lsb\\textrm{Softmax}_d\\lpa \\vct{y}^{(b)}\\rpa\\rsb\\rbr\n",
"\\end{equation}\n",
"\n",
"with corresponding gradient\n",
"\n",
"\\begin{equation}\n",
" \\pd{E^{(b)}}{y^{(b)}_d} = \\textrm{Softmax}_d\\lpa \\vct{y}^{(b)}\\rpa - t^{(b)}_d.\n",
"\\end{equation}\n",
"\n",
"The final layer of the model will then be an affine transformation which produces unbounded output values corresponding to the logarithms of the unnormalised predicted class probabilities. An implementation of this error function is provided in `CrossEntropySoftmaxError`. The cell below sets up a model with a single affine transformation layer and trains it on MNIST using this new cost. If you run it with equivalent hyperparameters to one of your runs with the alternative formulation above you should get identical error and classification curves (other than floating point error) but with a minor improvement in training speed.\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 0.4s to complete\n",
" error(train)=3.11e-01, acc(train)=9.13e-01, error(valid)=2.92e-01, acc(valid)=9.18e-01\n",
"Epoch 10: 0.4s to complete\n",
" error(train)=2.89e-01, acc(train)=9.20e-01, error(valid)=2.77e-01, acc(valid)=9.23e-01\n",
"Epoch 15: 0.4s to complete\n",
" error(train)=2.79e-01, acc(train)=9.22e-01, error(valid)=2.70e-01, acc(valid)=9.24e-01\n",
"Epoch 20: 0.4s to complete\n",
" error(train)=2.72e-01, acc(train)=9.24e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 25: 0.4s to complete\n",
" error(train)=2.68e-01, acc(train)=9.25e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 30: 0.4s to complete\n",
" error(train)=2.63e-01, acc(train)=9.27e-01, error(valid)=2.62e-01, acc(valid)=9.26e-01\n",
"Epoch 35: 0.4s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 40: 0.4s to complete\n",
" error(train)=2.59e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 45: 0.4s to complete\n",
" error(train)=2.55e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 50: 0.4s to complete\n",
" error(train)=2.54e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 55: 0.4s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 60: 0.4s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 65: 0.4s to complete\n",
" error(train)=2.50e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 70: 0.4s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.59e-01, acc(valid)=9.31e-01\n",
"Epoch 75: 0.4s to complete\n",
" error(train)=2.47e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 80: 0.4s to complete\n",
" error(train)=2.46e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 85: 0.4s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 90: 0.4s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 95: 0.4s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 100: 0.4s to complete\n",
" error(train)=2.43e-01, acc(train)=9.33e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNX9+P/XmSWTBbLMTEgCJEBCWAMBjYiICxAVFRX3\nWtt+FPupP6xL9dMPCFpL64eWVkT7betHpZRq+bRNK+ICCjRgRYnYKCD7EhL2QMhMFrJNMnPv748b\nBgKBBLhDtvfz8chj5s6cuffcN0PeOeeee47SdV1HCCGEEO2epa0rIIQQQojWkaQthBBCdBCStIUQ\nQogOQpK2EEII0UFI0hZCCCE6CEnaQgghRAchSVsIIYToICRpCyGEEB2EJG0hhBCig5CkLYQQQnQQ\ntrauQHMOHz7c1lXoVNxuN6WlpW1djU5FYhoaElfzSUxDw+y49uzZs1XlpKUthBBCdBCStIUQQogO\nQpK2EEII0UG0y2vaQgghLg1d16mrq0PTNJRSbV2dDuPo0aP4fL7z+oyu61gsFsLDwy841pK0hRCi\nC6urq8Nut2OzSTo4HzabDavVet6f8/v91NXVERERcUHHle5xIYTowjRNk4R9CdlsNjRNu+DPS9IW\nQoguTLrEL72LiXmnTtp6XS3ax++gF2xv66oIIYQQF61VfSIbN25k4cKFaJrGhAkTmDx5cpP3V65c\nyYoVK4IX2B999FF69+7N8ePHmTdvHgUFBVx//fU88sgjITmJs7La0D9+B44eRvUffGmPLYQQot2Y\nP38+sbGx3HvvveTk5HDdddeRmJh4Xvt4++23iYiI4N577z1rme3bt/PGG2/w6quvXmyVm9Vi0tY0\njQULFvD888/jcrmYMWMGWVlZ9O7dO1hm7Nix3HjjjQB89dVXvPXWWzz33HPY7Xbuv/9+9u/fz4ED\nB0JyAuei7HZU5ij0jV+i+/0ouW4jhBAdUiAQaDLw6/Tts/H7/QDk5OSwfPlyAP7xj38waNCgZpP2\nufb7ve99r8XjDR48mOLiYg4dOkSvXr1aLH++WuweLygoIDExkYSEBGw2G2PGjCE/P79JmcjIyODz\nurq6YH99eHg4gwYNIiwszORqt566/GqoPg47NrVZHYQQQpzb4sWLufXWW7nhhhuYNm0agUCA9PR0\nfvazn5Gdnc3XX3/NlVdeyezZs7nppptYunQpW7ZsYdKkSWRnZ/PII49QXl4OwD333MMLL7zAzTff\nzB/+8AfWrl1LRkYGNpuNpUuX8s033/D4449zww03UFtbe8Z+/+///o9bbrmF7Oxs/vM//5Pa2loA\nXn75ZV5//XUA7rzzTmbPns2tt97K2LFj+fLLL4PncsMNN/D++++HJE4tNj29Xi8ulyu47XK52L17\n9xnlli9fzrJly/D7/bzwwgvm1vJiDB0J4RHoX69FZVzW1rURQoh2S/vbfPQDRabuUyX3w/Kt/zxn\nmd27d/PBBx/w3nvvYbfbmTFjBu+++y41NTWMHDmSn/70p8GycXFxrFixAoDs7GxefPFFrrrqKl56\n6SXmzZvHz3/+cwAaGhr4+OOPAZg7dy7Dhw8HYNKkSfzpT3/iJz/5CZmZmc3u1+v18uCDDwLwq1/9\nir/+9a9MmTLljHr7/X6WLVvGqlWrmDdvHjk5OQBkZmbyu9/9jscee+yCYnYupvUXT5w4kYkTJ/L5\n55+zePFiHn/88VZ/Njc3l9zcXADmzJmD2+02q1oAVIy6Bt/6dbie+kmX7CK32Wymx7Srk5iGhsTV\nfC3F9OjRo8FbvvwWC5rJo8ktFkuLt5Tl5eWxefNmbr31VsDose3RowdWq5U77rgj2F2tlOLOO+/E\nZrNRWVlJZWUl11xzDQAPPPAA3//+97HZbE3KARw7doyBAwcGt5VSWK3WJtunli8oKGDOnDlUVFRQ\nXV3NuHHjsNlsWCyWJudz2223YbPZGDlyJAcPHgy+npCQQElJyVnP2+FwXPD3vMUM5nQ68Xg8wW2P\nx4PT6Txr+TFjxjB//vzzqkR2djbZ2dnBbbNXpNGHZaGvWUnp56u7ZGtbVvkxn8Q0NCSu5msppj6f\n7+Q13PseCcktRSeuK59NIBDg3nvvZcaMGU1ef+2119B1Pfh5XddxOBz4/X78fn+T907dPrUcGEmy\npqamyX4CgUCz+wV48sknWbBgAUOHDiUnJ4cvvvgCv9+PpmlomhYsZ7Vamxz/xPPq6uom+zudz+c7\n49/EtFW+0tLSKC4upqSkBL/fT15eHllZWU3KFBcXB5+vX7+epKSkVh38khl6WbCLXAghRPsyduxY\nli5dGkxkZWVlHDx48JyfiY6OJiYmJngtefHixYwePbrZsv3792fv3r3B7aioKKqqqs6676qqKhIS\nEmhoaGDJkiXneTZQWFjIwIEDz/tzrdFiS9tqtTJlyhRmz56NpmmMGzeO5ORkcnJySEtLIysri+XL\nl7N582asVivdunXjhz/8YfDzP/zhD4N/4eTn5/P88883GXl+KSh7GGr4KPQN69AfnNolu8iFEKK9\nGjBgANOmTeOBBx5A13VsNhuzZ89u8XOvvvoqzz77LHV1daSkpDBv3rxmy40fP54nn3wyuH3ffffx\n7LPPEh4ezgcffHBG+f/+7/9m0qRJuFwuRo4cec4E35y8vDwmTJhwXp9pLaXruh6SPV+Ew4cPm75P\nfeM6tN//AsuPfoYaOtL0/bdn0uVoPolpaEhczddSTGtqaprcAdRZPfLIIzz33HOkpqaasj+bzdZs\n97fP5+Puu+/mvffeO+s17eZiblr3eKchXeRCCNFlzZgxg5KSkpAf59ChQ8ycOTNk87l3maR9sov8\nC/QWBkUIIYToXPr373/Wa95mSk1NZcyYMSHbf5dJ2gAq62qoOg67Nrd1VYQQQojz1qWSNkNHgiMC\n/SvpIhdCCNHxdKmkrcIcqMwrjFHkgUBbV0cIIYQ4L10qaUPjXORVlbBTusiFEEJ0LF0uaZNxmdFF\nLqPIhRCiy5g/fz7/+Mc/LuizP/rRj1i6dCkAP/7xj9m1a9cZZXJycnjuuecAWLhwIX/7298uvLLn\n0OWStgpzoIZnoa//QrrIhRCigwic9vv69O2zOTG9aE5ODnfeeedF12Pu3LkMGDDgnGW+9a1v8cc/\n/vGij9WcLpe04cQo8krYtaWtqyKEEIJLtzRnQUFBcGESgAMHDgRnL3vllVe45ZZbGD9+PNOmTaO5\nucfuuecevvnmG8BoXY8dO5Zbb72Vr776KlgmIiKC5ORkNmzYYHqcuuZ8nhmXgyMc/au1qMGZLZcX\nQogu4A9fHaWorM7UffaLC+f7WQnnLHMpl+bs378/9fX17N+/n5SUFD744ANuu+02AB566CGefvpp\nAJ544gn++c9/cuONNzZb56NHjzJ37lyWL19O9+7duffee8nIyAi+P3z4cL788ktGjjR3Bs6u2dIO\nc6CGX2FMtCJd5EII0aY+//xzNm/ezC233MINN9zA559/zv79+7FarU1axQC33347AJWVlVRUVHDV\nVVcBcO+99wYXDzm1HEBJSQkulyu4fdtttwXnHP/ggw+CZfPy8pg0aRITJkwgLy+v2WvXJ3z99ddc\nddVVuFwuwsLCmhwPjOljjx49eiHhOKeu2dLGGEWu539mdJFLa1sIIVpsEYeKruvNLs35+uuvn1w2\ntFFr50k/tVx4eDh1dSd7EG6//XYeffRRbr75ZpRSpKamUldXx8yZM/noo4/o1asXL7/8Mj6f74LP\nyefzER4efsGfP5su2dIGjC7yMIdMtCKEEG3sUi/N2bdvX6xWK6+++mqwhXwiQTudTqqrq1m2bNk5\nj3/55Zezbt06vF4vDQ0NwdHlJxQWFjJo0KBz7uNCdN2WtuOULvJvP4o67a85IYQQl8alXpoTjNb2\niy++yLp16wCIiYnh29/+NhMmTCA+Pp7MzHP3wCYkJPBf//Vf3H777cTExDB06NAm7+fn5/PMM8+0\neA7nq8sszdkc/eu1aK/
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc4359d8080>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAENCAYAAADngqfoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNX9//HXnZns+0xCFhJCCIsBBIlRKGokJG6gliLa\nuuACrfUnotalVYtLXVpa+LpRt68FFyxfrQvWjaoBsWiEhE2FBEhYZElClkkyk2Rmkpl7fn8MDETA\nQJiQ7fN8PPJIZubMveceEt5z7j33HE0ppRBCCCFEr2To6goIIYQQovNI0AshhBC9mAS9EEII0YtJ\n0AshhBC9mAS9EEII0YtJ0AshhBC9mAS9EEII0YtJ0AshhBC9mAS9EEII0YtJ0AshhBC9mKmrK+Av\n5eXlXV2FXiU2NpaampqurkavIm3qf9KmnUPa1f86o02TkpKOq5z06IUQQoheTIJeCCGE6MUk6IUQ\nQoherNdco/8xpRROpxNd19E0raur0+Ps378fl8sFeNvSYDAQHBwsbSmEED1Mrw16p9NJQEAAJlOv\nPcROZTKZMBqNvsdutxun00lISEgX1koIIcSJ6rWn7nVdl5D3I5PJhK7rXV0NIYQQJ6jXBr2cYvY/\naVMhhOh5em3QCyGEEN2B2lWK/fXnUEp1yf4l6Ls5pRRXXnkldrudhoYGXn311Q5tZ/r06TQ0NPxk\nmUcffZSvvvqqQ9sXQgjRltq2Cc9TD6M/cTeOzz8Aa3WX1EOCvptbvnw5w4cPJyIiApvNxuuvv37U\ncm63+ye3s3jxYqKion6yzIwZM3juuec6XFchhOjrlFKoTevw/PU+9HkPwJ4daFfcQOz/vodm6dcl\ndTqu0WobN27klVdeQdd1cnNzmTJlSpvXq6ureeGFF7DZbISHhzN79mwsFgvV1dXMnz8fXdfxeDxc\nfPHFXHjhhbhcLp588kn279+PwWDgzDPP5NprrwVg5cqVLF68GLPZDMDFF19Mbm6unw/71JkxYwbl\n5eW4XC5mzpzJddddxxdffMHcuXPxeDyYzWb+9a9/0dTUxJw5c/juu+/QNI3f/e53TJ48maVLl/ra\n5s9//jM//PADF1xwAdnZ2eTm5jJv3jyioqIoKyvjq6++Our+AMaOHcuyZctoamriuuuu4+yzz2bt\n2rUkJCSwaNEiQkJCSE5Opq6ujqqqquOeWlEIIQQoXYeNa9A/eRt+KIOYWLRf3Yx23gVogUEYQsKg\nydEldWs36HVdZ+HChcyZMweLxcL9999PVlYWycnJvjKLFy8mOzubCRMmsGnTJpYsWcLs2bOJiYnh\n8ccfJyAgAKfTyd13301WVhZhYWFcdtlljBw5ErfbzaOPPsqGDRsYM2YMAOPHj2fmzJl+O0j9zZdR\ne3b6bXsAWkoahl/9pt1y//M//0NMTAwOh4PJkydz0UUXce+99/Lee+8xYMAA6urqAHj66aeJiIhg\n+fLlANTX1wNQVFTEX//6VwAeeOABtm7dyueffw5AQUEB33//PStWrGDAgAFH3d+kSZN8H5oO2rlz\nJ8899xzz5s3jt7/9LZ988glXXHEFAKeffjpFRUX8/Oc/90MrCSFE76Y8HlTRKtSyd6B8N/RLRLv+\nNrSf5aCZArq6esBxBH1ZWRkJCQnEx8cD3hAuKipqE/R79+7l+uuvB2DEiBHMmzfPu/HDbm9rbW31\n3Z4VFBTEyJEjfWXS0tKora310yF1L4sWLWLZsmWAd+GdN954g3HjxvmCOSYmBoBVq1bx/PPP+94X\nHR0NeAM/PDz8mNs/44wzfNs62v527tx5RNCnpKT42n/UqFHs2bPH95rFYmH//v0dPl4hhOgLVGsr\n6psVqP+8C9WVkDQA7dd3o2Wdi3bYHCTdQbtBb7VasVgsvscWi4XS0tI2ZVJTUyksLGTSpEkUFhbi\ncDiw2+1ERERQU1PD3Llzqays5LrrrjsidJqamli3bh2TJk3yPbdmzRpKSkpITEzkhhtuIDY29qQO\n8nh63p2hoKCAVatW8eGHHxISEsK0adMYMWIE27dvP+5tHLx/3WA4+nCK0NDQn9zfwdntDhcUFOT7\n2Wg04nQ6fY9dLhfBwcHHXT8hhOhLlMuF+uoz1KdLoa4GUgdjuPUBGH022jH+n+5qfplRZvr06Sxa\ntIiVK1eSkZGB2Wz2BVNsbCzz58/HarUyb948xo0b5+utejwennnmGS655BLfGYMzzzyTc845h4CA\nAD7//HOee+45Hn744SP2mZ+fT35+PgBz58494sPA/v37u3zCnKamJqKjo4mIiKC0tJT169fjdrtZ\ns2YN+/btIzU1lbq6OmJiYjj//PN5/fXXefzxxwFvTz46Opr09HT27dtHWloaUVFRNDU1+Y7LaDSi\naZrv8dH2ZzQaMZlMaJqG0Wj0zXZ38D0GgwGDweB7vHPnTt9p+x+3X1BQ0El/6OrLTCaTtJ+fSZt2\nDn+2q6feStM7r9G65XtMKQMxDRxCwKChmAYOwRAR6Zd9nAp6cxOOZe/S9MGbKFs9AcPPIOz2PxI4\n+uzjmmOkK39X201Cs9nc5rR6bW3tEb1ys9nMPffcA3innl2zZg1hYWFHlElJSWHLli2MGzcOgJde\neomEhAQmT57sKxcREeH7OTc3lzfeeOOo9crLyyMvL8/3+Mfr/LpcrjZTuHaF7OxsXnvtNc455xzS\n09PJzMwkOjqav/71r9x0003ouk5sbCxvvvkmt99+Ow888ADZ2dkYDAbuuusuJk2axMSJE1m1ahUp\nKSlERkaSlZVFdnY2OTk55ObmopTyjbg/2v48Hg9utxulFB6PB4/HAxwapa/rOrqu43a7aW1tZefO\nnb7T+j8eye9yuWSN6pMga3z7n7Rp5/BHu6rmJtSnS1HLP4DWFhicgXtjIaz8z6FC5jgYMAgtJQ0t\nZRCkpIGlX7eanEs12lDLP0Kt+BCam2DEGAyTrkIfOgI7wHFedu7K9ejbDfr09HQqKiqoqqrCbDZT\nUFDA7bff3qbMwdH2BoOBpUuXkpOTA3g/FERERBAYGEhjYyNbt27l0ksvBeDNN9+kubmZW265pc22\nDvZwAdauXdtmLEBPExQUdMwPKhMnTmzzOCwsjGeeeeaIctdccw133HEH11xzDcARt7+NHz/+uPa3\nZs0awPuBa8WKFb7nD2///Px8Jk+e3OVnQoQQPZdyuVBffIRa9i40N6KddR7a5degJfT3vm6rhz07\nUXt2eL/v3oH6tvDQZDKhYZDiDX9SBqENSIOEFLRT/P+SaqhDffY+6stl4HLCmHEYJl2JNnDIKa2H\nP7TbckajkRkzZvDEE0+g6zo5OTmkpKTw1ltvkZ6eTlZWFsXFxSxZsgRN08jIyPCNmN+3bx+vv/46\nmqahlOKyyy5jwIAB1NbW8t5779G/f3/+8Ic/AIduo1u2bBlr167FaDQSHh7Orbfe2rkt0M3Fx8dz\nzTXX+MY8dCa3281vf/vbTt2HEKJ3Um436ut81EdvQr0VRmZi+MV0tAHpbcppkdEwYgzaiDGH3uty\nwr4fULt3+D4EqP/+B1paUAAmk3ewW8qgwz4EpKGFhNIepRToOnjc4PEc5bun7ePWFtTaVahVn4PH\ng3b2eWiXXInWf0C7++quNNVVc/L5WXl5eZvHzc3NbQaqiRNjMpmOOHUvbXpy5DSz/0mbdo4TaVel\n697byz5YAlUVkH4ahqnXow0deVJ1ULoH9pe3CX9274BG26FCln5gCjgU0voxgvxEGU1o4yeiXTwV\nrZ9/5hTp1qfuhRBCiB9TSsGmdejvLYa9O6F/KobbHoRRWX65xq4ZjJCYgpaYAmPPP7TPBqvvlD/l\nu0EpMBoPfJkOfTcc5Tmj4UePD5XVDn+ufypajKWdGvYcEvRCCNEDqZ2lqLWroF8S2pDhkJB8ym7v\nUqXF6O+9DmXFEJeANvMutLOzO33/mqZBtAWiLWinZ3XqvnoTCXohhOhB1LZN6B//C4o3gsEAuu69\njh0WAYMz0IYMRxs8HFL
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc42c4099b0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.1 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine model (outputs are logs of unnormalised class probabilities)\n",
"model = SingleLayerModel(\n",
" AffineLayer(input_dim, output_dim, param_init, param_init)\n",
")\n",
"\n",
"# Initialise the error object\n",
"error = CrossEntropySoftmaxError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<span style=\"color:red\">\n",
"This gives exactly the same training curves (and error / accuracy values over training) as the two runs with equivalent parameters above (second `init_scale` experiment and second `learning_rate` experiment).\n",
"</span>\n",
"\n",
"<span style=\"color:red\">\n",
"The times per epoch seems to be slightly lower on average (0.20s compared to 0.22s) suggesting the reformulation gives a small efficiency gain (though this will become less apparent in deeper architectures as the benefit only applies to the final layer).\n",
"</span>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 2: training deeper models on MNIST\n",
"\n",
"We are now going to investigate using deeper multiple-layer model archictures for the MNIST classification task. You should experiment with training models with two to five `AffineLayer` transformations interleaved with `SigmoidLayer` nonlinear transformations. Intermediate hidden layers between the input and output should have a dimension of 100. For example the `layers` definition of a model with two `AffineLayer` transformations would be\n",
"\n",
"```python\n",
"layers = [\n",
" AffineLayer(input_dim, 100),\n",
" SigmoidLayer(),\n",
" AffineLayer(100, output_dim),\n",
" SoftmaxLayer()\n",
"]\n",
"```\n",
"\n",
"If you read through the extension to the first exercise you may wish to use the `CrossEntropySoftmaxError` without the final `SoftmaxLayer`.\n",
"\n",
"Use the code from the first exercise as a starting point and start with training hyperparameters which gave reasonable performance for the shallow architecture trained previously.\n",
"\n",
"Some questions to investigate:\n",
"\n",
" * How does increasing the number of layers affect the model's performance on the training data set? And on the validation data set?\n",
" * Do deeper models seem to be harder or easier to train (e.g. in terms of ease of choosing training hyperparameters to give good final performance and/or quick convergence)?\n",
" * Do the models seem to be sensitive to the choice of the parameter initialisation range? Can you think of any reasons for why setting individual parameter initialisation scales for each `AffineLayer` in a model might be useful? Can you come up with (or find) any heuristics for setting the parameter initialisation scales?\n",
" \n",
"You do not need to come up with explanations for all of these (though if you can that's great!), they are meant as prompts to get you thinking about the various issues involved in training multiple-layer models. \n",
"\n",
"You may wish to start with shorter pilot training runs (by decreasing the number of training epochs) for each of the model architectures to get an initial idea of appropriate hyperparameter settings before doing one or two longer training runs to assess the final performance of the architectures."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# disable logging by setting handler to dummy object\n",
"logger.handlers = [logging.NullHandler()]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models with two affine layers"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.10\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XtgVPWd///nZ2Zyv8wwE0i4XybhEiIXCQqIQMgEEBAh\nyNa2fvvd1m8vtrtV2633trSsXdq1dtvu91u31CL9ueuyVlBEBTLhLqLIRQRUSAKESyCQkMn9Mjmf\n3x+DMVQggTPJTJL34y+HnJz55OWZvHPO+XzOW2mtNUIIIYQIe5ZQD0AIIYQQ7SNFWwghhOgipGgL\nIYQQXYQUbSGEEKKLkKIthBBCdBFStIUQQoguQoq2EEII0UVI0RZCCCG6CCnaQgghRBchRVsIIYTo\nImyhHsDVnD17Nqj7S0pK4uLFi0HdZ08jGZonGZonGZonGQZHsHPs169fu7aTM20hhBCii5CiLYQQ\nQnQRUrSFEEKILiIs72kLIYToHFpr6uvrMQwDpVSoh9NlnD9/noaGhhv6Hq01FouF6Ojom85airYQ\nQvRg9fX1REREYLNJObgRNpsNq9V6w9/n9/upr68nJibmpt5XLo8LIUQPZhiGFOxOZLPZMAzjpr9f\nirYQQvRgckm885nJvNsXbWPtSzQe+TDUwxBCCCFM69ZFW1dWoLe/zaWnHqR5+aPoA++hTVyWEEII\n0TWtWLGCV155BYDVq1dz7ty5G97HX/7yl5Z9XMvHH3/Mww8/fFNjbI9ufSNDJTqwLH+BuAPvUrX2\nPzH+7zPQdyBqdi7q9mkoW0SohyiEEKIdmpubr5j49bevr8Xv9wOBQr1hwwYAXnnlFUaOHElKSkqb\n79Pa1772tTbfb9SoUZSUlHDmzBn69+/f5vY3yrp06dKlQd+rSVVVVUHbl7LZsI+bSN1tMyClPxR9\nCts3ondtBq2h/yAp3u0QGxtLbW1tqIfRpUmG5kmG5v1thk1NTUREhP534KuvvsqPfvQjXnzxRQ4d\nOsTMmTMZMWIEly5d4p//+Z9JT08nNzeX0tJSnnnmGex2O36/n2984xv8+c9/Ztu2bUyfPp3o6Gju\nvfdeDh06xK9+9Stqa2upq6ujvLycuXPnsn79elatWsX27dt5+eWXWbJkCVOnTr1iv3v37uXxxx9n\n5cqV7Nixg5ycHCIiIvj1r3/NwYMHyczMZPHixRw7doxf/vKX/OEPf2D06NEMGDAAAJ/Px5EjR5g4\nceJVf9arZZ6QkNCunLr1mXZrymZDTZqBvn06HN6HsWEN+pU/o99cjZoxD5U9D5XYK9TDFEKIkDH+\newX61PGg7lMNHIrlvm9ed5tjx46xbt06XnvtNSIiInjiiSdYs2YNtbW1jB8/np/+9Kct2/bq1YuN\nGzcC4PF4WLZsGZMnT+Zf//Vfee655/j5z38OBArj22+/DcCzzz7LmDFjAJg/fz4vvvgiP/7xjxk7\nduxV91teXs5Xv/pVAH75y1/y8ssv841vfOML4/b7/bz55pvk5+fz3HPPsXr1agDGjh3Lv//7v/Pd\n7373pjK7nh5TtD+jlIKMCVgzJqCLPsXYuAb99ivovNdQU2aiZi1C9ekb6mEKIUSPsXPnTj766CPm\nzp0LBNaOJyUlYbVamTdv3hXbLliwAIDKykp8Ph+TJ08GYMmSJXz729/+wnYApaWlpKWlXXcMrbf/\n9NNP+dWvfkVlZSU1NTVMnz79qt/z2XjHjBnD6dOnW/7d5XJx/vz5Nn/um9HjinZratgIrA8+gT53\nGr3pNfQ7XvT2TagJU1BzFqMGu0M9RCGE6DRtnRF3FK01S5Ys4Yknnrji359//vkv3F+OjY1t1z5b\nbxcdHU19fX27t3/kkUd44YUXGD16NKtXr+bdd9+96vdERkYCYLVaW+6dAzQ0NBAdHd2ucd6obj17\nvL1UygAsX/sHLP+yAjVrIfrwPox/foTm536MPnIArXWohyiEEN3W1KlTWb9+fUury0uXLl1x5no1\niYmJ2O123nvvPSBwT3zSpElX3TY1NZUTJ060vI6Li6O6uvqa+66uriY5OZmmpibWrl17gz8NFBUV\nMWLEiBv+vvZo15n2gQMHWLlyJYZhkJ2dzcKFC6/4+vr168nPz8dqtZKYmMiDDz5I7969W75eW1vL\nD37wAyZOnMgDDzwQ3J8giJTDhbr379Fzl6C3bUDnr8P4zU9gkDtw5j1hMspy44+tE0IIcW3Dhw/n\n0Ucf5ctf/jJaa2w2G88880yb3/dv//ZvPP7449TX1zNo0CCee+65q243c+ZMvv/977e8/ru/+zse\nf/xxoqOjWbdu3Re2/9GPfsT8+fNxuVyMHz/+ugX+anbt2kV2dvYNfU97Kd3GaaRhGDz00EM8/fTT\nuFwunnjiCR566KGWWXIAhw4dIi0tjaioKDZt2sThw4d55JFHWr6+cuVKKisriY+Pb1fRPnv2rIkf\n6Ytutlm5bmpEv7sFvXEtlJ6F3imoWQtRU7JRkVFBHWO4C3bD955IMjRPMjTvbzOsra1t9yXnruyB\nBx7gqaeeYtiwYUHZn81mu+KS+GcaGhpYvHgxr7322jUfD3u1zPv169eu923z8nhBQQEpKSkkJydj\ns9mYMmUKe/bsuWKbjIwMoqICRSwtLY3y8vKWrxUVFeHz+a6YpddVqIhILNNmY1n2f7F853GIS0D/\n5/MYj/8fjDf/B11zY399CSGECI0nnniC0tLSDn+fM2fO8OSTT3bY89zb3Gt5eTkul6vltcvl4tix\nY9fcfvPmzYwbNw4InKX/5S9/4R//8R/56KOPrvk9Xq8Xr9cLwPLly0lKSmr3D9AeNpvN/D5nL0DP\nupumw/upWfMSja+9BBvWED1rAbF334c1qU9wBhumgpJhDycZmicZmve3GZ4/f75HNAwZOXJk0Pd5\ntdyGDx/O8OHDr/t9UVFRN30cB/X/1Pbt2ykqKuKz57Vs2rSJ8ePHX1H0r8bj8eDxeFpeB/vyV1Av\nqaUMgu8+ieXUcfSGNdSu/x9q3/wr6vbpqDm5qL4Dg/M+YUYuS5onGZonGZr3txk2NDTcVIvJnu5a\nl8fbo6Gh4QvHcXsvj7dZtJ1OJ2VlZS2vy8rKcDqdX9ju4MGDrF27lqVLl7Y86eXo0aN8/PHHbNq0\nifr6evx+P9HR0S2L1rsyNXAo6ps/RC+6//JysTz0rnwYexuWOYtRqaNCPUQhhBDdTJtF2+12U1JS\nQmlpKU6nk127dl0xCw/g+PHjrFixgieffBK73d7y762327p1K4WFhd2iYLemkpJRX/k2+u770Jvf\nRG95E+OXj0FqOpY5i+GWCSiLrKwTQghhXptF22q18o1vfINnnnkGwzDIyspi4MCBrF69GrfbTWZm\nJi+99BL19fUt0+2TkpJ47LHHOnzw4UQl2FH3fAU9Jxe9YxM673WMf18G/QYFGpTcNg3VA+4bCSGE\n6DhtLvkKhXBZ8mWG9vvRH+xAb1gDZ06CMwmVcw9q6ixUdEynjiUY5F6ieZKheZKheT11ydeKFStw\nOBwsWbLkhr/34YcfxuPxMH/+fP7pn/6Jb33rW6Snp19xT3v16tUcPHiQZ555hpUrVxITE8N99913\n1f2ZWfIlp34dJNCgJAt9+ww4tBdjw6vo1S+g31iNypqLyr4blWBvcz9CCCGC25rTjGeffbbNbe67\n7z7uueeeaxZtM+RmawdTSqFuycT6o3/B8vivYHgG+s3/wXjsAYz/fB594cYbsQshRHfz6quvMm/e\nPHJycnj00Udpbm4mLS2Nn/3sZ3g8Hvbu3cvtt9/OM888w+zZs1m/fj2HDh1i/vz5eDweHnjgASoq\nKgC49957+clPfsJdd93Fn/70J9555x0yMjKw2WwUFBRc0YTk1KlTLU8v+81vfsPcuXOZOXMmjz76\n6FUfYX3vvffy4YcfAoE/BKZOncq8efP44IMPWraJiYlh4MCB7N+/P+g5yZl2J1LukVi/9yS65DR6\n45rAve9tG1CZdwQekzo
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc4359d8780>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAENCAYAAAAbl4wiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xt8VNW9///XmtxJQpKZgYSQcAug4Q4GQcBASEBMvFBU\ntCqI4KlVD/gtalv74NfTY6XFgx497QHb+hWt9HDEx7fFcxTQEEi4BUm4Ccg14X7N/Z5JMrPX74+R\naCSQQIZMJvN5Ph59PAyz2fOZRZpP9tp7vZfSWmuEEEII4bFM7i5ACCGEEG0jzVwIIYTwcNLMhRBC\nCA8nzVwIIYTwcNLMhRBCCA8nzVwIIYTwcNLMhRBCCA8nzVwIIYTwcNLMhRBCCA8nzVwIIYTwcL7u\nLuBGXbhwwWXnslqtFBUVuex83kjGsO1kDNtOxtA1ZBzbztVjGB0d3arj5MpcCCGE8HDSzIUQQggP\nJ81cCCGE8HAed8/8h7TW2Gw2DMNAKXVDf/fy5cvU1dXdoso8j9Yak8lEYGDgDY+lEEII9/H4Zm6z\n2fDz88PX98Y/iq+vLz4+PregKs9lt9ux2WwEBQW5uxQhhBCt5PHT7IZh3FQjF83z9fXFMAx3lyGE\nEOIGeHwzl+lg15MxFUIIzyKXtEIIIUQbldvsZJ2sYFhvX/p2af/39/gr885Aa80jjzxCZWUl5eXl\nfPjhhzd1nlmzZlFeXn7dY1577TW2bdt2U+cXQgjxHUNr9l2s5t+2nmfumjxW7Clgx6kSt9QizbwD\n2LhxI4MGDSI0NJSKigo++uijZo+z2+3XPc/KlSsJCwu77jFz585l2bJlN12rEEJ4u+KaBj45WMRP\n//cE/7LpLPsvVXPvgAj+kNaXBYn93FKTTLO7wNy5c7lw4QJ1dXXMmzePJ598kszMTJYsWYLD4cBs\nNvPJJ59QXV3NokWL2L9/P0opfvazn5GWlsaaNWt44oknAPjd737H6dOnmTJlComJiSQnJ7N06VLC\nwsLIy8tj27Ztzb4fwJgxY1i/fj3V1dU8+eST3HnnnezatYuoqChWrFhBUFAQMTExlJaWUlBQQPfu\n3d05bEII4TEchmbXhSo25JWz+0IVhoahkV14cng3xsaG4O/j3mvjTtXMjY/fQ5892frjlUJrfd1j\nVGxfTI/903WPeeutt4iIiKC2tpa0tDTuueceXnnlFf7xj3/Qq1cvSktLAXjnnXcIDQ1l48aNAJSV\nlQGQm5vLG2+8AcCvfvUrjh49yoYNGwDIzs7mwIEDbNq0iV69ejX7fqmpqZjN5iY1nTx5kmXLlrF0\n6VKeffZZ1q1bx0MPPQTA0KFDyc3NJS0trdVjJYQQ3uhSZT0b8svZeKKc0lo74YE+/CjezJT+4fQI\n9Xd3eY06VTN3lxUrVrB+/XrAuRHM3/72N8aOHdvYfCMiIgDYunUry5cvb/x74eHhgLOph4SEXPP8\nI0aMaDxXc+938uTJq5p5bGwsQ4YMAWDYsGGcPXu28TWLxcLly5dv+vMKIURn1uAw+OpsFRvyy/j6\nUg0mBaN6BDNldCQJPUPwNXW8FT+dqpm3dAX9Q76+vi3eh25JdnY2W7du5bPPPiMoKIiHH36YwYMH\nk5+ff0N1GIaBydT8NE2XLt89Gtnc+zWXYhcQEND43z4+Pthstsav6+rqCAwMbHV9QgjhDc6U17Eh\nr4zMkxVU1jno1sWXHw+zktwvjG7Bfu4u77o6VTN3h8rKSsLCwggKCiIvL489e/ZQV1fHV199xZkz\nZxqn2SMiIkhMTOTDDz/ktddeA5xX5OHh4fTr14/Tp0/Tt29fgoODqaqquqH3u1EnTpzgvvvuu+nP\nLIQQnYXNbrD9dAXpeeUcKarFR8GY2FCmxIUxPCoYnw54Fd4caeZtNGnSJFauXMnEiROJi4tj1KhR\nWCwW/u3f/o1nnnkGwzCwWq18/PHHvPjii/zqV79i8uTJmEwmFi5cSGpqKsnJyezYsYO+fftiNpsZ\nPXo0kydPJikpieTk5Bbf70Y0NDRw6tQphg8f7sphEEIIj5JfYiM9r4wtpyqoaTCIDvVnzshuJPUL\nIzzQ81qj0i09AdbBXLhwocnXNTU1Taahb4Qrptld4fLly7z44ot8/PHHt/y91q9fz4EDB/j5z39+\nzWNuZEytVitFRUWuKs8ryRi2nYyha3T2cayud7DlVAXpeWWcKK3D30cxrlcoU/uHM6hbkEvSL109\nhtHR0a06zvN+/eiEIiMjefzxx6msrCQ0NPSWvpfdbufZZ5+9pe8hhBAdhdaaI4W1pOeXse10JfUO\nTd+IAH6SEMnEPl0JCegcm21JM+8gHnjggXZ5n/vvv79d3kcIIdzpSrxqel4Z5yrqCfQ1kdQ3jCn9\nw+hv7nzbPEszF0II0SkYWrP/Ug3peWXsPFeJ3YDbrEHMHxvF+F5dCfLrvKGn0syFEEJ4tOKaBjbm\nl7Mhv5yC6gZC/U3cOyCCKf3D6R0e0PIJOgFp5kIIITxOc/GqwyK7MGtEx4hXbW/SzIUQQniMH8ar\nRgT6MGOQhZS4sA4Vr9revOtXlw7q+1ug3owBAwYAcOnSJf7pn5pPwXv44Yf5+uuvAXj00Ucbc+GF\nEKKja3AYbD1Vwf+38QzP/u8J/nGomLiIAF5N7Mn//VF/Zo3o5tWNHOTKvEP4/haobREVFcV7773X\n4nEPPfQQf/3rX3nxxRfb9H5CCHEr/TBetXuwL48PszLZA+JV25s0cxdw9Rao0dHRzJkzB3DukBYc\nHMysWbN4+umnKS8vx2638/Of/5x77rmnSR1nz57lqaeeYtOmTdTW1rJw4UIOHTpE//79m2SzT506\nlRkzZkgzF0J0OD+MV/U1wZ0xzmCXYZFdPCZetb11qmb+f3dd5mSpreUDv6VasQVq34hAnkmIvO4x\nrtwC9YEHHuBf/uVfGpv5Z599xn/9138REBDA+++/T2hoKCUlJdx///1MnTr1mmslP/roI4KCgti8\neTOHDh1i2rRpja+Fh4dTV1dHSUnJVbutCSGEO+QV29iQ/128as+unh2v2t5khFzAlVugDhkyhKKi\nIi5dukRxcTFhYWH07NmThoYGlixZws6dO1FKcenSJQoLC+nevXuzNe3cuZO5c+cCMGjQIOLj45u8\nbrVauXz5sjRzIYTbtEe8qrfoVM28pSvoH+qoW6Ded999rF27loKCgsZkuH/84x8UFxezfv16/Pz8\nGDNmTLNbn7aWbIMqhHAHrTWHC2vZ0Fy8at+uhPh3jnjV9iZPs7dRS1ugAo3T7Fe2QL3iyjT7lS1Q\nr3jggQf4n//5H9auXdu4VWllZSVWqxU/Pz+2b9/OuXPnrlvXmDFj+PTTTwE4cuQIhw8fbnxNa01h\nYSGxsbFtHwAhhGiFcpudTw8X88+fn+TVDWfYcaaKpL5hvDmtN2/f24e02yKkkbdBp7oydwdXb4EK\ncNttt1FdXU1UVBSRkc7ZhhkzZvDUU0+RnJzMsGHD6N+//3Xrmj17NgsXLmTixIkMGDCAYcOGNb62\nf/9+Ro0aha+v/PMLIW4db45XbW+t2gJ13759fPDBBxiGQXJyMtOnT2/yemFhIe+++y4VFRWEhIQw\nf/58LBYL4FzTfOXesdVq5Re/+AUABQUFvPPOO1RWVtKvXz/mz5/fquYiW6C23a9//WumTJnC3Xff\n3ezrsgVq+5IxbDsZQ9dw1Tg2F686qV8YU+I6f7xqh90C1TAM3n//fRYtWoTFYuHVV18lISGBmJiY\nxmNWrlxJYmIikyZN4uDBg6xatYr58+cD4O/vz9KlS68679/+9jfS0tIYP348f/nLX9i0aRNTp05t\n7efrVNpzC1RwXvlfq5ELIcTN+C5etYzdF6q9Pl61vbU4unl5eY3Tvb6+vowbN47c3Nwmx5w7d44h\nQ4YAMHjwYHbt2nXdc2qt+eabbxg7dizgnKr+4Tm9zQMPPNAujRxoXNMuhBBtdamynpX7Cpn3aT6/\n23yevGIbMwZZ+NMD/fh
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435999ef0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.75e-01\n",
" final error(valid) = 1.72e-01\n",
" final acc(train) = 9.50e-01\n",
" final acc(valid) = 9.53e-01\n",
" run time per epoch = 9.98\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.20\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xlg1NW9///nmZnsycwwSUjYl5kgmxAwLIkIhGwQkB3t\ndntv99r2Wnu/rXt76bUqbdG29t6f7dWW6rXXUhQQWbMgOwqKiLgyCftOwmTfJp/z+2NqblAkgZkk\nM8n78ZdDPvOZMy8PefP+LOejtNYaIYQQQgQ9U1cPQAghhBDtI0VbCCGECBFStIUQQogQIUVbCCGE\nCBFStIUQQogQIUVbCCGECBFStIUQQogQIUVbCCGECBFStIUQQogQIUVbCCGECBGWrh7A1Zw5cyag\n+0tISODSpUsB3WdPIxn6TzL0n2ToP8kwMAKdY9++fdu1nXTaQgghRIiQoi2EEEKECCnaQgghRIgI\nynPaQgghOofWmvr6egzDQCnV1cMJGefPn6ehoeG63qO1xmQyERkZecNZS9EWQogerL6+nrCwMCwW\nKQfXw2KxYDabr/t9Xq+X+vp6oqKibuhz5fC4EEL0YIZhSMHuRBaLBcMwbvj9UrSFEKIHk0Pinc+f\nzLt90TbWvkDjB4e6ehhCCCGE37p10daVHvT2zVx+8Ls0//pB9Ptvo7Xu6mEJIYToZM888wyrVq0C\nYOXKlZw7d+669/H888+37OPzfPDBB9xzzz03NMb26NYnMpTVjmnZs0S/tYvqNS9g/ObfYXAKptlL\nYMxElKlb/5tFCCG6jebm5isu/Pr068/j9XoBX6HevHkzAKtWrWL48OEkJye3+TmtffWrX23z80aM\nGMHZs2c5ffo0/fr1a3P769WtizaAiogkZu4XqJ0wDb23GL15NcZ/PQb9BqFmLUalTUHdwBWAQggh\nAufll1/mz3/+M42NjYwbN47HH3+c4cOH85WvfIWdO3fy2GOP8a//+q/MnTuXHTt28L3vfQ+n08n9\n999PfX09gwYN4oknnsBut7N48WJGjhzJ/v37mTdvHiNGjGD06NFYLBbWr1/PO++8ww9+8AMiIyNZ\nt24d06dPv2K/1dXV/PWvf6WxsZEhQ4bw1FNPERUVxRNPPEFMTAzf/e53WbBgAampqezZs4eKigqe\neOIJJk2aBEBOTg6vvPIK3/ve9wKeU7cv2p9QYWGoqTPRt+ag9+9Eb1yFfvYJ9Ct/Rc1chEqfgQoL\n6+phCiFElzH+9gz65NGA7lMNGILpC9+65jZHjhxh3bp1rF27lrCwMB544AFWr15NbW0t48aN49//\n/d9btu3VqxdbtmwBIDs7m0ceeYT09HR+/etf8+STT/If//EfADQ1NbFp0yYAli9fzpgxYwCYM2cO\nf/nLX/jpT3/K2LFjr7rf8vJyvvzlLwPwy1/+khdffJGvf/3rnxm31+tlw4YNFBcX8+STT7Jy5UoA\nxo4dy3/+539K0Q4EZTajJk9HT5wKB9/A2LgK/T//hX71b6i8+ajb8lARkV09TCGE6DF27drFu+++\nS35+PuC7dzwhIQGz2czs2bOv2Hbu3LkAVFZWUlFRQXp6OgBLlizhO9/5zme2A7hw4QIpKSnXHEPr\n7T/66CN+9atfUVlZSU1NDdOmTbvqez4Z75gxYzh16lTLn8fHx3P+/Pk2v/eN6HFF+xPKZILx6ZjG\nTYb3D2Js/Dt65Z/QG1ahsueiMmejomO6ephCCNFp2uqIO4rWmiVLlvDAAw9c8ed/+MMfPnN+OTo6\nul37bL1dZGQk9fX17d7+Rz/6EX/6058YNWoUK1euZO/evVd9T3h4OABms7nl3DlAQ0MDkZEd0/z1\n+CuxlFKoUeMw/+RxTPcug8Eu9NoXMO7/Bsaa/0FXVXT1EIUQolubMmUK69evb3nU5eXLl6/oXK/G\narVis9l44403AN858cmTJ191W5fLxbFjx1pex8TEUF1d/bn7rq6uJikpiaamJtasWXOd3wZKS0u5\n6aabrvt97dFjO+2rUSkjMf9wKfp4ie+w+aaX0EXrUFPzUDnzUY6Erh6iEEJ0O8OGDePee+/li1/8\nIlprLBYLjz76aJvv++1vf9tyIdrAgQN58sknr7rdjBkzuPvuu1te33HHHdx///0tF6J92k9+8hPm\nzJlDfHw848aNu2aBv5o9e/aQlZV1Xe9pL6WD8MblM2fOBHR/N/qwcn32pK9wv7EdlAmVMcN30Vrv\nPgEdXygI9APfeyLJ0H+Sof8+nWFtbW27DzmHsm984xs89NBDDB06NCD7s1gsVxwS/0RDQwOLFi1i\n7dq1n7s87NUy79u3b7s+t8cfHr8W1WcApq//CNMv/oCako3euxXj4bswnn0CffpEVw9PCCFEOz3w\nwANcuHChwz/n9OnTPPjggx22nrscHm8HlZiM+sr30HPuRBe+gt6+2dd9p07GNHsJavC1r0oUQgjR\ntVwuFy6Xq8M/Z+jQoQHr5q9GivZ1UPZ41JKvo2ctRhevR299FePg6zByHKb8JTBslCy+L4QQosNI\n0b4BKtaKmvcldO589LZN6MK1GMsfBNcIX/EefYsUbyGEEAEnRdsPKioaNWsROmsOelchestqjKf+\nAwYO9RXvcZNRJlkiVQghRGC0q2gfPHiQFStWYBgGWVlZzJ8//4qfr1+/nuLiYsxmM1arlbvuuovE\nxEQuXrzI8uXLMQyD5uZmZs6cSW5ubod8ka6kwiNQM+agp+ahX9+G3vQyxh9+Ccn9UbMWoSZOQ8lD\n5oUQQvjJvHTp0qXX2sAwDB577DEeeughFixYwIoVKxg5ciRWq7Vlm8bGRu68807y8/NpaGiguLiY\n9PR0IiIimDFjBjNnzmTatGk89dRTpKenExUVdc1BVVVVBeTLfSI6Opra2tqA7vNqlMmMGuhEZc6C\nvgOh5EPYsQX9+mtgsfgeUhKiDyfprAy7M8nQf5Kh/z6dYVNTE2E94LkLzzzzDKWlpYwaNeq633vP\nPffg9XoZNmwYP/7xjxk8eDCJiYkYhtGyzcqVK3nxxRfJyspixYoVfPTRR4wePfqq+7ta5nFxce0a\nS5u3fLndbpKTk0lKSsJisZCRkcH+/fuv2Gb06NFEREQAkJKSQnl5OeC7j+2TgTU1NV3xBbszZTJj\nmnAbpp/9DtMPHgarHf3XP2A88C2MLWvQ9XVdPUQhhAgpzc3N13z9ebxeL16vl5UrV7JgwQK/x7F8\n+XKGDRt2zW2+8IUv8Oc//9nvz7qaNot2eXk58fHxLa/j4+NbivLVbN26ldTU1JbXly5d4sc//jF3\n3XUX8+bNw+Fw+Dnk0KGUQo2diOmBX2P6t0egzwD0Sysw7v8mxqt/Q9cE9oiCEEKEqpdffpnZs2eT\nk5PDvffeS3NzMykpKfz85z8nOzubt956i0mTJvHoo4+Sl5fH+vXrOXz4MHPmzCE7O5tvfOMbeDwe\nABYvXszPfvYzZs2axbPPPsvu3btbHs3pdruveAjJyZMnW1Yv+81vfkN+fj4zZszg3nvv5Wprjy1e\nvJh33nkH8HXXU6ZMYfbs2bz55pst20RFRTFgwADefvvtgOcU0BOtO3bsoLS0lNZH3BMSEli+fDnl\n5eX8+te/ZvLkydjt9iveV1RURFFREQDLli0jISGwy4VaLJaA7/O6JWbBbVk0fnSYmpefp3Hd/0LB\nWiJnLSB67hcx24P7HzNBkWGIkwz9Jxn679MZnj9/vmUhkP/ed5bS8sAeCRzqiOLbE6+9iuTHH3/M\nq6++yvr16wkLC+O+++7jlVdeoba2lrS0NB555BHA1wjFx8dTXFwMwPTp03nsscfIyMjgl7/8Jb/9\n7W/5xS9+gVKK5uZmCgsLAfjVr35FamoqFouF4cOH09TUxOnTpxk0aBDr169n3rx5WCwWvvnNb/KT\nn/wEgO9///ts3bqVvLw8TCYTZrMZi8WCUgqz2cz58+d54oknKCgowGq1snDhQm6++eaWLFNTU3nz\nzTeZMGHCZ75vRETEDc/
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc43594c5f8>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAENCAYAAAAbl4wiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl4k9ed//33kffdklcWsxlszL4YzGqwIUBsJyF7GwJp\nIE3aZkx+TdMlfbhmOr0mM8zQ/Nprrmky0z4hM0mHJtNO0z6xgbDYZokDdkgIEMAbqzF4lXdLlnSf\n5w8lTsnGYoMs+fv6T+j2ra8Otr++Pzr3OUprrRFCCCGE1zJ5ugAhhBBC9I80cyGEEMLLSTMXQggh\nvJw0cyGEEMLLSTMXQgghvJw0cyGEEMLLSTMXQgghvJw0cyGEEMLLSTMXQgghvJw0cyGEEMLL+Xu6\ngBtVV1c3YOeKjY2lqalpwM43FMkY9p+M4cCQcew/GcP+G+gxHD58+HUdJ1fmQgghhJe7rivzo0eP\n8uqrr2IYBsuWLWP16tVXPd/Y2MjLL79Me3s74eHh5OfnExMTA8DDDz/MqFGjAPdfLD/+8Y8B+PWv\nf83JkycJDQ0F4Omnn2bMmDED9b6EEEKIIeOazdwwDF555RU2bdpETEwMzz//POnp6YwcObLvmNdf\nf53MzEyWLl3KiRMn2LZtG/n5+QAEBgayZcuWLz332rVrmTdv3gC9FSGEEGJoumYzr66uJjExkYSE\nBAAWLFhAeXn5Vc28traWdevWATB58uSvbN63gtYam82GYRgopW7oa+vr67Hb7beoMu+jtcZkMhEc\nHHzDYymEEMJzrtnMW1pa+iJzgJiYGKqqqq46ZvTo0ZSVlZGTk0NZWRk9PT10dHQQERGBw+HgJz/5\nCX5+ftxzzz3MnTu37+t+//vf88c//pEpU6awZs0aAgICbvgN2Gw2AgIC8Pe/8bl8/v7++Pn53fDX\n+TKn04nNZiMkJMTTpQghhLhOAzKbfe3atWzdupWSkhLS0tKwWCyYTO65dS+99BIWi4X6+np+/vOf\nM2rUKBITE3nkkUeIjo7G6XTyH//xH/zlL3/hgQce+MK59+zZw549ewDYvHkzsbGxVz1fX19PUFDQ\nTdd+M38E+DJ/f3+UUl8Y5687/nqPFV9OxnBgyDj2n4xh/3lqDK/ZySwWC83NzX2Pm5ubsVgsXzjm\nueeeA9xXyocPHyYsLKzvOYCEhAQmTZrEuXPnSExMxGw2AxAQEEBWVhZvv/32l77+8uXLWb58ed/j\nz0/5t9vtN3117e/vj9PpvKmv9WV2u/26b62QW1n6T8ZwYMg49p+MYf8N2lvTkpOTuXz5Mg0NDTid\nTkpLS0lPT7/qmPb2dgzDAOCtt94iKysLgM7OThwOR98xFRUVfZ+1W61WwP05bXl5OUlJSdf51oQQ\nQojBxe402FXdyp7KRo+8/jWvzP38/Fi/fj0vvPAChmGQlZVFUlISb775JsnJyaSnp3Py5Em2bduG\nUoq0tDQ2bNgAwKVLl/jNb36DyWTCMAxWr17d18z/9V//lfb2dsD9mfuTTz55C9/m4Ka15qGHHmLr\n1q0YhsFbb73Ft771rRs+z9q1a/m3f/s3oqKivvKYn//852RnZ7No0aJ+VCyEEAKgscvB9koru6tb\n6eg1yEx2MGNe3G2vQ2mt9W1/1X74/Apw3d3dffeq36jBErPv2bOHAwcO8Pd///dcvHiRxx57jKKi\noi8c53Q6+/0Zf21tLT/84Q/5/e9//5XH3MiYSizXfzKGA0PGsf9kDK+P1pqTDT28XWHlcG0HABkj\nI7gr1UzmpKSrPprur+uN2WX21wBYv349dXV12O12NmzYwKOPPkpxcTGbN2/G5XJhsVj4n//5H7q6\nuti0aRPHjh1DKcX3v/99cnNzeeutt1izZg0A//iP/8j58+e54447yMzMZNmyZWzZsoWoqCiqq6s5\nePDgl74eQEZGBjt27KCrq4tHH32UuXPn8v7775OYmMjWrVsJCQlh5MiRWK1WGhoaiI+P9+SwCSGE\nV7E7DQ6cb6egwspZq52IQBOr0yzcOcFMfLj7bixP3dbrU83ceOO36Itnr/94pbhWMKGSxmL6xre/\n9pgXX3wRs9lMT08Pubm5rFy5kh/+8If86U9/YtSoUX3zA371q18RERHB3r17AWhtbQWgvLycf/7n\nfwbgpz/9KRUVFezevRuA0tJSjh8/TlFRUd9Kep9/vZycnC9MSjx79iy//vWv2bJlC0899RTbt2/n\n/vvvB2Dq1KmUl5eTm5t73WMlhBBDVWOXg51VrbxT3UqH3cXoqCCezkhkyZhIgvwHx6roPtXMPWXr\n1q3s2LEDcH8M8Lvf/Y558+b1Nd9PZ+4fOHCAl156qe/roqOjAXdTDw8P/8rzz5gxo+9cX/Z6Z8+e\n/UIzT0pKYsqUKQBMmzaNixcv9j0XExNDfX39Tb9fIYTwdVprTjX2UFBh5b2L7ih97shw8lLNTIkP\nHXQLa/lUM7/WFfTnDcRn5qWlpRw4cIC3336bkJAQHnjgASZPnkxNTc0N1WEYRt+9+Z/3159ff9nr\nfdkqdn99772fnx82m63vsd1uJzg4+LrrE0KIoaLXZXDgnDtKP2O1ExZo4p6JFu5MiSYhPNDT5X2l\nwZEPeLGOjg6ioqIICQmhurqaDz74ALvdzqFDh7hw4QLw2W14mZmZ/Od//mff134as48bN47z588D\nEBYWRmdn5w293o06c+YMqampN/x1Qgjhq5q7HfzuaCMb3qrhXw9dwWlovjc3ka33judbs+IHdSMH\nH7sy94SlS5fy+uuvs2TJEpKTk5k1axYxMTH8y7/8C0888QSGYRAbG8sbb7zBM888w09/+lOys7Mx\nmUw8++yz5OTksGzZMt577z3Gjh2LxWJhzpw5ZGdnk5WVxbJly675ejfC4XBw7tw5pk+fPpDDIIQQ\nXkdrzemmT6L0Cx0Y+rMofWrC4IvSv47cmjYIbk2rr6/nmWee4Y033rjlr7Vjxw6OHz/Oj370o688\nRm5Nu71kDAeGjGP/DZUxdLgMDpzvoKDCSk2LjbAAE3eMj+bOCdEkRvTvCtxTK8DJlfkgkJCQwCOP\nPNK3Oc2t5HQ6eeqpp27pawghxGDU3P3ZrPQ2m4uRkYF8Z04CS8dGERLg3Z86SzMfJO6+++7b8jp3\n3XXXbXkdIYQYDLTWVDbbKDht5d0L7Rga0ke4o/Tpid4VpX8daeZCCCF8jsNl8O4Fd5Re1WwjNMBE\nbqqZnBQzw/oZpQ9G0syFEEL4DGuPk51VVnZWtdJqczEiMpCn5iSQ5QNR+teRZi6EEMLrVX4yK/3d\nC+04DUgfHkbeRAvTE0Mx+UiU/nWkmQshhPBKDpem9IJ7gZfKZhsh/ibunOCO0odH+l6U/nV8N3Pw\nIlprHnzwQTo6Om7q6ydMmADAlStX+Pa3v3wVvAceeICPPvoIgIcffrhvwRohhPA2rT1O3jjexLf/\nXM3/Lb1MZ6/Bk+kJbL0vmSfSE4ZcIwe5Mh8U9u7dy6RJk/p9W1piYiK//e1vr3nc/fffz3/913/x\nzDPP9Ov1hBDidqpqdkfpB8934DQ0s4eHkZdqZsawsCERpX8daeYDYKC3QB0+fDjf+ta3APcOaWFh\nYaxdu5bHH3+ctrY2nE4nP/rRj1i5cuVVdfz1Xug9PT08++yznDx5kvHjx1+1NvuKFSu47777pJkL\nIQY9p6Ep/WRWekVTD8H+JlZOiCY3xcyIIXgF/lV8qpn/v+/Xc9Zqu/aBn1DXsQXqWHMwT6QnfO0x\nA7kF6t13383f/d3f9TXzt99+m//+7/8mKCiIV155hYiICFpaWrjrrrtYsWLFV94j+dprrxESEsK+\nffs4efIkq1at6nsuOjoau91OS0vLF3ZbE0KIwaDV5mRXVSs7qlpp6XEyLCKAJ2bHsyw5itAAP0+X\nN+j4VDP3lIHcAnXKlCk0NTVx5coVmpubiYqKYsSIETgcDjZv3szhw4dRSnHlyhUaGxuJj4//0poO\nHz7M+vXrAZg0aRJpaWl
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435c3b828>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.70e-01\n",
" final error(valid) = 1.69e-01\n",
" final acc(train) = 9.52e-01\n",
" final acc(valid) = 9.55e-01\n",
" run time per epoch = 10.36\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.50\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8lOW9///XNZN9m6wQZE0GkJ0gUSCyhRnWAAKy1J72\n9NSeLtrW5XiqorbFY+3htMjpOe35tT1KqT322wZQEBM2k8gmsomAiLIkgbAvCZlsZL2v3x8DMciS\nkHuSmSSf5+Ph4+Ek99xzzZs7+eS65/7cl9Jaa4QQQgjh8yzeHoAQQgghmkaKthBCCNFGSNEWQggh\n2ggp2kIIIUQbIUVbCCGEaCOkaAshhBBthBRtIYQQoo2Qoi2EEEK0EVK0hRBCiDZCirYQQgjRRvh5\newC3cvbsWY/uLzY2lsuXL3t0nx2NZGieZGieZGieZOgZns7xnnvuadJ2MtMWQggh2ggp2kIIIUQb\nIUVbCCGEaCN88jNtIYQQrUNrTWVlJYZhoJTy9nDajAsXLlBVVXVXz9FaY7FYCAoKanbWUrSFEKID\nq6ysxN/fHz8/KQd3w8/PD6vVetfPq62tpbKykuDg4Ga9rpweF0KIDswwDCnYrcjPzw/DMJr9fCna\nQgjRgckp8dZnJvN2X7SNNW9R/cWn3h6GEEIIYVq7Ltq65Ap6ywauLPw+dUt/ij5yyNtDEkII4QWv\nv/46K1euBCA9PZ3z58/f9T7+8pe/1O/jdj7//HOeeuqpZo2xKdr1BxkqIgrL4jcI2buVsnfewljy\nAvQZgCVtAQxIktNCQgjRRtTV1d1w4ddXH99ObW0t4C7UGzZsAGDlypX069eP+Pj4Rl+noX/8x39s\n9PX69+/PuXPnOHPmDF27dm10+7vVpKK9f/9+li9fjmEYOBwOZs2adcP3MzIyyM7Oxmq1EhERwWOP\nPUZcXByXLl1iyZIlGIZBXV0dU6ZMYdKkSR5/E3eiAoMIfejrVNw/Dr39ffSGdzB+83NI6Osu3kOS\npXgLIYSXvf322/zpT3+iurqaYcOG8e///u/069ePb3zjG2zbto1f/vKX/PjHP2bmzJls3bqVxx9/\nHLvdzvPPP09lZSU9e/bktddeIzIykrlz5zJgwAD27NnDQw89RP/+/Rk0aBB+fn5kZGRw4MABfvSj\nHxEUFMTatWsZP378DfstKyvjr3/9K9XV1SQkJPDf//3fBAcH89prrxEaGsoPfvADZs+eTVJSEjt2\n7MDlcvHaa68xYsQIACZOnMi7777L448/7vGcGi3ahmGwbNkyXnrpJWJiYli4cCHJycl069atfpte\nvXqxePFiAgMD2bRpE2+99RZPP/00UVFR/OIXv8Df35/KykqeeeYZkpOTiY6O9vgbaYwKCERNmI4e\nMxn9UTZ63SqM370C3RPcxXvYSJSlXX9aIIQQd2T8/XX0qXyP7lN1T8Dyte/ecZtjx46xdu1a1qxZ\ng7+/PwsXLuSdd96hoqKCYcOG8fOf/7x+26ioKDZu3AiA0+nklVdeYdSoUfz6179m6dKl/Nu//RsA\nNTU1rF+/HoAlS5YwZMgQAKZPn86f//xnfvrTnzJ06NBb7reoqIh/+Id/AOA//uM/+Nvf/sajjz56\n07hra2vJzMwkOzubpUuXkp6eDsDQoUP53e9+552iffz4ceLj4+ncuTMAKSkp7Nmz54aiPWjQoPr/\n79OnD9u2bXPvvEEbQU1NjanL3D1F+fujxk5BpzjRu7e4i/cfFkOX7qi0+aj7R6Msd997J4QQonm2\nb9/Op59+yrRp0wB373hsbCxWq5W0tLQbtp05cyYAJSUluFwuRo0aBcC8efP4/ve/f9N2ABcvXqRP\nnz53HEPD7Y8cOcKvfvUrSkpKKC8vZ9y4cbd8zvXxDhkyhNOnT9d/PSYmhgsXLjT6vpuj0aJdVFRE\nTEzMDYM5duzYbbfPyckhKSmp/vHly5dZvHgx58+f5xvf+IZXZtm3ovz8UCkO9Mjx6L0fojNXoN94\nDb32b6hp81AjxqGkd1EI0YE0NiNuKVpr5s2bx8KFC2/4+h/+8IebPl8OCQlp0j4bbhcUFERlZWWT\nt3/66adZtmwZAwcOJD09nY8++uiWzwkICADAarXWf3YOUFVVRVBQUJPGebc8WpW2bt1KXl4eixYt\nqv9abGwsS5YsoaioiF//+teMHDmSyMjIG56XlZVFVlYWAIsXLyY2NtaTw8LPz+/O+5w2Bz1lFlW7\ntlK+6s/U/vm/UOtWEDLnmwRPmIbyD/DoeNqiRjMUjZIMzZMMzftqhhcuXPD6zVXGjRvHt771LX7w\ngx8QFxfHlStXKCsrA248Y6uUwmq14ufnR3R0NJGRkezdu5eRI0eyevVqUlJS8PPzu2E7gHvvvZeC\ngoL6x2FhYVy9erX+8Ve3Ly8v55577kFrzZo1a+jSpQt+fn5YLBYsFkv9dtefY7VaUUrVf/3kyZP0\n79//trkGBgY2+zhu9F8qOjqawsLC+seFhYW3nC0fPHiQ1atXs2jRIvz9/W+5n+7du/PFF18wcuTI\nG77ndDpxOp31jz291muT1z3tMwj9/K+xfLoXIyOd0j/8itK/L0NNmYMaMwkVEOjRcbUlsgaveZKh\neZKheV/NsKqqqlm34/Qku93OT37yE+bPn4/WGj8/P1599VWAG2awWmvq6urqv/af//mf9Rei9ejR\ng6VLl1JbW3vTduPGjeOJJ56ofzxv3jx+8pOf1F+I9tXt//Vf/5WpU6cSExPDsGHDKCsro7a2FsMw\nMAyjfrvrz6mrq0NrXf/1bdu2kZqaesPYG6qqqrrpOG7qetpKa63vtEFdXR1PPvkkP/vZz4iOjmbh\nwoU88cQTdO/evX6b/Px8li5dygsvvECXLl3qv15YWEh4eDgBAQGUlZXx4osv8swzz9CjR487Durs\n2bNNGnxTNecHXWsNnx/AyEyHo59BRCRq0mzUuCmooObdM7Ytk1+W5kmG5kmG5n01w4qKiiafcm7L\nvvOd7/Diiy+SmJjokf35+fndsihXVVXx8MMPs2bNmtvOtG+VeVOLdqMzbavVyqOPPsqrr76KYRik\npqbSvXt30tPTsdvtJCcn89Zbb1FZWcnSpUsB90Hx3HPPcebMGf7yl7+glEJrzYwZMxot2L5CKQUD\nkrAOSEIfPYSRkY5etRy9YRXK+RAqNQ0VEurtYQohhGiChQsXcvHiRY8V7ds5c+YML7zwQot95NDo\nTNsbfGGmfSs69wuMzBXw6V4ICUVNmIFyzkCFhntglL5NZjjmSYbmSYbmddSZtqfdbqbdFC060xZf\nUvZ+WJ/4GfpkLkZmOjrj7+j330WlTkNNfAgVEdn4ToQQQohmkqLdDKqnHevjL6DPnHS3im18B53z\nHmrsFNTk2ajImMZ3IoQQQtwlKdomqK49Ud/7CXrmI+h1q9A5GejN61GjJ7qvOI/p5O0hCiGEaEek\naHuAiu+GevQp9IyvodevQm/bhN62ETVqAmrqXFSnLo3vRAghhGiE3Gzbg1RcPJZ//BGWX/7RfavU\nnZsxfvoYxrL/RJ873fgOhBBCtIiGS3PeraeeeoqMjAzA3cN99OjRm7ZJT0/nxRdfBGD58uX8/e9/\nb/5g70Bm2i1ARcehvv599LR56PfXoDevR+/ajBr+oPv+5t16eXuIQgjRpnhyaU4zlixZ0ug2X/va\n13jooYf42te+Zvr1vkpm2i1IRUZjmfcolsVvoKY8jD70McbLT1D3P6+iTx739vCEEMJnvP3226Sl\npTFx4kSeffZZ6urq6NOnDy+//DJOp5OPP/6YESNG8OqrrzJ58mQyMjI4dOgQ06dPx+l08p3vfIfi\n4mIA5s6dy89+9jOmTp3KG2+8wYcffli/NOfx48dvWITk1KlTOBwOwH2HtWnTpjFhwgSeffZZbtUR\nPXfuXA4cOAC4/xAYPXo0aWlp7N27t36b4OBgunfvzieffOLxnGSm3QpUuA015x/Rk+egs99DZ6/F\n+MUuGDQcS9p8VO/+3h6
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435cf9e10>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAENCAYAAADngqfoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl4VFWe+P/3rarsa1UlJCwBIRAM+xJ2BUIiS3Ch3Xq6\nERdwtNueuI3dag8zPT+fx/n6bb/9nV5cprtHdMTmK90qqCSBECDIIgRUFARTSVjDlq2SylJVqap7\nfn9UWxJBA6SykHxez+PzWNxT5557CPnUPffU56MppRRCCCGE6JUM3T0AIYQQQnQeCfRCCCFELyaB\nXgghhOjFJNALIYQQvZgEeiGEEKIXk0AvhBBC9GIS6IUQQohezHQ5jQ4cOMDrr7+OrutkZWWxZMmS\nNserq6t59dVXcTgcREdHk5ubi9VqBeCHP/whgwcPBiAhIYGnn34agN///vdUVFRgMplITU3loYce\nwmQy8eWXX/LrX/+afv36ATBt2jTuvPPOoF2wEEII0Ze0G+h1Xee1115j5cqVWK1Wnn32WTIyMhg0\naFCgzerVq5k9ezZz587l0KFDrFmzhtzcXABCQ0N58cUXL+r3hhtuCLT53e9+x9atW5k/fz4A6enp\nPPPMM0G5QCGEEKIva3fpvry8nOTkZJKSkjCZTMycOZN9+/a1aVNZWcmYMWMAGD16NPv372/3xJMm\nTULTNDRNY/jw4dTW1l7lJQghhBDiu7R7R19XVxdYhgewWq2UlZW1aTNkyBBKSkrIycmhpKQEp9NJ\nY2MjMTExeDwennnmGYxGI7fddhtTp05t816v18uOHTu4//77A39ms9n4+c9/jtlsZtmyZaSkpLR7\nIWfOnGm3zeVKSEigpqYmaP31VTKPHSdz2HEyhx0nc9hxnTGHAwYMuKx2l/WMvj3Lli1j1apVFBcX\nk56ejsViwWDwLxa88sorWCwWzp8/z3PPPcfgwYNJTk4OvPe///u/SU9PJz09HYChQ4fyyiuvEB4e\nzqeffsqLL77I73//+4vOWVRURFFREQAvvPACCQkJwbgUAEwmU1D766tkHjtO5rDjZA47Tuaw47pz\nDtsN9BaLpc2yem1tLRaL5aI2Tz31FAAul4u9e/cSFRUVOAaQlJTEqFGjOH78eCDQ/+1vf8PhcPDQ\nQw8F+oqMjAz8/6RJk3jttddwOBzExsa2OWd2djbZ2dmB18H8pCSfXoND5rHjZA47Tuaw42QOO647\n7+jbfUafmprK2bNnqaqqwuv1snv3bjIyMtq0cTgc6LoOwLp168jMzASgqakJj8cTaFNaWhrYxLdl\nyxY+//xzHn/88cDdP0B9fT1fF9QrLy9H13ViYmIu62KEEEII0Va7d/RGo5Hly5fz/PPPo+s6mZmZ\npKSksHbtWlJTU8nIyODw4cOsWbMGTdNIT09nxYoVAJw+fZo//elPGAwGdF1nyZIlgUD/5z//mcTE\nRP7lX/4F+OZrdHv27KGwsBCj0UhoaCiPP/44mqZd8YUppXC5XOi6fsXvP3/+PG63+4rP2VsppTAY\nDISHh1/V34UQQojuo/WWevTf3ozndDoJCQnBZLrybQgmkwmv1xusofUKXq8Xj8dDRETEZb9Hlvs6\nTuaw42QOO07msON69NL9tUrX9asK8uLSTCZT4PGMEEKIa0evDfSyxBx8MqdCCHHt6bWBXgghhOhu\nSimOVLXwys5jdNeTcgn0PZxSirvuuovGxkYaGhp44403rqqfZcuW0dDQ8L1tnnvuOXbu3HlV/Qsh\nhPiG26tTVFHPkwXHeWbzST44dI7q5u7Z+yWBvofbsmULo0aNIiYmBofDwZtvvnnJdu1tHly9ejVx\ncXHf22b58uW8/PLLVz1WIYTo6843tfI/n1WxYl05f9hzDp8Oj0xNZt2KqfSLDumWMclutU62fPly\nzpw5g9vtZsWKFdxzzz1s27aNF154AZ/Ph8Vi4a9//SvNzc2sXLmSL774Ak3TeOKJJ1i8eDHr1q1j\n6dKlAPzHf/wHJ06c4KabbmL27NlkZWXx4osvEhcXR3l5OTt37rzk+cD/9cWCggKam5u55557mDp1\nKvv37yc5OZlVq1YRERHBoEGDsNvtVFVVBaoHCiGE+H5KKT4/10K+zc6+000ATE+JYXGamdH9ItA0\njYgQI83dNL4+Eej1t/+MOnXs8ttrWrvPUrSUoRj+4R/b7es3v/kNZrMZp9PJ4sWLWbBgAT//+c95\n7733GDx4MHa7HYDf/va3xMTEsGXLFsCfOAhg3759/O///b8B+OUvf0lpaSmbN28GYPfu3Rw8eJCt\nW7cGSgF/+3w5OTkXZTI8duwYL7/8Mi+++CIPP/ww+fn53HHHHQCMHTuWffv2sXjx4sudLiGE6JNa\nPD62HXWQb7NT6WglLszIHaOsLBgRT2KU/+5dKYU69AkNhz5B3f0gmqHrF9L7RKDvTqtWraKgoADw\nf9f/rbfeYvr06YHAbDabAdixYwevvPJK4H3x8fGAP+BHR0d/Z/8TJkwI9HWp8x07duyiQJ+SkhKo\nNjhu3DhOnToVOGa1Wjl//vxVX68QQvR2lQ43+bZ6tlY04PTqjLCG8/iM/swaEkOo0R/Ile6Dz/ag\n5/8NTh6lNSEJbloC1q5fLe0Tgf5y7rwvFKyEObt372bHjh18+OGHREREcOeddzJ69GgqKiquaCy6\nrrdJE3yhC2sDXOp8l8rwFxYWFvh/o9GIy+UKvHa73YSHh1/2+IQQoi/w6YpPzzSzwWbnwNlmTAa4\nYXAsi0eaSUv4JpGY8npRe7ejNr4D505D0kC0+x8lIecOatvZEN1Z+kSg7y6NjY3ExcURERFBeXk5\nn376KW63mz179nDy5MnA0r3ZbGb27Nm88cYbPPfcc4D/Tj4+Pp5hw4Zx4sQJhg4dSlRUFE1NTVd0\nvit19OhRbr755qu+ZiGE6E2a3D6KjtZTYKvnXJMHS4SJpeMSmD88nviIb0KoanWjdhWhNr4HddWQ\nMhTDw7+ASTPQDEa0kO7ZiAcS6DvV3LlzWb16NXPmzCE1NZVJkyZhtVr59a9/zYMPPoiu6yQkJPD2\n22/z2GOP8ctf/pJ58+ZhMBh48sknycnJISsri48//pihQ4disViYMmUK8+bNIzMzk6ysrHbPdyU8\nHg/Hjx9n/PjxwZwGIYS45hy3u8iz2Sk+5qDVpxjdL4J7JyQyLSUGk+Gb5GHK2YIqLkBtXg+NDTA8\nHcM9j8CYST0myVivzXXf0tLSZln7SvSkXPfnz5/nscce4+233+70cxUUFHDw4EF+8YtfXPL4lc6p\n5MfuOJnDjpM57Li+ModeXbH3VCN5NjtfVjkJNWrMHRpLTpqZoea2jzRVowO15QPUtjxoaYbREzHk\n3IWWNuaSfXdnrnu5o+/hkpKS+PGPf0xjY2Onl+v1er08/PDDnXoOIYToaepdXgrL6tlYVk+t00tS\ndAgPTEoka1g8MWHGNm2VvRZVuB710UZodcOkGf4AP2R4N42+fRLorwG33nprl5znlltu6ZLzCCFE\nT2CrcZJXamfnyUa8umJC/yh+OjWZSQOiMBraLrurqrOoTe+hdm8BXUebNgdt4R1oAwZ/R+89hwR6\nIYQQfYbHp7PzhH95vqzWRYTJwIIR8eSkxTMoNuyi9ur0CVT+O6h9O8BoRLthPtqCH6AlJHXD6K+O\nBHohhBC9Xk2Lh422egrL62lw+xgUG8pDGUlkDoslMsR4UXt1tBS94B04sBfCItDm34Z20xK0OHM3\njL5jJNALIYTolZRSfFnlJM9mZ8+pRpSCqYOiyUkzMz458qJd8Uop+OoLf4A/8jlExaDd+mO0eYvR\nojp3j1RnkkAvhBCiV3F5dbYfc5Bns3Oi3k10qIHbrrewKC2epOjQi9orXYcv9vmz2B2zQZwF7a4H\n0GYvRAuPuMQZri2XFegPHDjA66+/jq7rZGVlsWTJkjbHq6urefXVV3E4HERHR5Obm4vVagXghz/8\nYSBFa0JCAk8//TQAVVV
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435a3feb8>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.69e-01\n",
" final error(valid) = 1.71e-01\n",
" final acc(train) = 9.51e-01\n",
" final acc(valid) = 9.52e-01\n",
" run time per epoch = 10.14\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=1.00\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VPW9//HXd2ayJzOQSUiAEDIEEDDsYYtAEvZNWWRp\na29/vfq719Jetba9KmpbWq8tt269rfd2oZbqz942gKIYVgMkYRFEERGLLDIh7IEEMlnIer6/P0Zj\nECSBmWROks/z8ejjYZIzM99594R3zvnOOV+ltdYIIYQQwvQsgR6AEEIIIZpHSlsIIYRoI6S0hRBC\niDZCSlsIIYRoI6S0hRBCiDZCSlsIIYRoI6S0hRBCiDZCSlsIIYRoI6S0hRBCiDZCSlsIIYRoI2yB\nHsD1nDlzxq/PFxMTw8WLF/36nB2NZOg7ydB3kqHvJEP/8HeO3bp1a9Z2cqQthBBCtBFS2kIIIUQb\nIaUthBBCtBGmnNMWQgjROrTWVFVVYRgGSqlAD6fNOH/+PNXV1Tf1GK01FouF0NDQW85aSlsIITqw\nqqoqgoKCsNmkDm6GzWbDarXe9OPq6uqoqqoiLCzsll5XTo8LIUQHZhiGFHYrstlsGIZxy4+X0hZC\niA5MTom3Pl8yb/elbbzxKjWffBToYQghhBA+a9elrT2X0HkbubTkfuqf/zH68MFAD0kIIUQALF++\nnFWrVgGQlZXFuXPnbvo5XnnllYbn+CqHDh3i+9///i2NsTna9USGsnfGsuxPhL+XT/nrr2I8+zj0\nGYBl5iIYMEROCwkhRBtRX19/1Qe/vvz1V6mrqwO8Rb1x40YAVq1aRb9+/YiPj2/ydRr71re+1eTr\n9e/fn7Nnz3L69Gm6d+/e5PY3q12XNoAKCSVi9jeoHJGO3vE2euPrGL/+Kbj6est7UKqUtxBCBNhr\nr73Gn//8Z2pqahg6dCi//OUv6devH9/85jfZvn07v/jFL3jggQe46667yM/P57vf/S7Jyck89thj\nVFVV0bNnT5577jk6derE/PnzGTBgAHv37mX27Nn079+flJQUbDYb2dnZfPjhh/zbv/0boaGhrF27\nloyMjKuet7y8nL/+9a/U1NTgcrn4zW9+Q1hYGM899xwRERF85zvfYe7cuQwZMoRdu3ZRWlrKc889\nx6hRowCYPHkyb775Jt/97nf9nlO7L+3PqeAQ1IRZ6HFT0e9sQa9fjfHiU9DD5S3voaNRlnY9WyCE\nEDdk/H05+qTbr8+periwfO1fbrjN0aNHWbt2LW+88QZBQUEsWbKE119/ncrKSoYOHcpPf/rThm07\nd+7Mpk2bAJg0aRJPPfUUY8aM4ZlnnuH555/n5z//OQC1tbVs2LABgGeffZZBgwYBMGvWLP7yl7/w\n4x//mMGDB1/3eUtKSrjnnnsA+M///E/+9re/ce+9914z7rq6OtatW8eWLVt4/vnnycrKAmDw4MG8\n+OKLgSvt/fv3s2LFCgzDYOLEicyZM+eqn2dnZ7NlyxasVit2u53FixcTGxvLwYMHefnllxu2O3Pm\nDA899BAjR47077u4CSooCDV+GjptEvrdPG95/34ZdO2BmrkQNWIsynLz194JIYS4NTt27OCjjz5i\nxowZgPfa8ZiYGKxWKzNnzrxq27vuugsAj8dDaWkpY8aMAWDBggXcf//912wHUFRURJ8+fW44hsbb\nHz58mF/96ld4PB4qKipIT0+/7mM+H++gQYM4depUw/edTifnz59v8n3fiiZL2zAMXnrpJZ588kmc\nTidLliwhNTWVhISEhm2SkpJYtmwZISEhbN68mVdffZWHH36YlJQUnnnmGQDKy8t54IEHrvrLJpCU\nzYZKm4genYF+byd63Ur0n55Dr/0basYC1Kh0lFy7KIToQJo6Im4pWmsWLFjAkiVLrvr+73//+2vm\nl8PDw5v1nI23Cw0NpaqqqtnbP/zww7z00kvcfvvtZGVl8c4771z3McHBwQBYrdaGuXOA6upqQkND\nmzXOm9Xk+eBjx44RHx9PXFwcNpuNtLQ09u7de9U2KSkphISEANCnTx9KSkqueZ7du3czdOjQhu3M\nQlmsWEaOx/LT32BZ/BiEhqL/8l8YT34HI28jurY20EMUQoh2bezYsWRnZzcsdXnp0qWrjlyvx263\n43A42LNnD+CdEx89evR1t+3duzcFBQUNX0dERFBeXv6Vz11eXk5cXBy1tbWsWbPmJt8NHD9+nNtu\nu+2mH9ccTR5KlpSU4HQ6G752Op0cPXr0K7ffunUrQ4YMueb7O3fuZNasWbc4zJanLBYYloZl6Bj4\n6D2M7Cz0q/+Dzs5CTZuHGjcFFWyuPziEEKI96Nu3L4888ghf//rX0Vpjs9l4+umnm3zcr3/964YP\noiUmJvL8889fd7sJEybw4IMPNny9cOFCHnvssYYPon3Zv//7vzNr1iycTidDhw69YcFfz65du5g4\nceJNPaa5lNZa32iD3bt3s3//fr7zne8AkJ+fz9GjR7nvvvuu2TY/P59NmzaxdOlSgoKCGr5/6dIl\nfvSjH/GHP/zhurfLy8nJIScnB4Bly5ZRU1Pj05v6MpvNdtWpi+bQWlNz4D0qVq6g9h/7sXSKJnz2\nNwibOgdLWPNOz7Qnt5KhuJpk6DvJ0HdfzvD8+fOmOwPaEr797W/zk5/8hF69erXo61RXVzNnzhze\neuutr7w9bHV1NXFxcVd97/NT7U1p8kg7Ojqa4uLihq+Li4uJjo6+ZrsDBw6wZs2aawob4J133mHk\nyJFf+QYmTZrEpEmTGr7+/BSJv8TExNzac3Z3wcM/x3LkIEZ2FuUvv0j5ay+jJs1GZc5EhUf4dZxm\ndssZigaSoe8kQ999OcPq6upbWviirXnsscc4c+YMiYmJfnm+r/oD8sSJEw1z81/1B2Z1dfU1+3G3\nbt2a9bpNzmknJydz9uxZioqKqKurY9euXaSmpl61jdvtZvny5TzyyCM4HI5rnmPnzp3ccccdzRqQ\nGam+KVh/8BSWx34FrtvQb7yKseT/Yrz5v+iKskAPTwghRBN69+79lXPe/tSrVy/S0tJa7PmbPNK2\nWq3ce++9PP300xiGQWZmJj169CArK4vk5GRSU1N59dVXqaqqaphPiImJ4dFHHwW8H7W/ePEiAwYM\naLE30VpUcj+sD/4EfeJTjHVZ6Oy/o99+E5U5AzV5NsreKdBDFEII0Y41OacdCGfOnPHr87XUKTV9\n+oT3UrH3dsBn13+rqXNRnZxNP7iNkdOSvpMMfScZ+u7LGVZWVjb7MirxBV8+X3G9zJt7elwuRPaB\n6t4T9a//jr7r6+j1q9Fbs9G5G1BjJ3s/ce7sEughCiGEaEektP1AxSeg7v0++s6voTesRm/fjN6+\nCTVmAmr6fFSXroEeohBCiHZAbrbtRyo2Hsu3/g3LL/7gvVXq7lyMHy/GeOkF9Nkb3yhACCFEy2m8\nNOfN+v73v092djYAP/rRjzhy5Mg122RlZfHEE08AsGLFCv7+97/f+mBvQI60W4CKjkV94370jAXo\nt99A525A78lFDb/De3/zhKRAD1EIIdoUfy7N6Ytnn322yW2+9rWvMXv2bL72ta/5/HpfJkfaLUh1\nisay4F4sy/6EmnY3+uD7GD97kPr/fhp94lighyeEEKbx2muvMXPmTCZPnswjjzxCfX09ffr04Wc/\n+xmTJk3i/fffZ9SoUTz99NNMnTqV7OxsDh48yKxZs5g0aRL33Xcfly9fBmD+/Pn85Cc/Yfr06fzp\nT39i586dDUtzHjt27KpFSE6ePNlw97IXXniBGTNmMGHCBB555BGu9znt+fPn8+GHHwLePwTGjh3L\nzJkzee+99xq2CQsLo0ePHnzwwQd+z0mOtFuBinKg5n0LPXUeestb6C1rMf5jD6QMxzJzIap3/0AP\nUQgh+NN753FfuvHCGjfL1TmU/5sad8NtWnNpzt69e1NTU0NhYSGJiYmsXbuWO++8E/DeNe3hhx8G\n4IEHHuDtt99mypQp1x3
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435f2d9e8>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAENCAYAAAAbl4wiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VGWa//3PXdlTqSxVgQTIAoTFsC8BwpJIJCCyKA2o\n7YIL2NrTNvhoa0/r+Pymn3m1M8xo/6afnkd9zfhTZ7SHaVxAWyGKUZCECAkqArIvKQhJyApVSSqV\nVJ37+aMwSosmkIJKJdf7Lyt16pzr3JL65mzXrbTWGiGEEEIELVOgCxBCCCFE90iYCyGEEEFOwlwI\nIYQIchLmQgghRJCTMBdCCCGCnIS5EEIIEeQkzIUQQoggJ2EuhBBCBDkJcyGEECLISZgLIYQQQS40\n0AVcrsrKSr+tKzExkbq6Or+try+SMew+GUP/kHHsPhnD7vP3GA4cOLBLy8mRuRBCCBHkJMyFEEKI\nICdhLoQQQgS5oLtm/te01rS2tmIYBkqpy/rs2bNncbvdV6my4KO1xmQyERkZedljKYQQInCCPsxb\nW1sJCwsjNPTydyU0NJSQkJCrUFXw8ng8tLa2EhUVFehShBBCdFGXEnDPnj28+uqrGIbBnDlzWLJk\nyUXv19bW8uKLL+JwOIiJiWH16tXYbLaO91taWnjssceYMmUKq1atAuDEiRM8//zztLW1MXHiRO6/\n//4rOho0DOOKglxcWmhoqJytEEKIINPpNXPDMHj55Zd56qmn+Nd//Vd27NhBRUXFRcu8/vrr5Obm\n8txzz7F8+XLWrVt30fvr168nMzPzop+99NJLPPTQQ/zxj3+kurqaPXv2XNEOyOlg/5MxFUKI4NJp\nmB87dozk5GSSkpIIDQ1lxowZlJWVXbRMRUUFY8aMAWD06NHs3r27470TJ05w/vx5xo8f3/GzxsZG\nXC4XI0aMQClFbm7u99YphBBCBAvt8WDsKKR1x8cB2X6n56cbGhouOmVus9k4evToRcukp6dTWlrK\nggULKC0txeVy4XQ6MZvNvPbaa6xevZp9+/b96DobGhouuf3CwkIKCwsBWLt2LYmJiRe9f/bs2W6d\nZu8Jp+i11ixbtoz/+q//wjAMNmzYwP3333/Z67nzzjt58cUXiYuL+8Flfvvb3zJnzhxycnJ+cJmI\niIjvjfMPCQ0N7fKy4tJkDP1DxrH7ZAwvn25vw/XxJpo3vI6urcY9fTaJM+dc8zr8kmQrVqzglVde\nYdu2bWRmZmK1WjGZTGzZsoWJEydeFNyXKz8/n/z8/I7Xf91Zx+12X/FNbKGhoXg8niuuzV8KCwvJ\nzMwkKiqK06dP8+qrr7JixYrvLefxeH70j4/XXnutY7kfct999/HEE08wffr0H1zG7XZ3uYORdIzq\nPhlD/5Bx7D4Zw67Tbje66EP0hxvgXAMMHYnpjgeJnX1jQDrAdRrmVquV+vr6jtf19fVYrdbvLfP4\n448DvrvLd+3ahdls5siRIxw8eJAtW7bQ2tqKx+MhMjKSBQsWdLrOYLJy5UoqKytxu92sWrWKu+++\nm61bt7J27Vq8Xi9Wq5U33niD5uZmnn76afbu3YtSikcffZSFCxeyceNG7rrrLgD+8R//Ebvdzty5\nc8nNzWXOnDk8++yzxMXFcezYMYqLiy+5PYBp06ZRUFBAc3Mzd999N1OnTmX37t0kJyfzyiuvEBUV\nRUpKCo2NjdTU1NC/f/9ADpsQQgQd3epCf1qA/nAjOM/DiDGYVj4K141DKRWwe446DfOMjAyqqqqo\nqanBarVSUlLCmjVrLlrmm7vYTSYTGzduJC8vD+Ci5bZt28bx48c7QisqKoojR44wfPhwtm/fzvz5\n87u9M8afX0KfPtn15ZVCa/2jy6jUIZh++rMfXeb3v/89CQkJuFwuFi5cyI033sgTTzzBhg0bSEtL\no7GxEYA//OEPWCwWPv7Yd03l3LlzAJSVlfHP//zPADz11FMcPnyYjz76CICSkhL27dvHJ598Qlpa\n2iW3t2DBgu/9MXTy5Emef/55nn32WR566CE2b97MsmXLABg7dixlZWUsXLiwy2MlhBB9mW5pRn/y\nPrrwL9DshFETMS28DTVidKBLA7oQ5iEhIaxcuZJnnnkGwzDIy8sjNTWV9evXk5GRQVZWFgcOHGDd\nunUopcjMzOx4/OzHPPDAA7zwwgu0tbUxYcIEJk6c6JcdCoRXXnmFgoICwDcRzJ/+9Ceys7M7wjch\nIQGAoqIiXnjhhY7PxcfHA75Qj4mJ+cH1T5gwoWNdl9reyZMnvxfmqampHTcljhs3jtOnT3e8Z7PZ\nOHv27BXvrxBC9BW6yYEu/Av6k03gaoZxU3whPnRkoEu7SJeumU+aNIlJkyZd9LPbb7+947+zs7PJ\nzs7+0XXMnj2b2bNnd7zOyMjg97///WWU2rnOjqD/mj+umZeUlFBUVMR7771HVFQUy5cvZ/To0Rw/\nfvyy6jAMA5Pp0g8XREdH/+j2LvVceERERMd/h4SE0Nra2vHa7XYTGRnZ5fqEEKKv0Y5G9JZ30dsK\nwO2CSTMwLbwVlZYR6NIuSXqzd5PT6SQuLo6oqCiOHTvGF198gdvtZufOnZw6dQqg4zR7bm4u//mf\n/9nx2W9Osw8dOhS73Q6A2WymqanpsrZ3uU6cOMHIkT3rr0ohhOgJdGM9xp9fwnjyZ+gt76DGT8H0\n2/+PkL/5TY8NcugF7VwDbfbs2bz++utcf/31ZGRkMGnSJGw2G//yL//CAw88gGEYJCYm8uc//5lH\nHnmEp556ihtuuAGTycRjjz3GggULmDNnDp999hlDhgzBarUyZcoUbrjhBvLy8pgzZ06n27sc7e3t\nlJeXX/TcvxBC9HW6vgZd8BZ6RyEYBio7D3XTclTyoC6vw9AapzswT0gp3dkdYD1MZWXlRa9bWlou\nOg19OXrKo2lnz57lkUce4c9//vNV31ZBQQH79u3j17/+9Q8uczljKo+ydJ+MoX/IOHZfXxxDXVOJ\n3vwWeudWQKFm5qPmL0X1S+7a57XmaH0rRXYHO+xOJqcl8HDWlT+O/df89miauPqSkpK48847cTqd\nWCyWq7otj8fDQw89dFW3IYQQPZ2uOo3e9Aa6tAhCQ1HX34S6cSnK2nnTHK019nNuiuxOiu0Oqpva\nCTUpJg80M3uY/4L8ckiY9xA333zzNdnO4sWLr8l2hBCiJ9KnT/pC/IsSCI9Azb0FNW8JKi6h08+e\ncbRRbHdQZHdw+nwbJgXjks3cOsZGdqqFmPCQgJ3dkDAXQgjR6+mTRzE2rYevSiEyync9PP8WlCX2\nRz9X29xOkd1Bsd3B8QY3ChjVP4qfT0liepqF+MieEaM9owohhBDiKtDHDmC8vx6+/hKiY1A334m6\nYRHK/MO9PRpdHnacclBU7uRQnQuA4bZIVk7qz8x0C4nRYdeq/C6TMBdCCNGraK3h0F6MTW/A4X1g\niUMtvRc1+yZU1KVv7nW6vXx22kmR3cH+sy0YGgbHR7BifD9mpVtItoRf4724PBLmQgghegWtNXz9\nhe9I/PghiLOibluFyr0RFfH9Rlkt7V5KK5ooKnfwZVUzXg0DLWEsH20jZ3AsaXERl9hKzyRNY3oA\nrTW33norTqfzij4/fPhwAKqrq/nZzy7dBW/58uV89dVXgK973zcNa4QQIthprdF7dmI88yuM//f/\ngcY61J0/x/RP/4Fp7i0XBbnbY7DjlIO1289w79vH+NeSKuzn3Nx8nZX/fdNgXlg8lLvG9wuqIAc5\nMu8RPv74Y0aNGtXtx9KSk5N56aWXOl3um7nTH3nkkW5tTwghAkkbBnxR4judXlEO/ZJR9/wSNT0P\nFfrtde12r2ZPVTNFdge7Kppo9RjER4Ywd1g8OekWRiZGYQrQbGf+ImHuB/6eAnXgwIHcd999gG+G\nNLPZzIoVK7j//vs5f/48Ho+HX//619x4440X1XH69GnuvfdePvnkE1wuF4899hgHDhxg2LBhF/Vm\nnzdvHkuXLpUwF0IEJe3
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435de5588>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 2.05e-01\n",
" final error(valid) = 2.07e-01\n",
" final acc(train) = 9.40e-01\n",
" final acc(valid) = 9.39e-01\n",
" run time per epoch = 10.99\n"
]
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"num_epochs = 10 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"learning_rate = 0.2 # learning rate for gradient descent\n",
"\n",
"init_scales = [0.1, 0.2, 0.5, 1.] # scale for random parameter initialisation\n",
"final_errors_train = []\n",
"final_errors_valid = []\n",
"final_accs_train = []\n",
"final_accs_valid = []\n",
"\n",
"for init_scale in init_scales:\n",
"\n",
" print('-' * 80)\n",
" print('learning_rate={0:.2f} init_scale={1:.2f}'\n",
" .format(learning_rate, init_scale))\n",
" print('-' * 80)\n",
" # Reset random number generator and data provider states on each run\n",
" # to ensure reproducibility of results\n",
" rng.seed(seed)\n",
" train_data.reset()\n",
" valid_data.reset()\n",
"\n",
" # Alter data-provider batch size\n",
" train_data.batch_size = batch_size \n",
" valid_data.batch_size = batch_size\n",
"\n",
" # Create a parameter initialiser which will sample random uniform values\n",
" # from [-init_scale, init_scale]\n",
" param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
" # Create a model with two affine layers\n",
" hidden_dim = 100\n",
" model = MultipleLayerModel([\n",
" AffineLayer(input_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, output_dim, param_init, param_init)\n",
" ])\n",
"\n",
" # Initialise a cross entropy error object\n",
" error = CrossEntropySoftmaxError()\n",
"\n",
" # Use a basic gradient descent learning rule\n",
" learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
" stats, keys, run_time, fig_1, ax_1, fig_2, ax_2 = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
" plt.show()\n",
"\n",
" print(' final error(train) = {0:.2e}'.format(stats[-1, keys['error(train)']]))\n",
" print(' final error(valid) = {0:.2e}'.format(stats[-1, keys['error(valid)']]))\n",
" print(' final acc(train) = {0:.2e}'.format(stats[-1, keys['acc(train)']]))\n",
" print(' final acc(valid) = {0:.2e}'.format(stats[-1, keys['acc(valid)']]))\n",
" print(' run time per epoch = {0:.2f}'.format(run_time * 1. / num_epochs))\n",
"\n",
" final_errors_train.append(stats[-1, keys['error(train)']])\n",
" final_errors_valid.append(stats[-1, keys['error(valid)']])\n",
" final_accs_train.append(stats[-1, keys['acc(train)']])\n",
" final_accs_valid.append(stats[-1, keys['acc(valid)']])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |\n",
"|------------|--------------------|--------------------|------------------|------------------|\n",
"| 0.1 | 1.75e-01 | 1.72e-01 | 0.95 | 0.95 |\n",
"| 0.2 | 1.70e-01 | 1.69e-01 | 0.95 | 0.95 |\n",
"| 0.5 | 1.69e-01 | 1.71e-01 | 0.95 | 0.95 |\n",
"| 1.0 | 2.05e-01 | 2.07e-01 | 0.94 | 0.94 |\n"
]
}
],
"source": [
"j = 0\n",
"print('| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |')\n",
"print('|------------|--------------------|--------------------|------------------|------------------|')\n",
"for init_scale in init_scales:\n",
" print('| {0:.1f} | {1:.2e} | {2:.2e} | {3:.2f} | {4:.2f} |'\n",
" .format(init_scale, \n",
" final_errors_train[j], final_errors_valid[j],\n",
" final_accs_train[j], final_accs_valid[j]))\n",
" j += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models with three affine layers"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.10\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8U9ed9/HPkeTdluSF3ay2zL6bfTUYsCAkbGnTpu3M\nJDNt0nXS6SQhSTt00nSYNkvb6TzN89CUtkPbMCwmhCDbmD0QEhJCCAmJZcy+OWDL+67z/KHEhUKw\n8ZUtyf69X6++XjG+0j369trH597zO0dprTVCCCGECHqmQDdACCGEEK0jnbYQQggRIqTTFkIIIUKE\ndNpCCCFEiJBOWwghhAgR0mkLIYQQIUI6bSGEECJESKcthBBChAjptIUQQogQIZ22EEIIESIsgW7A\nrVy8eNGv75eUlMTVq1f9+p5djWRonGRonGRonGToH/7OsXfv3q06TkbaQgghRIiQTlsIIYQIEdJp\nCyGEECEiKJ9pCyGE6Bhaa2pra/F6vSilAt2ckHHlyhXq6uru6DVaa0wmE5GRkW3OWjptIYTowmpr\nawkLC8Nike7gTlgsFsxm8x2/rrGxkdraWqKiotp0Xrk9LoQQXZjX65UOuwNZLBa8Xm+bXy+dthBC\ndGFyS7zjGcm803fa3i3rqD9xLNDNEEIIIQzr1J22Li9F73VR+sRDNP3scfT776C1DnSzhBBCdLA1\na9awYcMGANavX8/ly5fv+D3++Mc/Nr/H5zlx4gT//M//3KY2tkanfpChrPGYVr9EzJEDVGT/Ce+v\nfgx9B6KcK1Djp6JMdz6JQAghRMdramq6YeLX3379eRobGwFfR52TkwPAhg0bGDJkCD179mzxPNf7\n2te+1uL5hg4dyqVLl7hw4QJ9+vRp8fg71ak7bQAVEUn04i9SNWEm+s296JxN6P/3c3T3XqgFy1BT\n5qDCwgLdTCGE6NI2bdrE7373O+rr6xk7diz/8R//wZAhQ/jKV77C/v37+elPf8p3vvMd7r77bvbt\n28c3v/lNUlJSePzxx6mtraV///4899xz2O12VqxYwbBhwzh8+DD33HMPQ4cOZcSIEVgsFrZt28Z7\n773Ht7/9bSIjI9m6dSuzZ8++4X0rKyv505/+RH19PQMHDuRXv/oVUVFRPPfcc8TExPDQQw+xdOlS\nxowZw8GDBykrK+O5555j0qRJAMybN49XXnmFb37zm37PqdN32p9RljDUtEz0lAx49028ro3o//lv\n9Kt/Qc27BzUzCxXZtin4QgjRGXhfXoM+d8qv76n6DsR03z/d9hi3283WrVvZsmULYWFhrFy5ks2b\nN1NdXc3YsWP5t3/7t+Zj4+Pjyc3NBSAzM5Onn36aKVOm8POf/5znn3+ef//3fwegoaEBl8sFwLPP\nPsuoUaMAuOuuu/j973/PD3/4Q0aPHn3L9y0pKeH+++8H4D//8z/5y1/+wgMPPHBTuxsbG3nttdfY\nuXMnzz//POvXrwdg9OjR/PrXv5ZO2x+UyQzjp2IaNwVOvOfrvDesRb+2ATXnLtTcu1Cx1kA3Uwgh\nuozXX3+d999/n4ULFwK+2vGkpCTMZjOLFi264di7774bgPLycsrKypgyZQoA9957L9/4xjduOg6g\nuLgYh8Nx2zZcf/zHH3/Mz372M8rLy6mqqmLWrFm3fM1n7R01ahTnz59v/vfExESuXLnS4uduiy7X\naX9GKQXDxmAeNgZd9DFe1yb0tpfRedmomQtQ85agEpIC3UwhhOgwLY2I24vWmnvvvZeVK1fe8O8v\nvvjiTc+Xo6OjW/We1x8XGRlJbW1tq49/5JFHeOmllxg+fDjr16/njTfeuOVrwsPDATCbzc3PzgHq\n6uqIjIxsVTvvVKeePd5aatBgzN96AtOPf40aPxW9axveJ76O9/e/Ql8+3/IbCCGEaLPp06ezbdu2\n5q0uS0tLbxi53orVasVms/Hmm28CvmfikydPvuWxqampnD59uvnrmJgYKisrP/e9Kysr6dGjBw0N\nDWRnZ9/hp4GioiIGDx58x69rjS470r4V1bsf6oFH0Pfcj87NRr++A31wJ4ybgsm5AtU/NdBNFEKI\nTictLY1HH32UL33pS2itsVgsPPPMMy2+7he/+EXzRLR+/frx/PPP3/K4OXPm8N3vfrf56y984Qs8\n/vjjzRPR/ta//uu/ctddd5GYmMjYsWNv28HfysGDB5k7d+4dvaa1lA7CwuWLFy/69f3aulm5Lveg\nd76K3r0daqpg2FhMC1dA2ogut4qQvzd874okQ+MkQ+P+NsPq6upW33IOZQ8++CBPPvkkgwYN8sv7\nWSyWG26Jf6auro7ly5ezZcuWz10e9laZ9+7du1Xnldvjt6GsdkxLv4pp9W9Ry/4Ozp/C++yTeFc/\nij76JtrA+rFCCCE6zsqVKykuLm7381y4cIEnnnii3dZzl5H2HdD1deiDO9G52XD1CvTuh8pajpow\nA9XJF9yXEY5xkqFxkqFxXXWk7W+fN9JuDRlpdxAVHoFp9kJMP3kR9eD3QSn0717A+9RDeHe/hq6/\ns71VhRBCiDvRuYeH7USZzajJs9ETZ8L7b/tqvf/8f9GvvozKvBs1eyEqOibQzRRCCNHJSKdtgDKZ\nYPRETKMmQMEHeF0b0Nn/g87ZhJrt9HXg1vhAN1MIIUQnIZ22HyilYPAIzINHoM+cRLs2onM2o/Nf\nRU3LRC1YikrqEehmCiGECHHyTNvPVP8UTA89hunp3/huoe/Pw/vkN/C+9Dz6wtlAN08IIbqk67fm\nvFP//M//zLZt2wD4wQ9+QEFBwU3HrF+/nieffBKAtWvX8vLLL7e9sbchI+12onr0Rn3t2+jFX0Lv\n2ILel4s+tMd3O925ApUyJNBNFEKIkOHPrTmNePbZZ1s85r777uOee+7hvvvuM3y+vyUj7Xam4hMx\nfeFBX6334i9B4Qm8qx+l6dkn0R+8SxBW3AkhRIfbtGkTixYtYt68eTz66KM0NTXhcDj48Y9/TGZm\nJu+88w6TJk3imWeeYcGCBWzbto3jx49z1113kZmZyYMPPojH4wFgxYoV/OhHP8LpdPLb3/6WAwcO\nNG/NWVhYeMMmJOfOnWteveyFF15g4cKFzJkzh0cfffSWv59XrFjBe++9B/j+EJg+fTqLFi3i7bff\nbj4mKiqKvn378u677/o9JxlpdxAVa0Xd/SX0/CXo/XnovC14f/Fv0C8Fk3M5jJvi24FMCCEC5Ldv\nX+FU6e031rhTA+Mj+cf028/p6citOVNTU6mvr+fs2bP069ePrVu3snjxYgD+/u//nkceeQSA73zn\nO+zYsYP58+ffss1Xrlzh2WefJScnh7i4OO69915GjBjR/P1Ro0bx5ptvMnbs2LbE9rla1WkfPXqU\ntWvX4vV6mTt3LkuWLLnh+9u2bWPnzp2YzWasVisPP/ww3bp1A2DdunUcOXIErTUjR47kH/7hH7rc\nEqDXU5FRqHn3oGcvRB/ajc7ZjPf//gx69PFNWJuSgbKEBbqZQgjRYTp6a87FixezdetWvv3tb7N1\n61Z+85vfAL41w3/zm99QU1ODx+Nh8ODBn9tpv/POO0yZMoXExMTm8xUVFTV/PykpicLCwjZn8nla\n7LS9Xi8vvfQSTz31FImJiaxcuZL09HSSk5ObjxkwYACrV68mIiKCvLw81q1bxyOPPMLHH3/Mxx9/\n3PwM4Ic//CEffvghw4cP9/sHCTUqLAw1Yz562lw48oav1vuPv0Zv/Qtq/hLf9qAR7bO1mxBC3EpL\nI+L20tFbc95999184xvfwOl0opRi0KBB1NbW8sQTT7B9+3b69OnDc889R11d2xfMaq/tOVt8pl1Y\nWEjPnj3p0aMHFouFqVOncvjw4RuOGTFiBBEREQA4HA5KSkoAXylUfX09jY2NNDQ00NTUhM1m8/uH\nCGXKZEalT8f01AuYvrcKuvdC/+9LeB97EO/Wv6CrKgLdRCGEaFcdvTXngAEDMJvN/OIXv2gekX/W\nQSckJFBVVcVrr7122/O
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435b57b00>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAENCAYAAAAbl4wiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xt4VOW5///3mpmczzOBBEg4JaDhDIaDiIGQECyIxVOt\nIliBqlWD31raS3vx6+63u7R0g1d7tZe6u93irnSnpd/WYLXahAARNEiCFgVBzYRjCJDjTCaHSTKz\nnt8fEwIRJECGzExyv/5ynMWaZ25C7tyzPnmWppRSCCGEECJgGXy9ACGEEEL0jjRzIYQQIsBJMxdC\nCCECnDRzIYQQIsBJMxdCCCECnDRzIYQQIsBJMxdCCCECnDRzIYQQIsBJMxdCCCECnDRzIYQQIsCZ\nfL2Aa1VVVeW1c8XHx1NbW+u18w1EUsPekxp6h9Sx96SGveftGg4dOvSqjpPJXAghhAhw0syFEEKI\nACfNXAghhAhwAXfN/KuUUjidTnRdR9O0a/qz586do62t7QatLPAopTAYDISGhl5zLYUQQvhOwDdz\np9NJUFAQJtO1vxWTyYTRaLwBqwpcLpcLp9NJWFiYr5cihBDiKgX8x+y6rl9XIxeXZzKZ0HXd18sQ\nQghxDQK+mcvHwd4nNRVCiMAiI60QQgjRSy0dbt4/4WBwvWKKue8HooCfzPsDpRT3338/DocDu93O\n//zP/1zXeZYvX47dbr/iMT/72c94//33r+v8QgghurPWOXlp31kefaOCF/edZeeXvtl0RyZzP7Bj\nxw7GjRtHVFQUp06d4vXXX+c73/nOJce5XK4r5gO2bNnS42utXLmSH/7wh8yZM6c3SxZCiAGrpcPN\ne8ca2V5ho6K+jWCjxu0jolk4JpbZNyVRV1fX52uSZu4FK1eupKqqira2NlatWsXDDz/Mrl272LBh\nA263G7PZzF/+8heam5tZt24dn376KZqm8f3vf5/FixeTn5/PsmXLAPjFL37BiRMnWLBgARkZGWRl\nZbFx40ZiYmKwWq28//77l309gJkzZ/Luu+/S3NzMww8/zIwZM9i/fz+JiYls3ryZsLAwkpKSaGho\noLq6msGDB/uybEIIETCUUpTXOSmw2thzvJE2t2JkbAiPT08gY2Q0kcGe34zyVeaoXzVz/c+voE4d\nu/rjNQ2l1BWP0ZJHYfj2d694zAsvvEBcXBytra0sXryYhQsX8sMf/pA33niD4cOH09DQAMBvfvMb\noqKi2LFjBwA2mw2AsrIyfvWrXwHw4x//mC+++ILt27cDUFJSwsGDB9m5cyfDhw+/7OstWrQIs9nc\nbU3Hjh3jxRdfZOPGjTz++OO888473HvvvQBMnDiRsrIyFi9efNW1EkKIgai53c17xxsptNo41tBG\nqEkjY2Q0OamxjLH4z54c/aqZ+8rmzZt59913Ac+NYP74xz8ya9asruYbFxcHwJ49e3jppZe6/lxs\nbCzgaeqRkZFfe/4pU6Z0netyr3fs2LFLmnlycjITJkwAYNKkSZw6darrOYvFwrlz5677/QohRH+m\nlOKLWs8U/v6JRtrditFxIXxvhmcKDw/yv/1J+lUz72mC/iqTyYTL5erVa5aUlLBnzx7eeustwsLC\nuO+++xg/fjwVFRXXtA5d1zEYLp9HDA8Pv+LrXW4Xu5CQkK7/NhqNOJ3OrsdtbW2EhoZe9fqEEGIg\naGpzU3zcTmG5nRP2NkJNBjJHxZCTGkuqxb+/Z0qavZccDgcxMTGEhYVhtVr5+OOPaWtr48MPP+Tk\nyZMAXR+zZ2RkdEuqn/+YffTo0Zw4cQKAiIgImpqarun1rtXRo0e56aabrvnPCSFEf6OU4nB1C78u\nqeLRfCuv7K8m2KTx1MxE/ueeVJ6cmej3jRz62WTuC/PmzWPLli3MnTuXlJQUpk2bhsVi4T/+4z9Y\nvXo1uq4THx/Pn//8Z5555hl+/OMfM3/+fAwGA88++yyLFi0iKyuLvXv3MmrUKMxmM9OnT2f+/Plk\nZmaSlZXV4+tdi46ODo4fP87kyZO9WQYhhAgojW1uio/ZKSi3UdnYTpjJQNZozxQ+2uz/zfurNNVT\nAszPVFVVdXvc0tLS7WPoa+GNj9m94dy5czzzzDP8+c9/vuGv9e6773Lw4EF+9KMffe0x11LT+Ph4\namt983uV/YXU0Dukjr3X32uolOKz6lYKrDb2nnTQoStuig8lJzWWOSOiCTX1/sNqb9dw6NChV3Wc\nTOZ+ICEhgYceegiHw0FUVNQNfS2Xy8Xjjz9+Q19DCCH8SaPTxc5jdgqtdk43thMRZCAn1TOFj4wL\nvCn8cqSZ+4m77rqrT15nyZIlffI6QgjhS7pSHDrXQoHVxoenmnDpirRBYdx36xBuGx5FiBemcH8i\nzVwIIUS/YXO62Flhp7DCxhlHB5HBBr4xJpac1FiGx4b0fIIAJc1cCCFEQNOV4tOzLRRabeyrdODS\nYfzgML49MZ5bk/tmCldKweef0h4TA0NH3vDX+6qrauYHDhzgtddeQ9d1srKyWLp0abfna2pqePnl\nl2lsbCQyMpLc3FwsFkvX8y0tLTz77LNMnz6dVatWAfDTn/6UhoYGgoODAVi3bh0xMTHeel9CCCH6\nuYZWFzs6p/BzTR1EhRhZPDaOBamxJMf0zRSuXC7URx+gCrfByQqaJ0+Hp/+/Pnnti/XYzHVd59VX\nX2XdunVYLBaef/550tPTSUpK6jpmy5YtZGRkMG/ePA4dOkReXh65ubldz2/dupW0tLRLzr1mzRpS\nUlK89FaEEEL0d7pSHDjTTKHVRmllE24FExLCeXjyIGYlRxJs7Jtr4crZgtqzHVX0d6ivgcRhaMuf\nIvbO+6hrdPTJGi7W47u2Wq0kJiaSkJCAyWRi9uzZlJWVdTumsrKya+vQ8ePHs3///q7njh49it1u\nl99rvoKLb4F6PcaMGQPA2bNn+e53L78L3n333ccnn3wCwAMPPNC1YY0QQgSCupYO/nKwlsffrOD/\n7qrks+pW7rrZzEtLRrM+ezgZI6P7pJErWx363/6A/qNVqL+8CpZBGJ5eh+H/voghYyFasG+uy/c4\nmdfX13f7yNxisVBeXt7tmBEjRlBaWsqiRYsoLS2ltbUVh8NBREQEr7/+Orm5uRw8ePCSc7/00ksY\nDAZmzpzJvffe6zcb1ve1i2+B2huJiYm88sorPR5377338oc//IFnnnmmV68nhBA3kltX/KtzCi87\n3YSuYFJiOI9MHczMpEiC+mgKB1CnT6AKt6H2vQe6DtNmYci5G220f+ym6ZUA3PLly9m8eTPFxcWk\npaVhNpsxGAwUFhYyderUbj8MnLdmzRrMZjOtra288MIL7N69m7lz515yXFFREUVFRQBs2LCB+Pj4\nbs+fO3fuivf47klv/ux5jzzySNctSVevXs2KFSvYuXMnv/jFL7pugfq3v/2N5uZmfvzjH3PgwAE0\nTWPt2rXceeedbNu2jeXLl2Mymfj3f/93hg0bxsqVKwHYuHEjERERPPLII6xYsQK73U5HRwfPPfcc\n3/jGN7q9j5MnT/Lwww+ze/duWltbeeaZZzh8+DCpqak4nU6MRiMmk4lFixZx11138YMf/OCy7yck\nJOSSOl+pfld7rLg8qaF3SB17z19qeM7Rxj8+O8fbn53jXFMbcWFBPHRLEkvGJ5AUG9Zn61BK0XHw\nI5rf/BPtH++F4BDCFi4l/M5vYRqSdNk/46sa9tjJzGZztxut19XVXXKHLrPZzNq1awFwOp3s27eP\niIgIvvzyS44cOUJhYSFOpxOXy0VoaCjLli3rOkdYWBhz5szBarVetplnZ2eTnZ3d9firO+u0tbVh\nNHruYPPf+89xrMHJ1dKu4haoo+JCWZ2ecMVjNm3a1O2WpAsWLODZZ5/tdgtUl8vFpk2biIiI6HYL\nVJfLRWlpKRs2bMDlcrFkyRL+7d/+jRUrVgDw5ptv8r//+78YjUb++7//m6ioKOrr61myZAnZ2dld\nn2a4XC7cbnfXf2/evJn
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435d50cc0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.86e-01\n",
" final error(valid) = 1.83e-01\n",
" final acc(train) = 9.46e-01\n",
" final acc(valid) = 9.48e-01\n",
" run time per epoch = 13.08\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.20\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8lOW9///XNTPZk5khCRAgbJkBAZHNsEW2kIUdAQW0\n7bfnfGt7Tm1Ptcs5LmhbWr62tKL22HN6eko9VH/2tAEVi2ELCbJGFEVAcIFJAgHZEzLZ1/v6/TE1\nTcqSwD1JJsnn+ZeT3HPPNe/HJZ987u1SWmuNEEIIIQKepaMHIIQQQojWkaIthBBCdBJStIUQQohO\nQoq2EEII0UlI0RZCCCE6CSnaQgghRCchRVsIIYToJKRoCyGEEJ2EFG0hhBCik5CiLYQQQnQSto4e\nwPWcO3fOr/uLjY3lypUrft1ndyMZmicZmicZmicZ+oe/c+zbt2+rtpNOWwghhOgkpGgLIYQQnYQU\nbSGEEKKTCMhz2kIIIdqH1prq6moMw0Ap1dHD6TQuXrxITU3NLb1Ha43FYiE0NPS2s5aiLYQQ3Vh1\ndTVBQUHYbFIOboXNZsNqtd7y++rr66muriYsLOy2PlcOjwshRDdmGIYU7HZks9kwDOO23y9FWwgh\nujE5JN7+zGTe5Yu28ear1H76UUcPQwghhDCtSxdtXXoVvXsbV5/8Zxqe/yH6s2MdPSQhhBAdYO3a\ntWzYsAGAjIwMLly4cMv7eOWVVxr3cSOffPIJ3/3ud29rjK3RqhMZhw8fZt26dRiGQUpKCosWLWr2\n+8zMTHJycrBardjtdh5++GF69uwJwPLlyxkwYADge4LM448/7uevcGPK3gPL6t8T/v4eyt94FWPN\nChgyAsu85TBijBwWEkKITqKhoaHZhV9///pG6uvrAV+h3rZtGwAbNmxg2LBhxMXFtfg5TX31q19t\n8fOGDx/O+fPn+fzzz+nXr1+L29+qFou2YRi89NJLPP3008TExPDkk0+SmJhIfHx84zaDBg1i9erV\nhISEkJWVxauvvsr3vvc9AIKDg3n22Wf9PvDWUiGhRNz7JSrHT0fv24He9gbGr34Mg4f6iveoRCne\nQgjRwV5//XX+53/+h9raWsaOHcvPf/5zhg0bxle+8hX27t3Lz372M77zne+wcOFC9uzZw7e+9S1c\nLhdPPPEE1dXVDBw4kOeeew6n08n999/PiBEjOHjwIPfeey/Dhw9n5MiR2Gw2MjMzOXLkCP/yL/9C\naGgomzZtYsaMGc32W15ezh//+Edqa2sZPHgwL774ImFhYTz33HNERETwzW9+k8WLFzNmzBhyc3Px\ner0899xzTJw4EYC0tDT+8pe/8K1vfcvvObVYtD0eD3FxcfTu3RuApKQkDh482Kxojxw5svG/hwwZ\nwt69e/0+ULNUcAhq5nz01Fnod3LQW17D+I9V0H+wr3iPnYSydOmzBUIIcVPGn9eizxT4dZ+q/2As\nD3zjptucPHmSTZs28eabbxIUFMSTTz7JG2+8QWVlJWPHjuXHP/5x47Y9evRg+/btAKSmprJq1Som\nT57Ms88+y/PPP89Pf/pTAOrq6ti6dSsAa9asYdSoUQDMnz+fP/zhD/zwhz9k9OjR191vcXExX/7y\nlwH4xS9+wZ/+9Ce+9rWvXTPu+vp6Nm/eTE5ODs8//zwZGRkAjB49mv/4j//omKJdXFxMTExM4+uY\nmBhOnjx5w+137tzJmDFjGl/X1dXxxBNPYLVauffee5kwYcI178nOziY7OxuA1atXExsbe0tfoiU2\nm635Ppd8Bb3wAar3ZFHx+is0/HY11v6Dibj/Hwi9JwV1G/fedXXXZChumWRonmRo3t9nePHixcZb\nvuotFgw/H3m0WCwt3lKWm5vLRx99xLx58wDfveO9evVqrBtfHK5WSrF48WJsNhulpaWUlpYydepU\nAB588EG+/vWvY7PZmm0HcPnyZe64447G10oprFZrs9dNt/d4PKxevRqv10tFRQXJycnYbDYsFkuz\n77NgwQJsNhtjx47l7NmzjT/v3bs3ly5duuH3DgkJue157Neb8/bs2UN+fj4rV65s/NlvfvMboqOj\nuXjxIj/96U8ZMGDANecRUlNTSU1NbXzt7xVobrgay6gJ6JF3o97fT8Pm9ZS+sJLSP/4ONXcpauJ0\nlNy72EhWBjJPMjRPMjTv7zOsqan52zncZQ+1ydXJX5xXvpGGhgaWLl3Kk08+2eznv/nNb9BaN75f\na01ISAj19fXU19c3+13T1023A1+RrKysbLafhoaG6+4X4JFHHuGll17izjvvJCMjg3feeYf6+noM\nw8AwjMbtrFZrs8//4r8rKiqa7e/v1dTUXDOP/bbKV3R0NEVFRY2vi4qKiI6Ovma7o0ePsnHjRh57\n7DGCgoKavR98f3mMGDGCU6dOtWpg7UVZrFgmTMPy4xexPPwEhIai//DvGE9/E2P3NnRdXUcPUQgh\nurQpU6aQmZnZWMiuXr3K2bNnb/oeu92Ow+Hg3XffBXznxCdNmnTdbd1ud7PaExERQXl5+Q33XV5e\nTu/evamrq2Pjxo23+G0gPz+fO+6445bf1xotFm2Xy8X58+e5dOkS9fX15ObmkpiY2GybgoIC1q5d\ny2OPPYbD4Wj8eXl5OXV/LXqlpaV89tlnzc6FBxJlsaDGJWF5+gUs3/kh2J3oV3+DseKfMHLeQtfe\n2jNmhRBCtM7QoUN57LHHePDBB0lNTeXBBx/k4sWLLb7vV7/6FatWrSI1NZXjx483XgD992bOnNlY\n3AGWLVvGE088QVpaGlVVVdds/2//9m/Mnz+fRYsW4Xa7b/n75ObmkpKScsvvaw2ltdYtbXTo0CFe\nfvllDMMgOTmZJUuWkJGRgcvlIjExkVWrVlFYWIjT6QT+dmvXZ599xu9+9zssFguGYTBv3jxmzpzZ\n4qDOnTtn/ps1cTuH1LTW8MkRjM0ZcOI42J2o9MWo6bNRobf3zNjOTA5LmicZmicZmvf3GVZWVhIe\nHt6BI2ofDz30EE899RQJCQl+2Z/NZrvu4e+amhruu+8+3nzzzRue075e5q09PN6qot3eAqFoN6VP\nHMPIzIBPjkBkFCr1XlTyPFR4hB9HGdjkH0vzJEPzJEPzumvR9ng8XLly5YaH0G/VjYp2fn4+Fy5c\nICkp6YbvNVO05UqrVlBDR2L9/kh03qcYm9ej33wVnbURNXMBKnUBKiKqo4cohBDiJtxu920d6r5V\nCQkJfuvmr0eK9i1QrmFYH/kR+nQexuYMdOaf0Tv+gkqei0q7F2V3dvQQhRBCdGFStG+DGujC+q0V\n6M9PozevR29/A73zLdS02ahZi1HOmJZ3IoQQQtwiKdomqH4DUf/0b+iFD6K3vIbemYnetRU1JQ01\newkqpldHD1EIIUQXIkXbD1RcPOpr30UveAC99TX03iz03u2oyTNRc+5H9erT0UMUQgjRBcjDtv1I\n9YzD8tV/wfKz/0ZNm40+sAvjhw9jvPQC+vzNHxQghBCi7TRdmvNWffe73yUzMxOAf/3Xf+XEiRPX\nbJORkcFTTz0FwLp16/jzn/98+4O9Cem024CK7on60j+j5y5F73gTvWsr+t1dqLvvQc1bhoof1NFD\nFEKITsWfS3OasWbNmha3eeCBB7j33nt54IEHTH/e35NOuw0pZzSWpV/Dsvr3qNn3oY99gPGTR2j4\nz2fQpz0dPTwhhAgYr7/+OvPmzSMtLY3HHnuMhoYGhgwZwk9+8hNSU1P54IMPmDhxIs888wyzZs0i\nMzOTY8eOMX/+fFJTU3nooYcoKSkB4P777+dHP/oRc+bM4fe//z379+9vXJrT4/E0LkwCcObMmcan\nl73wwgvMnTuXmTNn8thjj3G9x5jcf//9HDlyBPD9ITBlyhTmzZvH+++/37hNWFgY/fv358MPP/R7\nTtJptwMV5UAt+Sp61hJ0zlvonE0Y/+9dGHk3lnnLUO7hHT1EIYTg9+9fpOBqtV/3ObhHKF9P7H3T\nbdpzaU63201tbS2FhYU
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435c54358>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAENCAYAAAAbl4wiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xt0VFWe//33rsr9nqqQhJCAEAIGBLmEi4hAIISbKN76\nIoIKqH0Du9XuaXt4ftO/Xu00M+ha/TxrqTPdjzqjPTzt9G8avCYk4S4gQRQBAVNJuCeQW1VSSaoq\nqTr7+aMk3SgaMCFVlXxff5lVJ6e+tQ355Jyzv3srrbVGCCGEECHLFOgChBBCCNEzEuZCCCFEiJMw\nF0IIIUKchLkQQggR4iTMhRBCiBAnYS6EEEKEOAlzIYQQIsRJmAshhBAhTsJcCCGECHES5kIIIUSI\nCwt0Aderpqam186VkpJCQ0NDr51vIJIx7DkZw56TMewdMo4919tjmJGRcU3HyZW5EEIIEeIkzIUQ\nQogQJ2EuhBBChLiQe2b+ZVpr3G43hmGglLqu77106RIej+cGVRZ6tNaYTCaioqKueyyFEEIETsiH\nudvtJjw8nLCw6/8oYWFhmM3mG1BV6PJ6vbjdbqKjowNdihBCiGsU8rfZDcP4VkEuri4sLAzDMAJd\nhhBCiOsQ8mEut4N7n4ypEEKElpAPcyGEECLQnB4fb51o4qOzjoC8v4R5ENBa88ADD+B0OmlubuY/\n/uM/vtV5VqxYQXNz8zce85vf/IYPPvjgW51fCCHE32itqWhw8X/vr2XV5kpe/biOfaebAlKLhHkQ\n2LZtG2PGjCE+Pp6WlhZef/31qx7n9Xq/8TxvvPEGiYmJ33jMqlWrePHFF791rUIIMdC5vQYllQ6e\nLj7Nz7eeYd/ZFvKHJ/L7xTexbtaIgNQkM8d6wapVq6ipqcHj8bB69WoeeughduzYwYYNG/D5fFgs\nFv77v/+btrY21q9fz5EjR1BK8bOf/YwlS5awefNmli9fDsA///M/c+bMGebPn8+sWbOYN28eGzdu\nJDExkcrKSj744IOrvh/AtGnTKCoqoq2tjYceeoipU6fy0UcfkZ6ezquvvkp0dDSZmZnY7Xbq6upI\nTU0N5LAJIURIOdfsodjmYEd1M22dBsMSI3liShpzhicQEx7YzqhrCvPDhw/z2muvYRgG8+bNY9my\nZVe8Xl9fz8svv0xLSwtxcXGsXbsWq9UKwHe/+12GDh0K+Nes/Yd/+AcA6urq+P3vf4/T6WTEiBGs\nXbu2x7PSjT//EX3u1LUfrxRa6288RmUNx/S9x77xmBdeeIHk5GRcLhdLlixhwYIF/PznP+evf/0r\nQ4cOxW63A/D73/+e+Ph4tm3bBoDD4X+2cvDgQf7lX/4FgF/96ld8/vnnlJaWArBv3z6OHj3K9u3b\nu8bxy++3ePFiLBbLFTWdOnWKF198kY0bN/LEE0/w/vvvc9999wEwbtw4Dh48yJIlS655rIQQYiDq\n9GkOnHdSZHNw7FI7YSaYkZXAolFJ5A6KDpoJw92mp2EYvPLKK6xfvx6r1cqzzz5LXl4emZmZXce8\n8cYbzJo1izlz5nDs2DE2bdrE2rVrAYiIiGDjxo1fOe+f/vQnlixZwu23384f/vAHtm/fTmFhYS9+\ntL7z6quvUlRUBPg3gvnTn/7E9OnTu8I3OTkZgD179vDSSy91fV9SUhLgD/W4uLivPf+ECRO6znW1\n9zt16tRXwjwrK4tbbrkFgPHjx3Pu3Lmu16xWK5cuXfrWn1cIIfq7+rZOSiodlFY6sLt9pMaGs2LC\nIAqyE0mKCr6b2t1WVFlZSXp6OmlpaQDMmDGDgwcPXhHm58+fZ+XKlQCMHTv2quH997TWfPbZZzz5\n5JMAzJkzh7/85S89DvPurqC/LCwsrNvn0N3Zt28fe/bs4Z133iE6Opr777+fsWPHUlVVdV11GIaB\nyXT1KQwxMTHf+H5XW8UuMjKy67/NZjNut7vra4/HQ1RU1DXXJ4QQA4GhNYdr2yiyOfjoQitaQ96Q\nWBbmJDNxcCxmU3BchV9NtxPgmpqaum6Zg/+qrqnpytl6w4YNo7y8HIDy8nJcLhdOpxOAzs5OfvnL\nX/KP//iPXcc4nU5iYmK6Vl+zWCxfOWeocDqdJCYmEh0dTWVlJR9//DEej4cPP/yQs2fPAnTdZp81\na9YVM9Uv32YfMWIEZ86cASA2NpbW1tbrer/rVV1dzejRo6/7+4QQoj9qcXv56/FGfvh2Nf97x3k+\nb3Bx7xgr/373CNbPySJvSFxQBzn00gS4FStW8Oqrr7Jz505yc3OxWCxdV5kvvfQSFouFS5cu8Zvf\n/IahQ4decaXZnbKyMsrKygDYsGEDKSkpV7x+6dKlHj1r7+lz+oKCAv70pz8xZ84csrOzmTx5Mqmp\nqbzwwgs89thjGIZBSkoKf/nLX3j66af55S9/ydy5czGbzTzzzDMsWbKEwsJCDhw4QE5ODqmpqUyd\nOpW5c+cyb948CgoKUEp11Xm19zObzYSFhaGUwmw2d/2RdPl7TCYTJpOJsLAwOjs7OXPmDJMnT/7a\nzx4ZGfmVcf6m8bvWY8XVyRj2nIxh7xhI46i15litk81Ha9lha6DDp5kwJIEfzhzB7JFWws3frtkr\nUGOodDczwCoqKvjLX/7CP/7jPwKwefNmAO65556rHu92u/npT3/Kv/3bv33ltRdffJHJkyczbdo0\n1qxZwx/+8AfMZvNX3uOb1NTUXPF1e3v7df1x8Pd64zZ7b7h06RJPPvkkf/7zn2/4exUVFXH06FF+\n8YtffO0x1zOmKSkpNDQ09FZ5A5KMYc/JGPaOgTCOrk6DXaebKbY5OGX3EB1mIn9EAgtzkhmWFNn9\nCbrR22OYkZFxTcd1+6dHdnY2tbW11NXV4fV62bdvH3l5eVcc09LS0rWe9+bNm8nPzwegtbWVzs7O\nrmM+//xzMjMzUUoxduxYPvzwQwB27tz5lXMOJGlpaTz44INdjyZuJK/XyxNPPHHD30cIIYLJGYeH\nfyu/yKN/reTlcv8E4B9NTee1e0fyxJT0XgnyQOr2HrPZbGbVqlU899xzGIZBfn4+WVlZvPnmm2Rn\nZ5OXl8fx48fZtGkTSilyc3NZvXo1ABcuXOAPf/gDJpMJwzBYtmxZ18S55cuX8/vf/54///nPDB8+\nnLlz597YTxrk7rrrrj55n6VLl/bJ+wghRKB1+gz2n2ulqMLO8XoX4SbF7cPiWZSTzOiU/rXVc7e3\n2YNNf7zNHmzkNnvfkjHsORnD3tFfxvFSawdbbQ7Kqppp9vhIjwtnYU4S80YkknCD28oCdZs9+Jrl\nhBBCiOvkMzSf1LZRVGHnUE0bSsGUIXEszEliwuBYTP3oKvxqJMyFEEKELIfbS1llM1sr7dS1eUmO\nMvPALVYKRyYxKDa8z+rQjib0jvfwTJgKw/u+9VfCXAghREjRWnO83kVRhZ3955x4DRiXFsMjE1OZ\nlhVPWB/2hOuas+iSLegDO8HnozM2NiBhLrumBYG/3wL128jJyQHg4sWLPPbY1VfBu//++/n0008B\n/3r5lxesEUKIUNHe6eO9z+2se+8Uvyo9y8c1bSzKSebFO4fz24Kh3D4soU+CXGuNPnkE3//zG4x/\n+gn64G7UHYWYfvsycQ8+fsPf/2rkyjwI/P0WqD2Rnp7OH//4x26Pu++++/jP//zPruV0hRAimFU3\nuSm2Odh1uhm3VzPSEsXa6encMSyByLC+uybVXi/60F50yRY4WwXxiai7l6PmLELFJfRZHVcjYd4L\nensL1IyMDB555BHAv0NabGwsK1as4NFHH6W5uRmv18svfvELFixYcEUd586d4+GHH2b79u24XC6e\neuopjh8/zsiRI69Ym72wsJB7771XwlwIEbQ6fAZ7z/h3K/u8wUWEWXHHMP9uZTnW6D6tRbvb0XtK\n0WVvQ1M9pGeiVv4ENX0OKjyiT2v5Ov0qzP/fjy5xyu7u/sAvqGvYAnV4chRr8tK+8Zje3AL1rrvu\n4p/+6Z+6wvydd97hv/7
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc43577cc88>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.74e-01\n",
" final error(valid) = 1.75e-01\n",
" final acc(train) = 9.49e-01\n",
" final acc(valid) = 9.51e-01\n",
" run time per epoch = 15.16\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.50\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XtcVNe9///XmhnuMIMMIArijHi/gmJUooKKd02Midq0\n/ban7WnTtKf5Jr3k3h570qS2NTk9bb+/0zax1p70tESN1uA1aAAN0ZioMRc1GgfwjoIyXOS61++P\naSg2KugeYNDP8/Ho49GBPXvWfmfJZ/bea6+ltNYaIYQQQgQ8S1c3QAghhBDtI0VbCCGE6CakaAsh\nhBDdhBRtIYQQopuQoi2EEEJ0E1K0hRBCiG5CirYQQgjRTUjRFkIIIboJKdpCCCFENyFFWwghhOgm\nbF3dgKs5ffq0X/cXGxvLhQsX/LrP241kaJ5kaJ5kaJ5k6B/+zrF3797t2k7OtIUQQohuQoq2EEII\n0U206/L4gQMHWLlyJYZhMG3aNBYsWHDF73Nzc9m+fTtWqxW73c6DDz5IXFwc58+fZ/ny5RiGQXNz\nM7NmzWLGjBkdciBCCCHEra7Nom0YBitWrODpp5/G6XTyxBNPkJ6eTlJSUss2LpeLZcuWERISwrZt\n23j55Zd55JFH6NGjBz/5yU8ICgqirq6O733ve6SnpxMTE9OhByWEEKJ9tNbU1dVhGAZKqa5uTrdx\n7tw56uvrb+g9WmssFguhoaE3nXWbRfvYsWMkJCTQs2dPADIyMti7d+8VRXv48OEt/3/AgAHs3LnT\nt3PbP3bf2NiIYRg31UghhBAdo66ujqCgoCv+Xou22Ww2rFbrDb+vqamJuro6wsLCbu5z29qgoqIC\np9PZ8trpdHL06NFrbr9jxw5SU1NbXl+4cIFly5Zx9uxZvvjFL8pZthBCBBDDMKRgdyKbzXbDZ+hX\nvN+PbaGwsJDjx4+zdOnSlp/FxsayfPlyKioq+MUvfsH48eOJjo6+4n15eXnk5eUBsGzZMmJjY/3Z\nLGw2m9/3ebuRDM2TDM2TDM375wybm5ulaN+km80tNDT0pvtxm58YExNDeXl5y+vy8vKrni0fPHiQ\ndevWsXTpUoKCgq66nz59+nD48GHGjx9/xe+ys7PJzs5uee3PZ9+M9S8TPXEa3theftvn7Uie7TRP\nMjRPMjTvnzOsr6+/qcu8tzubzUZTU9NNvbe+vv4z/dhvz2mnpKRw5swZysrKaGpqoqioiPT09Cu2\n8Xg8vPjiizz66KM4HI6Wn5eXl9PQ0ABAdXU1R44caXfD/EF7L6ILtnDxiQdofuGH6CMfdNpnCyGE\nCBwvvvgiq1evBiAnJ4ezZ8/e8D7+9Kc/tezjWg4dOsTDDz98U21sjzbPtK1WK1/96ld59tlnMQyD\nKVOm0KdPH3JyckhJSSE9PZ2XX36Zuro6XnjhBcD3Te6xxx7j1KlT/OlPf0Iphdaa+fPnk5yc3GEH\n88+UvQeWZS8R/k4h1a++jLH8SRgwFMvcJTA0VUZKCiFEN9Hc3HzFFYF/fn0tn54N5+TksGXLFgBW\nr17N4MGDSUhIaPNzWvvSl77U5ucNGTKEM2fOcOrUKRITE9vc/kYprbX2+15N6ohpTM+fPoXe9Tp6\ny6tw8QK4B/qK98h0Kd7tIJclzZMMzZMMzfvnDGtrawkPD+/CFvmsXbuWP/zhDzQ0NJCWlsZPf/pT\nBg8ezBe/+EV27tzJc889x3e+8x3uuusuCgsL+da3vkVKSgqPP/44dXV19O3bl+eff57o6Gjuu+8+\nhg4dyt69e7n77rsZMmQI69at45e//CW5ubk88sgjJCQkEBoayoYNG8jKyrpiv9XV1fz5z3+moaEB\nt9vNr371K8LCwnj++eeJiIjgm9/8JosWLSI1NZWioiIqKyt5/vnnGTduHAAvvfQSDQ0NfOtb37rq\nsV4t8/Zehb5tRh+o4BDU1HnoSTPRb21Hb1qD8ZtnoI/bV7zTxqMsMkGcEOL2Zfz1RfQJj1/3qfq4\nsXzu69fd5ujRo2zYsIH169cTFBTEE088wauvvkptbS1paWn8+7//e8u2PXr0YOvWrYBvPNQzzzzD\nhAkT+MUvfsELL7zAf/zHfwC+x4w3b94MwPLlyxk5ciQA8+bN449//CM//OEPGTVq1FX3W1FRwRe+\n8AUAfvazn/GXv/yFr371q59pd1NTExs3bmT79u288MIL5OTkADBq1Ch+85vfXLNom3HbFO1PqaAg\n1ORZ6Ixs9NsFvuL922XQqw9q7mLU2IkoiwzKEEKIzrJr1y7ef/995syZA/ieHY+NjcVqtTJ37twr\ntr3rrrsA8Hq9VFZWMmHCBAAWLVrEAw888JntAMrKyhgwYMB129B6+yNHjvDzn/8cr9dLTU0NmZmZ\nV33Pp+0dOXIkJ0+ebPm50+nk3LlzbR73zbjtivanlM2GypiGHp+FfudN9MZX0C89j97wF9ScRahx\nmSh5DEIIcRtp64y4o2itWbRoEU888cQVP//tb3/7mfvL7b2U33q70NBQ6urq2r39I488wooVKxg2\nbBg5OTm89dZbV31PcHAw4Bv71XokeX19PaGhoe1q54267a8HK4sVyx2Tsfz7r7A8+DiEhqL/+F8Y\nT38To2ALurGxq5sohBC3tIkTJ5Kbm9tyr/3ixYtXnLlejd1ux+FwsGfPHsB3T/yfHyf+VP/+/Sku\nLm55HRERQXV19TX3XV1dTc+ePWlsbGTdunU3eDRw/PhxBg0adMPvaw85lfw7ZbHA6AwsaRPg/Xcw\ncnPQL/9/6Nwc1KyFqEkzUMEhXd1MIYS45QwcOJBHH32U+++/H601NpuNZ599ts33/fKXv2wZiJac\nnNzyBNM/mzp1Kg899FDL68WLF/P444+3DET7Zz/4wQ+YN28eTqeTtLS06xb4qykqKmLatGk39J72\num1Gj9/oiFOtNRx6D2NjDnz8IdijUTPuQWXOQoXe3Jyx3ZmM2jVPMjRPMjQvUEePd7Svfe1rPPXU\nU/Tr188v+7vW5Cr19fXce++9rF+//pozpsno8Q6glIKhqViHpqI//sB35r1mJXrLGlT23agpc1Hh\nEV3dTCGEEO3wxBNPUFZW5reifS2nTp3iySef7LCpYaVot4MaOBzrd4ejPzmMsfEV9PqX0dvWoabO\nR2XPR0VEdXUThRBCXEf//v3p379/h39Ov379OvSLgRTtG6BSBmN96Efokk8wNuagc/+Kfv1vqClz\nUNPvRtmj296JEEIIcZOkaN8E1TcF67eeRJ8q8T0qtvVV9I7XUJNnoWbeg4p2tr0TIYQQ4gZJ0TZB\nJfZFfeMH6LvuR29ag96Ri87fjJo43Tfi3Bnf1U0UQghxC5Gi7QcqIQn11YfR8z+H3rwGvXMbeudW\n1ISpqNn3oeJlWVAhhBDm3faTq/iTikvA8qV/w/Lc73xTpe7Ox/jhgxgr/hN95voTBQghhOg4rZfm\nvFEPP/wwubm5AHz/+9/n448//sw2OTk5PPXUUwCsXLmSv/71rzff2OuQM+0OoGLiUJ9/AD1nEfr1\n9ej8zeg9+agxd/rmN09ydXUThRCiW/Hn0pxmLF++vM1tPve5z3H33Xfzuc99zvTn/TM50+5AKjoG\ny6KvYln2EmrWvegP3sX48UM0/79n0SXHurp5QggRMNauXcvcuXOZPn06jz76KM3NzQwYMIAf//jH\nZGdn8+677zJu3DieffZZZs6cSW5uLh988AHz5s0jOzubr33ta1y6dAmA++67jx/96EfMnj2bl156\niTfffJPhw4djs9k4duzYFYuQnDhxomX2sv/8z/9kzpw5TJ06lUcffZSrzT1233338d577wG+LwIT\nJ05k7ty5vPPOOy3bhIWF0adPH/bv3+/3nORMuxOoKAdq4ZfQMxeit7+G3r4B4yd7YPgYLHMXo/oP\n6eomCiEEL71zDs/F6y+scaPcPUL51/Se192mM5fm7N+/Pw0NDZSWlpKcnMyGDRuYP38+AP/yL//C\nI488AsB3vvMdXn/9dWb
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435b3abe0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAAENCAYAAADngqfoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XlcVFee///XLYp9rwJBRYlCjLgvuK8IEoNZ7Kw9bYyJ\nZjrpTGMSJ91p++u3M9/MpMcZe5tfZ3lM98R0x7SPNpvpLJoAKkbFgJqYaNRoAS4EFQRqYSmoqnt+\nf1RSacYFlQIK/Dz/6vJu556u8K5zl/PRlFIKIYQQQvRJhp5ugBBCCCG6jgS9EEII0YdJ0AshhBB9\nmAS9EEII0YdJ0AshhBB9mAS9EEII0YdJ0AshhBB9mAS9EEII0YdJ0AshhBB9mAS9EEII0YcZe7oB\n/lJdXe23fSUkJHD+/Hm/7e96Jf3YedKHnSd92HnSh53n7z4cMGDAFa8rI3ohhBCiD5OgF0IIIfow\nCXohhBCiD+sz9+j/N6UUTqcTXdfRNO2qtj137hytra1d1LLeRymFwWAgLCzsqvtSCCFEz+qzQe90\nOgkODsZovPpTNBqNBAUFdUGrei+3243T6SQ8PLynmyKEEOIq9NlL97quX1PIi4szGo3out7TzRBC\nCHGV+mzQyyVm/5M+FUKI3qfPBr0QQgjR05RSqCOf0/Tmn3usDRL0AU4pxT333IPD4cBms/GnP/3p\nmvazZMkSbDbbZdd59tln2bVr1zXtXwghxHeUUqiD+9H/42n03/xfmj/chHK29Ehbrugm9oEDB3jl\nlVfQdZ3s7GwWLVrUbnltbS0vvfQSdrudqKgo8vPzMZvNANx3330MHjwY8M4M9PTTTwNQU1PD7373\nOxwOB0OHDiU/Px+j0YjL5eL555+noqKC6OhonnjiCfr16+fPc+5Vtm7dyogRI4iOjub06dO8+uqr\nPPjggxes53a7L/tMwvr16zs81rJly/jJT37CzJkzO9NkIYS4bildh8/L0D94HU5awJSItvhREm6/\njzq7o0fa1GHQ67rOyy+/zOrVqzGbzaxatYrMzExSUlJ866xfv57Zs2czd+5cDh06xIYNG8jPzwcg\nJCSEtWvXXrDf1157jYULFzJjxgz+8Ic/sG3bNnJzc9m2bRuRkZH8/ve/Z/fu3fzlL3/hySef9OMp\nd69ly5ZRXV1Na2sry5cv5/7772f79u2sWbMGj8eDyWTi9ddfp6mpidWrV/PFF1+gaRpPPvkkCxcu\nZNOmTSxevBiAX/7yl5w8eZL58+cze/ZssrOzWbt2LbGxsVgsFnbt2nXR4wFMmTKFLVu20NTUxP33\n38/kyZPZt28fycnJrFu3jvDwcFJSUmhoaKCmpua6/nElhBBXS+ke1P49qA82wtcnITEZbWk+2tS5\naMZgtJBQIECD3mKxkJycTFJSEgDTp09n79697YK+qqqKBx54AICRI0deNNj/nlKKL7/8kscffxyA\nuXPn8sYbb5Cbm8u+ffu45557AJg6dSrr1q1DKdWpB8H0v/4RdbryytfXNJRSl11HGzQEw/f/scN9\n/frXvyY+Pp6WlhYWLlzIzTffzE9+8hPefvttBg8eTENDAwC/+93viI6OZuvWrQBYrVYA9u7dy3/8\nx38A8POf/5yvvvqKwsJCAEpKSjh48CDbtm3zXTX538fLy8vDZDK1a1NlZSUvvPACa9eu5ZFHHmHz\n5s3cddddAIwePZq9e/eycOHCK+0uIYS4bimPB1X2MWrzG3C2CpJT0JY/iTZpNlqAvKbdYdDX19f7\nLsMDmM1mjh8/3m6d1NRUysrKyMvLo6ysjJaWFhwOB9HR0bhcLn72s58RFBTEHXfcweTJk3E4HERE\nRPjeVTeZTNTX119wvKCgICIiInA4HMTExPjtpLvTunXr2LJlC+AtvPPaa68xdepUXzDHx8cDsHPn\nTl588UXfdnFxcYA38KOioi65/3Hjxvn2dbHjVVZWXhD0gwYNYtSoUQCMGTOG06dP+5aZzWbOnTt3\nzecrhBDXA+V2ofZsR215E2rPQsoNGB75KUyYhmYIjID/ll9eNF+yZAnr1q2juLiYjIwMTCYTBoP3\nOb8XX3wRk8nEuXPnePbZZxk8eDARERGdPmZRURFFRUUArFmzhoSEhHbLz50799096/t/1OnjXYvd\nu3eza9cuNm/eTEREBN/73vcYM2YMFRUVF9xP1zQNo9F4wb8bjUYMBgMGg8H3w+jbdYKCgoiMjPR9\nvtjxvr13r2kaQUFBBAUFERoa6tsmODiYtrY232eXy0VERMRF7/eHhoZe0M+XYzQar2p9cSHpw86T\nPuw86cPvqLZWWra+T9Om11C15zCmDydy+ROETpqJZrj08+092YcdBr3JZKKurs73ua6u7oIRoslk\n4qmnngK8M9KVlpYSGRnpWwaQlJTEiBEjOHHiBFOmTKG5uRmPx0NQUBD19fW+9b49ntlsxuPx0Nzc\nTHR09AXtysnJIScnx/f5f5f/a21tvebZ7YxGI263+5q2/XtWq5WYmBhCQkI4evQo+/fvp7m5mT17\n9lBRUeG7dB8fH8+sWbP4n//5H5599lnftnFxcQwdOpTy8nKGDBlCWFgYjY2NvrZ5PB6UUr7PFzue\nx+PB7XajlMLj8eDxeAB82+i6jq7rvs8Wi4W8vLyLnn9ra+tVlVmU0padJ33YedKHnSd9CKq1FfXx\nh6iPNoGtHtKGY/jBo+gjJ9CoaTR+c1X6UgK6TG1aWhpnzpyhpqYGt9tNSUkJmZmZ7dax2+2+WdM2\nbdpEVlYWAI2NjbhcLt86X331FSkpKWiaxsiRI/nkk08AKC4u9u1z4sSJFBcXA/DJJ58wcuTIXjtR\ny9y5c/F4PMyZM4df/vKXTJgwAbPZzH/+53/y8MMPk5OTw49+5L3a8Pjjj2Oz2Zg3bx45OTmUlJQA\nkJ2dzZ49ewDvj6BJkyYxb948/vVf//WKjnc1XC4XJ06cYOzYsZ08cyGE6BuUsxl9y1voqx5Gvf4y\nJA/EsPJfMTz9H2ijJnaYT25d8clpB6+UnuqmFl9IUx09dQZ8+umn/PnPf0bXdbKysrjzzjvZuHEj\naWlpZGZm8sknn7BhwwY0TSMjI4Ply5cTHBzMV199xR/+8AcMBgO6rrNw4ULmzZsHeC+t/+53v6Ox\nsZEhQ4aQn5/vu4z8/PPPU1lZSVRUFE888YTvQcDLqa6ubve5ubn5mm8R+GtE7w/nzp3j8ccf569/\n/WuXH2vLli0cPHiQn/70pxddfrV9KqOAzpM+7Dzpw867HvtQNTeitr6PKnoXmhth1AQMC+9FSx9x\nRdufcbRRVG5ja7mVBqeHpKhQfr8wlVCjf6avuZoR/RUFfW/QV4Me4N133yUrK+uitzD86b333mP2\n7NnExsZedLkEffeTPuw86cPOu576UDnsqKK/obZ/AC3NMHYyhoX3oQ25scNtXR6dPacbKbRY+eJc\nMwYNJg6IYn56LDePuQFrfV2H+7hSVxP0UvWlF7j99tu75Ti33XZbtxxHCCECjbI1oAreQe3YAm2t\naBOmoy28F23QkA63PWVrpdBiZXulHUerh36RwSwek0B2WizmiGAAjIaeuwUtQS+EEOK6perPoz56\nG7WzANxutMmz0PLuQRsw+LLbOd06u0/aKbDYOHq+BaMBpqREk5sex5jkCAwB9GyZBL0QQojrjqo9\ni/rwLdTurYBCm5qFlnc3Wr/LXxIvr3dSYLHy8Qk7zS6dgTEhPDg+kayhscSFBWakBmarhBBCiC6g\nzlWjNr+B+mQ7GAxoM3PQFtyFlnDph76bXR52VNopLLdSXt9KSJDG9MHe0fuIxPCAfzNMgl4IIUSf\np74+hdr8OmrvLjAa0bIWot18J1q8+eLrK8XR8y0UWGzsPmmn1aMYEh/KDzOTmHNDDFGhgTX73eVI\nmdoA9/dlaq/FjTd6nxQ9e/Ys//iPF5+b/+677+bzzz8HvNUGv51nXwghejt1qhzPS2vQ/+XHqM/L\n0HIXYVjzRwzf/8eLhry
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435b957f0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.68e-01\n",
" final error(valid) = 1.75e-01\n",
" final acc(train) = 9.50e-01\n",
" final acc(valid) = 9.50e-01\n",
" run time per epoch = 14.17\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=1.00\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XlclOe9///XNTMwbAPIgKACLkjc0KCiAu6KK8YtUdPl\ne35tcnoS2yZp2nNiNGlr60nrabP0tGmbxuZYc9LTEk00CW4EjcEIGqIxahYVMYIrCiqDyHpfvz8G\nJxgX0BlgwM/z8cjj4TD33Pc1n9z65t6uj9Jaa4QQQgjh9UxtPQAhhBBCNI+EthBCCNFOSGgLIYQQ\n7YSEthBCCNFOSGgLIYQQ7YSEthBCCNFOWJqz0N69e1m5ciWGYTBx4kRmz5591fuZmZls2bIFs9lM\ncHAwCxcuJCIiggMHDrBq1SrXcidPnuSxxx5j+PDhnv0WQgghxB1ANfWctmEYPPbYYzz99NPY7XYW\nL17MY489RnR0tGuZAwcOEB8fj9VqJSsri08//ZTHH3/8qvVUVFTwyCOP8NJLL2G1Wlvm2wghhBAd\nWJOnxwsKCoiKiiIyMhKLxUJqair5+flXLZOQkOAK4vj4eMrKyq5Zz86dOxk8eLAEthBCCHGbmgzt\nsrIy7Ha767Xdbr9uKF+xdetWEhMTr/n5jh07GDly5G0OUwghhBDNuqbdXDk5ORQWFrJ06dKrfn7+\n/HmKioq4++67r/u57OxssrOzAVi+fDk1NTWeHBYWi4W6ujqPrvNOIzV0n9TQfVJD90kNPcPTdfT1\n9W3edptaICwsjNLSUtfr0tJSwsLCrllu3759rF27lqVLl+Lj43PVe3l5eQwfPhyL5fqbS0tLIy0t\nzfX63LlzzRp8c4WHh3t8nXcaqaH7pIbukxq6T2roGZ6uY9euXZu1XJOnx+Pi4jh16hQlJSXU1dWR\nm5tLUlLSVcscPXqUFStW8MQTTxASEnLNOuTUuBBCCOG+Jo+0zWYzDzzwAM888wyGYTB+/HhiYmLI\nyMggLi6OpKQkXnvtNaqqqnj++ecB528gixYtAqCkpIRz587Rv3//lv0mQgghRAfX5CNfbeHkyZMe\nXZ+cDnKf1NB9UkP3SQ3d9/Uaaq2pqqrCMAyUUm04svbFarVSXV19S5/RWmMymfDz87um1s09Pe7R\nG9GEEEK0L1VVVfj4+NzwniNxfRaLBbPZfMufq6uro6qqCn9//9varkxjKoQQdzDDMCSwW5HFYsEw\njNv+vIS2EELcweSUeOtzp+YdPrSNda9R88X+th6GEEII4bYOHdq6/Dz6/U2cX/wQ9c//FH3wQFsP\nSQghRBtYsWIFq1evBiAjI4PTp0/f8jpeffVV1zpu5PPPP+dHP/rRbY2xOTr0hQwV3AnT8r8S8FEO\nFW++hvHsEojvjyl9AfRPlNNCQgjRTtTX119149fXX9/IlVnLMjIy2LRpEwCrV6+mb9++REVFNbmd\nxv7lX/6lye3169ePU6dOceLECbp169bk8reqQ4c2gLL6ETjrm1QOG4v+4F30pjcxfvdz6HmXM7wH\nJUl4CyFEG3vjjTf4n//5H2pqahg8eDC//vWv6du3L9/+9rfZvn07v/rVr3jkkUeYOXMmOTk5fP/7\n3ycuLo4nn3ySqqoqunfvznPPPUdoaCj33Xcf/fv3Jz8/n1mzZtGvXz8SEhKwWCxkZmbyySef8MMf\n/hA/Pz/efvttxo0bd9V6Kyoq+Pvf/05NTQ09e/bk97//Pf7+/jz33HMEBgby8MMPM2fOHBITE8nN\nzeXixYs899xzjBgxAoBJkybx1ltv8f3vf9/jderwoX2F8rWiJsxAj56CztuC3rAG48VlENPTGd6D\nk1GmDn21QAghbsr45wp08VGPrlPF9MR0//duuszhw4d5++23WbduHT4+PixevJg333yTyspKBg8e\nzM9//nPXsp06dWLz5s2AcwrsZcuWkZKSwm9/+1uef/55fvnLXwJQW1vLxo0bAXj22WcZNGgQADNm\nzOBvf/sbP/3pT6/qh9F4vWVlZXzrW98C4L/+67/4xz/+wQMPPHDNuOvq6li/fj1btmzh+eefJyMj\nA4C7776bF198UULbE5SPD2rMVHRqGvrD953h/dJy6BKDSp+PGjYKZbr1Z++EEELcng8++ID9+/cz\nffp0wPnseHh4OGazmfT09KuWnTlzJgDl5eVcvHiRlJQUAObNm8dDDz10zXLgnJkzPj7+pmNovPzB\ngwf5zW9+Q3l5OZcuXWLs2LHX/cyV8Q4aNIjjx4+7fm632zlz5kyT3/t23HGhfYWyWFCpE9HJ49Af\n7UCvfx391+fQb/8DNX0easRYlDy7KIS4gzR1RNxStNbMmzePxYsXX/Xzl1566ZrrywEBAc1aZ+Pl\n/Pz8qKqqavbyjz/+OK+88goDBgwgIyODvLy8637mSmcus9l8Vcev6upq/Pz8mjXOW3XHnw9WJjOm\n4WMw/fz3mBY+CX5+6L/9N8bTD2O8vwldW9vWQxRCiA5t1KhRZGZmuqZXPX/+/FVHrtcTHBxMSEgI\nu3btApzXxJOTk6+7bO/evfnyyy9drwMDA6moqLjhuisqKoiMjKS2tpa1a9fe4reBwsJC+vTpc8uf\naw45lGygTCYYkoppcArs/wgjMwP92p/QmRmoqXNRoyejfK1tPUwhhOhw7rrrLp544gm+8Y1voLXG\nYrHwzDPPNPm53/3ud64b0WJjY11Nq75uwoQJPProo67X8+fP58knn3TdiPZ1//Ef/8GMGTOw2+0M\nHjz4pgF/Pbm5uUycOPGWPtNc0jDkBrTW8PknGOsz4NCnEByKmjwHNXYqyu/25oxtz6RRg/ukhu6T\nGrrv6zWsrKxs9inn9uzBBx/kqaeeolevXh5Zn8ViueqU+BXV1dXce++9rFu37obTw16v5tIwxE1K\nKeifiLl/IvrQAeeR95qV6E1rUGmzUOPTUQGBbT1MIYQQzbB48WJKSko8Fto3cuLECZYsWdJi87lL\naDeDuisB848T0Ee+wFj/Onrda+istagJ96DS7kEF2tp6iEIIIW6id+/e9O7du8W306tXrxb9xUBC\n+xaouL6YH/0Z+tgRjPUZ6Mx/ot99CzV+OmrSLFRwaFsPUQghRAcmoX0bVPc4zN9fgj5xzPmo2OY3\n0VvfQY2ZipoyBxVqb+shCiGE6IAktN2gunVH/dt/oGd+A71hDXprJnrbRtSoSc47zu2d23qIQggh\nOhAJbQ9QUdGoB36Evud+9MY16O1Z6O2bUSkTUNPuQ3Xu0tZDFEII0QHc8ZOreJKKiML0Lz/E9Ku/\nOKdK3bkN46cLMV55AX3q5hMFCCGEaDmNW3Peqh/96EdkZmYC8O///u8cOnTommUyMjJ46qmnAFi5\nciX//Oc/b3+wN9GsI+29e/eycuVKDMNg4sSJzJ49+6r3MzMz2bJlC2azmeDgYBYuXEhERAQA586d\n46WXXqK0tBRw3nbfuXPHPm2swiJQ33wIPX0e+t116G0b0bu2oYaOdM5vHt2jrYcohBDtiidbc7rj\n2WefbXKZ+++/n1mzZnH//fe7vb2va/JI2zAMXnnlFZYsWcILL7zAjh07rplerkePHixfvpxnn32W\n5ORkXnvtNdd7L774IjNnzuSFF17g17/+NSEhIR7/Et5KhYZhmvcApuV/RU29F31gN8YvHqX+j8+g\njxW09fCEEMJrvPHGG6SnpzNp0iSeeOIJ6uvriY+P5xe/+AVpaWns3r2bESNG8MwzzzBlyhQyMzM5\ncOAAM2bMIC0tjQcffJALFy4AcN999/Gzn/2MadOm8de//pUdO3a4WnMWFBRc1YSkuLjYNXvZCy+8\nwPTp05kwYQJPPPEE15t77L777uOTTz4BnL8IjBo1ivT0dD766CPXMv7+/sTExPDxxx97vE5NHmkX\nFBQQFRVFZGQkAKmpqeTn5xMdHe1aJiEhwfXn+Ph4tm/fDsDx48epr693tURrqQnUvZ2yhaDm/gt6\nylz0lnfQW97G+M9dkDA
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435ca8908>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAENCAYAAAAbl4wiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl4VGWe//33fVJZK1UJVQkJkJAAAiI7RnYSIAgIqIiK\n0yI6gq3dbaOPttO/xvHqmad/7TQz2NfVz/Slzoyj9mgP3bihrRDFKCRsJiigArJTgQAhIVtVlqpU\n1bmfP0qjKMqSSirL9/VfUafO+Z6bwCfnnG/dt9Jaa4QQQgjRZRmRLkAIIYQQbSNhLoQQQnRxEuZC\nCCFEFydhLoQQQnRxEuZCCCFEFydhLoQQQnRxEuZCCCFEFydhLoQQQnRxEuZCCCFEFydhLoQQQnRx\nlkgXcLlOnz4dtn2lpKRw7ty5sO2vJ5IxbDsZw/CQcWw7GcO2C/cY9u3b95K2kytzIYQQoouTMBdC\nCCG6OAlzIYQQoovrcs/Mv01rjdfrxTRNlFKX9dmzZ8/i8/naqbKuR2uNYRjExcVd9lgKIYSInC4f\n5l6vl+joaCyWyz8Vi8VCVFRUO1TVdQUCAbxeL/Hx8ZEuRQghxCXq8rfZTdO8oiAXF2axWDBNM9Jl\nCCGEuAxdPszldnD4yZgKIUTXIpe0QgghRBvpQABdshlvSioMHd3hx+/yV+bdgdaa22+/HY/HQ319\nPX/605+uaD9Lly6lvr7+B7f5zW9+w9atW69o/0IIIc6n/X7MzQWYT/wE/ad/x7ulMCJ1SJh3Ah98\n8AHXXHMNNpsNt9vNSy+9dMHtAoHAD+7n5ZdfJikp6Qe3WbZsGU8//fQV1yqEEAK0z4dZ+DfMx3+M\n/t9nIakXxkO/Jun//EtE6rmk2+x79uzhxRdfxDRN8vPzWbhw4XnvV1VV8eyzz+J2u0lMTGTFihU4\nnc7W95uamnj00Ue57rrrWL58OQDHjh3j6aefpqWlhbFjx3Lvvfd22We1y5Yt4/Tp0/h8PpYvX85d\nd93Fpk2bWLVqFcFgEIfDwSuvvEJjYyNPPPEEn332GUopHnnkEebPn8+6detYsmQJAP/yL/9CWVkZ\n119/Pbm5ueTn57N69WqSkpI4cuQIW7duveDxACZMmEBBQQGNjY3cddddjB8/no8//pj09HReeOEF\n4uPjycjIoLa2lsrKSnr37h3JYRNCiC5He5vRRQXo99aBpx6GjMBY9ghcPQqlVMRy7KJhbpomzz//\nPE888QROp5OVK1eSk5NDRkZG6zYvv/wyubm5TJ8+nb1797JmzRpWrFjR+v7atWsZNmzYeft97rnn\neOCBBxg8eDC/+93v2LNnD2PHjm3TyZh/fQ598vilb68UWusf3EZlDsD4ux//4Da///3v6dWrF83N\nzcyfP585c+bwD//wD7zxxhv079+f2tpaAP7whz9gs9n44IMPAKirqwNg586d/Ou//isAjz/+OAcP\nHuT9998HYPv27Xz++ed8+OGH9O/f/4LHmzdvHg6H47yajh8/ztNPP83q1at54IEH2LBhA7feeisA\nI0eOZOfOncyfP/+Sx0oIIXoy3dSI/vAddOHfoNED14zFmL8YNWR4pEsDLiHMjxw5Qnp6OmlpaQBM\nnjyZnTt3nhfm5eXl3H333QAMHz6c1atXt7537Ngx6uvrGTNmDEePHgWgtraW5uZmhgwZAkBubi47\nd+5sc5hHygsvvEBBQQEQWgjmz3/+MxMnTmwN3169egGwZcsWnnnmmdbPJScnA6FQT0xM/N79jxkz\npnVfFzre8ePHvxPmmZmZjBgxAoBRo0Zx8uTJ1vecTidnz5694vMVQoieQje40YV/Q3+4HpobYdR1\noRAfODTSpZ3nomFeU1Nz3i1zp9PJ4cOHz9smKyuL0tJS5s2bR2lpKc3NzXg8HqxWKy+99BIrVqzg\n888//8F91tTUXPD4hYWFFBaGGgpWrVpFSkrKee+fPXv26++Z3/XTi51O2G3bto2tW7eyYcMGEhIS\nuOWWWxg1ahTHjh37zvfflVJYLJbv/LnFYsEwDAzDaJ3E5qttoqKisFqtra8vdLxAIIDFYkEpRVRU\nFFFRUcTGxrZ+Jjo6mpaWltbXfr+fhISE7/1+fmxs7HfG+ftYLJZL3lZcmIxheMg4tp2M4deCdTU0\n/e0vNBesQ3ubiJ00Hett9xB9kRCP1BiG5atpS5cu5YUXXmDz5s0MGzYMh8OBYRhs3LiRsWPHnhfc\nl2vWrFnMmjWr9fW3l5bz+XxXPIubxWK5aFPZxdTV1WG324mJieHAgQN88sknNDU1sWPHDo4dO9Z6\nm71Xr15MmzaN//7v/+Y3v/lN62eTk5MZOHAgR48eZcCAAcTFxdHQ0NBaVzAYRGvd+vpCxwsGgwQC\nAbTWBINBgsEg8HXDnGmamKbZ+vrIkSPMmzfve8/d5/Nd8hJ+smRi28kYhoeMY9vJGIKurUa/9wZ6\ny3vgD6Cum4oxbzGBfv2pB7jI+ERqCdSLhrnD4aC6urr1dXV19Xdu6TocDh577DEgNL1qSUkJVquV\nQ4cO8cUXX7Bx40a8Xi+BQIC4uDjmzZt30X12FdOnT+fll18mLy+PQYMGMW7cOJxOJ//2b//Gfffd\nh2mapKSk8Ne//pWHH36Yxx9/nJkzZ2IYBo8++ijz5s0jPz+fHTt2MGDAABwOB9dddx0zZ85kxowZ\n5OfnX/R4l8Pv9+NyuRg9uuO/BymEEJ2Vrq5EF7yG3lYIpomaOAN1w22o9H6RLu2SKH2RDrBgMMjD\nDz/Mr3/9axwOBytXruShhx4iMzOzdZuvutgNw+Avf/kLhmFwxx13nLefzZs3c/To0dZu9pUrV3Lv\nvfe2NsDNnTv3koLp9OnT571uamoiISHhkk/4m8JxZR4OZ8+e5eGHH+avf/1rux+roKCAzz//nF/+\n8pffu83ljKn8Jt92MobhIePYdj1xDHXlafSG19AfbQIUasos1NxFqNT0K9pfp70yj4qKYtmyZTz5\n5JOYpsmMGTPIzMxk7dq1DBo0iJycHPbv38+aNWtQSjFs2LDWwP4h9913H8888wwtLS2MGTOmyza/\nhUNaWhp33nknHo8Hm83WrscKBAI88MAD7XoMIYTo7PSZk+j1r6BLt4DFgsq7ATVnEcrRNXsGLnpl\n3tl0xyvzzkauzDuWjGF4yDi2XU8YQ33yeCjEd22HmNhQiM9eiErqFZb9d9orcyGEEKKr08cPY65f\nC5+WQlx86Hn4rJtRNnukSwsLCXMhhBDdlj6yH/OdtbBvNyQkom66EzVzAcr6/XN7dEUS5kIIIboV\nrTUc+Axz/Stw8HOwJaEW3YOafgMq/soey3Z2EuZCCCG6Ba017NsVuhI/egCSHKjFy1G5c1CxcZEu\nr13JqmmdwDeXQL0SgwcPBqCiooIf//jC88jfdtttfPrppwDccccdrfPCCyFEV6e1Ru/5CPPJX2D+\nf/8v1J5D3fkTjN/9F8b1N3f7IAe5Mu8UvrkEalukp6fz3HPPXXS7W2+9lf/5n//h4YcfbtPxhBAi\nkrRpwq7todvp5S5ITUfd/XPUpBkoS3Sky+tQcmUeBsuWLWPu3LnMmDGDP//5zwBs2rSJOXPmMGvW\nLBYvXgxAY2MjjzzyCPn5+cyaNYv169cDsG7dOubMmQOElkD905/+1Lrv3//+9/zHf/wHjY2NLF68\nmDlz5pCfn8977733nTpOnjzJzJkzAWhubuanP/0peXl5LF++HK/X27rd7Nmzeeutt9plLIQQor3p\nYBDzo02Y/7wC8z//DQJ+1LJHMP7vsxjTZve4IIdudmX+3x+f5Xit9+IbfkldwhKoA3rFcV9O2g9u\nE84lUG+66Sb+6Z/+ib//+78H4O233+Z///d/iY2N5fnnn8dms1FTU8ONN97I7Nmzv3ft3Jdeeon4\n+HiKiorYv38/c+fObX0vOTkZn89HTU1Nl51GVwjR8+iAH71jE7rgNaiqgH5ZqPt/ibp2Esq4sjU6\nuotuFeaREs4lUEeMGMG
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435fae828>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.98e-01\n",
" final error(valid) = 2.10e-01\n",
" final acc(train) = 9.41e-01\n",
" final acc(valid) = 9.38e-01\n",
" run time per epoch = 16.21\n"
]
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"num_epochs = 10 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"learning_rate = 0.2 # learning rate for gradient descent\n",
"\n",
"init_scales = [0.1, 0.2, 0.5, 1.] # scale for random parameter initialisation\n",
"final_errors_train = []\n",
"final_errors_valid = []\n",
"final_accs_train = []\n",
"final_accs_valid = []\n",
"\n",
"for init_scale in init_scales:\n",
"\n",
" print('-' * 80)\n",
" print('learning_rate={0:.2f} init_scale={1:.2f}'\n",
" .format(learning_rate, init_scale))\n",
" print('-' * 80)\n",
" # Reset random number generator and data provider states on each run\n",
" # to ensure reproducibility of results\n",
" rng.seed(seed)\n",
" train_data.reset()\n",
" valid_data.reset()\n",
"\n",
" # Alter data-provider batch size\n",
" train_data.batch_size = batch_size \n",
" valid_data.batch_size = batch_size\n",
"\n",
" # Create a parameter initialiser which will sample random uniform values\n",
" # from [-init_scale, init_scale]\n",
" param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
" # Create a model with three affine layers\n",
" hidden_dim = 100\n",
" model = MultipleLayerModel([\n",
" AffineLayer(input_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, output_dim, param_init, param_init)\n",
" ])\n",
"\n",
" # Initialise a cross entropy error object\n",
" error = CrossEntropySoftmaxError()\n",
"\n",
" # Use a basic gradient descent learning rule\n",
" learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
" stats, keys, run_time, fig_1, ax_1, fig_2, ax_2 = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
" plt.show()\n",
"\n",
" print(' final error(train) = {0:.2e}'.format(stats[-1, keys['error(train)']]))\n",
" print(' final error(valid) = {0:.2e}'.format(stats[-1, keys['error(valid)']]))\n",
" print(' final acc(train) = {0:.2e}'.format(stats[-1, keys['acc(train)']]))\n",
" print(' final acc(valid) = {0:.2e}'.format(stats[-1, keys['acc(valid)']]))\n",
" print(' run time per epoch = {0:.2f}'.format(run_time * 1. / num_epochs))\n",
" \n",
" final_errors_train.append(stats[-1, keys['error(train)']])\n",
" final_errors_valid.append(stats[-1, keys['error(valid)']])\n",
" final_accs_train.append(stats[-1, keys['acc(train)']])\n",
" final_accs_valid.append(stats[-1, keys['acc(valid)']])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |\n",
"|------------|--------------------|--------------------|------------------|------------------|\n",
"| 0.1 | 1.86e-01 | 1.83e-01 | 0.95 | 0.95 |\n",
"| 0.2 | 1.74e-01 | 1.75e-01 | 0.95 | 0.95 |\n",
"| 0.5 | 1.68e-01 | 1.75e-01 | 0.95 | 0.95 |\n",
"| 1.0 | 1.98e-01 | 2.10e-01 | 0.94 | 0.94 |\n"
]
}
],
"source": [
"j = 0\n",
"print('| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |')\n",
"print('|------------|--------------------|--------------------|------------------|------------------|')\n",
"for init_scale in init_scales:\n",
" print('| {0:.1f} | {1:.2e} | {2:.2e} | {3:.2f} | {4:.2f} |'\n",
" .format(init_scale, \n",
" final_errors_train[j], final_errors_valid[j],\n",
" final_accs_train[j], final_accs_valid[j]))\n",
" j += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models with four affine layers"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.10\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAecAAAENCAYAAAA10q2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8VNXh///XnSX7QmYmCQTCFiAoooKxIlIViBvgUqto\n14+V9mPFpahFRf340VpaVFQ+WlrQItqvXaI/KyrWpeBWwQUEFFwCkUWWQMgChOwz9/7+mMmQkEAC\nmclMkvfz8ZjH3HvnzL1nDkPec+8991zDsiwLERERiRq2SFdAREREmlM4i4iIRBmFs4iISJRROIuI\niEQZhbOIiEiUUTiLiIhEGYWziIhIlFE4i4iIRBmFs4iISJRROIuIiEQZRyQ3vmvXrkhuvlvxeDyU\nlpZGuhrdjto19NSm4aF2Db1Qt2lWVla7y2rPWUREJMoonEVERKKMwllERCTKRPScs4iIdA7Lsqit\nrcU0TQzDiHR1uoQ9e/ZQV1d3TO+xLAubzUZcXFyH2lnhLCLSA9TW1uJ0OnE49Ge/vRwOB3a7/Zjf\n5/V6qa2tJT4+/ri3rcPaIiI9gGmaCuZO4nA4ME2zQ+tQOIuI9AA6lN25Otre3SKcrW3fYL7wNJbX\nG+mqiIiIdFj3COfdO7DeWgK7t0e6KiIiEkFPPfUUL7zwAgAFBQXs3r37mNfxl7/8JbiOI/nqq6+Y\nMWPGcdWxPbpFOBsDhgBgbS2KcE1ERKQjfD7fUeePxOv14vV6KSgo4Hvf+x4AL7zwAnv27GnXdpr6\n6U9/ypVXXnnU7Z1wwgkUFxezc+fOdtXvWHWLcCajD8TFw7ZvIl0TERE5ihdffJHJkydz3nnncfvt\nt+Pz+Rg6dCj3338/+fn5fPrpp5xxxhnMnj2bCy64gKVLl7JhwwamTJlCfn4+06ZNY9++fQBcccUV\n3HvvvVx00UX8+c9/ZsWKFZx00kk4HA6WLl3KZ599xo033sh5551HTU1Ni/X+9a9/ZdKkSeTn5/OL\nX/yCmpoaAB555BEWLFgQ3Mbs2bOZPHky48aN4+OPPw5+lvPOO4+XX345LO3ULbruGTYb9M/B2qY9\nZxGRtpj/eApr+5aQrtPIHoTt6l8ctcymTZt45ZVXWLJkCU6nk1mzZvHPf/6T6upqRo0axf/+7/8G\ny6alpfHmm28CkJ+fzwMPPMCZZ57Jww8/zKOPPspvfvMbABoaGnj99dcBmDt3LieffDIAU6ZM4Zln\nnuF//ud/OOWUU1pdb3l5OT/60Y8AePDBB/n73//Otdde26LeXq+X1157jeXLl/Poo49SUFAAwCmn\nnMIf/vAHpk+fflxtdjTdIpwBjIFDsN5+DcvrxdDlAiIiUeeDDz5g/fr1TJo0CfBfe+3xeLDb7Uye\nPLlZ2UsuuQSAAwcOsH//fs4880wArrzySq677roW5QBKSkoYOnToUevQtHxhYSEPPfQQBw4coKqq\ninPOOafV9zTW9+STT2bHjh3B5W63+4iHzTuq+6RY/xzwNkDxdsgeFOnaiIhErbb2cMPFsiyuvPJK\nZs2a1Wz5ggULWgz2kZCQ0K51Ni0XFxdHbW1tu8vfcsstLFq0iBEjRlBQUMCHH37Y6ntiYmIAsNvt\neJtcFVRXV0dcXFy76nmsusc5Z5p0CtOhbRGRqDRu3DiWLl0avA1jRUVFsz3R1qSkpJCamho81/vi\niy8yZsyYVssOGTKErVu3BucTExM5ePDgEdd98OBBMjMzaWho4KWXXjrGTwObN28mNzf3mN/XHt1n\nzzmjD8QnwLYiGHdepGsjIiKHGTZsGLfffjs/+MEPsCwLh8PB7Nmz23zfvHnzuPPOO6mtraV///48\n+uijrZabMGECN998c3B+6tSp3HnnncTFxfHKK6+0KD9z5kymTJmC2+1m1KhRRw3y1qxcuZKJEyce\n03vay7AsywrLmtth165dIV2fb+7dUF+H/a65IV1vV6AbrYeH2jX01Kbh0Va7VldXt/tQcVc2bdo0\n7r77bgYPHtzhdTkcjmaHsZuqq6vj+9//PkuWLGl1WNTW2jsrK6vd2+42h7UBjAE5sH2LRgoTEemh\nZs2aRUlJSdi3s3PnTu66666wjVfefQ5rAwwY4u8Ututb6N/xX00iItK1DBkyhCFDhoR9O4MHDw7J\n3vmRdLM9Z3UKExGRrq9bhPOaXQe569/bqEvL8HcK+1YjhYmISNfVLcLZa1p8UVLD5n31/pHCNMa2\niIh0Yd0inHM98QB8vbfGf2h7x1Z1ChMRkS6rXeG8bt06fvWrX3HTTTexZMmSFq+/++67TJs2jZkz\nZzJz5kyWL18e8ooeTWqcgz7JTr4urYEBOYc6hYmISI/S9JaRx2rGjBksXboUgF//+tcUFha2KFNQ\nUMDdd98NwOLFi/nHP/5x/JU9ijZ7a5umyaJFi7jnnntwu93MmjWLvLw8+vXr16zc2LFjmTZtWlgq\n2R65nnjWFVfBuBzA3ynMUI9tEZEuxefzNRvK8/D5I2m8HrmgoIA33nijw/WYO3fuUa9zBrj66qu5\n9NJLufrqqzu8vcO1uedcVFRE7969yczMxOFwMHbsWFatWhXyinTUcE88+2p9lCR4ID7RP1KYiIhE\nlc66ZWRRUVGzm2ls3749OJrXY489xqRJk5gwYQK33347rY3FdcUVV7Bu3TrAH/jjxo1j8uTJrF69\nOlgmPj6e7Oxs1q5dG/J2anPPuby8HLfbHZx3u91s2rSpRbmPP/6Yr776ij59+vBf//VfeDyeFmWW\nLVvGsmXLAJgzZ06rZY7XGVYcC1btYVedk6ycXKyd23CHcP3RzuFwhLQ9xU/tGnpq0/Boq1337NkT\nHDDjyU+K2VxeE9LtD3bF89/f6XPUMhs3buTVV19l6dKlOJ1O7rjjDl5++WWqq6vJy8vjgQceAMAw\nDNxud/AU6bnnnsvvfvc7xo4dy4MPPsi8efP47W9/i2EY+Hw+/v3vfwPw0EMPceqpp+JwOBg+fDgN\nDQ3s3LmTAQMGsHTpUi699FIcDgc///nPmTlzJgA33HADb7/9NhdccAE2mw273Y7D4cAwDADKysp4\n5JFHeOutt0hJSeHyyy9n5MiRwbY89dRTWb16NaeffnqzzxobG9uh73lIBiE57bTTOOuss3A6nfz7\n3/9m/vz5ze7L2Sg/P5/8/PzgfCiH8Eu1LOIcBqu3lHBq3wFYy19l7+5iDIczZNuIZhoSMTzUrqGn\nNg2Pttq1rq4ueHjYNM1W9xY7wjTNox4CBnjvvff4/PPPOf/88wH/LSNdLhd2u50LL7ww+H7Lspgy\nZQperzd4y8jvfOc7eL1evv/973Pdddfh9XqblQPYvXs3OTk5wfkpU6bw0ksvceONN7JkyRL+9Kc/\n4fV6ef/99/nTn/5ETU0N+/btY+jQoUycOBHTNPH5fMF1A6xatYoxY8bQq1cvAC6++GI2b94c3IbL\n5aKoqKjFZ6+rq2vx73Esw3e2Gc4ul4uysrLgfFlZGS6Xq1mZ5OTk4PTEiRN57rnn2l2BULHbDIa6\n4/m6tDYwUpg3MFJYTqfXRUQkmv08LzMi2+3sW0ZecsklXHfddVx00UUYhsHgwYOpra3lrrvu4l//\n+hd9+/blkUceoa6u7rg/U7huG9nmOeecnByKi4spKSnB6/WycuVK8vLympWpqKgITq9evbpFZ7HO\nkuuJZ2tFLfX9/B3BrG0ajEREJFp09i0jBw4ciN1uZ968eVxyySUAwSB2uVxUVVXx2muvHXX7o0aN\n4qOPPqK8vJyGhoZgb+5GmzdvZvjw4Uddx/Foc8/Zbrdz7bXXMnv2bEzTZPz48WRnZ1NQUEBOTg55\neXm8/vrrrF69GrvdTlJSEtOnTw95RdtjuCcenwVFtl6cEJ8IW4vgu+dHpC4iItJcZ98yEvx7zw88\n8AAfffQRAKmpqfzwhz9k4sSJpKenc8oppxx125mZmdx2221ccsklpKamMmLEiGavr1q1iltvvbXN\nz3CsutUtIw/UevnJi0X
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc437ffeeb8>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xt8k+X9+P/XnaTnQtsk0FJaKBTqypmuKKDWlnboQJSh\nouJwDJxOHDpRN/GnfjYdh0346HQo6ABPY8Ovx49zuq0om6MTqBwUQaAcW1ooaVp6TJrkvn9/pA0N\nFNpC0jTt+/l49NH7cCV53xeh71zXfeW6FE3TNIQQQgjR5ekCHYAQQggh2keSthBCCBEkJGkLIYQQ\nQUKSthBCCBEkJGkLIYQQQUKSthBCCBEkJGkLIYQQQUKSthBCCBEkJGkLIYQQQUKSthBCCBEkDIEO\noDWlpaWBDqFbMZvNWCyWQIfRrUid+ofUq+9JnfqHr+s1MTGxXeWkpS2EEEIECUnaQgghRJCQpC2E\nEEIEiS55T/tsmqZhs9lQVRVFUQIdTtA5efIkdrsdcNelTqcjPDxc6lIIIYJMUCRtm81GSEgIBkNQ\nhNvlGAwG9Hq9Z9/pdGKz2YiIiAhgVEIIITqqzSz44osvsn37dmJiYlixYsU55zVNY926dezYsYOw\nsDDmz5/P4MGDAdi0aRPvvvsuADNmzCA7O/uiglRVVRK2DxkMBk/LWwghRPBo8552dnY2jz322HnP\n79ixgxMnTvD8889z991388c//hGA2tpa3n77bZYsWcKSJUt4++23qa2tvaggpRvX96ROhRAi+LTZ\nfB02bBjl5eXnPV9YWEhWVhaKopCWlkZdXR2VlZV88803jBo1iujoaABGjRrFzp07ueqqq3wXvRBC\ndCJNdYGr6UfTWvyoZ36rLY9xgXOtPF690Lmm32htPo/W4nkaoqNRa2qar6DFxbS8MK+d1st4HdfO\nPKY5puYizfta0wGtxTZnnT+7bGvxnC+Gc8qdda4j19au6/c+4Jx8I0REt/6afnTJfc5WqxWz2ezZ\nN5lMWK1WrFYrJpPJc9xoNGK1Wlt9jvz8fPLz8wFYtmyZ1/OBeyBVd+ge1zSNm266iddeew1VVXn3\n3Xf58Y9/3OHnmTVrFi+99BIxMTHnLfOrX/2K3Nxcrr76aoBz6i8sLOycehbtZzAYpP58QHO5cJWX\n4So5grP0GPUOB+EuJ7hU0FS0lslKU0FV0dSzktjZx5v20TQ0TfUkWc3lBKcTrTnpupq3m347naC6\n0JzOc8+5nGcSdZCpDnQA3U1TL6U2OhNz5pWd/vJdIhPm5eWRl5fn2T97lhm73e41kCpY5efnk56e\nTkREBMXFxaxbt47Zs2efU87pdF7wQ8rrr7/uKXc+c+bM4ZFHHmHChAkYDIZzytrtdpkl6RLILFMd\nozkcUF4KZcVopcVwogStrBhOHAeno/UH6XSg6Nx/JHXKmW1F13ROaTp3nuOK7szj9AbQ68/86PRg\nCIWwM8cUnb6pXFN5XdPxVo6h1595Da/XannsrP2zrkFR8D6H0uKa8X6e5uuilbo4+3nOev24OCOV\nVVVn6tXrzliLnfPdMlPOV0Zpenjza7U8fta55rKex7f2WOXc528t5nPiPE+5811bO8q05/ZhSIBm\nRLvkpG00Gr0Cr6iowGg0YjQa2bNnj+e41Wpl2LBhl/pyATV37lxKS0ux2+3MmzePH/7wh3z22Wcs\nW7YMl8uF0Wjkrbfeoq6ujscff5yvvvoKRVF48MEHmTp1Ku+99x533HEHAEuWLOHo0aN873vfIysr\ni9zcXJ555hliYmIoKiriP//5T6uvB3DFFVfw8ccfU1dXxw9/+EMuv/xyCgsLSUhIYO3atURERJCU\nlERlZSXl5eXtfjMIcak0Wz2UHW9KyMVoZSVQWgyWE+7WL7j/OJrjISEJZdhY6JeE0i8ZEpIwJyVj\nsVplzIUPGcxmlLDIQIchfOSSk3ZmZiaffPIJV155JQcOHCAyMpK4uDjGjBnDn//8Z8/gs127djFr\n1qxLDlj9yytoxYcv+XlaUpIHobvtJ22WW7FiBXFxcTQ0NDB16lSuvfZaHnnkEd59910GDBhAZWUl\nAM899xy9evVi48aNAFQ1fcrdtm0bv/3tbwF47LHH2LdvH//85z8BKCgo4Ouvv+bTTz9lwIABrb7e\nlClTMBqNXjEdPnyYlStX8swzz3DPPffwt7/9jZtuugmAkSNHsm3bNm688UYf1JIQbprqgrpaKGtq\nLZc1JecTxWBt0fLQG6BvP0hOQbn8aneS7pcMCf1RQsNafW5Fr5eELcQFtJm0n3vuOfbs2UNNTQ0/\n/elPmTlzpqerdfLkyYwdO5bt27dz//33Exoayvz58wGIjo7mpptuYtGiRQDcfPPNnkFpwWrt2rV8\n/PHHgHtRkzfffJPx48d7kmxcXBwAn3/+OS+++KLncbGxsYA7eV+oDsaMGeN5rtZe7/Dhw+ck7eTk\nZEaMGAG4B/sVFxd7zplMJk6ePHnR1yuCm6ZpUFsN1aeh0QaNdrDb0Ox2r33O2vc+31Sm5b6j0fuF\nwsLdCTltBPRLdifmfklgTkDpBmNRhOhK2vwf9fOf//yC5xVF4a677mr13KRJk5g0adLFRXYe7WkR\n+0NBQQGff/45H374IREREdx8880MHz6cgwcPtvs5DAYDqqqi07X+TbvIyDNdWK29XmvfrQ4LO9Ni\n0ev12Gw2z77dbic8PLzd8YngpNXVQnkp2slS933jk83bZdBQ174nCQ1zJ9/QMO/t3rHuVnHzflgY\nhIZDZBRKQn/oNwDiTCjneU8LIXxLPga3U01NDTExMURERFBUVMT27dux2+188cUXHDt2zNM9HhcX\nR1ZWFq+++ipPPfUU4G5hx8bGMnjwYI4ePcqgQYOIioq64PfWW3u9jjp06BDXX3/9RV+z6Do0Wz2U\nl7mTcVNy1srL3Nu1LcYHKwoY+0B8f5TxadA3EWKMKGEtk3K4O/k274eEStIVIkhI0m6n7Oxs3njj\nDa655hpSU1PJyMjAZDLxu9/9jrvuugtVVTGbzfzlL3/hgQce4LHHHmPSpEnodDoWLlzIlClTyM3N\n5b///S+DBg3CaDQybtw4Jk2aRE5ODrm5uW2+Xkc4HA6OHDnC6NGjfVkNwo+0Rru7dVxeinay+fdx\n97HTld6F48zQtx9KxgTom4gSnwjxie4u6ZCQwFyAEMLvFE3rel88LC0t9dqvr6/36joOVidPnuSB\nBx7gL3/5i99f6+OPP+brr7/mF7/4Ratf+eoudRool/qVL62uBo4dQjt2yP27+JD760+aeqZQ71iv\nhKz0TYT4ftAn0d1y7obkq3S+J3XqH76u1077ypdov/j4eGbNmkVNTQ29evXy62s5nU7uuecev76G\naJumaVBZAcXuBK0dOwTFh6CixSyDRjMkD0bJvNI9kKtvorsVHSEfqoQQ3iRpd7IbbrihU15n2rRp\nnfI64gxNVd33nYsPnWlFFx+CmtPuAoribjEPvgyyv48yYDAkp6L06h3YwIUQQUOStujWNLvNPYL6\nfDNXtZxNq8WMUm19V1hzONwtZ68EfRjsDe4CegP0H4AyahwMGOxO0EmDUMJlOVQhxMWTpC26Ha2h\nHm3XFrTCzfDNjvNPk9mWltNgnjV1ZLnD7p6rGtyjsJNSUCZOOpOgEwegGGRAmBDCtyRpi25Bq69D\n27UV7cvN8M12d0KNNaFccx0kJreyKtKZRSW8V1Zqx3FNI6J3DA3mBHeC7tvPPW+1EEL4mSRtEbS0\n+lq0nU2Jes8Od6KOM6NkT0H57pUw+DK/ff+4l9mMXUbkCiE6mcyo0Ik0TeOWW26hxrO2bccMHToU\ngBMnTvCTn7Q+M9zNN9/Mrl27ALj11ls98553F1pdLWrBRlzPP4W68E60dc9ByWGU7KnoHv0dumV/\nRHfrXShD0mXCECFEtyMt7U60ceNGhg0bdslf90pISOCVV15ps1zz2t0PPfTQJb1eoGl1tWg7v3Df\no967y722sbEPyqSp7hb
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435cf5080>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 2.03e-03\n",
" final error(valid) = 1.35e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.73e-01\n",
" run time per epoch = 19.51\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.20\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8VNX9//HXubNkm2wzk4UsGAhBZFdGBUQFE60iWqqC\n669fC7UW1H7BBUGtdcNiRbQVLYiI9qtVrFZUXBsRt6gNSxBFMJFFlphAJmQly+Te3x8ThgQCSWAm\nk+XzfDzymJk759459zjynnvuPecqwzAMhBBCCNHpacGugBBCCCHaRkJbCCGE6CIktIUQQoguQkJb\nCCGE6CIktIUQQoguQkJbCCGE6CIktIUQQoguQkJbCCGE6CIktIUQQoguQkJbCCGE6CLMwa5AS/bs\n2RPsKnQrTqeTffv2Bbsa3Yq0aWBIu/qftGlg+Ltdk5KS2lSuTaGdl5fHsmXL0HWdzMxMJk6c2Oz9\nDz/8kA8++ABN0wgNDeXGG28kJSUFgDfeeINVq1ahaRq/+c1vGD58eDt3RQghhBDQhtDWdZ2lS5dy\nzz334HA4mDNnDi6XyxfKAGPGjOGCCy4AYM2aNbzwwgvcfffd7Nq1i5ycHBYsWEBpaSkPPvggf/3r\nX9E06ZUXQggh2qvV9CwoKCAxMZGEhATMZjOjR48mNze3WZnw8HDf85qaGpRSAOTm5jJ69GgsFgvx\n8fEkJiZSUFDg510QQggheoZWj7TdbjcOh8P32uFwkJ+ff0S5999/n3feeQePx8O9997rWzcjI8NX\nxm6343a7j1g3Ozub7OxsAObNm4fT6Wz/noijMpvN0qZ+Jm0aGNKu/tdamxqGgdvtxuPxdGCtur7i\n4mKO587WZrMZu93uO7ht9/rHtVYLLrzwQi688EI+//xzXn/9dW6++eY2r5uVlUVWVpbvtVw04V9y\nIYr/SZsGhrSr/7XWpgcOHMBisWA2d8rrkjsts9l8XD906uvr2bVrF2FhYc2Wt/VCtFa7x+12OyUl\nJb7XJSUl2O32o5Zv2n1++Lput/uY6wohhOhYuq5LYHcgs9mMruvHvX6roZ2enk5hYSHFxcV4PB5y\ncnJwuVzNyhQWFvqer1u3jl69egHgcrnIycmhvr6e4uJiCgsL6dev33FXVgghhH8dbzetOH4n0uat\n/rwymUxMmTKFuXPnous648aNIzU1leXLl5Oeno7L5eL9999n48aNmEwmbDYbN910EwCpqamMGjWK\nW2+9FU3TmDp1aodeOW7UHMBYtRLVbyCq/6AO+1whhBAiEJRxPGfSA8xfk6sYHg/6jGtQo85Du/b3\nftlmVyTnCf1P2jQwpF39r7U2ra6ubjYCqLtasmQJMTExTJo0ieXLl3PuueeSmJjYrm384x//ICws\njEmTJh31nPb333/P4sWLeeKJJ466nZba3K+Tq3RVymyGjEEYm78JdlWEEEKcgIaGBkwm01FfH83B\nYF2+fDnvv/8+AP/6178YMGBAi6F9rO3++te/bvXzTjnlFAoLC9m9ezfJycmtlm+vbj/LiRowFH7e\nhbG/pPXCQgghguL111/n4osv5vzzz2fWrFk0NDSQkZHB/fffT1ZWFmvXruXMM89k7ty5/OIXv2Dl\nypV8++23TJgwgaysLKZOncr+/fsBuOKKK7j33nu56KKLePbZZ/niiy8YPHgwZrOZlStXsmHDBm6+\n+WbOP/98Dhw4cMR2X3rpJcaPH09WVhY33HADBw4cAOCxxx5j0aJFAPzqV79i7ty5XHzxxYwZM4av\nv/7aty/nn38+b775ZkDaqVsfaYM3tA3A2LwRNXJssKsjhBCdlv7KEoyd2/y6TZXaB+2qG45ZJj8/\nn7feeosVK1ZgsViYM2cO//73v6murubUU0/lT3/6k69sbGwsH3zwAeAdLvzggw8yatQoHn30URYs\nWMADDzwAeIdWvffeewDMnz+foUOHAjBhwgSef/55/vjHPzJs2LAWt+t2u7n22msBeOSRR3j55ZeZ\nMmXKEfX2eDy88847fPTRRyxYsIDly5cDMGzYMBYuXMj06dOPq82OpduHNqlpEG6Dzd+AhLYQQnQ6\nn3/+ORs3bmT8+PGAd2ZNp9OJyWTi4osvblb20ksvBaC8vJyysjJGjRoFwKRJk7jxxhuPKAfeiVCa\nTvTVkqblt2zZwl/+8hfKy8upqqri3HPPbXGdg/UdOnQou3bt8i13OBwUFRW1ut/Ho9uHttJMMGCI\nnNcWQohWtHZEHCiGYTBp0iTmzJnTbPmiRYuOOL/c1ovmmpYLDQ2lpqamzeVnzpzJ0qVLGTRoEMuX\nL+fLL79scR2r1Qp4R1k1vSittraW0NDQNtWzvbr9OW1oPK9dUoyx9+dgV0UIIcRhxowZw8qVK31X\nuZeWljY7cm1JVFQU0dHRvnPJr7/+OiNHjmyxbL9+/di+fbvvdUREBJWVlUfddmVlJQkJCdTX1/PG\nG2+0c29g69atnHzyye1ery26/ZE2ND2v/Q0qrn2X+AshhAis/v37M2vWLK6++moMw8BsNjN37txW\n13viiSeYPXs2NTU19O7dmwULFrRY7rzzzuMPf/iD7/XkyZOZPXs2oaGhvPXWW0eUv+OOO5gwYQIO\nh4NTTz31mAHfkpycHDIzM9u1Tlt163HaBxmGgX7H9aiTh6DdcLtft90VyNhX/5M2DQxpV/+Tcdpe\nU6dO5e6776Zv375+2d7RxmnX1tZy+eWXs2LFiqNOD3si47R7Rve4UqiTh2Js/ua47soihBCia5sz\nZw7FxcUB/5zdu3dz1113BWw+924d2h7dYN2eSnaV18KAIVC+Hwp3BrtaQgghOli/fv2Oes7bn/r2\n7cvo0aMDtv1uHdp1DTpzP9lFdkGZ92I0kKvIhRBCdFndOrTDLSYGxoezZk+l9wI0R7yEthBCiC6r\nW4c2gCvJxs6yOooq67xH21u+xdAbgl0tIYQQot26f2gn2wBYs7sKBgyF6krYuT24lRJCCCGOQ7cP\n7eQoK70iLazdUynntYUQoodasmQJ//rXv45r3RkzZrBy5UoAbr/9dn744Ycjyixfvpy7774bgGXL\nlvHKK68cf2WPoduHNni7yDcWVVNni4FeqRLaQgjRxTQ0NBzz9dF4PB48Hg/Lly/nV7/61QnXY/78\n+fTv3/+YZa666iqee+65E/6slvSI0B6RbKOuweCbn6tRA4ZA/ncYLQyKF0IIERwddWvOgoKCZjch\n2blzp2/2sscff5zx48dz3nnnMWvWrBbn9bjiiivYsGED4D26HjNmDBdffDFr1qzxlQkLCyM1NZX1\n69f7vZ16xDSmg+PDCDUr1uypxDVgKMbH78L2H6DfwGBXTQghOo1n1xSxrfTYN9Zorz6xofzWlXDM\nMh15a85+/fpRV1fHTz/9RO/evXnrrbe45JJLALj++uuZOXMmALfccgv/+c9/uOCCC1qsc1FREfPn\nz+f9998nMjKSSZMmMXjwYN/7Q4cO5euvv+bUU089nmY7qh5xpG0xaQxLjGDN7kqMjEGglHSRCyFE\nJ9H01pznn38+n3/+OT/99FO7b8158OYhTcuB99acDofD9/qSSy7xzTn+1ltv+crm5OQwYcIEMjMz\nycnJafHc9UFr165l1KhROBwOrFZrs88D7/Sxgbg9Z4840gbvVeRf76pkZ0MoKal9MDZvhAlXBbta\nQgjRabR2RBwoHX1rzksvvZQbb7yRiy66CKUUffv2paamhrvuuot3332X5ORkHnvsMWpra497nwJ1\ne84ecaQNMCIpAoA1uxuvIv/xe4y64/8PIoQQwj86+tacaWlpmEwmnnjiCd8R8sGAttvtVFVV8c47\n7xzz80eMGMFXX32F2+2mvr7ed3X5QVu3bmXAgAHH3Mbx6DFH2o5wC31iQ1izu5LLBgzF+HAF/LgZ\nThkW7KoJIUSP1tG35gTv0faDDz7IV199BUB0dDTXXHMNmZmZxMXFMWzYsbMhISGB2267jUsvvZTo\n6GgGDRrU7P3c3FxuvfX
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc42c324278>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VOX9///nmZnsgSSTQELYVw0giwZBVGSrVtTWD1pa\nFxTB1opFW5dW/Vnby1aLVX/aBcVawKWl1Y9V+7EubUHRVhTCqhBAQLYQICSTbbLPnPv7x4SBQCAB\nZphM8npcV66ZOedk5j03Q15z3+ec+1jGGIOIiIi0eY5IFyAiIiKto9AWERGJEgptERGRKKHQFhER\niRIKbRERkSih0BYREYkSCm0REZEoodAWERGJEgptERGRKKHQFhERiRKuSBfQnMLCwkiX0K5kZGRQ\nXFwc6TLaFbVpeKhdQ09tGh6hbtfs7OxWbaeetoiISJRQaIuIiEQJhbaIiEiUaJP7tI9mjKG2thbb\ntrEsK9LlRJ0DBw5QV1cHBNrS4XAQHx+vthQRiTJREdq1tbXExMTgckVFuW2Oy+XC6XQGH/t8Pmpr\na0lISIhgVSIicrJaTMFnn32WNWvWkJKSwlNPPXXMemMMixYtYu3atcTFxTF79mz69esHwLJly3jj\njTcAmDp1KuPHjz+lIm3bVmCHkMvlCva8RUQkerS4T3v8+PE8+OCDx12/du1a9u/fz29/+1u+973v\n8cc//hEAr9fL66+/zmOPPcZjjz3G66+/jtfrPaUiNYwbempTEZHo02L3dfDgwRQVFR13/apVqxg3\nbhyWZTFo0CCqqqooLS1l48aNDBs2jOTkZACGDRvGunXruOiii0JXvYhIFDHGgG0f/jF+sA2YQ4/t\nwGNM4/JD9+0j7puj7jeuC/7YYAgury9OwZSWHrP8eNs3t12wbmhao+FwfYE3eHi5OaJOc8S2R752\nk+cwh1rpiOc64rbJfdPkJvhcRy7jiN8LA9+l34SE5LC+RnNOe8zZ4/GQkZERfJyeno7H48Hj8ZCe\nnh5c7na78Xg8zT7HkiVLWLJkCQBz585t8nwQOJCqPQyPG2O45ppreOmll7BtmzfeeINbbrnlpJ/n\n+uuv57nnniMlJeW42/z85z9n0qRJXHzxxQDHtF9cXNwx7Syt53K51H4hZIzBLjmI2VdAan0dxu8D\nvx/j94PfF7zF58fYTW+PXG/8fvD5wPZjfI3PYfvh0Prg/cbnto+4f+S2dtNtDq0/5n4wgP1g20c8\nthsf+48IY7tpAJ0hpWf8FduQMI4omuG5ZOReGLbnP542kYSTJ09m8uTJwcdHzzJTV1fX5ECqaLVk\nyRJycnJISEhgz549LFq0iOnTpx+znc/nO+GXlJdffjm43fHMmDGD++67jwsuuACXy3XMtnV1dZol\n6TRolqlTZyrKYO8uTOHuwO3eXVC4G2prwvOCTmfgx9H4c+i+03HUMscR6xofH7rvig0usxyH1h35\n4wwEhMMB1uHllnXUdpYDHFbj7VHPc+j3sQ5vA4fXNbPcsqzAOqtx+TGPoXNKChWV3mOWN7u9o/F1\nLOuI+8dbbh2uDyvwnIeWB5//iGUcZ/mhZcHn4IiwtY64sZquC+bx4eVncrdfTIRmRDvt0Ha73U0K\nLykpwe1243a7yc/PDy73eDwMHjz4dF8uombOnElhYSF1dXXMmjWLG2+8kQ8//JC5c+fi9/txu928\n9tprVFVV8dBDD/H5559jWRY/+tGPuOKKK3jzzTe54YYbAHjsscfYtWsXX/va1xg3bhyTJk3iiSee\nICUlhW3btvHf//632dcDGD16NO+99x5VVVXceOONnH/++axatYqsrCwWLlxIQkICPXr0oLS0lKKi\nolZ/GERCydRUQ+FuzN6dsHf34XCuLD+8UXIn6N4H64KJkN2TzpndqKyuPiI8XY1h6moavk5XIHSd\nrsPB2sx6yxH9X/ZPV1xGBpa+YLYbpx3aubm5vP/++1x44YVs3bqVxMRE0tLSGDFiBH/5y1+CB5+t\nX7+e66+//rQLtv/6AmbPjtN+niNZPfvi+M53W9zuqaeeIi0tjZqaGq644gouu+wy7rvvPt544w16\n9epFaWlgIOqZZ56hU6dOLF26FICysjIA8vLyePzxxwF48MEH2bJlC//+978BWL58OV988QUffPAB\nvXr1avb1pkyZgtvtblLTjh07mDdvHk888QS33XYb7777Ltdccw0A55xzDnl5eXzzm98MQSuJNM80\n1MO+PZi9u5v0oPEcPLxRXDxk98Iafj5074XVvQ907wWdUpv0juIzMvAqYESOq8XQfuaZZ8jPz6ey\nspLvf//7TJs2LTjUeumllzJy5EjWrFnDnXfeSWxsLLNnzwYgOTmZa665hgceeACAa6+9NnhQWrRa\nuHAh7733HhC4qMmf/vQnxowZEwzZtLQ0AP7zn//w7LPPBn8vNTUVCIT3idpgxIgRwedq7vV27Nhx\nTGj37NmToUOHAoGD/fbs2RNcl56ezoEDB075/UrHZvx+qKmCai9UBW5NtReqvFBeGgjnwl1wYF9g\nny2AywVZPbAGDIYevbGyewfC2d0Fy6EJGEVOV4uh/cMf/vCE6y3L4tZbb2123cSJE5k4ceKpVXYc\nrekRh8Py5cv5z3/+w9tvv01CQgLXXnstQ4YMYfv27a1+DpfLhW3bOI7zxysxMfGEr9fcudVxcXHB\n+06nk9ra2uDjuro64uPjW12ftD/G9gdCtsobCN9qL6bKC9VVwcdUNYZxdVWT7U64j9lyQNdugV5z\n7sVY3XtB997QpRtWOzhoVKSt0v+uVqqsrCQlJYWEhAS2bdvGmjVrqKur47PPPmP37t3B4fG0tDTG\njRvHiy++yCOPPAIEetipqan069ePXbt20bdvX5KSkk543npzr3eyvvrqK6688spTfs8SPYwxUO6B\ngsYDuw4d4LVvDzTUH/8XY2MhMfnwT3oXrJ59Dj9OSoaEJKykI7ZJSoKkzlgxMWfs/YlIgEK7lcaP\nH88rr7zCJZdcQv/+/Tn33HNJT0/n17/+Nbfeeiu2bZORkcFf//pX7rrrLh588EEmTpyIw+Hg7rvv\nZsqUKUyaNIlPP/2Uvn374na7GTVqFBMnTmTChAlMmjSpxdc7GQ0NDezcuZPhw4eHshmkDTDVVVC4\nq3Ef8s7GkN4NVZWHN0pxQ/feWOMvh/RMSErCOjKIG+8reEWii2VMBE4cbEFhYWGTx9XV1U2GjqPV\ngQMHuOuuu/jrX/8a9td67733+OKLL/jxj3/c7Clf7aVNI+VMnPJlGhrgQAGm4Iie89EHeMUnBMK5\ne++mt8mdw1pbuOhUutBTm4ZHqNv1jJ3yJa2XmZnJ9ddfT2VlJZ06dQrra/l8Pm677bawvoaEhrFt\nKClq7DU3HoFdsBOKCgOTeEDgFKas7oEDvA4dfd2jd+AAL01JK9JhKLTPsG984xtn5HWuuuqqM/I6\ncnKCE4scud+5cDfUHT6AkIzMQG955JjG3nMfyMzWAV4iotAWCQdTW9M4schRQ9tNJhbpHAjli74W\nuM3uFehFx2u3hYg0T6EtchqMzwcHCjGFuxqP3N4Z6Dkf3H94o9i4wMQiw0YdPne5R2+szmkRq1tE\nopNCW+QkmIoyzPqVlO/8Ev9XX8L+gsAFKiAw3WZmd6zeA2DspOBBYWRkamIREQkJhbZIC8zB/Zi1\nn2LWroDtmwKXO3RnBObMHnre4QPDsnroFCoRCSt9/T+DjDF861vforKysuWNmzFw4EAA9u/fz3e/\n2/zMcNdeey3r168H4Nvf/nZw3nNpPWMMZvd27L//Gf/P52A/+D3M/y6C2hqsK7+N4+HfkPHHv+O8\n82Ec19yMY8wErJ59FdgiEnbqaZ9BS5cuZfDgwad9uldWVhYvvPBCi9sdunb3Pffcc1qv1xEYvx+2\nbsSsW4FZ+1ngXGjLAQNzsKbNwhoxGqtLVnB7nWYlIpGg0D4Job40Z3Z2NjNmzAACV/RKSkpi+vTp\n3HLLLZSXl+Pz+fjxj3/
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc43596be80>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.99e-03\n",
" final error(valid) = 1.17e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.75e-01\n",
" run time per epoch = 20.58\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.50\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VNX9x/H3uTOTyb7MDEkIJCxhkx0JW0QRiVo2pSpq\n9Vdrpa1Kq4J1AWut1WJRAVcs1ALaaltQK1Vci4oiEQ1gUAQUZIdAyALZl5l7fn9MGBIJJMBMJsv3\n9Tx5Zrv3zncOQz45dzlHaa01QgghhGj2jGAXIIQQQojGkdAWQgghWggJbSGEEKKFkNAWQgghWggJ\nbSGEEKKFkNAWQgghWggJbSGEEKKFkNAWQgghWggJbSGEEKKFkNAWQgghWghrsAuoz4EDB4JdQqvi\ncrnIy8sLdhmtirRpYEi7+p+0aWD4u12TkpIatZz0tIUQQogWQkJbCCGEaCEktIUQQogWolke0xZC\nCNE0tNZUVFRgmiZKqWCX02IcOnSIysrK01pHa41hGISGhp5xW0toCyFEG1ZRUYHNZsNqlTg4HVar\nFYvFctrrud1uKioqCAsLO6P3ld3jQgjRhpmmKYHdhKxWK6ZpnvH6EtpCCNGGyS7xpnc2bd6qQ1tX\nlGG+8yp6+5ZglyKEEEKctVYd2lhs6HdeRa9+P9iVCCGECKLnn3+eV155BYClS5dy8ODB097G3//+\nd982TmbLli1MmzbtjGpsjFYd2spmQw0chv5yLbq6OtjlCCGEOEMej+eUj0/G7XbjdrtZunQpP/7x\njwF45ZVXOHToUKPep7YbbriByZMnn/L9zjnnHHJycti/f3+j6jtdrTq0AdSQC6C8FDZ/GexShBBC\nnMRrr73G+PHjufjii7nnnnvweDx0796dP/7xj2RkZLB+/XqGDRvGrFmzuPTSS1mxYgWbNm1iwoQJ\nZGRkMGXKFI4cOQLAVVddxQMPPMDYsWP529/+xpo1a+jbty9Wq5UVK1awceNGfvOb33DxxRdTXl5+\nwnZffvllxo0bR0ZGBr/85S8pLy8HYO7cuSxYsACAH//4x8yaNYvx48czcuRIPv/8c99nufjii/nv\nf/8bkHZq/acMnjMAIqLQWatRA4YGuxohhGi2zH8/j96706/bVMldMK795SmX2bZtG2+88QbLly/H\nZrMxc+ZM/vOf/1BWVsagQYP4wx/+4Fs2Li6O9957D4CMjAwefvhhRowYweOPP868efN46KGHAKiu\nruadd94BYM6cOfTv3x+ACRMm8MILL/D73/+eAQMG1LvdgoICrr/+egAeffRR/vWvf3HTTTedULfb\n7eatt97igw8+YN68eSxduhSAAQMG8OyzzzJ16tQzarNTafWhraxW1Lkj0F+sRldVokLswS5JCCFE\nLZ9++ilff/0148aNA7zXjrtcLiwWC+PHj6+z7GWXXQZAUVERR48eZcSIEQBMnjyZm2+++YTlAHJz\nc+nevfspa6i9/Lfffstjjz1GUVERpaWljBo1qt51jtXbv39/9u3b53ve6XSedPf72Wr1oQ2g0kZ6\nT0bbtB7OTQ92OUII0Sw11CMOFK01kydPZubMmXWeX7BgwQkDmISHhzdqm7WXCw0NpaKiotHLT58+\nnUWLFtGnTx+WLl3KZ599Vu86ISEhAFgsFtxut+/5yspKQkNDG1Xn6WrVx7RLKj0s+OIgG2NSISoG\n/cXqYJckhBDiB0aOHMmKFSt8U10WFhbW6bnWJzo6mpiYGN+x5Ndee43hw4fXu2y3bt3YtWuX73FE\nRAQlJSUn3XZJSQkJCQlUV1fz+uuvn+angR07dtCzZ8/TXq8xWnVP2241WLuvhJySavoPTkdnfoCu\nKEeFntnwcUIIIfyvR48e3HPPPfzkJz9Ba43VamXWrFkNrvfkk08yY8YMKioqSElJYd68efUud9FF\nF3H77bf7Hl999dXMmDGD0NBQ3njjjROWv/vuu5kwYQJOp5NBgwadMuDrk5mZyZgxY05rncZSWmsd\nkC2fhQMHDvhtW69syuOljXk81cdN8vz7UL+8C2PoBX7bfkvg78nahbRpoEi7+l9DbVpWVtboXc4t\n2ZQpU/jd735H165d/bI9q9VaZ5f4MZWVlVx55ZUsX778pMPD1tfmSUlJjXrfVr17HODS7nGEWBRv\nljsgxoHO+jTYJQkhhGhiM2fOJDc3N+Dvs3//fu67776Ajefe6kM72m7hoq4xfLyrmKODL4RN69Hl\nZcEuSwghRBPq1q3bSY95+1PXrl1JTw/cCc+tPrQBJvaKo9rUvNt+GLir0dmfN7ySEEII0cy0idDu\nGG1nSIcI3s2zUuVMQGfJWeRCCCFanjYR2gCX9XJwtNLD6oGTYHM2uvT0zgYUQgghgq3NhHa/hHC6\nxNl5056K9rjRX9Z/sbwQQgjRXLWZ0FZKcVkvB3vKYGPnYXIWuRBCtCG1p+Y8XdOmTWPFihUA3HXX\nXXz33XcnLLN06VJ+97vfAbBkyRL+/e9/n3mxp9BmQhvg/E7RxIVaeLPrGNi6EV18NNglCSGEaAR/\nTs15NubMmUOPHj1Oucy1117L4sWLz/q96tOo0M7OzuaOO+7gtttuY/ny5Se8vmLFCqZPn85dd93F\nQw89xOHDh32vrVq1ittvv53bb7+dVatW+a3wM2GzKMb1jONLM5Y9oe3Q6zODWo8QQgivppqac/v2\n7XUmIdm7d69v9LInnniCcePGcdFFF3HPPfdQ39hjV111FRs3bgS8veuRI0cyfvx41q1b51smLCyM\n5ORkvvzS/1NCN3j1t2maLFq0iPvvvx+n08nMmTNJS0ujY8eOvmU6d+7M7NmzsdvtvP/++7z00ktM\nnz6dkpISXn31VWbPng3AjBkzSEtLIzIy0u8fpLF+1C2WVzbls6LHpUxd9ylcODZotQghRHPyt3WH\n2Fl46ok1TleXuFB+kZZwymWacmrObt26UVVVxZ49e0hJSeGNN95g4sSJANx4441Mnz4dgNtuu43/\n/e9/XHLJJfXWfOjQIebMmcO7775LVFQUkydPpm/fvr7X+/fvz+eff86gQYPOpNlOqsGe9vbt20lM\nTCQhIQGr1Up6ejpZWVl1lunbty92u3fKy+7du1NQUAB4e+j9+/cnMjKSyMhI+vfvT3Z2tl8/wOmK\nDrUyuksMH8f25sjOnegjBUGtRwgh2rraU3NefPHFfPrpp+zZs+e0p+Y8NnlI7eXAOzWn0+n0PZ44\ncaJvzPE33njDt2xmZiYTJkxgzJgxZGZm1nvs+pj169czYsQInE4nISEhdd4PvMPHBmJ6zgZ72gUF\nBXU+rNPpZNu2bSdd/sMPP2TgwIH1rutwOHyBHkyX9Yrjve1HeL/9cK5Zn4kaMyHYJQkhRNA11CMO\nlKaemvOyyy7j5ptvZuzYsSil6Nq1KxUVFdx33328/fbbdOjQgblz51JZWXnGnylQ03P6dXDUTz75\nhB07dvDggw+e1norV65k5cqVAMyePRuXy+XPsk7gcsGIzkd4xz2SydnLib/mxoC+X7BZrdaAt2lb\nI20aGNKu/tdQmx46dChg42Q31qhRo/jZz37GLbfcQrt27SgsLPTNrFW7NqUUFosFq9WKw+EgNjaW\ndevWMXz4cF5//XXS09OxWq11lgPo2bMne/bs8T3u1q0bVquVp59+mkmTJmG1Wn0ntsXHx1NZWclb\nb73FxIkTsVqtGIbh296xbQ8ePJjf//73FBUVERUVxVtvvUWfPn1877Fz506GDh1ab9va7fYz/p43\n+C/lcDjIz8/3Pc7Pz8fhcJyw3FdffcXrr7/Ogw8+iM1m8627efNm3zIFBQX07t37hHUzMjLIyMjw\nPW6KWX7Gpkby2a5wPjxi5eLvtqAc7QL+nsEiMyf5n7RpYEi7+l9DbVpZWXlCb7appaamcvfdd3P1\n1VefMDVn7Zm0tNZ4PB7fc0888cQJU3O63e4Tlhs1ahS33357nW1NnDiRhx9+mLVr1+J2u4mIiOC6\n665j1KhRtGvXjgEDBmC
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435c0f518>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xt8k+X9//HXnaRND+kpCW0plFMBLUfBclQLhYqIhzF0\nHocHcHOyL56mTpw/59dNx0Snc1PQDVBxbPhVcXPzWBB1gFAEFCinIqdaSg/puUma5L5+f6QEysEW\nSEnTfp6PByZ37ivNJ1dj37nuw3VrSimFEEIIIdo9Q6gLEEIIIUTrSGgLIYQQYUJCWwghhAgTEtpC\nCCFEmJDQFkIIIcKEhLYQQggRJiS0hRBCiDAhoS2EEEKECQltIYQQIkxIaAshhBBhwhTqAk6muLg4\n1CV0KHa7nfLy8lCX0aFIn7YN6dfgkz5tG8Hu17S0tFa1k5G2EEIIESYktIUQQogwIaEthBBChIl2\nuU/7eEopXC4Xuq6jaVqoywk7hw8fxu12A/6+NBgMREVFSV8KIUSYaTG0X3rpJTZu3EhCQgLPPvvs\nCeuVUixevJhNmzZhNpuZNWsWffr0AWDVqlW88847AEybNo3x48efUZEul4uIiAhMprD4jtHumEwm\njEZjYNnr9eJyuYiOjg5hVUIIIU5Xi5vHx48fzyOPPHLK9Zs2baKkpIQXXniBn/70p/z1r38FoK6u\njrfeeounnnqKp556irfeeou6urozKlLXdQnsIDKZTOi6HuoyhBBCnKYWQ3vAgAFYLJZTrt+wYQPZ\n2dlomkb//v2pr6+nsrKSzZs3M2TIECwWCxaLhSFDhrB58+YzKlI24waf9KkQQoSfsx6+OhwO7HZ7\nYNlms+FwOHA4HNhstsDjVqsVh8Nxti8nhBCilZSuozyNKLcblA5KHb3Vdf99XR19XD+mzZHHj7Q7\n/rm6Dj4f6L6m5WNufTrqhMdO3u7osg6opso10JpuAbSTLJ9wqwVWn7CsAv9peh8cXaGa/gUeUses\nO8ly05O9k34A0ace0LaVdrHNOS8vj7y8PADmzp3b7EsA+A+k6gibx5VSXHPNNbz22mvous4777zD\n7bfffto/56abbmL+/PkkJCScss3jjz/OxIkTueSSSwBO6D+z2XxCP4vWM5lM0n9toDP2q/L5wOdF\neTwot8v/z9WAch1760S5nSins/l9d9PykX9uF8rZcPS+y0lpqN/g6dC05gHaXmkaamgW9qyLzvlL\nn3USWq3WZrPCVFRUYLVasVqtFBQUBB53OBwMGDDgpD8jNzeX3NzcwPLxs8y43e5mB1KFq7y8PDIz\nM4mOjubgwYMsXryY6dOnn9DO6/V+75eU119/PdDuVG677TYefPBBxowZg8lkOqGt2+2WWZLOgswy\n1TZC3a9K18HtgoZ6cNaDswGc9aimW//jDeBygs8LXg94vYHgPbLMKZe94PP4l71Ny+oMji8xR0Gk\nGaKim98m2tFSzGCOBrMZzRxNbEIC9U4XGAz+UDRooB25bzh6/9jlY9toBrSTPtcIRqP/OYZjbo2G\nkzxm9D/PeFw77Zj2mqHZbjuljhkZc8zo+PjlwGj5uOXAyJqjo/TAz/+eUfmx7Y573rH1RYRoRrSz\nDu2srCw+/PBDLrroInbv3k1MTAxJSUlccMEF/P3vfw8cfPb1119z0003ne3LhdSMGTMoLi7G7XYz\nc+ZMfvzjH/Ppp58yd+5cfD4fVquVN998k/r6eh599FG++eYbNE3jvvvu44orrmD58uXcfPPNADz1\n1FPs37+fSy+9lOzsbCZOnMi8efNISEigsLCQ//73vyd9PYBRo0bxwQcfUF9fz49//GNGjhzJhg0b\nSE1NZdGiRURHR9O9e3cqKyspLS1t9YdBiHCnlPKHan0t1NVCfQ2qvq4pfBuaB3FD/XGPNYCroeWR\nntHkD8mICP99owlMTf+OXY6KCtzXTmgX4Q+w45cjzRAVhRYZ5X9+U/g2u42I9IdoK8Xa7TjD8Aum\ndvwmcAG0IrSff/55CgoKqK2t5Wc/+xnXXXddYNQ2adIkhg0bxsaNG7n77ruJjIxk1qxZAFgsFq65\n5hrmzJkDwLXXXvu9B7S1lv6Pv6AO7j3rn3MsLb03hht+0mK7Z599lqSkJJxOJ1dccQWXXXYZDz74\nIO+88w49evSgsrIS8PdZXFwcK1asAKCqqgqA/Px8fv/73wPwyCOPsHPnTj755BMA1qxZw5YtW1i5\nciU9evQ46etNmTIFq9XarKa9e/fy4osvMm/ePO68807ef/99rrnmGgAGDx5Mfn4+P/jBD4LQS0Kc\nW6rR7Q/ehiMBXItquj0SyuqY+4HHv+/MCKMRomMgOvbobZdUtOgYiLE0PeZ/XDu2XUzs0fsRkXIg\npwiZFkP73nvv/d71mqZxxx13nHTdhAkTmDBhwplV1g4tWrSIDz74APBf1OSNN95g9OjRgZBNSkoC\n4IsvvuCll14KPC8xMRHwh/f3fXG54IILAj/rZK+3d+/eE0I7PT2dQYMGATBkyBAOHjwYWGez2Th8\n+PAZv18h2oLSdaiphPJSVPlhqCiF8sOoilIqnPX4qqugvgYaG0/9QyIjITYeYuMg1gLdeqDFxoOl\naTk2Hs0Sd3T9kQCONEvgirAWdkd3tWZE3BbWrFnDF198wXvvvUd0dDTXXnstAwcOZM+ePa3+GUfO\njzacYtNWTEzM977ekVnNjmU2mwP3jUYjLpcrsOx2u4mKimp1fUIEg1IK6mr8QVzuD2QqDh8T0KX+\nfb3HiksAewqGLqloaT394RtjAcsx4WuJgxj/rRZpPvmLC9HBhV1oh0ptbS0JCQlER0dTWFjIxo0b\ncbvdfPnllxw4cCCweTwpKYns7GxeffVVnnjiCcA/wk5MTKRPnz7s37+f3r17Exsb+72TzZzs9U7X\nt99+y5VXXnnG71mIU1H1dc3D+NhRc0Wp/2CuY1niwJYC3XqiDRkJ9hQ0ezLYksGWgtb05TNJDvAT\n4ntJaLfS+PHjWbJkCePGjSMjI4Phw4djs9l4+umnueOOO9B1Hbvdzj/+8Q/uueceHnnkESZMmIDB\nYOD+++9nypQpTJw4kbVr19K7d2+sVisjRoxgwoQJ5OTkMHHixBZf73R4PB727dvH0KFDg9kNopNR\nug9KD/mPIzn4LergPji4F6qPm3MhOsYfyslpaAMu8IeyLRnsKWBL9u8fFkKcNU2p9ndSXHFxcbPl\nhoaGZpuOw9Xhw4e55557+Mc//tHmr/XBBx+wZcsWHnrooZOe8tVR+jRUQn1qUltQLid8tx918Fs4\nuM9/+91+aGzaLWM0QtceaOm9oFsvtC6pYPcHsxYTnEkmOmK/hpr0adsIdr+es1O+ROulpKRw0003\nUVtbS1xcXJu+ltfr5c4772zT1xDhSSkFlRVQtLdpBN10W3bo6OlOMRZI742WfZn/tntv6JqOFhER\n2uKF6OQktM+xq6+++py8zlVXXXVOXke0b8rrhZKDTZu1v0UV+W+pqz3aqEuqP5jHjEdL7wPde4PV\nLkdZC9EOSWgL0cGomkrUlo2oLflQsNk/aQhARCSk9UAbNubo6Ll7L9nfLEQYkdAWIswpXfePor/Z\ngNqyAfbt9m/mTrSiZV0M/Qeh9egDKd3QOsB0wEJ0ZhLaQoQh5WqAgs3+oN76FVRX+qd77N0f7eqb\n0IZkQXof2cQtRAcjoS1EmFAl36G2NI2md23zX4AiOhZt4DAYnIU2+EK0uFNf+U0IEf5aP+u8OGtK\nKX70ox9RW1vbcuOT6NevHwAlJSX85Ccnnxnu2muv5euvvwbg+uuvD8x7LsKP8npQBZvRl/0V36/u\nRP9/d6HeXAhVDrTcqzA88BSGPyzBcOdDGMZOkMAWohOQkfY5tGLFCgYMGHDWp3ulpqbyl7/8pcV2\nR67d/Ytf/OKsXk+cO6rKcXQ0XfA1uJ1gioDzh6DlXo02OAvNnhLqMoUQISKhfRqCfWnOtLQ0brvt\nNsB/Ra/Y2FimT5/O7bf
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc4355f6198>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 3.07e-03\n",
" final error(valid) = 1.34e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.71e-01\n",
" run time per epoch = 20.54\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=1.00\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl8VNXdx/HPubNlX2YCCZAgJICyCAgBAQFBolZARVzb\nPm1dalux+kBVBLTWaqFUQKjbA1pAWm0FN6q4tZG6lIiEHQQRDCB7yAJZJ8nknuePGwYiSwLMZLL8\n3q9XXpk7c5ffHCZ859ztKK21RgghhBCNnhHqAoQQQghRPxLaQgghRBMhoS2EEEI0ERLaQgghRBMh\noS2EEEI0ERLaQgghRBMhoS2EEEI0ERLaQgghRBMhoS2EEEI0ERLaQgghRBNhD3UBp7J///5Ql9Cs\nJCQkkJeXF+oymhVp0+CQdg08adPgCHS7tm3btl7zSU9bCCGEaCIktIUQQogmQkJbCCGEaCIa5TFt\nIYQQDUNrjdfrxTRNlFKhLqfJOHToEBUVFWe1jNYawzAICws757aW0BZCiBbM6/XicDiw2yUOzobd\nbsdms531cj6fD6/XS3h4+DltV3aPCyFEC2aapgR2A7Lb7Zimec7LS2gLIUQLJrvEG975tHm9vl6t\nX7+ehQsXYpomI0aMYMyYMbVeX7ZsGR9//DE2m42YmBjuueceWrVqBcCtt95K+/btAeu6tocffvic\niz1b2luO/s97qM7dUZ26Nth2hRBCiGCoM7RN02T+/Pk8+uijeDweJk+eTHp6OsnJyf55OnTowPTp\n03G5XPzrX//ilVdeYcKECQA4nU5mzJgRvHdwJjY7+oM34eA+CW0hhGjBXnrpJeLi4rj55ptZvHgx\nl19+OUlJSWe1jr/+9a+Eh4dz8803n3aerVu3Mm/ePObMmXO+JZ9SnbvHd+zYQVJSEomJidjtdgYN\nGkR2dnateXr06IHL5QKgc+fOFBQUBKXYs6UcDtQlA9DrvkBXVYW6HCGEEOeourr6jNOn4/P58Pl8\nLF68mBtuuAGA119/nUOHDtVrOyf66U9/esbABujatSsHDhxg37599arvbNUZ2gUFBXg8Hv+0x+M5\nYygvX76c3r17+6erqqqYNGkSjzzyCKtWrTrPcs+e6jcEysvgqzUNvm0hhBD18+abbzJq1CiuvPJK\nJk6cSHV1NZ07d+b3v/89GRkZrFmzhksvvZSpU6dy9dVXs2zZMjZv3szo0aPJyMjgrrvu4siRIwDc\ndNNNPPbYY1xzzTX85S9/YcWKFfTo0QO73c6yZcvYsGEDv/71r7nyyispLy8/ab2vvvoqI0eOJCMj\ng7vvvpvy8nIAZs2axdy5cwG44YYbmDp1KqNGjWLw4MF8+eWX/vdy5ZVX8s9//jMo7RTQUwY/++wz\ncnJyePzxx/3PvfDCC7jdbg4dOsQTTzxB+/btT9olkZmZSWZmJgDTp08nISEhYDXpwVdweOEcHBtW\nEZcxOmDrbUrsdntA21RImwaLtGvg1dWmhw4d8p897vv7PMzvcgK6faN9KvYf/fKM83zzzTe8++67\nLFu2DIfDwcMPP8w///lPysrKSE9P58knnwSsE7g8Hg8ff/wxAMOGDWPatGkMGjSIP/3pT8yZM4c/\n/OEPKKWorq7m3//+NwBPPfUUvXv3xm63M2bMGBYtWsTvfvc7fwfz++stKCjgZz/7GQB//OMfWbx4\nMT//+c8xDAPDMPztZZomH330EZmZmcyePZs33ngDgD59+vDMM89w//33n/L9ulyuc/6c1xnabreb\n/Px8/3R+fj5ut/uk+TZu3Mjbb7/N448/jsPhqLU8QGJiIt26dWPXrl0nhXZGRgYZGRn+6YDf3P6S\nAVR88R8O79uLcoUFdt1NgAwYEHjSpsEh7Rp4dbVpRUWF/3pj0zTRWgd0+6Zp4vP5zjjPp59+ysaN\nG7nqqqsA69pxt9uNzWbjBz/4gX95rTWjR4/G5/NRVFTE0aNH6d+/Pz6fjxtvvJFf/vKX+Hy+WvMB\nHDx4kLS0tFrrqa6uPuV6Ab766iueeuopioqKKC0t5fLLL8fn82GaZq33c6y27t27s2fPHv/zcXFx\nHDx48LTvu6Ki4qR/k/oOGFJnaKelpXHgwAFyc3Nxu91kZWWd9O1h586dvPTSS0yZMoXY2Fj/8yUl\nJbhcLhwOB0VFRWzbto3rr7++XoUFQnFFNQvW5jKky+X0+vRD9MZsa3e5EEKIkxi33R2S7Wqtufnm\nm5k8eXKt5+fOnXvSDUwiIiLqtc4T5wsLC8Pr9dZ7/gkTJjB//ny6d+/O4sWL+eKLL065jNPpBMBm\ns9UK6IqKCsLCgtNBrDO0bTYbd955J1OnTsU0TYYPH05KSgqLFy8mLS2N9PR0XnnlFbxeL08//TRw\n/NKuffv28eKLL2IYBqZpMmbMmFpnnQdbmN1g08FS9oVH0jPODas+BwltIYRoVAYPHswdd9zB3Xff\nTUJCAoWFhZSWlp5xmZiYGGJjY/nyyy+59NJLefPNNxkwYMAp5+3UqRO7du3yT0dGRlJSUnLadZeU\nlJCYmEhVVRVvv/32WZ9lnpOTw4UXXnhWy9RXvY5p9+nThz59+tR67tZbb/U//u1vf3vK5S688EJm\nzZp1HuWdH4dNcXOPBF5YdZANfUbT+7O/o8tKURGRIatJCCFEbV26dGHixIn88Ic/RGuN3W5n6tSp\ndS43Z84cJk2ahNfrpX379v6O4/ddccUVtfYQ33LLLUyaNImwsDDeeeedk+Z/6KGHGD16NB6Ph0su\nueSMAX8qWVlZjBgx4qyWqS+lA30AIwD2798fsHVVVWvueedb3IaPacsmYtzxvxiDgtOYjZUcJww8\nadPgkHYNvLratKysrN67nJuyu+66i0ceeYTU1NSArM9ut5/ymHVFRQU33ngjS5cuPe3tYU/V5vU9\npt3sb2PqsClu6uFhWwmsv6A/OvvzUJckhBCigU2ePJnc3Nygb2ffvn1MmTIlaPdzb/ahDTAiNY6E\nCDtL0q5Gb1mPLi4KdUlCCCEaUKdOnU57zDuQUlNTGTRoUNDW3yJC2zq27WGbGc2G2DT02qxQlySE\nEEKctRYR2nC8t72482hM2UUuhBCiCWoxoe2wKW7q7mFbRBs2HK5AH8mveyEhhBCiEWkxoQ2QkRZL\ngkux+IIMzOwVoS5HCCGEOCstKrQdNoOberZmW2wH1m/+NtTlCCGEaCAvvfQSr7/++jktO378eJYt\nWwbAgw8+yDfffHPSPIsXL+aRRx4BYOHChbz22mvnXuwZtKjQhpretlHJYldXzNwDoS5HCCFEPQRy\naM7zMXPmTLp06XLGeW677TYWLFhw3ts6lRYX2g6bwY0Xxlm97ZUbQl2OEEIIGm5ozh07djBq1Cj/\ndvfs2eO/e9ns2bMZOXIkV1xxBRMnTjzl4Ck33XQTGzZY2bF48WIGDx7MqFGjWL16tX+e8PBwUlJS\nWLduXcDbKThXfzdyV/ZK5o1Na3kt18UlWqOUCnVJQggRcn9ZfYidhWceWONsdYwP4+fpiWecZ/v2\n7bzzzjssXboUh8PB5MmTeeuttygrK+OSSy7hd7/7nX/e+Ph4PvroI8AaIfLJJ59k4MCBzJgxg6ef\nfponnngCgKqqKj744APA6h337NkTsK7Xrqys5LvvvqN9+/a88847XHvttQDcfvvtTJgwAYD77ruP\nf//73/6Rx77v0KFDzJw5kw8//JDo6GhuvvlmevTo4X+9Z8+efPnll1xyySXn0myn1eJ62lBzbDu+\nlG3hbdiwZXeoyxFCiBbtv//9L5s2bWLkyJFceeWV/Pe//+W7777DZrPV6hUDXHfddQD+oTkHDhwI\nwM0338yXX3550nwAubm5eDwe//S1117rv+f4O++84583KyuL0aNHM2LECLKysk557PqYNWvWMHDg\nQDweD06ns9b2wLp97KFDh86lOc6oRfa0ATIu684bS3fwj012enW7QHrbQogWr64ecbA09NCc1113\nHb/85S+55pprUEqRmpq
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435674e80>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAe0AAAENCAYAAAAi8D15AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xt8FPW9//HX7G7uCUl2FxIC4RaghnA3AqIiIfFSFKVo\nsdXiBWyt9MQresQfes6xR0sFj7anirQlqFgq1gqnarU1aNSKkCB3AhgQgZBAkt2E3HaT3Z3v748N\nC+GWALvZbPJ5Ph55zO7OJPPZL0veme98Z76aUkohhBBCiE7PEOwChBBCCNE+EtpCCCFEiJDQFkII\nIUKEhLYQQggRIiS0hRBCiBAhoS2EEEKECAltIYQQIkRIaAshhBAhQkJbCCGECBES2kIIIUSIMAW7\ngDMpKysLdglditVqpaqqKthldCnSpoEh7ep/0qaB4e92TUlJadd2cqQthBBChAgJbSGEECJESGgL\nIYQQIaJTntM+lVIKp9OJrutomhbsckLO0aNHaWpqArxtaTAYiIyMlLYUQogQ02Zov/LKK2zatIn4\n+HheeOGF09YrpVi+fDmbN28mIiKCuXPnMmjQIAAKCgp49913AZgxYwaTJ0++oCKdTidhYWGYTCHx\nN0anYzKZMBqNvudutxun00lUVFQQqxJCCHG+2uwenzx5Mk8++eRZ12/evJkjR47w29/+lp/97Gf8\n8Y9/BKC+vp533nmH5557jueee4533nmH+vr6CypS13UJbD8ymUzouh7sMoQQQpynNkN72LBhxMbG\nnnX9xo0bmTRpEpqmMXToUBoaGqiurmbLli2MHDmS2NhYYmNjGTlyJFu2bLmgIqUb1/+kTYUQIvRc\n9OGr3W7HarX6nlssFux2O3a7HYvF4nvdbDZjt9vP+DPy8/PJz88HYOHCha1+HnjPycqR9sU5tf0i\nIiJOa2fRfiaTSdovAKRd/U/aFJTHA243yt0MLhfK7Ua5mltec7W85gK3C+V7fOo2zd7vO7591vex\n9u7b4e+lUyRhTk4OOTk5vuenXrDe1NTU6pxsqFJKMXPmTPLy8tB1ndWrV3P33Xef98+ZNWsWv/vd\n74iPjz/rNs888wxTpkzhyiuvxGQy4Xa7W61vamqSGy5cBLlhRWBIu/pfsNtU6Tp4PODxBiFud+vn\nLYGKxxuGNDeDqwnV3Ayuk75aXqf5xHPlajppXfMpj49v6wLl/9OBprTvUR8W6bef196bq1x0aJvN\n5lYfCJvNhtlsxmw2U1xc7HvdbrczbNiwi91dSFu7di3Dhg0jLi6OQ4cO8cYbb5wxtN1u9zl7Flas\nWNHmvmbPns1jjz3GlVdeeTElCyG6OKUUNDmgsREcjeBoAEcDqrGh5XkjOOq9y8ZGlKMBnI0tQes+\nKYhPen7yY3+OnwkL936Fh594fPx5VAzEJ6L51kd4l6awk75MLV9hYPQ+1s74+hm2N5lOrDMaiezZ\nk/og/DF00aGdmZnJRx99xBVXXEFJSQnR0dEkJiYyevRo/vznP/sGn23dupXbb7/9ogsOptmzZ1NW\nVkZTUxNz5szhJz/5CZ9++ikLFy7E4/FgNpt5++23aWhoYMGCBWzbtg1N03j44Ye54YYbWL16NXfc\ncQcAzz33HAcOHOCaa65h0qRJZGdns2jRIuLj49m7dy//+te/zrg/gPHjx/Phhx/S0NDAT37yE8aN\nG8fGjRtJTk4mLy+PqKgo+vbtS3V1NRUVFe3+C04IEVqU2+0NUEcjOB0ty0aU09HyuoN6TaFXVYKj\nJXAdjdDY0BLO3u3bDFajEaKivcEYFQORURAd4wswzXhy6Bl9gYjRdOLxqc+Ph+bJz8PCToRtq8fe\n8NUMcmuRNkP7pZdeori4mLq6On7+858zc+ZMX1frtddey5gxY9i0aRMPPPAA4eHhzJ07F4DY2Fhu\nueUW5s+fD8Ctt956zgFt7aW/9QfUof0X/XNOpqUOxPCjn7a53QsvvEBiYiIOh4MbbriB6667jsce\ne4x3332Xfv36UV1dDXjbLC4ujrVr1wJQU1MDQFFREb/+9a8BePLJJ9mzZw8ff/wxAOvWrWP79u18\n8skn9OvX74z7mzp1KmazuVVN+/fv5+WXX2bRokXcd999/P3vf+eWW24BYMSIERQVFXHzzTf7oZWE\nEP6gdB2aneBweIP2eMA6HSeFbaMvdHE2ok4JZV9Iu5rb3F+DweAN2eOBGx0Nlp5oUQNOBHF0dMvj\nWLSolsfRMSe+JzxcBq92Em2G9kMPPXTO9Zqmce+9955x3ZQpU5gyZcqFVdYJ5eXl8eGHHwLeSU3e\nfPNNJkyY4AvZxMREAL744gteeeUV3/clJCQA3vA+1x8uo0eP9v2sM+1v//79p4V2amoqw4cPB2Dk\nyJEcOnTIt85isXD06NELfr9CiNaU2wXHqr1fLaGqTgne4499rzsavd3PToc3hJudoFTbOzOaWoI0\nuiV0oyHBjBbZF6KiIPL4Ou967dTXWp5b+/TFZrMFvnFEh+gUA9HOR3uOiANh3bp1fPHFF7z33ntE\nRUVx6623kpGRwb59+9r9M45fH204SxdPdHT0Ofd3/K5mJ4uIiPA9NhqNOJ1O3/OmpiYiI/03UEKI\nrkw5GqHGBtVVqGq793GNDVVtg2rvY+qOnTtwjaYTgRoZ5f2Ki0frmdzy/KTXo6IgIgrNF7CtQ1cL\nC/PL+5Ij5K4l5EI7WOrq6oiPjycqKoq9e/eyadMmmpqaWL9+PQcPHvR1jycmJjJp0iRee+01nnnm\nGcB7hJ2QkMCgQYM4cOAAAwcOJCYm5pw3mznT/s7Xt99+y4033njB71mIrkDpHqg9dsYQVjV2qK6C\narv3aPhUMXGQaIEEC1r/NEiwQKIFLT7xxLndk8LYX0ErxNlIaLfT5MmTWbFiBVdffTVpaWmMHTsW\ni8XC888/z7333ouu61itVt566y0efPBBnnzySaZMmYLBYOCRRx5h6tSpZGdn89VXXzFw4EDMZjOX\nXXYZU6ZMISsri+zs7Db3dz5cLhffffcdo0aN8mczCNHpqRo7qqQYSnagvtkJR0q9lxWdzGCABLM3\nhFP6o2WMPRHOLUsSzGjhEWfeiRBBoinVnpMrHausrKzV88bGxlZdx6Hq6NGjPPjgg7z11lsB39eH\nH37I9u3befzxx894nXZXadNgCfa1r13VhbSrslV4w/mblpCuaPn9EREJaelo/QdBovWkMLZAj3g0\nQ+jf+6E95LMaGP5u1w67Tlu0X1JSErfffjt1dXXExcUFdF9ut5v77rsvoPsQoqMppeBoGeqbHVCy\n0xvS9krvyugYGJKBdvV1aEOGQ79BaF3gpkxCnExCu4PddNNNHbKfadOmdch+hAgkpetQdqDlSHon\nqmQn1HovoSQuHm3ocLjuB2hDM7zd3HIdr+jiJLSFEJ2G8rhR+0tQx89HlxRDY8uATbMVbdho79H0\n0AxI6iMjo0W3I6EthOhwSimoq4Hyw6jyQ3CkFFV2kMr9JShno3ejXiloYy/3hbRmTQpu0UJ0AhLa\nQoiAUboHqiqgvBR1pBTKD7UsS08cQYN30FhyXyInX4+zXxrakGFoCZaz/2AhuikJbSHERVPNTd4B\nYi1HzZSXeh8fLQO368SGPRIguS/aZVd6l71TIbmv99png4EeVivNMtJZiLOS0O5AJ0/NeSGjx4cM\nGUJJSQlHjhzhqaee4g9/+MNp29x666089dRTjBo1ittuu42lS5d2+7l0hf8oZyMcPogqO+jt0i73\nHj1jqzhxpzDNANZe3lDOGAu9+6Il9/UuYwJ71YQQXZ2Edgc6eWrOi5GcnHzGwD7VLbfcwuuvv86j\njz56UfsT3Y9yu7yhfPggHP6uZXnAG87HhYV7B4MNHAoTs1uOnPtCUop3ekQhhN9JaJ8Hf0/NmZKS\n4ptP+4UXXiAmJoZZs2Zxzz33cOzYMdxuN48//jjXXXddqzoOHTrEXXfdxSeffILD4eCRRx6huLiY\nwYMHt7r3+LXXXsuMGTM
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc435c960f0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 7.72e-03\n",
" final error(valid) = 1.62e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.62e-01\n",
" run time per epoch = 21.37\n"
]
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"learning_rate = 0.2 # learning rate for gradient descent\n",
"\n",
"init_scales = [0.1, 0.2, 0.5, 1.] # scale for random parameter initialisation\n",
"final_errors_train = []\n",
"final_errors_valid = []\n",
"final_accs_train = []\n",
"final_accs_valid = []\n",
"\n",
"for init_scale in init_scales:\n",
"\n",
" print('-' * 80)\n",
" print('learning_rate={0:.2f} init_scale={1:.2f}'\n",
" .format(learning_rate, init_scale))\n",
" print('-' * 80)\n",
" # Reset random number generator and data provider states on each run\n",
" # to ensure reproducibility of results\n",
" rng.seed(seed)\n",
" train_data.reset()\n",
" valid_data.reset()\n",
"\n",
" # Alter data-provider batch size\n",
" train_data.batch_size = batch_size \n",
" valid_data.batch_size = batch_size\n",
"\n",
" # Create a parameter initialiser which will sample random uniform values\n",
" # from [-init_scale, init_scale]\n",
" param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
" # Create a model with four affine layers\n",
" hidden_dim = 100\n",
" model = MultipleLayerModel([\n",
" AffineLayer(input_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, output_dim, param_init, param_init)\n",
" ])\n",
"\n",
" # Initialise a cross entropy error object\n",
" error = CrossEntropySoftmaxError()\n",
"\n",
" # Use a basic gradient descent learning rule\n",
" learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
" stats, keys, run_time, fig_1, ax_1, fig_2, ax_2 = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
" plt.show()\n",
"\n",
" print(' final error(train) = {0:.2e}'.format(stats[-1, keys['error(train)']]))\n",
" print(' final error(valid) = {0:.2e}'.format(stats[-1, keys['error(valid)']]))\n",
" print(' final acc(train) = {0:.2e}'.format(stats[-1, keys['acc(train)']]))\n",
" print(' final acc(valid) = {0:.2e}'.format(stats[-1, keys['acc(valid)']]))\n",
" print(' run time per epoch = {0:.2f}'.format(run_time * 1. / num_epochs))\n",
"\n",
" final_errors_train.append(stats[-1, keys['error(train)']])\n",
" final_errors_valid.append(stats[-1, keys['error(valid)']])\n",
" final_accs_train.append(stats[-1, keys['acc(train)']])\n",
" final_accs_valid.append(stats[-1, keys['acc(valid)']])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |\n",
"|------------|--------------------|--------------------|------------------|------------------|\n",
"| 0.1 | 2.03e-03 | 1.35e-01 | 1.00 | 0.97 |\n",
"| 0.2 | 1.99e-03 | 1.17e-01 | 1.00 | 0.97 |\n",
"| 0.5 | 3.07e-03 | 1.34e-01 | 1.00 | 0.97 |\n",
"| 1.0 | 7.72e-03 | 1.62e-01 | 1.00 | 0.96 |\n"
]
}
],
"source": [
"j = 0\n",
"print('| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |')\n",
"print('|------------|--------------------|--------------------|------------------|------------------|')\n",
"for init_scale in init_scales:\n",
" print('| {0:.1f} | {1:.2e} | {2:.2e} | {3:.2f} | {4:.2f} |'\n",
" .format(init_scale, \n",
" final_errors_train[j], final_errors_valid[j],\n",
" final_accs_train[j], final_accs_valid[j]))\n",
" j += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models with five affine layers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"learning_rate = 0.2 # learning rate for gradient descent\n",
"\n",
"init_scales = [0.1, 0.2, 0.5, 1.] # scale for random parameter initialisation\n",
"final_errors_train = []\n",
"final_errors_valid = []\n",
"final_accs_train = []\n",
"final_accs_valid = []\n",
"\n",
"for init_scale in init_scales:\n",
"\n",
" print('-' * 80)\n",
" print('learning_rate={0:.2f} init_scale={1:.2f}'\n",
" .format(learning_rate, init_scale))\n",
" print('-' * 80)\n",
" # Reset random number generator and data provider states on each run\n",
" # to ensure reproducibility of results\n",
" rng.seed(seed)\n",
" train_data.reset()\n",
" valid_data.reset()\n",
"\n",
" # Alter data-provider batch size\n",
" train_data.batch_size = batch_size \n",
" valid_data.batch_size = batch_size\n",
"\n",
" # Create a parameter initialiser which will sample random uniform values\n",
" # from [-init_scale, init_scale]\n",
" param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
" # Create a model with five affine layers\n",
" hidden_dim = 100\n",
" model = MultipleLayerModel([\n",
" AffineLayer(input_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, output_dim, param_init, param_init)\n",
" ])\n",
"\n",
" # Initialise a cross entropy error object\n",
" error = CrossEntropySoftmaxError()\n",
"\n",
" # Use a basic gradient descent learning rule\n",
" learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
" stats, keys, run_time, fig_1, ax_1, fig_2, ax_2 = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
" plt.show()\n",
"\n",
" print(' final error(train) = {0:.2e}'.format(stats[-1, keys['error(train)']]))\n",
" print(' final error(valid) = {0:.2e}'.format(stats[-1, keys['error(valid)']]))\n",
" print(' final acc(train) = {0:.2e}'.format(stats[-1, keys['acc(train)']]))\n",
" print(' final acc(valid) = {0:.2e}'.format(stats[-1, keys['acc(valid)']]))\n",
" print(' run time per epoch = {0:.2f}'.format(run_time * 1. / num_epochs))\n",
"\n",
" final_errors_train.append(stats[-1, keys['error(train)']])\n",
" final_errors_valid.append(stats[-1, keys['error(valid)']])\n",
" final_accs_train.append(stats[-1, keys['acc(train)']])\n",
" final_accs_valid.append(stats[-1, keys['acc(valid)']])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"j = 0\n",
"print('| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |')\n",
"print('|------------|--------------------|--------------------|------------------|------------------|')\n",
"for init_scale in init_scales:\n",
" print('| {0:.1f} | {1:.2e} | {2:.2e} | {3:.2f} | {4:.2f} |'\n",
" .format(init_scale, \n",
" final_errors_train[j], final_errors_valid[j],\n",
" final_accs_train[j], final_accs_valid[j]))\n",
" j += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> How does increasing the number of layers affect the model's performance on the training data set? And on the validation data set?\n",
"\n",
"<span style='color: red;'>\n",
"The best final training set error across the four initialisation scales used above for each model architecture, consistently decreases as we increase the number of layers.\n",
"</span>\n",
"\n",
"| Number of affine layers | Best final training set error |\n",
"|-------------------------|-------------------------------|\n",
"| 2 | $1.85 \\times 10^{-2}$ |\n",
"| 3 | $5.21 \\times 10^{-3}$ |\n",
"| 4 | $1.99 \\times 10^{-3}$ |\n",
"| 5 | $1.14 \\times 10^{-3}$ |\n",
"\n",
"<span style='color: red;'>\n",
"This makes sense as because the number of layers increase, for a fixed hidden layer width, the total number of free parameters in the model increases and so we would expect the model to be able to fit too the training data better.\n",
"</span>\n",
"\n",
"<span style='color: red;'>\n",
"If we look at the validation set however we see the opposite trend; as the number of layers increases the best final validation set error increases.\n",
"</span>\n",
"\n",
"| Number of affine layers | Best final validation set error |\n",
"|-------------------------|---------------------------------|\n",
"| 2 | $7.47 \\times 10^{-2}$ |\n",
"| 3 | $8.77 \\times 10^{-2}$ |\n",
"| 4 | $1.17 \\times 10^{-1}$ |\n",
"| 5 | $1.47 \\times 10^{-1}$ |\n",
"\n",
"<span style='color: red;'>\n",
"If we look more closely at the training curves for the models with more layers we can see what is happening here. For the models with three or more layers, after a certain number of epochs the validation set error begins to *increase* even as the training set error continues to decrease. This indicates that these models have begun *overfitting* to the training data. We could get a better validation set error in these cases by stopping the training early. *Early stopping* like this is one way of trying to overcome overfitting, in later labs we will consider other methods for improving generalisation by reducing overfitting.\n",
"</span>\n",
"\n",
"> Do deeper models seem to be harder or easier to train (e.g. in terms of ease of choosing training hyperparameters to give good final performance and/or quick convergence)?\n",
"\n",
"> Do the models seem to be sensitive to the choice of the parameter initialisation range? Can you think of any reasons for why setting individual parameter initialisation scales for each AffineLayer in a model might be useful? Can you come up with (or find) any heuristics for setting the parameter initialisation scales?\n",
"\n",
"<span style='color: red;'>\n",
"The final performance of the deeper models becomes increasingly sensitive to the choice of parameter initialisation. For the models with two affine layers, the final training errors for initialisation scales 0.1, 0.2 and 0.5 are all within approximately 10% of each other, while for the models with five affine layers there is an approximately 400% increase in final training error if moving from an initialisation scale of 0.2 to 0.1 and a 50% increase in final training error when moving from 0.2 to 0.5. The smaller parameter initialisation scales for the deeper models in particular seem to give poorer initial performance (error curves start from higher values) and for the five affine layer model the smallest parameter initialisation scale run shows a pronounced flatter section at the start of training with around 15 epochs before the error starts significantly decreasing.\n",
"</span>\n",
"\n",
"<span style='color: red;'>\n",
"In general the models with more layers also take longer to train per epoch, so on top of issues of potential overfitting and difficulty of choosing parameter initilisations we also need to factor in the potentially slower training of deeper models if computational time is a key constraint.\n",
"</span>\n",
"\n",
"<span style='color: red;'>\n",
"We might expect the appropriate initialisation scale for a given affine layer to depend on its input and output dimensionalities. Each output is calculated as the weighted sum of all the inputs, and so for a larger number of inputs the typical magnitude of the output activations will become larger as each will be calculate from a sum over more values. Similarly the backpropagated gradient at each input is calculated as a weighted sum over the gradients at each output, and so for a larger number outputs the typical magnitude of backpropagated gradients will become larger.\n",
"</span>\n",
"\n",
"<span style='color: red;'>\n",
"If we wish to keep some measure of the typical magnitude of the activations and backpropagated gradients at a given layer roughly constant through the network then we may therefore wish to set the parameter initialisation in a layer dimensionality dependent way. One heuristic based on trying to achieve a roughly constant variance in activations and backpropagated gradients through the network is to initialise the weights for a layer from a distribution with variance inversely proportional to the sum of the input and output dimensions of the layer. This is sometimes known as the Glorot or Xavier initialisation, after the name of the author of [the paper](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf) in which this scheme was proposed. This is discussed in [lecture 4](http://www.inf.ed.ac.uk/teaching/courses/mlp/2017-18/mlp04-learn.pdf).\n",
"</span>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 3: Hyperbolic tangent and rectified linear layers\n",
"\n",
"In the models we have been investigating so far we have been applying elementwise logistic sigmoid transformations to the outputs of intermediate (affine) layers. The logistic sigmoid is just one particular choice of an elementwise non-linearity we can use. \n",
"\n",
"As mentioned in [lecture 3](http://www.inf.ed.ac.uk/teaching/courses/mlp/2017-18/mlp03-mlp.pdf), although logistic sigmoid has some favourable properties in terms of interpretability, there are also disadvantages from a computational perspective. In particular that the gradients of the sigmoid become very close to zero (and may actually become exactly zero to a finite numerical precision) for very positive or negative inputs, and that the outputs are non-centred - they cover the interval $[0,\\,1]$ so negative outputs are never produced.\n",
2017-10-06 15:46:19 +02:00
"\n",
"Two alternative elementwise non-linearities which are often used in multiple layer models are the hyperbolic tangent and the rectified linear function.\n",
"\n",
"For a hyperbolic tangent (`Tanh`) layer the forward propagation corresponds to\n",
"\n",
"\\begin{equation}\n",
" y^{(b)}_k = \n",
" \\tanh\\left(x^{(b)}_k\\right) = \n",
" \\frac{\\exp\\left(x^{(b)}_k\\right) - \\exp\\left(-x^{(b)}_k\\right)}{\\exp\\left(x^{(b)}_k\\right) + \\exp\\left(-x^{(b)}_k\\right)}\n",
"\\end{equation}\n",
"\n",
"which has corresponding partial derivatives\n",
"\n",
"\\begin{equation}\n",
" \\frac{\\partial y^{(b)}_k}{\\partial x^{(b)}_d} = \n",
" \\begin{cases} \n",
" 1 - \\left(y^{(b)}_k\\right)^2 & \\quad k = d \\\\\n",
" 0 & \\quad k \\neq d\n",
" \\end{cases}.\n",
"\\end{equation}\n",
"\n",
"For a rectified linear (`Relu`) layer the forward propagation corresponds to\n",
"\n",
"\\begin{equation}\n",
" y^{(b)}_k = \n",
" \\max\\left(0,\\,x^{(b)}_k\\right)\n",
"\\end{equation}\n",
"\n",
"which has corresponding partial derivatives\n",
"\n",
"\\begin{equation}\n",
" \\frac{\\partial y^{(b)}_k}{\\partial x^{(b)}_d} = \n",
" \\begin{cases} \n",
" 1 & \\quad k = d \\quad\\textrm{and} &x^{(b)}_d > 0 \\\\\n",
" 0 & \\quad k \\neq d \\quad\\textrm{or} &x^{(b)}_d < 0\n",
" \\end{cases}.\n",
"\\end{equation}\n",
"\n",
"Using these definitions implement the `fprop` and `bprop` methods for the skeleton `TanhLayer` and `ReluLayer` class definitions below."
]
},
{
"cell_type": "code",
2017-10-15 14:13:09 +02:00
"execution_count": 1,
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [],
"source": [
"import numpy as np\n",
"from mlp.layers import Layer\n",
"\n",
"class TanhLayer(Layer):\n",
" \"\"\"Layer implementing an element-wise hyperbolic tangent transformation.\"\"\"\n",
"\n",
" def fprop(self, inputs):\n",
" \"\"\"Forward propagates activations through the layer transformation.\n",
"\n",
" For inputs `x` and outputs `y` this corresponds to `y = tanh(x)`.\n",
" \"\"\"\n",
2017-10-15 14:13:09 +02:00
" return np.tanh(inputs)\n",
2017-10-06 15:46:19 +02:00
"\n",
" def bprop(self, inputs, outputs, grads_wrt_outputs):\n",
" \"\"\"Back propagates gradients through a layer.\n",
"\n",
" Given gradients with respect to the outputs of the layer calculates the\n",
" gradients with respect to the layer inputs.\n",
" \"\"\"\n",
2017-10-15 14:13:09 +02:00
" return (1. - outputs**2) * grads_wrt_outputs\n",
2017-10-06 15:46:19 +02:00
"\n",
" def __repr__(self):\n",
" return 'TanhLayer'\n",
" \n",
"\n",
"class ReluLayer(Layer):\n",
" \"\"\"Layer implementing an element-wise rectified linear transformation.\"\"\"\n",
"\n",
" def fprop(self, inputs):\n",
" \"\"\"Forward propagates activations through the layer transformation.\n",
"\n",
" For inputs `x` and outputs `y` this corresponds to `y = max(0, x)`.\n",
" \"\"\"\n",
2017-10-15 14:13:09 +02:00
" return np.maximum(inputs, 0.)\n",
2017-10-06 15:46:19 +02:00
"\n",
" def bprop(self, inputs, outputs, grads_wrt_outputs):\n",
" \"\"\"Back propagates gradients through a layer.\n",
"\n",
" Given gradients with respect to the outputs of the layer calculates the\n",
" gradients with respect to the layer inputs.\n",
" \"\"\"\n",
2017-10-15 14:13:09 +02:00
" return (outputs > 0) * grads_wrt_outputs\n",
2017-10-06 15:46:19 +02:00
"\n",
" def __repr__(self):\n",
" return 'ReluLayer'"
]
},
2017-10-15 14:13:09 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Test your implementations by running the cells below."
]
},
2017-10-06 15:46:19 +02:00
{
"cell_type": "code",
2017-10-15 14:13:09 +02:00
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Outputs and gradients calculated correctly for TanhLayer.\n"
]
}
],
2017-10-06 15:46:19 +02:00
"source": [
"test_inputs = np.array([[0.1, -0.2, 0.3], [-0.4, 0.5, -0.6]])\n",
"test_grads_wrt_outputs = np.array([[5., 10., -10.], [-5., 0., 10.]])\n",
"test_tanh_outputs = np.array(\n",
" [[ 0.09966799, -0.19737532, 0.29131261],\n",
" [-0.37994896, 0.46211716, -0.53704957]])\n",
"test_tanh_grads_wrt_inputs = np.array(\n",
" [[ 4.95033145, 9.61042983, -9.15136962],\n",
" [-4.27819393, 0., 7.11577763]])\n",
"tanh_layer = TanhLayer()\n",
"tanh_outputs = tanh_layer.fprop(test_inputs)\n",
"all_correct = True\n",
"if not tanh_outputs.shape == test_tanh_outputs.shape:\n",
" print('TanhLayer.fprop returned array with wrong shape.')\n",
" all_correct = False\n",
"elif not np.allclose(test_tanh_outputs, tanh_outputs):\n",
" print('TanhLayer.fprop calculated incorrect outputs.')\n",
" all_correct = False\n",
"tanh_grads_wrt_inputs = tanh_layer.bprop(\n",
" test_inputs, tanh_outputs, test_grads_wrt_outputs)\n",
"if not tanh_grads_wrt_inputs.shape == test_tanh_grads_wrt_inputs.shape:\n",
" print('TanhLayer.bprop returned array with wrong shape.')\n",
" all_correct = False\n",
"elif not np.allclose(tanh_grads_wrt_inputs, test_tanh_grads_wrt_inputs):\n",
" print('TanhLayer.bprop calculated incorrect gradients with respect to inputs.')\n",
" all_correct = False\n",
"if all_correct:\n",
" print('Outputs and gradients calculated correctly for TanhLayer.')"
]
},
{
"cell_type": "code",
2017-10-15 14:13:09 +02:00
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Outputs and gradients calculated correctly for ReluLayer.\n"
]
}
],
2017-10-06 15:46:19 +02:00
"source": [
"test_inputs = np.array([[0.1, -0.2, 0.3], [-0.4, 0.5, -0.6]])\n",
"test_grads_wrt_outputs = np.array([[5., 10., -10.], [-5., 0., 10.]])\n",
"test_relu_outputs = np.array([[0.1, 0., 0.3], [0., 0.5, 0.]])\n",
"test_relu_grads_wrt_inputs = np.array([[5., 0., -10.], [-0., 0., 0.]])\n",
"relu_layer = ReluLayer()\n",
"relu_outputs = relu_layer.fprop(test_inputs)\n",
"all_correct = True\n",
"if not relu_outputs.shape == test_relu_outputs.shape:\n",
" print('ReluLayer.fprop returned array with wrong shape.')\n",
" all_correct = False\n",
"elif not np.allclose(test_relu_outputs, relu_outputs):\n",
" print('ReluLayer.fprop calculated incorrect outputs.')\n",
" all_correct = False\n",
"relu_grads_wrt_inputs = relu_layer.bprop(\n",
" test_inputs, relu_outputs, test_grads_wrt_outputs)\n",
"if not relu_grads_wrt_inputs.shape == test_relu_grads_wrt_inputs.shape:\n",
" print('ReluLayer.bprop returned array with wrong shape.')\n",
" all_correct = False\n",
"elif not np.allclose(relu_grads_wrt_inputs, test_relu_grads_wrt_inputs):\n",
" print('ReluLayer.bprop calculated incorrect gradients with respect to inputs.')\n",
" all_correct = False\n",
"if all_correct:\n",
" print('Outputs and gradients calculated correctly for ReluLayer.')"
]
2017-10-15 14:13:09 +02:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2017-10-06 15:46:19 +02:00
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
2017-10-15 14:13:09 +02:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
2017-10-06 15:46:19 +02:00
}
},
"nbformat": 4,
"nbformat_minor": 1
}