mlpractical/notebooks/03_Multiple_layer_models.ipynb

3280 lines
1.8 MiB
Plaintext
Raw Normal View History

2024-10-03 15:53:33 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$\n",
"\\newcommand{\\vct}[1]{\\boldsymbol{#1}}\n",
"\\newcommand{\\mtx}[1]{\\mathbf{#1}}\n",
"\\newcommand{\\tr}{^\\mathrm{T}}\n",
"\\newcommand{\\reals}{\\mathbb{R}}\n",
"\\newcommand{\\lpa}{\\left(}\n",
"\\newcommand{\\rpa}{\\right)}\n",
"\\newcommand{\\lsb}{\\left[}\n",
"\\newcommand{\\rsb}{\\right]}\n",
"\\newcommand{\\lbr}{\\left\\lbrace}\n",
"\\newcommand{\\rbr}{\\right\\rbrace}\n",
"\\newcommand{\\fset}[1]{\\lbr #1 \\rbr}\n",
"\\newcommand{\\pd}[2]{\\frac{\\partial #1}{\\partial #2}}\n",
"$\n",
"\n",
"\n",
"# Multiple layer models and Activation Functions\n",
"\n",
"In this notebook we will explore network models with multiple layers of transformations. This will build upon the single-layer affine model we looked at in the previous notebook and use material covered in the second and third lectures.\n",
"\n",
"You will need to use these models for the experiments you will be running in the first coursework so part of the aim of this lab will be to get you familiar with how to construct multiple layer models in our framework and how to train them.\n",
"\n",
"## What is a layer?\n",
"\n",
"Often when discussing (neural) network models, a network layer is taken to mean an input to output transformation of the form\n",
"\n",
"\\begin{equation}\n",
" \\boldsymbol{y} = \\boldsymbol{f}(\\mathbf{W} \\boldsymbol{x} + \\boldsymbol{b})\n",
" \\qquad\n",
" \\Leftrightarrow\n",
" \\qquad\n",
" y_k = f\\left(\\sum_{d=1}^D \\left( W_{kd} x_d \\right) + b_k \\right)\n",
"\\end{equation}\n",
"\n",
"where $\\mathbf{W}$ and $\\boldsymbol{b}$ parameterise an affine transformation as discussed in the previous notebook, and $f$ is a function applied elementwise to the result of the affine transformation (sometimes called the activation function). For example a common choice for $f$ is the logistic sigmoid function \n",
"\\begin{equation}\n",
" f(u) = \\frac{1}{1 + \\exp(-u)}.\n",
"\\end{equation}\n",
"\n",
"In the second lecture slides you were shown how to train a model consisting of an affine transformation followed by the elementwise logistic sigmoid using gradient descent. This was referred to as a 'sigmoid single-layer network'.\n",
"\n",
"In the previous notebook we also referred to single-layer models, where in that case the layer was an affine transformation, with you implementing the various necessary methods for the `AffineLayer` class before using an instance of that class within a `SingleLayerModel` on a regression problem. We could in that case consider the activation function $f$ to be the identity function $f(u) = u$. In the code for the labs we will however use a slightly different convention. Here we will consider the affine transformation and the subsequent activation function $f$ to be two separate transformation layers. \n",
"\n",
"This allows us to combine our already implemented `AffineLayer` class with any non-linear activation function applied to the outputs by simply implementing a layer object for the relevant non-linearity and then stacking the two layers together. An alternative would be to have our new layer objects inherit from `AffineLayer` and then call the relevant parent class methods in the child class; however, this would mean we need to duplicate a lot of the same boilerplate code in every new class.\n",
"\n",
"To give a concrete example, in the `mlp.layers` module there is a definition for a `SigmoidLayer` equivalent to the following (documentation strings have been removed here for brevity)\n",
"\n",
"```python\n",
"class SigmoidLayer(Layer):\n",
"\n",
" def fprop(self, inputs):\n",
" return 1. / (1. + np.exp(-inputs))\n",
"\n",
" def bprop(self, inputs, outputs, grads_wrt_outputs):\n",
" return grads_wrt_outputs * outputs * (1. - outputs)\n",
"```\n",
"\n",
"As you can see this `SigmoidLayer` class has a very lightweight definition, defining just two key methods:\n",
"\n",
" * `fprop` which takes a batch of values at the input to the layer and forward propagates them to produce activations at the outputs (directly equivalently to the `fprop` method you implemented for then `AffineLayer` in the previous notebook),\n",
" * `brop` which takes a batch of gradients with respect to the outputs of the layer and backward propagates them to calculate gradients with respect to the inputs of the layer (explained in more detail below).\n",
" \n",
"This `SigmoidLayer` class only implements the logistic sigmoid non-linearity transformation and so does not have any parameters. Therefore unlike `AffineLayer` it is derived directly from the base `Layer` class rather than `LayerWithParameters` and does not need to implement `grads_wrt_params` or `params` methods. \n",
"\n",
"To create a model consisting of an affine transformation followed by applying an elementwise logistic sigmoid transformation we first create a list of the two layer objects (in the order they are applied from inputs to outputs) and then use this to instantiate a new `MultipleLayerModel` object:\n",
"\n",
"```python\n",
"from mlp.layers import AffineLayer, SigmoidLayer\n",
"from mlp.models import MultipleLayerModel\n",
"\n",
"layers = [AffineLayer(input_dim, output_dim), SigmoidLayer()]\n",
"model = MultipleLayerModel(layers)\n",
"```\n",
"\n",
"Because of the modular way in which the layers are defined we can also stack an arbitrarily long sequence of layers together to produce deeper models. For instance the following would define a model consisting of three pairs of affine and logistic sigmoid transformations.\n",
"\n",
"```python\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, hidden_dim), SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim), SigmoidLayer(),\n",
" AffineLayer(hidden_dim, output_dim), SigmoidLayer(),\n",
"])\n",
"```\n",
"\n",
"## Back-propagation of gradients\n",
" \n",
"To allow training models consisting of a stack of multiple layers, all layers need to implement a `bprop` method in addition to the `fprop` we encountered in the previous week. \n",
"\n",
"The `bprop` method takes gradients of an error function with respect to the *outputs* of a layer and uses these gradients to calculate gradients of the error function with respect to the *inputs* of a layer. As the inputs to a hidden layer in a multiple-layer model consist of the outputs of the previous layer, this means we can calculate the gradients of the error function with respect to the outputs of every layer in the model by iteratively propagating the gradients backwards through the layers of the model (i.e. from the last to first layer), hence the term 'back-propagation' or 'bprop' for short. A block diagram illustrating this is shown for a three layer model below.\n",
"\n",
"<img src='res/fprop-bprop-block-diagram.png' />\n",
"\n",
"For a layer with parameters, the gradients with respect to the layer outputs are required to calculate gradients with respect to the layer parameters. Therefore by combining backward propagation of gradients through the model with computing the gradients with respect to parameters in the relevant layers, we can calculate gradients of the error function with respect to all of the parameters of a multiple-layer model in a very efficient manner. In fact the computational cost of computing gradients with respect to all of the parameters of the model using this method will only be a constant factor times the cost of calculating the model outputs in the forwards pass.\n",
"\n",
"So far, we have abstractly talked about calculating gradients with respect to the inputs of a layer using gradients with respect to the layer outputs. More concretely we will be using the chain rule for derivatives to do this, similarly to how we used the chain rule in exercise 4 of the previous notebook to calculate gradients with respect to the parameters of an affine layer given gradients with respect to the outputs of the layer.\n",
"\n",
"In particular if our layer has a batch of $B$ vector inputs each of dimension $D$, $\\left\\lbrace \\boldsymbol{x}^{(b)} \\right\\rbrace_{b=1}^B$, and produces a batch of $B$ vector outputs each of dimension $K$, $\\left\\lbrace \\boldsymbol{y}^{(b)}\\right\\rbrace_{b=1}^B$, then we can calculate the gradient with respect to the $d^\\textrm{th}$ dimension of the $b^{\\textrm{th}}$ input using the gradients with respect to the $b^{\\textrm{th}}$ output\n",
"\n",
"\\begin{equation}\n",
" \\frac{\\partial \\bar{E}}{\\partial x^{(b)}_d} = \n",
" \\sum_{k=1}^K \\left( \n",
" \\frac{\\partial \\bar{E}}{\\partial y^{(b)}_k} \\frac{\\partial y^{(b)}_k}{\\partial x^{(b)}_d} \n",
" \\right).\n",
"\\end{equation}\n",
"\n",
"The `bprop` method takes an array of gradients with respect to the outputs $\\frac{\\partial \\bar{E}}{\\partial y^{(b)}_k}$ and applies a sum-product operation with the partial derivatives of each output with respect to each input $\\frac{\\partial y^{(b)}_k}{\\partial x^{(b)}_d}$, producing gradients with respect to the inputs of the layer $\\frac{\\partial \\bar{E}}{\\partial x^{(b)}_d}$.\n",
"\n",
"For the affine transformation used in the `AffineLayer` implemented in lab 2, i.e. a forward propagation corresponding to \n",
"\n",
"\\begin{equation}\n",
" y^{(b)}_k = \\sum_{d=1}^D \\left( W_{kd} x^{(b)}_d \\right) + b_k\n",
"\\end{equation}\n",
"\n",
"then the corresponding partial derivatives of layer outputs with respect to inputs are\n",
"\n",
"\\begin{equation}\n",
" \\frac{\\partial y^{(b)}_k}{\\partial x^{(b)}_d} = W_{kd}\n",
"\\end{equation}\n",
"\n",
"and so the backwards-propagation method for the `AffineLayer` takes the following form\n",
"\n",
"\\begin{equation}\n",
" \\frac{\\partial \\bar{E}}{\\partial x^{(b)}_d} = \n",
" \\sum_{k=1}^K \\left( \\frac{\\partial \\bar{E}}{\\partial y^{(b)}_k} W_{kd} \\right).\n",
"\\end{equation}\n",
"\n",
"This can be efficiently implemented in NumPy using the `dot` function\n",
"\n",
"```python\n",
"class AffineLayer(LayerWithParameters):\n",
"\n",
" # ... [implementation of remaining methods from previous week] ...\n",
" \n",
" def bprop(self, inputs, outputs, grads_wrt_outputs):\n",
" return grads_wrt_outputs.dot(self.weights)\n",
"```\n",
"\n",
"An important special case applies when the outputs of a layer are an elementwise function of the inputs such that $y^{(b)}_k$ only depends on $x^{(b)}_d$ when $d = k$. In this case the partial derivatives $\\frac{\\partial y^{(b)}_k}{\\partial x^{(b)}_d}$ will be zero when $k \\neq d$ and the above summation reduces to a single term,\n",
"\n",
"\\begin{equation}\n",
" \\frac{\\partial \\bar{E}}{\\partial x^{(b)}_d} = \n",
" \\frac{\\partial \\bar{E}}{\\partial y^{(b)}_d} \\frac{\\partial y^{(b)}_d}{\\partial x^{(b)}_d}\n",
"\\end{equation}\n",
"\n",
"i.e. to calculate the gradient with respect to the $b^{\\textrm{th}}$ input vector we just perform an elementwise multiplication of the gradient with respect to the $b^{\\textrm{th}}$ output vector with the vector of derivatives of the outputs with respect to the inputs. This case applies to the `SigmoidLayer` and to all other layers applying an elementwise function to their inputs.\n",
"\n",
"For the logistic sigmoid layer we have that\n",
"\n",
"\\begin{equation}\n",
" y^{(b)}_d = \\frac{1}{1 + \\exp(-x^{(b)}_d)}\n",
" \\qquad\n",
" \\Rightarrow\n",
" \\qquad\n",
" \\frac{\\partial y^{(b)}_d}{\\partial x^{(b)}_d} = \n",
" \\frac{\\exp(-x^{(b)}_d)}{\\left[ 1 + \\exp(-x^{(b)}_d) \\right]^2} =\n",
" y^{(b)}_d \\left[ 1 - y^{(b)}_d \\right]\n",
"\\end{equation}\n",
"\n",
"which you should now be able relate to the implementation of `SigmoidLayer.bprop` given earlier:\n",
"\n",
"```python\n",
"class SigmoidLayer(Layer):\n",
"\n",
" def fprop(self, inputs):\n",
" return 1. / (1. + np.exp(-inputs))\n",
"\n",
" def bprop(self, inputs, outputs, grads_wrt_outputs):\n",
" return grads_wrt_outputs * outputs * (1. - outputs)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 1: training a softmax model on MNIST\n",
"\n",
"For this first exercise we will train a model consisting of an affine transformation plus softmax on a multiclass classification task: classifying the digit labels for handwritten digit images from the MNIST data set introduced in the first notebook.\n",
"\n",
"First run the cell below to import the necessary modules and classes and to load the MNIST data provider objects. As it takes a little while to load the MNIST data from disk into memory it is worth loading the data providers just once in a separate cell like this rather than recreating the objects for every training run.\n",
"\n",
"We are loading two data provider objects here - one corresponding to the training data set and a second to use as a *validation* data set. This is data we do not train the model on but measure the performance of the trained model on to assess its ability to *generalise* to unseen data. \n",
"\n",
"The concept of training, validation, and test data sets was introduced in lecture one, and the concept of generalisation is discussed in more detail in lecture five. As you will need to report both training and validation set performances in your experiments for the first coursework assignment we are providing code here to give an example of how to do this."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import logging\n",
"from mlp.layers import AffineLayer, SoftmaxLayer, SigmoidLayer\n",
"from mlp.errors import CrossEntropyError, CrossEntropySoftmaxError\n",
"from mlp.models import SingleLayerModel, MultipleLayerModel\n",
"from mlp.initialisers import UniformInit\n",
"from mlp.learning_rules import GradientDescentLearningRule\n",
"from mlp.data_providers import MNISTDataProvider\n",
"from mlp.optimisers import Optimiser\n",
"\n",
"plt.style.use('ggplot')\n",
"\n",
"# Seed a random number generator\n",
"seed = 6102016 \n",
"rng = np.random.RandomState(seed)\n",
"\n",
"# Set up a logger object to print info about the training run to stdout\n",
"logger = logging.getLogger()\n",
"logger.setLevel(logging.INFO)\n",
"logger.handlers = [logging.StreamHandler()]\n",
"\n",
"# Create data provider objects for the MNIST data set\n",
"train_data = MNISTDataProvider('train', rng=rng)\n",
"valid_data = MNISTDataProvider('valid', rng=rng)\n",
"input_dim, output_dim = 784, 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To minimise replication of code and allow you to run experiments more quickly, a helper function is provided below which trains a model and plots the evolution of the error and classification accuracy of the model (on both training and validation sets) over training."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval):\n",
"\n",
" # As well as monitoring the error over training also monitor classification\n",
" # accuracy i.e. proportion of most-probable predicted classes being equal to targets\n",
" data_monitors={'acc': lambda y, t: (y.argmax(-1) == t.argmax(-1)).mean()}\n",
"\n",
" # Use the created objects to initialise a new Optimiser instance.\n",
" optimiser = Optimiser(\n",
" model, error, learning_rule, train_data, valid_data, data_monitors)\n",
"\n",
" # Run the optimiser for 5 epochs (full passes through the training set)\n",
" # printing statistics every epoch.\n",
" stats, keys, run_time = optimiser.train(num_epochs=num_epochs, stats_interval=stats_interval)\n",
"\n",
" # Plot the change in the validation and training set error over training.\n",
" fig_1 = plt.figure(figsize=(8, 4))\n",
" ax_1 = fig_1.add_subplot(111)\n",
" for k in ['error(train)', 'error(valid)']:\n",
" ax_1.plot(np.arange(1, stats.shape[0]) * stats_interval, \n",
" stats[1:, keys[k]], label=k)\n",
" ax_1.legend(loc=0)\n",
" ax_1.set_xlabel('Epoch number')\n",
"\n",
" # Plot the change in the validation and training set accuracy over training.\n",
" fig_2 = plt.figure(figsize=(8, 4))\n",
" ax_2 = fig_2.add_subplot(111)\n",
" for k in ['acc(train)', 'acc(valid)']:\n",
" ax_2.plot(np.arange(1, stats.shape[0]) * stats_interval, \n",
" stats[1:, keys[k]], label=k)\n",
" ax_2.legend(loc=0)\n",
" ax_2.set_xlabel('Epoch number')\n",
" \n",
" return stats, keys, run_time, fig_1, ax_1, fig_2, ax_2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Running the cell below will create a model consisting of an affine layer followed by a softmax transformation and train it on the MNIST data set by minimising the multi-class cross entropy error function using a basic gradient descent learning rule. By using the helper function defined above, at the end of training, the evolution of the error function and also classification accuracy of the model over the training epochs will be plotted.\n",
"\n",
"**Your Tasks:**\n",
"- Try running the code for various settings of the training hyperparameters defined at the beginning of the cell to get a feel for how these affect how training proceeds. You may wish to create multiple copies of the cell below to allow you to keep track of and compare the results across different hyperparameter settings."
]
},
2024-10-10 15:52:23 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Varying initialisation scale\n",
"\n",
"First try a few different parameter initialisation scales:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `init_scale = 0.01`"
]
},
2024-10-03 15:53:33 +02:00
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 4,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 0.5s to complete\n",
" error(train)=3.10e-01, acc(train)=9.14e-01, error(valid)=2.91e-01, acc(valid)=9.18e-01\n",
"Epoch 10: 0.6s to complete\n",
" error(train)=2.88e-01, acc(train)=9.20e-01, error(valid)=2.76e-01, acc(valid)=9.23e-01\n",
"Epoch 15: 0.6s to complete\n",
" error(train)=2.78e-01, acc(train)=9.23e-01, error(valid)=2.69e-01, acc(valid)=9.24e-01\n",
"Epoch 20: 0.6s to complete\n",
" error(train)=2.71e-01, acc(train)=9.25e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 25: 0.5s to complete\n",
" error(train)=2.68e-01, acc(train)=9.25e-01, error(valid)=2.65e-01, acc(valid)=9.26e-01\n",
"Epoch 30: 0.5s to complete\n",
" error(train)=2.63e-01, acc(train)=9.27e-01, error(valid)=2.62e-01, acc(valid)=9.26e-01\n",
"Epoch 35: 0.5s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 40: 0.5s to complete\n",
" error(train)=2.59e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.29e-01\n",
"Epoch 45: 0.7s to complete\n",
" error(train)=2.55e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 50: 0.6s to complete\n",
" error(train)=2.54e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 55: 0.6s to complete\n",
" error(train)=2.52e-01, acc(train)=9.30e-01, error(valid)=2.58e-01, acc(valid)=9.29e-01\n",
"Epoch 60: 0.6s to complete\n",
" error(train)=2.52e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 65: 0.6s to complete\n",
" error(train)=2.50e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 70: 0.6s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 75: 0.6s to complete\n",
" error(train)=2.47e-01, acc(train)=9.31e-01, error(valid)=2.57e-01, acc(valid)=9.30e-01\n",
"Epoch 80: 0.5s to complete\n",
" error(train)=2.46e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 85: 0.5s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 90: 0.6s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.57e-01, acc(valid)=9.29e-01\n",
"Epoch 95: 0.5s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.29e-01\n",
"Epoch 100: 0.6s to complete\n",
" error(train)=2.43e-01, acc(train)=9.33e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABy3UlEQVR4nO3deXxU1f3/8deZzGQnC1lIIAkkhLAaBMQFrSCLVqECVrTu1Fq1ahFtsWLVql9Fsa3Wn1ptLS2tFkVQFHABcakiCoJKRDbZBRJIIPs+mfv7Y5KBIUESSDKT4f18PPLIzJ27nJsPk7w5c+65xrIsCxERERERP2HzdQNERERERA6ngCoiIiIifkUBVURERET8igKqiIiIiPgVBVQRERER8SsKqCIiIiLiVxRQRURERMSvKKCKiIiIiF9RQBURERERv6KAKiIiIiJ+xe7rBrS2wsJCnE6nr5shrSQhIYH8/HxfN0PagGobmFTXwKXaBq72rK3dbic2NvbY67VDW9qV0+mktrbW182QVmCMAdw1tSzLx62R1qTaBibVNXCptoHLX2urj/hFRERExK8ooIqIiIiIX1FAFRERERG/ooAqIiIiIn4l4C6SEhEREd+zLIuysjK/uvBGmlZZWUlNTU2r7S8kJISQkJAT2ocCqoiIiLS6srIyQkJCCA4O9nVT5BgcDkerzYBkWRaVlZWUl5cTERFx3PvRR/wiIiLS6izLUjg9CRljCA8PP+E56RVQRURERKRVNcyveryO6yP+JUuWsHDhQoqKikhJSWHy5Mn07du3yXU3btzIf//7X/bs2UN1dTUJCQmMHj2acePGedb5/vvvmTt3Ltu3byc/P5/rrruOsWPHHt8ZiYiIiEiH1uKAumLFCmbPns0NN9xA7969WbZsGTNmzODJJ58kPj6+0fohISFccMEFdO/enZCQEDZu3MgLL7xAaGgoo0ePBqC6upouXbpw1lln8e9///vEz6qNWa46+G4DVvFBbKef6+vmiIiIiASUFn/Ev3jxYkaOHMmoUaM8vafx8fEsXbq0yfXT09M555xzSE1NJTExkXPPPZeBAweyYcMGzzqZmZlcc801nH322TgcjuM/m/ayaR2uP92D9coLWHV1vm6NiIiIBJjly5dz7rnn4nK52mT/U6dO5frrr2/2+tXV1QwdOpScnJw2ac+RWtSD6nQ62bZtGxMmTPBanp2dzaZNm5q1j+3bt7Np0yZ+9rOfteTQjdTW1npdcWaMISwszPO4TWUNgPBIKC2GLesxfbLb9ngnqYY6tnk9pd2ptoFJdQ1cqm37e+SRR5gyZQo2m7sv8c9//jPvvvsu7733Xqvs/6GHHmrRFGAhISHcfPPNPPLII8ydO7dZ25zIv5cWBdSSkhJcLhfR0dFey6OjoykqKvrBbW+++WZKSkqoq6tj0qRJjBo1qsWNPdyCBQuYP3++53l6ejozZ84kISHhhPbbXAeGnUfFskWEb/ya2PMuaJdjnqySkpJ83QRpI6ptYFJdA1dLaltZWdkxPhU9AZZlUVdXh93uHadqamqOawaDhu1WrVrF9u3bmThxoudnaLPZMMYc82daW1vbrJ97XFyc1/PmbHPZZZfx8MMPs337drKysn5w3eDgYJKTk4+5z6M5roukmkrEx0rJDz30EFVVVWzevJk5c+aQlJTEOeecczyHB2DixIleF1o1HD8/P/+EpzZoDlffU2HZIso+eZ/Ki6/G2DQhQmszxpCUlEReXp4meg4wqm1gUl0D1/HUtqamxvNJp2VZUFPdlk08uuCQFvXkWZbFc889x4svvsj+/ftJT09n6tSpjBs3jhUrVjBp0iT++9//MnPmTDZs2MB///tfnnzySXr37o3D4WD+/Pn07t2b1157jc8++4yHH36Y9evXExMTw6RJk7jrrrs8gfbSSy9tcrvXX3+dc889l6CgIGpra5k7dy5/+tOfAEhMTATgiSee4PLLL6dbt248+uijfPjhh3zyySfcfPPN3HHHHdx11118+umn5Ofn07VrV6677jpuuOEGz3lOnTqVkpIS/vnPf+JwOBg/fjx9+/YlJCSEl19+GYfDwTXXXMNvfvMbzzadOnViyJAhzJ8/n2nTpv3gz7Gmpobc3NxGy+12e7M6E1sUUKOiorDZbI16S4uLixv1qh6p4QealpZGcXEx8+bNO6GA6nA4jpr22+UXY99TITQMig9ibd0ImU3PYiAnzrIs/bELUKptYFJdA9dx17amGtdtl7V+g5rB9syrEBLa7PVnzpzJO++8w6OPPkp6ejqff/45U6ZM8epxfPjhh7n//vtJS0sjKioKgHnz5nHttdfyxhtvAJCbm8s111zDZZddxlNPPcWWLVuYNm0aISEhXqHvyO0APv/8c6/hlBdffDGbNm3io48+4pVXXgHcYbHBn//8Z6ZPn84DDzxAUFAQLpeL5ORknn/+eTp37szq1au56667SExM5OKLLz7quc+bN48bb7yRRYsWsWbNGu644w6GDh3KueceuiB80KBBrFy5slk/yxP5PdCigGq328nIyCAnJ4fTTz/dszwnJ4ehQ4c2ez+WZbVLL2dbMg4HJnso1qqPsb5cgVFAFRER6dAqKip44YUXmDt3LqeddhoA3bt354svvuCll17iqquuAmDatGleoQ2gR48e3HvvvZ7njz32GF27duWRRx7BGENmZiZ5eXnMmDGDO+64wzO29MjtAHbv3k2XLl08z8PCwoiIiCAoKMjT4Xe4CRMmNLq257e//a3ncVpaGqtXr2bRokU/GFD79u3LnXfeCUBGRgazZ8/2XKzVICkpid27dx91H62lxR/xjxs3jqeffpqMjAyysrJYtmwZBQUFjBkzBoA5c+Zw8OBBbrvtNgDeffdd4uPj6datG+CeF3XRokVceOGFnn06nU7PyTqdTg4ePMiOHTsIDQ3167FMZsiw+oD6Gdak6zV4XEREpCnBIe6eTB8du7k2b95MVVUVV1xxhdfy2tpaBgwY4Hmend344uiBAwd6Pd+yZQtDhgzxygZDhw6lvLyc3NxcTy46cjuAqqqqFt3Lvql9/Oc//+Hll19m9+7dVFVVUVtbS//+/X9wP0fOaZ+YmEhBQYHXstDQUCorK5vdtuPV4oA6bNgwSktLee211ygsLCQ1NZXp06d7xhMUFhZ6nYxlWbz88svs378fm81GUlISV111lWcOVICDBw9y1113eZ4vWrSIRYsW0a9fPx544IETOL021n+I+x/+gf2wayt0z/R1i0RERPyOMaZFH7P7SsOUTv/5z38adZAFBwezc+dOAMLDwxtt2zCTUAPLshp1XDX1kfeR2wF07tyZ4uLiZrf7yPYsXLiQBx98kPvuu4/TTjuNiIgInnvuOb766qsf3M+RF3sZYxpNc1VUVNToAqu2cFwXSV1wwQVccEHTV67feuutXs8vvPBCr97SpiQmJvLqqz76n9UJMCEhMGAIfLkCa80KjAKqiIhIh5WVlUVISAh79uzhrLPOavR6Q0Btjl69evH22297BdXVq1cTGRl5zKvb+/fvz+bNm72WORyOZs+JumrVKoYMGcLkyZOPq+0/ZOPGjcfsiW0NuvT8BJnB7n/A1pef6aIAERGRDiwyMpKbbrqJBx54gFdffZUdO3awbt06Zs+e3eKOtOuuu469e/dy7733smXLFpYsWcKf//xnbrzxRs/406MZMWIEX3zxhdey1NRUdu3axbp16zh48CDV1UefFaFHjx7k5OTw0UcfsXXrVh5//HHWrl3bovYfzapVqxg+fHir7OuHKKCeIJM9FOx22LcH9u7ydXNERETkBNx1113ccccdPPPMM4wYMYIrr7yS9957j7S0tBbtJzk5mRdffJGvv/6aMWPGcPfdd3PFFVdw++23H3PbSy65hM2bN7NlyxbPsosuuogRI0Zw2WWXccopp3hd9X+ka665hgsvvJBf/epX/OQnP6GwsJDrrruuRe1vyurVqyktLWXs2LEnvK9jMVaAdfvl5+d73WGqPdQ9/X+Q8wXmJ1dgu/iKY28gzWKMITk5mdzcXPVOBxjVNjCproHreGpbUlL
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAroAAAF0CAYAAADM95pAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB6dElEQVR4nO3deXhU1eH/8fedTPZ9T1iCSQir7MhqDQLWVmgRS221VqlSsUEp3fzVUq0LaumCWr+2tXXrBqJimhgVYgRFwIKACghGYkBiEkxC9j0zc39/DAwMCUIg6+Tzep48ydx77r3nzknIh5NzzzFM0zQREREREfEwlu6ugIiIiIhIZ1DQFRERERGPpKArIiIiIh5JQVdEREREPJKCroiIiIh4JAVdEREREfFICroiIiIi4pEUdEVERETEIynoioiIiIhHUtAVEREREY9k7e4K9EQVFRXYbLburoZ0oOjoaEpLS7u7GtLB1K6eS23rudS2nqkr29VqtRIeHn5uZTu5Lr2SzWajpaWlu6shHcQwDMDZrqZpdnNtpKOoXT2X2tZzqW09U09uVw1dEBERERGPpKArIiIiIh5JQVdEREREPJKCroiIiIh4JD2M1k5NTU00NTV1dzWknRoaGmhubm5zn6+vL76+vl1cIxEREelsCrrtUFdXh2EYBAcHu54wlN7B29u7zZk0TNOkoaGBuro6AgMDu6FmIiIi0lk0dKEdbDYbAQEBCrkexDAMAgICNG+yiIiIB1LQbQcFXM+lthUREfE8CroiIiIi4pEUdEVERETkvJnlZVSv+ydmD1xVVkFXOkReXh5jx46ltra2U87/xz/+kSuuuKJdx1x11VW89tprnVIfERGRvsxsasTxv03YV92N/f/dTNUzf8Lc8153V6sVBV3pECtXruSmm24iKCgIgLVr1zJ8+PAOO/9tt93G2rVr23XMsmXLeOihh3A4HB1WDxERkb7KdDgwc/fiePYxHD+7CfPpR+DAh2Ca+F48HvwDuruKrWh6MblgRUVFvPHGG9x3333tPra5uRkfH5+zlgsMDGz39F+zZs3iF7/4BW+99RZXXnllu+smIiIiYJYUYW7biPm/t+BYyckd0XEYU2dimXo5MaPHUVxcjGma3VbPtijoXgDTNKG5mxaP8PFt10wBmzZt4rHHHiM3NxeLxcKECRO4//77ueiiiwBnWH3ggQfYvHkzTU1NpKSk8OCDDzJ+/HgAsrOzeeSRR8jNzSUgIIApU6bw1FNPAfDKK68wYsQI+vXrB8C2bdv46U9/CkD//v0B+OlPf8rPfvYzJk+ezHXXXcfhw4dZv349V155JY899hgPPvggr7/+OsXFxcTExDB//nx+8pOf4O3tDTiHLqxfv5433ngDcPbWVldXM2nSJJ588kmam5uZN28e9913n+sYLy8vZs6cyX//+18FXRERkXYw62sxd27B3LYRPv345A7/AIyJl2JMnQmDh2MYRo+euUhB90I0N+G4/dpuubTl/14AX79zLl9fX8+tt97KsGHDqK+v5w9/+AOLFi0iOzubhoYGFixYQFxcHM8++yzR0dHs3bvX9Sf/nJwcFi1axNKlS/nTn/5Ec3Mzb775puvc27dvZ/To0a7XEydO5L777uMPf/gDmzdvBnDrjf3rX//KsmXL+PGPf+zaFhgYyCOPPEJcXBwHDhzgzjvvJCgoiLS0tDPe07Zt24iJieHFF1/k0KFD/OhHP2LkyJF873vfc5UZO3Ysf/nLX875fRIREemrTLsd9r/v7L39YDvYjj9cZlhg5FiMqTMxxk7G8Ok9q4meV9DdsGEDmZmZVFZWMmDAABYuXPil4zHXr1/Phg0bKCkpISoqimuuuYbU1FTX/u3bt5Oens7Ro0ex2+3ExcXxjW98g8suu8xVJj09nR07dlBYWIiPjw9DhgzhhhtucPUiAjzxxBO8/fbbbtc+0TPZ182ZM8ft9R//+EdGjx7NJ598ws6dOzl27Bivvvoq4eHhACQmJrrK/ulPf2LevHn8/Oc/d20bOXKk6+uCggJGjRrleu3j4+NaPS4mJqZVXaZPn85tt93mtm3ZsmWurwcOHMinn35KZmbmlwbd0NBQHnzwQby8vBg8eDCzZs1iy5YtbkE3Pj6ewsJCjdMVERE5A/PzQ85wu/1tqK48uaP/IGe4nXwZRlhkt9XvQrQ76G7bto3nnnuORYsWMXToUHJycnjooYd45JFHiIqKalU+OzubNWvWsHjxYpKTk8nLy+PJJ58kMDCQiRMnAhAUFMQ111xDv379sFqt7N69mz//+c+EhIQwduxYAPbv38+VV15JcnIydrud559/nhUrVrBq1Sr8/E72bI4dO9YtHFmtndhp7ePr7FntDu3839Thw4f5/e9/z+7duykvL3cFv8LCQj766CMuvvhiV8g93UcffeQWHk/X2Njo1gZnc2rv7wlZWVk89dRTHD58mLq6Oux2u+vBtjMZMmQIXl5ertexsbEcOHDArYyfnx8Oh4OmpqbO/V4QERHpRczqSswdbzuHJhQcOrkjKARjcqpzaEJCUo8elnAu2v2bPysri5kzZzJr1iwAFi5cyIcffkh2djbXX399q/KbN29m9uzZTJs2DXCGkYMHD5KRkeEKuqf2DoJzWqi3336bjz/+2BV0ly9f7lYmLS2NRYsWkZ+fz4gRI07ekNVKWFhYe2/rvBiG0a7hA91p4cKF9OvXj9/97nfExcXhcDiYOXMmLS0tZw2pZ9sfERFBZWXlOdclIMD9qcxdu3aRlpbGz372M2bMmEFwcDAZGRn87W9/+9LznBiLe6rTB8FXVFTg7++Pv78/LT1wfj8REZGuYra0wJ4dOLZthH274MRfO72sMOYSLFNnwsXjMaytf7/2Vu0Kujabjfz8fK6++mq37aNHjyY3N7fNY1paWloFEh8fH/Ly8rDZbK162UzTZN++fRQVFX1pL2J9fT1Aq16//fv3s2jRIgIDAxk+fDjXXXcdoaGhZ6zbqeHHMAz8/f1dX3uK8vJyDh48yMqVK5k8eTIAO3bscO0fPnw4a9asoaKios1e3eHDh7Nlyxa+853vtHn+iy++mIMHD7pt8/HxwW63n1P93nvvPQYMGOA2ZrewsPCcjj2b3Nxct2EVX8aT2rwvONFeajfPo7b1XGrbrmeaJuTn4nh3I+aOd6D+lPnuE4dgmToTY9JXMIJCzvsaPbld2xV0q6urcTgcrYJjaGjoGXv0xowZw8aNG5k0aRKJiYnk5+ezadMm7HY7NTU1rmBVX1/P4sWLsdlsWCwWbrnlljb/xA3ORvvHP/7BsGHDSEhIcG0fN24cU6dOJSoqipKSEtauXcv999/Pb3/72zZ7/9LT03nppZdcrxMTE1m5ciXR0dFtXrehoaHN8/R00dHRREREsHr1avr160dhYSErVqwAnDMTfPvb3+b//u//WLRoEcuXLyc2Npa9e/cSFxfHJZdcwp133sm3vvUtEhMTmT9/PjabjTfffJM77rgDcE7j9ZOf/ASLxeIaSnDRRRdRV1fHu+++y8iRI/H39ycgIADDMPDy8nJ7HwcPHkxhYSFZWVmMHTuWnJwc1q9fD5zstbVYLBiGccbXJ+7l9G3vvfcel19+udu52uLj40N8fPwFv9fS9eLi4rq7CtJJ1LaeS23b+WylR6nf+Bp1G1/F9vlnru1ekTEEzLyKwJlz8E5I/JIztF9PbNfzGrTYVmI/U4pfsGABlZWVLF++HNM0CQ0NJTU1lczMTCyWk+tV+Pn58fvf/57Gxkb27t3LP//5T2JjY1sNawB4+umnOXLkCPfff7/b9hPDIwASEhJITk4mLS2N3bt3u3oyTzV//nzmzp3b6h5KS0ux2Wytyjc3N/faP38/8cQT3HPPPaSmppKUlMQDDzzAggULsNvtGIbB6tWrue+++7j++uux2WwMGTKEBx98kJaWFtcUXo8++iiPP/44QUFBTJkyxfVepKamYrVa2bhxIzNmzACc/+n4/ve/zw9/+EMqKipc04uZpondbnd7H2fPns0Pf/hDfvnLX9Lc3MysWbP48Y9/zKpVq1zlHA4Hpmme8TW
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2024-10-03 15:53:33 +02:00
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.01 # scale for random parameter initialisation\n",
"learning_rate = 0.1 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
"plt.show()"
]
},
2024-10-10 15:52:23 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `init_scale = 0.1`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 0.5s to complete\n",
" error(train)=3.22e-01, acc(train)=9.11e-01, error(valid)=3.00e-01, acc(valid)=9.16e-01\n",
"Epoch 10: 0.5s to complete\n",
" error(train)=2.95e-01, acc(train)=9.18e-01, error(valid)=2.81e-01, acc(valid)=9.21e-01\n",
"Epoch 15: 0.5s to complete\n",
" error(train)=2.85e-01, acc(train)=9.21e-01, error(valid)=2.74e-01, acc(valid)=9.23e-01\n",
"Epoch 20: 0.5s to complete\n",
" error(train)=2.78e-01, acc(train)=9.22e-01, error(valid)=2.70e-01, acc(valid)=9.25e-01\n",
"Epoch 25: 0.5s to complete\n",
" error(train)=2.72e-01, acc(train)=9.24e-01, error(valid)=2.68e-01, acc(valid)=9.25e-01\n",
"Epoch 30: 0.5s to complete\n",
" error(train)=2.68e-01, acc(train)=9.25e-01, error(valid)=2.64e-01, acc(valid)=9.27e-01\n",
"Epoch 35: 0.5s to complete\n",
" error(train)=2.64e-01, acc(train)=9.27e-01, error(valid)=2.62e-01, acc(valid)=9.28e-01\n",
"Epoch 40: 0.5s to complete\n",
" error(train)=2.63e-01, acc(train)=9.27e-01, error(valid)=2.63e-01, acc(valid)=9.27e-01\n",
"Epoch 45: 0.5s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.29e-01\n",
"Epoch 50: 0.5s to complete\n",
" error(train)=2.58e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.29e-01\n",
"Epoch 55: 0.5s to complete\n",
" error(train)=2.56e-01, acc(train)=9.29e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 60: 0.5s to complete\n",
" error(train)=2.55e-01, acc(train)=9.28e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 65: 0.5s to complete\n",
" error(train)=2.54e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 70: 0.5s to complete\n",
" error(train)=2.52e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 75: 0.5s to complete\n",
" error(train)=2.51e-01, acc(train)=9.30e-01, error(valid)=2.58e-01, acc(valid)=9.29e-01\n",
"Epoch 80: 0.5s to complete\n",
" error(train)=2.50e-01, acc(train)=9.30e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 85: 0.5s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 90: 0.6s to complete\n",
" error(train)=2.48e-01, acc(train)=9.31e-01, error(valid)=2.57e-01, acc(valid)=9.29e-01\n",
"Epoch 95: 0.5s to complete\n",
" error(train)=2.47e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.29e-01\n",
"Epoch 100: 0.5s to complete\n",
" error(train)=2.47e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.29e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAByX0lEQVR4nO3deXxU1cH/8c+dJftG9pAQCISwGnYVN6iAKxWwaCtWRWvVqg9u1dYqVq2i2KqPP5faWh5pVRRRUcAKiIoVUBCqBGQJ+5pAAtnIPjP398ckA0OCJJBkJpPv+/XKK5k75945N4eEb86cxTBN00RERERExE9YfF0BEREREZFjKaCKiIiIiF9RQBURERERv6KAKiIiIiJ+RQFVRERERPyKAqqIiIiI+BUFVBERERHxKwqoIiIiIuJXFFBFRERExK8ooIqIiIiIX7H5ugItraioCIfD4etqSAtJSEigoKDA19WQVqC2DUxq18Cltg1cbdm2NpuNTp06nbxcG9SlTTkcDmpra31dDWkBhmEA7jY1TdPHtZGWpLYNTGrXwKW2DVz+2rZ6i19ERERE/IoCqoiIiIj4FQVUEREREfErCqgiIiIi4lcCbpKUiIiI+J5pmhw5csSvJt5I4yorK6mpqWmx6wUHBxMcHHxa11BAFRERkRZ35MgRgoODCQoK8nVV5CTsdnuLrYBkmiaVlZWUl5cTHh5+ytfRW/wiIiLS4kzTVDjtgAzDICws7LTXpFdAFREREZEWVb++6qlSQBURERERv6KAegpMpxNz83pcq/7j66qIiIiIBBwF1FOx5Qdcf/kD5juvYbqcvq6NiIiIBJhly5ZxwQUX4HK5WuX6d999NzfddFOTy1dXVzNs2DBycnJapT7HU0A9FZl9ITQcykpgxxZf10ZEREQCzJNPPsmUKVOwWNxR7dlnn2XMmDEtdv3HH3+c559/vsnlg4ODue2223jyySdbrA4/5pSWmVq0aBHz5s2juLiYtLQ0Jk+eTJ8+fRotu2nTJt566y327dtHdXU1CQkJjB49mrFjx3rKLFmyhP/85z/s2bMHgO7du3PNNdeQmZl5KtVrdYbNhtF/MOa3X2GuXYXRo7evqyQiIiJtzDRNnE4nNpt3nKqpqTmlFQzqz/v222/ZsWOHV1ZqqtraWux2+0nLRUVFNfvaEyZM4IknnmDLli307Nmz2ec3R7N7UFesWMHMmTO58sormT59On369GHatGkUFhY2Wj44OJiLL76Yxx57jOeff54rr7yS2bNns2TJEk+ZDRs2cO655/LHP/6RJ554gri4OJ544gkOHz586nfW2rKHAWDmfOvjioiIiPg30zQxq6t889HMjQJM0+SVV15h+PDh9OjRg9GjR7NgwQLAnYFSU1NZunQpl156KRkZGaxcuZKJEyfy0EMP8eijj9K/f3+uueYaAL7++msuv/xyMjIyGDRoENOmTfNafulE582bN48LLriAkJAQAGbPns1zzz3Hhg0bSE1NJTU1ldmzZwOQmprKv/71L2688UYyMzN54YUXcDqd3HfffZx99tn06NGD888/n3/84x9e93n8W/wTJ05k6tSpPPHEE/Tr14+BAwfy7LPPep0TGxvLkCFD+PDDD5v1PT0Vze5BXbBgARdeeCGjRo0CYPLkyaxdu5bFixczadKkBuUzMjLIyMjwPE5MTGTVqlVs3LiR0aNHAzBlyhSvc2677TZWrlzJunXrGDFiRHOr2CaMM4ZgWiywbxdm4QGM+CRfV0lERMQ/1VTjuvNqn7y05aV3ITikyeWnT5/OJ598wlNPPUVGRgbffPMNU6ZMIS4uzlPmiSee4JFHHiE9Pd3TEzlnzhyuv/56T3jLy8vjuuuu4+qrr+aFF15g69at3H///QQHB3Pfffd5rnX8eQDffPMN48eP9zy+4oor2Lx5M0uXLuWdd94BIDIy0vP8s88+y4MPPsijjz6K1WrF5XKRkpLCq6++SmxsLKtXr+aBBx4gMTGRK6644oT3PmfOHG655Rbmz5/PmjVruOeeexg2bBgXXHCBp8ygQYNYuXJlk7+fp6pZAdXhcLB9+3avbxpAdnY2mzdvbtI1duzYwebNm/nFL35xwjLV1dU4HA4iIiJOWKa2ttZr1wPDMAgNDfV83dqMiChcmX0g9wfI+RZj1E9b/TU7mvp2bIv2lLaltg1MatfA1VHatqKigtdee43Zs2czdOhQALp27cq3337Lm2++ybXXXgvA/fff7xXaALp168bDDz/sefz000/TuXNnnnzySQzDIDMzk/z8fKZNm8Y999zjGVt6/HkAe/fuJSnpaMdXaGgo4eHhWK1WEhMTG9R7/PjxDXLVb3/7W8/X6enprF69mvnz5/9oQO3Tpw/33nsv4B5uOXPmTM9krXrJycns3bv3hNc41un8e2lWQC0tLcXlchEdHe11PDo6muLi4h8997bbbqO0tBSn08lVV13l6YFtzFtvvUVsbCxnnHHGCcvMnTuX9957z/M4IyOD6dOnk5CQ0LSbaQGl542mJPcHgjbnkPDLW9rsdTua5ORkX1dBWonaNjCpXQNXc9q2srLSMxbStNngb3Nbq1o/Lii4yUFp+/btVFVVed5qr1dbW8sZZ5zhGWs6ZMgQr3GehmEwaNAgr2Pbtm1j2LBhXmNRhw8fTnl5OQUFBaSlpTV6HkBVVRXh4eFexy0WC4ZhNDq+dPDgwQ2Oz5w5k7feeou9e/dSWVlJbW0t/fv395Q7/nqGYdCvXz+v6yQnJ3P48GGvY+Hh4V5teyJBQUGkpKT8aJkfc0qTpBpr6JM1/uOPP05VVRW5ubnMmjWL5ORkzjvvvAblPvroI5YvX86jjz76owOMJ0yY4DV4uP71CwoKTnt7raYyu7snR1XlrGb/jm0YIWFt8rodhWEYJCcnk5+f3+wxROLf1LaBSe0auE6lbWtqarz3d7dYW6l2J9GMTFBTUwPAv/71rwZhPCgoiF27dgEN9643TZPg4GCvYy6XC9M0vY7Vf+1wOKitrW30PHCP9Tx06NBJr1fv+GvMmzePRx55hKlTpzJ06FDCw8P561//ynfffecpd+z17HY7pmlisVga3Fd9XesdOnSIuLi4Rutx/PcyLy+vwXGbzdakzsRmBdSoqCgsFkuD3tKSkpIGvarHq++STk9Pp6SkhDlz5jQIqPPmzWPu3LlMnTqVrl27/uj17Hb7CdN7W/1iNBM7Q2IKHMzDXP8dDDmnTV63ozFNU//ZBSi1bWBSuwauQG/brKwsgoOD2bdvH8OHD2/wfH1AbYqePXvy73//G9M0PZ1oq1evJiIi4qQ9i/369SM3N9frmN1ub/KaqKtWrWLIkCFMnjz5lOr+YzZt2kS/fv2aVPZ0/q00axa/zWaje/fuDRZpzcnJoVevXk2+Tn0iP9a8efN4//33+cMf/kCPHj2aUy2fMQwDI/tMAMy1q3xcGxERETkdERER3HrrrTz66KO8++677Ny5k/Xr1zNz5kzefffdZl3rhhtuYP/+/Tz88MNs3bqVRYsW8eyzz3LLLbd4xp+eyMiRI/n2W+9Vgrp06cLu3btZv349hw8fprq6+oTnd+vWjZycHJYuXcq2bdt45plnWLt2bbPqfyKrVq1qkwnszV5mauzYsXz22Wd8/vnn7N27l5kzZ1JYWOhZPHbWrFm89NJLnvILFy5k9erV5OXlkZeXxxdffMH8+fM5//zzPWU++ugj3nnnHX7zm9+QmJhIcXExxcXFVFVVtcAtti5jQN1yU+tWa1cpERGRdu6BBx7gnnvu4aWXXmLkyJFMmjSJTz/9lPT09GZdJyUlhTfeeIPvv/+eMWPG8Pvf/55rrrmGu+6666TnXnnlleTm5rJ161bPscsuu4yRI0dy9dVXc8YZZ/zoUk/XXXcdl156Kb/5zW/46U9/SlFRETfccEOz6t+Y1atXU1ZWxuWXX37a1zoZwzyF/tf6hfqLioro0qULN9xwA3379gXg5ZdfpqCggEcffRSATz75hCVLlnDw4EEsFgvJycmMGjWK0aNHe/6CuOOOOygoKGjwOhMnTuTqq5u3LEVBQcFJx0W0JNP
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAArEAAAF0CAYAAAA0F2G3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABtaklEQVR4nO3deXxU9b3/8deZTCb7vkMCJOz7IoKCFQQsVqmoRa3WKq1WW7RW29veWpe6oF5si7XeLv6utlatilUpSCsgiiKg4sYma4AA2SAh+56Z+f7+mGQgJgiBJLPwfj4ePJI5c+acz+RL4J1vvotljDGIiIiIiAQQm68LEBERERHpLIVYEREREQk4CrEiIiIiEnAUYkVEREQk4CjEioiIiEjAUYgVERERkYCjECsiIiIiAUchVkREREQCjkKsiIiIiAQchVgRERERCTh2XxfgC+Xl5TidTl+XIV0kJSWFkpISX5chXUztGrzUtsFLbRucerpd7XY7CQkJJz6vB2rxO06nk+bmZl+XIV3AsizA06bGGB9XI11F7Rq81LbBS20bnPy5XTWcQEREREQCjkKsiIiIiAQchVgRERERCTgKsSIiIiIScM7IiV3HY4yhpqbG7wYuy1err6+nqampw+fCwsIICwvr4YpERESkuynEHqOmpoawsDAcDoevS5FOCA0N7XC1CWMM9fX11NbWEhUV5YPKREREpLtoOMExjDEKsEHEsiwiIyO1JrCIiEgQUoiVoNe6xp2IiIgED4VYEREREQk4GhMrIiIiIu0YpxMOFVC7YyMmIRVSM3xdUhsKsXJCubm5zJkzh7Vr1xIdHd3l1//d737H8uXLeeutt076NRdffDG33XYbF198cZfXIyIicqYxVeVwMA+Tnwf5LR+LDoLLSRlg+/YPsKZ/08dVtqUQKye0YMECbrjhBm+AXbRoEffffz/bt2/vkuv/8Ic/5Hvf+16nXnPHHXfw4IMPctFFF3VJDSIiImcC09wMRQcx+fuOhtX8PKiu7PgF4RE4cgbhjIrpyTJPikKsfKXCwkLeeustHnjggU6/tqmp6aRWe4iKiur0EljTp0/n5z//Oe+++y4zZ87sdG0iIiLBzBgDFWUtQfWYwFqcD253+xdYFqT2gsy+WJnZWJn9ILMfVnIaab16UVRU5Hfr6CvEHocxBpoafXNzR1inZtSvXr2aJ554gp07d2Kz2TjrrLN48MEH6devH+AJog899BBr1qyhsbGRgQMH8vDDDzNu3DgAVq5cyeOPP87OnTuJjIzknHPO4emnnwbgjTfeYNiwYfTq1QuA9evX89Of/hSA3r17A/DTn/6Un/3sZ0ycOJFrrrmGvLw8li9fzsyZM3niiSd4+OGHefPNNykqKiI1NZXLL7+cO++8k9DQUKD9cII77riDqqoqJkyYwFNPPUVTUxOzZ8/mgQce8L4mJCSEadOm8a9//UshVkREzmimqREKD7QdCpCfB7XVHb8gMtoTUFuDamY29OqD1cHmQP68wo9C7PE0NeK+7Sqf3Nr2v69AWPhJn19XV8fNN9/MkCFDqKur47e//S033XQTK1eupL6+njlz5pCens7f/vY3UlJS2LJlC+6Wn8JWrVrFTTfdxO23384f/vAHmpqaePvtt73X/uijjxg1apT38fjx43nggQf47W9/y5o1awDa9KL+5S9/4Y477uAnP/mJ91hUVBSPP/446enpbN++nV/84hdER0czb968476n9evXk5qayj//+U/27dvHj370I4YPH853vvMd7zljxozhz3/+80l/nURERAKZMQbKStoEVZO/Dw4Vgemgd9Vmg7Tex4TVfpCZDQlJfh1OT5ZCbBC45JJL2jz+3e9+x6hRo9i1axeffPIJR44c4d///jcJCQkAZGdne8/9wx/+wOzZs/mv//ov77Hhw4d7Pz948CAjR470PnY4HMTExGBZFqmpqe1qmTx5Mj/84Q/bHLvjjju8n2dlZbFnzx6WLl36lSE2Li6Ohx9+mJCQEAYMGMD06dNZu3ZtmxCbkZFBQUGBN5CLiIgEG2MM5OViPngb8/FaqKnq+MToWMjKxurd72hg7ZWFFRq8mzgpxB6PI8zTI+qje3dGXl4ev/nNb/jss88oKyvzhrqCggK++OILRowY4Q2wX/bFF1+0CYZf1tDQQHj4yfcKH9tr22rZsmU8/fTT5OXlUVtbi8vlOuEqB4MGDSIkJMT7OC0trd1EsvDwcNxuN42Njdjt+qssIiLBw5SVYj56F/PBas8qAa1C7JCRebR3tSW0EpcQFL2rnaH/+Y/DsqxO/Urfl+bOnUuvXr147LHHSE9Px+12M23aNJqbm08YQE/0fGJiIhUVFSddS2RkZJvHn376KfPmzeNnP/sZU6dOJSYmhiVLlvD//t//+8rrtI59PdaXB5SXl5cTERFBREQEzc3NJ12jiIiIPzKNDZjPP8R88A5s3wSt/++FOrDGnoN17gUwZBSWvf3/kWcihdgAV1ZWxu7du1mwYAETJ04EYMOGDd7nhw4dyksvvUR5eXmHvbFDhw5l7dq1XH311R1ef8SIEezevbvNMYfDgcvlOqn6Pv74YzIzM9uMkS0oKDip157Izp072wx1EBGRM4vZvweTux0rsy/kDA7IX50btxt2b/MMF/hkPTTWH31y4DCsc6dhnTUZK7Jzq/icCRRiA1x8fDwJCQm88MILpKamUlBQwKOPPup9/rLLLuPJJ5/kxhtv5K677iI1NZWtW7eSlpbG+PHj+elPf8rVV19N3759mT17Nk6nk9WrV3vHq06ZMoWf//znuFwu76/3MzMzqa2t5f3332f48OHe3tCOZGdnU1BQwJIlSxg9ejRvv/02b775Zpe89w0bNnD++ed3ybVERCQwmOYmzCfrMO/+B/bu9BwDCHV4guyQkViDRkLOIL/usTSHCzEfrPYMFzhy+OgTyWlY516Adc4FWH62Q5a/UYgNcDabjT/96U/cd999TJ8+nZycHB566CHmzJkDeHpNX3rpJR544AG++93v4nQ6GTRoEA8//DAAkyZN4qmnnuL3v/89f/zjH4mOjuacc87xXn/69OnY7Xbef/99pk6dCsDZZ5/Nd7/7XX70ox9RXl7uXWKrIzNnzuQHP/gBd999N01NTUyfPp077riDhQsXntb7Lioq4pNPPuEPf/jDaV1HREQCgykpxry3HLPuLahpWToqxA6DhkPBfqiqgJ1bMDu3eEKtwwH9h2INHok1eCT0G4jl4/kTpq4G88lazPp3YM+Oo0+ER2CNPw/r3GkwYCiWzea7IgOIZfxt5doeUFJS0uEYyqqqKmJjY31QkX979tlnWblyJS+++KKvS/F66KGHqK6u5rHHHiM0NPQrx8SqXQOPZVlkZGT45eLacnrUtsGrO9rWuN3wxWe4V/8Htn56dIxoYjLW+RdhnXchVlyC537FBZidm2HHFsyure13oHKEwYBhnp7awSOh7wCsYyYQdxfjcsG2zz29rp9/CM6W/68sGwwb7RkuMOacDtdo9Qe++J4NDQ0lJSXlhOepJ1ZO6LrrrqOyspKampoTrirQU5KTk9st5SUiIsHBVFdh1q/CvPsmlB46+sSwMdguuBhGnt0mgFqW5Zmxn5EJUy/2hK3Cg5idmzE7t8KuLZ7e222fY7Z97umpDYvwjDltHX7QJ6dLQ63J3+cJrh+9B5XlR5/o1Qdr0jSsiVOw4pO67H5nIvXEHkM9doFJPbHBR711wUttG7xOt22NMbBvF+bd/3jWQ23tsYyMwpo0A2vKRVjpvU+pNuN2Q+F+zM6tmB1bYNdWqKtpe1JE5DE9taMgqx+WrXOh1lRVYDa85xkucHDf0SeiYz2h9dwLoE//gFoKSz2xIiIiIh0wjY2Yj9d4el335x59ok8O1tSLsSZMOe1ftVs2G2Rme7ZXnf5NT6jNz/OMn925BXZ9AfW1sOUTzJZPPD21kVEwcPjRntrMfh2OVTXNzbB5A+7173iGPLRuwBNih9FnYzv3Ahhxll9PMgtUCrEiIiLS48yhQsx7b2LWvX20V9QeinX2eVhTL4bsQd3WY2nZbJ6Q3Cc
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 128 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.1 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `init_scale = 0.5`"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 0.6s to complete\n",
" error(train)=3.38e-01, acc(train)=9.03e-01, error(valid)=3.17e-01, acc(valid)=9.11e-01\n",
"Epoch 10: 0.5s to complete\n",
" error(train)=3.06e-01, acc(train)=9.13e-01, error(valid)=2.94e-01, acc(valid)=9.17e-01\n",
"Epoch 15: 0.5s to complete\n",
" error(train)=2.92e-01, acc(train)=9.17e-01, error(valid)=2.83e-01, acc(valid)=9.20e-01\n",
"Epoch 20: 0.6s to complete\n",
" error(train)=2.82e-01, acc(train)=9.20e-01, error(valid)=2.77e-01, acc(valid)=9.22e-01\n",
"Epoch 25: 0.5s to complete\n",
" error(train)=2.77e-01, acc(train)=9.22e-01, error(valid)=2.75e-01, acc(valid)=9.22e-01\n",
"Epoch 30: 0.6s to complete\n",
" error(train)=2.71e-01, acc(train)=9.24e-01, error(valid)=2.71e-01, acc(valid)=9.25e-01\n",
"Epoch 35: 0.5s to complete\n",
" error(train)=2.67e-01, acc(train)=9.25e-01, error(valid)=2.69e-01, acc(valid)=9.26e-01\n",
"Epoch 40: 0.5s to complete\n",
" error(train)=2.65e-01, acc(train)=9.27e-01, error(valid)=2.68e-01, acc(valid)=9.26e-01\n",
"Epoch 45: 0.5s to complete\n",
" error(train)=2.61e-01, acc(train)=9.27e-01, error(valid)=2.66e-01, acc(valid)=9.27e-01\n",
"Epoch 50: 0.6s to complete\n",
" error(train)=2.59e-01, acc(train)=9.28e-01, error(valid)=2.65e-01, acc(valid)=9.27e-01\n",
"Epoch 55: 0.5s to complete\n",
" error(train)=2.57e-01, acc(train)=9.29e-01, error(valid)=2.64e-01, acc(valid)=9.29e-01\n",
"Epoch 60: 0.5s to complete\n",
" error(train)=2.56e-01, acc(train)=9.28e-01, error(valid)=2.65e-01, acc(valid)=9.28e-01\n",
"Epoch 65: 0.5s to complete\n",
" error(train)=2.54e-01, acc(train)=9.30e-01, error(valid)=2.63e-01, acc(valid)=9.28e-01\n",
"Epoch 70: 0.5s to complete\n",
" error(train)=2.52e-01, acc(train)=9.30e-01, error(valid)=2.64e-01, acc(valid)=9.28e-01\n",
"Epoch 75: 0.5s to complete\n",
" error(train)=2.50e-01, acc(train)=9.31e-01, error(valid)=2.62e-01, acc(valid)=9.29e-01\n",
"Epoch 80: 0.6s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.63e-01, acc(valid)=9.28e-01\n",
"Epoch 85: 0.5s to complete\n",
" error(train)=2.48e-01, acc(train)=9.31e-01, error(valid)=2.62e-01, acc(valid)=9.30e-01\n",
"Epoch 90: 0.6s to complete\n",
" error(train)=2.47e-01, acc(train)=9.31e-01, error(valid)=2.62e-01, acc(valid)=9.28e-01\n",
"Epoch 95: 0.6s to complete\n",
" error(train)=2.47e-01, acc(train)=9.31e-01, error(valid)=2.62e-01, acc(valid)=9.29e-01\n",
"Epoch 100: 0.6s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.62e-01, acc(valid)=9.28e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABoBklEQVR4nO3deXxU1cH/8c+ZTDLZE0ISEkgCSVgFgoCIIiIqFBWqYNHHpSq1Vq1a1Lb6aCtWrUuxVeujrT4/61OqLUpRUcQNcakCAoJiQJYQ9kACCVnJPpn7+2OSgSEsSZgkk+H7fr3ympk79557bo4Tv5w551xjWZaFiIiIiIifsHV2BUREREREDqeAKiIiIiJ+RQFVRERERPyKAqqIiIiI+BUFVBERERHxKwqoIiIiIuJXFFBFRERExK8ooIqIiIiIX1FAFRERERG/ooAqIiIiIn7F3tkV8LWSkhKcTmdnV0N8JCEhgcLCws6uhrQDtW1gUrsGLrVt4OrItrXb7XTr1u3E+7Wl8I8++oiFCxdSWlpKSkoKM2bMYNCgQUfdd9OmTfzrX/9iz5491NbWkpCQwIQJE5gyZcpR91+2bBnPPvssZ5xxBvfee2+r6+Z0Oqmvr2/1ceJ/jDGAu00ty+rk2ogvqW0Dk9o1cKltA5e/tm2rA+ry5cuZM2cON910EwMGDGDJkiU8/vjjPPPMM8THxzfb3+FwMGnSJHr37o3D4WDTpk289NJLhIaGMmHCBK99CwsLefXVV48ZdkVEREQk8LV6DOqiRYu44IILuPDCCz29p/Hx8SxevPio+6enpzN27FhSU1NJTExk3LhxDBs2jI0bN3rt53K5+J//+R+uvPJKEhMT23Y1IiIiItLltaoH1el0sm3bNqZOneq1PSsri82bN7eojO3bt7N582auuuoqr+1vvPEG0dHRXHDBBc3C69HU19d7fZVvjCEsLMzzXLq+pnZUewYetW1gUrsGLrVt4PLXtm1VQC0vL8flchETE+O1PSYmhtLS0uMee+utt1JeXk5DQwNXXHEFF154oee9TZs28emnn/Lkk0+2uC4LFizgjTfe8LxOT09n9uzZJCQktLgM6RqSkpI6uwrSTtS2gUntGrjUtoHL39q2TZOkjpayT5S8H3nkEWpqasjJyWHu3LkkJSUxduxYqquree6557jllluIjo5ucR2mTZvmNdGq6fyFhYWaxR8gjDEkJSVRUFDgVwO35eSpbQOT2jVwtaVtLcuioqJC/y10ASEhIdTV1fmsPIfDQWho6FHfs9vtLepMbFVAjY6OxmazNestLSsra9areqSmcaVpaWmUlZUxf/58xo4dy759+ygsLGT27NmefZv+Y77qqqv485//fNRUHxwcTHBw8FHPpQ9DYLEsS20aoNS2gUntGrha07YVFRU4HA5CQkLauVZysoKDg322ApJlWVRXV3Pw4EEiIiLaXE6rAqrdbicjI4Ps7GzOPPNMz/bs7GxGjRrV4nIsy/L0cvbs2ZM//elPXu+//vrr1NTUeCZgiYiISNdiWZbC6SnIGEN4eDhlZWUnVU6rv+KfMmUKzz33HBkZGfTv358lS5ZQVFTExIkTAZg7dy7FxcXccccdAHz44YfEx8fTq1cvwD3e9N133+Xiiy8G3N3KaWlpXudoStxHbhcRERER/3eyk65aHVDHjBlDRUUFb775JiUlJaSmpnL//fd7xhOUlJRQVFTk2d+yLF577TX279+PzWYjKSmJa6+9ttkaqCIiIiIiAMYKsIFChYWF7X4nKauhAXI3YpUVYztzXLue61RmjCE5OZn8/HyNZwswatvApHYNXG1p2/Ly8lZNfpbO48sxqE2O1f7BwcEtmiTV6oX6BchZj+tPv8F6/SUsV0Nn10ZEREQCzNKlSxk3bhwul6tdyr/rrru48cYbW7x/bW0to0aNIjs7u13qcyQF1LboNxjCI6GiDHI3dXZtREREJMA89thjzJw5E5vNHdWeeuopz3wfX3jkkUd45plnWry/w+Hg1ltv5bHHHvNZHY5HAbUNjN2OyXKvWmCtXdHJtREREZHOcPiqRIdr65qiTcd9/fXXbN++3Wu995Zq6Vf10dHRJ1wi9EjTpk1j1apVbNmypdX1ai0F1DYyw0cDYK1dqbFWIiIix2FZFlZtTef8tPL/0ZZl8de//pWzzz6bzMxMJkyYwKJFiwBYvnw5vXr14vPPP+fiiy8mPT2dlStXMn36dH7729/y0EMPMWTIEK6++moAvvrqKyZPnkx6ejrDhw/n8ccf9wq0xzpu4cKFjBs3zrPY/bx583j66afZsGEDvXr1olevXsybNw+AXr168corr/CTn/yEvn378uyzz9LQ0MCvfvUrzjrrLDIzMzn33HP529/+5nWdR37FP336dGbNmsWjjz7K4MGDOf3003nqqae8jomLi2PkyJG8/fbbrfqdtkWb7iQlwOAREBwChQWwZyek9OnsGomIiPinulpcd1zZKae2Pf9vcBz9rkZHM3v2bD744AOeeOIJ0tPTWbFiBTNnzqR79+6efR599FEefPBB0tLSPBOB5s+fz/XXX+8Jb/n5+Vx33XVceeWVPPvss+Tm5nLPPffgcDj41a9+5SnryOMAVqxYwdSpUz2vL730UjZv3sznn3/O66+/DkBUVJTn/aeeeor777+fhx56iKCgIFwuF8nJybz44ovExcWxevVq7r33XhITE7n00kuPee3z58/n5ptv5t1332XNmjXcfffdjBo1inHjDk0IHz58OCtXrmzx77OtFFDbyDhC4bTT4btVWGtXYBRQRUREurSqqipeeukl5s2bxxlnnAFA7969+frrr/nnP//JtddeC8A999zjFdoA+vTpwwMPPOB5/Yc//IGePXvy2GOPYYyhb9++FBQU8Pjjj3P33Xd7xpYeeRxAXl4ePXr08LwOCwsjIiKCoKAgz505Dzd16lSuuuoqr22//vWvPc/T0tJYvXo177777nED6qBBg/jlL38JQEZGBnPmzPFM1mqSlJREXl7eMcvwFQXUk2CGn4X13Sqsb1fAlKtOfICIiMipKMTh7snspHO3VE5ODjU1NZ6v2pvU19czZMgQz+usrKxmxw4bNszrdW5uLiNHjvRasH7UqFFUVlaSn5/vuYHRkccB1NTU4HC0vN5HK+OVV17htddeIy8vj5qaGurr6xk8ePBxyxk0aJDX68TERK+17QFCQ0Oprq5ucd3aSgH1JJisUVjGBru2YR3Yj+ne/F81IiIipzpjTKu+Zu8sTUs6vfLKKyQlJXm9FxISws6dOwEIDw9vdmxYWJjXa8uymt1N6WjjYY88DtxjPVtzq9Aj67Nw4UIefvhhZs2axRlnnEFERAQvvPAC33777XHLsdu9Y6ExptkyV6WlpV7DHdqLAupJMFEx0G8Q5HyPtXYl5sIfdnaVREREpI369++Pw+Fgz549nH322c3ebwqoLdGvXz/ef/99r6C6evVqIiMjSU5OPu6xgwcPJicnx2tbcHBwi9dEXbVqFSNHjmTGjBltqvvxbNq06YQ9sb6gWfwnyZx+FoD7a34RERHpsiIjI7nlllt46KGH+Pe//82OHTtYv349c+bM4d//bt0QhRtuuIG9e/fywAMPkJuby0cffcRTTz3FzTff7Bl/eizjx4/n66+/9tqWmprKrl27WL9+PcXFxdTW1h7z+D59+pCdnc3nn3/O1q1befLJJ/nuu+9aVf9jWbVqFeedd55PyjoeBdSTZE53LzfFlu+xDpZ3bmVERETkpNx7773cfffdPP/884wfP55rrrmGjz/+mLS0tFaVk5yczKuvvsratWuZOHEi9913H1dffTV33nnnCY+9/PLLycnJITc317PtkksuYfz48Vx55ZUMHTr0uEs9XXfddVx88cX8/Oc/54c//CElJSXccMMNrar/0axevZqKigomT5580mWdiLECbBHPwsJCn99P9kQaHp4JeTswP7kT25gLO/TcgUz39Q5catvApHYNXG1p22Pdi11a5tFHH6W8vJwnn3yy3c8VHBzcoux08803M2TIEGbOnHnCfY/V/sHBwSQkJJzwePW
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAArEAAAF0CAYAAAA0F2G3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABtCklEQVR4nO3deXxU9b3/8deZZCb7SnaSQFhlDSCb0AqCllZpEUrdbqu0Um3BWrreehVb9x+2xe22vd7aSperYkUEsQJGURQXxIXdQFhDFkjIvs9yfn9MMjAmQIJJZuH9fDx4JHPmzDnfyZeQN598zvcYpmmaiIiIiIgEEIuvByAiIiIi0lUKsSIiIiIScBRiRURERCTgKMSKiIiISMBRiBURERGRgKMQKyIiIiIBRyFWRERERAKOQqyIiIiIBByFWBEREREJOAqxIiIiIhJwQn09AF+orKzE4XD4ehjSTZKTkykrK/P1MKSbaV6Dl+Y2eGlug1Nvz2toaCgJCQnn3q8XxuJ3HA4Hdrvd18OQbmAYBuCeU9M0fTwa6S6a1+CluQ1emtvg5M/zqnYCEREREQk4CrEiIiIiEnAUYkVEREQk4CjEioiIiEjAuSAv7DoT0zSpq6vzu8ZlObvGxkZaWlo6fC4sLIywsLBeHpGIiIj0NIXY09TV1REWFobNZvP1UKQLrFZrh6tNmKZJY2Mj9fX1REVF+WBkIiIi0lPUTnAa0zQVYIOIYRhERkZqTWAREZEgpBArQa9tjTsREREJHgqxIiIiIhJw1BMrIiIiIu2YDjuUFlGfvx0zPhlS0n09JC8KsXJOBQUFzJ8/n3feeYfo6OhuP/7vf/971q9fz2uvvdbp11x55ZXcdtttXHnlld0+HhERkQuJ6XJC+XEoOopZdLj14xE4UQxOJxWA5brvY8z8uq+H6kUhVs5p2bJl3HTTTZ4Au3LlSn7zm9+wd+/ebjn+D37wA7773e926TVLlizh3nvv5atf/Wq3jEFERCTYmaYJVRVQdASz+AgcO4JZfBRKjsIZlqokIgrbgME4omJ6d7CdoBArZ1VcXMxrr73GPffc0+XXtrS0dGq1h6ioqC4vgTVz5kx+8Ytf8OabbzJr1qwuj01ERCSYmfW17rBadBSKj7grq0VHoaGu4xdYbZCehdG3H/Tth9E3GzL6YSQmkZqRQUlJid+to68QewamaUJLs29Obgvr0hX1mzZt4rHHHiM/Px+LxcLFF1/MvffeS//+/QF3EL3vvvvYvHkzzc3NDB48mAceeIBx48YBsHHjRh555BHy8/OJjIxk8uTJPPXUUwC8/PLLDB8+nIyMDADeffddfvrTnwLQt29fAH7605/ys5/9jEmTJnH99ddz+PBh1q9fz6xZs3jsscd44IEHePXVVykpKSElJYW5c+fyk5/8BKvVCrRvJ1iyZAk1NTVMnDiRJ598kpaWFubMmcM999zjeU1ISAgzZszgpZdeUogVEblAmc3NUHECbOEQEQnhERiWC+uadbO5CUoKW0PqqdBKVUXHL7BYILUvRkZ2a1h1h1aSUzEsIe129+cVfhRiz6SlGddt1/jk1Jb/fh7Cwju9f0NDA7fccgsXXXQRDQ0N/O53v2PhwoVs3LiRxsZG5s+fT1paGk8//TTJycns3LkTl8sFQF5eHgsXLuT222/n8ccfp6Wlhddff91z7A8++IDRo0d7Ho8fP5577rmH3/3ud2zevBnAq4r6P//zPyxZsoQf//jHnm1RUVE88sgjpKWlsXfvXn75y18SHR3NokWLzvie3n33XVJSUvjXv/7FoUOH+OEPf8iIESP4j//4D88+Y8aM4U9/+lOnv04iIhK4zOZmOHYI80gBHDng/lhSCK0/zwAwDAiPdAfaiEiIjIKIKIyISIiIan186nPP9tP27Wohqdvfp8sFTgfY7eCwg8PR+tEO9hbMEyWtbQDu0Er5cThThbRPindVNbMfpGZitBaEAp1CbBC46qqrvB7//ve/Z/To0ezbt49t27Zx8uRJXnnlFRISEgDIycnx7Pv4448zZ84cfv7zn3u2jRgxwvN5YWEho0aN8jy22WzExMRgGAYpKSntxjJ16lR+8IMfeG1bsmSJ5/OsrCwOHDjA2rVrzxpi4+LieOCBBwgJCWHQoEHMnDmTd955xyvEpqenU1RU5AnkIiISHMzmZig8iHnkABwpwDx6AIoLwezg3/uISLC3uMOeaUJjvfvP6cc703k62hgSciroej62Bt62oBsRCaGh3gHTYQe7A5ytwbM1hJrtgmjrR09QPf0YDvf2roqJ86qqGhnZkJHtHnMQU4g9E1uYuyLqo3N3xeHDh/ntb3/Lxx9/TEVFhSfUFRUVsXv3bkaOHOkJsJ+3e/dur2D4eU1NTYSHd74qfHrVts26det46qmnOHz4MPX19TidznOucjBkyBBCQk79WiM1NbXdhWTh4eG4XC6am5sJDdVfZRGRQGQ2N3kH1iMHoORYx4E1Nh76DcLoNxCj30DoNxjiEzEMA9Pe4g6vDQ2ngmxjA2bDqc9pbICGOszTPj+1vd59TqcT6mrdf04fZ+98OdoLCYVQqzs0h1ohKcW7FSAjGyM23lej8yn95D8DwzC69Ct9X1qwYAEZGRk8/PDDpKWl4XK5mDFjBna7/ZwB9FzPJyYmUlVV1emxREZ6/6/vo48+YtGiRfzsZz9j+vTpxMTEsGbNGv73f//3rMexdvCrjs83lFdWVhIREUFERAR2u73TYxQREd/wBNbDp7UElBZ1HFjjEiB7IEZraKXfIE9g7YhhtbkvTor1Ltp0tjHANE1objoVaD0h+LSg29j6eUMDptOOEWr1DpihVrB2sK31sWd/a6jXdq/9rFZ3cG39eKH1+HaFQmyAq6ioYP/+/SxbtoxJkyYBsHXrVs/zw4YN49lnn6WysrLDauywYcN45513uPbaazs8/siRI9m/f7/XNpvNhtPp7NT4PvzwQzIzM716ZIuKijr12nPJz8/3anUQERH/YTY1QuHneljPGFgTobW6avQb5P48vk+vjtcwDAiPcP9JOHVu/72sSRRiA1x8fDwJCQn885//JCUlhaKiIh566CHP81dffTVPPPEEN998M3fccQcpKSns2rWL1NRUxo8fz09/+lOuvfZa+vXrx5w5c3A4HGzatMnTrzpt2jR+8Ytf4HQ6Pb/ez8zMpL6+nrfffpsRI0Z4qqEdycnJoaioiDVr1pCbm8vrr7/Oq6++2i3vfevWrVx66aXdciwRETnFNE13b6bT2f5jW9/m554zHXZqP6jGufNjd0tA6bGOLzjyBNZBpwXWxN5/kxLwFGIDnMVi4Y9//CN33303M2fOZMCAAdx3333Mnz8fcFdNn332We655x6+853v4HA4GDJkCA888AAAU6ZM4cknn+TRRx/lD3/4A9HR0UyePNlz/JkzZxIaGsrbb7/N9OnTAZgwYQLf+c53+OEPf0hlZaVnia2OzJo1i+9///vceeedtLS0MHPmTJYsWcLy5cu/0PsuKSlh27ZtPP7441/oOCIiwcJ0OeHgPsxPP8A8Xvy5oHmGQHr6R8dpn5/nBbNVn98Qn+juYc0eqMAq3c4w/W3l2l5QVlbWYQ9lTU0NsbGxPhiRf1uxYgUbN27kmWee8fVQPO677z5qa2t5+OGHsVqtZ+2J1bwGHsMwSE9P98vFteWL0dx2L9PeAp/twPzkfcztW6GmqudOZrG4ezVDQtp/DA2FkFDCM/vRkpYJ2e4eViOu44uKJXD44nvWarWSnJx8zv1UiZVz+va3v011dTV1dXXnXFWgtyQlJbVbyktE5EJgNtRh7vwIPnkfc9fH0Nx46smIKIxR42HQRacuFAoJwThT+AwJbQ2gZ3iuLaBaQs55gZFhGCTrPyjSixRi5ZxCQ0O9LszyBz/84Q99PQQRkV5jVpRjbt+K+en7kL/T/Wv/NvGJGGMmY4ydBENGuq+AF7kAnFeI3bBhA2vXrqWqqorMzEwWLFjAsGHDzrj/+vXr2bBhAydOnCApKYl58+Yxbdo0z/MffPABq1evprS0FKfTSVpaGl//+tfbXbTT1fOKiIg
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.5 # scale for random parameter initialisation\n",
"learning_rate = 0.1 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"|`init_scale`| Final `error(train)` | Final `error(valid)` |\n",
"|------------|----------------------|----------------------|\n",
"| 0.01 | 2.43e-01 | 2.58e-01 |\n",
"| 0.1 | 2.43e-01 | 2.59e-01 |\n",
"| 0.5 | 2.45e-01 | 2.62e-01 |\n",
"\n",
"<span style=\"color:green\">\n",
"Larger initialisation scale of 0.5 seems to give slightly slower initial learning than smaller scales of 0.1 and 0.01 however difference is only slight suggesting for this shallow architecure training performance is not particularly sensitive to initialisation scale.\n",
"</span>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Varying learning rate\n",
"\n",
"<span style=\"color:green\">Now let's try some different values for learning rate.</span>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `learning_rate = 0.05`"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 0.6s to complete\n",
" error(train)=3.41e-01, acc(train)=9.05e-01, error(valid)=3.16e-01, acc(valid)=9.12e-01\n",
"Epoch 10: 0.6s to complete\n",
" error(train)=3.10e-01, acc(train)=9.14e-01, error(valid)=2.92e-01, acc(valid)=9.18e-01\n",
"Epoch 15: 0.6s to complete\n",
" error(train)=2.97e-01, acc(train)=9.18e-01, error(valid)=2.82e-01, acc(valid)=9.21e-01\n",
"Epoch 20: 0.5s to complete\n",
" error(train)=2.88e-01, acc(train)=9.20e-01, error(valid)=2.76e-01, acc(valid)=9.23e-01\n",
"Epoch 25: 0.5s to complete\n",
" error(train)=2.83e-01, acc(train)=9.21e-01, error(valid)=2.73e-01, acc(valid)=9.24e-01\n",
"Epoch 30: 0.5s to complete\n",
" error(train)=2.77e-01, acc(train)=9.22e-01, error(valid)=2.69e-01, acc(valid)=9.24e-01\n",
"Epoch 35: 0.5s to complete\n",
" error(train)=2.74e-01, acc(train)=9.24e-01, error(valid)=2.67e-01, acc(valid)=9.25e-01\n",
"Epoch 40: 0.6s to complete\n",
" error(train)=2.72e-01, acc(train)=9.24e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 45: 0.6s to complete\n",
" error(train)=2.68e-01, acc(train)=9.26e-01, error(valid)=2.64e-01, acc(valid)=9.27e-01\n",
"Epoch 50: 0.5s to complete\n",
" error(train)=2.66e-01, acc(train)=9.26e-01, error(valid)=2.63e-01, acc(valid)=9.28e-01\n",
"Epoch 55: 0.6s to complete\n",
" error(train)=2.64e-01, acc(train)=9.26e-01, error(valid)=2.62e-01, acc(valid)=9.29e-01\n",
"Epoch 60: 0.5s to complete\n",
" error(train)=2.63e-01, acc(train)=9.26e-01, error(valid)=2.62e-01, acc(valid)=9.28e-01\n",
"Epoch 65: 0.6s to complete\n",
" error(train)=2.61e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.27e-01\n",
"Epoch 70: 0.5s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 75: 0.5s to complete\n",
" error(train)=2.58e-01, acc(train)=9.29e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 80: 0.5s to complete\n",
" error(train)=2.57e-01, acc(train)=9.29e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 85: 0.5s to complete\n",
" error(train)=2.56e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 90: 0.6s to complete\n",
" error(train)=2.55e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 95: 0.6s to complete\n",
" error(train)=2.54e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 100: 0.6s to complete\n",
" error(train)=2.53e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABqfElEQVR4nO3deXhU5cH+8e85M5PJnpCNhCwQCCBb2FERcQFc+VWxaF2qUmvVVovaVqt1qVrUYqvWV7u9ra9YLYqoKKJFRGsroCC4RPZ9CSQkgez7zJzfH5MMDAmYQJKZDPfnunJl5sxZnsnDJDfPeRbDsiwLEREREZEgYQa6ACIiIiIih1NAFREREZGgooAqIiIiIkFFAVVEREREgooCqoiIiIgEFQVUEREREQkqCqgiIiIiElQUUEVEREQkqCigioiIiEhQUUAVERERkaBiD3QBOlppaSkulyvQxZAOkpycTHFxcaCLIZ1AdRuaVK+hS3Uburqybu12Oz169Pj2/bqgLF3K5XLR2NgY6GJIBzAMA/DWqWVZAS6NdCTVbWhSvYYu1W3oCta61S1+EREREQkqCqgiIiIiElQUUEVEREQkqCigioiIiEhQCblBUiIiIhJ4lmVRVVUVVANvpHW1tbU0NDR02PmcTidOp/OEzqGAKiIiIh2uqqoKp9NJWFhYoIsi38LhcHTYDEiWZVFbW0t1dTVRUVHHfR7d4hcREZEOZ1mWwulJyDAMIiMjT3hOegVUEREREelQzfOrHi8FVBEREREJKgqox8Fyu7E2rcWz6r+BLoqIiIhIyFFAPR5b1uH5/a+wXvlfLI870KURERGRELNs2TImTpyIx+PplPPfcccd3HDDDW3ev76+nrFjx5KXl9cp5TnScY3if//991m4cCFlZWVkZGQwY8YMBg0a1Oq+Gzdu5J///Cd79+6lvr6e5ORkJk+ezNSpU1vdf/ny5TzzzDOMGTOGu++++3iK1/lyBkNEFFRVwPbNkNP6excRERE5Ho8++igzZ87ENL1tiU8++SSLFy/mgw8+6JDzP/LII+2aAszpdHLLLbfw6KOPMm/evA4pw7G0uwV1xYoVzJkzh8suu4zZs2czaNAgHnvsMUpKSlrd3+l0cv755/Pwww/z9NNPc9lllzFv3jyWLl3aYt/i4mJeeumlo4bdYGHY7RhDRwFg5a0KcGlEREQkECzLanW0+vHOKdp83Oeff86OHTuO2ph3LG2dLio2Npa4uLh2nXvatGmsWrWKLVu2tLtc7dXugLpo0SLOPfdcJk2a5Gs9TUpKYsmSJa3un52dzYQJE8jMzCQlJYWJEycyfPhwNmzY4Lefx+Phf/7nf7jiiitISUk5vnfTlYaPA8D6+vMAF0RERCS4WZaFVV8XmK92LhRgWRZ/+tOfOP300+nXrx+TJ09m0aJFgLeRLj09nY8//pgLL7yQ7OxsVq5cyfTp07nvvvt46KGHGDp0KFdddRUAn376KRdffDHZ2dmMHDmSxx57zC/QHu24hQsXMnHiRMLDwwGYN28eTz31FOvXryc9PZ309HRfK2Z6ejr/+Mc/+MEPfkBOTg7PPPMMbrebn//855x22mn069ePM888k7///e9+7/PIW/zTp0/ngQceYNasWQwZMoQRI0bw5JNP+h2TkJDA6NGjeeutt9r1Mz0e7brF73K52L59O5deeqnf9tzcXDZt2tSmc+zYsYNNmzZx5ZVX+m1//fXXiY2N5dxzz20RXlvT2Njo978EwzCIiIjwPe5s5rAxuE0T9u2Gkv0Yyamdfs2TTXM9dkV9StdS3YYm1WvoOuG6bajHc9sVHViitjOfew2c4W3ef/bs2fzrX//i8ccfJzs7m88++4yZM2eSmJjo22fWrFk8+OCDZGVlERsbC8D8+fO57rrrfOGtoKCAa6+9liuuuIJnnnmGrVu3ctddd+F0Ovn5z3/uO9eRxwF89tlnflnrO9/5Dps2beLjjz/m1VdfBSAmJsb3+pNPPsm9997LQw89hM1mw+PxkJaWxl/+8hcSEhJYvXo1d999NykpKXznO9856nufP38+N910E++88w5r1qzhzjvvZOzYsUycONG3z8iRI1m5cmWbfpYn8rugXQG1oqICj8fTokk4Li6OsrKyYx57yy23UFFRgdvt5vLLL2fSpEm+1zZu3MhHH33EE0880eayLFiwgNdff933PDs7m9mzZ5OcnNzmc5yooiEjqP/mC2J2bCQmd2SXXfdkk5qq8B+qVLehSfUautpTt7W1tTgcDgAsj5v6zirUt3A4HBhN5fg21dXV/O1vf+ONN95g7NixAOTk5LBmzRrmzp3LtddeC8A999zjl2MMwyA7O5uHH37Yt+2xxx4jPT2dJ554AsMwGDRoEMXFxfzmN7/h7rvvxjTNVo8DyM/Pp1evXr6fn8PhICYmBrvdTnp6eotyf/e73/WVrdm9997re9yvXz+++OIL3n33Xb773e8C+K7ffA3DMBg8eDC//OUvARg4cCAvvvgiK1as8Huv6enpLFq0yHfc0YSFhZGWlnbMfY7luAZJtZaIvy0lP/LII9TV1bF582bmzp1LamoqEyZMoLa2lmeffZabb77Z97+Qtpg2bZpf34zm6xcXF5/w6gVt5Rk0Er75gvJPllI17uwuuebJxDAMUlNTKSws1FrOIUZ1G5pUr6HreOq2oaHBd6fTMkxvS2YANBomRhv7Za5fv566ujouv/xy/3M0NjJ06FBfvhgyZIjfXVzLssjNzfXbtmnTJkaNGuWXSUaNGkV1dTW7d+8mPT291eMA6urqsNvtfts9Hg+WZbXax3To0KEttv/jH//glVdeIT8/n7q6OhobG/3Kffj5HA4HlmVxyimn+J0nOTmZoqIiv20Oh4Oamppv7eva0NBAQUFBi+12u71NjYntCqixsbGYptmitbS8vPxbO9o29yvNysqivLyc+fPnM2HCBPbv309xcTGzZ8/27dv8j//KK6/kD3/4Q6v/Y3M4HEdN7132izF3LLz2PNbmtXiqqzAij3/NWTk6y7L0xy5EqW5Dk+o1dB1v3RqG0a7b7IHSPKXTP/7xjxbZIywsjF27dgEQGRnZ4tjmbobNLMtq0XjX2s/uyOPA29ezvLy8zeU+sjwLFy7k4Ycf5oEHHmDMmDFERUXx5z//mS+//PKY57Hb/WOhYRgtprkqKyvz6+5wLCfye6BdAdVut9O3b1/y8vIYN26cb3teXp6vKbwtDh/11qtXL37/+9/7vf7qq69SV1fnG4AVrIyevSA1HQr3Yq37EmPshEAXSURERI7TgAEDcDqd7N27l9NPP73F680BtS369+/Pe++95xdUV69eTXR09Lfe+h4yZAibN2/22+ZwONo8J+qqVasYPXo0M2bMOK6yH8vGjRsZMmRIh5zrWNp9i3/q1Kk8++yz9O3blwEDBrB06VJKSkqYMmUKAHPnzuXgwYPcdtttACxevJikpCRfn4mNGzfyzjvvcOGFFwLe/5FkZWX5XSMqytsSeeT2YGTkjsMqXAB5q0ABVUREpNuKjo7m5ptv5qGHHsLj8TBu3DiqqqpYvXo1kZGRZGRktPlc119/PX//+9+5//77+cEPfsC2bdt48sknuemmm3xzmx7N2Wefzfz58/22ZWZmsnv3btauXUuvXr2IiorC6XS2enyfPn14/fXX+fjjj8nMzOSNN97g66+/JjMzs83lP5pVq1Zx1113nfB5vk27A+r48eOprKzkjTfeoLS0lMzMTO69915ff4LS0lK/OVEty+KVV16hqKgI0zRJTU3lmmuuYfLkyR33LgLIGD4Wa8kCrG/WYLndGDZboIskIiIix+nuu+8mKSmJ5557jt27dxMbG8uwYcP46U9/2q5VndLS0njppZeYNWsWU6ZMIT4+nquuuorbb7/9W4+97LLLePTRR9m6dSs5OTkAXHTRRbz33ntcccUVlJeX89RTT/G9732v1eOvvfZa1q1bx49//GMMw+CSSy7h+uuv56OPPmpz+VuzevVqKisrufjii0/oPG1hWCHWUai4uLjNk9R2BMvtxvOza6GmCvOuxzAGDO2ya4c6wzB
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAArEAAAF0CAYAAAA0F2G3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABu5ElEQVR4nO3deXxU9b3/8deZzGTf95AESCAgi0F2BRUFLa1SES/VulRpteoFtdQuv3pdWhXtRW9R621vbbWlrVWxIoKogAiKgMqisgvEsIQkkIRskz0zc35/TDIQE5ZAkll4Px+PPJI5c+ac78yXIe9853O+X8M0TRMRERERET9i8XYDREREREQ6SyFWRERERPyOQqyIiIiI+B2FWBERERHxOwqxIiIiIuJ3FGJFRERExO8oxIqIiIiI31GIFRERERG/oxArIiIiIn5HIVZERERE/I7V2w3whoqKChwOh7ebIV0kKSmJ0tJSbzdDupj6NXCpbwOX+jYw9XS/Wq1W4uLiTr1fD7TF5zgcDpqbm73dDOkChmEA7j41TdPLrZGuon4NXOrbwKW+DUy+3K8qJxARERERv6MQKyIiIiJ+RyFWRERERPyOQqyIiIiI+J1z8sKuEzFNk5qaGp8rXJaTq6+vp6mpqcP7QkJCCAkJ6eEWiYiISHdTiD1OTU0NISEhBAcHe7sp0gk2m63D2SZM06S+vp7a2loiIiK80DIRERHpLionOI5pmgqwAcQwDMLDwzUnsIiISABSiJWA1zrHnYiIiAQOhVgRERER8TsKsSIiIuKX6ptdbDtSi8OlC7LPRWd0Ydfy5ctZsmQJlZWVZGRkMGPGDAYNGnTC/ZctW8by5cspKSkhMTGR6667jgkTJnju/+yzz1i0aBGHDx/G6XSSmprKd7/7XS699NKzOq90jby8PKZPn87atWuJjIzs8uP/7ne/Y9myZbz//vun/ZirrrqKe+65h6uuuqrL2yMiIr7NNE0+PmDnb5+XUF7voHdMMD8elUJuqi7iPZd0eiR2/fr1zJ8/n+uuu465c+cyaNAgnnzyScrKyjrcf8WKFbz66qt873vfY968eVx//fW89NJLbNq0ybNPZGQk1113HXPmzOHpp5/m8ssv549//CNffvnlGZ9Xus7cuXO57bbbPAF2wYIFXfrHw913382CBQs69ZjZs2fz5JNP4nK5uqwdIiLi+w5UNvLQyoP8bl0R5fXuC3cPVjXx8AcFPL22kLK69rPVSGDqdIhdunQpEydOZNKkSZ7R0MTERFasWNHh/mvWrOGKK65g3LhxpKSkMH78eCZOnMjixYs9+wwZMoQxY8aQkZFBamoqV111FX369OGrr7464/NK1ygqKuL999/nhhtu6PRjTzR36zdFREQQHx/fqWNPmjQJu93Ohx9+2Ol2iYiI/6ltcvLipiPMfncf20vqCQ4yuHlYIvOv689VA2KxGLD2gJ1Zb+fz5o6jNDtVYhDoOlVO4HA4yM/P59prr22zPTc3l927d3f4mObmZmw2W5ttwcHB5OXl4XA4sFrbNsE0TbZv305RURE333zzGZ+39dzHzx9qGAZhYWGen0/GNE1oajzpPt0mOKRTV9SvXr2a5557jt27d2OxWBg5ciSPPfYYffv2BdxB9PHHH2fNmjU0NjaSk5PDE088wYgRIwD3aPkzzzzD7t27CQ8P58ILL+TFF18E4O2332bw4MH06tULcI+I33///QCkp6cDcP/99/Ozn/2MsWPHcuONN7J//36WLVvG5MmTee6553jiiSd47733KC4uJjk5mWnTpvHTn/7U8+/im+UEs2fPprq6mjFjxvDCCy/Q1NTE1KlTefTRRz2PCQoKYuLEibz11ltMnjz5lK+RZijwL639pX4LPOrbwNVdfWuaJqv3VTH/8xIqG5wAXJQZxe0jU0iOdP9OuHtMGt/qH8efNh7mq9J6/v5lKSvzq7hzdArD07q+DO5c4svv2U6F2OrqalwuFzExMW22x8TEUFlZ2eFjhg0bxqpVqxgzZgxZWVnk5+ezevVqnE4ndruduLg4AOrq6rjrrrtwOBxYLBZuv/12cnNzz/i8AIsWLeKNN97w3M7KymLu3LkkJSV1uH99fb0nJJmNDTTec/1JX4/uEvLCIozg0NPev7Gxkf/8z/9k8ODB1NbW8tRTT3HHHXewevVq6urqmD59OmlpafzjH/8gOTmZbdu2YbFYsNlsvP/++9xxxx3Mnj2bP/7xjzQ1NbFy5UrP67BhwwYuuOACz+2LLrqIOXPmMHfuXNavXw+4R1JtNhuGYfCnP/3JE2rBvRBBdHQ0zz//PKmpqezatYv777+f6Oho7r33XgAsFguGYXjOYbFYWL9+PampqSxatIh9+/Zx5513kpubyw9+8APP8x45ciR/+MMfPOc5keDgYNLS0k779RTfkZqa6u0mSDdR3wauruzb3UfsPP3BHrYUVgHQJz6cn0/M4cKshHb7pqXBuMF9eXfHYX7/0dcUVjfx6w8KmDggiZ9enkNq9On/XpX2fPE9e0YXdnWUxk+U0KdPn05lZSUPPvggpmkSExPDhAkTWLJkCRbLsWqG0NBQnn76aRoaGti2bRv/+Mc/SElJYciQIWd0XoBp06YxZcqUdvuWlpZ2OAF+U1OTZ+TW7GAFqJ7S3NyMYQk67f2//e1vt7n99NNPk5uby44dO9i0aRNHjx7lnXfe8fzBkJmZ6TnPvHnzmDp1qmd0FWDgwIGe1+HgwYMMHTrUc7t1AQHDMNqUADQ3N2OaJuPHj+fOO+9ss701rAKkpaVx1113sXjxYu6++24AXC4Xpml6ztH6B8vjjz9OUFAQffv2ZeLEiXz00Ud8//vf9xwrOTmZQ4cO4XK5cDqdJ3x9mpqaKC4uPu3XU7zPMAxSU1M5fPiwloEOMOrbwNWVfVvT6OTlLaUs21uBy4RQq8EN5ydyzXkJ2IJO/n/6iAT4w5S+vLq1jKW7y1m1p5S1X5dx/dBEpg2OxxakiZlOxayrhYJ8zANfYxbkE1R0APPq72OMuKhHzm+1Wk844Nhmv84cNDo6GovF0m70s6qqqt0oaavg4GBmzpzJnXfeSVVVFXFxcaxcuZKwsDCioqI8+1ksFk/K79u3L4WFhbz11lsMGTLkjM4L7tG5E43QnfINFhyC5X9fP/k+3SU4pFO779+/n6effprPP/+c8vJyz8VOhYWF7Nixg6FDh3oC7Dft2LHDU7bRkYaGBkJDT/+v19bR8+MtXbqUF198kf3791NbW4vT6TzlLAcDBgwgKOhYkE9JSWHXrl1t9gkNDcXlctHY2NiuLOWb9MvSP5mmqb4LUOrbwHU2fesyTT74uop/fFlKdaN7cOLiPlH8cEQyieE2z/FPJdxm4faRyUzKjubPm46wo6Sel7eU8kF+JXeMTGFUukoMWpnVlXAwH/Pg18e+lx5us48DMPbvxTL8Qq+08UQ6FWKtVivZ2dls3bqVMWPGeLZv3bqV0aNHn/KxCQnu4f9169YxYsSINiOx33T8yNzZnPdMGYYBIf7x0cOMGTPo1asXTz31FKmpqbhcLiZOnEhzc/MpA+ip7o+Pjz9pycY3hYeHt7m9efNmZs6cyc9+9jMuu+wyoqKiWLx4MX/+859PepyO/vj45n9cFRUVhIWFERYW1qb2WURE/M/eo/X8eeMR9hxtACAzJpg7z3LarL5xoTxxRW8+PmDnr5+XUGxv5vEPDzEmI5I7RiaTEnnuLDVvmiaUl0HB15gH8jEL8uHA11B5tOMHJCRD72wsvfsRf8FoKmLal3B4W6fLCaZMmcLzzz9PdnY2AwYMYOXKlZSVlXHllVcC8Morr1BeXs4999wDuC8qysvLIycnh9raWpYuXUpBQQGzZs3yHHPRokX069ePlJQUHA4HX3zxBWvWrOGOO+447fOeq8rLy9m7dy9z585l7NixgLuOtdWgQYN49dVXqaio6HA0dtCgQaxdu/aEsw8MHTqUvXv3ttkWHBx80o/vj7dx40YyMjL4yU9+4tlWWFh4Wo89ld27d3P++ed3ybFERMQ7qhudvPxlKSvyKjGBMKuFG3MTuXpgHFb
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.05 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `learning_rate = 0.1`"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 0.5s to complete\n",
" error(train)=3.11e-01, acc(train)=9.13e-01, error(valid)=2.92e-01, acc(valid)=9.18e-01\n",
"Epoch 10: 0.5s to complete\n",
" error(train)=2.89e-01, acc(train)=9.20e-01, error(valid)=2.77e-01, acc(valid)=9.23e-01\n",
"Epoch 15: 0.6s to complete\n",
" error(train)=2.79e-01, acc(train)=9.22e-01, error(valid)=2.70e-01, acc(valid)=9.24e-01\n",
"Epoch 20: 0.6s to complete\n",
" error(train)=2.72e-01, acc(train)=9.24e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 25: 0.6s to complete\n",
" error(train)=2.68e-01, acc(train)=9.25e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 30: 0.5s to complete\n",
" error(train)=2.63e-01, acc(train)=9.27e-01, error(valid)=2.62e-01, acc(valid)=9.26e-01\n",
"Epoch 35: 0.5s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 40: 0.5s to complete\n",
" error(train)=2.59e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 45: 0.6s to complete\n",
" error(train)=2.55e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 50: 0.6s to complete\n",
" error(train)=2.54e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 55: 0.6s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 60: 0.6s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 65: 0.6s to complete\n",
" error(train)=2.50e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 70: 0.5s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.59e-01, acc(valid)=9.31e-01\n",
"Epoch 75: 0.6s to complete\n",
" error(train)=2.47e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 80: 0.6s to complete\n",
" error(train)=2.46e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 85: 0.5s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 90: 0.6s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 95: 0.5s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 100: 0.5s to complete\n",
" error(train)=2.43e-01, acc(train)=9.33e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABxc0lEQVR4nO3deXhU5f3+8feZzGQnC9khCQRCWA2boOCCsmitVMCirahI1apVf7hVK61atYpiq9aq1dbylapFERQF3BCQVgRBUInsa1gTyL5vkzm/PyYZGBIkCUlmMtyv68o1M2fO8px8GLh55jnPMUzTNBERERER8RIWTzdAREREROR4CqgiIiIi4lUUUEVERETEqyigioiIiIhXUUAVEREREa+igCoiIiIiXkUBVURERES8igKqiIiIiHgVBVQRERER8SoKqCIiIiLiVayebkBrKygowG63e7oZ0kpiYmLIycnxdDOkDai2vkl19V2qre9qz9parVYiIyNPvV47tKVd2e12ampqPN0MaQWGYQDOmpqm6eHWSGtSbX2T6uq7VFvf5a211Vf8IiIiIuJVFFBFRERExKsooIqIiIiIV1FAFRERERGv4nMXSYmIiIjnmaZJaWmpV114I42rqKigurq61fYXEBBAQEDAae1DAVVERERaXWlpKQEBAfj7+3u6KXIKNput1WZAMk2TiooKysrKCAkJafF+9BW/iIiItDrTNBVOz0CGYRAcHHzac9IroIqIiIhIq6qfX7WlFFBFRERExKsooLaA6ajF3L4Jx7r/ebopIiIiIj5HAbUltm/C8ZffY77zGqaj1tOtERERER+zatUqLrzwQhwOR5vs/+677+bGG29s8vpVVVUMGzaMjIyMNmnPiRRQW6JXfwjpBCVFsHOLp1sjIiIiPubJJ59k+vTpWCzOqPbss88ybty4Vtv/448/zvPPP9/k9QMCArjtttt48sknW60NP6ZF00x99tlnLFq0iMLCQhITE5k2bRp9+/ZtdN1t27bxn//8h0OHDlFVVUVMTAxjx45l/PjxrnUOHDjAvHnz2Lt3Lzk5Odxwww1cfvnlLTujdmBYrRiDhmN+tRxzw2qM3md5ukkiIiLSzkzTpLa2FqvVPU5VV1e3aAaD+u2++eYb9u7d65aVmqqmpgabzXbK9cLCwpq970mTJvHEE0+wc+dOevXq1eztm6PZPairV69mzpw5XHnllcyaNYu+ffsyc+ZMcnNzG10/ICCASy+9lMcee4znn3+eK6+8knnz5rFs2TLXOlVVVcTFxTFlyhQiIiJafDLtyRh6HgDmt2sw26j7XURExBeYpolZVemZn2beKMA0Tf7+978zYsQIevbsydixY1myZAngzEBdu3Zl5cqVXHbZZaSkpLB27VomT57MH/7wBx599FEGDBjANddcA8CaNWu4/PLLSUlJYfDgwcycOdNt+qWTbbdo0SIuvPBCAgMDAZg3bx7PPfccW7ZsoWvXrnTt2pV58+YB0LVrV9544w1+9atfkZqaygsvvEBtbS333Xcf5557Lj179uSCCy7gX//6l9t5nvgV/+TJk3n44Yd54okn6N+/P4MGDeLZZ59126Zz584MHTqUDz74oFm/05Zodg/qkiVLGD16NGPGjAFg2rRpbNy4kaVLlzJlypQG66ekpJCSkuJ6HRsby7p169i6dStjx44FIDU1ldTUVADmzp3bohNpd30GQlAwFOXDnu2Q2ngPsoiIyBmvugrHnVd75NCWl96FgMAmrz9r1iw++eQTnnrqKVJSUvj666+ZPn06UVFRrnWeeOIJHnnkEZKTk109kfPnz2fq1Kmu8JaVlcX111/P1VdfzQsvvMCuXbu4//77CQgI4L777nPt68TtAL7++msmTpzoen3FFVewfft2Vq5cyTvvvANAp06dXO8/++yzzJgxg0cffRQ/Pz8cDgcJCQm8+uqrdO7cmfXr1/PAAw8QGxvLFVdccdJznz9/PrfccguLFy9mw4YN3HPPPQwbNowLL7zQtc7gwYNZu3Ztk3+fLdWsgGq329mzZ4/bLw0gPT2d7du3N2kfe/fuZfv27fzyl79szqEbqKmpcbvrgWEYBAUFuZ63NcPfHzN9OObalZjfrsbSq1+bH/NMU1/H9qintC/V1jeprr7rTKlteXk5r732GvPmzePss88GoFu3bnzzzTe89dZbXHvttQDcf//9bqENoHv37jz00EOu108//TRdunThySefxDAMUlNTyc7OZubMmdxzzz2usaUnbgdw8OBB4uLiXK+DgoIICQnBz8+P2NjYBu2eOHFig1z129/+1vU8OTmZ9evXs3jx4h8NqH379uXee+8FoEePHsyZM8d1sVa9+Ph4Dh48eNJ9HO90/rw0K6AWFxfjcDgIDw93Wx4eHk5hYeGPbnvbbbdRXFxMbW0tV111lasHtqUWLlzIggULXK9TUlKYNWsWMTExp7Xf5igfN568tSuxbFxH/F0P+fwH11Pi4+M93QRpI6qtb1JdfVdzaltRUeEaC2larfCPhW3VrB/nH9Dkf5/37NlDZWWl66v2ejU1NZx11lmusaZDhw51G+dpGAaDBw92W7Z7926GDRvmNhZ1xIgRlJWVkZOTQ2JiYqPbAVRWVhISEuK23GKxYBhGo+NLhwwZ0mD5nDlz+M9//sPBgwepqKigpqaGAQMGuNY7cX+GYdC/f3+3/cTHx5Ofn++2LCQkxK22J+Pv709CQsKPrvNjWnSRVGOFPlXxH3/8cSorK9mxYwdz584lPj6e888/vyWHB5wDdY8fPFx//JycnNO+vVZTmV26Q0AgtUezyPr6S4zubTtg+ExjGAbx8fFkZ2c3ewyReDfV1jeprr6rJbWtrq52v7+7xa+NWncKzcgE1dXVALzxxhsNwri/vz/79u0DGt673jRNAgIC3JY5HA5M03RbVv/cbrdTU1PT6HbgHOuZl5d3yv3VO3EfixYt4pFHHuHhhx/m7LPPJiQkhFdeeYXvvvvOtd7x+7PZbJimicViaXBe9W2tl5eXR1RUVKPtOPF3mZWV1WC51WptUmdiswJqWFgYFoulQW9pUVFRg17VE9V3SScnJ1NUVMT8+fNPK6DabLaTpvd2+4vR5o8xYCjmhq9wrP8KS7fU9jnuGcY0Tf1j56NUW9+kuvouX69tWloaAQEBHDp0iBEjRjR4vz6gNkWvXr34+OOPMU3T1Ym2fv16QkNDT9mz2L9/f3bs2OG2zGazNXlO1HXr1jF06FCmTZvWorb/mG3bttG/f/8mrXs6f1aadRW/1WqlR48eDSZpzcjIoHfv3k3eT30i9wmuq/lX+/SHVkRExNeFhoZy66238uijj/Luu++SmZnJpk2bmDNnDu+++26z9nXDDTdw+PBhHnroIXbt2sVnn33Gs88+yy233OIaf3oyF110Ed98843bsqSkJPbv38+mTZvIz8+nqqrqpNt3796djIwMVq5cye7du3nmmWfYuHFjs9p/MuvWrWPUqFGtsq8f0+xppsaPH8/y5ctZsWIFBw8eZM6cOeTm5romj507dy4vvfSSa/1PP/2U9evXk5WVRVZWFl988QWLFy/mggsucK1jt9vJzMwkMzMTu91Ofn4+mZmZZGdnt8Ipti3jrKFg84ejWXAw09PNERERkdPwwAMPcM899/DSSy9x0UUXMWXKFD7//HOSk5ObtZ+EhATefPNNvv/+e8aNG8eDDz7INddcw1133XXKba+88kp27NjBrl27XMt++tOfctFFF3H11Vdz1lln/ehUT9dffz2XXXYZv/nNb/jZz35GQUEBN9xwQ7Pa35j169dTUlLSLnPVG2YLuv3qJ+ovKCggKSmJG264gX79nFexv/zyy+Tk5PDoo48C8Mknn7Bs2TKOHj2KxWIhPj6eMWPGMHbsWNf/II4ePcqdd97Z4Dj9+vVz7aepcnJyTjkuorXVvjwTvv8aY/wvsEy4tl2P7csMwyAhIYGsrCz1TvsY1dY3qa6+qyW1LS4ubtFk8OL0xBNPUFxczDPPPNPmxzpxTO3J3HLLLQwYMIDp06efct2T1d9mszVpDGqLAqo380RAdXz
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAroAAAF0CAYAAADM95pAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACFXElEQVR4nOzdeXxU1f3/8dedLftKViCBJCRAwi6yaoMstQqWpVTrUuWrKDagpa31V8UFEbe2otVqa6sWawuiYEwMKjGAImBBQAEJAiEgIQkmIZnsyyzn98fA4JggCWaZTD7PxyMPZu6ce++5c5Lwzplzz9GUUgohhBBCCCE8jK6rKyCEEEIIIURHkKArhBBCCCE8kgRdIYQQQgjhkSToCiGEEEIIjyRBVwghhBBCeCQJukIIIYQQwiNJ0BVCCCGEEB5Jgq4QQgghhPBIEnSFEEIIIYRHkqArhBBCCCE8kqGrK+COKioqsFqtXV0N0Y7Cw8MpLS3t6mqIdibt6rmkbT2XtK1n6sx2NRgMhISEtK5sB9elW7JarVgslq6uhmgnmqYBjnZVSnVxbUR7kXb1XNK2nkva1jO5c7vK0AUhhBBCCOGRJOgKIYQQQgiPJEFXCCGEEEJ4JAm6QgghhBDCI8nNaG3U2NhIY2NjV1dDtFF9fT1NTU0tvubl5YWXl1cn10gIIYQQHU2CbhvU1taiaRoBAQHOOwxF92A0GlucSUMpRX19PbW1tfj5+XVBzYQQQgjRUWToQhtYrVZ8fX0l5HoQTdPw9fWVeZOFEEIIDyRBtw0k4HouaVshhBDC80jQFUIIIYQQHkmCrhBCCCGEuGiqsoKqN/+Fstu6uirNSNAV7SIvL48RI0ZQU1PTIcd/+umnmTZtWpv2ufrqq3nvvfc6pD5CCCFET6cqTmN/45/Y/jCfytdeQO3a3tVVakaCrmgXTz31FLfccgv+/v4ArFmzhsGDB7fb8e+8807WrFnTpn0WL17M448/jt1ub7d6CCGEED2dOl2K/b9/x37/7aiN74KlCdPAIWjBIV1dtWYk6IofrKioiA8//JDrrruuzfueb27b7/Lz8yM0NLRNx54yZQrV1dV89NFHba6XEEIIIVyp0lPY//1X7EsWoD56D6xWGJCM7jfLiHj6X2hJQ7q6is1I0P0BlFKoxoau+VKqTXXdvHkzs2bNYvDgwaSkpHDzzTdz/Phx5+tFRUX86le/IiUlhQEDBnDVVVexZ88e5+vZ2dlcddVVxMfHM2TIEObPn+987d133yU5OZnevXsDsH37dn77299SVVVFnz596NOnD08//TQAY8eO5dlnn2Xx4sUMGjSI3//+9wA89thjXHbZZSQkJDB+/Hj++Mc/usx7+92hC4sXL+bWW2/l73//OyNHjiQlJYX777/fZR+9Xs/kyZN555132vReCSGEEOIc9U0R9n/9BfsDd6I+yQabFQYORXfPY+jufQJdyki3nb3oohaM2LBhA5mZmZjNZvr27cu8efO+92PqDz74gA0bNlBSUkJYWBhz5swhNTXV+fqOHTtIT0/n1KlT2Gw2oqKiuOaaa/jRj37kLJOens7OnTspLCzEZDKRlJTETTfd5AxXAC+88AIff/yxy7kTExN57LHHLuYyL6ypEfuiazvm2Beg++ub4OXd6vJ1dXXccccdDBo0iLq6Ov785z8zf/58srOzqa+vZ+7cuURFRfGvf/2L8PBw9u/f7/zIPycnh/nz53P33Xfz3HPP0dTUxMaNG53H3rFjB8OGDXM+Hz16NI888gh//vOf2bJlC4DLYgx///vfWbx4Mb/+9a+d2/z8/HjmmWeIiori4MGD3Hvvvfj7+5OWlnbea9q+fTsRERG89dZbHDt2zBnUb7zxRmeZESNG8Le//a3V75MQQgghHFTxSdR7b6J2bAF1Zhhg8kh0M65DS0zu2sq1UpuD7vbt21m5ciXz589n4MCB5OTk8Pjjj/PMM88QFhbWrHx2djarV69mwYIFJCQkkJeXx0svvYSfnx+jR48GwN/fnzlz5tC7d28MBgN79uzhxRdfJDAwkBEjRgCQm5vLlVdeSUJCAjabjTfeeIPly5ezYsUKvL3PBb4RI0a4hCODQRZ/A5g+fbrL86effpphw4Zx+PBhdu3axenTp1m/fj0hIY7xNXFxcc6yzz33HDNnzuSee+5xbktJSXE+LigoYOjQoc7nJpPJuXpcREREs7pMnDiRO++802Xb4sWLnY9jYmI4evQomZmZ3xt0g4KCeOyxx9Dr9QwYMIApU6awdetWl6AbHR1NYWGhjNMVQgghWkkVfo1a/yZq11Y4+wny0NGOgBs/sGsr10ZtToFZWVlMnjyZKVOmADBv3jz27t1LdnY2N9xwQ7PyW7ZsYerUqUyYMAGAyMhIjhw5QkZGhjPofjs0geNu+Y8//pivvvrKGXSXLFniUiYtLY358+eTn59PcvK5vyoMBgPBwcFtvayLY/Jy9Kx2BZNXm4ofP36cP/3pT+zZs4fy8nJn8CssLOTAgQMMGTLEGXK/68CBAy7h8bsaGhpc/ti4kG/3/p6VlZXFyy+/zPHjx6mtrcVmszlvbDufpKQk9Hq983lkZCQHDx50KePt7Y3dbqexsVH+6BFCCCG+hzqRj339Gtjz6bmNI8ahm3EtWr8BXVexH6BN//NbrVby8/OZNWuWy/Zhw4Zx6NChFvexWCwYjUaXbSaTiby8PKxWa7PwoZTiyy+/pKio6HvDVV1dHUCzMJSbm8v8+fPx8/Nj8ODBXH/99QQFBZ23bt8e06lpGj4+Ps7HF6JpWpuGD3SlefPm0bt3b/74xz8SFRWF3W5n8uTJWCyWC4bUC70eGhqK2WxudV18fX1dnu/evZu0tDR+97vfMWnSJAICAsjIyOAf//jH9x7nu99XQLOxyxUVFfj4+ODj4+PS1i1x1/FFomVn20vazfNI23ouaVv3pI4fwZ61BvXFDscGTUMbNcHRgxsT9/07497t2qagW1VVhd1ubxYcg4KCzht0hg8fzqZNmxgzZgxxcXHk5+ezefNmbDYb1dXVzl7Euro6FixYgNVqRafTcdttt7XY8weOMPPaa68xaNAgYmNjndtHjhzJ+PHjCQsLo6SkhDVr1rBs2TKefPLJFkNReno6a9eudT6Pi4vjqaeeIjw8vMXz1tfXt3gcd1deXs6RI0d4+umnGTduHAD/+9//AMcNW0OGDGH16tXU1NS02KubkpLCtm3buOmmm1o8/rBhw8jLy3N5b3x8fLDZbM3eL03T0Ov1Ltv37NlD3759XYZGFBcXA+fCrE6nQ9O08z4/ey3f3ZaXl+f8Pvq+tjOZTERHR5/3deG+oqKiuroKooNI23ouaVv30HhwH1VvvEzD2flvNQ3fH/2YwOtuxdgvoc3Hc8d2vajPcltK7OdL8XPnzsVsNrNkyRKUUgQFBZGamkpmZiY63blJH7y9vfnTn/5EQ0MD+/fv59///jeRkZHNhjUAvPLKK5w4cYJly5a5bD87PAIgNjaWhIQE0tLS2LNnD2PHjm12nNmzZzNjxoxm11BaWorVam1Wvqmp6YK9gu7Iz8+PkJAQVq5cSWhoKIWFhTzxxBMA2Gw2rrnmGp599lluvvlm7rvvPiIiIvjyyy+JjIxk9OjRLF68mOuuu47Y2FhmzpyJ1Wpl8+bNzvGzl19+Ob///e9paGhwDiWIjo6mtraWTZs2kZKS4uxVVUphs9lc3sfY2FgKCwtZu3Ytw4cPZ+PGjaxfvx7AWc5ut6OUOu/zs9fy3W2ffvopl19+ucuxWtLU1OQM16J70DSNqKgoTp061eZZSIR7k7b1XJ3dtqq8FHvmalTeQbSoPhATjxYbj9YvAULC3LIHsjOow19if3cN6uAXjg06HdrYSeim/5ymqL6UAbTh/8TObleDwXDeTslmZdty4MDAQHQ6XbPe28rKyvMODzCZTKSlpXHHHXdQWVlJSEgIOTk5+Pj4EBAQ4Cyn0+mcfwn079+fwsJC3nnnnWZB99VXX2X37t088sgj9OrV63vrGxISQnh4+HkDjNFoPG8vnyf9ctXpdLz44os89NBDTJkyhfj
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.1 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `learning_rate = 0.2`"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 0.5s to complete\n",
" error(train)=2.90e-01, acc(train)=9.19e-01, error(valid)=2.77e-01, acc(valid)=9.22e-01\n",
"Epoch 10: 0.5s to complete\n",
" error(train)=2.75e-01, acc(train)=9.23e-01, error(valid)=2.69e-01, acc(valid)=9.25e-01\n",
"Epoch 15: 0.6s to complete\n",
" error(train)=2.66e-01, acc(train)=9.26e-01, error(valid)=2.64e-01, acc(valid)=9.26e-01\n",
"Epoch 20: 0.5s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 25: 0.6s to complete\n",
" error(train)=2.57e-01, acc(train)=9.28e-01, error(valid)=2.64e-01, acc(valid)=9.27e-01\n",
"Epoch 30: 0.6s to complete\n",
" error(train)=2.53e-01, acc(train)=9.29e-01, error(valid)=2.61e-01, acc(valid)=9.30e-01\n",
"Epoch 35: 0.5s to complete\n",
" error(train)=2.50e-01, acc(train)=9.30e-01, error(valid)=2.60e-01, acc(valid)=9.30e-01\n",
"Epoch 40: 0.5s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 45: 0.5s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 50: 0.5s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.60e-01, acc(valid)=9.31e-01\n",
"Epoch 55: 0.5s to complete\n",
" error(train)=2.43e-01, acc(train)=9.32e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 60: 0.6s to complete\n",
" error(train)=2.44e-01, acc(train)=9.31e-01, error(valid)=2.63e-01, acc(valid)=9.29e-01\n",
"Epoch 65: 0.6s to complete\n",
" error(train)=2.41e-01, acc(train)=9.34e-01, error(valid)=2.60e-01, acc(valid)=9.30e-01\n",
"Epoch 70: 0.6s to complete\n",
" error(train)=2.40e-01, acc(train)=9.34e-01, error(valid)=2.62e-01, acc(valid)=9.29e-01\n",
"Epoch 75: 0.6s to complete\n",
" error(train)=2.38e-01, acc(train)=9.34e-01, error(valid)=2.60e-01, acc(valid)=9.30e-01\n",
"Epoch 80: 0.6s to complete\n",
" error(train)=2.38e-01, acc(train)=9.33e-01, error(valid)=2.62e-01, acc(valid)=9.29e-01\n",
"Epoch 85: 0.5s to complete\n",
" error(train)=2.36e-01, acc(train)=9.35e-01, error(valid)=2.61e-01, acc(valid)=9.30e-01\n",
"Epoch 90: 0.6s to complete\n",
" error(train)=2.36e-01, acc(train)=9.34e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 95: 0.6s to complete\n",
" error(train)=2.37e-01, acc(train)=9.34e-01, error(valid)=2.63e-01, acc(valid)=9.29e-01\n",
"Epoch 100: 0.6s to complete\n",
" error(train)=2.35e-01, acc(train)=9.35e-01, error(valid)=2.63e-01, acc(valid)=9.29e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABv9ElEQVR4nO3deXxTVf7/8ddNk+77RgtlKUtZLbsILiCKG4wsA6i4Meqoow6uODouo34RRUcdf67jDDOMOiiCoogKiKhjAdkUC4Igu0CB7i3d09zfH2kDoUVaSJs0vJ+PRx9tbm5uTvoh6Ztz7jnXME3TRERERETER1i83QARERERkaMpoIqIiIiIT1FAFRERERGfooAqIiIiIj5FAVVEREREfIoCqoiIiIj4FAVUEREREfEpCqgiIiIi4lMUUEVERETEpyigioiIiIhPsXq7AZ6Wn5+P3W73djPEQxISEsjOzvZ2M6QJqLb+SXX1X6qt/2rO2lqtVmJiYk68XzO0pVnZ7Xaqqqq83QzxAMMwAGdNTdP0cmvEk1Rb/6S6+i/V1n/5am01xC8iIiIiPkUBVURERER8igKqiIiIiPgUBVQRERER8Sl+N0lKREREvM80TQ4fPuxTE2+kfmVlZVRWVnrseEFBQQQFBZ3SMRRQRURExOMOHz5MUFAQgYGB3m6KnIDNZvPYCkimaVJWVkZJSQlhYWEnfRwN8YuIiIjHmaapcHoaMgyD0NDQU16TXgFVRERERDyqdn3Vk3VSQ/yLFy9mwYIFFBQUkJKSwuTJk+nevXu9+65atYolS5awa9cu7HY7KSkpTJgwgT59+rj2sdvtfPjhh3z99dfk5eXRunVrrr76ard9REREROT00Oge1BUrVjBr1izGjRvHjBkz6N69O9OnTycnJ6fe/Tdv3kx6ejoPPvggTz/9ND179mTGjBns3LnTtc+7777L559/zu9+9zuef/55RowYwbPPPuu2jy8xTRNz60Ycq//n7aaIiIiI+J1GB9SFCxcyfPhwLrjgAlfvaXx8PEuWLKl3/8mTJzN69Gg6d+5McnIykyZNIjk5mXXr1rn2+eabbxg7diz9+vWjVatWXHTRRfTu3ZuPP/745F9ZU/rxexzP/hnz3X9g2nVZVREREfGsjIwMzjvvPBwOR5Mc/6677uKGG25o8P4VFRUMHDiQzMzMJmnPsRo1xG+329mxYwdjxoxx256ens6WLVsadAyHw0FZWRnh4eGubVVVVXVOpA4MDPzVY1ZVVbnNODMMg5CQENfPTapHH4iKgcJ8yFyL0X9I0z7faaq2jk1eT2l2qq1/Ul39l2rb/J588kmmTJmCxeLsS3zuuedYtGgRn3/+uUeO/8QTTzRqCbCgoCBuvfVWnnzySebMmdOgx5zKv5dGBdSioiIcDgdRUVFu26OioigoKGjQMRYuXEhFRQWDBw92bevduzcLFy6ke/futGrVio0bN7J27dpf/V/D/PnzmTdvnut2amoqM2bMICEhoTEv6aQVjLic4nn/IXDtNySM+m2zPOfpKikpydtNkCai2von1dV/Naa2ZWVl2Gy2JmyN95mmSXV1NVare5yqrKw8qRUMah+3evVqdu7cydixY12/Q4vFgmEYJ/ydVlVVNej3HhcX53a7IY+ZOHEi06ZNY+fOnaSlpf3qvoGBgSQnJ5/wmMdzUpOk6kvEDUnJGRkZzJ07l6lTp7qF3N/97ne8/vrr3HXXXRiGQatWrRg2bBhfffXVcY81duxYRo0aVef5s7OzT3lpg4Yw+5wF8/5D+doV7N+8ESM67sQPkkYxDIOkpCQOHDighZ79jGrrn1RX/3Uyta2srHSNdJqmCZUVTdnE4wsMalRPnmmavPbaa7z11lscOnSI1NRU7rrrLkaNGsWKFSuYMGEC//3vf5kxYwabN2/mv//9Ly+88AJdu3bFZrMxb948unbtyvvvv8/KlSuZNm0amzZtIjo6mgkTJnD//fe7Au348ePrfdwHH3zAeeedR0BAAFVVVcyZM4e//vWvACQmJgLw/PPPc8UVV9CmTRueeuopvvzyS7755htuvfVW7r77bu6//36WL19OdnY2rVu35vrrr+emm25yvc677rqLoqIi/vWvf2Gz2Rg9ejTdu3cnKCiId955B5vNxrXXXsu9997rekxERAT9+/dn3rx5TJ069Vd/j5WVlWRlZdXZbrVaG9SZ2KiAGhkZicViqdNbWlhYWKdX9VgrVqzg9ddf55577iE9Pb3Oce+//34qKys5fPgwMTEx/Pe//3UVoT42m+24ab9ZPhhbtYHOPWDbJhzLv8By2YSmf87TlGma+mPnp1Rb/6S6+q+Trm1lBY47Jnq+QQ1gefk9CApu8P4zZszgs88+46mnniI1NZVvv/2WKVOmuPU4Tps2jUcffZR27doRGRkJwNy5c7nuuuv48MMPAcjKyuLaa69l4sSJvPjii2zbto2pU6cSFBTkFvqOfRzAt99+63Y65eWXX86WLVv46quvePfddwFnWKz13HPP8eCDD/LYY48REBCAw+EgOTmZ119/ndjYWNauXcv9999PYmIil19++XFf+9y5c7n55pv5+OOPWbduHXfffTcDBw7kvPPOc+3Tt29fVq1a1aDf5al8DjQqoFqtVjp27EhmZiZnnnmma3tmZiYDBw487uMyMjJ47bXXuPPOO+nXr99x9wsMDCQ2Nha73c6qVavcTgPwRcY5F2Ju24S5fCnmpeN1bo6IiEgLVlpayj/+8Q/mzJnDgAEDAGjfvj1r1qzh7bff5uqrrwZg6tSpbqENoEOHDjz88MOu208//TStW7fmySefxDAMOnfuzIEDB5g+fTp3332369zSYx8HsHfvXlq1auW6HRISQlhYGAEBAfV23o0ZM4Yrr7zSbdt9993n+rldu3asXbuWjz/++FcDavfu3bnnnnsA6NixI7NmzXJN1qqVlJTE3r17j3sMT2n0EP+oUaN46aWX6NixI2lpaSxdupScnBxGjBgBwOzZs8nLy+OOO+4AnOH0lVdeYfLkyaSlpbl6XwMDAwkNDQXg559/Ji8vjw4dOpCXl8fcuXMxTZPRo0d76GU2DaP/2Zjv/AMOZcHPmyCtp7ebJCIi4nsCg5w9mV567obaunUr5eXlXHXVVW7bq6qq6NWrl+v2sSPB4JxPc7Rt27bRv39/t86rgQMHUlJSQlZWFm3atKn3cQDl5eWNupZ9fcd48803eeedd9i7dy/l5eVUVVXRs+ev55Rj17RPTEyss4xocHAwZWVlDW7byWp0QB0yZAjFxcW8//775Ofn07ZtWx588EHX+QT5+fluL2bp0qVUV1czc+ZMZs6c6do+dOhQbr/9dsBZ+HfffZdDhw4RHBxM3759ueOOO07pGq7NwQgOwRh4DmbG55gZn2MooIqIiNRhGEajhtm9pXZy9ptvvllnQlhgYCC7d+8GcHWwHa12JaFapmnWGVmtb8j72McBxMbGUlhY2OB2H9ueBQsW8Pjjj/PII48wYMAAwsLCeO211/j+++9/9TjHTvYyDKPOhPWCgoI6E6yawklNkrr44ou5+OKL672vNnTWeuyxx054vB49evDCCy+cTFO8zjj7QmdAXbcc86qbMULq/qMVERER35eWlkZQUBD79u2r9zTD2oDaEF26dOHTTz91C6pr164lPDz8hLPbe/bsydatW9222Wy2Bq+Junr1avr378/kyZNPqu2/5qeffjphT6wnNHqhfjlGp26Q1AYqKzDXZni7NSIiInKSwsPDueWWW3jsscd477332LVrFxs3bmTWrFm8917jTlG4/vrr2b9/Pw8//DDbtm1j8eLFPPfcc9x8882u80+PZ9iwYaxZs8ZtW9u2bdmzZw8bN24kLy+Piorjr4rQoUMHMjMz+eqrr9i+fTvPPPMMP/zwQ6PafzyrV69m6NChHjnWr1FAPUWGYWCcfSEAZoZnFs8VERER77j//vu5++67efnllxk2bBiTJk3i888/p127do06TnJyMm+99Rbr169nxIgRPPDAA1x11VXceeedJ3zsuHHj2Lp1K9u2bXNtu+yyyxg2bBgTJ07
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAArEAAAF4CAYAAABD1aHMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACM3UlEQVR4nOzdeXhU1f3H8fedzEz2hOwLCZCEfd9RUEEWccGylLovVBEUkFprF6W2LqiltuCvWq0tKloFrUJMjAohgIZFQTbZAzEsIQvZ93Xmnt8fA4MxbAlJZjJ8X8/DQ3Ln3jtn5ptJPnPm3HM0pZRCCCGEEEKIdsTg6AYIIYQQQgjRVBJihRBCCCFEuyMhVgghhBBCtDsSYoUQQgghRLsjIVYIIYQQQrQ7EmKFEEIIIUS7IyFWCCGEEEK0OxJihRBCCCFEuyMhVgghhBBCtDsSYoUQQgghRLtjbM5Ba9asITExkZKSEqKiopgxYwa9evU67/6rV69mzZo15OXlERwczLRp0xg9erT99q1btxIfH09ubi5Wq5Xw8HBuvfVWrrvuugbnKSoq4v3332f37t3U1dURERHBI488QmxsbHMehhBCCCGEaKeaHGK3bNnCsmXLmDlzJj169CAlJYUXX3yRJUuWEBwc3Gj/5ORkVqxYwezZs4mLiyM9PZ0333wTb29vhg4dCoCPjw/Tpk0jMjISo9HIzp07ef311/Hz82PgwIEAVFRU8PTTT9OnTx+eeuop/Pz8OHXqFF5eXpf3DAghhBBCiHanySE2KSmJsWPHMm7cOABmzJjB999/T3JyMnfddVej/VNTUxk/fjwjR44EICwsjCNHjpCQkGAPsX369GlwzM0338zXX3/NoUOH7CE2ISGBoKAg5syZY98vNDS0qc0XQgghhBAuoEkh1mKxkJGRwZQpUxps79+/P2lpaec8pr6+HpPJ1GCb2WwmPT0di8WC0diwCUop9u3bR3Z2Nnfffbd9+/bt2xkwYACLFy/mwIEDBAYGcsMNNzB+/Pjztre+vp76+nr79waDAQ8Pj0t9uEIIIYQQwkk1KcSWlZWh6zr+/v4Ntvv7+1NSUnLOYwYMGMD69esZPnw4MTExZGRksGHDBqxWK+Xl5QQEBABQVVXF7NmzsVgsGAwGHnzwQfr3728/T15eHmvXruWWW25h6tSppKen884772AymRqMr/2x+Ph4PvnkE/v33bt3Z+HChU15yEIIIYQQwgk168IuTdMuaRvA9OnTKSkpYcGCBSil8Pf3Z/To0SQmJmIwnJ0cwcPDg5dffpmamhr27t3Le++9R1hYmH2oga7rxMXF2YcsxMTEkJmZSXJy8nlD7NSpU5k0aVKjNubn52OxWJrz0IWT0TSN8PBwcnNzUUo5ujmihUhdXZfU1nVJbV2TI+pqNBoJCQm5+H5NOamfnx8Gg6FRr2tpaWmj3tkzzGYzc+bMYdasWZSWlhIQEEBKSgqenp74+vra9zMYDISHhwPQpUsXsrKy+PTTT+0hNiAggKioqAbnjoqKYuvWredtr8lkajSU4Qx5gbkWpZTU1AVJXV2X1NZ1SW1dkzPWtUnzxBqNRmJjY9mzZ0+D7Xv27KFHjx4XPTYoKAiDwcDmzZsZPHhwg57Yn1JKNRjP2qNHD7Kzsxvsk52dfUlJXQghhBBCuJYmL3YwadIk1q1bx/r16zl58iTLli2joKCACRMmALB8+XJee+01+/7Z2dmkpqaSk5NDeno6r7zyCpmZmdx55532feLj49mzZw+nTp0iKyuLpKQkUlNTufbaa+373HLLLRw5coRVq1aRm5vLpk2bWLduHRMnTrycxy+EEEIIIdqhJo+JHTlyJOXl5axcuZLi4mKio6N58skn7T2ixcXFFBQU2PfXdZ2kpCSys7Nxc3OjT58+LFy4sMH0WLW1tSxdupTCwkLMZjMdO3bk0UcftU/LBdC1a1eeeOIJli9fzsqVKwkNDeX+++9vEHSFEEIIIcSVQVPONsChDeTn5zcYqnCGUoqKigqnG/MhLsxsNlNXV3fO29zd3XF3d2/jFonLpWkaERER5OTkyOvRxUhtXZfU1jU5oq4mk6nlL+xydRUVFbi7u2M2mx3dFNEEJpPpvG9KqqurqaysxNvb2wEtE0IIIURrafKYWFemlJIA60I0TcPLy0umUxNCCCFckIRY4fLON4exEEIIIdovCbFCCCGEEKLdkRArhBBCCCHOSVVXUfbR26jaGkc3pREJseKi0tPTGThwIBUVFa1y/r///e/2eYYv1c0338wXX3zRKu0RQgghrnSqthZ9zSqsf5hJ6Xuvo9Z95ugmNSIhVlzUokWLuP/++/Hx8QHgo48+olevXi12/ocffpiPPvqoScc89thjvPjii+i63mLtEEIIIa50ylKP/tUX6Atmoz5ZBpXlGKO6QGQnRzetEQmx4oKys7NZu3Ytt99+e5OPPd/crT/l7e1NYGBgk849btw4ysvL+eqrr5rcLiGEEEI0pHQr+jcb0J+eg/rgX1BaBEGhGH75K8Jf/xDDwBGObmIjEmLPQymFqq1xzL8mTia8YcMGpkyZQq9evejTpw/33Xcfx44ds9+enZ3NI488Qp8+fejatSs33XQTO3futN+enJzMTTfdRGxsLH379mXmzJn22z777DN69+5NZGQkAFu2bOHxxx+nrKyMjh070rFjR/7+978DMGLECF555RUee+wxevbsyW9/+1sAXnjhBa655hri4uK4+uqr+etf/9pgXtefDid47LHHeOCBB/jXv/7FoEGD6NOnD0899VSDY9zc3Bg7diyffvppk54rIYQQQpyllELt/Ab9mfmot5dAwSnw64B25ywMz7+BYdR4NDfnXFbAOVvlDOpq0efd5pC7Nrz2P3D3uOT9q6qqmDVrFj179qSqqoq//e1vzJw5k+TkZKqrq5k+fTrh4eG88847hISEsHfvXvvH8CkpKcycOZP58+fzj3/8g7q6OtatW2c/99atW+nfv7/9+6FDh/Lss8/yt7/9jdTUVIAGCwn861//4rHHHuNXv/qVfZu3tzdLliwhPDycgwcP8rvf/Q4fHx/mzJlz3se0ZcsWQkND+fjjjzl69Kg9hN999932fQYOHMgbb7xxyc+TEEIIIWyUUnBwN/qq/8LxdNtGL2+0G3+ONnYSWhNyiKNIiHUBt9xyS4Pv//73v9O/f38OHz7M9u3bKSws5PPPPycgIACAmJgY+77/+Mc/mDx5Mk888YR9W58+fexfZ2Zm0q9fP/v3ZrMZX19fNE0jNDS0UVtGjRrFww8/3GDbY489Zv86OjqaH374gcTExAuGWH9/f1544QXc3Nzo2rUr48aNY9OmTQ1CbEREBFlZWTIuVgghhGgClX4Q/dP3IW2vbYO7B9r4n6HdMAXNy8exjWsCCbHnY3a39Yg66L6b4tixY7z88svs3LmToqIie6jLyspi//799O3b1x5gf2r//v0NguFP1dTU4OFx6e/Gftxre0ZSUhJLly7l2LFjVFZWYrVa7ReJnU/37t1xc3Ozfx8WFsbBgwcb7OPh4YGu69TW1mI0yo+yEEIIcSHqRIYtvO7dbttgNKKNuRntpulofh0c2bRmkb/856FpWpM+0nekGTNmEBkZyV//+lfCw8PRdZ2xY8dSX19/0QB6sdsDAwMpKSm55LZ4eXk1+H7Hjh3MmTOH3/zmN4wZMwZfX18SEhL497//fcHzmEymRtt+Ola4uLgYT09PPD09G4yXFUIIIcRZKjcLlbgc9d1G2waDAW3UeLRbbkcLCnFs4y6DhNh2rqioiCNHjrBo0SJGjLBdObht2zb77b169WLFihUUFxefsze2V69ebNq06byzD/Tt25cjR4402GY2m7FarZfUvu+++46oqKgGY2SzsrIu6diLSUtLazDUQQghxJVFHdiN2vo1BAajdYqDznEQECzLjZ+mCvNRSR+itqyD05/SasOuRfvZXWjhHR3cussnIbad69ChAwEBAbz//vuEhoaSlZXFSy+9ZL99ypQpvPrqqzz44IM8+eSThIaGsm/fPsLCwhg6dCiPP/44t99+O507d2by5Ml
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.2 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `learning_rate = 0.5`"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 0.5s to complete\n",
" error(train)=2.79e-01, acc(train)=9.20e-01, error(valid)=2.74e-01, acc(valid)=9.23e-01\n",
"Epoch 10: 0.5s to complete\n",
" error(train)=2.68e-01, acc(train)=9.24e-01, error(valid)=2.72e-01, acc(valid)=9.26e-01\n",
"Epoch 15: 0.6s to complete\n",
" error(train)=2.55e-01, acc(train)=9.28e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 20: 0.5s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.61e-01, acc(valid)=9.29e-01\n",
"Epoch 25: 0.6s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.73e-01, acc(valid)=9.25e-01\n",
"Epoch 30: 0.6s to complete\n",
" error(train)=2.47e-01, acc(train)=9.31e-01, error(valid)=2.70e-01, acc(valid)=9.27e-01\n",
"Epoch 35: 0.6s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.69e-01, acc(valid)=9.27e-01\n",
"Epoch 40: 0.6s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.72e-01, acc(valid)=9.26e-01\n",
"Epoch 45: 0.6s to complete\n",
" error(train)=2.36e-01, acc(train)=9.35e-01, error(valid)=2.66e-01, acc(valid)=9.29e-01\n",
"Epoch 50: 0.5s to complete\n",
" error(train)=2.38e-01, acc(train)=9.33e-01, error(valid)=2.69e-01, acc(valid)=9.28e-01\n",
"Epoch 55: 0.6s to complete\n",
" error(train)=2.36e-01, acc(train)=9.34e-01, error(valid)=2.68e-01, acc(valid)=9.26e-01\n",
"Epoch 60: 0.6s to complete\n",
" error(train)=2.46e-01, acc(train)=9.29e-01, error(valid)=2.81e-01, acc(valid)=9.22e-01\n",
"Epoch 65: 0.6s to complete\n",
" error(train)=2.33e-01, acc(train)=9.35e-01, error(valid)=2.70e-01, acc(valid)=9.28e-01\n",
"Epoch 70: 0.5s to complete\n",
" error(train)=2.35e-01, acc(train)=9.36e-01, error(valid)=2.75e-01, acc(valid)=9.27e-01\n",
"Epoch 75: 0.5s to complete\n",
" error(train)=2.31e-01, acc(train)=9.36e-01, error(valid)=2.70e-01, acc(valid)=9.26e-01\n",
"Epoch 80: 0.6s to complete\n",
" error(train)=2.35e-01, acc(train)=9.34e-01, error(valid)=2.76e-01, acc(valid)=9.25e-01\n",
"Epoch 85: 0.6s to complete\n",
" error(train)=2.32e-01, acc(train)=9.35e-01, error(valid)=2.75e-01, acc(valid)=9.26e-01\n",
"Epoch 90: 0.5s to complete\n",
" error(train)=2.29e-01, acc(train)=9.37e-01, error(valid)=2.74e-01, acc(valid)=9.26e-01\n",
"Epoch 95: 0.6s to complete\n",
" error(train)=2.31e-01, acc(train)=9.35e-01, error(valid)=2.76e-01, acc(valid)=9.27e-01\n",
"Epoch 100: 0.5s to complete\n",
" error(train)=2.31e-01, acc(train)=9.36e-01, error(valid)=2.77e-01, acc(valid)=9.27e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACL1klEQVR4nOzdd3hUVfrA8e+5mfRKEiAhBUKvoSMgUgWlKEWwYEPXXrCii4ptAcW6rv23oqy6ropKEUECBJQmVUCkhl4SSCC9z9zz+2NgNCZAEpLMZPJ+noeHzJ1b3snJTN6ce857lNZaI4QQQgghhIswnB2AEEIIIYQQfyYJqhBCCCGEcCmSoAohhBBCCJciCaoQQgghhHApkqAKIYQQQgiXIgmqEEIIIYRwKZKgCiGEEEIIlyIJqhBCCCGEcCmSoAohhBBCCJciCaoQQgghhHApFmcHUNXS09OxWq3ODkNUkfr165OamursMEQ1kLZ1T9Ku7kva1n3VZNtaLBbq1at34f1qIJYaZbVaKS4udnYYogoopQB7m2qtnRyNqErStu5J2tV9Sdu6L1dtW7nFL4QQQgghXIokqEIIIYQQwqVIgiqEEEIIIVyKJKhCCCGEEMKlSIIqhBBCCCFciiSoQgghhBDCpUiCKoQQQgghXIokqEIIIYQQwqVIgiqEEEIIIVyK260kJYQQwjmyCq1kn8wm0NmBCCFqPelBFUIIcdG01ryQeISb/rOBrcm5zg5HCFHLSYIqhBDiou04mc/eUwVo4L/bUl1qTW8hRO0jCaoQQoiL9sOedMfXu1Lz2ZqS58RohBC1XaXGoC5evJj58+eTkZFBdHQ0EyZMoE2bNmXuu27dOhISEjh48CBWq5Xo6GjGjRtHp06dSuz3ww8/kJCQQFpaGkFBQVxyySWMHz8eLy+vyoQohBCihpzKK2btkWwAejYJ5ZeDp/nytzQ6RvihlHJydEKI2qjCPahr1qxh1qxZjBkzhhkzZtCmTRumT59OWlpamfvv3LmT+Ph4Jk+ezMsvv0y7du2YMWMGBw4ccOyzcuVKvvjiC8aNG8ebb77JPffcw9q1a/niiy8q/8qEEELUiB/3ZmBqaNfAj2eHtsHTUOyUXlQhxEWocIK6YMECBg4cyKBBgxy9p+Hh4SQkJJS5/4QJExg5ciTNmzcnMjKS8ePHExkZyaZNmxz77Nmzh1atWtGnTx8aNGhAx44dufTSS9m/f3/lX5kQQohqV2wzWZyUAcDwVvWoH+DNlS1CAPjytzQZiyqEqJQK3eK3Wq3s37+fUaNGldgeHx/P7t27y3UO0zTJz88nICDAsa1169asXLmSpKQkmjdvzokTJ/j111/p16/fOc9TXFxMcXGx47FSCl9fX8fXovY7247Snu5H2tZ9rD2SQ2aBjTA/C71igwC4pn04P+7NYGdqPr+dyKdjpL+ToxQXS96z7stV27ZCCWpWVhamaRIcHFxie3BwMBkZGeU6x4IFCygsLKRXr16ObZdeeilZWVlMmTIFAJvNxpAhQ0olwn82Z84cvvnmG8fjuLg4ZsyYQf369cv/gkStEBER4ewQRDWRtq39EhKPATCuSwzRjSIBaNc0hjGd8vlq81G+3ZXJFZ2budwvP1E58p51X67WtpWaJFXWB015PnxWrVrF7NmzmTRpUokk9/fff+e7777jjjvuoEWLFqSkpPDJJ58QEhLC2LFjyzzX6NGjGTFiRKnrp6amYrVaK/qShAtSShEREUFKSorcJnQz0rbuIelUPr8dz8JiQK+GFlJSUhztemUTH77bothyLJPFv+6TXtRaTt6z7qum29ZisZSrM7FCCWpQUBCGYZTqLc3MzCzVq/pXa9as4YMPPuDRRx8lPj6+xHNfffUVffv2ZdCgQQDExsZSUFDA//3f/zFmzBgMo/RQWU9PTzw9Pcu8lrx53IvWWtrUTUnb1m4LdttLS10aG0SIj4ejLbXWhPpaGNIihB92p/O/bal0aOgrvahuQN6z7svV2rZCk6QsFgtNmzZl27ZtJbZv27aNVq1anfO4VatW8e677zJx4kS6dOlS6vnCwsJSH1yGYbjUN0oIIcQfsgqsrDyYBdgnR5XlmrahWAzFjtR8fjshM/qFEOVX4Vn8I0aMYNmyZSQmJnL06FFmzZpFWloagwcPBuCLL77gnXfecex/Njm95ZZbaNmyJRkZGWRkZJCX98eHVdeuXVmyZAmrV6/m5MmTbNu2ja+++opu3bqV2XsqhBDCuZbsy6TY1DQL9aFlmE+Z+4T5eXJFc/vdtf9tkxn9Qojyq/AY1N69e5Odnc23335Leno6MTExTJ482TGeID09vURN1KVLl2Kz2Zg5cyYzZ850bO/Xrx/3338/ANdccw1KKb788ktOnz5NUFAQXbt25YYbbrjY11ctdH4eeslcsFoxxtzi7HCEEKJG2UzNojMrRw1vGXLeW/dj2oWxOCnT0YsaHyFjUYUQF6a0m/1Jm5qaWqL8VHXQO7divjEFPCwYU99HhTes1uvVVUopIiMjSU5Olp4XNyNtW7utO5LN9J+PEejtwcxRzfC22O90natdP9yQwsI9GbRr4Mv0wY2dFba4CPKedV813baenp7lmiQl988rQbXpCG06gs2Kni+rXQkh6pYfzvSeDm4W7EhOz+eadmFYDMXvJ/P57URudYcnhHADkqBWkjHafmtf/7ICfeyQk6MRQoiacSSzkK0peRgKhrYoe3LUX4X7eTK4mX0s6pfbyl4WWwgh/kwS1EpScS2ga2/QGnPOZ84ORwghasTZsafdowJoEFB2qb+ynO1F3S69qEKIcpAE9SIYo24Cw4Ct69FJO50djhBCVKu8YhvL9p+/tNS51Pf/Uy/qb6eqPDYhhHuRBPUiqIho1KWXA2DO+VQGjgsh3Nry/VkUWE2ig7yIb+hX4ePtvaiw/USe9KIKIc5LEtSLpEZcDxZP2PM7/L7Z2eEIIUS10Fo7JkcNa1mvUqtC2XtRQwDpRRVCnJ8kqBdJhYajBg4HwPzuU7RpOjkiIYSoeltT8jiWVYSvxWBA06BKn+fPvajbZXUpIZzmRE4Ri/ak84/lR0jYdcLZ4ZRS4UL9ojQ1dCx6ZQIcOYDeuArVo6+zQxJCiCq18Ezv6cCmQfh5elT6PPX9Pbm8WQg/7s3gy9/SmNowtqpCFEKcR7HNZEdqPpuO5bDpeC5Hs4ocz4UEptKhW5gToytNEtQqoAKCUENGo+f9Fz33c3SX3iiLfGuFEO7hRE4RG47lAPbb+xdrbLswlu7L4LcTefx+Io92lRjPKoS4sNTcYjYdz2Hz8Vy2puRSYP1jroyhoHW4L12jAhjWqQlYs50XaBkki6oi6vKr0YkLIDUFvWoJqv9QZ4ckhBBV4se9GZgaOkb4ER3sfdHnq+/vyaCmISxOsvei/kN6UYWoEsU2zc7UPDYdz2Xz8RwOZxaVeL6ejwedGwXQtZE/nSL8CfD2sK8kVT+A5GRJUN2S8vFFjbgO/b//Qy/4Ct1rIMr74j/IhRDCmQqtJkv2ZQIwvAp6T88a2y6MZfsz2HYij99P5tGugfSiClEZaXnFbD6ey6bjOWxJzqPA+sdcGENByzBfukb507VRAHH1vDEqMcHRGSRBrUKq7xXohLlw6iQ6cQFq6DXODkkIIS7KykNZZBfaaOBvoVtUQJWdt0HAX3pRB0kvqhDlYTU1u1Lz2XTcPpb0UEZhieeDfTzoEulPl0YBdI70J9C78mPGnUkS1CqkLJ6oUTeiZ76J/vEbdN8rUP5V94EuhBA1SWvND7vtk6OGtqiHh1G1PS+OXtQU6UUV4nxOOXpJ7WNJ84r/6CVVQMtwH7o2CqBLI3+ahfrUml7S85EEtYqpHn3RP34Hxw6hF3+LGnOrs0MSQohK2Z1WwP70Qrw8FJc3D6ny80svqhBls5maXWn2Gfebk3M5kF6ylzTI+2wvqT+dI/0J8nG/dM79XpGTKcMDY/TNmO9MRS/7Hj1wBCrEtUo3CCFEeZwtzH9Z4yCCquk24dkZ/dtS8thxMo+20ovqknaczGN/fhr+tiLC/Sxu0UPnajLyrWw8M+N+S3IuuX/
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAArEAAAF0CAYAAAA0F2G3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACpVklEQVR4nOzdd3hUVfrA8e+5M5n03iEBEjqBgDQ1SBFQpLiUxbY2VhGUtWDdn7quDbuCq65tUVkL6loQjAohgoai9N4DBEIK6b3O3PP7Y5KBSALpM0nO53l4SGZueZObTN459z3vEVJKiaIoiqIoiqK0IZq9A1AURVEURVGUhlJJrKIoiqIoitLmqCRWURRFURRFaXNUEqsoiqIoiqK0OSqJVRRFURRFUdoclcQqiqIoiqIobY5KYhVFURRFUZQ2RyWxiqIoiqIoSpujklhFURRFURSlzVFJrKIoiqIoitLmGO0dgD3k5uZiNpvtHYbSTAIDA8nMzLR3GEozU9e1/VLXtv1S17Z9au3rajQa8fX1vfB2rRCLwzGbzVRWVto7DKUZCCEA6zWVUto5GqW5qOvafqlr236pa9s+OfJ1VeUEiqIoiqIoSpujklhFURRFURSlzVFJrKIoiqIoitLmqCRWURRFURRFaXM65MSuukgpKSoqcrjCZeX8SktLqaioqPU5Z2dnnJ2dWzkiRVEURVFamkpiz1JUVISzszMmk8neoSgN4OTkVGu3CSklpaWlFBcX4+7ubofIFEVRFEVpKaqc4CxSSpXAtiNCCNzc3FRPYEVRFEVph1QSq7R71T3uFEVRFEVpP1QSqyiKoiiKorQ5KolVFEVRFEVpAJmXjUw7Ze8wOjyVxCoXlJiYyKBBgygqKmqR47/22mtcccUVDdpn0qRJ/Pjjjy0Sj6IoiqLURZ46jv7Pv6E/+Tf02C+Qum7vkDoslcQqF/TSSy9x66234uHhAcCXX35J3759m+34d955J19++WWD9pk/fz7PP/88unrxUBRFUVqJzM5Af/1pKC0BKZHLl6K//TyypNjeoXVIKolVzis1NZXVq1dz3XXXNXjfunq3/pG7uzt+fn4NOva4ceMoLCzkl19+aXBciqIoitJQsqgA/fWnID8HOnVB3DAHjE6wazP6cw8iU07aO8QORyWxdZBSIsvL7POvgYstrF27lmnTptG3b1+ioqK45ZZbSEpKsj2fmprKXXfdRVRUFD169GDixIls377d9nxcXBwTJ04kMjKS/v37M3v2bNtz33//Pf369aNTp04AbNy4kQceeICCggI6d+5M586dee211wC4+OKLef3115k/fz59+vTh4YcfBuC5557jsssuo3v37lx66aW8/PLLNfq6/rGcYP78+dx22228++67XHTRRURFRfHYY4/V2MdgMDB27Fi+++67Bn2vFEVRFKWhZHk5+lsLIP0U+Aag3fcU2tgpaI+8CH4BkJGK/sJDyG0b7B1qh6IWO6hLRTn63dfa5dTaW/8DZ5d6b19SUsKcOXPo06cPJSUlvPrqq8yePZu4uDhKS0uZOXMmISEhfPTRRwQGBrJnzx7bbfj4+Hhmz57NvffeyxtvvEFFRQU///yz7dibNm0iOjra9vnQoUN5+umnefXVV0lISACosZDAu+++y/z587nvvvtsj7m7u7No0SJCQkI4cOAAjzzyCB4eHsybN6/Or2njxo0EBQXx1Vdfcfz4cVsSfuONN9q2GTRoEO+88069v0+KoiiK0lDSYkH/zytw9CC4uaPd9xTCLwAAEdET7R+L0N9/BQ7uRn/3JcSEGYjpNyMMBjtH3v6pJLYdmDx5co3PX3vtNaKjozl8+DBbt24lOzubH374AV9fXwAiIiJs277xxhtMnTqVhx56yPZYVFSU7ePk5GQGDBhg+9xkMuHp6YkQgqCgoHNiGTFiBHfeeWeNx+bPn2/7ODw8nKNHj7JixYrzJrHe3t4899xzGAwGevTowbhx41i/fn2NJDY0NJSUlBRVF6soinIWaTaD1BFOavGeppJSIpe+C7s2g9EJ7e4nEJ271NhGeHqjzX8a+e3HyLhlyFXfIk8eRbvjYYSnl50i7xgalcSuWrWKFStWkJeXR1hYGLNmzTrvRJ+VK1eyatUqMjIyCAgIYMaMGYwePdr2/KZNm1i2bBnp6elYLBZCQkK4+uqrGTVqVI3j5OTk8Omnn7Jz504qKioIDQ3lrrvuIjIysjFfxvmZnK0jovZgcm7Q5klJSbzyyits376dnJwcW1KXkpLCvn376N+/vy2B/aN9+/bVSAz/qKysDBeX+o8Knz1qWy02NpbFixeTlJREcXExFovFNkmsLr169cJw1rvY4OBgDhw4UGMbFxcXdF2nvLwco1G9H1MUpWOTKSeQ6+KQv60FoxHtsVcR/ucONij1J7//ApmwCoSGdsdDiJ79at1OGAyIa/6K3q0n8r9vwIFd6AvuR5v3KKJrj1aOuuNo8F/+jRs3smTJEmbPnk3v3r2Jj4/n+eefZ9GiRQQEBJyzfVxcHJ9//jlz586le/fuJCYm8t577+Hu7s7QoUMB8PDwYMaMGXTq1Amj0cj27dt5++238fLyYtCgQQAUFRXxxBNP2Oojvby8OH36NG5ubk37DtRBCNGgW/r2NGvWLDp16sTLL79MSEgIuq4zduxYKisrL5iAXuh5Pz8/8vLy6h3LH6/Htm3bmDdvHg8++CBjxozB09OT5cuX8/7775/3OE5OTuc89sda4dzcXFxdXXF1da1RL6soitJRyPIy5Nb11kTr2KEaz+kfLkJ7cAFCU7e1G0NPWIX8/nMAxF/mIAZfesF9tGGXITuFo7/9PGSkob/4d8TN89BixrV0uC2qoXN1WkuDJ3bFxsYyduxYxo0bZxuFDQgIIC4urtbtExISGD9+PDExMQQHBzNixAjGjh3L8uXLbdtERUUxfPhwwsLCCAkJYdKkSXTt2pWDBw/atlm+fDn+/v7MmzePHj16EBQUxIABAwgJCWnEl91+5OTkcOTIEe677z5GjhxJz549yc/Ptz3ft29f9u3bR25ubq379+3bl/Xr19d5/P79+3PkyJEaj5lMJiwWS73i27JlC2FhYdx3330MHDiQyMhIUlJS6rXvhRw6dKhGqYOiKEpHIU8cRf/0bfSHZyGXvGFNYA0GGHwp4vb7wdkVDu9Dxn1n71DbJLlzE/JT65wLMelatDGT6r2v6NwV7fHXIHoYmCuRH/0L/bN3kea2N9giLRb0jWtIn/tnZMoJe4dzjgaNxJrNZo4dO8a0adNqPB4dHc2hQ4dq3aeysvKcUTWTyURiYiJms/mc28BSSvbu3UtqamqN29xbt25l4MCBLFy4kP379+Pn58eVV17J+PHj64y3srKyxgidEAJXV1fbx+2Bj48Pvr6+fPrppwQFBZGSksILL7xge37atGm8+eab3H777Tz66KMEBQWxd+9egoODGTp0KA888ADXXXcdXbt2ZerUqZjNZtauXWurVx09ejQPP/wwFovFdns/LCyM4uJi1q1bR1RUlG00tDYRERGkpKSwfPlyBg4cyM8//8xPP/3ULF/75s2bzyk5qUt7ud4dRfX1Utet/VHXtvFkaQly06/oCavg5NEzTwSFoo28EhEzDuFtLR3TzWb0/76J/O4ziBqM6NICZXd/0F6urTx60DqRS+qIy65Am35Tg78m4e6JuPsfyB++RF/xOfKXH5HJxzDc9X8IH/8Wirz5SIvF+rMW+4V1RBkQq5djmHWvvUOroUFJbEFBAbqu4+3tXeNxb2/vOm85Dxw4kDVr1jB8+HAiIiI4duwYa9euxWKxUFhYaKvVLCkpYe7cuZjNZjRN4/bbb69RX5mRkcHq1auZPHky06dPJzExkY8++ggnJ6ca9bVnW7ZsGV9//bXt84iICF566SUCAwNr3b60tLTW29iO7v333+fxxx9n3LhxdO/eneeff55p06ZhMBhwd3fnq6++4sknn+Tmm2/GYrHQq1cvXnzxRdv3bvHixSxcuJB///vfeHp6cskll9i+D1dddRWPPfYYGzduZOzYsQDExMRw6623Mm/ePHJycnj
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.5 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine + softmax model\n",
"model = MultipleLayerModel([\n",
" AffineLayer(input_dim, output_dim, param_init, param_init),\n",
" SoftmaxLayer()\n",
"])\n",
"\n",
"# Initialise a cross entropy error object\n",
"error = CrossEntropyError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"|`learning_rate`| Final `error(train)` | Final `error(valid)` |\n",
"|---------------|----------------------|----------------------|\n",
"| 0.05 | $2.53\\times 10^{-1}$ | $2.59\\times 10^{-1}$|\n",
"| 0.1 | $2.43\\times 10^{-1}$ | $2.59\\times 10^{-1}$|\n",
"| 0.2 | $2.35\\times 10^{-1}$ | $2.63\\times 10^{-1}$|\n",
"| 0.5 | $2.31\\times 10^{-1}$ | $2.77\\times 10^{-1}$|\n",
"\n",
"<span style=\"color:green\">\n",
"Increasing the learning rate, as would be expected, increase the speed of learning, with the final training error reached monotonically decreasing over the learning rates tested as the learning rate was increased. Note however the validation set error increases for larger learning rates - this suggests the model is overfitting to the data, with the larger learning rates causing the model to begin overfitting sooner - we could have afforded to halt learning earlier in these cases when there was no further improvement in the validation set error. Notice also the error curves for the largest learning rate value are much more noisy suggesting learning is becoming quite unstable with this large a step size, with a lot of the gradient descent steps overshooting and causing the error function value to increase.\n",
"</span>"
]
},
2024-10-03 15:53:33 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Optional extra: more efficient softmax gradient evaluation\n",
"\n",
"In the lectures you were shown that for certain combinations of error function and final output layers, that the expressions for the gradients take particularly simple forms. \n",
"\n",
"In particular it can be shown that the combinations of \n",
"\n",
" * logistic sigmoid output layer and binary cross entropy error function\n",
" * softmax output layer and cross entropy error function\n",
" \n",
"lead to particularly simple forms for the gradients of the error function with respect to the inputs to the final layer. In particular for the latter softmax and cross entropy error function case we have that\n",
"\n",
"\\begin{equation}\n",
" y^{(b)}_k = \\textrm{Softmax}_k\\lpa\\vct{x}^{(b)}\\rpa = \\frac{\\exp(x^{(b)}_k)}{\\sum_{d=1}^D \\lbr \\exp(x^{(b)}_d) \\rbr}\n",
" \\qquad\n",
" E^{(b)} = \\textrm{CrossEntropy}\\lpa\\vct{y}^{(b)},\\,\\vct{t}^{(b)}\\rpa = -\\sum_{d=1}^D \\lbr t^{(b)}_d \\log(y^{(b)}_d) \\rbr\n",
"\\end{equation}\n",
"\n",
"and it can be shown (this is an instructive mathematical exercise if you want a challenge!) that\n",
"\n",
"\\begin{equation}\n",
" \\pd{E^{(b)}}{x^{(b)}_d} = y^{(b)}_d - t^{(b)}_d.\n",
"\\end{equation}\n",
"\n",
"\n",
"The combination of `CrossEntropyError` and `SoftmaxLayer` used to train the model above calculate this gradient less directly by first calculating the gradient of the error with respect to the model outputs in `CrossEntropyError.grad` and then back-propagating this gradient to the inputs of the softmax layer using `SoftmaxLayer.bprop`.\n",
"\n",
"Rather than computing the gradient in two steps like this we can instead wrap the softmax transformation in to the definition of the error function and make use of the simpler gradient expression above. More explicitly we define an error function as follows\n",
"\n",
"\\begin{equation}\n",
" E^{(b)} = \\textrm{CrossEntropySoftmax}\\lpa\\vct{y}^{(b)},\\,\\vct{t}^{(b)}\\rpa = -\\sum_{d=1}^D \\lbr t^{(b)}_d \\log\\lsb\\textrm{Softmax}_d\\lpa \\vct{y}^{(b)}\\rpa\\rsb\\rbr\n",
"\\end{equation}\n",
"\n",
"with corresponding gradient\n",
"\n",
"\\begin{equation}\n",
" \\pd{E^{(b)}}{y^{(b)}_d} = \\textrm{Softmax}_d\\lpa \\vct{y}^{(b)}\\rpa - t^{(b)}_d.\n",
"\\end{equation}\n",
"\n",
"The final layer of the model will then be an affine transformation which produces unbounded output values corresponding to the logarithms of the unnormalised predicted class probabilities. An implementation of this error function is provided in `CrossEntropySoftmaxError`. The cell below sets up a model with a single affine transformation layer and trains it on MNIST using this new cost. If you run it with equivalent hyperparameters to one of your runs with the alternative formulation above you should get identical error and classification curves (other than floating point error) but with a minor improvement in training speed.\n"
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 11,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 0.4s to complete\n",
" error(train)=3.11e-01, acc(train)=9.13e-01, error(valid)=2.92e-01, acc(valid)=9.18e-01\n",
"Epoch 10: 0.4s to complete\n",
" error(train)=2.89e-01, acc(train)=9.20e-01, error(valid)=2.77e-01, acc(valid)=9.23e-01\n",
"Epoch 15: 0.4s to complete\n",
" error(train)=2.79e-01, acc(train)=9.22e-01, error(valid)=2.70e-01, acc(valid)=9.24e-01\n",
"Epoch 20: 0.4s to complete\n",
" error(train)=2.72e-01, acc(train)=9.24e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 25: 0.4s to complete\n",
" error(train)=2.68e-01, acc(train)=9.25e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n",
"Epoch 30: 0.4s to complete\n",
" error(train)=2.63e-01, acc(train)=9.27e-01, error(valid)=2.62e-01, acc(valid)=9.26e-01\n",
"Epoch 35: 0.4s to complete\n",
" error(train)=2.60e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 40: 0.4s to complete\n",
" error(train)=2.59e-01, acc(train)=9.28e-01, error(valid)=2.61e-01, acc(valid)=9.28e-01\n",
"Epoch 45: 0.4s to complete\n",
" error(train)=2.55e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n",
"Epoch 50: 0.4s to complete\n",
" error(train)=2.54e-01, acc(train)=9.30e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 55: 0.4s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.59e-01, acc(valid)=9.30e-01\n",
"Epoch 60: 0.4s to complete\n",
" error(train)=2.52e-01, acc(train)=9.29e-01, error(valid)=2.60e-01, acc(valid)=9.29e-01\n",
"Epoch 65: 0.4s to complete\n",
" error(train)=2.50e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 70: 0.4s to complete\n",
" error(train)=2.49e-01, acc(train)=9.31e-01, error(valid)=2.59e-01, acc(valid)=9.31e-01\n",
"Epoch 75: 0.4s to complete\n",
" error(train)=2.47e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 80: 0.4s to complete\n",
" error(train)=2.46e-01, acc(train)=9.31e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 85: 0.4s to complete\n",
" error(train)=2.45e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.31e-01\n",
"Epoch 90: 0.4s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 95: 0.4s to complete\n",
" error(train)=2.44e-01, acc(train)=9.32e-01, error(valid)=2.58e-01, acc(valid)=9.30e-01\n",
"Epoch 100: 0.4s to complete\n",
" error(train)=2.43e-01, acc(train)=9.33e-01, error(valid)=2.59e-01, acc(valid)=9.29e-01\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABxc0lEQVR4nO3deXhU5f3+8feZzGQnC9khCQRCWA2boOCCsmitVMCirahI1apVf7hVK61atYpiq9aq1dbylapFERQF3BCQVgRBUInsa1gTyL5vkzm/PyYZGBIkCUlmMtyv68o1M2fO8px8GLh55jnPMUzTNBERERER8RIWTzdAREREROR4CqgiIiIi4lUUUEVERETEqyigioiIiIhXUUAVEREREa+igCoiIiIiXkUBVURERES8igKqiIiIiHgVBVQRERER8SoKqCIiIiLiVayebkBrKygowG63e7oZ0kpiYmLIycnxdDOkDai2vkl19V2qre9qz9parVYiIyNPvV47tKVd2e12ampqPN0MaQWGYQDOmpqm6eHWSGtSbX2T6uq7VFvf5a211Vf8IiIiIuJVFFBFRERExKsooIqIiIiIV1FAFRERERGv4nMXSYmIiIjnmaZJaWmpV114I42rqKigurq61fYXEBBAQEDAae1DAVVERERaXWlpKQEBAfj7+3u6KXIKNput1WZAMk2TiooKysrKCAkJafF+9BW/iIiItDrTNBVOz0CGYRAcHHzac9IroIqIiIhIq6qfX7WlFFBFRERExKsooLaA6ajF3L4Jx7r/ebopIiIiIj5HAbUltm/C8ZffY77zGqaj1tOtERERER+zatUqLrzwQhwOR5vs/+677+bGG29s8vpVVVUMGzaMjIyMNmnPiRRQW6JXfwjpBCVFsHOLp1sjIiIiPubJJ59k+vTpWCzOqPbss88ybty4Vtv/448/zvPPP9/k9QMCArjtttt48sknW60NP6ZF00x99tlnLFq0iMLCQhITE5k2bRp9+/ZtdN1t27bxn//8h0OHDlFVVUVMTAxjx45l/PjxrnUOHDjAvHnz2Lt3Lzk5Odxwww1cfvnlLTujdmBYrRiDhmN+tRxzw2qM3md5ukkiIiLSzkzTpLa2FqvVPU5VV1e3aAaD+u2++eYb9u7d65aVmqqmpgabzXbK9cLCwpq970mTJvHEE0+wc+dOevXq1eztm6PZPairV69mzpw5XHnllcyaNYu+ffsyc+ZMcnNzG10/ICCASy+9lMcee4znn3+eK6+8knnz5rFs2TLXOlVVVcTFxTFlyhQiIiJafDLtyRh6HgDmt2sw26j7XURExBeYpolZVemZn2beKMA0Tf7+978zYsQIevbsydixY1myZAngzEBdu3Zl5cqVXHbZZaSkpLB27VomT57MH/7wBx599FEGDBjANddcA8CaNWu4/PLLSUlJYfDgwcycOdNt+qWTbbdo0SIuvPBCAgMDAZg3bx7PPfccW7ZsoWvXrnTt2pV58+YB0LVrV9544w1+9atfkZqaygsvvEBtbS333Xcf5557Lj179uSCCy7gX//6l9t5nvgV/+TJk3n44Yd54okn6N+/P4MGDeLZZ59126Zz584MHTqUDz74oFm/05Zodg/qkiVLGD16NGPGjAFg2rRpbNy4kaVLlzJlypQG66ekpJCSkuJ6HRsby7p169i6dStjx44FIDU1ldTUVADmzp3bohNpd30GQlAwFOXDnu2Q2ngPsoiIyBmvugrHnVd75NCWl96FgMAmrz9r1iw++eQTnnrqKVJSUvj666+ZPn06UVFRrnWeeOIJHnnkEZKTk109kfPnz2fq1Kmu8JaVlcX111/P1VdfzQsvvMCuXbu4//77CQgI4L777nPt68TtAL7++msmTpzoen3FFVewfft2Vq5cyTvvvANAp06dXO8/++yzzJgxg0cffRQ/Pz8cDgcJCQm8+uqrdO7cmfXr1/PAAw8QGxvLFVdccdJznz9/PrfccguLFy9mw4YN3HPPPQwbNowLL7zQtc7gwYNZu3Ztk3+fLdWsgGq329mzZ4/bLw0gPT2d7du3N2kfe/fuZfv27fzyl79szqEbqKmpcbvrgWEYBAUFuZ63NcPfHzN9OObalZjfrsbSq1+bH/NMU1/H9qintC/V1jeprr7rTKlteXk5r732GvPmzePss88GoFu3bnzzzTe89dZbXHvttQDcf//9bqENoHv37jz00EOu108//TRdunThySefxDAMUlNTyc7OZubMmdxzzz2usaUnbgdw8OBB4uLiXK+DgoIICQnBz8+P2NjYBu2eOHFig1z129/+1vU8OTmZ9evXs3jx4h8NqH379uXee+8FoEePHsyZM8d1sVa9+Ph4Dh48eNJ9HO90/rw0K6AWFxfjcDgIDw93Wx4eHk5hYeGPbnvbbbdRXFxMbW0tV111lasHtqUWLlzIggULXK9TUlKYNWsWMTExp7Xf5igfN568tSuxbFxH/F0P+fwH11Pi4+M93QRpI6qtb1JdfVdzaltRUeEaC2larfCPhW3VrB/nH9Dkf5/37NlDZWWl66v2ejU1NZx11lmusaZDhw51G+dpGAaDBw92W7Z7926GDRvmNhZ1xIgRlJWVkZOTQ2JiYqPbAVRWVhISEuK23GKxYBhGo+NLhwwZ0mD5nDlz+M9//sPBgwepqKigpqaGAQMGuNY7cX+GYdC/f3+3/cTHx5Ofn++2LCQkxK22J+Pv709CQsKPrvNjWnSRVGOFPlXxH3/8cSorK9mxYwdz584lPj6e888/vyWHB5wDdY8fPFx//JycnNO+vVZTmV26Q0AgtUezyPr6S4zubTtg+ExjGAbx8fFkZ2c3ewyReDfV1jeprr6rJbWtrq52v7+7xa+NWncKzcgE1dXVALzxxhsNwri/vz/79u0DGt673jRNAgIC3JY5HA5M03RbVv/cbrdTU1PT6HbgHOuZl5d3yv3VO3EfixYt4pFHHuHhhx/m7LPPJiQkhFdeeYXvvvvOtd7x+7PZbJimicViaXBe9W2tl5eXR1RUVKPtOPF3mZWV1WC51WptUmdiswJqWFgYFoulQW9pUVFRg17VE9V3SScnJ1NUVMT8+fNPK6DabLaTpvd2+4vR5o8xYCjmhq9wrP8KS7fU9jnuGcY0Tf1j56NUW9+kuvouX69tWloaAQEBHDp0iBEjRjR4vz6gNkWvXr34+OOPMU3T1Ym2fv16QkNDT9mz2L9/f3bs2OG2zGazNXlO1HXr1jF06FCmTZvWorb/mG3bttG/f/8mrXs6f1aadRW/1WqlR48eDSZpzcjIoHfv3k3eT30i9wmuq/lX+/SHVkRExNeFhoZy66238uijj/Luu++SmZnJpk2bmDNnDu+++26z9nXDDTdw+PBhHnroIXbt2sVnn33Gs88+yy233OIaf3oyF110Ed98843bsqSkJPbv38+mTZvIz8+nqqrqpNt3796djIwMVq5cye7du3nmmWfYuHFjs9p/MuvWrWPUqFGtsq8f0+xppsaPH8/y5ctZsWIFBw8eZM6cOeTm5romj507dy4vvfSSa/1PP/2U9evXk5WVRVZWFl988QWLFy/mggsucK1jt9vJzMwkMzMTu91Ofn4+mZmZZGdnt8Ipti3jrKFg84ejWXAw09PNERERkdPwwAMPcM899/DSSy9x0UUXMWXKFD7//HOSk5ObtZ+EhATefPNNvv/+e8aNG8eDDz7INddcw1133XXKba+88kp27NjBrl27XMt++tOfctFFF3H11Vdz1lln/ehUT9dffz2XXXYZv/nNb/jZz35GQUEBN9xwQ7Pa35j169dTUlLSLnPVG2YLuv3qJ+ovKCggKSmJG264gX79nFexv/zyy+Tk5PDoo48C8Mknn7Bs2TKOHj2KxWIhPj6eMWPGMHbsWNf/II4ePcqdd97Z4Dj9+vVz7aepcnJyTjkuorXVvjwTvv8aY/wvsEy4tl2P7csMwyAhIYGsrCz1TvsY1dY3qa6+qyW1LS4ubtFk8OL0xBNPUFxczDPPPNPmxzpxTO3J3HLLLQwYMIDp06efct2T1d9mszVpDGqLAqo380RAdXz
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAroAAAF0CAYAAADM95pAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACFXElEQVR4nOzdeXxU1f3/8dedLftKViCBJCRAwi6yaoMstQqWpVTrUuWrKDagpa31V8UFEbe2otVqa6sWawuiYEwMKjGAImBBQAEJAiEgIQkmIZnsyyzn98fA4JggCWaZTD7PxyMPZu6ce++5c5Lwzplzz9GUUgohhBBCCCE8jK6rKyCEEEIIIURHkKArhBBCCCE8kgRdIYQQQgjhkSToCiGEEEIIjyRBVwghhBBCeCQJukIIIYQQwiNJ0BVCCCGEEB5Jgq4QQgghhPBIEnSFEEIIIYRHkqArhBBCCCE8kqGrK+COKioqsFqtXV0N0Y7Cw8MpLS3t6mqIdibt6rmkbT2XtK1n6sx2NRgMhISEtK5sB9elW7JarVgslq6uhmgnmqYBjnZVSnVxbUR7kXb1XNK2nkva1jO5c7vK0AUhhBBCCOGRJOgKIYQQQgiPJEFXCCGEEEJ4JAm6QgghhBDCI8nNaG3U2NhIY2NjV1dDtFF9fT1NTU0tvubl5YWXl1cn10gIIYQQHU2CbhvU1taiaRoBAQHOOwxF92A0GlucSUMpRX19PbW1tfj5+XVBzYQQQgjRUWToQhtYrVZ8fX0l5HoQTdPw9fWVeZOFEEIIDyRBtw0k4HouaVshhBDC80jQFUIIIYQQHkmCrhBCCCGEuGiqsoKqN/+Fstu6uirNSNAV7SIvL48RI0ZQU1PTIcd/+umnmTZtWpv2ufrqq3nvvfc6pD5CCCFET6cqTmN/45/Y/jCfytdeQO3a3tVVakaCrmgXTz31FLfccgv+/v4ArFmzhsGDB7fb8e+8807WrFnTpn0WL17M448/jt1ub7d6CCGEED2dOl2K/b9/x37/7aiN74KlCdPAIWjBIV1dtWYk6IofrKioiA8//JDrrruuzfueb27b7/Lz8yM0NLRNx54yZQrV1dV89NFHba6XEEIIIVyp0lPY//1X7EsWoD56D6xWGJCM7jfLiHj6X2hJQ7q6is1I0P0BlFKoxoau+VKqTXXdvHkzs2bNYvDgwaSkpHDzzTdz/Phx5+tFRUX86le/IiUlhQEDBnDVVVexZ88e5+vZ2dlcddVVxMfHM2TIEObPn+987d133yU5OZnevXsDsH37dn77299SVVVFnz596NOnD08//TQAY8eO5dlnn2Xx4sUMGjSI3//+9wA89thjXHbZZSQkJDB+/Hj++Mc/usx7+92hC4sXL+bWW2/l73//OyNHjiQlJYX777/fZR+9Xs/kyZN555132vReCSGEEOIc9U0R9n/9BfsDd6I+yQabFQYORXfPY+jufQJdyki3nb3oohaM2LBhA5mZmZjNZvr27cu8efO+92PqDz74gA0bNlBSUkJYWBhz5swhNTXV+fqOHTtIT0/n1KlT2Gw2oqKiuOaaa/jRj37kLJOens7OnTspLCzEZDKRlJTETTfd5AxXAC+88AIff/yxy7kTExN57LHHLuYyL6ypEfuiazvm2Beg++ub4OXd6vJ1dXXccccdDBo0iLq6Ov785z8zf/58srOzqa+vZ+7cuURFRfGvf/2L8PBw9u/f7/zIPycnh/nz53P33Xfz3HPP0dTUxMaNG53H3rFjB8OGDXM+Hz16NI888gh//vOf2bJlC4DLYgx///vfWbx4Mb/+9a+d2/z8/HjmmWeIiori4MGD3Hvvvfj7+5OWlnbea9q+fTsRERG89dZbHDt2zBnUb7zxRmeZESNG8Le//a3V75MQQgghHFTxSdR7b6J2bAF1Zhhg8kh0M65DS0zu2sq1UpuD7vbt21m5ciXz589n4MCB5OTk8Pjjj/PMM88QFhbWrHx2djarV69mwYIFJCQkkJeXx0svvYSfnx+jR48GwN/fnzlz5tC7d28MBgN79uzhxRdfJDAwkBEjRgCQm5vLlVdeSUJCAjabjTfeeIPly5ezYsUKvL3PBb4RI0a4hCODQRZ/A5g+fbrL86effpphw4Zx+PBhdu3axenTp1m/fj0hIY7xNXFxcc6yzz33HDNnzuSee+5xbktJSXE+LigoYOjQoc7nJpPJuXpcREREs7pMnDiRO++802Xb4sWLnY9jYmI4evQomZmZ3xt0g4KCeOyxx9Dr9QwYMIApU6awdetWl6AbHR1NYWGhjNMVQgghWkkVfo1a/yZq11Y4+wny0NGOgBs/sGsr10ZtToFZWVlMnjyZKVOmADBv3jz27t1LdnY2N9xwQ7PyW7ZsYerUqUyYMAGAyMhIjhw5QkZGhjPofjs0geNu+Y8//pivvvrKGXSXLFniUiYtLY358+eTn59PcvK5vyoMBgPBwcFtvayLY/Jy9Kx2BZNXm4ofP36cP/3pT+zZs4fy8nJn8CssLOTAgQMMGTLEGXK/68CBAy7h8bsaGhpc/ti4kG/3/p6VlZXFyy+/zPHjx6mtrcVmszlvbDufpKQk9Hq983lkZCQHDx50KePt7Y3dbqexsVH+6BFCCCG+hzqRj339Gtjz6bmNI8ahm3EtWr8BXVexH6BN//NbrVby8/OZNWuWy/Zhw4Zx6NChFvexWCwYjUaXbSaTiby8PKxWa7PwoZTiyy+/pKio6HvDVV1dHUCzMJSbm8v8+fPx8/Nj8ODBXH/99QQFBZ23bt8e06lpGj4+Ps7HF6JpWpuGD3SlefPm0bt3b/74xz8SFRWF3W5n8uTJWCyWC4bUC70eGhqK2WxudV18fX1dnu/evZu0tDR+97vfMWnSJAICAsjIyOAf//jH9x7nu99XQLOxyxUVFfj4+ODj4+PS1i1x1/FFomVn20vazfNI23ouaVv3pI4fwZ61BvXFDscGTUMbNcHRgxsT9/07497t2qagW1VVhd1ubxYcg4KCzht0hg8fzqZNmxgzZgxxcXHk5+ezefNmbDYb1dXVzl7Euro6FixYgNVqRafTcdttt7XY8weOMPPaa68xaNAgYmNjndtHjhzJ+PHjCQsLo6SkhDVr1rBs2TKefPLJFkNReno6a9eudT6Pi4vjqaeeIjw8vMXz1tfXt3gcd1deXs6RI0d4+umnGTduHAD/+9//AMcNW0OGDGH16tXU1NS02KubkpLCtm3buOmmm1o8/rBhw8jLy3N5b3x8fLDZbM3eL03T0Ov1Ltv37NlD3759XYZGFBcXA+fCrE6nQ9O08z4/ey3f3ZaXl+f8Pvq+tjOZTERHR5/3deG+oqKiuroKooNI23ouaVv30HhwH1VvvEzD2flvNQ3fH/2YwOtuxdgvoc3Hc8d2vajPcltK7OdL8XPnzsVsNrNkyRKUUgQFBZGamkpmZiY63blJH7y9vfnTn/5EQ0MD+/fv59///jeRkZHNhjUAvPLKK5w4cYJly5a5bD87PAIgNjaWhIQE0tLS2LNnD2PHjm12nNmzZzNjxoxm11BaWorVam1Wvqmp6YK9gu7Iz8+PkJAQVq5cSWhoKIWFhTzxxBMA2Gw2rrnmGp599lluvvlm7rvvPiIiIvjyyy+JjIxk9OjRLF68mOuuu47Y2FhmzpyJ1Wpl8+bNzvGzl19+Ob///e9paGhwDiWIjo6mtraWTZs2kZKS4uxVVUphs9lc3sfY2FgKCwtZu3Ytw4cPZ+PGjaxfvx7AWc5ut6OUOu/zs9fy3W2ffvopl19+ucuxWtLU1OQM16J70DSNqKgoTp061eZZSIR7k7b1XJ3dtqq8FHvmalTeQbSoPhATjxYbj9YvAULC3LIHsjOow19if3cN6uAXjg06HdrYSeim/5ymqL6UAbTh/8TObleDwXDeTslmZdty4MDAQHQ6XbPe28rKyvMODzCZTKSlpXHHHXdQWVlJSEgIOTk5+Pj4EBAQ4Cyn0+mcfwn079+fwsJC3nnnnWZB99VXX2X37t088sgj9OrV63vrGxISQnh4+HkDjNFoPG8vnyf9ctXpdLz44os89NBDTJkyhfj
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2024-10-03 15:53:33 +02:00
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"init_scale = 0.1 # scale for random parameter initialisation\n",
"learning_rate = 0.1 # learning rate for gradient descent\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"\n",
"# Reset random number generator and data provider states on each run\n",
"# to ensure reproducibility of results\n",
"rng.seed(seed)\n",
"train_data.reset()\n",
"valid_data.reset()\n",
"\n",
"# Alter data-provider batch size\n",
"train_data.batch_size = batch_size \n",
"valid_data.batch_size = batch_size\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-init_scale, init_scale]\n",
"param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
"# Create affine model (outputs are logs of unnormalised class probabilities)\n",
"model = SingleLayerModel(\n",
" AffineLayer(input_dim, output_dim, param_init, param_init)\n",
")\n",
"\n",
"# Initialise the error object\n",
"error = CrossEntropySoftmaxError()\n",
"\n",
"# Use a basic gradient descent learning rule\n",
"learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
"_ = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)"
]
},
2024-10-10 15:52:23 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<span style=\"color:green\">\n",
"This gives exactly the same training curves (and error / accuracy values over training) as the two runs with equivalent parameters above (second `init_scale` experiment and second `learning_rate` experiment).\n",
"</span>\n",
"\n",
"<span style=\"color:green\">\n",
"The times per epoch seems to be slightly lower on average (0.20s compared to 0.22s) suggesting the reformulation gives a small efficiency gain (though this will become less apparent in deeper architectures as the benefit only applies to the final layer).\n",
"</span>"
]
},
2024-10-03 15:53:33 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 2: training deeper models on MNIST\n",
"\n",
"We are now going to investigate using deeper multiple-layer model archictures for the MNIST classification task. You should experiment with training models with two to five `AffineLayer` transformations interleaved with `SigmoidLayer` nonlinear transformations. Intermediate hidden layers between the input and output should have a dimension of 100. For example the `layers` definition of a model with two `AffineLayer` transformations would be\n",
"\n",
"```python\n",
"layers = [\n",
" AffineLayer(input_dim, 100),\n",
" SigmoidLayer(),\n",
" AffineLayer(100, output_dim),\n",
" SoftmaxLayer()\n",
"]\n",
"```\n",
"\n",
"If you read through the extension to the first exercise you may wish to use the `CrossEntropySoftmaxError` without the final `SoftmaxLayer`.\n",
"\n",
"**Your Tasks:**\n",
"- Use the code from the first exercise as a starting point to train models of varying depths, and compare their results. It is a good idea to start with training hyperparameters which gave reasonable performance for the shallow architecture trained previously.\n",
"\n",
"Some questions to investigate:\n",
"\n",
" 1. How does increasing the number of layers affect the model's performance on the training data set? And on the validation data set?\n",
" 2. Do deeper models seem to be harder or easier to train (e.g. in terms of ease of choosing training hyperparameters to give good final performance and/or quick convergence)?\n",
" 3. Do the models seem to be sensitive to the choice of the parameter initialisation range? Can you think of any reasons for why setting individual parameter initialisation scales for each `AffineLayer` in a model might be useful? Can you come up with (or find) any heuristics for setting the parameter initialisation scales?\n",
" \n",
"You do not need to come up with explanations for all of these (though if you can that's great!), they are meant as prompts to get you thinking about the various issues involved in training multiple-layer models. \n",
"\n",
"You may wish to start with shorter pilot training runs (by decreasing the number of training epochs) for each of the model architectures to get an initial idea of appropriate hyperparameter settings before doing one or two longer training runs to assess the final performance of the architectures."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# disable logging by setting handler to dummy object\n",
"logger.handlers = [logging.NullHandler()]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models with two affine layers"
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 13,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.10\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABnP0lEQVR4nO3deXxU1cH/8c+ZJSF7AiQkJATCEmQLsisquIAbqODjXqu4VJ+ipVqr1SpWfRRFa60/l1qtLS5VERVFXEBUFFRW0YAoYQsQSCCBrGTP3N8fkwwMCZCEJDMZvu/XK6/M3HvunXNzjPly7jnnGsuyLERERERE/ITN1xUQERERETmYAqqIiIiI+BUFVBERERHxKwqoIiIiIuJXFFBFRERExK8ooIqIiIiIX1FAFRERERG/ooAqIiIiIn5FAVVERERE/IoCqoiIiIj4FYevK9DS8vPzqa6u9nU1pIXExsaSm5vr62pIK1DbBia1a+BS2wautmxbh8NBTEzM0cs15+QLFixg3rx5FBQUkJSUxJQpU+jXr1+DZZcvX87ChQvJzMykurqapKQkLr30Uk488URPmcWLF/P888/XO/b1118nKCioSXWrrq6mqqqqSceIfzLGAO42tSzLx7WRlqS2DUxq18Cltg1c/tq2TQ6o3377LbNmzeLGG2+kb9++LFq0iBkzZvDUU0/RuXPneuV//vln0tLSuPLKKwkLC+PLL79k5syZzJgxg5SUFE+5kJAQnn76aa9jmxpORURERKT9a3JAnT9/PmeeeSZnnXUWAFOmTOHHH39k4cKFXHXVVfXKT5kyxev9VVddxapVq1i9erVXQDXGEB0d3dTqiIiIiEiAaVJAra6uZsuWLUyaNMlre1paGhs2bGjUOVwuF2VlZYSHh3ttLy8vZ+rUqbhcLnr06MHll1/uFWAPVVVV5XUr3xhDSEiI57W0f3XtqPYMPGrbwKR2DVxq28Dlr23bpIBaVFSEy+UiKirKa3tUVBQFBQWNOsf8+fOpqKjg5JNP9mzr2rUrU6dOJTk5mbKyMj7++GOmT5/OE088QUJCQoPnmTt3Lu+8847nfUpKCjNnziQ2NrYplyTtQHx8vK+rIK1EbRuY1K6BS20buPytbZs1SaqhlN2Y5L106VLmzJnDnXfe6RVyU1NTSU1N9bzv27cvf/rTn/jkk0+4/vrrGzzX5MmTmThxYr3Pz83N1Sz+AGGMIT4+npycHL8auC3HTm0bmNSugas5bWtZFsXFxfpvoR0ICgqisrKyxc4XHBxMhw4dGtzncDga1ZnYpIAaGRmJzWar11taWFhYr1f1UN9++y0vvPACf/jDH0hLSztiWZvNRq9evcjJyTlsGafTidPpbHCffhkCi2VZatMApbYNTGrXwNWUti0uLiY4OFgTntsBp9PZYisgWZZFWVkZJSUlhIWFNfs8TVqo3+Fw0LNnT9LT0722p6en07dv38Met3TpUp577jmmTZvG0KFDj/o5lmWxbds2TZoSERFppyzLUjg9DhljCA0NPea72U2+xT9x4kSeeeYZevbsSWpqKosWLSIvL4/x48cD8MYbb7Bv3z5uvfVW4EA4nTJlCqmpqZ7e16CgIEJDQwGYM2cOffr0ISEhwTMGNTMzkxtuuOGYLk5ERERE2t6xTrpqckAdPXo0xcXFvPvuu+Tn59OtWzfuuecez3iC/Px88vLyPOUXLVpETU0NL7/8Mi+//LJn+9ixY7nlllsA2L9/Py+++CIFBQWEhoaSkpLCgw8+SO/evY/p4kRERESk/TFWgA0Uys3NbfUnSVkuF2xaj1WwD9vIMa36WcczYwwJCQlkZ2drPFuAUdsGJrVr4GpO2xYVFREZGdnKNZOW0JJjUOscrv2dTmejJkk1aQyq1PolHdcTf8Z66yWsmhpf10ZEREQCzNKlSxkzZgwul6tVzn/bbbcddqWkhlRUVDBixIh685BaiwJqc/QdBBFRUFwIP//o69qIiIhIgHnkkUeYNm0aNps7qj355JOe+T4t4aGHHuKpp55qdPng4GD+93//l0ceeaTF6nAkCqjNYOx2zPBTALBWfOXj2oiIiIgvWJbV4Gz15q4pWnfcypUr2bp1q9d6743V2Fv1kZGRR10i9FCTJ09mxYoVbNy4scn1aioF1GYytWNPrTXLsCorfFwbERER/2VZFlZFuW++mjge2rIsnn/+eU4++WR69erFuHHjmD9/PuBe0z0xMZHFixdz3nnnkZKSwvLly7nkkku49957eeCBBxg4cCBXXnklAN999x0TJkwgJSWFIUOGMGPGDK9Ae7jj5s2bx5gxYzyL3c+ePZu//e1vrF+/nsTERBITE5k9ezYAiYmJvPrqq1x33XX07t2bp59+mpqaGu644w5OOukkevXqxWmnnca//vUvr+s89Bb/JZdcwvTp03n44YcZMGAAJ554Ik8++aTXMR07dmTYsGG8//77TfqZNkezniQlQM8ToFMc7N0Da1fBsFN8XSMRERH/VFmB69bLfPLRtmffhuCGn2rUkJkzZ/LJJ5/w6KOPkpKSwrJly5g2bRqdOnXylHn44Ye5//77SU5O9kwEmjNnDtdcc40nvGVnZ/PrX/+ayy67jKeffppNmzZx5513EhwczB133OE516HHASxbtoxJkyZ53l944YVs2LCBxYsX89ZbbwEQERHh2f/kk09yzz338MADD2C323G5XCQkJPDCCy/QsWNHVq1axV133UVcXBwXXnjhYa99zpw53HTTTXz44YesXr2a22+/nREjRjBmzIEJ4UOGDGH58uWN/nk2lwJqMxmbDTPiNKxP38W14mvsCqgiIiLtWmlpKS+99BKzZ89m+PDhAHTv3p2VK1fy+uuv86tf/QqAO++80yu0AfTo0YP77rvP8/6xxx6ja9euPPLIIxhj6N27Nzk5OcyYMYPbb7/dM7b00OMAsrKy6NKli+d9SEgIYWFh2O124uLi6tV70qRJXHHFFV7b/vjHP3peJycns2rVKj788MMjBtR+/frxhz/8AYCePXsya9Ysz2StOvHx8WRlZR32HC1FAfUYmFFjsD59F9JXYZWWYELDfV0lERER/xMU7O7J9NFnN1ZGRgbl5eWeW+11qqqqGDhwoOd9Q49sHzx4sNf7TZs2MWzYMK8F60eMGMH+/fvJzs4mMTGxweMAysvLCQ5ufL0bOserr77Km2++SVZWFuXl5VRVVTFgwIAjnqdfv35e7+Pi4rzWtgfo0KEDZWVlja5bcymgHovEHtA1GXZtx1qzDHPKOF/XSERExO8YY5p0m91X6pZ0evXVV4mPj/faFxQUxLZt2wA8T8I8WEhIiNd7y7LqPU2pofGwhx4H7rGehYWFja73ofWZN28eDz74INOnT2f48OGEhYXxj3/8gzVr1hzxPA6Hdyw0xtRb5qqgoMBruENrUUA9BsYYzMgxWO+/jrXia1BAFRERabdSU1MJDg5m586dnHzyyfX21wXUxujTpw8ff/yxV1BdtWoV4eHhJCQkHPHYAQMGkJGR4bXN6XQ2ek3UFStWMGzYMKZMmdKsuh/JL7/8ctSe2JagWfzHqG42Pz+nYxXm+7YyIiIi0mzh4eHcfPPNPPDAA7z99ttkZmaybt06Zs2axdtvN22IwrXXXsuuXbu477772LRpEwsWLODJJ5/kpptu8ow/PZzTTz+dlStXem3r1q0b27dvZ926dezbt4+KisOvINSjRw/S09NZvHgxmzdv5vHHH+fHH1tm3fYVK1YwduzYFjnXkSigHiMTGw8pqWC5sFZ94+vqiIiIyDG46667uP3223n22Wc5/fTTueqqq/jss89ITk5u0nkSEhJ47bXX+OGHHxg/fjx33303V155Jb///e+PeuzFF19MRkYGmzZt8mw7//zzOf3007nssssYNGjQEZd6+vWvf815553Hb3/7Wy644ALy8/O59tprm1T/hqxatYri4mImTJhwzOc6GmMF2AOTc3NzW/x5skfj+vxDrLdegp59sd/zRJt+diDTc70Dl9o2MKldA1dz2vZwz2KXxnn44YcpKiri8ccfb/XPcjqdjcpON910EwM
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF1CAYAAAAk8BgwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABqS0lEQVR4nO3deXxU1cH/8c+dLfsGWSEBEnZRNtkEKwgurdIClmq11qJSqchj6WJb69KqqEXr/tTW/rRi9QGtKIJRASNYtlYElcUFCIsCCZCQfZ/l/v6YZMiQsASSzGTyfb9evDJz7rl3zp1D5Ou595xrmKZpIiIiIiISJCyBboCIiIiISGMKqCIiIiISVBRQRURERCSoKKCKiIiISFBRQBURERGRoKKAKiIiIiJBRQFVRERERIKKAqqIiIiIBBUFVBEREREJKgqoIiIiIhJUbC3d4YsvvmDZsmXs3buX4uJifv3rXzNq1KhT7vPSSy9x4MABEhIS+N73vsdll13mV+e///0vr732GocPHyYlJYVrr732lMcVERERkdDT4oBaW1tLr169uPjii3nsscdOWf/IkSM8/PDDTJo0if/5n/9hx44dPP/888TGxjJmzBgAdu7cyZNPPsk111zDqFGj2LhxI0888QT3338/ffv2bVH7iouLcblcLT0tCVJJSUkUFBQEuhnSBtS3oUn9GrrUt6GrPfvWZrORkJBw6notPfCwYcMYNmzYaddfuXIliYmJzJgxA4D09HR2797N22+/7Quo77zzDoMHD2batGkATJs2jS+++IJ33nmHuXPntqh9LpcLp9PZon0kOBmGAXj71DTNALdGWpP6NjSpX0OX+jZ0BWvftjigttSuXbsYPHiwX9nQoUNZvXo1LpcLm83Gzp07ufLKK/3qDBkyhHffffeEx3U6nX5B1DAMIiIifK+l42voR/Vn6FHfhib1a+hS34auYO3bNg+oJSUlxMXF+ZXFxcXhdrspLy8nISGBkpIS4uPj/erEx8dTUlJywuMuWbKExYsX+95nZmYyf/58kpKSWrP5EgRSU1MD3QRpI+rb0KR+DV3q29AVbH3b5gEVmqbyhiHkk6V10zRPun3atGlMnjy5yWcUFBToHtQQYRgGqampHDp0KKguO8jZU9+GJvVr6FLfhq727lubzXZag4ltHlCbGwktKyvDarUSHR19wjqlpaVNRl4bs9vt2O32Zrfplye0mKapPg1R6tvQpH4NXerb0BVsfdvmAbVv375s3rzZr2zLli1kZWVhs3k/vl+/fmzbts1vRHTr1q3069ev1dphmiYVFRVB9eXLqVVXV1NXV9fstrCwMMLCwtq5RSIiItLWWhxQa2pqOHTokO/9kSNH2LdvH9HR0SQmJrJw4UKKioqYM2cOAJdddhkrVqzgpZdeYtKkSezcuZNVq1bx85//3HeMK664gj/84Q+89dZbjBw5ko8//pht27Zx//33t8IpelVUVBAWFobD4Wi1Y0rbs9vtza7KYJom1dXVVFZWEhUVFYCWiYiISFtpcUDdvXs39913n+/9P//5TwDGjx/PbbfdRnFxMYWFhb7tycnJ3Hnnnbz00kusWLGChIQEbrzxRt8SUwD9+/dn7ty5vPrqq7z22mukpqYyd+7cFq+BejKmaSqchhDDMIiMjKS0tDTQTREREZFWZpghds27oKCg2RG3srIyYmNjA9AiORsnGkFtoH7tmAzDIC0tjfz8fN12E0LUr6FLfRu62rtv7Xb7aU2SsrR5S0REREREWqBdlpkSERERkcAyXS6oroLqCqiqhOoqzOpK6gacC1EnXjkpEBRQO7nc3FymT5/OunXrfMt+tabHHnuM5cuX8/7775/2PldccQVz5szhiiuuaPX2iIiIdESmaUJdnV+4pKoSs6qiPnRW1pd7f5rNlFFX2+yxK797DUy5vp3P6OQUUDu5+fPn85Of/MQXTl977TX++Mc/8uWXX7bK8X/2s59x4403tmifuXPncv/99/Ptb3+7VdogIiISCKbHDbW1UFsDdTVQU/+zthZqqzEbtjX8qakPndXHAiiNA6jb3ToNC4uAiEiIjIKIKKzJaa1z3FakgNqJ5eXl8f777/utynC66urqTmtVhKioqBYvAzVp0iTuuOMOPvzwQy6//PIWt01ERKSlTNOE2mooK4XKcl9oNBsHyNoa7yhko/fNb6/2hlBn8+t4nxXDUh8sjwVMIqIwGpfVlxsRDWXRx/aJiMKwWo8dzjCITUujMj+/9dt6FjplQPUOkzc/zN3mHGEnfYTr8VavXs1TTz3Fjh07sFgsnH/++dx///306tUL8IbMBx54gDVr1lBbW0vfvn158MEHGT58OAArV67kiSeeYMeOHURGRjJmzBief/55AN5++23OOeccunXrBsCGDRv45S9/CUD37t0B+OUvf8mvfvUrRo8ezbXXXsu+fftYvnw5l19+OU899RQPPvgg7733Hvn5+SQnJzNt2jR+8Ytf+J7ydfwl/rlz51JWVsaoUaN47rnnqKurY8qUKdx3332+faxWKxMnTuStt95SQBURkTNmmqY3bJaXQlkJZlkJlJV4Q2h54/clUF7ivYTeFgwDHOEQHg6OMAgL9/tjhIV7t0dEeANnZOPQeew9kZEQFtGiHNFRdcqASl0tnjlXB+SjLf/7L+9fyNNUVVXFLbfcwoABA6iqquLPf/4zM2fOZOXKlVRXVzN9+nRSU1N58cUXSUpKYtu2bXg8HgBycnKYOXMmt99+O08//TR1dXV88MEHvmN/9NFHDB482Pd+xIgR3Hffffz5z39mzZo1AH6jn3/729+YO3eu30MWoqKieOKJJ0hNTeXLL7/kN7/5DdHR0cyePfuE57RhwwaSk5N5/fXX2bt3L7feeiuDBg3iRz/6ka/O0KFD+etf/3ra35OIiHQOptvtDZzHh87ykmPv67dRXtryy+IOB0THei+DO8Ig3PvTqP/pHy7DfPUMv/KI+m317+2OThEqW1PnDKgdyJVXXun3/rHHHmPw4MHs3LmTTZs2cfToUd555x0SEhIAyMzM9NV9+umnmTJlCr/+9a99ZYMGDfK93r9/P+edd57vvcPhICYmBsMwSE5ObtKWcePG8bOf/cyvbO7cub7XGRkZ7N69m2XLlp00oMbFxfHggw9itVrp06cPkyZNYt26dX4BNS0tjYMHD/rCtoiIhDbT5YLSIiguxCwqhOKjUHIUSov9Q2dlObR0vc7IKIiJh9g4iI3HiImHWO8fIzYeYuJ87wkLV5gMAp0zoDrCvCOZAfrslti3bx+PPvoon3zyCUVFRb7AdvDgQT7//HPOPfdcXzg93ueff+4X+o5XU1NDePjpj+Y2Hm1tkJ2dzfPPP8++ffuorKzE7XafcjWAfv36YW10/0tKSkqTSVnh4eF4PB5qa2ux2TrnX1MRkVDRbPgsLsQsLoSG92XFpx88DcM7ytkQMn2BM65p6IyJx6i/hUw6jk75L79hGC26zB5IM2bMoFu3bjzyyCOkpqbi8XiYOHEiTqfzlOHyVNu7dOlCSUnJabclMjLS7/3mzZuZPXs2v/rVr5gwYQIxMTEsXbqUv//97yc9jr2Z/1Ac//SK4uJiIiIiiIiIOOmTpEREJLB84bOoPnCeTfi02iChKyR0xUhIgoQuEJtwLHQ2hNDoWAyL9VRHkw6sUwbUjqKoqIhdu3Yxf/58Ro8eDcDGjRt92wcOHMiiRYsoLi5udhR14MCBrFu3jmuuuabZ45977rns2rXLr8zhcOA+zft1Pv74Y9LT0/3uST148OBp7XsqO3bs8Lv9QERE2p9pmlBRDgX5VO3cimdPLubZhs8uiRjxicdeJyRCl/r30XEYFj3kUhRQg1p8fDwJCQm88sorJCcnc/DgQR5++GHf9qlTp/LMM89w8803c+edd5KcnMz27dtJSUlhxIgR/PKXv+Saa66hZ8+eTJkyBZfLxerVq333h44fP5477rgDt9vtu+Senp5OZWUla9euZdCgQb5RzOZkZmZy8OBBli5dypAhQ/jggw947733WuXcN27cyEUXXdQqxxIRkZMza6rgcD7m4YN
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.87e-02\n",
" final error(valid) = 7.73e-02\n",
" final acc(train) = 9.98e-01\n",
" final acc(valid) = 9.76e-01\n",
" run time per epoch = 1.64\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.20\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABj0ElEQVR4nO3deXxU1f3/8deZJZN9T0hICARCkC2y41KBuiNUhVJbrQsuta22qLVurVr1696qtYttrba4/FTcUEQURaWKyKoQEAXZtwQSyL4nc39/TDJhSIAkhMxkeD8fjzwy9869dz43x8g75557rrEsy0JEREREJEDY/F2AiIiIiMiBFFBFREREJKAooIqIiIhIQFFAFREREZGAooAqIiIiIgFFAVVEREREAooCqoiIiIgEFAVUEREREQkoCqgiIiIiElAUUEVEREQkoDj8XUBnKyoqor6+3t9lSCdJSkqioKDA32XIMaC2DU5q1+Cltg1eXdm2DoeDuLi4I2/XBbV0qfr6eurq6vxdhnQCYwzgaVPLsvxcjXQmtW1wUrsGL7Vt8ArUttUlfhEREREJKAqoIiIiIhJQFFBFREREJKAooIqIiIhIQAm6m6RERETE/yzLory8PKBuvJHWVVVVUVtb22nHc7lcuFyuozqGAqqIiIh0uvLyclwuFyEhIf4uRY7A6XR22gxIlmVRVVVFRUUFERERHT6OLvGLiIhIp7MsS+H0OGSMITw8/KjnpFdAFREREZFO1TS/akcpoIqIiIhIQFFA7QDL7cbasBb3sk/9XYqIiIhI0FFA7YhvV+P+4++wXvk3VkODv6sRERGRILNo0SLGjRuH2+0+Jse/8cYbueqqq9q8fU1NDaNHjyY3N/eY1HMwBdSOyB4KkdFQVgLfrPZ3NSIiIhJkHnjgAWbMmIHN5olqjz32GGeddVanHf++++7jiSeeaPP2LpeLX/ziFzzwwAOdVsPhKKB2gHE4MKNOBcBa9j8/VyMiIiL+YFlWq3erd3RO0ab9li9fzpYtW5g8eXK7j9HW6aKio6OJiYlp17GnTJnCsmXL+O6779pdV3spoHaQGT0OAOurJVi1NX6uRkREJHBZloVVU+2fr3Y+KMCyLJ566ilOPvlk+vXrx5lnnsncuXMBWLx4MWlpaSxcuJCJEyeSmZnJ0qVLmTZtGr///e+55557GDJkCBdffDEAX3zxBZMmTSIzM5Phw4fz4IMP+gTaQ+03Z84cxo0bR2hoKACzZs3i8ccfZ926daSlpZGWlsasWbMASEtL4/nnn+fKK68kKyuLJ598koaGBm6++WZOOukk+vXrx2mnncYzzzzjc54HX+KfNm0ad911F/fffz+DBw9m2LBhPPbYYz77xMfHM3LkSN566612/Uw7QhP1d1TWQIhPhP2FsGYljDzF3xWJiIgEptoa3L+6yC8fbfvbq+AKbfP2jzzyCO+99x4PPfQQmZmZLFmyhBkzZpCQkODd5v777+fuu+8mIyOD6OhoAF577TUuv/xyb3jLy8vjsssu46KLLuLJJ59k48aN3HLLLbhcLm6++WbvsQ7eD2DJkiVceOGF3uXzzz+f9evXs3DhQl555RUAoqKivO8/9thj3HHHHdxzzz3Y7Xbcbjepqan885//JD4+nhUrVnDrrbeSnJzM+eeff8hzf+2117j22mt55513WLlyJTfddBOjR49m3Lhx3m2GDx/O0qVL2/zz7CgF1A4yNhtm9GlY82fjXvYpdgVUERGRbq2yspJ///vfzJo1i1GjRgHQu3dvli9fzosvvshPf/pTAG655Raf0AbQp08f7rzzTu/yww8/TM+ePXnggQcwxpCVlUV+fj4PPvggN910k3ds6cH7AezcuZMePXp4l8PCwoiIiMBut5OcnNyi7gsvvJCf/OQnPut++9vfel9nZGSwYsUK3nnnncMG1IEDB/Kb3/wGgL59+zJz5kzvzVpNUlJS2Llz5yGP0VkUUI+CGTMOa/5syF2OVVWJCQv3d0kiIiKBJ8Tl6cn002e31YYNG6iurvZeam9SV1fHkCFDvMs5OTkt9j3xxBN9ljdu3MjIkSN9JqwfPXo0FRUV5OXlkZaW1up+ANXV1e16ln1rx3j++ed5+eWX2blzJ9XV1dTV1TF48ODDHmfgwIE+y8nJyRQWFvqsCw0Npaqqqs21dZQC6tHo1RdS0iB/F9ZXSzCnnO7vikRERAKOMaZdl9n9pWlKp+eff56UlBSf90JCQti2bRsA4eEtO6TCwsJ8li3LavE0pdbGwx68H3jGepaUlLS57oPrmTNnDvfeey933XUXo0aNIiIign/84x989dVXhz2Ow+EbC40xLaa5Ki4u9hnucKwooB4FYwxm9Disd17GWv4pKKCKiIh0W9nZ2bhcLnbt2sXJJ5/c4v2mgNoW/fv3Z968eT5BdcWKFURGRpKamnrYfQcPHsyGDRt81jmdzjbPibps2TJGjhzJ9OnTO1T74Xz77bdH7IntDLqL/yiZMY3jMtatwipr+187IiIiElgiIyP5+c9/zj333MOrr77K1q1bWbt2LTNnzuTVV9s3ROGKK65g9+7d3HnnnWzcuJH58+fz2GOPce2113rHnx7KhAkTWL58uc+6Xr16sX37dtauXcv+/fupqTn0DEJ9+vQhNzeXhQsXsmnTJh599FFWr+6ceduXLVvG+PHjO+VYh6OAepRMShpk9AO3G2vl5/4uR0RERI7Crbfeyk033cTf/vY3JkyYwCWXXMKHH35IRkZGu46TmprKCy+8wKpVqzjrrLO4/fbbufjii7nhhhuOuO/UqVPZsGEDGzdu9K4777zzmDBhAhdddBFDhw497FRPl112GRMnTuSXv/wlP/jBDygqKuKKK65oV/2tWbFiBWVlZUyaNOmoj3UkxmrvBGEBrqCgoM2T1HYW9/zZWK//F/oPwn7rw1362cHMGENqaip5eXntnsdOApvaNjipXYNXR9q2tLTUOwWTtN/9999PaWkpjz766DH/LKfT2absdO211zJkyBBmzJhxxG0P1f5Op5OkpKQj7q8e1E5gRn/P8+K7dVj7C/xbjIiIiHR7M2bMID09nYaGBn+XAkBNTQ2DBg3iZz/7WZd8ngJqJzDxSdB/EADW8kV+rkZERES6u+joaGbMmIHdbvd3KQC4XC5uvPHGVmcdOBYUUDtJ081S1rJP/VyJiIiISPemgNpBpTUNbNxX7V02I08Fmw22b8LK3+XHykRERES6NwXUDsjNr+Dq2Rt57PNduBsHi5uoGBg0DFAvqoiIiMjRUEDtgP4JYThsht1ldazOr/SuN6MbL/Mv/1R3sIqIiIh0kAJqB4Q5bZzeNwaAd9cXedeb4SeBMwTyd8GOzf4qT0RERKRbU0DtoInZsQCs2FXOnvJaAExYOOSMAsBaqsv8IiIiIh2hgNpB6dEuhqWEYwHvbSj2rrc13c2//DOsNj4zV0RERESaKaAehfMGxAGwYFMxNfWNYXToKAgLh6JC2PiNH6sTERGR7mrRokWMGzcOdyd2dt14441cddVV3uVp06Zx9913H3afsWPH8u9//xvwTNY/evRocnNzO62mQ1FAPQqjekaSHOGgrNbNZ9tKATDOEMywkwDPzVIiIiIi7fXAAw8wY8YMbLZjF9X+/e9/c+utt7Z5e5fLxS9+8QseeOCBY1ZTEwXUo2C3GSb29/Sivru+yHvnvnfS/hWfY9XX+60+EREROXYsy6K+lX/na2trO3S8pv2WL1/Oli1bmDx58lHVdyRxcXFERka2a58pU6awbNkyvvvuu2NUlYcC6lE6MyuWELthc1EN6wsbJ+4feCJExUB5KXy72r8FioiI+JllWVTXu/3y1d5pHy3L4qmnnuLkk0+mX79+nHnmmcydOxeAxYsXk5aWxsKFC5k4cSKZmZksXbqUadOm8fvf/5577rmHIUOGcPHFFwPwxRdfMGnSJDIzMxk+fDgPPvigT6A91H5z5sxh3LhxhIaGArBx40bS0tLYuHGjT63/+te/GDt2LJZl0dDQwM0338xJJ51Ev379OO2003jmmWcOe64HX+IvLCzkiiuuoF+/fpx
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF2CAYAAACiZGqeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABtRUlEQVR4nO3deXxU1cH/8c+dzGRfISuEJUFAREF2RAULKq3SIj5UrbUWlUoFaqmttj5qW/cHbYvKr7b20arVB6QuCGIFjKggqCgquAKRnSSQkEz2ZZbz+2OSgSGJEAyZyeT7fr3ySubOuTfn5hD4crZrGWMMIiIiIiIhwhbsCoiIiIiIHEkBVURERERCigKqiIiIiIQUBVQRERERCSkKqCIiIiISUhRQRURERCSkKKCKiIiISEhRQBURERGRkKKAKiIiIiIhRQFVREREREKKva0nfPHFFyxfvpydO3dSVlbGb37zG0aPHn3Mc55++mn27dtHSkoKP/jBD7jwwgsDyrz33nssWbKEAwcOkJGRwY9+9KNjXldEREREwk+be1Dr6+vp27cv11577XGVP3jwIPfffz+DBg1i/vz5TJs2jSeffJL33nvPX2bbtm089NBDjB8/ngcffJDx48ezYMECtm/f3tbqiYiIiEgn1+Ye1GHDhjFs2LDjLr969WpSU1OZMWMGANnZ2Xz99de88sorjB07FoBXX32VIUOGMG3aNACmTZvGF198wauvvsq8efPaVL+ysjLcbnebzpHQlZaWRnFxcbCrISeB2jY8qV3Dl9o2fHVk29rtdlJSUo5d7mRXZPv27QwZMiTg2Jlnnsmbb76J2+3Gbrezbds2Lr744oAyQ4cO5T//+U+r13W5XLhcLv9ry7KIiYnB7XYroIYJy7IA8Hg8GGOCXBtpT2rb8KR2DV9q2/AVqm170gOq0+kkKSkp4FhSUhIej4fKykpSUlJwOp0kJycHlElOTsbpdLZ63aVLl/LCCy/4X+fk5DB//nzS0tLas/oSAjIzM4NdBTlJ1LbhSe0avtS24SvU2vakB1Q4nM6bNCX0o48fXeab3p82bRpTpkxp9j2Ki4vVgxomLMsiMzOToqKikPpfnXx7atvwpHYNX2rb8NXRbWu324+rM/GkB9SWekIrKiqIiIggPj6+1TLl5eXNel6P5HA4cDgcLb6nX57wYoxRm4YptW14UruGL7Vt+Aq1tj3p+6D279+fLVu2BBzbvHkzubm52O2+fDxgwAA+/fTTgDJbtmxhwIABJ7t6IiIiIhJi2tyDWldXR1FRkf/1wYMH2bVrF/Hx8aSmprJo0SJKS0uZO3cuABdeeCGrVq3i6aefZtKkSWzbto01a9bwy1/+0n+Niy66iD/84Q+8/PLLjBo1ig8++IBPP/2Uu+66qx1u0ccYQ1VVVUj970COrba2loaGhhbfi4qKIioqqoNrJCIiIidbmwPq119/zZ133ul//a9//QuACRMmMGfOHMrKyigpKfG/n56ezq233srTTz/NqlWrSElJ4ZprrvFvMQUwcOBA5s2bx3PPPceSJUvIzMxk3rx59O/f/9vcW4CqqiqioqKIjIxst2vKyedwOAJ2a2hijKG2tpbq6mri4uKCUDMRERE5WSwTZl2KxcXFLQaaiooKEhMTg1Aj+TZaC6hNjjVXWUKTZVlkZWVRWFioUY0wonYNX2rb8NXRbetwOI5rkdRJn4MqcjJ9004PIiIi0jkpoIqIiIhISOmQfVBFRERE5OQxxoDHAx43uN3gcfk+u92Hjx153OMBtxvjcdMwaDDEhtZ0OQXULi4/P5/p06fzzjvv+PelbU9//vOfWblyJa+//vpxn3PRRRcxd+5cLrroonavj4iISDCYhnqoroLqCt/nqkpMdSX4P6qgvg7jD5muFsJl4/GjjzW9PkHVUy6DS37Sjnf77SmgdnHz58/npz/9qT+cLlmyhD/+8Y98+eWX7XL9n//851xzzTVtOmfevHncddddfPe7322XOoiIiLQX43Y1Bs1KqKqEmkpMVeXh8FlViWl6/4gytLJl4klld0CEHeyNHxFHfbY7wG7Hntmj4+t2DAqoXVhBQQGvv/56wLZhx6uhoeG4tuyKi4tr8zZQkyZN4uabb+att95i8uTJba6biIjI8TCuBqgsh8oKqKrAVJZDVTlUNobKgF5OX68n9bUn/g1tNohLaPyIh/hErNh4iG88FhV9RID0hUcrwt5yyLTbIcLR+jGb7bgWEluWRUJWFlWFhSd+XydBlwyoxhhoqA/ON4+MatPK8zfffJOHH36YrVu3YrPZGDFiBHfddRd9+/YFfCHz7rvvZu3atdTX19O/f3/uvfdehg8fDsDq1atZsGABW7duJTY2lrFjx/L4448D8Morr3DaaafRo4fvf04bNmzgpptuAqBnz54A3HTTTfz6179mzJgx/OhHP2LXrl2sXLmSyZMn8/DDD3Pvvffy2muvUVhYSHp6OtOmTeNXv/qV/zG0Rw/xz5s3j4qKCkaPHs1jjz1GQ0MDU6dO5c477/SfExERwcSJE3n55ZcVUEVE5LgYY6C2pjFgHhE4Kyv8x0xVRWMgLf92YdOyIDbeFzIbA6fVFDLjEnyBMzYeKz4xoAwxsdp95jh1yYBKQz3euZcF5Vvb/t+/ff9DOk41NTVcf/31nHrqqdTU1PCnP/2JmTNnsnr1ampra5k+fTqZmZk8+eSTpKWl8emnn+L1egHIy8tj5syZ3HjjjTzyyCM0NDTwxhtv+K/9/vvvM2TIEP/rkSNHcuedd/KnP/2JtWvXAgT0fv79739n3rx5AU8Bi4uLY8GCBWRmZvLll19yyy23EB8fz+zZs1u9pw0bNpCens7zzz/Pzp07ueGGGxg8eDA//vGP/WXOPPNM/va3vx33z0lERMKLMQaqKqCisVfzyMBZ2fjaHzh9gRTPCczDjIiA+CRfqExIwkpo/DquMVzGJ2DFHRU+Y2KxbBHtf9Pi1zUDaidy8cUXB7z+85//zJAhQ9i2bRsffvghhw4d4tVXXyUlJQWAnJwcf9lHHnmEqVOn8pvf/MZ/bPDgwf6v9+7dyxlnnOF/HRkZSUJCApZlkZ6e3qwuZ599Nj//+c8Djs2bN8//da9evfj6669Zvnz5NwbUpKQk7r33XiIiIjjllFOYNGkS77zzTkBAzcrKYv/+/f6wLSIi4cW4XFBWAocOYkp9nyktxpQWw6FiKC32LQhqq6hoiE/0fSQkYSX4PjcdsxKSDr9OSISYOPVqhqCuGVAjo3w9mUH63m2xa9cuHnzwQT766CNKS0v9gW3//v18/vnnnH766f5werTPP/88IPQdra6ujujo4+/NPbK3tcmKFSt4/PHH2bVrF9XV1Xg8nmPuBjBgwAAiIg7/zzMjI6PZoqzo6Gi8Xi/19fXY7V3zj6mISGdljIGaqsageRBzqARKD8KhxgBaWgzlZcd3sbgEX5CMT4T4owJnQiJWfGDgtNr476yEpi75L79lWW0aZg+mGTNm0KNHDx544AEyMzPxer1MnDgRl8t1zHB5rPe7deuG0+k87rrExsYGvN60aROzZ8/m17/+Needdx4JCQksW7aMf/zjH994naa5pkc6+vFqZWVlxMTEEBMT842POhURkY5nPB5wHmoMnAf9PZ4BvZ/1dce+UGQkdEuDbulY3dOgW+oRX6dBSncse/N/MyT8dcmA2lmUlpayfft25s+fz5gxYwDYuHGj//1BgwaxePFiysrKWuxFHTRoEO+88w6XX355i9c//fTT2b59e8CxyMhIPB7PcdXvgw8+IDs7O2BO6v79+4/r3GPZunVrwPQDERHpOMbrhbJDUFyIKS6CkgMcqqnEvW+3L3yWlYI5jilYCUm+oNk9HatbGnRPxeqWDt19oZT4BA2vS4sUUENYcnIyKSkpPPvss6Snp7N//37uv/9+//uXXHIJCxcu5LrrruPWW28lPT2dzz77jIyMDEaOHMlNN93E5ZdfTp8+fZg6dSput5s333zTPz90woQJ3HzzzXg8Hv+Qe3Z2NtXV1axbt47Bgwf7ezFbkpOTw/79+1m2bBlDhw7ljTfe4LX
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.85e-02\n",
" final error(valid) = 7.47e-02\n",
" final acc(train) = 9.98e-01\n",
" final acc(valid) = 9.78e-01\n",
" run time per epoch = 1.68\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.50\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABlCElEQVR4nO3deXxU1f3/8deZyWSyJyQkJGSBsMoi+6KoiIpahSpY6l7FpdaqRa3V6ldxq0u1P7V+Xeq31RaXqrgi4gKioqKyKiCL7GEPJJB9T+b8/phkYEhYAklmEt7Px2OcmTv33vlMjsA7Z849x1hrLSIiIiIiQcIR6AJERERERPalgCoiIiIiQUUBVURERESCigKqiIiIiAQVBVQRERERCSoKqCIiIiISVBRQRURERCSoKKCKiIiISFBRQBURERGRoKKAKiIiIiJBJSTQBTS1vLw8qqurA12GNJHExERycnICXYY0A7Vt26R2bbvUtm1XS7ZtSEgI7dq1O/R+LVBLi6qurqaqqirQZUgTMMYA3ja11ga4GmlKatu2Se3adqlt265gbVt9xS8iIiIiQUUBVURERESCigKqiIiIiAQVBVQRERERCSpt7iIpERERCTxrLcXFxUF14Y00rKysjMrKyiY7n9vtxu12H9U5FFBFRESkyRUXF+N2uwkNDQ10KXIILperyWZAstZSVlZGSUkJkZGRR3wefcUvIiIiTc5aq3B6DDLGEBERcdRz0iugioiIiEiTqptf9UgpoIqIiIhIUFFAPQLW48GuXo5nwdeBLkVERESkzVFAPRI/L8Xz//4H++a/sEc5xkJERERkf3PnzmXkyJF4PJ5mOf8tt9zC1Vdffdj7V1RUMHToUJYtW9Ys9exPAfVI9OwH0bFQVACrlga6GhEREWljHn74YSZNmoTD4Y1qTzzxBGeeeWaTnf/BBx/kqaeeOuz93W43119/PQ8//HCT1XAwCqhHwDidmCEnA2DnzwlsMSIiIhIQ1toGr1Y/0jlF645buHAhGzduZOzYsY0+x+FOFxUTE0NsbGyjzj1+/HgWLFjA2rVrG11XYymgHiEz/FQA7JL52IryAFcjIiISvKy12IrywNwauVCAtZbnn3+eE088ka5duzJ69GhmzJgBwHfffUdqaipz5szhnHPOITMzk/nz5zNhwgTuvvtu7r//fvr27csll1wCwPfff8+YMWPIzMxk4MCBPPLII36B9kDHTZ8+nZEjRxIWFgbA1KlTefLJJ1m5ciWpqamkpqYydepUAFJTU3nllVe46qqr6NatG08//TQ1NTXcdtttnHDCCXTt2pVTTjmFF1980e9z7v8V/4QJE5g8eTIPPfQQffr0YcCAATzxxBN+x8THxzN48GCmTZvWqJ/pkdBE/UeqS09ITIacbOyS+b7AKiIiIvuprMBz04UBeWvHs2+BO+yw93/sscf45JNPePTRR8nMzGTevHlMmjSJhIQE3z4PPfQQ9957LxkZGcTExADw9ttvc8UVV/jC244dO/jNb37DhRdeyNNPP826deu4/fbbcbvd3Hbbbb5z7X8cwLx58xg3bpzv+Xnnncfq1auZM2cOb775JgDR0dG+15944gnuuusu7r//fpxOJx6Ph5SUFF544QXi4+NZtGgRd9xxB0lJSZx33nkH/Oxvv/021113HR9++CGLFy/m1ltvZejQoYwcOdK3z8CBA5k/f/5h/zyPlALqETLGYIafip0xFTv/K1BAFRERadVKS0v517/+xdSpUxkyZAgAnTp1YuHChbz22mtcdtllANx+++1+oQ2gc+fO3HPPPb7nf/3rX+nYsSMPP/wwxhi6detGdnY2jzzyCLfeeqtvbOn+xwFs3bqVDh06+J6Hh4cTGRmJ0+kkKSmpXt3jxo3j4osv9tv2pz/9yfc4IyODRYsW8eGHHx40oPbq1Ys//vGPAHTp0oUpU6b4Ltaqk5yczNatWw94jqaigHoU6gIqK3/EFhViomMCXZKIiEjwCXV7ezID9N6Ha82aNZSXl/u+aq9TVVVF3759fc/79etX79j+/fv7PV+3bh2DBw/2m7B+6NChlJSUsGPHDlJTUxs8DqC8vLxRa9k3dI5XXnmFN954g61bt1JeXk5VVRV9+vQ56Hl69erl9zwpKYnc3Fy/bWFhYZSVlR12bUdKAfUomOQ06NQNNq3DLpqLOe3cQJckIiISdIwxjfqaPVDqpnR65ZVXSE5O9nstNDSUTZs2ARAREVHv2PDwcL/n1tp6qyk1NB52/+PAO9azoKDgsOvev57p06fzwAMPMHnyZIYMGUJkZCT/+Mc/+PHHHw96npAQ/1hojKk3zVV+fr7fcIfmooB6lMzwU7Gb1nmv5ldAFRERabV69OiB2+1m27ZtnHjiifVerwuoh6N79+58/PHHfkF10aJFREVFkZKSctBj+/Tpw5o1a/y2uVyuw54TdcGCBQwePJiJEyceUe0H8/PPPx+yJ7Yp6Cr+o2SGngzGwPqfsTnZgS5HREREjlBUVBS/+93vuP/++3nrrbfIyspi+fLlTJkyhbfeatwQhSuvvJLt27dzzz33sG7dOmbOnMkTTzzBdddd5xt/eiCjRo1i4cKFftvS09PZvHkzy5cvZ8+ePVRUVBzw+M6dO7Ns2TLmzJnD+vXrefzxx1m6tGnmbV+wYAGnntr8190ooB4lE5cAx3nHolgtfSoiItKq3XHHHdx66608++yzjBo1iksvvZTPPvuMjIyMRp0nJSWFV199lSVLlnDmmWdy5513cskll3DzzTcf8tgLLriANWvWsG7dOt+2c889l1GjRnHhhRdy/PHHH3Sqp9/85jecc845/P73v+eXv/wleXl5XHnllY2qvyGLFi2iqKiIMWPGHPW5DsXYxk4QFuRycnIOe5LapuL5djZ2yv9CSjqOB56tN+ZEjowxhpSUFHbs2NHoeewkuKlt2ya1a9t1JG1bWFjom4JJGu+hhx6isLCQxx9/vNnfy+VyHVZ2uu666+jbty+TJk065L4Han+Xy0ViYuIhj1cPahMwA0+EEBfs2AJbNgS6HBEREWnlJk2aRFpaGjU1NYEuBYCKigp69+7Nb3/72xZ5PwXUJmAiIqH/UADvnKgiIiIiRyEmJoZJkybhdDoDXQoAbrebW265pcFZB5qDAuoRqPFYvt1UyCdr8nzbHMNHAd5xqNYTHL/tiIiIiLRGCqhHYMmOEh6fu51XluRQWlUbRvsOhohIyN8Da1YEtkARERGRVkwB9QgM7BhJakwopVUePlvnnUjXuFyYwScB+ppfRERE5GgooB4BhzGM6xUPwPSf91Dt8V7RaOq+5l/8HbaqMlDliYiIiLRqCqhHaFRmDLFhTnJLq/l2U6F3Y/fe0K49lJXAT4sCW6CIiIhIK6WAeoRCnQ7G9mgHwLRVe7xLmTkcmGGnAODR1/wiIiIiR0QB9Sj8okc7Qp2GDXkV/LSzFNj7NT/LFmJLiwNXnIiIiEgrpYB6FGLcTkZ3jQXg/ZV7vBvTOkPHDKiuxi7+LnDFiYiISKs1d+5cRo4cicfjabJz3nLLLVx99dW+5xMmTODee+896DHDhw/nX//6F+CdrH/o0KEsW7asyWo6EAXUo3TecfE4DPywo4RN+RUYYzDDTwV0Nb+IiIgcmYcffphJkybhcDRfVPvXv/7FHXfccdj7u91urr/+eh5++OFmq6mOAupRSokO5YT0aMA7FhXwBVTWLMfm7Q5UaSIiItKMrLVUV1fX215ZeWQz+dQdt3DhQjZu3MjYsWOPqr5DadeuHVFRUY06Zvz48SxYsIC1a9c2U1VeCqhNoG7Kqa+zCthdWoVJSIJuvcFa7IKvA1ydiIhIYFlrKa/2BORmrW10rc8//zwnnngiXbt2ZfTo0cyYMQOA7777jtTUVObMmcM555xDZmYm8+fPZ8KECdx9993cf//99O3bl0suuQSA77//njFjxpCZmcnAgQN55JFH/ALtgY6bPn06I0eOJCwsDIB169aRmprKunXr/Gr9v//7P4YPH461lpqaGm677TZOOOEEunbtyimnnMKLL7540M+6/1f8ubm5XHnllXT
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF3CAYAAABpOLk7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABl4ElEQVR4nO3deXxU1cH/8c+dzCRkJxBCAgkSliCLbCIgqFBQqUoLWB61ai0qlUfgsdRWW2u1xfVB2+LSWvWnlVYf0LogGBUigoKgIqAgoEAEFLJAQvY9M3N+f0wyMCRIAiEzmXzfr1dek7n33Jtz5xD4cu4551rGGIOIiIiISICw+bsCIiIiIiLHUkAVERERkYCigCoiIiIiAUUBVUREREQCigKqiIiIiAQUBVQRERERCSgKqCIiIiISUBRQRURERCSgKKCKiIiISEBRQBURERGRgGJv7gE7d+5k+fLl7Nu3j8LCQn7zm98wcuTIkx7zr3/9i4MHDxIXF8ePf/xjLr30Up8yn3zyCa+88gqHDh2ia9eu/PSnPz3peUVEREQk+DS7B7W6upqePXty0003Nan84cOHefjhh+nfvz8LFixg2rRpvPDCC3zyySfeMrt37+axxx7joosu4tFHH+Wiiy5i4cKF7Nmzp7nVExEREZE2rtk9qMOGDWPYsGFNLp+RkUF8fDwzZswAIDk5mW+++Ya33nqL0aNHA/D2228zePBgpk2bBsC0adPYuXMnb7/9NvPmzWtuFUVERESkDWt2QG2uPXv2MHjwYJ9tQ4cOZc2aNTidTux2O7t37+aKK67wKTNkyBDeeeedE563traW2tpa73vLsggPD6ewsBCn09myFyF+YVkW8fHx5OfnY4zxd3WkBaltg5PaNXipbYNXa7et3W4nLi7u5OXOdEWKioqIjY312RYbG4vL5aK0tJS4uDiKioro2LGjT5mOHTtSVFR0wvMuXbqU1157zfs+NTWVBQsWNOmipW2Jj4/3dxXkDFHbBie1a/BS2wavQGvbMx5QwZPOj1Wf0I/ffnyZ79s/bdo0Jk+e3OBn5OXlqQc1SFiWRWJiIrm5ufofe5BR2wYntWvwUtsGr9ZuW7vdTpcuXU5e7kxXpLGe0JKSEkJCQoiKijphmeLi4gY9r8dyOBw4HI5G9+mXJ7gYY9SmQUptG5zUrsFLbRu8Aq1tz/g6qH379mXbtm0+27Zu3UqvXr2w2z35OC0tjS+//NKnzLZt20hLSzvT1RMRERGRANPsgFpVVcX+/fvZv38/4FlGav/+/eTn5wOwePFi/va3v3nLX3rppeTn53vXQV29ejWrV6/mRz/6kbfM5ZdfztatW3nzzTfJysrizTff5Msvv2wwcUpEREREgl+zb/F/8803zJ8/3/v+3//+NwDjxo1jzpw5FBYWesMqQEJCAnfddRf/+te/WLlyJXFxcdx4443eJaYA+vXrx7x583j55Zd55ZVXSExMZN68efTt2/d0rs2HMYaysrKA6r6Wk6usrKSmpqbRfWFhYYSFhbVyjURERORMs0yQJba8vDyf5afqlZaWEhYWRmhoqB9qJafK4XA02p7GGCorKzHGEBkZ6YeayemwLIukpCRycnL0n8YgonYNXmrb4NXabetwOJo0SeqMj0ENFMYYhdMgYlkWERERWrFBREQkCLWbgCrB6fuWIhMREZG2SQFVRERERAJKqyzULyIiIiJnlnG7wOkEZ23dqxNc9a+NbHM6MS4nNf0HQsSJ1573BwXUdi4zM5Pp06fz0UcfeR+c0JL+8pe/sGLFCt57770mH3P55Zczd+5cLr/88havj4iIyJlijPEEv5pqqK6Cmqq676u920wj26ip+762xhMavSGy1idMHv3+mO3H7jfuU6p3+eSrYOrPWvjTOD0KqO3cggUL+PnPf+4Np6+88gp/+tOf+Oqrr1rk/P/93//NjTfe2Kxj5s2bx3333ccPf/jDFqmDiIjI9zEuFxQXQtERKDyCKS2uC41VDYKkOTZUHhsu6793n1pIPCNC7GC3H32128HuOOa9A+x27Ind/F3TBhRQ27Hs7Gzee+89n3Vtm6qmpqZJqyJERkY2exmoiRMncscdd/DBBx8wadKkZtdNRESknqmugsIjUJiPKSqoC6H5mMK674uOQHHRKfc+nlBICIR2gLAwCA3z/T6sA1bdq2df/VcohDiOhskQO5b3e4dPqPQJnccd4z22CROJLcsiOimJspyclr3+09QuA6oxxvO/HH8IDWvWzPM1a9bw+OOPs2vXLmw2G+eeey733XcfPXv2BDwh8/7772ft2rVUV1fTt29fHnzwQYYPHw5ARkYGCxcuZNeuXURERDB69Giee+45AN566y0GDBhAt26e/zlt2LCB22+/HYDu3bsDcPvtt/PrX/+aUaNG8dOf/pT9+/ezYsUKJk2axOOPP86DDz7Iu+++S05ODgkJCUybNo1f/epXOBwOoOEt/nnz5lFSUsLIkSN55plnqKmpYcqUKcyfP997TEhICBMmTODNN99UQBURkUYZtxvKSuoCZwGmLnhSdMQTPgvzoagAKsubdkKbDWI7QVxniI7FCgv3CZSe1/qgWR8wjw2fvmUse7uMWC2mfX56NdW4517llx9t+9t/PH+Im6iiooJbbrmFs88+m4qKCv785z8zc+ZMMjIyqKysZPr06SQmJvLCCy/QpUsXvvzyS9x1txdWrVrFzJkzue2223jiiSeoqanh/fff9577008/ZfDgwd73I0aMYP78+fz5z39m7dq1AD69n08//TTz5s3jl7/8pXdbZGQkCxcuJDExka+++oo777yTqKgoZs+efcJr2rBhAwkJCbz66qvs27ePW2+9lYEDB3Ldddd5ywwdOpR//OMfTf6cREQkeBhjPOHzUDamMN/TA1p0BIoKjr4vLvCMu2yKsHCI6wQdO2PFdYaOnSGuM1bdKx07Q0wsli3kzF6YNFn7DKhtyBVXXOHz/i9/+QuDBw9m9+7dbNq0iSNHjvD2228TFxcHQGpqqrfsE088wZQpU/jNb37j3TZw4EDv9wcOHOCcc87xvg8NDSU6OhrLskhISGhQl7Fjx/Lf//3fPtvmzZvn/T4lJYVvvvmG5cuXf29AjY2N5cEHHyQkJIQ+ffowceJEPvroI5+AmpSURFZWljdsi4hI8DHOWsg7BLkHMblZntdDWZBzECrKmnaSmI7HBM5OEBdfF0Q71W2PxwqPOKPXIS2vfQbU0DBPT6affnZz7N+/n0cffZQtW7ZQUFDgDWxZWVns2LGDQYMGecPp8Xbs2OET+o5XVVVFhw5N7809tre1Xnp6Os899xz79++nvLwcl8t10tUA0tLSCAk5+r/Url27NpiU1aFDB9xuN9XV1dh1m0REpE0zpSV1IfQg5GYdDaH5uSeeVGRZnrDZqYun17Oup/PYHlBi47Dsjta9GGkV7fJffsuymnWb3Z9mzJhBt27deOSRR0hMTMTtdjNhwgRqa2tPGi5Ptr9Tp04UFRU1uS4REb7/A928eTOzZ8/m17/+NePHjyc6Opply5bx7LPPfu956seaHuv45/8WFhYSHh5OeHg4tbW1Ta6jiIj4h3E6IS8XDh3E5GR5XnOzIDcLyktPfGBYB0hMxuraHZK6Q9dkrMTukNANK6x5nToSPNplQG0rCgoK2LNnDwsWLGDUqFEAbNy40bu/f//+LFmyhMLCwkZ7Ufv3789HH33E1Vdf3ej5Bw0axJ49e3y2hYaG4nK5mlS/zz77jOTkZJ8xqVlZWU069mR27drlM/xAREQCgyktqQuhB+FQ1tEQmp8L3/fvR+cE6NodKynZ85rYHRKToWMnPbZaGlBADWAdO3YkLi6Ol156iYSEBLKysnj44Ye9+6dOncqTTz7JzTffzF133UVCQgLbt2+na9eujBgxgttvv52rr76as846iylTpuB0OlmzZo13fOi4ceO44447cLlc3lvuycnJlJeXs27dOgYOHOjtxWxMamoqWVlZLFu2jCFDhvD+++/z7rvvtsi1b9y4kYsuuqhFziUiIt/PGANVlVBa7P0ypcWeiUp17w8VF+D8bt/
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 2.05e-02\n",
" final error(valid) = 9.07e-02\n",
" final acc(train) = 9.97e-01\n",
" final acc(valid) = 9.75e-01\n",
" run time per epoch = 1.63\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=1.00\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABoh0lEQVR4nO3dd3xUVf7/8deZzCSkB0IgIQVClRapCqiAfRW+Cn4tq7sqlnV31cW2urqKbREX/a5lLVvdxfJzRawIKojICtItFGlSQk0gQHpP5vz+mGRgSMAEksxkeD8fjzwyc+fcO5+bY8ybM+eea6y1FhERERGRAOHwdwEiIiIiIodTQBURERGRgKKAKiIiIiIBRQFVRERERAKKAqqIiIiIBBQFVBEREREJKAqoIiIiIhJQFFBFREREJKAooIqIiIhIQFFAFREREZGA4vR3AU0tNzeXqqoqf5chTSQhIYGcnBx/lyHNQH0bnNSvwUt9G7xasm+dTidt27b98XYtUEuLqqqqorKy0t9lSBMwxgCePrXW+rkaaUrq2+Ckfg1e6tvgFah9q4/4RURERCSgKKCKiIiISEBRQBURERGRgKKAKiIiIiIBJegukhIRERH/s9ZSVFQUUBfeSP1KS0upqKhosuOFhYURFhZ2QsdQQBUREZEmV1RURFhYGKGhof4uRX6Ey+VqshWQrLWUlpZSXFxMZGTkcR9HH/GLiIhIk7PWKpyehIwxREREnPCa9AqoIiIiItKkatdXPV4KqCIiIiISUBRQj4N1u7Gb1uJesdDfpYiIiIgEHQXU47HuO9xP/x77n79jT3COhYiIiMiRFi1axMiRI3G73c1y/DvvvJMbb7yxwe3Ly8sZOnQoq1evbpZ6jqSAejx6nwoxcVCYD2u/9nc1IiIiEmSeeOIJJk6ciMPhiWp/+tOfOP/885vs+I8//jjPPvtsg9uHhYXxq1/9iieeeKLJajgWBdTjYEJCMMNGA+BeMt+/xYiIiIhfWGvrvVr9eNcUrd1vxYoVbNu2jbFjxzb6GA1dLiomJobY2NhGHXv8+PEsX76cH374odF1NZYC6nEyw8/xPFi1AltU4N9iREREApi1Flte5p+vRt4owFrLyy+/zPDhw+nWrRvnnXces2bNAmDx4sUkJyezYMECLrroItLT01m2bBmXX345Dz74II8++ij9+vXj6quvBmDJkiWMGTOG9PR0Bg4cyJQpU3wC7dH2mzlzJiNHjqRNmzYATJ8+nWeeeYZ169aRnJxMcnIy06dPByA5OZnXXnuNG264ge7du/P8889TXV3NPffcw7Bhw+jWrRtnnXUW//znP33O88iP+C+//HImTZrE5MmT6du3LwMGDOBPf/qTzz7t2rVj8ODBfPDBB436mR4PLdR/nExKF0jrBju2YJd/iTmn8f/KEREROSlUlOO+/Uq/vLXjxbchrE2D20+dOpVPPvmEJ598kvT0dJYuXcrEiROJj4/3tpk8eTIPP/wwaWlpxMTEADBjxgyuu+46b3jLysri2muv5corr+T5559n8+bN3HvvvYSFhXHPPfd4j3XkfgBLly5l3Lhx3ueXXHIJGzduZMGCBbz11lsAREdHe1//05/+xAMPPMCjjz5KSEgIbrebpKQk/vrXv9KuXTtWrlzJfffdR4cOHbjkkkuOeu4zZszglltu4aOPPuLrr7/mrrvuYujQoYwcOdLbZuDAgSxbtqzBP8/jpYB6AsyIc7A7tmAXzwcFVBERkVatpKSEf/zjH0yfPp0hQ4YA0LlzZ1asWMEbb7zBz372MwDuvfden9AG0KVLFx566CHv8z/+8Y906tSJJ554AmMM3bt3Jzs7mylTpnDXXXd555YeuR/Arl276Nixo/d5eHg4kZGRhISE0KFDhzp1jxs3jp/+9Kc+23772996H6elpbFy5Uo++uijYwbU3r17c/fddwPQtWtXpk2b5r1Yq1ZiYiK7du066jGaigLqCTCnjcLO+Dds34zdvR2T3NnfJYmIiASe0DDPSKaf3ruhNm3aRFlZmfej9lqVlZX069fP+zwjI6POvqeeeqrP882bNzN48GCfBeuHDh1KcXExWVlZJCcn17sfQFlZWaPuZV/fMV577TX+85//sGvXLsrKyqisrKRv377HPE7v3r19nnfo0IH9+/f7bGvTpg2lpaUNru14KaCeABMdA/2HwHdLsYvnY664wd8liYiIBBxjTKM+ZveX2iWdXnvtNRITE31eCw0NZfv27QBERETU2Tc8PNznubW2zt2U6psPe+R+4JnrmZ+f3+C6j6xn5syZPPbYY0yaNIkhQ4YQGRnJX/7yF7799ttjHsfp9I2Fxpg6y1zl5eX5THdoLgqoJ8hxxjm4v1uKXfoF9rLrMCEh/i5JREREjkPPnj0JCwtj9+7dDB8+vM7rtQG1IXr06MHHH3/sE1RXrlxJVFQUSUlJx9y3b9++bNq0yWeby+Vq8Jqoy5cvZ/DgwUyYMOG4aj+WDRs2/OhIbFPQVfwnqt9giIqBgjxYd+x/mYiIiEjgioqK4pe//CWPPvoob7/9NpmZmaxdu5Zp06bx9tuNm6Jw/fXXs2fPHh566CE2b97MnDlz+NOf/sQtt9zinX96NKNHj2bFihU+21JTU9mxYwdr167l4MGDlJeXH3X/Ll26sHr1ahYsWMCWLVt46qmnWLVqVaPqP5rly5czatSoJjnWsSigniDjdGFO93SU/epzP1cjIiIiJ+K+++7jrrvu4sUXX2T06NFcc801fPbZZ6SlpTXqOElJSbz++ut89913nH/++dx///1cffXV3HHHHT+672WXXcamTZvYvHmzd9vFF1/M6NGjufLKK+nfv/8xl3q69tprueiii/j1r3/N//zP/5Cbm8v111/fqPrrs3LlSgoLCxkzZswJH+vHGNvYBcICXE5OToMXqW0qdscW3H+4C5xOHP/3GiYyqkXfP1gZY0hKSiIrK6vR69hJYFPfBif1a/A6nr4tKCjwLsEkjTd58mQKCgp46qmnmv29XC5Xg7LTLbfcQr9+/Zg4ceKPtj1a/7tcLhISEn50f42gNoXUrpDSBaqqsCu+9Hc1IiIi0spNnDiRlJQUqqur/V0KAOXl5fTp04df/OIXLfJ+CqhNwBjjvbOUXaxbn4qIiMiJiYmJYeLEiYQEyMXXYWFh3HnnnfWuOtAcFFCbiBk2ChwO2LYJm9X8C9iKiIiIBCsF1CZiYtp6rugH7BJdLCUiIiJyvBRQm5BjxLkA2CVfYN2BMWdEREREpLVRQD1OG3JKWbqz0HdjxlCIiIK8g7B+tX8KExEREWnlFFCPw8rdRfxu7nb+sjyb8qpDd3UwLhfm9JEA2MX6mF9ERETkeCigHocBSZF0jHKRV1bNJz/k+rxmhtd8zP/tUmxJsT/KExEREWnVFFCPg9NhuLJfPADvrTtI2WGjqHTpDkmpUFmBXbnITxWKiIiItF4KqMdpdHosiVEu8suq+WTToVFUYwxmRM2aqEu0JqqIiIg03qJFixg5ciRut/vHGzfQnXfeyY033uh9fvnll/Pwww8fc5/TTz+df/zjH4Bnsf6hQ4eyenXzX2ejgHqcnA7DFTWjqO8fMYpqho0G44DN67F79/ipQhEREWmtnnjiCSZOnIjD0XxR7R//+Af33Xdfg9uHhYXxq1/9iieeeKLZaqrlPJ6d5syZw8yZM8nLyyMlJYUJEybQu3fvetsuW7aMuXPnkpmZSVVVFSkpKVxxxRUMGDDA22bBggW8/PLLdfZ94403CA0NPZ4SW8To9FhmrD1AdlElH2/K5bI+nsBq4uKh7wBY+w12yXzMuJ/7t1ARERFpctZaqqurcTp941RFRcVx5Zfa/VasWMG2bdsYO3ZsU5Var7Zt2zZ6n/HjxzN58mR++OEHevTo0QxVeTQ6li9evJhp06Zx2WWXMXXqVHr37s2UKVPYv39/ve3Xr19PRkYGDzzwAH/84x/p27cvU6dOZdu2bT7twsPD+fvf/+7zFcjhFHznor6/7iCllYeNovqsidp0w/MiIiKtjbWWsiq3X76stY2u9eWXX2b48OF069aN8847j1m
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABuD0lEQVR4nO3deXxU1f3/8dedTPaNrCSQAElYRCAoqwYrFLS2yreCX6qtWouVSgVq7WK/9au2VdGWLmpra2t/+i1tLWhdEBorxAgWBAuKFQQRiAEJJJGE7Pss5/fHkIEhiRAMmcnk/Xw8eEzmzrkz5+ZDwptz7znXMsYYREREREQChM3fHRAREREROZkCqoiIiIgEFAVUEREREQkoCqgiIiIiElAUUEVEREQkoCigioiIiEhAUUAVERERkYCigCoiIiIiAUUBVUREREQCigKqiIiIiAQUu7870NOqq6txOp3+7ob0kJSUFCoqKvzdDTkHVNvgpLoGL9U2ePVmbe12OwkJCadv1wt96VVOpxOHw+HvbkgPsCwL8NTUGOPn3khPUm2Dk+oavFTb4BWotdUpfhEREREJKAqoIiIiIhJQFFBFREREJKAooIqIiIhIQAm6SVJdMcbQ0NAQUBcAy+k1NzfT1tbW6Wvh4eGEh4f3co9ERETkXOs3AbWhoYHw8HDCwsL83RXphtDQ0E5XZTDG0NzcTGNjI9HR0X7omYiIiJwr/eYUvzFG4TSIWJZFVFSU1rwVEREJQv0moEpwal+/TURERIKHAqqIiIiIBJR+cw2qiIiISCAyxkBzI7S1gtuAMWDcJz1y4rn71Nc62dbZe3TxvsYY2kaPhag4f38bfCig9nNFRUXMmzePN954g5iYmB5//1/96lesXbuWV1999Yz3ufLKK1myZAlXXnllj/dHRESktxiXC+proLYaaqoxddVQWwW11Zjaas/29j9O/92mvXH2l2DOTX77/M4ooPZzy5Yt42tf+5o3nD777LP85Cc/Yc+ePT3y/t/85je5+eabu7XPHXfcwf3338/nP//5HumDiIhITzKtLSeFyypMbc2J4FnnCaPUVkFDnWe08kyFhAAWWBbYLLBsnq/bHz9pG4DtlNcs66RtJ+9z8jYLe9rgc/Ft+lQUUPux0tJSXn31Ve67775u79vW1nZGqyJER0d3exmoWbNmceedd/L6669zxRVXdLtvIiIiZ8M42uDYUaiqwNRUQ92JEU5zUiClpfnM39SyQdwAiE+A+ASs44/EJ5709fHXQnt/tSHLsohNT6ehrKzXP/uT9MuAaozxXOfhD2Hh3Zp5vmHDBn7961+zd+9ebDYbEydO5P7772fYsGGAJ2Q+8MADbNy4kdbWVkaMGMGDDz7IhAkTACgoKOCRRx5h7969REVFcdFFF/Hkk08C8I9//IPzzz+fQYMGAbBlyxa++93vAjB4sOd/U9/97nf53ve+x9SpU/nKV77CwYMHWbt2LVdccQW//vWvefDBB3nllVcoKysjNTWVuXPn8p3vfIfQ0FCg4yn+O+64g7q6OqZMmcITTzxBW1sbV199Nffdd593n5CQEGbOnMlLL72kgCoiIj3GOB1QVQGVRzGVH3vCaOVRzLGPofKoJ3yeqbAwiE88KXgmeoLogOPBMy4BBiRATByWLeScHVOw6pcBlbZW3Euu9ctH2377dwiPOOP2TU1N3HrrrZx33nk0NTXxy1/+kgULFlBQUEBzczPz5s0jLS2NP/3pT6SkpPDee+/hdrsBKCwsZMGCBdx+++385je/oa2tjddee8373lu3biU3N9f7fNKkSdx333388pe/ZOPGjQA+o59/+MMfuOOOO/j2t7/t3RYdHc0jjzxCWloae/bs4Qc/+AExMTEsWrSoy2PasmULqampPPfccxw4cIDbbruNMWPGcMMNN3jbXHDBBfz+978/4++TiIiIcbmg5hhUfoypPAqVH3u+bg+gNcdOf8o9PBKSUk6EzpNHOOMTIX6AJ5hGRGqpw3OofwbUPuSqq67yef6rX/2K3Nxc9u3bx9tvv82xY8d4+eWXSUhIACArK8vb9je/+Q1XX3013//+973bxowZ4/26pKSEcePGeZ+HhYURGxuLZVmkpqZ26Mu0adP45je/6bPtjjvu8H6dmZnJhx9+yJo1az4xoMbHx/Pggw8SEhLC8OHDmTVrFm+88YZPQE1PT+fIkSPesC0iImLcbqipgmPtI6Af+46GVlXA6f7dCAuDpIGQPBArKdXzmOx5JCkVomMVPANA/wyoYeGekUw/fXZ3HDx4kF/84he88847VFVVeQPbkSNH2L17N2PHjvWG01Pt3r3bJ/SdqqWlhYiIMx/NPXm0tV1+fj5PPvkkBw8epLGxEZfLddrVAEaOHElIyInTHQMHDuwwKSsiIgK3201rayt2e//8ayoi0l+ZpgY4uB/zUTFUlnsCaOXHngB6ujsI2u2Q2DF4WskDITkVYgcogPYB/fJffsuyunWa3Z/mz5/PoEGD+PnPf05aWhput5uZM2ficDhOGy5P93piYiI1NTVn3JeoqCif59u3b2fRokV873vfY8aMGcTGxrJ69Wr++Mc/fuL7tF9rejJzyimX6upqIiMjiYyMxOHw39IbIiJybhmnA0oOYg7ug+J9nsfyI13vYLNBYorPCCjJxwNo0kDPqXib7kPU151VQF23bh1r1qyhpqaGjIwM5s+fz+jRo7tsv3btWtatW8fRo0dJTk7mmmuuYfr06d7XnU4nL730Ev/617+oqqpi0KBB3HDDDVxwwQVn072gUVVVxf79+1m2bBlTp04FYNu2bd7XR48ezcqVK6muru50FHX06NG88cYbXHfddZ2+/9ixY9m/f7/PtrCwMFwu1xn176233iIjI8PnmtQjRz7hl0o37N271+fyAxER6fuMMVBRhjmwHw7swxzYB4eKO18DNCUNa9gISBsMSQNPjIAOSMIK0aSjYNftgLplyxaWL1/OggULGDVqFIWFhTz00EM88sgjJCcnd2hfUFDAypUrWbhwITk5ORQVFfHEE08QHR3NpEmTAHjmmWfYtGkTCxcuZPDgwezYsYNf/OIXLF261Oeayv5mwIABJCQk8PTTT5OamsqRI0f46U9/6n19zpw5PPbYY9xyyy3cddddpKamsmvXLgYOHMikSZP47ne/y3XXXcfQoUO5+uqrcTqdbNiwwXt96PTp07nzzjtxuVzeU+4ZGRk0NjayadMmxowZ4x3F7ExWVhZHjhxh9erVjB8/ntdee41XXnmlR45927ZtXHrppT3yXiIi4h+mvs5zqv7AXk8YPbAfGus7NoyOhawRWFkjsbJGwbARWLGBdWcj6V3dDqj5+fnMnDmTWbNmAZ5T0Dt27KCgoIDrr7++Q/uNGzdy2WWXkZeXB3iuN9y/fz+rV6/2BtRNmzYxd+5c79JIn/vc53j33Xf5xz/+we23337WB9fX2Ww2Hn/8cX70ox8xa9YssrOzeeCBB5g3bx7gGe1cuXIl9913H1/96ldxOp2MHDmSBx98EIC8vDyeeOIJHn30UX73u98RExPDRRdd5H3/WbNmYbfb2bRpEzNmzABg8uTJfPWrX+W2226jurrau8xUZ6644gq+8Y1vcPfdd9PW1sasWbO44447ePjhhz/VcZeVlfH222/zm9/85lO9j4iI9B7jaINDxceD6PHR0Yryjg3tdsjMxsoeBVkjsbJGQEq6rgsVH90KqE6nk+LiYubMmeOzPTc3l71793a6j8Ph6HDNYVhYGEVFRTidTux2Ow6Ho8Oi72FhYV2+Z/v7nnxtomVZ3pG+YPpLfumll/L666/7bDv5NHpGRgb/7//9vy73v/LKK7u8ZWhISAjf+ta3+OMf/+gNqAA/+9nP+NnPfubTduvWrZ2+xz333MM999zjs+0b3/iG9+vvfe97PgH30Ucf7fAe999/v8/zJ598kmuvvda7PuvpBFO9+4v2mql2wUV1DV6n1ta43fDxEcyBk0ZHSw6Cq5MJTGmDvSOjVtYIyMzCsneciyD+Eag/t90KqHV1dbjdbuLj4322x8fHdznZZvz48axfv54pU6a
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 3.41e-02\n",
" final error(valid) = 1.26e-01\n",
" final acc(train) = 9.93e-01\n",
" final acc(valid) = 9.64e-01\n",
" run time per epoch = 1.66\n"
]
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"learning_rate = 0.2 # learning rate for gradient descent\n",
"\n",
"init_scales = [0.1, 0.2, 0.5, 1.] # scale for random parameter initialisation\n",
"final_errors_train = []\n",
"final_errors_valid = []\n",
"final_accs_train = []\n",
"final_accs_valid = []\n",
"\n",
"for init_scale in init_scales:\n",
"\n",
" print('-' * 80)\n",
" print('learning_rate={0:.2f} init_scale={1:.2f}'\n",
" .format(learning_rate, init_scale))\n",
" print('-' * 80)\n",
" # Reset random number generator and data provider states on each run\n",
" # to ensure reproducibility of results\n",
" rng.seed(seed)\n",
" train_data.reset()\n",
" valid_data.reset()\n",
"\n",
" # Alter data-provider batch size\n",
" train_data.batch_size = batch_size \n",
" valid_data.batch_size = batch_size\n",
"\n",
" # Create a parameter initialiser which will sample random uniform values\n",
" # from [-init_scale, init_scale]\n",
" param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
" # Create a model with two affine layers\n",
" hidden_dim = 100\n",
" model = MultipleLayerModel([\n",
" AffineLayer(input_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, output_dim, param_init, param_init)\n",
" ])\n",
"\n",
" # Initialise a cross entropy error object\n",
" error = CrossEntropySoftmaxError()\n",
"\n",
" # Use a basic gradient descent learning rule\n",
" learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
" stats, keys, run_time, fig_1, ax_1, fig_2, ax_2 = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
" plt.show()\n",
"\n",
" print(' final error(train) = {0:.2e}'.format(stats[-1, keys['error(train)']]))\n",
" print(' final error(valid) = {0:.2e}'.format(stats[-1, keys['error(valid)']]))\n",
" print(' final acc(train) = {0:.2e}'.format(stats[-1, keys['acc(train)']]))\n",
" print(' final acc(valid) = {0:.2e}'.format(stats[-1, keys['acc(valid)']]))\n",
" print(' run time per epoch = {0:.2f}'.format(run_time * 1. / num_epochs))\n",
"\n",
" final_errors_train.append(stats[-1, keys['error(train)']])\n",
" final_errors_valid.append(stats[-1, keys['error(valid)']])\n",
" final_accs_train.append(stats[-1, keys['acc(train)']])\n",
" final_accs_valid.append(stats[-1, keys['acc(valid)']])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |\n",
"|------------|--------------------|--------------------|------------------|------------------|\n",
"| 0.1 | 1.87e-02 | 7.73e-02 | 1.00 | 0.98 |\n",
"| 0.2 | 1.85e-02 | 7.47e-02 | 1.00 | 0.98 |\n",
"| 0.5 | 2.05e-02 | 9.07e-02 | 1.00 | 0.98 |\n",
"| 1.0 | 3.41e-02 | 1.26e-01 | 0.99 | 0.96 |\n"
]
}
],
"source": [
"j = 0\n",
"print('| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |')\n",
"print('|------------|--------------------|--------------------|------------------|------------------|')\n",
"for init_scale in init_scales:\n",
" print('| {0:.1f} | {1:.2e} | {2:.2e} | {3:.2f} | {4:.2f} |'\n",
" .format(init_scale, \n",
" final_errors_train[j], final_errors_valid[j],\n",
" final_accs_train[j], final_accs_valid[j]))\n",
" j += 1"
]
2024-10-03 15:53:33 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models with three affine layers"
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 15,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.10\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABoQElEQVR4nO3deXwU9eH/8dfskc19AAkJJIEECHIYLjlERRTRKlTBelRrFc+2ainWaqWKVasoWkv9edTW2uLxVREVRFBBvCpyoxAQIVzhTCCBXJBzs/P7Y5MlSxJIQpLdLO/n47GP3Zn5zOxn8jH4zmc+8xnDNE0TERERERE/YfF1BUREREREalNAFRERERG/ooAqIiIiIn5FAVVERERE/IoCqoiIiIj4FQVUEREREfErCqgiIiIi4lcUUEVERETEryigioiIiIhfUUAVEREREb9i83UFWlp+fj5Op9PX1ZAWEhsbS25urq+rIa1AbRuY1K6BS20buNqybW02GzExMScv15yDL1q0iPnz51NQUEBiYiKTJk2iT58+9ZbdvHkz//d//8e+ffsoLy8nNjaWiy66iPHjx3uVW7FiBbNnz+bAgQN07tyZ6667jmHDhjW5bk6nk8rKyuaclvgZwzAAd5uapunj2khLUtsGJrVr4FLbBi5/bdsmB9Rly5Yxa9YsbrvtNnr37s2SJUuYPn06M2fOpFOnTnXKOxwOLrnkErp164bD4WDz5s288sorBAcHc9FFFwGQmZnJ3//+d6699lqGDRvGqlWrmDlzJo899hi9evU69bMUERERkXajyWNQFyxYwIUXXsiYMWM8vaedOnVi8eLF9ZZPSUnh3HPPJSkpibi4OEaNGsWAAQP48ccfPWUWLlxIeno6EydOpGvXrkycOJH+/fuzcOHC5p+ZiIiIiLRLTQqoTqeTHTt2MGDAAK/16enpbNmypVHH2LlzJ1u2bKFv376edZmZmaSnp3uVGzBgAJmZmU2pnoiIiIgEgCZd4i8qKsLlchEVFeW1PioqioKCghPu++tf/5qioiKqqqq4+uqrGTNmjGdbQUEB0dHRXuWjo6NPeMzKykqvsaaGYRASEuL5LO1fTTuqPQOP2jYwqV0Dl9o2cPlr2zbrJqn6TuJkJ/bYY49RVlZGZmYmb731FvHx8Zx77rkNljdN84THnDt3Lu+9955nOSUlhRkzZhAbG9uIM5D2JD4+3tdVkFaitg1MatfA1ZS2dblc7N+/n8rKSr+6+Ubq2rFjR4sdyzAMYmJi6nQ8NlWTAmpkZCQWi6VOz2ZhYWGdXtXjxcXFAZCcnExhYSFz5szxBNT6ektPdsyJEyd6zQRQE2Zzc3M1zVSAMAyD+Ph4cnJy9I9bgFHbBia1a+BqTtsWFRXhcDhwOBytXDs5VXa7vcVmQDJNk7y8PHJzcwkPD6+z3WazNaozsUkB1WazkZqaSkZGhtcUUBkZGQwdOrTRxzFN0ytEpqWlsWHDBq/AmZGRQVpaWoPHsNvt2O32Bo8vgcM0TbVpgFLbBia1a+BqStuapklQUFAr10j8jWEYhIaGUlhYeEr/DjT5Lv7x48fz+eef88UXX7B3715mzZpFXl4eY8eOBeCtt97ihRde8JT/9NNPWbNmDdnZ2WRnZ/Pll1/y0Ucfcd5553nKXHbZZaxfv5558+axb98+5s2bx4YNGxg3blyzT0xEREREfONUx7Q2eQzqyJEjKS4u5v333yc/P5+kpCSmTp3q6a7Nz88nLy/PU940Td5++20OHjyIxWIhPj6eX/ziF545UAF69+7NlClTeOedd5g9ezbx8fFMmTJFc6CKiIiInIYMM8Cuw+Tm5rb6k6RMpxO2/4hZmI9l2KhW/a7TmWEYJCQkkJ2drcuFAUZtG5jUroGrOW1bVFREZGRkK9dMWkJLjkGt0VD72+32Ro1BbfIlfgG2b8b11wcx3/4Xpsvl69qIiIhIgFm6dCmjRo3C1Uo5Y8qUKdxyyy2NLl9eXs7QoUPJyMholfocTwG1OXqcASGhcKQIdm3zdW1EREQkwDzxxBNMnjwZi8Ud1Z599lnP/T4t4bHHHmPmzJmNLu9wOPj1r3/NE0880WJ1OBEF1GYwbDbo436alrlhrY9rIyIiIr5w/KxENSoqKpp1vJr9Vq9ezc6dO71mN2qsxl6qj4yMPOkUocebOHEiq1atYuvWrU2uV1MpoDaT0X8IAOZGBVQREZETMU0Ts7zMN68mjoc2TZOXXnqJs88+mx49enDRRRexYMECAJYtW0bXrl356quvuPTSS0lJSWHlypVcddVVPPjggzzyyCP079+f6667DoDly5czbtw4UlJSGDRoENOnT/cKtA3tN3/+fEaNGkVwcDAAs2fP5m9/+xubNm2ia9eudO3aldmzZwPQtWtXXn/9dW6++WZ69uzJc889R1VVFffeey8jRoygR48enHfeefz73//2Os/jL/FfddVVTJs2jccff5x+/foxcOBAnn32Wa99OnTowJAhQ5g3b16TfqbN0awnSQkY/QZjAmRtxSwuwojQQHAREZF6VZTjuvsan3y15YV3wRHc6PIzZszgk08+4cknnyQlJYUVK1YwefJkOnbs6Cnz+OOP8/DDD5OcnOy5EWjOnDnceOONnvCWnZ3NL3/5S6655hqee+45tm3bxn333YfD4eDee+/1HOv4/QBWrFjBhAkTPMuXX345W7Zs4auvvuKdd94BICIiwrP92WefZerUqTzyyCNYrVZcLhcJCQm8/PLLdOjQgTVr1nD//fcTFxfH5Zdf3uC5z5kzhzvuuIOPPvqItWvXcs899zB06FBGjTp2Q/igQYNYuXJlo3+ezaWA2kxGh07QtRvs24W56XuM4ef7ukoiIiJyCkpKSnjllVeYPXs2Z511FgDdunVj9erVvPnmm/ziF78A4L777vMKbQDdu3fnoYce8iw/9dRTdOnShSeeeALDMOjZsyc5OTlMnz6de+65xzO29Pj9APbu3Uvnzp09yyEhIYSFhWG1Wj1P5qxtwoQJ/PznP/da94c//MHzOTk5mTVr1vDRRx+dMKD26dOH3//+9wCkpqYya9Ysz81aNeLj49m7d2+Dx2gpCqinwOg/BHPfLti4FhRQRURE6hfkcPdk+ui7GyszM5OysjLPpfYalZWV9O/f37Ocnp5eZ98BAwZ4LW/bto0hQ4Z4TVg/dOhQjh49SnZ2Nl27dq13P4CysrImPSK2vmO8/vrrvP322+zdu5eysjIqKyvp16/fCY/Tp08fr+W4uDivue0BgoODKS0tbXTdmksB9RQYZw7BXPQB5g/fY7pcGBYN6RURETmeYRhNuszuKzVTOr3++uvEx8d7bQsKCmLXrl0AhIaG1tk3JCTEa9k0zTpPU6pvPOzx+4F7rGdhYWGj6318febPn8+jjz7KtGnTOOusswgLC+Mf//gH33///QmPY7N5x0LDMOpMc1VQUOA13KG1KKCeih59IDgEigth13ZI0ZOvRERE2qu0tDQcDgf79u3j7LPPrrO9JqA2Rq9evfj444+9guqaNWsIDw8nISHhhPv269ePzMxMr3V2u73Rc6KuWrWKIUOGMGnSpGbV/UQ2b9580p7YlqAuv1PgNd2U7uYXERFp18LDw/nVr37FI488wrvvvktWVhYbN25k1qxZvPtu04Yo3HTTTezfv5+HHnqIbdu2sWjRIp599lnuuOMOz/jThowePZrVq1d7rUtKSmL37t1s3LiRw4cPU15e3uD+3bt3JyMjg6+++ort27fz9NNPs379+ibVvyGrVq3i/PNbf1ijAuop0nRTIiIigeP+++/nnnvu4YUXXmD06NFcf/31fPbZZyQnJzfpOAkJCbzxxhusW7eOsWPH8sADD3Ddddfxu9/97qT7XnnllWRmZrJt27GHAV122WWMHj2aa665hjPPPPOEUz398pe/5NJLL+U3v/kNP/3pT8nPz+emm25qUv3rs2bNGoqLixk3btwpH+tkDDPAHpicm5vb4s+TPRHzcB6uP94ChoHlb29ghGu6qZai53oHLrVtYFK7Bq7mtG1Dz2KXxnn88ccpKiri6aefbvXvstv
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABbYElEQVR4nO3deXxU9b3/8deZLfsGSQgQAmGVRZBFsGDFgkqrVMCmbrWKSrWCP0VbvbUurYharIrLra29WvFqQeqCICpgBC+biuDCpkBYZEkCCdn3Wc7vj0mGDAlCIMlMJu/n45HHzJxzZvI5+RLz9vs93+8xTNM0EREREREJEpZAFyAiIiIiUp8CqoiIiIgEFQVUEREREQkqCqgiIiIiElQUUEVEREQkqCigioiIiEhQUUAVERERkaCigCoiIiIiQUUBVURERESCigKqiIiIiAQVW6ALaG6FhYW4XK5AlyHNJCkpiby8vECXIS1AbRua1K6hS20bulqzbW02GwkJCSc/rhVqaVUulwun0xnoMqQZGIYBeNvUNM0AVyPNSW0bmtSuoUttG7qCtW01xC8iIiIiQUUBVURERESCigKqiIiIiAQVBVQRERERCSohN0nqREzTpKysLKguAJaTq6yspKamptF9YWFhhIWFtXJFIiIi0tLaTUAtKysjLCwMh8MR6FKkCex2e6OrMpimSWVlJeXl5URFRQWgMhEREWkpTQ6o27dvZ8mSJezdu5fCwkJ+//vfM3LkyJO+59VXX+XgwYMkJCRw+eWXc8kll/gd89lnn7Fw4UIOHz5Mp06duOaaa076uU1hmqbCaQgxDIPIyEiKi4sDXYqIiIg0syZfg1pdXU2PHj246aabTun4I0eO8Pjjj9O/f3/mzJnDlClTeOWVV/jss898x+zcuZNnnnmGCy64gL/+9a9ccMEFzJ07l127djW1PGln6tZvExERkdDR5B7UoUOHMnTo0FM+fsWKFSQmJjJ16lQAUlNT2b17N++99x7nnXceAO+//z6DBw9mypQpAEyZMoXt27fz/vvvM3PmzKaWKCIiIiJtWItfg7pr1y4GDx7st+2cc85h1apVuFwubDYbO3fu5LLLLvM7ZsiQIXzwwQcn/Fyn0+l3baJhGERERPieS/uh9m576tpMbRda1K6hS20bfEzTBLcb3C7weLyPbrf3y+M+9tzt8ntt+va7wO0Bj5uaswZiRMUF+pT8tHhALSoqIi7O/6Tj4uJwu92UlpaSkJBAUVER8fHxfsfEx8dTVFR0ws9dtGgRb731lu91eno6c+bMISkpqdHjKysrsdvtp30eoSorK4tJkybx+eefEx0d3eyf/8QTT/Dhhx+yatWqU37PJZdcwh133MHEiRMBfrDdHA4HnTt3PuM6JTBSUlICXYK0ALVr6GrJtjU9HnA5MZ1OzNpHnDX1XteA0/to+h5rwOXyPW/sPX7bnDWYLmfDz3HVgNN17HiPCzDAqP3CAAMwjNqQXm/fcccYDd5jqXccGI1+bu0xnroQ6Q2bpssbLk2Xq9HtzaV84i9Jue2/mu3zmkOrzOI//v+46pZ6+qH/EzNN8wf3T5kyxRdg6n9WXl4eLperwfE1NTWNzgZv72bPns0NN9xAWFgYTqeThQsX8uc//5lvv/22WT7/lltu4YYbbmjSz/7OO+9k1qxZXHzxxb66TqSmpoacnJzmKFVakWEYpKSkkJubq6XfQojate0yTRNcTqiqrPdVgVn73KiuIiY8jJKCAu9xLpc36Lld4HL5tuFygdvpt82s23fCY2tfezyB/jGEBqsNrBbvo8UK1npfFmvt/trnNhuGxYq1U9dW+7212Wwn7Ez0O66lC2msJ7SkpASr1errsWvsmOLi4gY9r/XZ7fYT9qzpP4ynJjs7m48++oiHH364ye+tqak5pVURoqKimrwM1Pjx47nnnnv45JNPmDBhwkmPV3u3XaZpqv1CkNq1dZgeN1RVecNkdaVfuDRrA6Z/4Kz0BU7f8ZUVx167f7hHrqh1TuuYuiBls4Pd7n202cFmO7bNavPtM+r212232cFuq/e+4/fVe4+t3nH22mAHYJreL0wwOe65x/uIWe+4eu+pf2z9Y37o8yynECwbbKv3HoulyZdhGIZBbOfOlOfkBNXvbYsH1D59+rBp0ya/bd988w09e/bEZvN++759+7Jlyxa/HtHNmzfTt2/fFqnJNE2oqW6Rzz4pR1iT/vGsWrWKZ599lh07dmCxWBg+fDizZs2iR48egDdkPvLII6xevZrq6mr69OnDo48+yrBhwwDvJLW5c+eyY8cOIiMjOe+883jppZcAeO+99xgwYABdunQBYP369dx9990AdO3aFYC7776b3/3ud4waNYprrrmGffv2sWzZMiZMmMCzzz7Lo48+yocffkhOTg7JyclMmTKFu+66y/c/D0899RTLli3jo48+AmDmzJmUlJQwcuRIXnzxRWpqapg0aRIPP/yw7z1Wq5Vx48bx7rvvnlJAFRFpy0yPG8rLoKwUykqgvASz7nlZKZSXYtZ7TkWZN1i21N8xRxiER/h9GeERhMfGU+V01Ya52kBXF/bqttUFRuuxYwybDaz1Q2C9Y+u/rntutYPVimHRzS7bsyYH1KqqKnJzc32vjxw5wr59+4iOjiYxMZH58+dTUFDA7bffDnivJ1y+fDmvvvoq48ePZ+fOnaxcuZI777zT9xmXXnopf/rTn3j33Xc599xz+eKLL9iyZQuzZs1qhlNsRE01ntuvbJnPPgnLf/8HwsJP+fiKigpuueUWzjrrLCoqKnjyySeZNm0aK1asoLKykoyMDFJSUnjllVdISkpiy5YteGqHSTIzM5k2bRp33HEHzz33HDU1NXz88ce+z/7888/9JrCNGDGChx9+mCeffJLVq1cD+PV+/uMf/2DmzJl+bRcVFcXcuXNJSUnh22+/5d577yU6Oprp06ef8JzWr19PcnIyb775Jnv37uW2225j4MCB/OpXv/Idc8455/D3v//9lH9OIiLBwHQ56wXNUiirC5d1r2vDZ3m9AFpRdmbf1GqFsNowGRHpfQw7FiyPD5uERWDUHVdvm/d5OEZd72E9hmGQ2LkzOUHWyyahq8kBdffu3X5Dwv/7v/8LwNixY5kxYwaFhYXk5+f79icnJ3Pffffx6quvsnz5chISErjxxht9S0wB9OvXj5kzZ/LGG2+wcOFCUlJSmDlzJn369DmTcwsJx69u8NRTTzF48GB27tzJxo0bOXr0KO+//z4JCQmAd7JYneeee45Jkybx+9//3rdt4MCBvucHDhzg7LPP9r12OBzExMRgGAbJyckNahkzZgy//e1v/bbVXwasW7du7N69myVLlvxgQI2Li+PRRx/FarXSu3dvxo8fz9q1a/0CaufOnTl06JAvbIuItCbT5fT2alaUecNkeRlmeRlUeJ/XfZn1gidlpd6h8tMVGQVRMRAdC9GxGL7nMRAVgxET690fGX0siIZHeHspNbteQkyTA+rAgQP5z3/+c8L9M2bMaLBtwIABzJkz5wc/97zzzvMLrS3KEebtyQwER9PuHb9v3z7++te/8uWXX1JQUOALbIcOHWLbtm0MGjTIF06Pt23bNr/Qd7yqqirCw0+9N/f45cIAli5dyksvvcS+ffsoLy/H7XafdDWAvn37YrUe+z/0Tp06NZiUFR4ejsfjobq62ncpiIhIU5im6b2+sjZg1oVN8/jgWVEXOOsdV111+t/YsEBUtH+49D33PhrRtWGzLnRGxWBYG/ZcirRX7fIvv2EYTRpmD6SpU6fSpUsXnnjiCVJSUvB4PIwbNw6n03nScHmy/R06dPjBpbyOFxkZ6fd606ZNTJ8+nd/97ndceOGFxMTEsHjxYv75z3/+4Oc0Nrnt+CGjwsJCIiIiiIiI0OoLIgLU/neishyKCqC4ELOoAIoLvK/LSuoFz7Jj12qeySiMYUBElDds1vZcGvWee0NoDEZt6Kzr+SQiUtdPipyhdhlQ24qCggJ27drFnDlzGDVqFAAbNmzw7e/fvz8LFiygsLCw0V7U/v37s3btWq666qpGP3/QoEENbifrcDh
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 5.52e-03\n",
" final error(valid) = 8.77e-02\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.77e-01\n",
" run time per epoch = 1.95\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.20\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABmqUlEQVR4nO3deXhU5d3G8e+ZJZN9YQkJJIFACIIYZBNFCyigsihg3bWKS7WiRaxL1aoVqyjaaq1rXVpcqiAqiICAuLyKyKoQEAUCBAQSSCAb2ZM57x+TDBkSIAlJZjLcn+uaK3PW+U0ewZvnPOc5hmmaJiIiIiIiPsLi7QJERERERGpSQBURERERn6KAKiIiIiI+RQFVRERERHyKAqqIiIiI+BQFVBERERHxKQqoIiIiIuJTFFBFRERExKcooIqIiIiIT1FAFRERERGfYvN2AU0tJyeHiooKb5chTaR9+/ZkZWV5uwxpBmpb/6R29V9qW//Vkm1rs9mIioo6/n4tUEuLqqiooLy83NtlSBMwDANwtalpml6uRpqS2tY/qV39l9rWf/lq2+oSv4iIiIj4FAVUEREREfEpCqgiIiIi4lMUUEVERETEp/jdTVIiIiLifaZpcujQIZ+68UbqVlxcTFlZWZOdz+Fw4HA4TugcCqgiIiLS5A4dOoTD4SAgIMDbpchx2O32JpsByTRNiouLKSwsJCQkpNHn0SV+ERERaXKmaSqcnoQMwyA4OPiE56RXQBURERGRJlU9v2pjNeoS/+LFi5k3bx65ubnExcUxceJEevbsWee+K1euZMmSJaSnp1NRUUFcXByXXXYZp59+unufr7/+mpdffrnWse+++67+9SUiIiJykmlwQF2+fDkzZszg5ptvpkePHixdupRp06bx3HPP0a5du1r7//zzz6SkpHDVVVcREhLCV199xfTp05k2bRqJiYnu/YKCgnj++ec9jvXVcGo6KyHtZ8zcg1jOGOLtckRERET8SoMv8c+fP5/zzjuP4cOHu3tP27Vrx5IlS+rcf+LEiYwbN46kpCRiY2O5+uqriY2NZe3atR77GYZBZGSkx8tnbd6I85kHMWe+jllZ6e1qRERExM8sW7aMIUOG4HQ6m+X8U6ZM4cYbb6z3/qWlpQwcOJDU1NRmqedIDQqoFRUVbN++nT59+nisT0lJYfPmzfU6h9PppLi4mNDQUI/1JSUlTJo0iT/84Q889dRT7NixoyGltazk3hAaBgV5sHmDt6sRERERP/PEE08wefJkLBZXVPvHP/7ByJEjm+z8jz32GM8991y993c4HPzhD3/giSeeaLIajqVBl/jz8/NxOp1ERER4rI+IiCA3N7de55g/fz6lpaWcddZZ7nUdO3Zk0qRJJCQkUFxczMKFC3n44Yd55plniI2NrfM85eXlHlMiGIZBUFCQ+31zMmw2zH6DMb9ZjLn2Oyyn9m3WzztZVbdjc7entDy1rX9Su/ovtW3dTNOksrISm80zTpWVlTVqmGL1catXr2bHjh2MHTu2wecoLy/Hbrcfd7/w8PAGn3vChAk8/vjjbN26le7dux93/xP576VRN0nV9YH1KWLZsmXMnj2be++91yPkJicnk5yc7F7u0aMHf/7zn/nss8+O2v08Z84cPvzwQ/dyYmIi06dPp3379g35Ko1WcuF4sr5ZjLFuBTF3T8WwaUrZ5hITE+PtEqSZqG39k9rVfzWkbYuLi91ByTRNKCttrrKOLcDRoKBkmiYvvvgib731Fvv376dr167cfffdXHTRRXz33XdMmDCBmTNn8uSTT7Jp0yZmzZrF3//+d0455RTsdjuzZ8+mR48efPLJJyxfvpypU6fy008/ERkZyRVXXMEDDzzgDrTjx4+v87j58+czbNgwwsLCAJg5cybPPvssAJ06dQLgX//6F1deeSXR0dE8/fTTfPnll3zzzTfcdttt3HPPPdx9990sW7aM/fv306lTJ2644QZuueUW9/f84x//SF5eHm+//TYAl112Gb169cLhcPC///0Pu93O9ddfz3333ec+pkOHDgwcOJB58+Zx//33H/vXHhBw1E7G+mhQqgoPD8disdTqLc3Ly6vVq3qk5cuX8+qrr/KnP/2JlJSUY+5rsVjo1q0bmZmZR91nwoQJHv+yqP6PLysr64Tn3qoPs20shEXizM9l71eLsfTu1+yfebIxDIOYmBgyMzP1JBI/o7b1T2pX/9WYti0rK3Nf6TRLS3DecXlzlnhUlhc/wHAE1nv/p556is8++4wnn3ySxMREVqxYwaRJk4iIiHB/96lTp/LII4+QkJBAeHg4pmkya9YsrrvuOubMmQPArl27uOqqq7j88sv55z//SVpaGvfeey92u527774boM7jysvLWb58OePHj3f//kaPHs1PP/3E119/zcyZMwEICwtzb3/66ad54IEHeOSRR7BarZSWltKhQwdeeeUV2rRpw5o1a7jvvvto27YtF198MeAacmmaprvHtbqWW265hU8//ZS1a9dy11130b9/f4YMOXxDeJ8+ffj++++PO7F/WVkZGRkZtdbbbLZ6dSY2KKDabDa6du1KamoqZ5xxhnt9amoqAwcOPOpxy5Yt45VXXuHOO++kX7/jBznTNNm5cyfx8fFH3cdutx+1C7tF/mK0WDD6n4X59WeYq7/F1GX+ZmOapv5n56fUtv5J7eq//L1ti4qKeP3115k1axYDBgwAoHPnzqxevZp3332Xa665BoB7773XI7QBdOnShYceesi9/NRTT9GxY0eeeOIJDMMgKSmJzMxMpk2bxl133eUeW3rkcQC7d++mQ4cO7uWgoCBCQkKwWq1ER0fXqnv8+PFceeWVHuvuuece9/uEhATWrFnDp59+6g6odenZsyd/+tOfAOjatSszZsxw36xVLSYmht27dx/1HDWdyH8rDb4uPXbsWF544QW6du1KcnIyS5cuJTs72z1w97333uPgwYPccccdgCucvvTSS0ycOJHk5GR372tAQADBwcEAzJ49m+7duxMbG+seg5qens5NN93U6C/WEowBv3EF1B9XYF57G4bt+GM+RERETjoBDiwvfuC1z66vLVu2UFJSwlVXXeWxvry8nN69e7uX67oSfOQN5GlpafTv399jeMHAgQMpLCwkIyPDfan+yOPAdeN4Q55lX9c53n77bd5//312795NSUkJ5eXlnHrqqcc8z5Fz2kdHR5Odne2xLjAwkOLi4nrX1lgNDqiDBw+moKCAjz76iJycHOLj43nggQfc3bU5OTkeX2bp0qVUVlby5ptv8uabb7rXDx06lNtvvx2AwsJCXnvtNXJzcwkODiYxMZGpU6eSlJR0ot+veXXvCRFRkJcDP6+H0wZ4uyIRERGfYxgGNOAyu7dUT+n09ttv1xpvGxAQwM6dOwHcHWw1Vd+oXc00zVpjX+vqUTzyOIA2bdqQl5dX77qPrGfevHlMnTqVhx9+mAEDBhASEsIrr7zCjz/+eMzzHHmzl2EYtaa5ys3NpW3btvWurbEadWfPBRdcwAUXXFDnturQWe3RRx897vkmTpzIxIkTG1OKVxkWK0a/wZhfLcBcvQxDAVVERKTVSk5OxuFwsGfPHo/ZhqpVB9T66N69OwsXLvQIqmvWrCE0NPS4Nw+deuqpbNmyxWOd3W6v95yoq1aton///h7ZqiG1H8svv/xy3J7YptDgifrFkzHwNwCY61ZiHmfAsIiIiPiu0NBQbr31Vh599FE++OAD0tPT2bhxIzNmzOCDDxo2ROH6669n7969PPTQQ6SlpbF48WL+8Y9/cMstt7jHnx7NsGHDWL16tce6+Ph4du3axcaNGzl48CClpUefFaFLly6kpqby9ddfs23bNp5++mnWr1/foPqPZtWqVQwdOrRJznUsCqgnqtspENkGigth07G7zkVERMS33Xfffdx11128+OKLDBs2jKuvvprPP/+chISEBp0nNjaWd955h3Xr1jFy5Ejuv/9+rrrqKu68887jHnvJJZewZcsW0tLS3OtGjx7NsGHDuPzyyznttNOYO3fuUY//3e9+x6hRo7jtttu46KKLyMnJ4frrr29Q/XV
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABocklEQVR4nO3deXxU9b3/8deZzEz2DbICARI2AQHZFVQsiLRKBSzV1q1UqVyBq3Sxt160vSrqRVtxaW3tTytWL2hdEIwKMYIFQUVxAQQDISBIEknIvs9yfn9MMjAkCIEkM5m8n48Hj8x8z/dMPidfom++55zvMUzTNBERERERCRAWfxcgIiIiInI8BVQRERERCSgKqCIiIiISUBRQRURERCSgKKCKiIiISEBRQBURERGRgKKAKiIiIiIBRQFVRERERAKKAqqIiIiIBBQFVBEREREJKFZ/F9DWSktLcTqd/i5D2khiYiJFRUX+LkPagcY2OGlcg5fGNnh15NharVbi4+NP3a8DaulQTqcTh8Ph7zKkDRiGAXjG1DRNP1cjbUljG5w0rsFLYxu8AnVsdYpfRERERAKKAqqIiIiIBBQFVBEREREJKAqoIiIiIhJQgu4mqZMxTZOqqqqAugBYTq22tpaGhoYWt4WGhhIaGtrBFYmIiEh76zIBtaqqitDQUOx2u79LkVaw2Wwtrspgmia1tbVUV1cTGRnph8pERESkvbQ6oO7atYs1a9awf/9+SktL+c1vfsO4ceNOuc9zzz3HN998Q3x8PFdeeSWXXXaZT58PP/yQl156iW+//Zbk5GR++tOfnvJzW8M0TYXTIGIYBhEREZSXl/u7FBEREWljrb4Gtb6+nr59+3LTTTedVv8jR47w4IMPMnjwYJYuXcqsWbN49tln+fDDD7199uzZw6OPPsrFF1/Mww8/zMUXX8yyZcvYu3dva8uTLqZp/TYREREJHq2eQR05ciQjR4487f5ZWVkkJCQwZ84cAHr16sW+fft44403OP/88wF48803GT58OLNmzQJg1qxZ7Nq1izfffJNFixa1tkQRERER6cTa/RrUvXv3Mnz4cJ+28847jw0bNuB0OrFarezZs4crrrjCp8+IESN46623Tvq5DofD59pEwzAIDw/3vpauQ+Pd+TSNmcYuuGhcg5fG1pdpmmC6wW1C0+uTtbmbtn1Hm/v4z2i8mdv7HsD09ME8bj/f95637qYCW+zj/TzT3dju6d5wzlCMyNiO/SGeQrsH1LKyMmJjfQ86NjYWl8tFZWUl8fHxlJWVERcX59MnLi6OsrKyk37uqlWreOWVV7zv09PTWbp0KYmJiS32r62txWaznfFxBKvc3FxmzJjBRx99RFRUVJt//kMPPcTbb7/Nhg0bTnufyy67jNtuu43p06cDfOe42e12UlNTz7pO8Y+UlBR/lyDtQOMavM50bE23G5wOTKcD0+EEl7PxtcO3vem10wEOB6bLiel0grOxv7Np38a+rqZtnu3ebS4npsN5bLur6bOPa3M6vNuOfY/GfV1OnyBput2+4TLIVF/xY1Lm/5e/y/DRIXfxn/gvrqalnr7rX2KmaX7n9lmzZnkDzPGfVVRUhNPpbNa/oaGhxbvBu7olS5bws5/9jNDQUBwOBy+99BL/8z//w+7du9vk82+55RZ+9rOftepnf/vtt3PvvfcydepUb10n09DQQEFBQVuUKh3IMAxSUlIoLCzU0m9BROPauZimCQ31UFsDtdVQW4PZ+JWaaqitxmx6XVdDWEgIdVWV0BTmXA7va1xNXx3gch3X3tgnCENdqxgWsBhgGJ7XTV992hr7YXjaMY61GZzQftwfn/fHfcZp7mtYDKypPTvs99ZqtZ50MtGnX3sX0tJMaEVFBSEhId4Zu5b6lJeXN5t5PZ7NZjvpzJr+w3h68vPzeeedd7jnnntavW9DQ8NprYoQGRnZ6mWgpkyZwh133MF7773HtGnTTtlf4915maap8QtCGtf2Z5omOBp8wqU3ZDaGS0+bJ2CaLfSjttoTJk9TbVsegMUCViuE2Bq/Wk/yNQSsNs/XEE+70fS6aXtIUz/rsdch1uP+hPh8jnGy/sd/niXEU+PxgdJyGuGyhf6BflmEYRhEp6ZSVVAQUL+37R5QBwwYwLZt23zavvjiCzIyMrBaPd9+4MCB7Nixw2dGdPv27QwcOLBdavL+q9Ef7KGt+su6YcMGHnvsMXJycrBYLIwePZp7772Xvn37Ap6Qed9997Fx40bq6+sZMGAA999/P6NGjQI8N6ktW7aMnJwcIiIiOP/883n66acBeOONNxgyZAg9evQAYMuWLfzqV78CoGfPngD86le/4te//jXjx4/npz/9KQcOHGDt2rVMmzaNxx57jPvvv5+3336bgoICkpKSmDVrFr/85S+9/3j405/+xNq1a3nnnXcAWLRoERUVFYwbN46nnnqKhoYGZsyYwT333OPdJyQkhMmTJ/P666+fVkAVEQkGpsMBFWXeP2ZFqed1ZXnj+8ZtVRWeWU1X87OFZ8SwQHg4hEdCeETjn0iMxq9ERGKERxKTmEhFdQ1mY5gzvCHvhJB5fJv1uKBoPRYCDUtI29QuQavVAbWuro7CwkLv+yNHjnDgwAGioqJISEhgxYoVlJSUsHDhQsBzPeG6det47rnnmDJlCnv27GH9+vXcfvvt3s+4/PLL+cMf/sDrr7/O2LFj+fjjj9mxYwf33ntvGxxiCxrqcS+8un0++xQsf/4XhIaddv+amhpuueUWzjnnHGpqavjjH//I3LlzycrKora2ltmzZ5OSksKzzz5LYmIiO3bswN14KiU7O5u5c+dy22238fjjj9PQ0MC7777r/eyPPvrI5wa2MWPGcM899/DHP/6RjRs3AvjMfv7tb39j0aJFPmMXGRnJsmXLSElJYffu3fz2t78lKiqK+fPnn/SYtmzZQlJSEi+//DL79+/n1ltvZejQoVx33XXePueddx5//etfT/vnJCISiMz6evAGzeNC5vGBs8ITQKmtbv03MAwIi/AJloRHYIRHQkTT+2PB04iI9OlHRCSEhp9y4iRQZ9kkeLU6oO7bt8/nlPA///lPACZNmsSCBQsoLS2luLjYuz0pKYk777yT5557jnXr1hEfH8/Pf/5z7xJTAIMGDWLRokW8+OKLvPTSS6SkpLBo0SIGDBhwNscWFE5c3eBPf/oTw4cPZ8+ePXzyySccPXqUN998k/j4eMBzs1iTxx9/nBkzZvCb3/zG2zZ06FDv60OHDjFs2DDve7vdTnR0NIZhkJSU1KyWiRMn8h//8R8+bccvA5aWlsa+fftYs2bNdwbU2NhY7r//fkJCQujfvz9Tpkzh/fff9wmoqampHD582Bu2RUQCgel0QE0VVFdDdcUJQdP3NRXlUN/KE+MhVoiJ8/4xYmKPvY+Ow4iJg+jYYwEzLBzD0uolzUUCXqsD6tChQ/nXv/510u0LFixo1jZkyBCWLl36nZ97/vnn+4TWdmUP9cxk+oO9dc+OP3DgAA8//DCffvopJSUl3sB2+PBhvvzyS84991xvOD3Rl19+6RP6TlRXV0dY2OnP5p64XBhAZmYmTz/9NAcOHKC6uhqXy3XK1QAGDhxISMix0zvJycnNbsoKCwvD7XZTX1/vvRRERKQtmKbpCY7VVY1/KqGmCrPF95WetppKTyhtbeAEz7WPPqEz7oT38cfeR0QG/DWLIh2hS/6f3zCMVp1m96c5c+bQo0cPHnroIVJSUnC73UyePBmHw3HKcHmq7d26dfvOpbxOFBER4fN+27ZtzJ8/n1//+tdccsklREdHs3r1av7+979/5+e0dHPbiaeMSktLCQ8PJzw8XKsviEiLTNOEutrjrstsHjKprsL0vm4Km1WtujmoGcPwzF5GRntmM08InUbjbKc3dIZHKHSKtFKXDKidRUlJCXv37mXp0qWMHz8egK1bt3q3Dx48mJUrV1JaWtriLOrgwYN5//33ueaaa1r8/HPPPbfZ42Ttdjuu0/wP98cff0yvXr18rkk9fPjwae17Kjk5OT6XH4hI12CapucGoJZOmVeeeAq9zHMn+5myWj0
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 5.21e-03\n",
" final error(valid) = 8.95e-02\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.76e-01\n",
" run time per epoch = 1.92\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.50\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABmCklEQVR4nO3dd3xUVf7G8c+5mUmvkEBCIBBKkGJoAoKKKGKDVXHRtYtl3V11sa2uXXGxoOuqP9s2VyyrIiqKoIKorCIgRSUgSA89kEB6n8z9/THJwJCAJARmMjzv1ytm5rb5Xg6JD+eee66xbdtGRERERCRAWP4uQERERERkXwqoIiIiIhJQFFBFREREJKAooIqIiIhIQFFAFREREZGAooAqIiIiIgFFAVVEREREAooCqoiIiIgEFAVUEREREQkoCqgiIiIiElAc/i6gueXn5+NyufxdhjSTpKQkcnNz/V2GHAFq2+Ckdg1eatvgdTTb1uFwkJCQ8MvbHYVajiqXy0V1dbW/y5BmYIwBPG1q27afq5HmpLYNTmrX4KW2DV6B2ra6xC8iIiIiAUUBVUREREQCigKqiIiIiAQUBVQRERERCShBd5OUiIiI+J9t25SUlATUjTfSsPLycqqqqprteGFhYYSFhR3WMRRQRUREpNmVlJQQFhZGaGiov0uRX+B0OpttBiTbtikvL6e0tJSoqKgmH0eX+EVERKTZ2batcHoMMsYQGRl52HPSK6CKiIiISLOqm1+1qRRQRURERCSgKKA2gW3b2GtW4F78jb9LEREREQk6CqhN8dMPuJ+6F/vtf2K79FhVERERaV7z5s1j2LBhuN3uI3L8W2+9lWuvvfaQt6+srGTgwIFkZWUdkXr2p4DaFD36QFwrKC6ErMX+rkZERESCzKOPPsr48eOxLE9Ue/rppxk5cmSzHf+RRx7hmWeeOeTtw8LC+P3vf8+jjz7abDUcjAJqE5iQEMzQ0wBwz5vj52pERETEH2zbbvBu9abOKVq33+LFi9m4cSOjR49u9DEOdbqo2NhY4uLiGnXsMWPGsGjRItauXdvouhpLAbWJzNAzPC9WfI+dv9u/xYiIiAQw27axKyv889XIBwXYts1LL73EkCFD6NKlC2eccQYzZswAYP78+aSmpjJ37lzOOecc0tPT+e677xg7diz33XcfDz/8ML179+bSSy8FYMGCBYwaNYr09HT69evHY4895hNoD7Tf9OnTGTZsGOHh4QBMmTKFv/3tb6xcuZLU1FRSU1OZMmUKAKmpqbz++utcc801dO3aleeee46amhruuOMOTjzxRLp06cIpp5zCv//9b5/z3P8S/9ixY3nggQeYOHEivXr1om/fvjz99NM++7Rq1YoBAwbw4YcfNurPtCk0UX8TmeRU6NYT1q7Env8FZtTF/i5JREQkMFVV4r7ZP/+ftF54F8LCD3n7SZMm8emnn/L444+Tnp7OwoULGT9+PK1bt/ZuM3HiRB588EHS0tKIjY0FYOrUqVx11VXe8LZjxw6uvPJKLr74Yp577jnWrVvHnXfeSVhYGHfccYf3WPvvB7Bw4UIuuOAC7/vzzjuP1atXM3fuXN555x0AYmJivOuffvpp7rnnHh5++GFCQkJwu92kpKTw97//nVatWrFkyRLuuusu2rRpw3nnnXfAc586dSo33HADH3/8MUuXLuW2225j4MCBDBs2zLtNv379+O677w75z7OpFFAPgzl5JPbaldjfzsE+96LDnvNLRERE/KesrIx//etfTJkyhRNOOAGAjh07snjxYt58800uv/xyAO68806f0AbQqVMn7r//fu/7J554gnbt2vHoo49ijKFr167k5OTw2GOPcdttt3nHlu6/H8DWrVtp27at931ERARRUVGEhITQpk2benVfcMEFXHLJJT7L/vSnP3lfp6WlsWTJEj7++OODBtQePXpw++23A9C5c2cmT57svVmrTnJyMlu3bj3gMZqLAuphMANOwn77n5CbA2t+gu69/V2SiIhI4AkN8/Rk+umzD9WaNWuoqKjwXmqvU11dTe/ee/8fn5mZWW/fPn36+Lxft24dAwYM8Om8GjhwIKWlpezYsYPU1NQG9wOoqKho1LPsGzrG66+/zttvv83WrVupqKigurqaXr16HfQ4PXr08Hnfpk0b8vLyfJaFh4dTXl5+yLU1lQLqYTBh4ZiBp2B/Mxt73ucYBVQREZF6jDGNuszuL3VTOr3++uskJyf7rAsNDWXTpk0AREZG1ts3IiLC571t2/WurDY0Hnb//cAz1rOwsPCQ696/nunTpzNhwgQeeOABTjjhBKKionj55Zf54YcfDnoch8M3Fhpj6k1zVVBQ4DPc4UhRQD1M5uSRnoD6/bfYl96AiYzyd0kiIiLSBBkZGYSFhbFt2zaGDBlSb31dQD0U3bp145NPPvEJqkuWLCE6OpqUlJSD7turVy/WrFnjs8zpdB7ynKiLFi1iwIABjBs3rkm1H8zPP//8iz2xzUF38R+u9AxI6QBVVdh6spSIiEiLFR0dze9+9zsefvhh3n33XbKzs1mxYgWTJ0/m3XcbN0Th6quvZvv27dx///2sW7eOWbNm8fTTT3PDDTd4x58eyPDhw1m82Hee9Q4dOrB582ZWrFjBnj17qKysPOD+nTp1Iisri7lz57J+/XqefPJJli1b1qj6D2TRokWceuqpzXKsg1FAPUzGGMzJnolz7Xmf+7kaERERORx33XUXt912Gy+88ALDhw/nsssu4/PPPyctLa1Rx0lJSeGNN97gxx9/ZOTIkdx9991ceuml3HLLLb+474UXXsiaNWtYt26dd9m5557L8OHDufjiizn++OMPOtXTlVdeyTnnnMMf/vAHfvWrX5Gfn8/VV1/dqPobsmTJEoqLixk1atRhH+uXGLuxE4QFuNzc3EOepLap3LbNvE3FlFXXcHa3BOziQtx3joOaGqyH/g/TvtMR/fxjhTGGlJQUduzY0eh57CSwqW2Dk9o1eDWlbYuKirxTMEnjTZw4kaKiIp588skj/llOp/OQstMNN9xA7969GT9+/C9ue6D2dzqdJCUl/eL+6kFtgiXbSnj62+1M/j6XosoaTEwc9BkEgP2tniwlIiIih2f8+PG0b9+empoaf5cCQGVlJT179uS3v/3tUfk8BdQmOCE1mvSEMMpdbt7/yfMUKavuMv/Cr7CPcA+uiIiIBLfY2FjGjx9PSEiIv0sBICwsjFtvvbXBWQeOBAXUJrCM4Yo+nu7pT9bks7usGnr1g/jWUFIMy478ExZEREREgpUCahMNaBdFj6QIqmps3l2xG2OFYIaeDoBbl/lFREREmkwBtYmMMVxZ24v6+boCdhRXYU46w7Pypx+w9+T6sToRERGRlksB9TD0ahtJv5Qoamx4JysP0yYFuh8Pto09/wt/lyciIiLSIimgHqa6saj/yy5iU0El5mRPL6r97RfYh/jEBxERERHZSwH1MHVtHc6QDjHYwH+X5WL6DYWISMjbCauX+7s8ERERkRZHAbUZXNYnEcvAd1tLWFvsxgwaBoA9TzdLiYiIiDSWAmozSIsLY3i652kJbyzL3fvo0+/nY5eW+LM0ERERaYHmzZvHsGHDcDfjcMFbb72Va6+91vt+7NixPPjggwfdZ/DgwfzrX/8CPJP1Dxw4kKysrGar6UAUUJvJJccn4rAgK6eMZWEpkNoRXNXYi772d2kiIiLSwjz66KOMHz8eyzpyUe1f//oXd9111yFvHxYWxu9//3seffTRI1ZTHQXUZtI2OpSzusYD8N9leXBSbS/qvM/9WJWIiIgcKbZt43K56i2vqqpq0vHq9lu8eDEbN25k9OjRh1XfL0lISCA6OrpR+4wZM4ZFixaxdu3aI1SVhwJqM7qodyKhIYY1uytY0ulEcDhg83rszRv8XZqIiIjf2LZNhcvtly/bthtd60svvcSQIUPo0qULZ5xxBjNmzABg/vz5pKamMnfuXM455xzS09P57rvvGDt2LPfddx8PP/wwvXv35tJLLwVgwYIFjBo1ivT0dPr168djjz3mE2gPtN/06dMZNmwY4eHhAKxbt47U1FT
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABp00lEQVR4nO3deXxU1f3/8dedzEzInpAdAhJWWWSTRcGKBZVWqSxfXGrVolKpQK211W/9qu1XRS1qi9XWFn9asfoFqQuCsQJGaEFQcQVZCkRAgSQkIfuemTm/PyYZGBKWsM1k8n4+Hjwyc+femzPzIfDOueecaxljDCIiIiIiQcIW6AaIiIiIiBxOAVVEREREgooCqoiIiIgEFQVUEREREQkqCqgiIiIiElQUUEVEREQkqCigioiIiEhQUUAVERERkaCigCoiIiIiQUUBVURERESCij3QDTjdSkpKcLlcgW6GnCbJyckUFhYGuhlyBqi2oUl1DV2qbeg6m7W12+0kJCQcf7+z0JazyuVy0dDQEOhmyGlgWRbgrakxJsCtkdNJtQ1NqmvoUm1DV7DWVpf4RURERCSoKKCKiIiISFBRQBURERGRoKKAKiIiIiJBJeQmSR2NMYbKysqgGgAsx1dTU0N9fX2Lr4WHhxMeHn6WWyQiIiJnWqsD6tatW1m2bBm7d++mpKSEX/3qV4wYMeK4x7z00kvs27ePhIQErrrqKi6//HK/fT766CMWL17MgQMHSE1N5Yc//OFxz9salZWVhIeH43Q6T9s55cxzOBwtrspgjKGmpoaqqiqioqIC0DIRERE5U1p9ib+uro5u3bpxyy23nND+BQUFPPbYY/Tt25e5c+cyefJkXnzxRT766CPfPjt27OCpp57i4osv5oknnuDiiy9m3rx57Ny5s7XNOypjjMJpCLEsi8jISK15KyIiEoJa3YM6ZMgQhgwZcsL7r1y5kqSkJKZNmwZARkYGX3/9NW+//TYXXHABAO+88w4DBw5k8uTJAEyePJmtW7fyzjvvcOedd7a2idKONK3fJiIiIqHjjE+S2rlzJwMHDvTbNnjwYHbt2uXr/dqxY0ezfQYNGsSOHTvOdPNEREREJMic8UlSpaWlxMXF+W2Li4vD7XZTUVFBQkICpaWlxMfH++0THx9PaWnpUc/b0NDgNzbRsiwiIiJ8j6X9UL3bnqaaqXahRXUNXart2WM8HvB4wBzx9cg/5iiPm+3jBo/HO0m82etu8Bjq+/bHiooP9Fv3c1Zm8R/5F7ppJv2x/qIbY475+pIlS3j99dd9zzMzM5k7dy7Jyckt7l9TU4PD4WhNs9uFnJwcJk6cyMcff0x0dPRpP//jjz/Ou+++y+rVq0/4mMsvv5w77riDCRMmAByzbk6nk/T09FNupwRGWlpaoJsgZ4DqGrraS22NMdBQj6mrwzTUY+pqMfV1mPr6xq+1hz1u/FN3lNfr6jANdYed6yjH19dBgOZVVE24mrTb/zsg3/toznhAbakntLy8nLCwMF8gammfsrKyZj2vh5s8ebIvwMChsFtYWNjixJn6+voWZ4O3d3PmzOHHP/4x4eHhNDQ0sHjxYv73f/+Xbdu2nZbz33bbbfz4xz9u1Wf/85//nIceeojLLrvM166jqa+vJy8v73Q0Vc4iy7JIS0sjPz9fS7+FENU1dJ3p2hpjoK4GqiqhqgLT+JWqSnA1gNsNbhe43ZjGrxz51XWU7W4X5ijbmx9/2LZgZdnAZjV+PeKPFdb4+Fiv+79m2cKwp3U+az+3drv9qJ2Jfvud6Yb06tWLzz77zG/bxo0b6d69O3a799v37t2br776yi9wbtq0id69ex/1vA6H46g9a/qH8cTk5uby3nvv8eCDD7b62Pr6+hNaFSEqKqrVy0CNGzeOu+++m3/961+MHz/+uPur3m2XMUb1C0Gqa+g6Xm2NxwM11YfCZVUFpqoCqiuP2Ob/nOpKbzAMRpYNnE5wHPbH77kDHOFYzkOPcTrB3rTf4dscWM7wI87T+LrDAXb7oZBpWS0ETNtpH2ZhWRYx6elU5uUF1c9tqwNqbW0t+fn5vucFBQXs2bOH6OhokpKSWLhwIcXFxcyePRvwXq5dsWIFL730EuPGjWPHjh2sWrWKn//8575zXHHFFfz2t7/lrbfeYvjw4XzyySd89dVXPPTQQ6fhLTZnjIH6ujNy7uNyhrfqL9fq1av54x//yPbt27HZbJx//vk89NBDdOvWDfCGzIcffpg1a9ZQV1dHr169eOSRRxg6dCjgXUVh3rx5bN++ncjISC644AKef/55AN5++2369etHp06dAFi/fj133XUXAJ07dwbgrrvu4pe//CUjR47khz/8IXv27GH58uWMHz+eP/7xjzzyyCO8++675OXlkZKSwuTJk/nFL37h++Xh97//PcuXL+e9994D4M4776S8vJwRI0Ywf/586uvrmThxIg8++KDvmLCwMMaOHctbb711QgFVREROL2MM1NZAZTlUlmMqK6jaZsOzfx+mssIXLk31YSGzqtIbNE8l5NgdEB0DUTEQFQ0RUVgOJ4SFQZi9lV8PPbbsrTzWbm8Mjd7vrbG3Z1+rA+rXX3/t1+P297//HYAxY8Ywa9YsSkpKKCoq8r2ekpLCvffey0svvcSKFStISEjg5ptv9i0xBdCnTx/uvPNOXn31VRYvXkxaWhp33nknvXr1OpX3dnT1dXhmX3Nmzn0ctj/9A8I7nPD+1dXV3HbbbZx77rlUV1fz5JNPMn36dFauXElNTQ1Tp04lLS2NF198keTkZL766is8Hg8A2dnZTJ8+nTvuuIOnn36a+vp63n//fd+5P/74Y7/VE4YNG8aDDz7Ik08+yZo1awD8ej//+te/cuedd/r9chEVFcW8efNIS0tj27Zt3HPPPURHRzNz5syjvqf169eTkpLCa6+9xu7du7n99tvp378/P/rRj3z7DB48mL/85S8n/DmJiMjRmYZ6qCiHyjJv4Kwoh8oKXwD1bitrfNy4/YjL3MWt+YbhEd6AGRXtC5tWU+iMioHIpucx/vs4dXdA8Wp1QO3fvz//+Mc/jvr6rFmzmm3r168fc+fOPeZ5L7jgAr/QKl5XXnml3/Pf//73DBw4kB07dvDpp59y8OBB3nnnHRISEgDvZLEmTz/9NBMnTuRXv/qVb1v//v19j/fu3ct5553ne+50OomJicGyLFJSUpq1ZfTo0fz0pz/123b4OrVdunTh66+/ZtmyZccMqHFxcTzyyCOEhYXRs2dPxo0bxwcffOAXUNPT09m/f78vbIuIiJdxu6GqHCoqDuvhbAyajSHTVB4RQOtqT+6bhXeA6FiIiqFDUjJ1didEHiN4RkV7w6ddk5Ll1JyVWfxBxxnu7ckM0PdujT179vDEE0/w+eefU1xc7Ats+/fvZ8uWLQwYMMAXTo+0ZcsWv9B3pNraWjp0OPHe3CPXqgXIysri+eefZ8+ePVRVVeF2u4+7GkDv3r0JCwvzPU9NTW02KatDhw54PB7q6up8Y5VFRNoLU1UBBfmYglwozIfCfExhHhTkQ1mr+jIPCbN7w2Z0DETHYsXENT6P9W23YmIhOu7Q88b/syzLIjk9nbwgG6cooatd/s9vWVarLrMH0rRp0+jUqROPP/44aWlpeDwexo4dS0NDw3HD5fFe79ix4zHXmj1SZGSk3/PPPvuMmTNn8stf/pJLLrmEmJgYli5dynPPPXfM87Q0ue3If/BKSkqIiIggIiJCqy+ISMgxHg+UFh8KnoX5UJCHKcyHwjyorjr2CSzL21sZfShQesNljC9wWtGxcHgI7RChsZTSZrTLgNpWFBcXs3PnTubOncvIkSMB2LBhg+/1vn37smjRIkpKSlrsRe3bty8ffPAB1157bYvnHzBgADt37vTb5nQ6cZ/gTMpPPvmEjIwMvzGp+/fvP6Fjj2f79u1+ww9ERNoa43LBwQIobAyeBYe+UnQAGuqPfYL4jpCchpWcDslpkJLufZyY7O3dtIUd+3iRNkwBNYjFx8eTkJDAK6+8QkpKCvv37+exxx7zvT5p0iSeeeYZbr31Vu69915SUlLYvHkzqam
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 7.20e-03\n",
" final error(valid) = 1.02e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.74e-01\n",
" run time per epoch = 1.96\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=1.00\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABn50lEQVR4nO3deXhU1eH/8feZJfsKBBISAmEJexAQUVRAFK1CVazVulRxqbZoqda6VbFqEYt+1fqrW1u1uFRFXBBxAUGxsgvKJsoe1gQCZN8nc39/TDIwJEASksxk+LyeZ56ZuXPvnXNzCHw49yzGsiwLEREREZEAYfN3AUREREREDqeAKiIiIiIBRQFVRERERAKKAqqIiIiIBBQFVBEREREJKAqoIiIiIhJQFFBFREREJKAooIqIiIhIQFFAFREREZGAooAqIiIiIgHF4e8CNLXc3FxcLpe/iyFNJCEhgZycHH8XQ5qB6jY4qV6Dl+o2eLVk3TocDuLj44+/XwuUpUW5XC4qKyv9XQxpAsYYwFOnlmX5uTTSlFS3wUn1GrxUt8ErUOtWt/hFREREJKAooIqIiIhIQFFAFREREZGAooAqIiIiIgEl6AZJiYiIiP9ZlkVRUVFADbyRupWWllJRUdFk5wsNDSU0NPSEzqGAKiIiIk2uqKiI0NBQQkJC/F0UOQ6n09lkMyBZlkVpaSnFxcVERkY2+jy6xS8iIiJNzrIshdOTkDGGiIiIE56TXgFVRERERJpUzfyqjaWAKiIiIiIBRQG1ESzLwtq4Dve33/i7KCIiIiJBRwG1MX74HveTf8Z6+19YJ9jHQkRERORICxcuZPjw4bjd7mY5/x133MGNN95Y7/3Ly8sZMmQIa9asaZbyHEkBtTF6D4DYeCjMh7Ur/F0aERERCTKPPfYYEydOxGbzRLWnnnqK0aNHN9n5H330UZ555pl67x8aGspvf/tbHnvssSYrw7EooDaCsdsxp58DgHvRPD+XRkRERPzBsqw6R6s3dk7RmuO+/fZbtm3bxtixYxt8jvpOFxUTE0NsbGyDzj1u3DiWL1/Opk2bGlyuhlJAbSRz5nmeF2tXYOUd9G9hREREAphlWVjlZf55NHChAMuyeOGFFzjjjDPo1q0b5513HrNnzwZg8eLFJCcns2DBAi688ELS0tJYtmwZl19+OQ888AAPP/ww/fr146qrrgJgyZIljBkzhrS0NAYOHMiUKVN8Au3Rjps1axbDhw8nLCwMgOnTp/P000+zfv16kpOTSU5OZvr06QAkJyfz+uuvc8MNN9C9e3eeffZZqqqquOuuuzj99NPp1q0bZ599Ni+//LLPdR55i//yyy9n0qRJTJ48mb59+3LKKafw1FNP+RzTpk0bBg8ezMyZMxv0M20MTdTfSCYpBbr1gi0/YS39CvOzX/i7SCIiIoGpohz37Vf45attz70LoWH13n/q1Kl89tlnPP7446SlpbF06VImTpxI27ZtvftMnjyZhx56iNTUVGJiYgCYMWMG1113nTe8ZWVl8etf/5orrriCZ599ls2bN3P33XcTGhrKXXfd5T3XkccBLF26lEsvvdT7/uKLL2bDhg0sWLCAd955B4Do6Gjv50899RT3338/Dz/8MHa7HbfbTVJSEi+99BJt2rRhxYoV3HPPPbRv356LL774qNc+Y8YMbrnlFj7++GNWrlzJnXfeyZAhQxg+fLh3n4EDB7Js2bJ6/zwbSwH1BJgzz8Pa8hPWovlYF1x2wnN+iYiIiP+UlJTw73//m+nTp3PqqacC0LlzZ7799lvefPNNrrnmGgDuvvtun9AG0KVLFx588EHv+7/97W907NiRxx57DGMM3bt3Jzs7mylTpnDnnXd6+5YeeRzArl276NChg/d9eHg4kZGR2O122rdvX6vcl156Kb/61a98tv3pT3/yvk5NTWXFihV8/PHHxwyovXv35o9//CMAXbt2Zdq0ad7BWjUSExPZtWvXUc/RVBRQT4AZchbWO/+G7F2wdYOnRVVERER8hYR6WjL99N31tXHjRsrKyry32mtUVlbSr18/7/uMjIxaxw4YMMDn/ebNmxk8eLBP49WQIUMoLi4mKyuL5OTkOo8DKCsra9Ba9nWd4/XXX+ftt99m165dlJWVUVlZSd++fY95nt69e/u8b9++Pfv37/fZFhYWRmlpab3L1lgKqCfAhEVgBp+JteRLrEXzMAqoIiIitRhjGnSb3V9qpnR6/fXXSUxM9PksJCSE7du3AxAREVHr2PDwcJ/3lmXVurNaV3/YI48DT1/P/Pz8epf7yPLMmjWLRx55hEmTJnHqqacSGRnJiy++yPfff3/M8zgcvrHQGFNrmqu8vDyf7g7NRQH1BJmzzvME1OXfYF15M6YV/AKKiIhIbenp6YSGhrJ7927OOOOMWp/XBNT66NGjB59++qlPUF2xYgVRUVEkJSUd89i+ffuyceNGn21Op7Pec6IuX76cwYMHM378+EaV/Vh++umn47bENgWN4j9RPfpC+yQoL8VaucjfpREREZFGioqK4tZbb+Xhhx/m3XffJTMzk3Xr1jFt2jTefbdhXRSuv/569uzZw4MPPsjmzZuZM2cOTz31FLfccou3/+nRjBw5km+//dZnW6dOndixYwfr1q3j4MGDlJeXH/X4Ll26sGbNGhYsWMCWLVt44oknWL16dYPKfzTLly9nxIgRTXKuY1FAPUHGGMywcwGwNCeqiIhIq3bPPfdw55138txzzzFy5EiuvvpqvvjiC1JTUxt0nqSkJN544w1WrVrF6NGjue+++7jqqqv4wx/+cNxjL7vsMjZu3MjmzZu92y666CJGjhzJFVdcQf/+/Y851dOvf/1rLrzwQn73u9/x85//nNzcXK6//voGlb8uK1asoLCwkDFjxpzwuY7HWA2dICzA5eTk1HuS2sZyuS0+25hLpdvisj5tsQ7ux33fzWC5sT32EqZ9x2b9/pOFMYakpCSysrIaPI+dBDbVbXBSvQavxtRtQUGBdwomabjJkydTUFDAE0880ezf5XQ665WdbrnlFvr168fEiROPu+/R6t/pdJKQkHDc49WC2girsop5eeU+3l6zn5ziSkybdtB3IADWoi/9XDoRERFp7SZOnEhKSgpVVVX+LgoA5eXl9OnTh9/85jct8n0KqI0wuGMk/TpEUFFl8dr3+wCwnVl9m3/xfCx3YPxhEhERkdYpJiaGiRMnYrfb/V0UAEJDQ7njjjvqnHWgOSigNoIxhpsGtccA32wv5Md9JTBgKERGQ94BWL/K30UUERERabUaNc3UnDlzmDVrFnl5eaSkpDB+/Phak7vWWLZsGXPnziUzMxOXy0VKSgq//OUvOeWUU7z7LFiwgBdeeKHWsW+++SYhISGNKWKz69omjNHdY5m7OZ+XV+7jyZ91xpw+Emv+x1gL52H6DfZ3EUVERERapQYH1MWLFzNt2jRuvvlmevbsybx585gyZQrPPPMM7dq1q7X/jz/+SEZGBldddRWRkZF89dVXTJ06lSlTppCWlubdLzw8nGeffdbn2EANpzWuyUjgm8xCNh8sY8G2As458zxPQF21DKuwABOtzuEiIiIiDdXgW/yzZ89m1KhRnHvuud7W03bt2jF37tw69x8/fjyXXHIJ3bt3JykpiauvvpqkpCRWrlzps58xhri4OJ9HoIsLd3BFf89qCq+vyqEssTOkdoMqF9byr/1cOhEREZHWqUEtqC6Xi61bt3LppZf6bM/IyGDDhg31Oofb7aa0tJSoqCif7WVlZUyYMAG3202XLl248sorfVpYj1RZWekzJYIxxttx98ilxZrTxb3aMGdTHtlFlXyw/gBXnzUa91tbsBbNx5x3cYuVIxjV1GNL1qe0DNVtcFK9Bi/VrTTGifx5aVBALSgowO12Exsb67M9NjaWvLy8ep1j9uzZlJeX+ywh1rFjRyZMmEBqaiqlpaV8+umnTJo0iSeffPKoy4F9+OGHvPfee973aWlpTJ06tV5zazW1u84L4e6Za5n5Yy6/uvJimPEq7NxK25J8Qrr1avHyBJsj10OW4KG6DU6q1+DVkLotLS3F6XQ2Y2m
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABud0lEQVR4nO3deXxU9b3/8deZzEz2hISsQIAE2TEIAiqoIKBUpUUoLlWrqFQqcC22tbfUpVVRL9qK1ltb+9OK1ctSQYRGhRjBguC+sAuERZEkkJB9n+X8/phkYEhYAiEzGd7PxyOPzJz5npnvyYeYt99zvt9jmKZpIiIiIiISICz+7oCIiIiIyNEUUEVEREQkoCigioiIiEhAUUAVERERkYCigCoiIiIiAUUBVUREREQCigKqiIiIiAQUBVQRERERCSgKqCIiIiISUBRQRURERCSgWP3dgdZWUlKC0+n0dzeklSQmJlJYWOjvbshZoNoGJ9U1eKm2wasta2u1WomLizt5u5a+8bZt21ixYgV79+6lpKSEX//61wwbNuyk+7z66qt8//33xMXF8aMf/YirrrrKp83HH3/M4sWLOXjwIMnJyfzkJz856fs2x+l04nA4WryfBB7DMABPTU3T9HNvpDWptsFJdQ1eqm3wCtTatvgUf11dHd27d+fOO+88pfaHDh3iySefpG/fvsydO5eJEyfyyiuv8PHHH3vb7Ny5k2effZbLL7+cp59+mssvv5x58+axa9eulnZPRERERNq5Fo+gDho0iEGDBp1y++zsbBISEpgyZQoAXbp0Yffu3fz73//m4osvBuDtt98mMzOTiRMnAjBx4kS2bdvG22+/zaxZs1raRRERERFpx876Nai7du0iMzPTZ9sFF1zAmjVrcDqdWK1Wdu7cybXXXuvTZuDAgbzzzjvHfV+Hw+FzKt8wDMLDw72Ppf1rrKPqGXxU2+CkugYv1TZ4BWptz3pALS0tJTY21mdbbGwsLpeLiooK4uLiKC0tpUOHDj5tOnToQGlp6XHfd9myZSxZssT7PD09nblz55KYmNia3ZcAkJKS4u8uyFmi2gYn1TV4qbbBK9Bq2yaz+I9N5Y0X4Z4orZumecLXJ06cyPjx45t8RmFhYbOz+E3TpKKiIqAuAJaTs9vt1NfXN/taaGgoYWFhbdwjaQ2GYZCSkkJBQYF+J4OI6hq8VNvg1da1tVqtpzSYeNYDanMjoeXl5YSEhBAVFXXcNmVlZU1GXo9ms9mw2WzNvtbcD7iiooLQ0FDsdnvLDkD8ymazNbsqg2ma1NTUUFlZSWRkpB96Jq3BNE39sQtCqmvwUm2DV6DV9qwv1N+zZ082bdrks23jxo1kZGRgtXryca9evdi8ebNPm02bNtGrV69W64dpmgqnQcQwDCIiIrTmrYiISBBqcUCtra1l37597Nu3D/AsI7Vv3z6KiooAWLBgAf/7v//rbX/VVVdRVFTkXQd19erVrF69mh/+8IfeNtdccw0bN27krbfe4sCBA7z11lts3ry5ycQpkWMF2kXdIiIicuZafIp/9+7dPPLII97n//znPwEYOXIkM2bMoKSkxBtWAZKSkpg9ezavvvoqq1atIi4ujjvuuMO7xBRA7969mTVrFosWLWLx4sWkpKQwa9YsevbseSbHJiIiIiLtkGEG0gUHraCwsLDZaxbLy8uJiYnxQ4/kTBzvGtRGqmv7ZBgGqamp5OfnB9Q1T3JmVNfgpdq2f6ZpgqMeaqqP+qqC2moS+wzgcGRsm9TWZrMFxiQpCWy5ublMnjyZDz/80DtprTX96U9/YuXKlbz33nunvM8111zDzJkzueaaa1q9PyIiIu2N6XZDXa1vsKypxmwImNRUQ/WRwGke08775Wp+3kbVD2+ECbe28VGdmALqOW7u3Lncfvvt3nC6ePFi/vCHP7B9+/ZWef+f//zn3HHHHS3aZ9asWTz66KP84Ac/aJU+iIiI+JPpcEBNJVRXeb/MmiOPva/VNIbLY4JlbTW01uimYUBYOIRHQHgkhEUQkpTaOu/dihRQz2F5eXm89957PtcUn6r6+vpTWhUhMjKyxctAjRkzhvvvv58PPviAcePGtbhvIiIircl01HtC48kCZpPtDd8dza/n3WIhIUeCZXgEhEVAeARG4/PwSAgP975uHPOc8EgIDcOwHJkjbxgGMampVOXnt04fW8k5GVBN04T6Ov98uD20RTPP16xZw3PPPceOHTuwWCxceOGFPProo3Tv3h3whMzHHnuMtWvXUldXR8+ePXn88ccZPHgwANnZ2cybN48dO3YQERHBxRdfzEsvvQTAv//9b/r160enTp0A2LBhA7/85S8B6Ny5MwC//OUv+dWvfsVFF13ET37yE/bt28fKlSsZN24czz33HI8//jjvvvsu+fn5JCUlMXHiRO677z7vGrXHnuKfNWsW5eXlDBs2jBdffJH6+nomTJjAI4884t0nJCSE0aNH89ZbbymgiojIWWG6XVBaAiVFmMWFUFwExYWex2UlvoHUefy5EKfMMI6ExIhIiIiC8EiMiIbnRwXPI4EzwjeQ2uznzOo152RApb4O98wb/PLRlv/9F4Se+t2Pqqurufvuu+nTpw/V1dX88Y9/ZOrUqWRnZ1NTU8PkyZNJSUnhlVdeITExkc2bN+N2uwHIyclh6tSp3Hvvvfz5z3+mvr6e999/3/ven3zyCZmZmd7nQ4YM4ZFHHuGPf/wja9euBfAZ/fzb3/7GrFmz+MUvfuHdFhkZybx580hJSWH79u385je/ISoqiunTpx/3mDZs2EBSUhJvvPEGe/fu5Z577qF///7ccsst3jYXXHABf/3rX0/55yQiItLINE2oroTDhc0EUM93Sg9Dw9/LU3LSgNm4vWFbeJT3uedUerjPyKWc2LkZUNuRY9eC/dOf/kRmZiY7d+7k888/5/Dhw7z99tvExcUBkJ6e7m375z//mQkTJvDrX//au61///7ex/v37+f888/3Prfb7URHR2MYBklJSU36MmLECH7+85/7bJs1a5b3cVpaGrt372bFihUnDKixsbE8/vjjhISEcN555zFmzBg+/PBDn4CamprKgQMHvGFbRESkkVlf5wmczY1+NgbQUzlTarFAXALEJWDEJ0J8AsQnYsTGQWS0AqYfnZsB1R7qGcn002e3xL59+3j66af58ssvKS4u9ga2AwcOsHXrVgYMGOANp8faunWrT+g7Vm1tbYvuZX/0aGujrKwsXnrpJfbt20dVVRUul+ukqwH06tWLkJAQ7/Pk5OQmk7LCwsJwu93U1dV57zgmIiLByzORyPfaTbOy3DPSWewbRKksP7U3jY6F+ERPAO3o+U58IkZDECW2A4Yl5OTvI23unPzLbxhGi06z+9OUKVPo1KkTTz31FCkpKbjdbkaPHo3D4ThpuDzZ6/Hx8ZSWlp5yXyIiInyef/HFF0yfPp1f/epXjBo1iujoaJYvX87f//73E75P47WmRzt27bWSkhLCw8MJDw8/4TqoIiISGI47U72ZSUNm9bHbK6G+hROJQsM8ITP+qNHPuKPCZ1xHjBYOCkngOCcDantRXFzMrl27mDt3LhdddBEAn376qff1vn37snDhQkpKSpodRe3bty8ffvghN954Y7PvP2DAAHbt2uWzzW6343K5Tql/n332GV26dPG5JvXAgQOntO/J7Nixw+fyAxEROftM0/QEx8pyqCiHynLMynKMygpKDRPXoYNQXXl2Z6qH+17TacR1PHLqPS4ROnqCKBGR58yEoXORAmoA69ChA3Fxcbz++uskJSVx4MABnnzySe/r1113Hc8//zx33XUXs2fPJikpiS1btpCcnMyQIUP45S9/yY033ki3bt2YMGECTqeTNWvWeK8PHTlyJPfffz8ul8t7yr1Lly5UVVWxbt06+vfv7x3FbE56ejoHDhxg+fLlDBw4kPfff5933323VY79008/5fLLL2+V9xIROVeZjnqoKPMEzspyzIbQeeR5GVRWNATSMqiqgGYGKUyg4lQ/NLyZSUNHPaZhAlFzk4s813nqlLsooAY0i8XCCy+8wMMPP8yYMWPIyMj
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.73e-02\n",
" final error(valid) = 1.31e-01\n",
" final acc(train) = 9.98e-01\n",
" final acc(valid) = 9.66e-01\n",
" run time per epoch = 1.92\n"
]
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"learning_rate = 0.2 # learning rate for gradient descent\n",
"\n",
"init_scales = [0.1, 0.2, 0.5, 1.] # scale for random parameter initialisation\n",
"final_errors_train = []\n",
"final_errors_valid = []\n",
"final_accs_train = []\n",
"final_accs_valid = []\n",
"\n",
"for init_scale in init_scales:\n",
"\n",
" print('-' * 80)\n",
" print('learning_rate={0:.2f} init_scale={1:.2f}'\n",
" .format(learning_rate, init_scale))\n",
" print('-' * 80)\n",
" # Reset random number generator and data provider states on each run\n",
" # to ensure reproducibility of results\n",
" rng.seed(seed)\n",
" train_data.reset()\n",
" valid_data.reset()\n",
"\n",
" # Alter data-provider batch size\n",
" train_data.batch_size = batch_size \n",
" valid_data.batch_size = batch_size\n",
"\n",
" # Create a parameter initialiser which will sample random uniform values\n",
" # from [-init_scale, init_scale]\n",
" param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
" # Create a model with three affine layers\n",
" hidden_dim = 100\n",
" model = MultipleLayerModel([\n",
" AffineLayer(input_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, output_dim, param_init, param_init)\n",
" ])\n",
"\n",
" # Initialise a cross entropy error object\n",
" error = CrossEntropySoftmaxError()\n",
"\n",
" # Use a basic gradient descent learning rule\n",
" learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
" stats, keys, run_time, fig_1, ax_1, fig_2, ax_2 = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
" plt.show()\n",
"\n",
" print(' final error(train) = {0:.2e}'.format(stats[-1, keys['error(train)']]))\n",
" print(' final error(valid) = {0:.2e}'.format(stats[-1, keys['error(valid)']]))\n",
" print(' final acc(train) = {0:.2e}'.format(stats[-1, keys['acc(train)']]))\n",
" print(' final acc(valid) = {0:.2e}'.format(stats[-1, keys['acc(valid)']]))\n",
" print(' run time per epoch = {0:.2f}'.format(run_time * 1. / num_epochs))\n",
" \n",
" final_errors_train.append(stats[-1, keys['error(train)']])\n",
" final_errors_valid.append(stats[-1, keys['error(valid)']])\n",
" final_accs_train.append(stats[-1, keys['acc(train)']])\n",
" final_accs_valid.append(stats[-1, keys['acc(valid)']])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |\n",
"|------------|--------------------|--------------------|------------------|------------------|\n",
"| 0.1 | 5.52e-03 | 8.77e-02 | 1.00 | 0.98 |\n",
"| 0.2 | 5.21e-03 | 8.95e-02 | 1.00 | 0.98 |\n",
"| 0.5 | 7.20e-03 | 1.02e-01 | 1.00 | 0.97 |\n",
"| 1.0 | 1.73e-02 | 1.31e-01 | 1.00 | 0.97 |\n"
]
}
],
"source": [
"j = 0\n",
"print('| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |')\n",
"print('|------------|--------------------|--------------------|------------------|------------------|')\n",
"for init_scale in init_scales:\n",
" print('| {0:.1f} | {1:.2e} | {2:.2e} | {3:.2f} | {4:.2f} |'\n",
" .format(init_scale, \n",
" final_errors_train[j], final_errors_valid[j],\n",
" final_accs_train[j], final_accs_valid[j]))\n",
" j += 1"
]
2024-10-03 15:53:33 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models with four affine layers"
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 18,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.10\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAp8AAAF0CAYAAABlg1LUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABbHUlEQVR4nO3deXxU1cH/8c+dJZN9T0hICCRAkMUoIC5g1YpIFapikbauuFRbF9wqlbpULWLxqVqrdnnUFpcfikpRxKqI1KciCuIWUfZ9CZBANrLOcn9/zJIMCZCEJJNMvu/XK6/MnDn3zpkcE7+ce865hmmaJiIiIiIincAS6gaIiIiISM+h8CkiIiIinUbhU0REREQ6jcKniIiIiHQahU8RERER6TQKnyIiIiLSaRQ+RURERKTTKHyKiIiISKdR+BQRERGRTqPwKSIiIiKdxhbqBrRGaWkpLpcr1M2QdpKWlkZxcXGomyEdQH0bntSv4Ut9G746s29tNhtJSUlHr9cJbWk3LpcLp9MZ6mZIOzAMA/D2qWmaIW6NtCf1bXhSv4Yv9W346qp9q8vuIiIiItJpFD5FREREpNMofIqIiIhIp1H4FBEREZFO060WHImIiEjXUFdXR11dXaibIUdRU1NDfX19u53P4XDgcDiO6RwKnyIiItIqVVVVGIZBXFxcYEW1dE12u73ddgoyTZOamhqqqqqIiYlp83l02V1ERERaxeVyER0dreDZwxiGQXR09DHvua7wKSIiIq2i0NmzHWv/K3yKiIiISKdR+DyE6XJirv8Oz7IPQt0UERERkbCj8Hmommo8/zMD84WnMGuqQ90aERERCTPLli3jjDPOwOPxdMj5b7vtNq655poW16+rq2PUqFEUFhZ2SHsOpfB5CCMuAZJTvU92bA5tY0RERCTsPPzww0ybNg2LxRvDHnvsMcaNG9du53/ooYd44oknWlzf4XDwy1/+kocffrjd2nAkCp/NyekPgLltU4gbIiIiIqFgmmazq7rbumem/7jPP/+cLVu2MHHixFafo6VbJsXHx5OQkNCqc0+aNImVK1eyYcOGVrertRQ+m2H4wifbFT5FRESOxDRNzLra0HyZZqvb+pe//IXTTjuN/v37c84557Bo0SIAli9fTlZWFh999BHnnXceubm5rFixgsmTJ3PPPffwwAMPMGzYMH7+858D8OmnnzJhwgRyc3MZPnw4s2bNCgqrhztu4cKFnHHGGURGRgIwb948Hn/8cb7//nuysrLIyspi3rx5AGRlZfHiiy9y9dVXM2DAAJ588kncbjd33nknp556Kv379+cHP/gBzz33XNDnPPSy++TJk7nvvvuYOXMmQ4cO5cQTT+Sxxx4LOiY5OZmRI0fy5ptvtupn2hbaZL4ZRt/+mGjkU0RE5Kjq6/DcPCUkb215+jVwRLa4/uzZs3n33Xd55JFHyM3N5bPPPmPatGmkpKQE6sycOZP777+fnJwc4uPjAXj99de58sorA8GsqKiIK664gilTpvDkk0+yceNG7rrrLhwOB3feeWfgXIceB/DZZ59x0UUXBZ5fcMEFrFu3jo8++ohXX30VgLi4uMDrjz32GDNmzOCBBx7AarXi8XjIzMzkb3/7G8nJyaxatYrp06eTnp7OBRdccNjP/vrrr3P99dfz9ttv88UXX3D77bczatQozjjjjECd4cOHs2LFihb/PNtK4bM5/pHPPbsw62oxWvEftoiIiHQ91dXVPPvss8ybN4+TTjoJgL59+/L555/z8ssvc9lllwFw1113BQUygH79+nHvvfcGnv/hD3+gd+/ePPzwwxiGwYABA9izZw+zZs3i9ttvD8zlPPQ4gJ07d9KrV6/A86ioKGJiYrBaraSnpzdp90UXXcTPfvazoLJf//rXgcc5OTmsWrWKt99++4jhc/Dgwdxxxx0A5OXlMWfOnMDCJ7+MjAx27tx52HO0F4XPZhiJyZCQBOWlsGMLDBgc6iaJiIh0TREO7whkiN67pdavX09tbW3g8ref0+lk2LBhgecFBQVNjj3hhBOCnm/cuJGRI0cGbbY+atQoqqqqKCoqIisrq9njAGpra1t1b/TmzvHiiy/yyiuvsHPnTmpra3E6nQwdOvSI5xk8ODjLpKenU1JSElQWGRlJTU1Ni9vWVgqfh5PTH75dhbl9E4bCp4iISLMMw2jVpe9Q8W9r9OKLL5KRkRH0WkREBNu2bQMgOjq6ybFRUVFBz03TbHKXn+bmnx56HHjnVpaXl7e43Ye2Z+HChTz44IPcd999nHTSScTExPDXv/6Vr7766ojnsdmCI59hGE22eiorKwuagtBRFD4Pw+jbH/PbVaB5nyIiIt1efn4+DoeDXbt2cdpppzV53R8+W2LgwIH8+9//Dgqhq1atIjY2lszMzCMeO3ToUNavXx9UZrfbW7zn58qVKxk5ciRTp05tU9uPZO3atUcdQW0PWu1+GP4V76ZWvIuIiHR7sbGx3HDDDTzwwAO89tprbN26ldWrVzNnzhxee6110wauuuoqdu/ezb333svGjRt5//33eeyxx7j++usD8z0P56yzzuLzzz8PKuvTpw/bt29n9erVHDhwgLq6usMe369fPwoLC/noo4/YtGkTjz76KN98802r2n84K1eu5Mwzz2yXcx2Jwufh9PUtOtq9HdPZtj29REREpOuYPn06t99+O08//TRnnXUWl156KR988AE5OTmtOk9mZiYvvfQSX3/9NePGjePuu+/m5z//ObfeeutRj7344otZv349GzduDJSdf/75nHXWWUyZMoXjjz/+iNsdXXHFFZx33nn86le/4sc//jGlpaVcddVVrWp/c1atWkVlZSUTJkw45nMdjWG2dpOsECouLm7xBqvHyjRNPHdcAQcrsPz2jxi5+Z3yvj2FYRhkZmZSVFTU6n3apGtT34Yn9Wv4akvfVlRUBLYhktabOXMmFRUVPProox3+Xna7vUXZ6frrr2fYsGFMmzbtqHUP1/92u520tLSjHq+Rz8MwDEN3OhIREZF2N23aNLKzs3G73aFuCuC9t/uQIUP4xS9+0Snvp/B5BEZf3elIRERE2ld8fDzTpk3DarWGuimA997ut912W7Or8zuCwucR+MOnRj5FRERE2ofC55H473S0axumq3PmmoqIiIiEM4XPI0ntBdEx4HbB7u2hbo2IiIhIt6fw2YySaidfFVVp0ZGIiIhIO1P4PERJtZNrF2ziof/soM7lCWw2r0VHIiIiIsdO4fMQKVE2kqNseEzYeKA2sNm8Rj5FREREjp3C5yEMw2BQaiQA60pqGkY+d27F7CL7cYmIiIh0VwqfzchP8e5ztb6kFtIzITIKnPVQtCPELRMREZHubtmyZZxxxhl4PJ52O+dtt93GNddcE3g+efJk7r///iMec8opp/Dss88C3o3mR40aRWFhYbu16XBsbTno/fffZ+HChZSVlZGdnc3UqVMZPHhws3W/++47HnzwwSblTzzxBFlZWW15+w43KNUfPmswLBbIyYP132Fu34SR3S+0jRMREZFu7eGHH2batGlYLB03Bvjss89it9tbXN/hcPDLX/6Shx9+mHnz5nVYu6AN4XP58uXMmTOH6667jkGDBrFkyRJmzZrFE088QWpq6mGP+9Of/kR0dHTgeVe+J2z/lEgsBuyvcVFS7SQ5pz/m+u9g+2YYPTbUzRMREZEOZpombrcbmy04KtXX1xMREdHq8/mP+/zzz9myZQsTJ05sr6Y2KykpqdXHTJo0iZkzZ7JhwwYGDhzYAa3yanXkXrRoEWeffTZjx44NjHqmpqayePHiIx6XkJBAYmJi4Ksj0/6xirRZ6JfoALzzPhu2W9oYymaJiIh0OaZpUuvyhOTLNM1Wt/Uvf/kLp512Gv379+ecc85h0aJFgHdwLSsri48++ojzzjuP3NxcVqxYweTJk7nnnnt44IEHGDZsGD//+c8B+PTTT5kwYQK5ubkMHz6cWbNm4XK5Au91uOMWLlzIGWecQWSkd33Jxo0bycrKYuPG4Izx97//nVNOOSUQgu+8805OPfVU+vfvzw9+8AOee+65I37WQy+7l5S
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABm2ElEQVR4nO3deXxU1cH/8c+dJSvZIAkJJEACISIKgggWrFBQqUplKbVq1aJSrcCjVKu/8rhVRS1qhar1UVsLVh+WFkVSbCEi9hFBpaAVVyAsJZAEEsi+znJ/f0wyZEiCSQiZyeT7fr18zdx7z52cO8fAl3PPOdcwTdNERERERCRAWPxdARERERGRxhRQRURERCSgKKCKiIiISEBRQBURERGRgKKAKiIiIiIBRQFVRERERAKKAqqIiIiIBBQFVBEREREJKAqoIiIiIhJQFFBFREREJKDY/F2BjlZcXIzT6fR3NaSDJCQkUFhY6O9qyBmgtg1OatfgpbYNXp3Ztjabjbi4uG8v1wl16VROpxOHw+HvakgHMAwD8LSpaZp+ro10JLVtcFK7Bi+1bfAK1LbVLX4RERERCSgKqCIiIiISUBRQRURERCSgKKCKiIiISEAJuklSLTFNk4qKioAaACzfrrq6mrq6umaPhYaGEhoa2sk1EhERkTOt2wTUiooKQkNDCQkJ8XdVpA3sdnuzqzKYpkl1dTWVlZVERkb6oWYiIiJyprQ5oH711VdkZWWxf/9+iouL+eUvf8no0aO/9ZxXX32VQ4cOERcXx1VXXcVll13mU+ajjz5i1apVHDlyhN69e3Pttdd+6+e2hWmaCqdBxDAMIiIiKC0t9XdVREREpIO1eQxqbW0tAwYM4Oabb25V+aNHj/LEE08wZMgQFi1axPTp01m6dCkfffSRt8zu3btZsmQJF198MU899RQXX3wxixcvZs+ePW2tnnQzDeu3iYiISPBocw/qiBEjGDFiRKvLZ2dnEx8fz6xZswBISUlh7969/O1vf+PCCy8E4O2332bYsGFMnz4dgOnTp/PVV1/x9ttvM3/+/LZWUURERES6sDM+BnXPnj0MGzbMZ995553He++9h9PpxGazsXv3bq688kqfMsOHD+fvf//7ma6eiIhIt2KaJphucJvQ8L7xPur3uxsdw8Rpt2IWH6s/v/GEY7PZt02PtfC+aQV9zzfx1sHzcQ31blzGbEXZhvdmo5/T6HNOPu90693M29aXO+kksxXlWvO5zZQxMagbcg5ERBNIznhALSkpISYmxmdfTEwMLpeL8vJy4uLiKCkpITY21qdMbGwsJSUlLX6uw+HwmTxjGAbh4eHe99I6OTk5zJw5kw8++IAePXp0+Of/9re/Zf369bzzzjutPueKK65g3rx5XHHFFa0qr/buehraTG0XXE63Xb3BxxuWTgpOpgluV/NlXC7PMZcLXE5wNmw7MRuOOeuPNS7navTe7fbuM09x7MQ+V9Mg1yT44T1mNpQ5uXyjMk2PmY1CYzOh0qd8C4HTPKlMO+W3+0wJdJVX/ghjxk/9XQ0fnTKL/+Q/rBqWejrVH2KmaZ7y+Jo1a1i9erV3Oy0tjUWLFpGQkNBs+erqaux2e1uq3S089dRT3HzzzcTFxQGwcuVK7r//fnJycjrk8+fNm8ett97apu/+7rvv5te//jU/+MEPAE55bkhICMnJyaddT/GPpKQkf1dB2sk0TVzHCnHm7sNxcD+Og/tw5O4nr6iwUXB0YzYESLcL3PUhzX0iaJr15Txl3P6+LGmOxQIYYDn57+RG243+vm7yd7fRwobRwvsm5RqOG57PNoz6w8aJ8+qPN5Q1vNsnlW1Uzjhp2/dzjJPq1/jtydfX0nU0f/6py7X8nfhutuacFn5+M+VtyX2JC7A/j894QG2uJ7SsrAyr1ertsWuuTGlpaZOe18amT5/OlClTvNsNvxCFhYU4nc4m5evq6ppdrqg7y8vLY8OGDTz00EPe76bhu/u276qurq5VqyI0rFXalu9+woQJlJWV8c477zB58uRTnltXV0d+vv5d39UYhkFSUhIFBQVamzjAmW43FBdh5h2EvFzM/FzP+/xcqK7yX8W8QcXiebVaPf9ZrGC11W/bwGo5sd3omHGKY9hsnlDm8zlWjIb3FuuJn2dYPMGtoR4Nr5b6oNPSMe++k67DYjkRsJor4/1c3/JNf1ajY00+p9FnnVy+IZzVH2vcK67fWQ9/Xf2Z+rmGYRDViW1rs9la7Ez0KXemK5KRkcGOHTt89n322Wekp6djs3l+/ODBg/n88899AufOnTsZPHhwi59rt9tb7Fn7ti/YNE2oq23tJXSskNA23f567733+N3vfseuXbuwWCycf/75PPLIIwwYMADwhMxHH32U999/n9raWjIyMnjssccYOXIk4JmktnjxYnbt2kVERAQXXnghf/zjHwH429/+xtlnn02fPn0A2Lp1K3fddRcAffv2BeCuu+7i7rvvZsyYMVx77bUcOHCA9evXM3nyZH73u9/x2GOP8Y9//IP8/HwSExOZPn06v/jFL7xtc/It/vnz51NWVsbo0aN56aWXqKurY+rUqTz88MPec6xWKxMnTuStt95i8uTJ3/oddfc/LLsy0zTVfgHCdLug6Cjk52Lm5UL+Qc9rwSGorWn+JIsFEpMhORUjuR9G31TiB5/NsZISTJ+QdHJAszQNUD7vv/1cDQ/pHCf/fup3NngFWtu2OaDW1NRQUFDg3T569CgHDhygR48exMfHs3z5co4fP868efMAuOyyy9iwYQOvvvoqkyZNYvfu3WzatIk777zT+xlXXHEFDz30EG+99RYXXHAB//rXv/j888955JFHOuASm1FXi3ve1Wfms7+F5fm/QGhYq8tXVVVx6623ctZZZ1FVVcXTTz/N7Nmzyc7Oprq6mpkzZ5KUlMTSpUtJSEjg888/x11/m2zjxo3Mnj2bO+64g2effZa6ujreffdd72d//PHHPhPYRo0axcMPP8zTTz/N+++/D+CzCP6LL77I/PnzfdouMjKSxYsXk5SUxNdff829995Ljx49mDNnTovXtHXrVhITE/nrX//K/v37uf322xk6dCg/+clPvGXOO+88/ud//qfV35OItI7pckFhQX0Q9fSEmvn1QbSFp7ZhtUHvPhjJqdAnFZL7YfRJhcQ+GI06CgzDIDQ5GSM//7TGOoqItDmg7t27l4cffti7/ec//xmA8ePHM3fuXIqLiykqKvIeT0xMZMGCBbz66qts2LCBuLg4brrpJu8SUwCZmZnMnz+flStXsmrVKpKSkpg/fz4ZGRmnc21B4eTVDX77298ybNgwdu/ezfbt2zl27Bhvv/22dwxpWlqat+yzzz7L1KlT+eUvf+ndN3ToUO/73Nxczj33XO92SEgIUVFRGIZBYmJik7qMGzeOn//85z77Gi8Dlpqayt69e8nKyjplQI2JieGxxx7DarUyaNAgJk2axAcffOATUJOTkzl8+LA3bItI25hOBxzNb9QjWh9IjxyGZoZBAWCzQ1KKJ3wmp9a/9oOEJAxbt3nwoIgEgDb/iTN06FD+8pe/tHh87ty5TfadffbZLFq06JSfe+GFF/qE1jMqJNTTk+kPIW17dvyBAwd46qmn+OSTTzh+/Lg3sB0+fJgvv/ySc845xxtOT/bll1/6hL6T1dTUEBbW+t7ck5cLA1i3bh1//OMfOXDgAJWVlbhcrm9dDWDw4MFYrVbvdu/evfn66699yoSFheF2u6mtrfUOBREJNqbb5RluVFvruY1eV+Pz3mx2f/1rbQ1mo/ec/L6utuVezJDQ+tvynh5Rb89ofG8Mi7X5c0REOlG3/JvfMIw23Wb3p1mzZtGnTx+efPJJkpKScLvdTJw4EYfD8a3h8tuO9+zZ85RLeZ0sIiLCZ3vHjh3MmTOHu+++mwkTJhAVFcXatWt5+eWXT/k5zY0dPnncS3FxMeHh4YSHh2tym3QZpmlCRRkcOYx5JK/+Nd+zr3GIbAicjhZuqXeU0HCfAGoke3pG6ZWIYWnzgwRFRDpNtwyoXcXx48fZs2cPixYtYsyYMQBs27bNe3zIkCGsWLG
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 2.03e-03\n",
" final error(valid) = 1.35e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.73e-01\n",
" run time per epoch = 2.25\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.20\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABmEUlEQVR4nO3deXxU1f3/8deZJTtZICEJWSQsiSwGBEXFXUSrUBXrUne01rbaUmyr1Vbcvopivxb91a0uLS5fFFFBBBREpBUREBQCsoRdlgQSyEb2ydzfH5MMDAmQhJCZDO/n45FHZu49985ncgi8OXPuucayLAsRERERkQBh83cBIiIiIiIHU0AVERERkYCigCoiIiIiAUUBVUREREQCigKqiIiIiAQUBVQRERERCSgKqCIiIiISUBRQRURERCSgKKCKiIiISEBRQBURERGRgOLwdwFtraioCJfL5e8ypI0kJCRQUFDg7zLkOFDfBif1a/BS3wav9uxbh8NBXFzc0du1Qy3tyuVyUVtb6+8ypA0YYwBPn1qW5edqpC2pb4OT+jV4qW+DV6D2basC6pw5c5gxYwbFxcWkpqYyevRo+vTp02TbdevW8X//93/s3LmT6upqEhISuPjiixk5cqRPu8WLFzNlyhR2795NYmIiN9xwA0OGDGlNeSIiIiLSgbU4oC5atIhJkyZx5513kpWVxbx58xg/fjwTJ04kPj6+UfvQ0FAuvfRSTjrpJEJDQ1m3bh2vvfYaYWFhXHzxxQDk5uby3HPPcf311zNkyBCWLl3KxIkTefzxx+ndu/exv0sRERER6TBafJHUzJkzueiiixg2bJh39DQ+Pp65c+c22T4jI4NzzjmHtLQ0unbtynnnnceAAQNYu3att82sWbPIzs5m1KhRpKSkMGrUKPr378+sWbNa/85EREREpENq0Qiqy+Vi8+bNXHXVVT7bs7OzWb9+fbPOsWXLFtavX8/Pf/5z77bc3FxGjBjh027AgAHMnj37sOepra31mWtqjCE8PNz7WDq+hn5UfwYf9W1wUr8GL/Vt8ArUvm1RQC0tLcXtdhMTE+OzPSYmhuLi4iMe++tf/5rS0lLq6uq49tprGTZsmHdfcXExsbGxPu1jY2OPeM5p06bxwQcfeJ9nZGQwYcIEEhISmv1+pGNISkrydwlynKhvg5P6NXi1pG/dbje7du2itrY2oC6+kcY2b97cZucyxhAXF9co17VUqy6SaiplHy15P/7441RVVZGbm8vkyZNJSkrinHPOOWx7y7KOeM5Ro0b5XGjV0LagoEDLTAUJYwxJSUnk5+frL7cgo74NTurX4NWavi0tLSU0NJTQ0NDjXJ0cK6fT2WYrIFmWRWFhIQUFBURFRTXa73A4mjWY2KKAGh0djc1mazSyWVJS0mhU9VBdu3YFID09nZKSEqZOneoNqE2Nlh7tnE6nE6fT2eQ+/cUYXCzLUp8GKfVtcFK/Bq+W9K1lWYSEhBzniiTQGGOIiIigpKTkmP4eaNFFUg6Hgx49epCTk+OzPScnh6ysrGafx7Isn1HOzMxMVq1a1eicmZmZLSlPRERERALAsc5pbfFV/CNHjuSLL75g/vz57Nixg0mTJlFYWMjw4cMBmDx5Mi+88IK3/WeffcayZcvIy8sjLy+PL7/8kk8++YRzzz3X2+byyy9n5cqVTJ8+nZ07dzJ9+nRWrVrV6MIpEREREQl+LZ6DOnToUMrKyvjwww8pKioiLS2NBx980DufoKioiMLCQm97y7J499132bNnDzabjaSkJG666SbvGqgAWVlZjB07lvfee48pU6aQlJTE2LFjA3YNVKu2BjatwyotxjbkPH+XIyIiIhJUjBVkE4UKCgqO+61OrY1rcE94AKKisT37FsbW4oFoaQZjDMnJyeTl5Wk+W5BR3wYn9Wvwak3flpaWEh0dfZwrC14LFy7kL3/5CwsWLMB2HHLG2LFjKS0t5V//+lezLpKqrq7mnHPO4Y033iA7O/uo5z9c/zudzmZdJKVk1Rrde0NIKOwvhV0/+rsaERERCTJPPvkkY8aM8YbTZ5991judsi08/vjjTJw4sdntQ0ND+fWvf82TTz7ZZjUciQJqKxiHE3r1AcBav+oorUVERCQYHXrRd4OamppWna/huG+//ZYtW7b4LKfZXM39FDk6OvqoKzAdatSoUSxdupQNGza0uK6WUkBtJZN1CgDWOgVUERGRI7EsC6u6yj9fLZxuYlkWL730EmeddRY9e/bk4osvZubMmQAsWrSIlJQUFixYwGWXXUZGRgZLlizhmmuu4a9//SuPPvoo/fv354YbbgDgm2++YcSIEWRkZHDqqacyfvx4n0B7uONmzJjBeeedR1hYGABTpkzh73//O2vWrCElJYWUlBSmTJkCQEpKCm+99Ra33347vXr14vnnn6euro4//vGPnHnmmfTs2ZNzzz2X119/3ed9jh07ljvuuMOnlnHjxvHEE0/Qr18/Bg4cyLPPPutzTOfOnRk8eDDTp09v0c+0NVq1UL94AqoFkLsay+3WPFQREZHDqanG/dvr/PLSthfeh9CwZrefMGECn376KU899RQZGRksXryYMWPG0KVLF2+bJ554gocffpj09HTvPMupU6dy6623esNbXl4et9xyC9dddx3PP/88Gzdu5L777iM0NJQ//vGP3nMdehzA4sWLfW4rf8UVV7B+/XoWLFjAe++9B0CnTp28+5999lkefPBBHn30Uex2O263m+TkZF555RU6d+7MsmXLuP/+++natStXXHHFYd/71KlTueuuu/jkk09Yvnw59957L6effjrnnXfggvBTTz2VJUuWNPvn2VoKqK11Ui8IDYeK/bBjK6T38HdFIiIicgwqKip47bXXmDJlCqeddhoAJ510Et9++y3vvPMON910EwD33XefT2gD6N69Ow899JD3+dNPP023bt148sknMcbQq1cv8vPzGT9+PPfee693bumhxwHs2LGDxMRE7/Pw8HAiIyOx2+3eGx8d7KqrruLnP/+5z7Y//elP3sfp6eksW7aMTz755IgBtU+fPvzhD38AoEePHkyaNImFCxf6vNekpCR27Nhx2HO0FQXUVjIOB/TuA6u/w1q/CqOAKiIi0rSQUM9Ipp9eu7lyc3OpqqryftTeoLa2lv79+3ufN3UV+4ABA3yeb9y4kcGDB/ssWH/66adTXl5OXl4eKSkpTR4HUFVV1aJbxDZ1jrfeeot3332XHTt2UFVVRW1tLf369Tviefr06ePzvGvXrj5LhwKEhYVRWVnZ7NpaSwH1GJisU7DqAyrDr/R3OSIiIgHJGNOij9n9xe12A55wl5SU5LMvJCSEbdu2ARAREdHo2PDwcJ/nlmU1uptSU/NhDz0OPHM9S0pKml33ofXMmDGDxx57jHHjxnHaaacRGRnJyy+/zPfff3/E8zgcvrHQGOP9mTQoLi72me5wvCigHgOTlV0/D/UHLHcdxmb3d0kiIiLSSpmZmYSGhrJz507OOuusRvsbAmpz9O7dm9mzZ/sE1WXLlhEVFUVycvIRj+3Xrx+5ubk+25xOZ6OweDhLly5l8ODBjB49ulW1H8m6deuOOhLbFnRlz7FI7wHhEVBZDj9u9nc1IiIicgyioqL41a9+xaOPPsr777/P1q1bWb16NZMmTeL991s2ReG2225j165dPPTQQ2zcuJE5c+bw7LPPctdddx114f0LLriAb7/91mdbWloaP/74I6tXr2bfvn1UV1cf9vju3buTk5PDggUL2LRpE8888wwrV65sUf2Hs3TpUs4///w2OdeRKKAeA2O3Q2/P/yKs9av9XI2IiIgcq/vvv597772XF154gQsuuIAbb7yRzz//nPT09BadJzk5mbfffpsVK1YwfPhwHnjgAW644QZ+//vfH/XYq6++mtzcXDZu3Ojddvnll3PBBRdw3XXXccoppxxxqadbbrmFyy67jN/85jf89Kc/paioiNtuu61F9Tdl2bJllJWVMWLEiGM+19HoVqfHyD13GtbUf8Mpp2Ef83C7ve6JQLdNDF7q2+Ckfg1eutVp+3viiScoLS3lmWeeOe6
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABcK0lEQVR4nO3deXxU5d3//9eZJftCICQBwr4vguwKViioVKUC3tRdi0q1IreltvZX61a3emNbtdr6rb2r4q03SF0oiBYQoTeKCwIqiwhEiEASIIHs+8y5fn9MMmRIgASSzGR4Px+PPDJzzpkzn5kL8O11rus6ljHGICIiIiISIhzBLkBEREREpC4FVBEREREJKQqoIiIiIhJSFFBFREREJKQooIqIiIhISFFAFREREZGQooAqIiIiIiFFAVVEREREQooCqoiIiIiEFAVUEREREQkprmAX0Nzy8/PxeDzBLkOaSceOHcnNzQ12GdIC1LbhSe0avtS24as129blcpGUlHTq41qhllbl8Xiorq4OdhnSDCzLAnxtaowJcjXSnNS24UntGr7UtuErVNtWl/hFREREJKQooIqIiIhISFFAFREREZGQooAqIiIiIiEl7CZJnYgxhpKSkpAaACynVl5eTlVVVYP7IiMjiYyMbOWKREREpKWdNQG1pKSEyMhIIiIigl2KNIHb7W5wVQZjDOXl5ZSWlhIbGxuEykRERKSlNDmgfv311yxbtoy9e/eSn5/PL3/5S8aMGXPK17zyyiscOHCApKQkrrjiCi655JKAYz799FMWL17MoUOHSE1N5dprrz3leZvCGKNwGkYsyyImJobCwsJglyIiIiLNrMljUCsrK+nRowe33HJLo44/fPgwTzzxBAMHDmT+/PnMmDGDl19+mU8//dR/zK5du3jmmWe48MIL+f3vf8+FF17I008/ze7du5tanpxlatdvExERkfDR5B7U4cOHM3z48EYfv2rVKpKTk5k1axYA6enpfPvtt7zzzjucd955ALz77rsMHTqUGTNmADBjxgy+/vpr3n33XebNm9fUEkVERESkDWvxMai7d+9m6NChAdvOPfdc1q5di8fjweVysWvXLi6//PKAY4YNG8Z7773X0uWJiIhICzDGgLHBNmAMYGoe23Ue1zw31Pw2dX6Oew51jq05n+HYuTCB22uPrd3uP9auLfAkx9aZUO1/bI79amhb7QPTwLa671d33/HnCQKDRdXAIRCTELQaGtLiAbWgoIDExMSAbYmJiXi9XoqLi0lKSqKgoIB27doFHNOuXTsKCgpOeN7q6uqAyTOWZREdHe1/LI2TkZHBzJkz+eijj4iLi2v28//xj39kxYoVvP/++41+zWWXXcbcuXO57LLLGnW82rvtqW0ztV14CbV2NbbtCyO21/fjtcHrqXns9e3zeus8rzmm5nhz/Lba5/W22TVBzK7/2LZ9dZxkf+A2b51t5rj93jrba4Nezec8PgDWC3rU2X7cc3+4swODnh14bJZlYXu9DYfHhkKmtBmll/8I68ofB7uMAK0yi//4f6xql3o62T9ixpiT7l+yZAlvvvmm/3nPnj2ZP38+HTt2bPD48vJy3G53U8o+K/z+97/nlltuISkpCYDXX3+d+++/n4yMjGY5/9y5c7ntttua9N3/4he/4Le//S0//OEPAU762oiICDp16nTGdUpwpKWlBbsEaQbGGOyjeVRlZlDy1afEeTwYrwe8HozHA15vzXMvxuPB2F7wH+MNPKbu9tpzeBs43q59Xc3+uu/hD5J2sL+asNIq36ZlgeUAh++3ZVlgAVjH9lk1x2HV7K85xnL4t9ceY9U+d1h1znHsGMvhCDg+4BgLLKxjddX9Xfua47f5P4MVuKuBfce2HX+e1udM6xJy/x63eEBtqCe0qKgIp9Pp77Fr6JjCwsJ6Pa91zZgxg6lTp/qf14bZ3NxcPB5PveOrqqoaXK7obJadnc3KlSt56KGH/N9N7Xd3qu+qqqqqUasi1K5V2pTvfuLEiRQVFfH+++8zZcqUk762qqqKnJycRp9bQoNlWaSlpXHw4EGtTdzGmIoyyNqHOZCJyfoOcyATsr6D0uJgl9Z4Dgc4nOB0gbP2sbPmd83zhrY5nViOuttqj3XUhKo6P015bjl8QanB/U5fcGloX91QZp3kuaNOKHM4jgU3/4+jzuM6AbE2rDl8QTG5Ywp5R49gAkLcyc5z3Laa0Olvg+PqDJWe97rO9F+ntvCvm2VZJLTiv8cul+uEnYkBx7V0IX379mXTpk0B27766it69eqFy+V7+379+rF169aAwLllyxb69et3wvO63e4T9qyd6gs2xkBVZWM/QvOKiGzSX8K1a9fypz/9iZ07d+JwOBg5ciSPPPIIPXr0AHwh89FHH2XdunVUVlbSt29fHn/8cUaMGAH4Jqk9/fTT7Ny5k5iYGM477zz+/ve/A/DOO+8waNAgOnfuDMDHH3/M3XffDUCXLl0AuPvuu/nFL37B2LFjufbaa8nMzGTFihVMmTKFP/3pTzz++OP861//Iicnh5SUFGbMmMHPf/5zf9scf4l/3rx5FBUVMWbMGF544QWqqqqYNm0aDz/8sP81TqeTSZMm8c9//pMpU6ac8jtSwGm7jDFqvxBlPB44lI3JyoSs7zBZ38GBTDhyuOEXWA5I7URU155Uer2BAc/lqhP2an/XbHOeepvVyONOFS5rw2QoBqGQZ1m4O3UCd1SLXb7XvwXBFWr/Hjc5oFZUVHDw4EH/88OHD5OZmUlcXBzJycksXLiQo0ePMnfuXAAuueQSVq5cySuvvMLkyZPZtWsXa9as4Wc/+5n/HJdddhkPPfQQ//znPxk9ejSff/45W7du5ZFHHmmGj9iAqkrsuVe1zLlPwfHnf0BkVKOPLysr47bbbmPAgAGUlZXxhz/8gdmzZ7Nq1SrKy8uZOXMmaWlpvPzyy3Ts2JGtW7di11zWWr16NbNnz+auu+7i2Wefpaqqig8++MB/7s8++yxgAtuoUaN4+OGH+cMf/sC6desAAhbB/+tf/8q8efMC2i42Npann36atLQ0duzYwa9+9Svi4uKYM2fOCT/Txx9/TEpKCm+88QZ79+7ljjvuYPDgwVx//fX+Y84991z+3//7f43+nkTk9BhjIP9ITQjN9P0+8B0c3A8NXI0CILE9dOmOld7d97tLD+iUjiMyio6dOpGTkxNS/6ETkbanyQH122+/5eGHH/Y//5//+R8AJkyYwJ133kl+fj55eXn+/SkpKdx777288sorrFy5kqSkJG6++Wb/ElMA/fv3Z968ebz++ussXryYtLQ05s2bR9++fc/ks4WF41c3+OMf/8jQoUPZtWsXGzdu5MiRI7z77rv+MaQ9e/b0H/vss88ybdo0fvnLX/q3DR482P94//79nHPOOf7nERERxMfHY1kWKSkp9WoZP348P/3pTwO21V0GrGvXrnz77bcsW7bspAE1MTGRxx9/HKfTSZ8+fZg8eTIfffRRQEDt1KkTWVlZ/rAtImfOlJVCdk0ArRNIKStt+AWR0dClG1aX7tClx7FAGhdas31FJPw0OaAOHjyYf/zjHyfcf+edd9bbNmjQIObPn3/S85533nkBobVFRUT6ejKDIaJp947PzMzk97//PZs3b+bo0aP+wJaVlcX27dsZMmSIP5web/v27QGh73gVFRVERTW+N/f45cIAli9fzt///ncyMzMpLS3F6/WecjWAfv364XQ6/c9TU1PZsWNHwDFRUVHYtk1lZaV/KIiInJqpqoSyEiguwmTvg6xMfyDlaG7DL3I4ILULVnqPmh5RXxClQ0rNJBIRkdZ1Vv6X37KsJl1mD6ZZs2bRuXNnnnzySdLS0rBtm0mTJlFdXX3KcHmq/e3btz/pUl7Hi4mJCXi+adMm5syZwy9+8QsmTpxIfHw8S5cu5W9/+9tJz9PQ2OHjLwfm5+cTHR1NdHS0JrfJWcd4PL6QWVri+11WgikthtJS/3NKi309oqW+x5TV7KuuOvnJ2yf7ekM7d4P0msvzaelYWuVERELIWRlQ24qjR4+ye/du5s+fz9ixYwHYsGGDf//
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.99e-03\n",
" final error(valid) = 1.17e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.75e-01\n",
" run time per epoch = 2.21\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.50\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABlbUlEQVR4nO3dd3xUVd7H8c+Zkl4hgYSE0EEEQxdFBRU7KOJjXwuWdXd1ZbFhxVUfy6Jre9ayRVcsqyIWRCwgoquAVIWAIEVIIBAggTTSJ3OfPyYZGBIwCSEzmXzfr1deM/fOLb/hGPxy7rnnGsuyLEREREREAoTN3wWIiIiIiBxIAVVEREREAooCqoiIiIgEFAVUEREREQkoCqgiIiIiElAUUEVEREQkoCigioiIiEhAUUAVERERkYCigCoiIiIiAUUBVUREREQCisPfBTS3/Px8XC6Xv8uQZpKYmEhubq6/y5CjQG0bnNSuwUttG7xasm0dDgfx8fG/vl0L1NKiXC4XVVVV/i5DmoExBvC0qWVZfq5GmpPaNjipXYOX2jZ4BWrb6hK/iIiIiAQUBVQRERERCSgKqCIiIiISUBRQRURERCSgBN1NUiIiIuJ/lmWxb9++gLrxRupXVlZGZWVlsx0vNDSU0NDQIzqGAqqIiIg0u3379hEaGkpISIi/S5Ff4XQ6m20GJMuyKCsro6SkhMjIyCYfR5f4RUREpNlZlqVw2gYZY4iIiDjiOekVUEVERESkWdXOr9pUCqgiIiIiElAUUJvAcldjrV+De+m3/i5FREREJOgooDbF+jW4/3of1rv/wqqu9nc1IiIiEmQWLFjAyJEjcbvdR+X4kyZN4vrrr2/w9hUVFQwbNoyMjIyjUs/BFFCbond/iIqG4kLYsMbf1YiIiEiQeeyxx5g4cSI2myeqPf3005x55pnNdvxHHnmEZ599tsHbh4aG8vvf/57HHnus2Wo4HAXUJjB2O2bQiQBYyxf6uRoRERHxB8uy6r1bvalzitbut2zZMrZs2cLYsWMbfYyGThcVExNDbGxso449fvx4li5dysaNGxtdV2MpoDaRGXoSANaP3+syv4iIyGFYloVVUe6fn0Y+KMCyLF566SVOPPFEevTowRlnnMHs2bMBWLRoESkpKXzzzTece+65dOvWjSVLlnDxxRdz//3389BDD9G/f3+uuOIKAL7//nvGjBlDt27dGDRoEI8//rhPoD3UfrNmzWLkyJGEhYUBMH36dJ555hnWrl1LSkoKKSkpTJ8+HYCUlBTeeOMNrrvuOnr27Mnzzz9PdXU1d9xxByeccAI9evTglFNO4ZVXXvH5ngdf4r/44ouZMmUKjz76KP369WPgwIE8/fTTPvu0a9eOIUOGMHPmzEb9mTaFJupvqj7pvpf5+w7wd0UiIiKBqbIC9x8v9cupbS+8B6FhDd5+6tSpfP755zzxxBN069aNxYsXM3HiRNq3b+/d5tFHH+XBBx8kLS2NmJgYAGbMmME111zjDW85OTlcffXVXHrppTz//PNs2rSJu+66i9DQUO644w7vsQ7eD2Dx4sVceOGF3uULLriA9evX88033/Duu+8CEB0d7f386aef5t577+Whhx7CbrfjdrtJTk7m73//O+3atWP58uVMnjyZDh06cMEFFxzyu8+YMYObbrqJTz75hBUrVnDbbbcxbNgwRo4c6d1m0KBBLFmypMF/nk2lgNpEtZf5re/mYi1fiFFAFRERadVKS0v517/+xfTp0xk6dCgAXbp0YdmyZbz11lv85je/AeCuu+7yCW0AXbt25YEHHvAu/+Uvf6FTp0489thjGGPo2bMnO3fu5PHHH+e2227zji09eD+A7OxsOnbs6F0ODw8nMjISu91Ohw4d6tR94YUXcvnll/usu/POO73v09LSWL58OZ988slhA2rfvn25/fbbAejevTvTpk3z3qxVKykpiezs7EMeo7kooB4BM/QkT0D9YRHWlb/D2O3+LklERCTwhIR6ejL9dO6G2rBhA+Xl5d5L7bWqqqro37+/dzk9Pb3OvgMG+HZUbdq0iSFDhvhMWD9s2DBKSkrIyckhJSWl3v0AysvLG/Us+/qO8cYbb/DOO++QnZ1NeXk5VVVV9OvX77DH6du3r89yhw4dyMvL81kXFhZGWVlZg2trKgXUI1F7mX9fkS7zi4iIHIIxplGX2f2ldkqnN954g6SkJJ/PQkJCyMrKAiAiIqLOvuHh4T7LlmXVeZpSfeNhD94PPGM9CwsLG1z3wfXMmjWLhx9+mClTpjB06FAiIyN5+eWX+fHHHw97HIfDNxYaY+pMc1VQUOAz3OFoUUA9ArrMLyIiEjx69+5NaGgo27dv58QTT6zzeW1AbYhevXrx2Wef+QTV5cuXExUVRXJy8mH37devHxs2bPBZ53Q6Gzwn6tKlSxkyZAgTJkxoUu2H8/PPP/9qT2xz0F38TVBU7uK91Xm8/uPu/Xfz/7BId/OLiIi0YlFRUfzud7/joYce4r333iMzM5M1a9Ywbdo03nuvcUMUrr32Wnbs2MEDDzzApk2bmDNnDk8//TQ33XSTd/zpoZx66qksW7bMZ13nzp3ZunUra9asYe/evVRUVBxy/65du5KRkcE333zDL7/8wpNPPsmqVasaVf+hLF26lFGjRjXLsQ5HAbUJdhRX8Z+MPD75OZ+CLsf6XuYXERGRVmvy5MncdtttvPDCC5x66qlceeWVfPnll6SlpTXqOMnJybz55pusXLmSM888k3vuuYcrrriCP/3pT7+670UXXcSGDRvYtGmTd915553HqaeeyqWXXspxxx132Kmerr76as4991z+8Ic/cP7555Ofn8+1117bqPrrs3z5coqLixkzZswRH+vXGKuxE4QFuNzc3AZPUttUlmVx15wsNu4p58r0BC758V2s7+ZiRp6N7epbjuq52xJjDMnJyeTk5DR6HjsJbGrb4KR2DV5NaduioiLvFEzSeI8++ihFRUU8+eSTR/1cTqezQdnppptuon///kycOPFXtz1U+zudThITE391f/WgNoExhvP7xAPw+YZ8XINrL/Nr0n4RERE5chMnTiQ1NZXqAMkVFRUVHHvssfz2t79tkfMpoDbRiLQY4sMd5JdXsyi8qy7zi4iISLOJiYlh4sSJ2ANkCsvQ0FAmTZpU76wDR4MCahM57YbzescB8MmGQhjkudvPWr7Aj1WJiIiItH4KqEfg7J5xOG2GTXvL2XDMKYAu84uIiIgcKQXUIxAb5mBUN88A4Fnl7XWZX0RERKQZKKAeodqbpRZn72PPoFMBXeYXERERORIKqEeoa3wYx3WMwG3B58k141B1mV9ERESkyRRQm8H5x3h6UecWhFAe095zmX/9aj9XJSIiItI6KaA2g6GdokiKcrKv0s23Ay8AwFqx0M9ViYiIiLROCqjNwG4zjKkZizo7og8WuswvIiIiTbdgwQJGjhyJ2+1utmNOmjSJ66+/3rt88cUX8+CDDx52n+HDh/Ovf/0L8EzWP2zYMDIyMpqtpkNRQG0mZ/SIJdxhI7vcxqpOA3WZX0RERJrsscceY+LEidhsRy+q/etf/2Ly5MkN3j40NJTf//73PPbYY0etploKqM0kwmlndI9YAGb3OgvQZX4REZFgZlkWLperzvrKysomHa92v2XLlrFlyxbGjh17RPX9mvj4eKKiohq1z/jx41m6dCkbN248SlV5KKA2o7F94jHADyaB7eGJuswvIiKCJ8iVu9x++bEsq9G1vvTSS5x44on06NGDM844g9mzZwOwaNEiUlJS+Oabbzj33HPp1q0bS5Ys4eKLL+b+++/noYceon///lxxxRUAfP/994wZM4Zu3boxaNAgHn/8cZ9Ae6j9Zs2axciRIwkLCwNg06ZNpKSksGnTJp9a//GPfzB8+HAsy6K6upo77riDE044gR49enDKKafwyiuvHPa7HnyJPy8vj2uvvZYePXpwwgkn8OGHH9bZp127dgwZMoSZM2c26s+1sRxH9ehtTHJ0CENToli2fR+fdTuV366d4bnMf+xAf5cmIiLiNxXVFpdN3+CXc0+/rDdhDtPg7ad
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABmwUlEQVR4nO3deXwV9b3/8decLftGQhYgSAIEWQRBBAQryCIutCyl2lrrRaVyRa6lVr21Lq170Vqstrb2VyteLWhdKIgVkMXKorKpyCIQFkGSQEL2/Szz++MkBw4JSyDJOTl5Px+PPHJmzszkM+dL4M33O98ZwzRNExERERGRIGEJdAEiIiIiIidSQBURERGRoKKAKiIiIiJBRQFVRERERIKKAqqIiIiIBBUFVBEREREJKgqoIiIiIhJUFFBFREREJKgooIqIiIhIUFFAFREREZGgYgt0Ac2tqKgIl8sV6DKkmXTs2JH8/PxAlyEtQG0bmtSuoUttG7pas21tNhsJCQln3q4VamlVLpcLp9MZ6DKkGRiGAXjb1DTNAFcjzUltG5rUrqFLbRu6grVtNcQvIiIiIkFFAVVEREREgooCqoiIiIgEFQVUEREREQkqITdJ6lRM06S8vDyoLgCWM6uqqqK2trbR98LCwggLC2vlikRERKSltZuAWl5eTlhYGA6HI9ClSBPY7fZG78pgmiZVVVVUVFQQFRUVgMpERESkpTQ5oO7YsYPFixezf/9+ioqKuOeeexgyZMgZ93n11Vf59ttvSUhI4Hvf+x5XXXWV3zaffvopb775JkeOHCElJYUf/ehHZzxuU5imqXAaQgzDIDIykpKSkkCXIiIiIs2sydeg1tTU0K1bN2699daz2v7o0aM89dRT9O7dmzlz5jB58mReeeUVPv30U982u3fv5rnnnuOKK67gmWee4YorrmDu3Lns2bOnqeVJO1N//zYREREJHU3uQR04cCADBw486+2XL19OUlIS06ZNA6BLly7s3buX9957j2HDhgHw/vvv079/fyZPngzA5MmT2bFjB++//z6zZ89uaokiIiIi0oa1+DWoe/bsoX///n7rLr74YlavXo3L5cJms7F7926uu+46v20GDBjAv//975YuT0RERNog0zTB9IDbAx4PeNx1Xx5wn/Da4276Nqan7qcYUD9SZ5y0jOFdd/LyifudvOxbrF8+6djeMwOz7rv3RL1fx0/8hG0aWaZu+5OXfcfDb9kEantfBJGxZ/rIW1WLB9Ti4mLi4uL81sXFxeF2uykrKyMhIYHi4mLi4+P9tomPj6e4uPiUx3U6nX6TZwzDICIiwvdazk52djZTp05l7dq1REdHN/vxn332WZYuXcqHH3541vtce+21zJo1i2uvvfastld7tz31baa2Cy1q17bBdDmhuhpqqqCmGmqqMWuqG6yjphqzum65toZjdhvuysq6AGeCpy4gmqY32Plee9ebnrrlk7fxmA33Od02pwuW0iwqrvsBxpT/CnQZflplFv/Jf1nV3+rpdH+JmaZ52vcXLlzI22+/7VvOyMhgzpw5dOzYsdHtq6qqsNvtTSm7XXjmmWe49dZbSUhIAOCNN97gwQcfJDs7u1mOP2vWLG6//fYmffa/+MUv+M1vfsN3v/tdgNPu63A4SEtLO+86JTBSU1MDXYK0ALVr05geD7hdmE6nNzy6XJguV91r50nrnZi1NZjV1XiqKzGrKjGrqzCrq/BUV2FWVdat919nVlfhqarErKkCl+uc6qxs5vNuURYLWCwYVitYbMdfW60YlrrvVqt3O6vN9/r49sbxXsv63kzf8gm9j/W9kb4eTvO0+5kNejFPOJ5pntQbW9+7atR1zp64fLpe3OPbGSf2zJ5mP1taZxKC7Pe2xQNqYz2hpaWlWK1WX49dY9uUlJQ06Hk90eTJk5kwYYJvuT7M5ufn42rkl6+2trbR2xW1Zzk5OSxbtoxf//rXvs+m/rM702dVW1t7VndFqL9XaVM++1GjRlFaWsqHH37I+PHjT7tvbW0tubm5Z31sCQ6GYZCamkpeXp7uTRxCQqVdTdP09iRWVkJVBVRWYFZVQlW5b51ZWQHVld6wVxcecbvqll2Ybqe3p69uGfdJ25z43e0OzIna7BAeDo5w7/ewCIywcAiPAEeY93tYOISFY4RFENuhA6Xl5ceDjuENgb7XhuENdhbL8WXf+rp1lsb3NerXW6wnHb9uncXi/W494bXF4r9stfqOa1iaNgfcPOl7MDqX2s5mH8MwiGnF31ubzXbKzkS/7Vq6kJ49e7J582a/dV9++SWZmZnYbN4fn5WVxVdffeUXOLdu3UpWVtYpj2u320/Zs3amD9g0TaitOdtTaF6OsCYNf61evZo//OEP7Nq1C4vFwiWXXMKjjz5Kt27dAG/IfOyxx/j444+pqamhZ8+ePPHEEwwaNAjwTlKbO3cuu3btIjIykmHDhvG3v/0NgPfee48+ffrQqVMnANavX8/dd98NQOfOnQG4++67+cUvfsHQoUP50Y9+xIEDB1i6dCnjx4/nD3/4A0888QQffPABubm5JCcnM3nyZH7+85/72ubkIf7Zs2dTWlrKkCFDeOmll6itrWXixIk88sgjvn2sViujR4/mX//6F+PHjz/jZ9SW/yFs70zTVPuFoEC3q+lyQlUlVFb4AiZVlZgnvD4ePOuWK8uP71NdGdjhY8PwhkebDax1X7YTvtcHy/pAefKXX7D0bufd/oQw6gjHsJ19BDAMg5i0NMpzc9vE72xbqDHYBPr39mRNDqjV1dXk5eX5lo8ePcqBAweIjo4mKSmJ+fPnU1hYyKxZswC46qqrWLZsGa+++ipjxoxh9+7drFq1ip/97Ge+Y1x77bX8+te/5l//+heXXnopGzdu5KuvvuLRRx9thlNsRG0NnlnXt8yxz8Dyx396/5I4S5WVldx+++1ceOGFVFZW8rvf/Y7p06ezfPlyqqqqmDp1Kqmpqbzyyit07NiRr776Ck/dX6wrVqxg+vTp3HXXXTz//PPU1taycuVK37E/++wzvwlsgwcP5pFHHuF3v/sdH3/8MYDfTfD/8pe/MHv2bL+2i4qKYu7cuaSmprJz507uu+8+oqOjmTlz5inPaf369SQnJ/PWW2+xf/9+7rjjDvr27cuPf/xj3zYXX3wxf/7zn8/6cxKR0GM6a6G8DCpK676XYZaXQXkpVHi/m3XrfWG0qgJO8fS5JrNaISIKIiIhMtr7PSISIzLKuz48EhwO/xBZFyQN3+v6oGmt28Z+Utis28Zq9S0bFmvz1C/ShjU5oO7du5dHHnnEt/x///d/AIwcOZI777yToqIiCgoKfO8nJydz//338+qrr7Js2TISEhK45ZZbfLeYAujVqxezZ8/mjTfe4M033yQ1NZXZs2fTs2fP8zm3kHDy3Q2effZZ+vfvz+7du9m0aRPHjh3j/fff911DmpGR4dv2+eefZ+LEidxzzz2+dX379vW9PnToEBdddJFv2eFwEBMTg2EYJCcnN6hlxIgR/Pd//7ffuhNvA5aens7evXtZvHjxaQNqXFwcTzzxBFarlR49ejBmzBjWrl3rF1DT0tI4fPiwL2yLSNtlmiZUV3mDZV3g9AXLEwKo6Quede/VVJ/fDw6P8AbJyLqQGRGFceJyfdA8MXT6to0Gh0MTvkQCpMkBtW/fvvzzn/885ft33nlng3V9+vRhzpw5pz3usGHD/EJri3KEeXsyA8HRtGfHHzhwgGeeeYYtW7ZQWFjoC2yHDx9m+/bt9OvXzxdOT7Z9+3a/0Hey6upqwsPPvjf35NuFASxZsoS//e1vHDhwgIqKCtxu9xnvBpCVlYXVeryHICUlhZ07d/ptEx4ejsfjoaamxncpiIgEF9PjgeJCKDiCWXAECo5QWFuF++gRzPKS40Gzotx7reW5sFggKsb7FR0D0bEY9a+jYiE6xrscFXNS6IxQT6RIG9Yu/+U3DKNJw+yBNG3aNDp16sTTTz9NamoqHo+H0aNH43Q6zxguz/R+hw4dTnsrr5NFRkb6LW/evJmZM2fyi1/8glGjRhETE8OiRYv461//etrjNHbt8Mn
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 3.07e-03\n",
" final error(valid) = 1.34e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.71e-01\n",
" run time per epoch = 2.23\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=1.00\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABotUlEQVR4nO3dd3xUVf7/8deZzKQRUkghgQRIgERa6EV0QcWyCivioq6uq1hWdy2Iuro2XPVrWXRtP+sWd7GsiqAgggpiWxGpKgFBeuglIZ20mcz9/THJQEiAJCTMZHg/H495ZObOvXc+wzH45txzzjWWZVmIiIiIiPgJm68LEBERERE5lAKqiIiIiPgVBVQRERER8SsKqCIiIiLiVxRQRURERMSvKKCKiIiIiF9RQBURERERv6KAKiIiIiJ+RQFVRERERPyKAqqIiIiI+BW7rwtobvn5+bhcLl+XIc0kPj6enJwcX5chLUBtG5jUroFLbRu4TmTb2u12YmJijr3fCajlhHK5XDidTl+XIc3AGAN42tSyLB9XI81JbRuY1K6BS20buPy1bXWJX0RERET8igKqiIiIiPgVBVQRERER8SsKqCIiIiLiVwJukpSIiIj4nmVZlJSU+NXEG6lfWVkZlZWVzXa+kJAQQkJCjuscCqgiIiLS7EpKSggJCSE4ONjXpcgxOByOZlsBybIsysrKOHDgAG3atGnyeXSJX0RERJqdZVkKpychYwzh4eHHvSa9AqqIiIiINKua9VWbSgFVRERERPyKAmoTWO4qrPWrcS/9n69LEREREQk4TZokNW/ePGbPnk1BQQHJyclMmDCBHj161LvvkiVLmD9/PtnZ2bhcLpKTk7nkkkvo16+fd5+vvvqKl19+uc6xb731ln+OX1m3Gvczk6FtFNbA0zBBQb6uSERERALIwoULue+++/jqq6+w2Zq/P3HSpEkUFRXx73//u0H7V1RUcPrpp/Paa6+RmZnZ7PUcrtEBddGiRUydOpXrr7+ejIwMFixYwOOPP86zzz5LXFxcnf3Xrl1LZmYml19+OW3atOHLL79kypQpPP7446Smpnr3CwsL4/nnn691rF+GU4D03hDRFooLYd0q6NnP1xWJiIhIAHnssceYOHGiN5w+/fTTfPrpp3z22WfNcv5HHnmkUUuAhYSE8Ic//IHHHnuMadOmNUsNR9PoSD5nzhzOOussRo0a5e09jYuLY/78+fXuP2HCBMaOHUu3bt1ISkriiiuuICkpiRUrVtTazxhDdHR0rYe/MkFBmAHDAbCWL/RxNSIiIuILlmXVO1u9qWuK1hy3bNkytmzZwpgxYxp9joYuFxUZGUlUVFSjzj1u3DiWLl3Khg0bGl1XYzWqB9XlcrF582YuuuiiWtszMzNZt25dg87hdrspKysjIiKi1vby8nJuuukm3G43Xbp04bLLLqvVw3o4p9NZqxGMMYSFhXmftzQz+BdY/5uH9cN38Ns/YuxaUra51bTjiWhPObHUtoFJ7Rq4jrdtLcuCyormLKnhgkMaVbdlWbzyyiu8+eab7Nu3j9TUVCZNmsSYMWNYtGgRl1xyCf/973+ZMmUKa9eu5b///S/PPvssGRkZOBwOZsyYQUZGBu+//z7fffcdjz76KGvWrCE6OppLLrmEu+++G3t1Zhg/fny9x82ePZsRI0YQGhoKwLRp03jmmWcA6NixIwDPPPMMl112GR07duSJJ57gyy+/5JtvvuEPf/gDt99+O3fffTfffvstOTk5dOjQgauvvprrr7/e+z0Pv8Q/fvx4evToQUhICO+88w4Oh4Pf/e533Hnnnd5j2rVrx8CBA5k1axZ33XXXMf8sj+fvgkalqqKiItxud53EHRUVRUFBQYPOMWfOHCoqKjj11FO92zp06MBNN91Ep06dKCsr4+OPP2by5Mk89dRTJCUl1XuemTNnMmPGDO/r1NRUpkyZQnx8fGO+UpNZCfHseu0Z3AV5xOzbQdjAU499kDRJYmKir0uQFqK2DUxq18DVmLYtKyvD4XAAYFWUU3HLpS1V1lGF/H0mJji0wfs//vjjzJ07lyeffJK0tDQWL17MxIkTad++vTdYPvbYYzz00EN07tyZqKgojDFMnz6dCRMmMHfuXCzLIjc3l6uuuorLLruMl156iQ0bNnDnnXcSFhbG3XffDVDvcQ6HgyVLljBu3Djvn9+vf/1rNmzYwJdffsn06dMBTw9ozftPP/00DzzwAI8++ihBQUEEBQXRsWNH/vWvf9GuXTuWLVvGn/70Jzp06MDYsWMBsNlsGGO856ip5Q9/+AOffvopy5YtY+LEiQwbNowzzjjD++czcOBAli1b5j3uSIKDg4+Y4RqiSd1+9SXihqTkhQsXMn36dO66665aITc9PZ309HTv64yMDP785z/zySefcO2119Z7rnHjxtXq+q75/JycnONeHLahrH5D4atP2D9/NkEdupyQzzyZGGNITExkz549ulVegFHbBia1a+BqSttWVlZ6r3RazXSXoqZwOp0YW8MmM5eWlvLqq68ybdo0Bg0aBHjC4XfffcfUqVP57W9/C8Cf/vQnTjvtNO9xlmXRpUsX7rvvPu+2v/71ryQlJfF///d/GGPo0qULd9xxB48//ji33XYbNput3uOcTifbt28nPj7e++dnt9sJCwvDZrPRrl27WvsCXHTRRVxyySW1vssdd9zhfT527FiWLFnCzJkzueCCCwDPFW3LsnA6nTgcDizLokePHkyaNAmAlJQUXnvtNb7++uta3zUhIYFt27YdcyhBZWUlu3fvrrPdbrc3qDOxUQE1MjISm81Wp7e0sLDwmOMYFi1axKuvvsodd9xxzNlfNpuNrl27smfPniPu43A4jpjeT9RfjGbQL7C++gTr++9w//YPGPvR/zUhTWNZlv5nF6DUtoFJ7Rq4mty2wSHYXnyv+Qtq4Gc31Pr16ykvL+fyyy+vtd3pdNK7d2/v6/pyTN++fWu93rhxIwMHDqzVgTd48GAOHDjA7t27vZfqDz8OPMMeG3Mv+/rO8cYbb/DOO++wY8cOysvLcTqd9OrV66jnOXxFpoSEBHJzc2ttCw0NpaysrEF1Hc/fA40KqHa7nbS0NLKyshgyZIh3e1ZWFoMHDz7icQsXLuSVV17htttuY8CAAcf8HMuy2Lp1KykpKY0p74SqcLnJTexKYlQMFObD2izoM9DXZYmIiPgdYwyENPwyu6+43W7AE+4OH84QHBzM1q1bAQgPD69zbM08mBqWZdW5ulxfYDv8OPCM9SwsLGxw3YfXM3v2bB5++GEmT57MoEGDaNOmDa+88go//PDDUc9jP2w+jTHG+2dSo6CggNjY2AbX1lSNvsQ/ZswYXnjhBdLS0khPT2fBggXk5uZyzjnnAPD222+Tl5fHLbfcAnjC6UsvvcSECRNIT0/39r4GBwd7/0CnT59O9+7dSUpK8o5Bzc7O5rrrrmumr9m81uwr5a/f7KRdmJ2n+w+Hr+ZiLfsGo4AqIiLSaqWnpxMSEsLOnTtrzZWpURNQG6J79+58/PHHtYLq8uXLiYiIOObYzF69erF+/fpa2xwOR52weCRLly5l4MCBTJgwoUm1H83PP/98zJ7Y5tDogDp8+HCKi4t5//33yc/PJyUlhXvvvdc7niA/P79Wd/CCBQuoqqritdde47XXXvNuHzlyJDfffDMABw4c4B//+AcFBQWEh4eTmprKww8/TLdu3Y73+7WI5KgQKlwWW/IrWN79dAZ9NRfrxyVYTifmGIOGRURExD9FRERw44038tBDD+F2uxkyZAglJSUsX76c8PBwkpOTG3yuq6++mn/961888MADXHPNNWzatImnn36aG2644ZgL759xxhneyVA1UlJS2LZtG6tXr6ZDhw60adPmiMMAunTpwowZM/jqq69ISUnh/fffZ+XKlc1yZXrp0qUNmsF/vJo0Seq8887jvPPOq/e9mtBZ46GHHjrm+SZMmFAr5fu7yJAgLkiP5oM1eUzfH87A6HaYgjxY8yP0PfJQBxEREfFvd999N3Fxcbz44ots27aNyMhI+vTpw6233trgHkyApKQk3nzzTR5
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABpyElEQVR4nO3deXxU9b3/8deZzEz2jYQskCBhlUWQXUHFskirVMBLtbW2RaVSgWuprbZeta170bZY/dVbrFasXtC6IBgVYgQLghVBZREFIiBLEpOQfZ/l/P6YZGBIAgSSzGTyfj4ePGbmbPmcfEl48z3n+z2GaZomIiIiIiIBwuLvAkRERERETqSAKiIiIiIBRQFVRERERAKKAqqIiIiIBBQFVBEREREJKAqoIiIiIhJQFFBFREREJKAooIqIiIhIQFFAFREREZGAooAqIiIiIgHF6u8C2lpJSQlOp9PfZUgb6d69O4WFhf4uQ9qB2jY4qV2Dl9o2eHVk21qtVuLj40+/XQfU0qGcTicOh8PfZUgbMAwD8LSpaZp+rkbakto2OKldg5faNngFatvqEr+IiIiIBBQFVBEREREJKAqoIiIiIhJQFFBFREREJKAE3SCplpimSWVlZUDdACynV1NTQ319fbPrQkNDCQ0N7eCKREREpL21OqDu3r2b1atXc+DAAUpKSvjVr37F2LFjT7vP888/z5EjR4iPj+fqq6/miiuu8NnmP//5Dy+//DLffPMNycnJ/OAHPzjtcVujsrKS0NBQ7HZ7mx1T2p/NZmt2VgbTNKmpqaGqqorIyEg/VCYiIiLtpdWX+Ovq6ujduzc33XTTGW1fUFDAI488wqBBg1i8eDGzZs3iueee4z//+Y93m7179/L4449z2WWX8dhjj3HZZZexZMkS9u3b19ryWmSapsJpEDEMg4iICM15KyIiEoRa3YM6YsQIRowYccbbZ2VlkZiYyJw5cwBIS0vjq6++4s033+Siiy4C4K233mLYsGHMmjULgFmzZrF7927eeustFi1a1NoSpQtpnL9NREREgke734O6b98+hg0b5rPswgsvZP369TidTqxWK3v37uWqq67y2Wb48OG8/fbbLR7X4XD4XPo1DIPw8HDve+k61N6dT2Obqe2Ci9o1eKltg1egtm27B9TS0lJiY2N9lsXGxuJyuaioqCA+Pp7S0lLi4uJ8tomLi6O0tLTF465cuZJXX33V+zkjI4PFixfTvXv3ZrevqanBZrOd9XmI/5yq3ex2O6mpqR1YjbSllJQUf5cg7UDtGrzUtoHNNE1wuTCdDnA4MB31mE4npqMenA5MhwPT6Tjhs2ddfVUpKX3P93f5PjpkFP/JqbxxJP2p0rppmqdcP2vWLKZPn97kaxQWFjZ7X2J9fb0egdqMnJwcZs+ezQcffEBUVFSbH/9Pf/oTa9as4d133z3jfa688koWLlzIlVde2eIgqUb19fXk5eW1RanSgQzDICUlhfz8fM2sEUTUrsFLbXt6pmmC0wH19eCoa3ith/o6cNRjnri8YVnjtmbjtg4HuJwNrw5wOj3HdDo9odPZuKxhuXfbE17Pon2ipn+P4llxHdK2Vqu1xc5En+3au5DmekLLy8sJCQnxBqLmtikrK2vS83oim83WYs+afnjO3OLFi/nJT37ibYuXX36Z3//+93zxxRdtcvyf/exn3Hjjja3aZ9GiRdx///18+9vfPqPt1d6dl2maar8gpHYNXoHUtqbbDS7XCUHO6QloLqfvZ2+Yc3k/m43hzuVq2O6kYzgcDYHREyZNb5g8KVw2vnfUefYJkO+ND6sNbDYIsXreW5u+hiT3DKi2hQ4IqP3792fbtm0+y7Zv306fPn2wWj1ffsCAAezcudOnR3THjh0MGDCgvcvr0nJzc3n33Xe57777Wr1vfX39Gc2KEBkZ2eppoCZPnswdd9zB+++/z7Rp01pdm4iIBBbT7YbaGqiphppKqK6GmirM6iqoqYLqqoZ1nvdmTePn6uO9ho1hszFMut3+Pq2WGRaw28Fmb3gNPeG9Heyez8ZJn7HZmg+RIVYMawvrWnoNsUJIyGnvLTUMg5jUVKoC7GpkqwNqbW0t+fn53s8FBQUcPHiQqKgoEhMTWb58OcXFxSxcuBCAK664grVr1/L8888zefJk9u7dy7p16/j5z3/uPcaVV17J7373O9544w3GjBnDxx9/zM6dO7n//vvb4BSbMk3T878ef7CHtupG5PXr1/OXv/yFPXv2YLFYGDVqFPfffz+9e/cGPCHzgQceYMOGDdTV1dG/f38eeughRo4cCXhmUViyZAl79uwhIiKCiy66iGeeeQaAN998k8GDB9OjRw8ANm/ezO233w5Az549Abj99tv55S9/ybhx4/jBD37AwYMHWbNmDdOmTeMvf/kLDz30EO+88w55eXkkJSUxa9YsfvGLX3h7t0++xL9o0SLKy8sZO3YsS5cupb6+nhkzZnDfffd59wkJCWHSpEm88cYbCqgiIgHAdDhwlRZj5h9tCJWVUFPteX9SyPSEyxNCZ3UV1Fa3f++iYRwPaCENf6wnvJ68/IR1RogNQkJOWG5rGjDtoWC3Y/gsa9jmxJBpt3sCZYANOupsWh1Qv/rqK58et3/+858ATJw4kQULFlBSUkJRUZF3fVJSEnfddRfPP/88a9euJT4+nhtvvNE7xRTAwIEDWbRoES+99BIvv/wyKSkpLFq0iP79+5/LubWsvg73wmvb59inYfl//4LQsDPevrq6mltuuYXzzz+f6upq/vjHPzJ37lyysrKoqalh9uzZpKSk8Nxzz9G9e3d27tyJu+F/ldnZ2cydO5fbbruNJ554gvr6et577z3vsT/66COfGRZGjx7Nfffdxx//+Ec2bNgA4NP7+be//Y1Fixb5/OciMjKSJUuWkJKSwhdffMGdd95JVFQU8+fPb/GcNm/eTFJSEq+88goHDhzg1ltvZciQIfzwhz/0bnPhhRfyv//7v2f8fRIRkaZMp+N4T2RtNdTUeHouG9/XNvRaNrw3fbY94b3TSW5bFGS1Qnik509Ew5/wCIyIKAiP8FlnhEd4ltnsLYTLxjDqCZaGJaQtKpQA0eqAOmTIEP71r3+1uH7BggVNlg0ePJjFixef8rgXXXSRT2gVj5On3/rTn/7EsGHD2Lt3L1u3buXYsWO89dZbxMfHA57ZDBo98cQTzJgxg1/96lfeZUOGDPG+P3z4MBdccIH3s91uJzo6GsMwSEpKalLLhAkT+NnPfuaz7MR5atPT0/nqq69YvXr1KQNqbGwsDz30ECEhIfTr14/JkyfzwQcf+ATU1NRUjh496g3bIiJdhecqX70nGNbWHA+LtTWeYOm9VH48SJo1Jy6vOv7e2caDgxtD48lBMuKE0BkeAeFRGI3vT1hn2PTAHDkzHTKKP+DYQz09mX762q1x8OBBHnvsMT755BOKi4u9ge3o0aN8/vnnDB061BtOT/b555/7hL6T1dbWEhZ25r25J89nC5CZmckzzzzDwYMHqaqqwuVynXY2gAEDBhAScvx/usnJyU0GZYWFheF2u6mrq/PeqywiEqi8I7gbg6E3WJ4QKlsKmyfvU1vT9vdXhoZBWASEh3vCYli4p+cyPKJh+Qmv4REYPtt6tkvN6EN+QUFADaSR4NUl/+U3DKNVl9n9ac6cOfTo0YNHH32UlJQU3G43kyZNwuFwnDZcnm59t27dTjnX7MkiIiJ8Pm/bto358+fzy1/+kssvv5zo6GhWrVrF008/fcrjNDf7wsm/8EpKSggPDyc8PFzTg4mIX5im6QmOJcegpAiz9Njx9yXHoLTY01tZV+O5RO5q40cvG4YnSIaGNwTIcO8fIzzy+LLwSE+YbAiSnBw6w8IxQs7t8rdhGOd8DJHW6JIBtbMoLi5m3759LF68mHHjxgGwZcsW7/pBgwaxYsUKSkpKmu1FHTRoEB988AHXXXdds8cfOnQo+/bt81lmt9txuVxnVN/HH39MWlqazz2pR48ePaN9T2fPnj0+tx+IiLQl0zShsgJKiqDkGGbDq28QPeYJn60V2hAkG0KjJ1RGYHiXhR8PkKH
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 7.72e-03\n",
" final error(valid) = 1.62e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.62e-01\n",
" run time per epoch = 2.23\n"
]
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"learning_rate = 0.2 # learning rate for gradient descent\n",
"\n",
"init_scales = [0.1, 0.2, 0.5, 1.] # scale for random parameter initialisation\n",
"final_errors_train = []\n",
"final_errors_valid = []\n",
"final_accs_train = []\n",
"final_accs_valid = []\n",
"\n",
"for init_scale in init_scales:\n",
"\n",
" print('-' * 80)\n",
" print('learning_rate={0:.2f} init_scale={1:.2f}'\n",
" .format(learning_rate, init_scale))\n",
" print('-' * 80)\n",
" # Reset random number generator and data provider states on each run\n",
" # to ensure reproducibility of results\n",
" rng.seed(seed)\n",
" train_data.reset()\n",
" valid_data.reset()\n",
"\n",
" # Alter data-provider batch size\n",
" train_data.batch_size = batch_size \n",
" valid_data.batch_size = batch_size\n",
"\n",
" # Create a parameter initialiser which will sample random uniform values\n",
" # from [-init_scale, init_scale]\n",
" param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
" # Create a model with four affine layers\n",
" hidden_dim = 100\n",
" model = MultipleLayerModel([\n",
" AffineLayer(input_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, output_dim, param_init, param_init)\n",
" ])\n",
"\n",
" # Initialise a cross entropy error object\n",
" error = CrossEntropySoftmaxError()\n",
"\n",
" # Use a basic gradient descent learning rule\n",
" learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
" stats, keys, run_time, fig_1, ax_1, fig_2, ax_2 = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
" plt.show()\n",
"\n",
" print(' final error(train) = {0:.2e}'.format(stats[-1, keys['error(train)']]))\n",
" print(' final error(valid) = {0:.2e}'.format(stats[-1, keys['error(valid)']]))\n",
" print(' final acc(train) = {0:.2e}'.format(stats[-1, keys['acc(train)']]))\n",
" print(' final acc(valid) = {0:.2e}'.format(stats[-1, keys['acc(valid)']]))\n",
" print(' run time per epoch = {0:.2f}'.format(run_time * 1. / num_epochs))\n",
"\n",
" final_errors_train.append(stats[-1, keys['error(train)']])\n",
" final_errors_valid.append(stats[-1, keys['error(valid)']])\n",
" final_accs_train.append(stats[-1, keys['acc(train)']])\n",
" final_accs_valid.append(stats[-1, keys['acc(valid)']])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |\n",
"|------------|--------------------|--------------------|------------------|------------------|\n",
"| 0.1 | 2.03e-03 | 1.35e-01 | 1.00 | 0.97 |\n",
"| 0.2 | 1.99e-03 | 1.17e-01 | 1.00 | 0.97 |\n",
"| 0.5 | 3.07e-03 | 1.34e-01 | 1.00 | 0.97 |\n",
"| 1.0 | 7.72e-03 | 1.62e-01 | 1.00 | 0.96 |\n"
]
}
],
"source": [
"j = 0\n",
"print('| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |')\n",
"print('|------------|--------------------|--------------------|------------------|------------------|')\n",
"for init_scale in init_scales:\n",
" print('| {0:.1f} | {1:.2e} | {2:.2e} | {3:.2f} | {4:.2f} |'\n",
" .format(init_scale, \n",
" final_errors_train[j], final_errors_valid[j],\n",
" final_accs_train[j], final_accs_valid[j]))\n",
" j += 1"
]
2024-10-03 15:53:33 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models with five affine layers"
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 18,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.10\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAEJCAYAAACnqE/cAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAA0h0lEQVR4nO3de3xTdZ438M85OUma3tOEtvQCNOWuIGAVBkdFqayCri4iM446MjqOis+q6ywOOPPyeT3uqiiLuszAg48jqPPa2RlmR3EQUaw31CIWi8LIrTSlUkpb2vTe5np+zx9J0zstNOlp0s/79eor55p886P0k9/vnJwjCSEEiIiISDOy1gUQERGNdgxjIiIijTGMiYiINMYwJiIi0hjDmIiISGMMYyIiIo0pWr54ZWWlli8fdaxWK2pra7UuI6qwTUOPbRoebNfQC0ebZmRk9LmcPWMiIiKNMYyJiIg0xjAmIiLSmKbHjImIaPgIIeB0OqGqKiRJ0rqcEa+6uhoul+u89xNCQJZlxMTEDLqdGcZERKOE0+mEXq+HovBP/2AoigKdTndB+3q9XjidTphMpkFtz2FqIqJRQlVVBvEwURQFqqoOenuGMRHRKMGh6eF1Pu0dFR+RzlSexYfFZZAlCZIEyJIEWZYgd0wH5iUgsLzLssC0JEvQdSyTO9clJyUgNT0FyTEKf5GJiCgsoiKMayrO4H8a4iGkUHT0ReCnQz2AesSobqSKdqQpXqSZdEhLNiHNkoS0DCvSk2MRo3CQgYiILkxUhPHMWVPwlq0BwueDKgRUnwqhqlCFgPCp8AkVQhVQVRWqT0AIAZ9PhRACquhcpgb2UVUBoQr4VBX1zW2obmhHtdOHaq+CGsmEQ554ONuMQKUTOFQBAEjytSNNciHVKJAWr0e6OR6paRakj0nEmDgDdDJ71URE4SSEwPLly7Flyxaoqoq33noLK1asOO/nueuuu/C73/0OFoul322eeuopXHvttfjhD384hIo7RUUYSwYjYE3zD0OH+bWEEBBN9Wg6U4OqqlpUO1pQ0+xCtROoFkaUuBJQ6DZBrQdgrwNQB1moGCPakarzIs0kY8Y4MxZcOjHMlRIRRR6fz9ftDOae830Rwt+h+uijjzB9+nQkJCTg1KlTeOONN/oM44Ge8w9/+MOAdd5zzz1YtWoVw1grkiRBSkpBclIKkqcCU3usFy4XfGerUHumBtVnG/y96jYvqr06VMOEvW4rPjriwryL3IiJMWjyHoiI1D+9AnGqLKTPKWXnQP7xfefc5q9//Su2bNkCt9uN2bNn49lnn8XUqVPxi1/8Ap9++imefPJJ3HHHHd3mDxw4gD//+c8AgNtvvx333XcfTp06hTvvvBPz58/H119/jS1btuCtt97CHXfcAQB45plnUF5ejuuuuw5XXXUVFi5ciBdeeAFpaWn47rvv8Mknn+Cee+5BZWUlXC4X7r33Xtx5550AgLlz52LXrl1wOp24/fbbcfnll2P//v1IT0/Hli1bYDKZkJWVhfr6etTU1CA1NXXIbccwDjHJaISSNR7pWeOR3mOdUFXs3f05nquLRbn9FKZMz9WkRiIiLZSUlOBvf/sbtm/fDr1ejzVr1uDNN99EW1sbpkyZglWrVgFAt/mDBw9i27ZteOeddyCEwI033ogf/OAHSEpKQmlpKV544QU8++yzAICioiI899xzAIAnnngCx44dwwcffAAAKCwsxDfffIOPPvoI48aNAwCsX78eZrMZ7e3tWLJkCRYvXoyUlJRuNZeVlWHjxo1Yt24d7r//frz77ru49dZbAQAzZsxAUVERlixZMuS2YRgPI0mWYZsyAShsQ2l5DcOYiDQzUA82HD7//HMcOnQIixcvBuC/CInVaoVOp+sWaF3nv/rqK1x//fWIjY0FANxwww3Yt28fFi1ahKysLFx66aXB/RoaGhAfH9/v68+aNSsYxACwZcsW7Nq1C4D/LoJlZWW9wjg7OxsXX3wxAGDmzJk4depUcJ3FYkF1dfUFtUVPDONhljouE/F7voXd4dS6FCKiYSWEwG233YY1a9Z0W7558+Zux3CNRmNwXgiB/nQEdIeOC23Ict9nD3XdvrCwEJ999hl27NgBk8mEZcuW9XnpS6PRGJzW6XRwOjv/drtcLsTExPRb3/ng93GGmazTIcfbALubn4OIaHT54Q9/iHfeeSd4j+D6+npUVFScc5958+bh/fffR3t7O9ra2vDee+9h7ty5fW5rs9lQXl4OAIiLi0NLS0u/z9vc3IykpCSYTCacOHECxcXF5/1+7HY7pkyZct779YWJoAGbScVOnxkerxd6XpqOiEaJyZMn4/HHH8ftt98OIQQURcHTTz99zn1mzJiB2267LThsffvtt+Piiy/uNlzcYeHChdi7dy9ycnKQkpKCyy67DNdeey2uueYaLFy4sNu2CxYswB/+8Afk5+fDZrNhzpw55/VePB4PTp48iUsuueS89uuPJM41BhBmlZWVWr20pj7ZvRcvnjXjpctNyJk0PmTPa7Vag584KTTYpqHHNg2PwbRrW1tbr6HdaFJdXY1HHnkEf/rTn0LyfIqiwOv19rlu165dOHToEB5//PF+9++rvTMyMvrclsPUGsid4D/PurSsSuNKiIiiR1paGn7yk5+gubk57K/l9Xpx//33h+z5GMYaGJuTDaPPDXttq9alEBFFlX/8x39EQkJC2F/npptuQlJSUsiejwcsNaDoFUzw1qPMe2H3ySQioujCnrFGbEYv7EoyfD6f1qUQEZHGGMYasaXEwKkzour70XkSGxERdWIYayR3fBoAoLTsjMaVEBGR1hjGGsnOHQdF9cJeE/6z/oiIRoOOK3xd6NnUkyZNAgBUVVXhvvv6vlzosmXL8O233wIAfvSjH6GhoeGCXqsnhrFGDEYDsj0NsLfxn4CIqEPP82gGc15Nx/3oP/zww+AtFIciPT0dr7zyyoDb3XrrrXj99deH9FodeDa1hmx6F4p8yee8lioRUTj8fn81yupDe438HHMMfp6Xds5thusWik8//TQyMzOD9zNev3494uLicNddd+FnP/sZGhsb4fV68fjjj+Mf/uEfutV46tQp3H333dizZw/a29vx2GOPoaSkBBMnTux2bepFixZh6dKleOSRR4bcdkwADdnMRjTp41BXGZq7fhARjWRdb6H4wQcfQKfTdbuF4jvvvIPLL7+823xMTEzwFoo7duzAH//4R/z9738HAJSWlmLZsmXYvXs3srKyUFRUhJkzZwIAbr75ZuzYsSP42jt27MBNN90Eo9GIV199Fe+//z7+8pe/4KmnnjrnzSjeeOMNmEwmFBQU4OGHH8bBgweD65KTk+FyueBwOIbcNuwZayg3OxVoVGEvrcCYrLFal0NEo8hAPdhwGM5bKF588cWora1FVVUV6urqkJSUhMzMTHg8Hqxduxb79u2DJEmoqqrC2bNnkZqa2mfN+/btwz333AMAmD59OqZNm9ZtvdVqRXV1da9bL54vhrGGxk8aB+mQHfbqJvR9DxIiougx3LdQXLJkCXbu3ImamhrcfPPNAIA333wTdXV12LVrF/R6PebOndvnrRO7kiSp33Whuo0ih6k1FBsbgwxPI0pbNLtXBxHRsBnOWygC/qHqt99+Gzt37gz2tJubm2G1WqHX6/HFF18M+Ppz587FW2+9BQA4evQojhw5ElwnhMDZs2eRnZ098JsfwIA949raWmzcuBENDQ2QJAn5+fnBIYauBW3duhUHDhyA0WjEypUrYbPZhlzcaGBTnDjiTYAQ4pyfvoiIIt1w3kIRAKZMmYLW1lakp6cjLc0/LL906VLcfffduOGGG3DRRRdh4sSJ53z9n/70p3jssceQn5+P6dOnY9asWcF1Bw8exJw5c6CE4Fa4A95Csb6+HvX19bDZbGhvb8fq1auxatUqZGVlBbcpLi7Ge++9hzVr1qCkpASvvfYannnmmQFffLTeQrGrN//2OV5vtuKNfAuS0sYM6bl4a7rQY5uGHts0PHgLxeG9hSIAPPnkk7juuutw5ZVX9rk+pLdQNJvNwV6uyWRCZmZmrzPH9u/
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAEJCAYAAACnqE/cAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAxL0lEQVR4nO3deXhU5d0+8PucmcnMZJ/MkIQsGAKBJC5ADIsgsiQGBRUvtWjdStFfrdpXqpZWqdq+9qUvLhSlLxSqgEvt4tUK1lbURkDRWA2bUpJgwhIIZJ3JOsms5/z+mGTIZCFAZnJmkvtzXbnmLM/M+eYh5M5zVkGWZRlERESkGFHpAoiIiEY6hjEREZHCGMZEREQKYxgTEREpjGFMRESkMIYxERGRwtRKbvzMmTNKbn7YMZlMaGhoULqMYYV96n/s08Bgv/pfIPo0KSmpz+UcGRMRESmMYUxERKQwhjEREZHCFD1m3JMsy7DZbJAkCYIgKF1OyKmtrYXdbgfg6UtRFKHT6diXRERBLqjC2GazQaPRQK0OqrJChlqthkql8s67XC7YbDbo9XoFqyIiooEMmHobNmzA/v37ERMTgzVr1vRaL8sytm7digMHDkCr1eKhhx5Cenr6RRUjSRKD2I/UarV3pExERMFrwGPGc+fOxcqVK/tdf+DAAdTU1GDdunX4wQ9+gFdfffWii+HuVP9jnxIRBb8Bh6HZ2dmoq6vrd/3evXtxzTXXQBAETJgwAVarFY2NjTAYDH4tlIiIzk2WJEByA91f3RIguz2vfa33advHOlkGIAOSDMhSt2nPcrlruSz3aNs13fMzup7a223a+yDfHuu7Xvpa5p2Wfd+Pbu17PiHYZ9t9LOtRS5s+HPK8RRA0YQP2/WANep+wxWKByWTyzhuNRlgslj7DuLCwEIWFhQCA1atX+7wP8JyAxN3Ug9Oz/7Raba9+pvOnVqvZf36mZJ/KLhdku63XF+w2yHY7ZHtH5zK759VhB9wuyC6X59XtAlzuzlcXILnPrvN5dfexrPP9bvfZ4JJlyN2Do/tXVzB1Bo/ssxw92siolSRF+nQ4swoCRi2+A2JEVMC3Nejkk3v+5YH+d43m5+cjPz/fO9/zziZ2u93nBKRQJcsylixZgi1btkCSJGzbtg1Lly694M+555578H//93+IiYnpt82zzz6L+fPn4+qrr4ZarYbL5fJZb7fbeVeeQeBdjfzvXH0qSxLgsAMOG2CzeV7tdsDeAXQGZH/r4LB7wtNu6/yMri+Hp63D7gnCCyUIgEoFqNSAqDo7reo53eNVHQZo9YBKDaHnewTB8wUBENBtutsX4NNG6N7e2xaAIAIQEB4RgXa7HRBFT52iCKhEQFB5Xrsv75wWfNqqPJ+lUp1t21WLKPrWJog+24bY7fsRO9cDZ9/XfXlX/d2/v26L+lzW9f32t8zbL/Bd1rN99/W96uidXd6f1Q7/nXvT3x24Bh3GRqPR5z+W2Wwe8buoP/74Y2RnZyMqKgqnTp3CG2+80WcYu93uc/7x8eabbw64rWXLlmHFihW4+uqrB1MyjWCy2+0JNJtnhOgJt87Ac9o96yUJcHtGgp5Rn9t3uue85O5sL3lHiZAkyG4XGlUquFtbOrdlOxue9g5PcF4ItRoI0wFaHaDVAmFaz7w+AoiJgxCm7ba825dWB4RpPeu9y/popwnzBGkIiDSZYOMfjiFr0GGcm5uLDz74ALNmzUJ5eTnCw8P9EsbSn1+BfOr4oD+nOyF1LMQ7/t+A7ZYtW4YzZ87Abrfjvvvuw913341du3Zh9erVcLvdiIuLw9tvvw2r1YqnnnoK33zzDQRBwKOPPopFixZh27ZtuOuuuwAAv/71r1FZWYlrr70W11xzDfLy8vCb3/wGCQkJOHz4MHbv3t3n9gBg+vTp2LFjB6xWK+6++25MmzYNe/fuRWJiIrZs2QK9Xo+UlBQ0Njairq6u37+4aPiR3W6gvQ1oawWsLUBbK2Rbx9kwtfm+du2KhbdNt/Uup3+KEsUeo0eV73TnvKTVetrpI4DYOAhavScItbqzX2GecPWuC9MBOl2P4NVB4GEtGiYG/El+6aWXUFJSgtbWVvzwhz/EkiVLvLtCCwoKMGXKFOzfvx+PPPIIwsLC8NBDDwW86EBbs2YNDAYDOjo6sGjRIixYsAArVqzAO++8gzFjxqCxsRGAp2+ioqLw8ccfAwCampoAAMXFxXjuuecAACtXrsSRI0fwr3/9CwBQVFSEgwcPYufOnRgzZkyf21u4cCHi4uJ8ajp+/DjWr1+PF154AQ888ADef/993HrrrQCAyy+/HMXFxVi8eHHA+4b8T3bYgbaWzmBthdwtYGFt9QRtW0vndOdru3XgD+4aAer0Z0NOHwEYjBC0Os8uVG1nyGnPthG87fWARtNjl6yqj7BVA6J43mfuG7nrn6iXAcP4xz/+8TnXC4KA+++/31/1eJ3PCDZQtmzZgh07dgDwPFnqD3/4A2bMmOENz66R/549e7Bhwwbv+2JjYwF4QjkyMrLfz588ebL3s/ra3vHjx3uFcWpqKi677DIAwBVXXIFTp0551xmNRtTW1l7st0sBIrdbgcYGwFIP2dIAWDqnmy1Aa2eoWlvOvWtWpwciooDIaCAiCsKoxM75bssio4GISE/QdoVrmBaCGBq7V4koyO7AFQyKioqwZ88evPfee9Dr9bjttttw6aWX4tixY73ayrLc52hArVZDkiSIYt+XcYeHh59ze33dqEOr1XqnVSoVbDabd95ut0On013Q90mDIzsdnUHbANlS7wnaRt9pdLT7vkkUgVgjYDACxlEQLkkHIqI9wRoRBSEyymcekVEQ1BplvkEiGlIM4x5aW1sRExMDvV6PiooK7N+/H3a7HV988QVOnjzp3U1tMBgwZ84cbN26Fc8++ywAz4g4NjYW6enpqKysxNixYxEREYG2trYL2t6FOnbsGG644YaL/p6pN9nWDtSchu3I15BOHPMNWks90Nrc+01RMUDcKCA+CULWJMBgAuJMEOJGeaZjDSE1WpU6r5QQg+jGMbIsw+GWYXdJsLlk2FyS90uSPTV3XQ0kybL3slYJnuWeeblzvWda6tbe26azPeD5/kUBUImdr4LgMy2Kna+dbVR9tO+5Tvb5njqvVOp2nWvnRUs9LrWV0W2RzyW6sgzUuVphtrTDLQMuSYZbkuGSPa9uCXDLcudyz7RbOjt/tp3s836px8UyPX8UzvWTcc62PVYK/bXrsaDnuv623/Man17X/Mj9r+s+r9c1Y0lmFLTqwD9TiWHcw9y5c/Hmm28iPz8f6enpyMnJgdFoxPPPP4/7778fkiTBZDLhz3/+M5YvX46VK1di/vz5EEURjz32GBYuXIi8vDx88cUXGDt2LOLi4jB16lTMnz8f8+bNQ15e3oDbuxBOpxMnTpzApEmT/NkNI4Isy0BLE1B9CnJNFVBdBbn6FFBz2jOyBeCNXJ3eE7RxJgiXjPMN2jgTYDBd0I0BJFmG0+35ZeiSZDglGS53t+nOeWeP9d75nm17rD+7DD7Lum/H1XN5j8/o+kWsFgGNKCJMLSBMFKBRiQhTCd6v7vNd0xqVp22YWux8j4CwznUGs4yGpubOEO0KVqmP+d7rHG6pV0CQ/6g7/7BQi54/OLqungIuPtA8C+R+18n9zvT+o6X/pjJ6RnOv4D7nHwd9rxOEVtw6sf9Djv4kyH1dKDxEzpw54zPf3t7usws3VNXW1mL58uX485//HPBt7dixA4cOHcJPf/rTPq8zHi59Ohiy2w2Ya7uFbRXk6iqgpsr3RCitHhidAiExxfM6OhWxE7LQJKghhEec/TxZht0tw+pww+qQPK9OCW19zncuc55d1+6U4A7A/zq16PklqhG7TasE77T3SyVA03NZH201ogAZnuB3SDIcLhlOSYLD3bnM7QnH7vPOznmH29PWdR73oRAAaNUi9GoBWrUInVrsfBV6Tfdap+pa5qlXEDwBIgqAAM8oVOy8RlcQPPf/FTpHp12X93aNfAVB6Fx/drr7CLtrNNk1LXWOMCUZnSPKs9Ndo3TvKLRH+67vG/C97LX
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 4.75e-03\n",
" final error(valid) = 2.19e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.64e-01\n",
" run time per epoch = 1.48\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.20\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAEJCAYAAACnqE/cAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAA6vklEQVR4nO3deXhU9b3H8fc5M5PJvs2QhISwJKCioogRLK5I5LrU1sdaq7VWxHpb9VarrQpWbWvF0ouIcotXrwW0vbXVaxUrlUrjLnFBEcWtbAGBbGRPINvMOfePmQwJBAJkkkkmn9fz5Jlzzpxz5psfIZ/8fmczbNu2ERERkYgxI12AiIjIUKcwFhERiTCFsYiISIQpjEVERCJMYSwiIhJhCmMREZEIc0byw0tLSyP58VHH6/VSVVUV6TKiito0/NSmfUPtGn590abZ2dndLlfPWEREJMIUxiIiIhGmMBYREYmwiB4zFhGR/mPbNi0tLViWhWEYkS5nwKuoqKC1tfWwt7NtG9M0iY2NPeR2VhiLiAwRLS0tuFwunE796j8UTqcTh8NxRNv6fD5aWlqIi4s7pPU1TC0iMkRYlqUg7idOpxPLsg55fYWxiMgQoaHp/nU47R0VYWxXlGI99wfsuppIlyIiInLYoiKMaazDXvksfLU50pWIiIgctugI4+yRANilX0W4EBERiRTbtvn2t79NY2Mj9fX1PPHEE0e0n6uuuor6+vqDrnPvvffy9ttvH9H+uxMVYWzEJ0JqOiiMRUQGNb/ff9D57ti2jWVZvPLKKxx77LEkJSXR0NDAH/7wh0P6jH398Y9/JCUl5aDrzJo1i8WLF/dY26GKntPqskdil26PdBUiIoOC9ZfHsbeXhHWfRu4YzMuvO+g6f/3rX1m6dCltbW2cdNJJ/OY3v+GYY47h3//933njjTe45557uPLKK7vMf/TRRzz99NMAXHHFFVx33XVs376d733ve0ydOpUPP/yQpUuX8vzzz3PllVcCcP/997Nt2zbOPfdczjzzTKZPn86DDz5IZmYmn332Ga+//jqzZs2itLSU1tZWrr32Wr73ve8BMGXKFFauXElLSwtXXHEFkydP5oMPPiArK4ulS5cSFxfHiBEjqK2tpbKykoyMjF63XVT0jAGM7JFQ9hX2YZxKLiIi/Wfjxo387W9/Y/ny5fzzn//E4XDw3HPPsWfPHo4++mhWrFjB5MmTu8zHxsbyzDPPsGLFCl588UWeeuopPv30UwA2b97MpZdeyqpVqxgxYgRr1qzhhBNOAODOO+9k1KhR/POf/+Tuu+8GYN26ddxxxx28/vrrACxYsIB//OMfvPTSSyxdupSamv1PAi4pKeHqq6/mtddeIzk5mZdeein03oQJE1izZk1Y2iaqesa0tUFVBWQMj3Q1IiIDWk892L7w9ttvs379ei644AIgcBMSr9eLw+HgwgsvDK3Xef7999/nvPPOIz4+HoDzzz+f9957jxkzZjBixAhOPvnk0HZ1dXUkJiYe8PMnTpzIyJEjQ/NLly5l5cqVQOApgiUlJaSnp3fZJjc3l+OPPx6AE044ge3b947AejweKioqjqgt9hU1YWxkj8SGwHFjhbGIyIDTcYLVnDlzuix/9NFHu9zpyu12h+Zt2z7g/joCukPHjTZMs/tB387rFxcX89Zbb/Hiiy8SFxfHpZde2u2tL91ud2ja4XDQ0tISmm9tbSU2NvaA9R2OqBmmZnguoDOqRUQGqtNPP50VK1aEnhFcW1vLjh07DrrNqaeeyssvv0xzczN79uzhH//4B1OmTOl23by8PLZt2wZAQkICTU1NB9xvY2MjKSkpxMXFsWnTJtauXXvY38+WLVs4+uijD3u77kRPzzg+AdK8OqNaRGSAOuqoo7j99tu54oorsG0bp9PJ3LlzD7rNhAkT+Pa3vx0atr7iiis4/vjjuwwXd5g+fTrvvPMOY8aMIT09nVNOOYVzzjmHadOmMX369C7rnn322fzxj3+ksLCQvLw8Jk2adFjfS3t7O1u3buXEE088rO0OxLAPNgbQx0pLS8O6P/9Dv4CGOhz3PBzW/Q4WXq839BenhIfaNPzUpn3jUNp1z549+w3tRpOKigpuvvlm/vKXv4Rlf06nE5/P1+17K1euZP369dx+++0H3L679s7Ozu523egZpiZ4RnX5Tmyr5+vSREQkumRmZvLd736XxsbGPv8sn8/HD3/4w7DtL2qGqYHAGdXtbbCrAjK7/+tDRESi1ze+8Y1++ZyLLroorPuLvp4x6LixiIgMKlEVxmTrjGoRERl8oiaM65p9+FxxkD5MPWMRERlUoiKMP9jZxNXPbWJLbUvwHtUKYxERGTwOKYzXrVvHzTffzI9//GOWL1++3/ufffYZV199Nbfddhu33XYbzz77bLjrPKiRKYE7pGypacHIGQnlO7AP4UkfIiISPTo/QvFIjBs3DoDy8nKuu67724VeeumlfPzxxwB85zvfoa6u7og+a189nk1tWRZLlizhrrvuwuPxMGfOHAoKChgxYkSX9caPH8/s2bPDUtThGpbgJCnGDPWM8flgVxlkjeh5YxERGTD8fn+XW2PuO98d27axbZtXX3019AjF3sjKyuLxxx/vcb1vfetbPPnkk9x88829+jw4hDDetGkTWVlZZGZmAjB16lTWrFmzXxhHkmEYjEmPZXNNK8b4TveoVhiLiHTr9x9UUFLb0vOKh2FMWiw/KMg86Dr99QjFuXPnkpOTw8yZM4HAE5oSEhK46qqruOaaa6ivr8fn83H77bfzb//2b11q3L59O1dffTVvvvkmzc3N3HrrrWzcuJGxY8d2uTf1jBkzuOSSS/onjGtqavB4PKF5j8fDxo0b91tvw4YN3HbbbaSlpXHVVVeRm5u73zpFRUUUFRUBMG/ePLxeb29q7+K47Eae/biU1PEnUQPE1VWTGMb9DwZOpzOsbSpq076gNu0bh9KuFRUVOJ2BX/umaWIYRlhrME0ztP/ubNiwgRdffJEVK1bgcrm44447eOGFF9izZw/HHnts6AESnec//vhjnnnmmdDTlc4//3xOP/10UlJS2Lx5Mw8//DDz588H4IMPPmDBggU4nU4uueQS7r77bn7wgx8AsGLFCv785z+TkJDAk08+SVJSEtXV1VxwwQVccMEFobZwOp1deuJ/+tOfiI+P5/XXX+ezzz7j3HPPxeFwhNq7ra2NhoaG/Z72BIGHTBzqz3qPYdzd3TL3/QccM2YMjzzyCLGxsaxdu5b58+ezaNGi/bYrLCyksLAwNB/OW+INj7Vo99t8XFpDrjeTPRu/oGWI3XJPtxkMP7Vp+KlN+8ahtGtra2soaGZNGtYndRzo9pEAb7zxBp988gkzZswAAo9QTE9Px+FwcN5554W27Tz/zjvvcN5554WennTeeedRXFwceoTixIkTQ9vV1tYSGxuLz+dj/Pjx7Nq1ix07dlBdXU1ycjJZWVm0t7dz33338d5772EYBuXl5ZSVlZGRkRGq39/pnKPi4mJmzZqFz+fj6KOPZvz48fj9/tBnejwedu7cSXJy8n7fb2tr637/Jge6HWaPYezxeKiurg7NV1dXk5aW1mWdzvfenDRpEkuWLKGhoaHb4vpKXnrgMVZbalvJ1RnVIiIDTn8/QvHCCy/k73//O5WVlXzzm98E4LnnnqO6upqVK1ficrmYMmVKt49O7OxgIwjheoxij2dT5+fnU1ZWRmVlJT6fj+LiYgoKCrqsU1dXF2qwTZs2YVlWrw+gH67spBjcDiNwRnX2SKgoxT7IX2giItK/+vMRigDf/OY3eeGFF/j73/8eeupTY2MjXq8Xl8vF6tWre/z8KVOm8PzzzwPw5Zdf8sUXX4Tes22bXbt2dXtY9nD12DN2OBzMmjWLuXPnYlkW06ZNIzc3l1WrVgGBA9jvvvsuq1atwuFwEBMTw09+8pOwH4vosU7TYHSae+8Z1X4fVJYGpkVEJOL68xGKAEcffTS7d+/uchLyJZdcwtVXX83555/Pcccdx9ixYw/6+d///ve59dZbKSws5Nhjj2XixImh9z755BMmTZp00OPkhyqqHqH
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfAAAAEJCAYAAABrMXU3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAABAQklEQVR4nO3deXxU5b348c85M1nJNpnJQjaSgIrUKNogCiJiIrUimkstrWspWhX01guWFv1Z7FVpcaGxVBB+XhatXeR3r+TW1oUGEZd4Mcp63ZpIgISEbDPZM5PMnPP7Y5IhIYEQMsnMkO/79cpr5pzznHOe8ySZ7zzPec7zKLqu6wghhBAioKi+zoAQQgghBk8CuBBCCBGAJIALIYQQAUgCuBBCCBGAJIALIYQQAUgCuBBCCBGAjL7OwGBVVlb6OgvnDIvFQl1dna+zcc6RcvU+KdPhIeXqfd4u06SkpFNukxq4EEIIEYAkgAshhBABSAK4EEIIEYAkgAshhBABSAK4EEIIEYAG7IW+bt069uzZQ3R0NKtXr+6zXdd1Nm/ezN69ewkJCWHx4sVkZmYCsG/fPjZv3oymaeTk5JCXlwdAS0sL+fn51NbWEhcXx5IlS4iIiPDulQkhhBDnsAFr4Ndccw2PPvroKbfv3buX48ePs2bNGu69917+4z/+AwBN09i4cSOPPvoo+fn5fPTRR1RUVABQUFBAVlYWa9asISsri4KCAu9cjRBCCDFKDFgDnzRpEjU1Nafc/umnn3L11VejKArnn38+ra2t2Gw2amtrSUxMJCEhAYBp06ZRXFxMSkoKxcXF/OpXvwJg5syZ/OpXv+KOO+7wzhUJIUYdXdNAc4Gr69Wz3ON99/qeaXQNdL33DzpoXa8nb+uZ5rTb8KTRu99DjzSclPbk5ZPTcdpz9FnfJ4/9p20JH4PW1tp1/J4FevIs03q/b0+x4iTKaRZP3qacctMpT+fVvA5dW1IKTJk57OcBLwzkYrVasVgsnmWz2YzVasVqtWI2m3utLykpAaCxsRGTyQSAyWSiqanplMcvLCyksLAQgFWrVvU6lxgao9Eo5TkMzuVy1V1O9LY2tNZm9LYW9LZW9/vWVvTWZrS2VvTODnfgdDnRnU5wudBdzhPrTnrF5XKn07peXb330V1Oal2aex9N6wrKLvTuAO1y+bpYAlartw6knCLa9gmufuBUefUSe8b5WL77vWE9R7chB3C9n1+QoiinXD9Yubm55ObmepZl1CDvkVGYhoe/l6uu69DcALXV0BWEaW+F9jbo8d6zvq1rW3sbONoHPoHB4P5RDWAwgqq6X7vX97dO7XoNDfNsV7q3qwZCxozB0dEJBtWdVlVRuvdRT6xzH0c9cTxFPXF8Ve067on3KAZ3TU9Ru2p8CqiK+1Xp/lH7pjndNrX7c64rDUpXmp7rex6n6z090ig99u3e3Of83e/VHsdX6Pc6euZBObHeYrFQV9/zb7VnDbj35/XZfH6fTp8Y0Wv55G29s3bamjvez+tgxI7gSGxDDuBms7lXZuvr6zGZTDidTurr6/usB4iOjsZms2EymbDZbERFRQ01G0KIk+itzVBdiV5TCdXuH726EmoqwX6KQGwwQvgYCAuHsDHu99EmlLAxJ5a7tind78O7tnWvNxi8fi3Rfv6lKFApBoP7i5Avzn1ykPVh0A1UQw7g2dnZvP3220yfPp2SkhLCw8MxmUxERUVRVVVFTU0NsbGxFBUV8dOf/tSzz65du8jLy2PXrl1MmTJlyBcixGik29uhpiswdwVnT5BuaT6RUFHBEg/xY1EmXAjxSSjxiTAmsndgDgr2ae1FCHHmFL2/tu4enn/+eb744guam5uJjo5m/vz5OJ1OAGbPno2u62zcuJH9+/cTHBzM4sWLGT9+PAB79uzh5ZdfRtM0Zs2axbx58wBobm4mPz+furo6LBYLS5cuPePHyGQyE+/x96beQOXNctV13d2MbbNC9TFPbdr9WgWN1t47mCzuIJ2QDAldr/FJYElACQrySp58Qf5Wh4eUq/eN5GQmAwZwfyMB3Hvkn3d49Feuuq677x+3trhrxq3N7ibu1uau5RZobUJvbXGv617f1uLuLd1TZDQkJKEkJLlr0l3BmrgklJCQEbzSkSN/q8NDytX7RjKAB9x0okL4G93lcjddl5fBsSM02ttwWes8gdr90+LuXX0qIWEwJgIiImFMJIrJ4l4eE+V+jTadCNjhY0bu4oQQfksCuBCDoLc0QcVh9IrDUFGGXnEEKo9CZ4c7gcFAR0ys+77ymEgYm4rSIzAzJhJlzIn3RERCeERAN28LIXxDArgQ/dBdLvc9556BurwMGk48WUFkNKRmoMyaAynpKKnpkJhCXOJYaZYUQgw7CeBi1DtRqy5zv5YfdteqnZ3uBAYjjE1BmZgFKRkoKemQmo4SZfJltoUQo5wEcDEq6eVlaAWvwtFD/deqr72xV61aMUoTtxDCv0gAF6OOXvIF2u+fAGMQyrculVq1ECIgSQAXo4p+8FO09avAFIe65AkUc5yvsySEEGdFArgYNbRP3kfflA/J41Af+hVKVIyvsySEEGdNArgYFbT33kL/03qYcCHqg7+UZ6mFEAFPArg4p+m6jv7Wf6Jv+wNkZaPe94tzdrQyIcToIgFcnLN0XUf/ry3o72xDuXwmyo8fQjHKn7wQ4twgn2binKRrLvQ/rEP/8B8o19yAcuu9KKrq62wJIYTXSAAX5xy9sxNt42r4rAhlznyUm2+XKTKFEOccCeDinKLb29Fe/A18sQ/l+wtRZ+f5OktCCDEsJICLc4be2oy25gkoK0H50b+iXnWdr7MkhBDDRgK4OCfoDVa05x+H6mOo9/8c5bJpvs6SEEIMKwngIuDptcfR8ldAUwPqv65AmTTZ11kSQohhJwFcBDT92BG0/MehswN16ZMomRf4OktCCDEiJICLgKUf+hrtd/8OQcGoP/8NSvI4X2dJCCFGjARwEZD0L/ahrfs1RMW4JyWJS/R1loQQYkRJAB+l9E8/xLanCC3G7J5OMy0jYOa91vcUob30HCQko/7bv6PExPo6S0IIMeLOKIDv27ePzZs3o2kaOTk55OXl9dre0tLCiy++SHV1NUFBQSxatIi0tDQqKyvJz8/3pKupqWH+/PnMmTOHrVu3smPHDqKiogC49dZbueyyy7x3ZaJfemcH+taN6O+9hTPWgt7cBJ0d6AAGI4xNRUnNgNSME69jIn2dbQ/to0L0l1+AjPNQf7rCr/ImhBAjacAArmkaGzdu5LHHHsNsNvPII4+QnZ1NSkqKJ822bdtIT09n2bJlHDt2jI0bN7JixQqSkpJ49tlnPce57777uPzyyz37zZkzh5tuumkYLkv0R6+uRNvwNJSXocz+Fyw/WUJdfT3UVKIfPQQVh9HLD6F/sRc+ftcd1AFMFncg766pp2RAXOKID02qbS9A/3+bYNKlqIsfQQkJHdHzCyGEPxkwgJeWlpKYmEhCQgIA06ZNo7i4uFcAr6io4F/+5V8ASE5Opra2loaGBmJiYjxpDh48SGJiInFxcV6+BHEmtE/eR39lLRiNqP/6S5SLp6AYjSgGg7vWPTYVps70pNebbFB+GL2iDMrL0MvL0P/3M3RNcycICYWUdJSUdHcTfGqGe3kYgqqu6+gFf0R/cyvKt6ej3L0UJcj/m/qFEGI4DRjArVYrZrPZs2w2mykpKemVZty4cezevZuJEydSWlpKbW0tVqu1VwD/6KOPmD59eq/93nnnHd5//30yMzO56667iIiI6HP+wsJCCgsLAVi1ahUWi2VQFzja6Q4HzZuep337fxM0MYvoh5/AYHF/GTMajacuT4sFMs/rfawOB87ywzjLSug8XILzcCnO4g/Rd73trq0rCobEZJQwL8+13dmBVl5GWO5cIu//uftLhx87bbmKsyJlOjykXL1vJMt0wACu63qfdSdPDJGXl8eWLVtYtmwZaWlpZGRkoPZoXnU6nXz22WfcdtttnnWzZ8/mlltuAeC1117jlVdeYfHixX3OlZubS25urme5rq7uDC5LAOjHK9A2PAMVh1Gu/x6um2/
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.14e-03\n",
" final error(valid) = 1.49e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.71e-01\n",
" run time per epoch = 1.47\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=0.50\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEJCAYAAABbvWQWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAA+vklEQVR4nO3deXxU5d3//9d1ZpLJvk1IQhL2HVkCBtlU0ESqLEpREa1WxFq3u6htrWjdbhRL64KlP/xqFdC2t1VbFQVBaRRwCUUQIrggQQGBBEL2yZ6Zc/3+mDAkkJAEJplJ8nk+HvPInDnnzHxyKXnPuc65rqO01hohhBBC+CXD1wUIIYQQomkS1EIIIYQfk6AWQggh/JgEtRBCCOHHJKiFEEIIPyZBLYQQQvgxq68LaEpOTo6vS+g0YmNjyc/P93UZnYq0aduQdvU+adO24e12TUxMbHKdHFELIYQQfkyCWgghhPBjEtRCCCGEH/Pbc9RCCCHah9aaqqoqTNNEKeXrcjqEo0ePUl1d3ap9tNYYhkFQUFCr2lmCWgghuriqqioCAgKwWiUSWspqtWKxWFq9n9PppKqqiuDg4BbvI13fQgjRxZmmKSHdTqxWK6ZptmofCWohhOjipLu7fbW2vTt1UGvThfnhanTWf31dihBCCHFGOnVQK8OC3rgW86P3fF2KEEIIcUY6dVADqJFjYc9X6IoyX5cihBDCB7TWXH311TgcDkpKSnj55ZfP6H1uuOEGSkpKTrvNwoUL+fTTT8/o/ZvS+YN61DhwudC7vvB1KUIIIc6Qy+U67XJjtNaYpsmHH37I0KFDCQ8Pp7S0lL/97W8t+oyT/f3vfycyMvK028ybN49ly5Y1W1trdP7L/PoMhIgoyNoCYyf5uhohhPBr5msvog/u8+p7qh59MObcctpt3nzzTVasWEFNTQ2jRo3iD3/4A4MHD+aXv/wlmzZt4uGHH+ZnP/tZg+UdO3bw+uuvA3Dttddyyy23cPDgQa6//nomTJjAF198wYoVK3j77bf52c9+BsATTzzBgQMHuOSSS7jwwgtJS0vjmWeeIT4+nq+//pqNGzcyb948cnJyqK6u5uabb+b6668HYOzYsaxbt47y8nJuuOEGxowZw7Zt20hISGDFihUEBweTnJxMUVEReXl5xMXFeaX9Ov8RtWGgRp6H/uoLdG2tr8sRQghxkuzsbN59911WrVrFf/7zHywWC2+99RYVFRUMGjSINWvWcN555zVYDgoK4o033mDNmjWsXr2aV199la+++gqA77//nquuuor169eTnJzM1q1bGTFiBAAPPPAAvXr14j//+Q8PPfQQAFlZWdx3331s3LgRgKeffpr333+ftWvXsmLFCgoLC0+p+YcffuDGG29kw4YNREREsHbtWs+64cOHs3XrVq+1T+c/ogZUylj0J+vhu50w7FxflyOEEH6ruSPftvDpp5+ya9cupk6dCrgnYImNjcVisTBt2jTPdvWXP//8cy699FJCQkIAuOyyy9iyZQtTpkwhOTmZc8898be+uLiYsLCwJj8/JSWFnj17epZXrFjBunXrAPedHPft20dMTEyDfXr27MmwYcMAGDFiBAcPHvSss9vtHD169IzaojFdIqgZMhJsQeisLSgJaiGE8CvHL/a6//77G7z+/PPPN5j9y2azeZa11k2+3/HwPu74JCOG0Xgncv3tMzMz+eSTT1i9ejXBwcFcddVVjU4VGhgY6HlusVioqqryLFdXVxMUFNRkfa3V6bu+AVRAIJwzGv3l5+hWzggjhBCibZ1//vmsWbPGc3/noqIiDh06dNp9xo0bxwcffEBlZSUVFRW8//77jB07ttFt+/bty4EDBwAIDQ2lrKzpUUAOh4PIyEiCg4PZu3cv27dvb/Xv88MPPzBo0KBW79eUrnFETV339/ZMOLDXfYGZEEIIvzBw4EB+97vfce2116K1xmq1smjRotPuM3z4cK6++mpPV/i1117LsGHDGnRBH5eWlsbmzZvp06cPMTExjBkzhosvvpiLLrqItLS0BttOnjyZv//976Snp9O3b19Gjx7dqt+ltraW/fv3M3LkyFbtdzpKn67/wIdycnK88j7FVU5qXZpYqjB/fQPq0isxfnqDV967o4iNjfV8UxXeIW3aNqRdva8lbVpRUXFKd3FncvToUe666y5ee+01r72n1WrF6XSe8vq6devYtWsXv/vd75rct7H2TkxMbHL7Tt31XevS/GrNPv6RdQwVGg4DzkHvkOlEhRCiK4mPj+e6667D4XC0+Wc5nU5uvfVWr75npw7qAItiUu8IPv2xlIKKWlTKWMg9iM7zztG6EEKIjuHyyy8nPDy8zT9nxowZzU6K0lotOkedlZXFypUrMU2TtLQ0Zs6c2WD9J598wjvvvANAUFAQv/jFL+jdu3eL9m1r0wZFs+a7It7PLua6lLHo119yX/095aftWocQQghxJpo9ojZNk+XLl/PAAw+wZMkSPvvss1OuxouLi+PRRx/lqaee4sorr+Svf/1ri/dta93DAzkvOYz3s4upieoGyX3QO7a0aw1CCCHEmWo2qPfu3UtCQgLx8fFYrVYmTJhwyowrgwYN8gwmHzBgAAUFBS3etz1cPjiG0moXm/aXuru/v9+NLi1u9zqEEEKI1mq267uwsBC73e5ZttvtZGdnN7n9Rx99xKhRo1q9b0ZGBhkZGQAsXryY2NjYlv0GLTDJrhnwZQFr95Zy1eQpFK15jbAfdhOcPt1rn+HPrFarV9tTSJu2FWlX72tJmx49ehSrtcuM1vWaM20zm83Wqv/Pm/2UxkZvKaUa3farr75iw4YNLFy4sNX7pqenk56e7ln29hCNqf0j+PPmXD6uiGF4TDdKP82gPGWcVz/DX8mQF++TNm0b0q7e15I2ra6ubjADWGejtWb27NmsWLHijC4oGzBgANnZ2Rw5coSHHnqIF1988ZThWVdddRUPPfQQI0eO5JprruGFF14gKiqq0ferrq4+5b/JWQ3Pstvtnq5sgIKCAqKjo0/Z7sCBA7zwwgvce++9noZo6b7t4YJe4UQFWVi9u8jd/f1tFrqRaeGEEEL4H2/d5vJsJCQk8OKLLza73ZVXXskrr7xyVp9VX7NH1P369SM3N5e8vDxiYmLIzMxk/vz5DbbJz8/nqaee4n/+538afCtoyb7tJcBicNnAaP65M5+coePo/tEa+GYHjOoaR9VCCNESL207yr6iquY3bIU+0UH8IjX+tNu0120uFy1aRFJSEnPnzgXcd8oKDQ3lhhtu4KabbqKkpASn08nvfvc7fvKTnzSo8eDBg9x444189NFHVFZWMn/+fLKzs+nfv3+Dub6nTJnCrFmzuOuuu7zSfs0eUVssFubNm8eiRYu45557GD9+PD169GD9+vWsX78egH//+9+UlZXx0ksvce+997JgwYLT7usrlw6Iwmoo1tR0g5BQdJZc/S2EEL7Wnre5vOKKK1i9erXns1evXs2MGTOw2WwsX76cDz74gH/9618sXLjwtDf+eOWVVwgODiYjI4P58+ezc+dOz7qoqCiqq6sbvT3mmWjRmfDRo0efMt/plClTPM9vu+02brvtthbv6ytRQVYm9Y7go32lXDd8HKE7P0e7XKhOfG5GCCFao7kj37bQnre5HDZsGPn5+Rw5coSCggIiIyNJSkqitraWxYsXs2XLFpRSHDlyhGPHjhEXF9dozZs3b+amm24CYOjQoQwZMqTB+tjYWI4ePXrK7THPRJe7zG/G4Gg+/KGEjJ7nc8WWD+H7b2HgMF+XJYQQXVZ73+Zy2rRpvPfee+Tl5XHFFVcA8NZbb1FQUMC6desICAhg7Nixjd7esr6mLo4G797qslNPIdqYPtFBjIgP4b2yCFzWQOn+FkIIH2vP21yCu/v7nXfe4b333vMcoTscDmJjYwkICGjR5Fzjx4/n7bffBmD37t18++23nnVaa44dO+a1U71dLqjBfVSdX+nivyOmorO2nPabmRBCiLZV/zaX6enpXHvttRw9evS0+9S/zeX06dM9t7lszPHbXB43aNAgysvLPRNyAcyaNYsvv/ySyy6
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEJCAYAAABbvWQWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAABBjElEQVR4nO3deXxU9b3/8dc5s2RfJyQhGwmLAoIiBrHgBoloRa/UWmqt9lK0FZBqq6Ui12vvry1eXCgtFZdaBLt4rfe20GJFaXAnFiOL4oImAoEsZJnJnsxMZs7398eEIYGEAJlkJsnn+XjkkTlzzsz5nG8t73y/Z/lqSimFEEIIIUKSHuwChBBCCNEzCWohhBAihElQCyGEECFMgloIIYQIYRLUQgghRAiToBZCCCFCmDnYBfSkoqIi2CUMGUlJSdTW1ga7jCFF2rR/SLsGnrRp/wh0u6alpfW4TnrUQgghRAiToBZCCCFCmAS1EEIIEcJC9hz1iZRSOJ1ODMNA07RglzOoVFVV4XK5/MtKKXRdJzw8XNpSCCFC3KAJaqfTicViwWweNCWHDLPZjMlk6vKex+PB6XQSERERpKqEEEKcjl5T78knn2T37t3ExcWxevXqk9YrpdiwYQN79uwhLCyMJUuWMHr0aAD27t3Lhg0bMAyDvLw85s2bd9aFGoYhIR1AZrO5Sy9bCCFEaOr1HPWVV17JihUrely/Z88ejh49ytq1a/n+97/P7373O8AXrOvXr2fFihWsWbOGHTt2UFZWdtaFyhBt4EmbCiFE6Ou1izpx4kSqq6t7XP/BBx9w+eWXo2ka55xzDi0tLdTV1VFTU0NqaiopKSkAzJgxg6KiIjIyMgJXvRBCDCNKKfB6wOMBr9f3uvNvj+ek91wV0aj6eqBjRmMFKNV1GdXxXjfLqE6f6bSMQhkd32MYHauMjm07/1bH93dse9XDT6fv7rGeE3+for6uNNA6v+5m2b94bLnTNicst0+/HOJsDIQ+jyU7HA6SkpL8yzabDYfDgcPhwGazdXm/uLi4x+8pKCigoKAAgFWrVnX5TvBdECVD32evu7YLCws7qZ3F6TGbzdJ2/eBU7eoLKS/K7UK5XeB2oVwu/7L/Pbe763vtbpTX6wsTrxdleMHw+l57fb8xDN/7Xt+6nt73f8YwfCHkK+x4ff5ldTwnVKdwgm6Dx/9Zwxe2yusFTzuqI3iV1wOejhrOUP0Zf2IIORa+6sTQ7jtvfAJJ194U8O/tTp+TT3XTAJqm9fh+T/Lz88nPz/cvn/jEF5fLddIFUYONUor58+fz3HPPYRgGmzZtYsGCBWf8PbfddhtPPPEEcXFxPW7zs5/9jNmzZ3PppZdiNpvxeDwnbeNyueSJRWdJnvbUO+X1grMN2lqgrbXjdxuq87Kz1fe6tRXlbMVieGlvbQG3G9rd0O46/trtPh6OfWEygX7sRweTfsJyp9fHlrVO75tMvgDQTZ16YR09rS7L3fXQOi1301vTju3PZAaz2f9aO/ZeN+u6/racsK2JuMREGhoaT78H2bH6pOXOn9HwtQka6NrxdT39dN7Ov71+/HuOfb9+YvudQT2dlk+VNapz7x1OPcLgj7GTPxOWmjpgTybrc1DbbLYuxdrtdhISEvB4PNjt9pPeH862b9/OxIkTiYmJ4ciRI/z+97/vNqi9Xu8p/yj5wx/+0Ou+Fi5cyLJly7j00kv7UrIQQEfo1tWCvRplrwZH7fEAbm1BtbUeD91j77ucvX+xyQwRkR0/URAdDeGREBuPZrGCxQrWjt+WsE6vj7+vWTvWddn22OswMFv8Yazpw+/REdakJDT5o9JPO/GPprP9Hos1ANWcnj4HdW5uLq+++iozZ86kuLiYyMhIEhISiI2NpbKykurqahITEyksLOTuu+8ORM0YLz6LOnIwIN91jJaZg37z9065zcKFC6moqMDlcnH77bdz66238sYbb7Bq1Sq8Xi+JiYm89NJLtLS08OCDD/LRRx+haRo/+tGPmDt3Lps2beLb3/42AA8//DClpaVcddVVXH755eTl5fHLX/6SlJQUPvnkE958881u9wcwffp0tm7dSktLC7feeisXX3wxH3zwAampqTz33HNERESQkZFBXV0d1dXVp/xLTQjAN8TaOYhrq8FehbLXgL3at844oTdrDesashGRkGBDC++0HOl7rUVE+gL4xPdP+McuUUYqhDhJr0H9q1/9ik8//ZSmpiYWLVrE/Pnz/cOoc+bM4cILL2T37t3cfffdWK1WlixZAoDJZGLhwoWsXLkSwzCYNWsWmZmZ/Xs0/Wz16tUkJCTQ1tbG3Llzufrqq1m2bBl//etfycrKoq6uDvC1WUxMDNu3bwegvr4egKKiIh555BEAVqxYweeff84///lPAAoLC9m7dy+vv/46WVlZ3e7v2muvJTExsUtNBw8eZN26dTz22GPceeedvPLKK3z9618HYPLkyRQVFXHDDTf0e9uI0KY87b5e8LEg7ghjZa/qCGJH12FlTYN4G9iS0cZNhMRkSEpGs/l+kzACzWIJ3gEJMYz0GtQ//OEPT7le0zTuuOOObtdNnTqVqVOnnlVhp9Jbz7e/PPfcc2zduhXwze71xz/+kUsuucQfrMeG9t955x2efPJJ/+fi4+MBX2BHR0f3+P1Tpkzxf1d3+zt48OBJQZ2ZmcmkSZMAOP/88zly5Ih/nc1mo6qq6mwPVwxCyu2CyiOoskNQVooqPwRHy6He3vWCGk2HBBvYRqCdM9kXvraOILYlQ2ISmlmCWIhQIJdRn6bCwkLeeecdtmzZQkREBDfddBPnnXceBw4cOGlbpVS3FzOYzWYMw0Dv4TxZZGTkKffX3QNKwsLC/K9NJhNO5/Hzgi6Xi/Dw8DM6TjE4KKXAUQNlhzpCueN3VcXxnrHVCmmj0MZPhqSUrkGckIQmd1EIMSjI/1NPU1NTE3FxcURERFBSUsLu3btxuVy89957HD582D/0nZCQwBVXXMGGDRv42c9+Bvh60vHx8YwePZrS0lJycnKIioqiubn5jPZ3pg4cOMB111131scsQoNqa4XyUl8Ql3cEcnmp74KtY0akQno2Wu6laBnZkD4KklPR9MF9p4QQQoL6tF155ZX84Q9/ID8/n9GjRzN16lRsNhuPPvood9xxB4ZhkJSUxIsvvsg999zDihUrmD17Nrquc++993LttdeSl5fHe++9R05ODomJiUybNo3Zs2cza9Ys8vLyet3fmWhvb+fQoUNccMEFgWwG0Y+U4YXqo/4wPtZTprbT6YuIKMgYhTb9SsjI7gjlLN8FXEKIIUlT3d3wHAIqKiq6LLe2tnYZGh6MqqqquOeee3jxxRf7fV9bt25l3759/OQnP+nxPuqh0KbBEoj7qFVTAxz4HPXlftSBz+FQ8fFbmjQdUtP9vWMtIwcysn3njofwo1/l/vTAkzbtH4Fu1369j1qcvpSUFG655RaampqIiYnp1315PB7uvPPOft2HOH3K6/UNXx/YD19+7vtdXelbaTJBRg7azHwYNcYXziMzB/Q+TSFE6JKgHmD/9m//NiD7uf766wdkP6J7qqnR11s+sB/15f6uveXYeBg9Hu2yOWijx0P2WDRr2Cm/TwgxfElQC9FHyvBC+WFfIB/Yj/ryc6juOHXTubc8+ly0MeN9V18P4eFrIURgSVALcYaUpx3XrkKM3e/7hrAPFoOrzbfS31u+ytdbHjUWLUx6y0KIsydBLcRpUu3tqB3/RG39P+odtb4JGzJHo82YDWPGo40+F5JSpLcshAgoCWoheqHa3ah3tqG2/sX3hK8x44lfdD+NadnSWxZC9LvhN5VMECml+MY3vkFTU9NZfX7cuHEAHD16lO99r/vHqN500018+OGHAHzzm9/0P2dcnDnldmFs34Kx4vuo//ktJKWg/+hn6Pc/Qti0mRLSQogBIT3qAdR5msu+SE1N5dlnn+11u69//es8//zz3HfffX3a33CjXC7U26+iXvsrNNTBOZPQb78Xzp0sw9pCiAE3KIP6dx9UcbDuNOa6PQM5CeHckZtyym0COc3lypUrSU9P989HvXr1aqKiorj
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 1.81e-03\n",
" final error(valid) = 1.47e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.71e-01\n",
" run time per epoch = 1.49\n",
"--------------------------------------------------------------------------------\n",
"learning_rate=0.20 init_scale=1.00\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAENCAYAAADALCYAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAABBtklEQVR4nO3deXxU1f3/8de5s2XfZrIQ9l1ANk0EWVQggIIornWp/bq1KvjTVosV61etiqKCWFssVBG1rVa/LlgVBKMoYKyGTVBEQBFZEkJWsieTe35/3DAkkJBAJpksn+fjkUdm5t4785nzUN455957jtJaa4QQQgjRKhmBLkAIIYQQ9ZOgFkIIIVoxCWohhBCiFZOgFkIIIVoxCWohhBCiFZOgFkIIIVoxe2N22rx5M0uXLsU0TSZMmMD06dNrbV+7di3vvvsuAEFBQdx888306NEDgJkzZxIUFIRhGNhsNubOnevXLyCEEEK0Zw0GtWmaLFmyhPvvvx+3283s2bNJSkqiS5cuvn3i4uJ46KGHCAsLY9OmTfz973/nscce821/8MEHiYiIaJ5vIIQQQrRjDQ5979q1i4SEBOLj47Hb7YwaNYr09PRa+/Tv35+wsDAA+vbtS05OTvNUK4QQQnQwDfaoc3Nzcbvdvudut5udO3fWu/8nn3zC8OHDa702Z84cACZOnEhKSsqp1iqEEEJ0OA0GdV0zjCql6tz3m2++YfXq1Tz88MO+1x555BFiYmIoKCjg0UcfJTExkYEDBx53bGpqKqmpqQDMnTuXioqKRn8JcWJ2ux2v1xvoMtoVadPmIe3qf9KmzcPf7ep0Ouv/rIYOdrvdtYayc3JyiI6OPm6/PXv2sHjxYmbPnk14eLjv9ZiYGAAiIyNJTk5m165ddQZ1SkpKrd52dnZ2Q6WJRvJ4PNKefiZt2jykXf1P2rR5+LtdExMT693W4Dnq3r17k5GRQVZWFl6vl7S0NJKSkmrtk52dzbx587j99ttrfVhZWRmlpaW+x1u2bKFbt26n+j2EEEKIDqfBHrXNZuPGG29kzpw5mKbJuHHj6Nq1K6tWrQJg0qRJvPnmmxQVFfHCCy/4jpk7dy4FBQXMmzcPgKqqKsaMGcOwYcOa79sIIYQQ7YxqrctcHjhwINAltBsy9OV/0qbNQ9rV/xrTplprysrKME2z3muQRG0ul4vy8vKTOkZrjWEYBAUFHdfOJxr6btSEJ0IIIdqvsrIyHA4HdrtEQmPZ7XZsNttJH+f1eikrKyM4OLjRx8gUokII0cGZpikh3ULsdjumaZ7UMRLUQgjRwclwd8s62fZu10Gtq6owP34fvfnLQJcihBBCnJJ2HdQYBvrTDzBT/xPoSoQQQgSI1porrriCwsJCCgoKeOmll07pfa677joKCgpOuM/DDz/MunXrTun969Oug1ophUoaCzu+QefnBrocIYQQp6iqquqEz+uitcY0TT7++GMGDhxIeHg4hw8f5pVXXmnUZxzrH//4B5GRkSfc58Ybb2ThwoUN1nYy2nVQA6jkMaA1ekNaoEsRQghRj7feeoupU6cyceJE7rnnHqqqqujbty9PPfUUF154IRs2bDju+eLFixk/fjzjx4/n+eefB2Dv3r2ce+65zJ49m8mTJ3PgwAHeeecdJk+eDMBjjz3Gnj17mDhxIo888ghpaWlcfvnlzJw5kwkTJgBW2J5//vmMGzeOf/7zn74aR4wYQW5uLnv37mXMmDHMmjWLcePGcfXVV/sm9+rSpQt5eXlkZWX5rW3a/WV+FXFdKOvaj7D1a2HChYEuRwghWjXz38+j9+7263uqrj0xrvp1vdt37tzJf/7zH5YtW4bD4WD27Nm8/fbblJSU0L9/f2bNmgVQ6/mWLVt44403eP/999Fac+GFF3L22WcTGRnJDz/8wNNPP83jjz8OQHp6Ok888QQA9913H99//z0fffQRAGlpaWzevJlPPvnEN3Pm/PnziY6OprS0lKlTpzJlyhTfdNhH/Pjjj/z1r3/lqaee4pZbbmH58uVcdtllAAwePJj09HSmTp3ql/Zr10FdWWVy+/s/MnzAJdyy6gl07iFUTGygyxJCCFHDunXr2Lp1K1OmTAGs+7o9Hg82m61W2NV8/tVXX3H++ecTEhICwAUXXMCXX37JpEmT6NKlC2eeeabvuPz8fN9SzHUZNmxYremtX3zxRVasWAFYk2/t3r37uKDu1q0bp59+OgBDhgxh7969vm1ut5uDBw+eUlvUpV0HtcNmcGZiGB/t8nJJUDTx69ehJl0S6LKEEKLVOlHPt7kcudhr9uzZtV5ftGhRrUlFXC6X7/mJJtU8Et5HHLl32TDqPttbc/+0tDTWrl3Le++9R3BwMJdffnmdM5DVXO3KZrNRVlbme15eXk5QUFC99Z2sdn+O+vLT3SileHPgdHS6f6/EE0II0XRjxozh/fff9011mpeXx759+054zMiRI1m5ciWlpaWUlJTw4YcfMmLEiDr37dWrF3v27AEgNDSUoqKiet+3sLCQyMhIgoOD2bVrFxs3bjzp7/Pjjz/Sv3//kz6uPu0+qD0hDib1jWJ1xGlkZuagD2UGuiQhhBA19OvXj3vuuYerr76alJQUrr766gaHjgcPHswVV1zB1KlTufDCC7n66qt9Q9HHmjBhAl988QVgLb2cnJzM+PHjeeSRR47b97zzzqOqqoqUlBSefPJJzjjjjJP6LpWVlfz0008MHTr0pI47kQ6xKEdOSSW3vvsDY/an8/8GBWNccLnf3rstkIUO/E/atHlIu/pfY9q0pKTkuOHi9uTgwYPceeed/Pvf//bbe9rtdrxe73Gvr1ixgq1bt3LPPffUe2xd7d2k9ajbA3eIg/P7RfNpwpkc2PR1oMsRQgjRguLj47nmmmsoLCxs9s/yer3ccsstfn3PDhHUAJcNdGNX8H+u/ujM/YEuRwghRAu66KKLCA8Pb/bPmTZtWoOTopysDhPUUcF2pvQMZU38Gez78qtAlyOEEEI0SocJaoBLzuiMQ1fx+oEO9bWFEEK0YR0qsaKC7EyJKGJdeF9+3vFToMsRQgghGtShghrgklH9cFVV8vom/11VLoQQQjSXDhfUkbFuppbt4PPKaH7KK2v4ACGEEG1azWUuT0Xfvn0ByMzM5Ne/rnvmtssvv5yvv7buKvrFL35Bfn7+KX1WXTpcUANc3C+SoKpyXv9qT6BLEUII0Qj+WuayKRISEnyrdJ3IZZddxssvv9ykz6qpQwZ1RNJIpu1fR1q2Zrf0qoUQIuBaapnLOXPm8NJLL/k+d/78+SxatIji4mKuvPJKJk+ezIQJE1i5cuVxNe7du5fx48cDUFpaym233UZKSgq33nprrbm+J02axLvvvuu3tmnXi3LUR4VHMC00n/erynhtSzb3ndsl0CUJIUSr8ML6g37vwPSMDuLmpPh6t7fkMpcXX3wxDz74INdffz0A7733Hv/6179wuVwsWbKE8PBwcnNzmTZtGpMmTUIpVWfNL7/8MsHBwaSmprJt2zbOP/9837aoqCjKy8vJzc09btWtU9EhgxogLGkkF61ew79tk/ght4zeMf5b6UQIIUTjteQyl6effjrZ2dlkZmaSk5NDZGQknTt3prKykrlz5/Lll1+ilCIzM5NDhw4RFxdXZ81ffPEFN9xwAwADBw5kwIABtbZ7PB4OHjwoQd0UathIpr66hPd6jue1Ldncf570qoUQ4kQ93+bS0stcTp06lQ8++ICsrCwuvvhiAN5++21ycnJYsWIFDoeDESNG1Lm8ZU319bbBv0tddshz1AAqNIzQ0wZycUYa6fuL2JlTGuiShBCiQ2rJZS7BGv5+9913+eCDD3w99MLCQjweDw6Hg88//7zBzz/77LN55513ANi+fTvfffedb5vWmkOHDtG1a9eGv3wjdNigBlDJY5i66yPC7ZrXtsiKPUIIEQgtucwlQP/+/SkuLiYhIYH4eGsE4dJLL+Xrr7/mggsu4J133qFPnz4n/Pz/+Z//obi4mJSUFJ577jmGDRvm27ZlyxbOOOMM7Hb/DFp3iGUu66NLSzDvuo63x97MP3VPnpzcnf6
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEJCAYAAABbvWQWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAABEZUlEQVR4nO3deXxU9b3/8dc5M5N9nyEJ2Qj7LkgDUnDDRFzQGpdS61aKXgXk2lZv6nIt/n5t8dIq1R8F0WsR1NZrva3Qq9VqA+7xQtgERTARBJJAlpnsy0xmzvf3x4SBACERZjKT5PN8PPKYOUvmfM4XzXvO9v1qSimFEEIIIUKSHuwChBBCCNE1CWohhBAihElQCyGEECFMgloIIYQIYRLUQgghRAiToBZCCCFCmDnYBXSloqIi2CX0GzabjZqammCX0a9ImwaGtKv/SZsGhr/bNS0trctlckQthBBChDAJaiGEECKESVALIYQQISxkr1GfTClFW1sbhmGgaVqwy+lTKisrcTqdvmmlFLquExERIW0phBAhrs8EdVtbGxaLBbO5z5QcMsxmMyaTqdM8t9tNW1sbkZGRQapKCCFET3Sbes888wzbt28nPj6e5cuXn7JcKcXatWvZsWMH4eHhLFq0iGHDhgGwc+dO1q5di2EY5Obmkp+ff9aFGoYhIe1HZrO501G2EEKI0NTtNepLL72URx55pMvlO3bs4OjRo6xYsYK7776bP/zhD4A3WNesWcMjjzzCU089xSeffEJZWdlZFyqnaP1P2lQIIUJft4eo48aNo6qqqsvlW7du5eKLL0bTNEaNGkVzczO1tbVUV1eTmppKSkoKADNmzKC4uJiMjAz/VS+EEKJfUkqBux3a28HtgnZ3x7TL++pb5n1VJ06fuJ7HE5D6XDMvA9vggHz2yc75XLLD4cBms/mmrVYrDocDh8OB1WrtNL+kpKTLzyksLKSwsBCAZcuWdfpM8N4QJae+z97p2i48PPyUdhY9Yzabpe0CQNrV/75NmyqlwNmG0dqCamlGtbWgWlswWppRrS2df9zt4HGjPB7wuMHtRhke76vHAx4PyuP2BqXHjXK7T5jX8d7tBqPj9dg67S7UscD1lwCcPTQGp2MbM9Hvn3s655x8SqlT5mma1uX8ruTl5ZGXl+ebPrnHF6fTecoNUX2NUoq5c+fywgsvYBgG69evZ968ed/6c26//XZWrlxJfHx8l+v88pe/5LLLLuPCCy/EbDbjdrtPWcfpdEqPRWdJensKDGnXb0e528HpBGeb98fV8ep0grMV5XQSYzHRVFMNba3Q1gJtrai21o7pE39aoK0NlNGzjZvNYDKDbgKTyfvepHe8mk6abzo+LywcTNGgm9BO/j2zGcwWNIsFzCf8WMK8yywWNEvYCfNPv87xaYt3GwEQ0Ys9k51zUFut1k7F2u12EhMTcbvd2O32U+YPZBs3bmTcuHHExsZy+PBhXnrppdMGtcfjOeOXkpdffrnbbc2fP5+CggIuvPDCcylZCHEayuPpOMV6/OjQd5rV4wFPu+9IEre703vfEajnxN/teO9ydYTt8aA9Hr5tx0P52LwenNZtPPbGZIKIKIiIPP4TFQ1JNrSIyFOXRUR1zD8+7XsfFo6mSzccveWcgzonJ4d//OMfzJw5k5KSEqKiokhMTCQuLo4jR45QVVVFUlISRUVF3Hffff6oGePV51GHD/jls47RMoei3/wvZ1xn/vz5VFRU4HQ6ufPOO7ntttt47733WLZsGR6Ph6SkJF577TWam5t59NFH2bVrF5qm8bOf/Yw5c+awfv16br31VgAef/xxDh48yOWXX87FF19Mbm4uv/vd70hJSeGLL77g/fffP+32AC644ALefvttmpubue2225g2bRpbt24lNTWVF154gcjISDIyMqitraWqquqM39SEGMiOneqlqQGaG6GpEeV73wBNjdDciOp49c1ztgauKEsYhIdDuDcQCY/w/iRY0cIjvPMiTlx2fF3t2LrHfsLCsaalY29p9R5dyg2kfVK3Qf3000+zZ88eGhsbWbBgAXPnzvWdRp09ezbnn38+27dv57777iMsLIxFixYBYDKZmD9/PkuXLsUwDGbNmkVmZmZg9ybAli9fTmJiIq2trcyZM4crrriCgoICXn/9dbKysqitrQW8bRYbG8vGjRsBqKurA6C4uJjf/OY3ADzyyCPs27ePf/7znwAUFRWxc+dONm3aRFZW1mm3d/XVV5OUlNSppgMHDrBq1SqeeOIJ7rnnHt566y1uvPFGACZOnEhxcTHXXXddwNtGiFCinG1QdQSqKlD2qs6Be0Io09zgPZrtSlQ0xMRBdCzEJaANzuyYjvYGqskEJsvxU7xm8wmnc03HTw+fvN7Jy469t1jQdP+eqtXjE9HaA3NDlegd3Qb1T3/60zMu1zSNu+6667TLpkyZwpQpU86qsDPp7sg3UF544QXefvttwDu61x//+EemT5/uC9Zjp/Y/+ugjnnnmGd/vJSQkAN7AjomJ6fLzJ0+e7Pus023vwIEDpwR1ZmYmEyZMAOC8887j8OHDvmVWq5XKysqz3V0hQppyu6GmEiorUJXl3teqCqisgNqTrh3qujdsY+IgJhYGDUYbOur4dHQsWkzH8uiOeVExAbu+KcS3IbdR91BRUREfffQRb7zxBpGRkdx0002MHz+e/fv3n7KuUuq0p5jMZjOGYaB3cW0nKirqjNs7XQcl4eHhvvcmk4m2tjbftNPpJCIi4lvtpxChRBkG1NqhstwXwqqyAirLvSFtnHDjU1QMpKajjZkIyWmQko6Wkga2ZIiMltO+os+SoO6hxsZG4uPjiYyMpLS0lO3bt+N0Ovn00085dOiQ79R3YmIil1xyCWvXruWXv/wl4D2STkhIYNiwYRw8eJChQ4cSHR1NU1PTt9ret7V//36uueaas95nIQLNe424Fepqob6W1s+aML7e1xHGFd7T1+2u478QFg7JaWiZwyDnIkhJ84ZxShpaTFzwdkSIAJKg7qFLL72Ul19+mby8PIYNG8aUKVOwWq389re/5a677sIwDGw2G6+++io/+clPeOSRR7jsssvQdZ3777+fq6++mtzcXD799FOGDh1KUlISU6dO5bLLLmPWrFnk5uZ2u71vo729nW+++YZJkyb5sxmE6BFleKCxAeq9AawaaqHOAQ11qHqHbz71teA6fqaoAbzXbAelegN53OTjR8Yp6ZCQJEfGYsDR1OkeeA4BFRUVnaZbWlo6nRruiyorK/nJT37Cq6++GvBtvf322+zevZuf//znXT5H3R/aNFgG8vO+yuWEwwegztERunVQ70A1eF+pr4WG+tM/jxsZDfEJEJ+EFud99U3HJ5A4ciy1euCefR2IBvJ/q4Hk73YN6HPUoudSUlK45ZZbaGxsJDY2NqDbcrvd3HPPPQHdhhgYVFsrfL0X9dXnqK++gG++6nyntKZDXALEJ3oDN2s4xCVCQiJafKL3fcerdsI9FadjttnQJFSE6ESCupd973vf65XtXHvttb2yHdH/qNYWKN2D2vc56qvP4dDX3o41dB2GjEDLvRZtxFiwpniPhmPi/P5IkRDiOAlqIQY41dwEJV8cP2I+tN972tpkhuwRaLOvRxs1AUaMQYuQSyVC9DYJaiEGGNXYACXeUFb7Pofyb0Apb//Iw0ahzfm+N5iHjen2VLUQIvAkqIXo51R9rfdI+auOU9kVh7wLwsJg+Fi07/3QG8xDR3kHPBBChBQJaiH6GdXSDF/tRn25C/XlZ3Cko7e68Ejv6esLLvEGc/YINLMluMUKIbolQd2LThzm8mzu+h45ciQlJSUcPXqUX/ziFzz//POnrHPTTTfxi1/8gkmTJvGDH/yA5557Tsb37edUezt8/WVHMO+Eb0q915jDwmHUeLSZud5gzhoujz0J0QdJUPeiE4e5PBepqamnDemT3Xjjjbz44os88MAD57Q9EVqUYcDhA6gvd6K+3AWlX3iHR9R17+nrOd9HGzMJho32jusrhOjT+mRQ/2FrJQdq27pf8VsYmhjBXTkpZ1zHn8NcLl26lPT0dN941MuXLyc
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" final error(train) = 4.95e-03\n",
" final error(valid) = 1.88e-01\n",
" final acc(train) = 1.00e+00\n",
" final acc(valid) = 9.60e-01\n",
" run time per epoch = 1.58\n"
]
}
],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n",
"num_epochs = 100 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats\n",
"learning_rate = 0.2 # learning rate for gradient descent\n",
"\n",
"init_scales = [0.1, 0.2, 0.5, 1.] # scale for random parameter initialisation\n",
"final_errors_train = []\n",
"final_errors_valid = []\n",
"final_accs_train = []\n",
"final_accs_valid = []\n",
"\n",
"for init_scale in init_scales:\n",
"\n",
" print('-' * 80)\n",
" print('learning_rate={0:.2f} init_scale={1:.2f}'\n",
" .format(learning_rate, init_scale))\n",
" print('-' * 80)\n",
" # Reset random number generator and data provider states on each run\n",
" # to ensure reproducibility of results\n",
" rng.seed(seed)\n",
" train_data.reset()\n",
" valid_data.reset()\n",
"\n",
" # Alter data-provider batch size\n",
" train_data.batch_size = batch_size \n",
" valid_data.batch_size = batch_size\n",
"\n",
" # Create a parameter initialiser which will sample random uniform values\n",
" # from [-init_scale, init_scale]\n",
" param_init = UniformInit(-init_scale, init_scale, rng=rng)\n",
"\n",
" # Create a model with five affine layers\n",
" hidden_dim = 100\n",
" model = MultipleLayerModel([\n",
" AffineLayer(input_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, hidden_dim, param_init, param_init),\n",
" SigmoidLayer(),\n",
" AffineLayer(hidden_dim, output_dim, param_init, param_init)\n",
" ])\n",
"\n",
" # Initialise a cross entropy error object\n",
" error = CrossEntropySoftmaxError()\n",
"\n",
" # Use a basic gradient descent learning rule\n",
" learning_rule = GradientDescentLearningRule(learning_rate=learning_rate)\n",
"\n",
" stats, keys, run_time, fig_1, ax_1, fig_2, ax_2 = train_model_and_plot_stats(\n",
" model, error, learning_rule, train_data, valid_data, num_epochs, stats_interval)\n",
"\n",
" plt.show()\n",
"\n",
" print(' final error(train) = {0:.2e}'.format(stats[-1, keys['error(train)']]))\n",
" print(' final error(valid) = {0:.2e}'.format(stats[-1, keys['error(valid)']]))\n",
" print(' final acc(train) = {0:.2e}'.format(stats[-1, keys['acc(train)']]))\n",
" print(' final acc(valid) = {0:.2e}'.format(stats[-1, keys['acc(valid)']]))\n",
" print(' run time per epoch = {0:.2f}'.format(run_time * 1. / num_epochs))\n",
"\n",
" final_errors_train.append(stats[-1, keys['error(train)']])\n",
" final_errors_valid.append(stats[-1, keys['error(valid)']])\n",
" final_accs_train.append(stats[-1, keys['acc(train)']])\n",
" final_accs_valid.append(stats[-1, keys['acc(valid)']])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |\n",
"|------------|--------------------|--------------------|------------------|------------------|\n",
"| 0.1 | 4.75e-03 | 2.19e-01 | 1.00 | 0.96 |\n",
"| 0.2 | 1.14e-03 | 1.49e-01 | 1.00 | 0.97 |\n",
"| 0.5 | 1.81e-03 | 1.47e-01 | 1.00 | 0.97 |\n",
"| 1.0 | 4.95e-03 | 1.88e-01 | 1.00 | 0.96 |\n"
]
}
],
"source": [
"j = 0\n",
"print('| init_scale | final error(train) | final error(valid) | final acc(train) | final acc(valid) |')\n",
"print('|------------|--------------------|--------------------|------------------|------------------|')\n",
"for init_scale in init_scales:\n",
" print('| {0:.1f} | {1:.2e} | {2:.2e} | {3:.2f} | {4:.2f} |'\n",
" .format(init_scale, \n",
" final_errors_train[j], final_errors_valid[j],\n",
" final_accs_train[j], final_accs_valid[j]))\n",
" j += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> How does increasing the number of layers affect the model's performance on the training data set? And on the validation data set?\n",
"\n",
"<span style='color: green;'>\n",
"The best final training set error across the four initialisation scales used above for each model architecture, consistently decreases as we increase the number of layers.\n",
"</span>\n",
"\n",
"| Number of affine layers | Best final training set error |\n",
"|-------------------------|-------------------------------|\n",
"| 2 | $1.85 \\times 10^{-2}$ |\n",
"| 3 | $5.21 \\times 10^{-3}$ |\n",
"| 4 | $1.99 \\times 10^{-3}$ |\n",
"| 5 | $1.14 \\times 10^{-3}$ |\n",
"\n",
"<span style='color: green;'>\n",
"This makes sense as adding more layers increases the total number of free parameters in the model and so we would expect the model to be able to fit too the training data better.\n",
"</span>\n",
"\n",
"<span style='color: green;'>\n",
"If we look at the validation set however we see the opposite trend; as the number of layers increases the best final validation set error increases.\n",
"</span>\n",
"\n",
"| Number of affine layers | Best final validation set error |\n",
"|-------------------------|---------------------------------|\n",
"| 2 | $7.47 \\times 10^{-2}$ |\n",
"| 3 | $8.77 \\times 10^{-2}$ |\n",
"| 4 | $1.17 \\times 10^{-1}$ |\n",
"| 5 | $1.47 \\times 10^{-1}$ |\n",
"\n",
"\n",
"If we look more closely at the training curves for the models with more layers we can see what is happening here. For the models with three or more layers, after a certain number of epochs the validation set error begins to **grow** even as the training set error continues to decrease. This indicates that these models have begun **overfitting** to the training data. We could get a better validation set error in these cases by stopping the training early. **Early stopping** like this is one way of trying to overcome overfitting, in later labs we will consider other methods for improving generalisation by reducing overfitting.\n",
"\n",
"\n",
"> Do deeper models seem to be harder or easier to train (e.g. in terms of ease of choosing training hyperparameters to give good final performance and/or quick convergence)?\n",
"\n",
"> Do the models seem to be sensitive to the choice of the parameter initialisation range? Can you think of any reasons for why setting individual parameter initialisation scales for each AffineLayer in a model might be useful? Can you come up with (or find) any heuristics for setting the parameter initialisation scales?\n",
"\n",
"The final performance of the deeper models becomes increasingly sensitive to the choice of parameter initialisation. For the models with two affine layers, the final training errors for initialisation scales 0.1, 0.2 and 0.5 are all within approximately 10% of each other, while for the models with five affine layers there is an approximately 400% increase in final training error if moving from an initialisation scale of 0.2 to 0.1 and a 50% increase in final training error when moving from 0.2 to 0.5. The smaller parameter initialisation scales for the deeper models in particular seem to give poorer initial performance (error curves start from higher values) and for the five affine layer model the smallest parameter initialisation scale run shows a pronounced flatter section at the start of training with around 15 epochs before the error starts significantly decreasing.\n",
"\n",
"\n",
"In general the models with more layers also take longer to train, which can be accounted for by their increased complexity, so on top of issues of potential overfitting and difficulty of choosing parameter initilisations we also need to factor in the potentially slower training of deeper models if computational time is a key constraint.\n",
"\n",
"\n",
"We might expect the appropriate initialisation scale for a given affine layer to depend on its input and output dimensionalities. Each output is calculated as the weighted sum of all the inputs, and so for a larger number of inputs the typical magnitude of the output is expected to grow as the sum will contain more terms. Similarly the backpropagated gradient at each input is calculated as a weighted sum over the gradients at each output, and so for a larger number outputs the typical magnitude of backpropagated gradients will become larger.\n",
"\n",
"\n",
"If we wish to keep the typical magnitude of the activations and backpropagated gradients at a given layer roughly constant through the network, we may wish to initialise a layer parameters according to its size. One heuristic based on trying to achieve a roughly constant variance in activations and backpropagated gradients through the network is to initialise the weights for a layer from a distribution with variance inversely proportional to the sum of the input and output dimensions of the layer. This is sometimes known as the Glorot or Xavier initialisation, after the name of the author of [the paper](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf) in which this scheme was proposed. This is discussed in [lecture 4](http://www.inf.ed.ac.uk/teaching/courses/mlp/2017-18/mlp04-learn.pdf).\n"
]
2024-10-03 15:53:33 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 3: Hyperbolic tangent and rectified linear layers\n",
"\n",
"In the models we have been investigating so far, we have been applying elementwise logistic sigmoid transformations to the outputs of intermediate (affine) layers. The logistic sigmoid is just one particular choice of an elementwise non-linearity we can use. \n",
"\n",
"As discussed in lecture 3, although logistic sigmoid has some favourable properties in terms of interpretability, there are also disadvantages from a computational perspective. In particular:\n",
"1. the gradients of the sigmoid become close to zero (and may actually become zero because of finite numerical precision) for large positive or negative inputs, \n",
"2. the outputs are non-centred - they cover the interval $[0,\\,1]$ so negative outputs are never produced.\n",
"\n",
"Two alternative elementwise non-linearities which are often used in multiple layer models are the hyperbolic tangent (tanh) and the rectified linear function (ReLU).\n",
"\n",
"For tanh (`TanhLayer`) layer the forward propagation corresponds to\n",
"\n",
"\\begin{equation}\n",
" y^{(b)}_k = \n",
" \\tanh\\left(x^{(b)}_k\\right) = \n",
" \\frac{\\exp\\left(x^{(b)}_k\\right) - \\exp\\left(-x^{(b)}_k\\right)}{\\exp\\left(x^{(b)}_k\\right) + \\exp\\left(-x^{(b)}_k\\right)}\n",
"\\end{equation}\n",
"\n",
"which has corresponding partial derivatives\n",
"\n",
"\\begin{equation}\n",
" \\frac{\\partial y^{(b)}_k}{\\partial x^{(b)}_d} = \n",
" \\begin{cases} \n",
" 1 - \\left(y^{(b)}_k\\right)^2 & \\quad k = d \\\\\n",
" 0 & \\quad k \\neq d\n",
" \\end{cases}.\n",
"\\end{equation}\n",
"\n",
"For a ReLU (`ReluLayer`) the forward propagation corresponds to\n",
"\n",
"\\begin{equation}\n",
" y^{(b)}_k = \n",
" \\max\\left(0,\\,x^{(b)}_k\\right)\n",
"\\end{equation}\n",
"\n",
"which has corresponding partial derivatives\n",
"\n",
"\\begin{equation}\n",
" \\frac{\\partial y^{(b)}_k}{\\partial x^{(b)}_d} = \n",
" \\begin{cases} \n",
" 1 & \\quad k = d \\quad\\textrm{and} &x^{(b)}_d > 0 \\\\\n",
" 0 & \\quad k \\neq d \\quad\\textrm{or} &x^{(b)}_d < 0\n",
" \\end{cases}.\n",
"\\end{equation}\n",
"\n",
"**Your Tasks:**\n",
"- Using these definitions implement the `fprop` and `bprop` methods for the skeleton `TanhLayer` and `ReluLayer` class definitions below."
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 19,
2024-10-03 15:53:33 +02:00
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from mlp.layers import Layer\n",
"\n",
"class TanhLayer(Layer):\n",
" \"\"\"Layer implementing an element-wise hyperbolic tangent transformation.\"\"\"\n",
"\n",
" def fprop(self, inputs):\n",
" \"\"\"Forward propagates activations through the layer transformation.\n",
"\n",
" For inputs `x` and outputs `y` this corresponds to `y = tanh(x)`.\n",
" \"\"\"\n",
2024-10-10 15:52:23 +02:00
" return np.tanh(inputs)\n",
2024-10-03 15:53:33 +02:00
"\n",
" def bprop(self, inputs, outputs, grads_wrt_outputs):\n",
" \"\"\"Back propagates gradients through a layer.\n",
"\n",
" Given gradients with respect to the outputs of the layer calculates the\n",
" gradients with respect to the layer inputs.\n",
" \"\"\"\n",
2024-10-10 15:52:23 +02:00
" return (1. - outputs**2) * grads_wrt_outputs\n",
2024-10-03 15:53:33 +02:00
"\n",
" def __repr__(self):\n",
" return 'TanhLayer'\n",
" \n",
"\n",
"class ReluLayer(Layer):\n",
" \"\"\"Layer implementing an element-wise rectified linear transformation.\"\"\"\n",
"\n",
" def fprop(self, inputs):\n",
" \"\"\"Forward propagates activations through the layer transformation.\n",
"\n",
" For inputs `x` and outputs `y` this corresponds to `y = max(0, x)`.\n",
" \"\"\"\n",
2024-10-10 15:52:23 +02:00
" return np.maximum(inputs, 0.)\n",
2024-10-03 15:53:33 +02:00
"\n",
" def bprop(self, inputs, outputs, grads_wrt_outputs):\n",
" \"\"\"Back propagates gradients through a layer.\n",
"\n",
" Given gradients with respect to the outputs of the layer calculates the\n",
" gradients with respect to the layer inputs.\n",
" \"\"\"\n",
2024-10-10 15:52:23 +02:00
" return (outputs > 0) * grads_wrt_outputs\n",
2024-10-03 15:53:33 +02:00
"\n",
" def __repr__(self):\n",
" return 'ReluLayer'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Test your implementations by running the cells below."
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 20,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Outputs and gradients calculated correctly for TanhLayer.\n"
]
}
],
2024-10-03 15:53:33 +02:00
"source": [
"test_inputs = np.array([[0.1, -0.2, 0.3], [-0.4, 0.5, -0.6]])\n",
"test_grads_wrt_outputs = np.array([[5., 10., -10.], [-5., 0., 10.]])\n",
"test_tanh_outputs = np.array(\n",
" [[ 0.09966799, -0.19737532, 0.29131261],\n",
" [-0.37994896, 0.46211716, -0.53704957]])\n",
"test_tanh_grads_wrt_inputs = np.array(\n",
" [[ 4.95033145, 9.61042983, -9.15136962],\n",
" [-4.27819393, 0., 7.11577763]])\n",
"tanh_layer = TanhLayer()\n",
"tanh_outputs = tanh_layer.fprop(test_inputs)\n",
"all_correct = True\n",
"if not tanh_outputs.shape == test_tanh_outputs.shape:\n",
" print('TanhLayer.fprop returned array with wrong shape.')\n",
" all_correct = False\n",
"elif not np.allclose(test_tanh_outputs, tanh_outputs):\n",
" print('TanhLayer.fprop calculated incorrect outputs.')\n",
" all_correct = False\n",
"tanh_grads_wrt_inputs = tanh_layer.bprop(\n",
" test_inputs, tanh_outputs, test_grads_wrt_outputs)\n",
"if not tanh_grads_wrt_inputs.shape == test_tanh_grads_wrt_inputs.shape:\n",
" print('TanhLayer.bprop returned array with wrong shape.')\n",
" all_correct = False\n",
"elif not np.allclose(tanh_grads_wrt_inputs, test_tanh_grads_wrt_inputs):\n",
" print('TanhLayer.bprop calculated incorrect gradients with respect to inputs.')\n",
" all_correct = False\n",
"if all_correct:\n",
" print('Outputs and gradients calculated correctly for TanhLayer.')"
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 21,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Outputs and gradients calculated correctly for ReluLayer.\n"
]
}
],
2024-10-03 15:53:33 +02:00
"source": [
"test_inputs = np.array([[0.1, -0.2, 0.3], [-0.4, 0.5, -0.6]])\n",
"test_grads_wrt_outputs = np.array([[5., 10., -10.], [-5., 0., 10.]])\n",
"test_relu_outputs = np.array([[0.1, 0., 0.3], [0., 0.5, 0.]])\n",
"test_relu_grads_wrt_inputs = np.array([[5., 0., -10.], [-0., 0., 0.]])\n",
"relu_layer = ReluLayer()\n",
"relu_outputs = relu_layer.fprop(test_inputs)\n",
"all_correct = True\n",
"if not relu_outputs.shape == test_relu_outputs.shape:\n",
" print('ReluLayer.fprop returned array with wrong shape.')\n",
" all_correct = False\n",
"elif not np.allclose(test_relu_outputs, relu_outputs):\n",
" print('ReluLayer.fprop calculated incorrect outputs.')\n",
" all_correct = False\n",
"relu_grads_wrt_inputs = relu_layer.bprop(\n",
" test_inputs, relu_outputs, test_grads_wrt_outputs)\n",
"if not relu_grads_wrt_inputs.shape == test_relu_grads_wrt_inputs.shape:\n",
" print('ReluLayer.bprop returned array with wrong shape.')\n",
" all_correct = False\n",
"elif not np.allclose(relu_grads_wrt_inputs, test_relu_grads_wrt_inputs):\n",
" print('ReluLayer.bprop calculated incorrect gradients with respect to inputs.')\n",
" all_correct = False\n",
"if all_correct:\n",
" print('Outputs and gradients calculated correctly for ReluLayer.')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PyTorch\n",
"\n",
"In this section we will builld on we learned in the previous lab and will use PyTorch to build a multi-layer model for the MNIST classification task. "
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 22,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f0a02286d70>"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
2024-10-03 15:53:33 +02:00
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torchvision import datasets,transforms\n",
"from torch.utils.data.sampler import SubsetRandomSampler\n",
"\n",
"torch.manual_seed(seed)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Neural networks are typically take a long time to converge. This process can be sped up by using a GPU. If you have a GPU available, you can use it by setting the `device` variable below to `cuda`. If you do not have a GPU available, you can still run the code on the CPU by setting `device` to `cpu`.\n",
"\n",
"When training, both the model and the data should be on the same device. The `to` method can be used to move a tensor to a device. For example, `x = x.to(device)` will move the tensor `x` to the device specified by `device`. Look through the code to see where we put the model and the data on the CPU or GPU device."
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 23,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"data": {
"text/plain": [
"device(type='cpu')"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
2024-10-03 15:53:33 +02:00
"source": [
"# Device configuration\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"# Set training run hyperparameters\n",
"batch_size = 128 # number of data points in a batch\n",
"learning_rate = 0.001 # learning rate for gradient descent\n",
"num_epochs = 50 # number of training epochs to perform\n",
"stats_interval = 5 # epoch interval between recording and printing stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The [transforms](https://pytorch.org/vision/0.9/transforms.html) are common transformations for the image datasets. We will use the `Compose` transform to combine the `ToTensor` and `Normalize` transforms. The `ToTensor` transform converts the image to a tensor and the `Normalize` transform normalizes the image by subtracting the mean and dividing by the standard deviation. The `Normalize` transform takes two arguments: the mean and the standard deviation. The mean and standard deviation are calculated for each channel. The mean and standard deviation for the MNIST dataset are $0.1307$ and $0.3081$ respectively. `Normalize` transform is particularly useful when there is a big discrepancy between the values of the pixels in an image.\n",
"\n",
"When working with images, transforms are used to augment the dataset (i.e. create artificial images based on existing ones). This is done to increase the size of the dataset and to make the model more robust to changes in the input images. An illustration of how transforms affect an image is shown [here](https://pytorch.org/vision/0.11/auto_examples/plot_transforms.html#sphx-glr-download-auto-examples-plot-transforms-py)."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Popular machine learning datasets are available in the [`torchvision.datasets`](https://pytorch.org/vision/0.15/datasets.html) module. The `MNIST` dataset is available in the [`torchvision.datasets.MNIST`](https://pytorch.org/vision/0.15/generated/torchvision.datasets.MNIST.html) class. This way, we can download the dataset directly from the PyTorch library into the `data` folder of our repository. \n",
"\n",
"The `MNIST` class takes the following arguments:\n",
"- `root`: the path where the dataset will be stored\n",
"- `train`: if `True`, the training set is returned, otherwise the test set is returned\n",
"- `download`: if `True`, the dataset is downloaded from the internet and put in `root`. If the dataset is already downloaded, it is not downloaded again\n",
"- `transform`: the transform to be applied to the dataset"
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 26,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 403: Forbidden\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.0%\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 403: Forbidden\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.0%\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 403: Forbidden\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.0%\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 403: Forbidden\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.0%"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
2024-10-03 15:53:33 +02:00
"source": [
"train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"test_dataset = datasets.MNIST('../data', train=False, download=True, transform=transform)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the training set has $60,000$ images and the test set has $10,000$ images. Each image is a $28 \\times 28$ grayscale image. "
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 28,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train dataset: \n",
" Dataset MNIST\n",
" Number of datapoints: 60000\n",
" Root location: ../data\n",
" Split: Train\n",
" StandardTransform\n",
"Transform: Compose(\n",
" ToTensor()\n",
" Normalize(mean=(0.1307,), std=(0.3081,))\n",
" )\n",
"torch.Size([60000, 28, 28])\n",
"torch.Size([60000])\n",
"\n",
"Test dataset: \n",
" Dataset MNIST\n",
" Number of datapoints: 10000\n",
" Root location: ../data\n",
" Split: Test\n",
" StandardTransform\n",
"Transform: Compose(\n",
" ToTensor()\n",
" Normalize(mean=(0.1307,), std=(0.3081,))\n",
" )\n"
]
}
],
2024-10-03 15:53:33 +02:00
"source": [
"print(\"Train dataset: \\n\", train_dataset)\n",
"print(train_dataset.data.size())\n",
"print(train_dataset.targets.size())\n",
"print(\"\\nTest dataset: \\n\", test_dataset)\n"
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 29,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAG0CAYAAACbs5jqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAfj0lEQVR4nO3df2zV1f3H8dctvaWX20tvYwFLacMPy8+MwmIo/uhADCOSLtjNMdhMdIJzKzEhmXFz+AsCSrfI3BzJnJYgS/xRwEaiJDJcgiUsmonCXB1ZBIa5g5QipZdLW1p6v3/cb+8s9xb6ud7bd2/v85GY+Dmfc+7n8Oa0L8798bmucDgcFgAABrKsJwAAyFyEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQMgu3bt8vlcl3zvxEjRlhPExh02dYTADLBnDlz9NRTT8U919jYqL/+9a+66667BnlWgD1CCBgEc+bM0Zw5c+Keu+WWWyRJP/nJTwZxRsDQ4OIu2oCdTz/9VN/4xjdUXFys//znPzwlh4zDa0KAoRdffFGStGrVKgIIGYmdEGCkvb1d48ePV1tbm06ePKmSkhLrKQGDjp0QYKS+vl6tra266667CCBkLEIIMPKnP/1JkvTQQw8ZzwSww9NxgIGmpibNmjVLEyZM0MmTJ3k9CBmLnRBggDckABHshIBB1tHRofHjx+vChQu8IQEZj50QMMh27typ8+fPa+nSpQQQMh4hBAyy3jckcIcEgKfjgEH12WefaebMmbwhAfh/hBAAwAxPxwEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMJNtPYH+/PCHP9S//vWvPm1er1eNjY2qrKxUKBQympk96hBBHSKoQwR1iBgKdZg+fbpeffXVAfVN2R0T3n33Xe3Zs0etra2aMGGC7r//fs2YMWPA47/5zW/q448/7tPm8/nU1tam0aNHKxgMJnvKaYM6RFCHCOoQQR0ihkId5s6dq8OHDw+ob0qejjt06JC2b9+u7373u6qtrdWMGTP0zDPPqKWlJRWXAwCkqZSE0Ntvv61FixbpzjvvjO6CCgsLtW/fvlRcDgCQppL+mlB3d7eOHz+uu+++u0/77NmzdezYsZj+XV1d6urqih67XC55PB55vV75fL4+fXuPr27PNNQhgjpEUIcI6hAxFOrg9XoH3DfpIdTW1qaenh7l5+f3ac/Pz1dra2tM/4aGBu3atSt6PGnSJNXW1qqxsbHfawQCgaTNN51RhwjqEEEdIqhDRLrUIWXvjnO5XANqq66uVlVVVUyfyspKHTlypE9fn8+nQCCg4uLijH/hkTpQh17UIYI6RAyFOpSXl19zI/FVSQ+h0aNHKysrK2bXc+HChZjdkSS53W653e6Y9lAo1G8Bg8FgRi+yXtQhgjpEUIcI6hBhWQcnbw1P+hsTsrOzNXnyZB09erRP+9GjRzVt2rRkXw4AkMZS8nRcVVWVXnjhBU2ePFlTp07V/v371dLSosWLF6ficgCANJWSELr11lsVDAa1e/dunT9/XiUlJXrsscc0ZsyYVFwOAJCmUvbGhCVLlmjJkiWpengAwDDADUwBAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBlCCABgJjvZD1hfX69du3b1acvPz9dLL72U7EsBANJc0kNIkkpKSvTEE09Ej7Oy2HABAGKlJISysrLk9/tT8dAAgGEkJSF05swZPfTQQ8rOzlZZWZlWrlypcePGxe3b1dWlrq6u6LHL5ZLH45HX65XP5+vTt/f46vZMQx0iqEMEdYigDhFDoQ5er3fAfV3hcDiczIt//PHH6uzs1Pjx49Xa2qo333xTgUBAW7ZsiVuUq19DmjRpkmpra5M5JQDAEJX0ELpaR0eHHn74YS1btkxVVVUx5/vbCVVWVurIkSN9+vp8PgUCARUXFysYDKZy2kMadYigDhHUIYI6RAyFOpSXl6uxsXFAfVPydNxX5ebmqrS0VKdPn4573u12y+12x7SHQqF+CxgMBjN6kfWiDhHUIYI6RFCHCMs6hEKhAfdN+dvWurq6FAgEVFBQkOpLAQDSTNJ3Qjt27NDNN9+swsJCXbhwQbt371Z7e7sWLFiQ7EsBANJc0kPoyy+/1O9+9zu1tbVp9OjRKisr06ZNmzRmzJhkXwoAkOaSHkJr165N9kMCAIYpbmUAADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMwQQgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADATMq/1A4Y7lwul+MxRUVFjsd8//vfj9s+cuRISdLPfvYzdXZ29jl3zz33OL6OJE2ZMsXxmPnz5zsec+rUKcdjMLywEwIAmCGEAABmCCEAgBlCCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmOEu2hiWJkyYkNC4ZcuWOR6zYsUKx2Nuu+02x2OuZ/PmzUl7rFAo5HjMpUuXknZ9ZA52QgAAM4QQAMAMIQQAMEMIAQDMEEIAADOEEADADCEEADBDCAEAzBBCAAAzhBAAwAwhBAAwQwgBAMxwA1MMqtmzZzse89hjj8Vtz86OLN9t27apu7u7z7nq6mrnk5OUk5PjeMzJkycdj/nDH/7geEzvn/dqbrdbq1evVl1dnbq6uvqc++lPf+r4OpL0l7/8xfGYlpaWhK6FzMZOCABghhACAJghhAAAZgghAIAZQggAYIYQAgCYIYQAAGYIIQCAGUIIAGCGEAIAmCGEAABmCCEAgBluYArdcccdCY3btm2b4zHjxo1zPCY3N/ea57/3ve/FtL300kuOryNJf/7znx2POXz4sOMxly5dcjxmzpw5cdu9Xq9Wr16t7du3KxQK9TmX6A1M//GPfyQ0DnCKnRAAwAwhBAAw4/jpuKamJu3Zs0cnTpzQ+fPn9cgjj2jevHnR8+FwWDt37tR7772nixcvqqysTKtWrVJJSUlSJw4ASH+Od0KdnZ2aOHGiHnjggbjn33rrLb3zzjt64IEH9Oyzz8rv92vjxo1qb2//2pMFAAwvjkNo7ty5WrFihSoqKmLOhcNh7d27V9XV1aqoqFBpaanWrFmjzs5OHTx4MCkTBgAMH0l9d1xzc7NaW1tVXl4ebXO73Zo5c6aOHTumxYsXx4zp6urq85XELpdLHo9HXq9XPp+vT9/e46vbM02y6zBq1KiExrlcrqRcPxXcbndC4xKpRSJ/DyNGjHA8xuv1xm3vnXOif4/xJPI159Y/l/x+iBgKdehvrcaT1BBqbW2VJOXn5/dpz8/P7/f75xsaGrRr167o8aRJk1RbW6vGxsZ+rxMIBL7+ZIcB6hARLwx//OMfJ/RYiY4bCvbt25e0x/rlL385KGNSgZ+LiHSpQ0o+J3T1L4VwONxv3+rqalVVVcWMrays1JEjR/r09fl8CgQCKi4uVjAYTOKM00uy6/Ctb30roXFbt251PGbs2LGOx1zrc0Iulyvu+tq+fbvj60jS66+/7njM1et0IBL5nNDs2bPjto8aNUr79u3Tt7/97ZjHvdY/5q6ltrbW8ZhnnnkmoWslC78fIoZCHcrLywe89pIaQn6/X1JkR1RQUBBtb2tri9kd9XK73XGfOgmFQv0WMBgMZvQi65WsOiTyC1G69j8urH31KV4nEqlFIn8HiVzn6g+ixnvM6/UZqMuXLzseM1R+Jvn9EGFZByfrMKmfExo7dqz8fr+OHj0abevu7lZ
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2024-10-03 15:53:33 +02:00
"source": [
"plt.imshow(train_dataset.data[42], cmap='gray')\n",
"plt.title('%i' % train_dataset.targets[42])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, since we want to evaluate the performance of our model during training, we need a validation set. We can create a validation set by splitting the training set into two parts. As a general rule, the validation set should be $10-20\\%$ of the training set. The `SubsetRandomSampler` class can be used to create a subset of the training set. The `SubsetRandomSampler` class takes a list of randomly shuffled indices as an argument and selects a subset of the training set based on these indices.\n",
"\n",
"*Why would we want to randomly shuffle the data when creating the separate training and validation set?*\n",
"\n",
"*We could just take the first 80% of data points and assign them to the training set and the last 20% of data points to the validation set. When and why would this be a bad practice?*\n",
"\n",
"*Why do we want to shuffle the training and valisation sets but not the test set (see `shuffle=False`)?*"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"valid_size=0.2 # Leave 20% of training set as validation set\n",
"num_train = len(train_dataset)\n",
"indices = list(range(num_train))\n",
"split = int(np.floor(valid_size * num_train))\n",
"np.random.shuffle(indices) # Shuffle indices in-place\n",
"train_idx, valid_idx = indices[split:], indices[:split] # Split indices into training and validation sets\n",
"train_sampler = SubsetRandomSampler(train_idx)\n",
"valid_sampler = SubsetRandomSampler(valid_idx)\n",
"\n",
"# Create the dataloaders\n",
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=True)\n",
"valid_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler, pin_memory=True)\n",
"test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To create a multy-layer model, we will use the [`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) container. The `nn.Sequential` container takes a list of layers as an argument and applies them sequentially. The `nn.Sequential` container is a convenient way to create a model with multiple layers. However, it is not very flexible. For example, we cannot have skip connections in a model created using the `nn.Sequential` container.\n",
"\n",
"Since we are working with images, we will have to flatten the images before passing them to the model. We can do this using the [`nn.Flatten`](https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html) layer. The `nn.Flatten` layer takes a tensor of shape `(N, C, H, W)` and flattens it to a tensor of shape `(N, C*H*W)`.\n",
"\n",
"In between the affine layers, we will use the [`nn.ReLU`](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) activation function. The `nn.ReLU` activation function applies the ReLU function elementwise to the input tensor. There are other acrtivation functions available in PyTorch, such as the [`nn.Tanh`](https://pytorch.org/docs/stable/generated/torch.nn.Tanh.html) activation function, the [`nn.Sigmoid`](https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html) activation function, and the [`nn.LeakyReLU`](https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html) activation function. The `nn.LeakyReLU` activation function is similar to the `nn.ReLU` activation function, but it allows a small gradient when the input is negative. This can help with training."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"class MultipleLayerModel(nn.Module):\n",
" \"\"\"Multiple layer model.\"\"\"\n",
" def __init__(self, input_dim, output_dim, hidden_dim):\n",
" super().__init__()\n",
" self.flatten = nn.Flatten()\n",
" self.linear_relu_stack = nn.Sequential(\n",
" nn.Linear(input_dim, hidden_dim),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_dim, hidden_dim),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_dim, output_dim),\n",
" )\n",
" \n",
" def forward(self, x):\n",
" x = self.flatten(x)\n",
" logits = self.linear_relu_stack(x)\n",
" return logits"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since our image size is $1 \\times 28 \\times 28$, this will be the output size of the `nn.Flatten` layer. The input size of the first affine layer that will take in $x_i$ datapoints will be $1 \\times 28 \\times 28$ and the output size will be $100$. The input size of the second affine layer will be $100$ and the output size will be $10$. The output size of the second affine layer is 10 because we have $10$ classes in the MNIST dataset. Therefore, the last layer will out put the a vector $y = (y_1, \\dots, y_{K})^{\\top}$ where $y_k$ is the probability that the image belongs to class $k$."
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 32,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MultipleLayerModel(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" (linear_relu_stack): Sequential(\n",
" (0): Linear(in_features=784, out_features=100, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=100, out_features=100, bias=True)\n",
" (3): ReLU()\n",
" (4): Linear(in_features=100, out_features=10, bias=True)\n",
" )\n",
")\n"
]
}
],
2024-10-03 15:53:33 +02:00
"source": [
"input_dim = 1*28*28\n",
"output_dim = 10\n",
"hidden_dim = 100\n",
"\n",
"model = MultipleLayerModel(input_dim, output_dim, hidden_dim).to(device)\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we want to calissify images by their labels, we will use the [`nn.CrossEntropyLoss`](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) loss function. The `nn.CrossEntropyLoss` loss function combines the softmax function and the cross entropy loss function. The `nn.CrossEntropyLoss` loss function takes the logits as an input and returns the loss. The logits are the outputs of the last affine layer before the softmax function is applied. The `nn.CrossEntropyLoss` loss function is equivalent to applying the softmax function to the logits and then applying the cross entropy loss function to the softmax outputs and the labels."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"loss = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Adam optimiser"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, our training loop will combine a training and an evaluation pass per epoch. During the training pass, we propagate the training data through the model and calculate the loss. Then, we calculate the gradients of the loss with respect to the parameters of the model and update the parameters of the model. During the evaluation pass, we propagate the validation data through the model and calculate the loss and the accuracy. We do not calculate the gradients of the loss with respect to the parameters of the model and we do not update the parameters of the model. \n",
"\n",
"*What would happen if we do?*"
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 34,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0 \tError(train): 0.358887 \tAccuracy(train): 0.896062 \tError(valid): 0.185471 \tAccuracy(valid): 0.946836\n",
"Epoch: 5 \tError(train): 0.052651 \tAccuracy(train): 0.983104 \tError(valid): 0.095970 \tAccuracy(valid): 0.971687\n",
"Epoch: 10 \tError(train): 0.022297 \tAccuracy(train): 0.992771 \tError(valid): 0.090072 \tAccuracy(valid): 0.974374\n",
"Epoch: 15 \tError(train): 0.014676 \tAccuracy(train): 0.995188 \tError(valid): 0.104987 \tAccuracy(valid): 0.975731\n",
"Epoch: 20 \tError(train): 0.009700 \tAccuracy(train): 0.996771 \tError(valid): 0.109034 \tAccuracy(valid): 0.977089\n",
"Epoch: 25 \tError(train): 0.010248 \tAccuracy(train): 0.996354 \tError(valid): 0.122468 \tAccuracy(valid): 0.978363\n",
"Epoch: 30 \tError(train): 0.005657 \tAccuracy(train): 0.998208 \tError(valid): 0.166619 \tAccuracy(valid): 0.970855\n",
"Epoch: 35 \tError(train): 0.010991 \tAccuracy(train): 0.996333 \tError(valid): 0.143561 \tAccuracy(valid): 0.976978\n",
"Epoch: 40 \tError(train): 0.011190 \tAccuracy(train): 0.996521 \tError(valid): 0.152304 \tAccuracy(valid): 0.975150\n",
"Epoch: 45 \tError(train): 0.002502 \tAccuracy(train): 0.999458 \tError(valid): 0.152308 \tAccuracy(valid): 0.977255\n",
"Epoch: 50 \tError(train): 0.002459 \tAccuracy(train): 0.999167 \tError(valid): 0.227140 \tAccuracy(valid): 0.971493\n"
]
}
],
2024-10-03 15:53:33 +02:00
"source": [
"# Keep track of the loss values over training\n",
"train_loss = [] \n",
"valid_loss = []\n",
"\n",
"# Keep track of the accuracy values over training\n",
"train_acc = []\n",
"valid_acc = []\n",
"\n",
"for i in range(num_epochs+1):\n",
" # Training\n",
" model.train()\n",
" batch_loss = []\n",
" batch_acc = []\n",
" for batch_idx, (x, t) in enumerate(train_loader):\n",
" x = x.to(device)\n",
" t = t.to(device)\n",
" \n",
" # Forward pass\n",
" y = model(x)\n",
" E_value = loss(y, t)\n",
" \n",
" # Backward pass\n",
" optimizer.zero_grad()\n",
" E_value.backward()\n",
" optimizer.step()\n",
" \n",
" # Calculate accuracy\n",
" _, argmax = torch.max(y, 1)\n",
" acc = (t == argmax.squeeze()).float().mean()\n",
" \n",
" # Logging\n",
" batch_loss.append(E_value.item())\n",
" batch_acc.append(acc.item())\n",
" \n",
" train_loss.append(np.mean(batch_loss))\n",
" train_acc.append(np.mean(batch_acc))\n",
"\n",
" # Validation\n",
" model.eval()\n",
" batch_loss = []\n",
" batch_acc = []\n",
" for batch_idx, (x, t) in enumerate(valid_loader):\n",
" x = x.to(device)\n",
" t = t.to(device)\n",
" \n",
" # Forward pass\n",
" y = model(x)\n",
" E_value = loss(y, t)\n",
" \n",
" # Calculate accuracy\n",
" _, argmax = torch.max(y, 1)\n",
" acc = (t == argmax.squeeze()).float().mean()\n",
" \n",
" # Logging\n",
" batch_loss.append(E_value.item())\n",
" batch_acc.append(acc.item())\n",
" \n",
" valid_loss.append(np.mean(batch_loss))\n",
" valid_acc.append(np.mean(batch_acc))\n",
"\n",
" if i % stats_interval == 0:\n",
" print('Epoch: {} \\tError(train): {:.6f} \\tAccuracy(train): {:.6f} \\tError(valid): {:.6f} \\tAccuracy(valid): {:.6f}'.format(\n",
" i, train_loss[-1], train_acc[-1], valid_loss[-1], valid_acc[-1]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see below the evolution of our training and validation losses, as well as, the respective accuracies. We can see that the training loss decreases and the training accuracy increases with each epoch. However, the validation loss starts increasing after $10$ epochs and the validation accuracy increases only up to a certain point. \n",
"\n",
"*What could be happening here?*\n",
"\n",
"*Is training for 50 epoch a sensible choice?* \n",
"\n",
"*What number of epochs would be a better choice and why?* "
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 35,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACGaElEQVR4nO3dd3hUZfrw8e+Zkt57hwBJDGDoVQUFEQurYG8/xfbqoruirqvY2+Kia2+rrruoK4qCoAICS1OpKgKhFyGBhPRk0suU8/5xkoEhCaRMMjNwf64rV5IzZ848kzuT3POU+1FUVVURQgghhBDCTehc3QAhhBBCCCGOJwmqEEIIIYRwK5KgCiGEEEIItyIJqhBCCCGEcCuSoAohhBBCCLciCaoQQgghhHArkqAKIYQQQgi3IgmqEEIIIYRwK5KgCiGEEEIItyIJqhBCCCGEcCsGVzfA2crKyrBYLN3yWJGRkRQVFXXLY4muITH0bBI/zycx9HwSQ8/XnTE0GAyEhoae+rxuaEu3slgsmM3mLn8cRVHsj6eqapc/nnA+iaFnk/h5Pomh55MYej53jaEM8QshhBBCCLciCaoQQgghhHArkqAKIYQQQgi3IgmqEEIIIYRwK6fdIikhhBBCuF59fT319fWuboZog9raWhoaGpx2PW9vb7y9vTt1DUlQhRBCCOFU1dXVKIpCYGCgfZW4cF9Go9FpFZBUVaW2tpbq6mr8/f07fB0Z4hdCCCGEU1ksFvz8/CQ5PQMpioKfn1+na9JLgiqEEEIIp5LEVHT2d0ASVCGEEEII4VZkDmoHqHU1qPt2UnMwEHqlu7o5QgghhBCnFelB7YiyUmxvPU/Zmy+4uiVCCCGEOM2UlpaSkZHBkSNHuuT6c+fOJT29fR1sd911F++//36XtKclkqB2ROOqNFt1JarN6uLGCCGEEMIZpk+fTnx8fLOPm266qVvb8fbbbzNhwgQSExMBWL9+PfHx8ZSXlzvl+pdffjk//fRTu+7zwAMP8Oabb1JZWemUNpyKDPF3hF+A9llVoaYG/ANc2x4hhBBCOMUFF1zAq6++6nDMy8urxXPNZjNGo/GUx9qi6X61tbV88cUXfPLJJ+2+RkNDQ6ttPZ6vry++vr7tunbfvn1JTEzk66+/5tZbb21329qrQwnqsmXL+PbbbzGZTCQkJDB16tRWu4r37NnDZ599Rm5uLvX19URGRnLhhRcyadIk+zlr1qzh3XffbXbf//73v236QXc3xWAEb1+or4XqSklQhRBCiJNQVRUaXFS038u7XSvKvby8iIqKavG2+Ph4XnzxRVavXs1PP/3EPffcg6IoLF26lDvuuIM33niDI0eOcOTIEY4ePcoTTzzB2rVr0el0nH/++bzwwgtERkYC8Morr7R4v9WrV6PX6xk6dCgAR44c4ZprrgG0JBHgmmuu4fXXX+fqq68mLS0No9HIvHnzSEtLY/78+bz//vt8+eWXZGdnExISwoQJE3jiiSfsdUnnzp3LM888w+7duwF46aWXWLJkCXfffTcvv/wy5eXlXHDBBbz88ssEBBzLcS666CK++eYb90xQ169fz+zZs7nzzjtJS0tjxYoVzJw5k9dee42IiIhm53t7ezNx4kR69OiBt7c3e/bs4cMPP8THx4cLL7zQfp6vry9vvPGGw33dMTm18/fXEtSaKle3RAghhHBvDfXY7rvWJQ+te/tL8PZx2vVeeeUVZsyYwTPPPINer2fu3LlkZWXx3Xff8eGHH6LTabMnb7/9dvz8/Jg/fz4Wi4XHHnuMP/7xj8ybN89+rZbut3HjRgYMGGA/Jy4ujg8//JC77rqLH3/8kcDAQHx8jj2fr776iltuuYWFCxcee846Hc899xyJiYkcPnyYxx57jBdeeIEXX3yx1eeVnZ3NsmXL+PjjjykvL+eee+7h7bff5tFHH7WfM3DgQN5++23q6+s7vVPUqbQ7QV20aBHjxo1j/PjxAEydOpVt27axfPlybrzxxmbnJycnk5ycbP8+KiqKn3/+md27dzskqIqiEBIS0oGn4CL+gVBajFpdhVR7E0IIIU4PK1asICUlxeHYtGnTeOCBBwCYPHky119/vcPtZrOZN998k/DwcAB+/PFHdu/ezYYNG4iPjwfgzTff5IILLmDr1q0MHDiwxfsB5OTkEB0dbf9er9fb86OIiAiCg4MdHrtnz5488cQTDsfuuusu+9dJSUk8/PDDzJgx46QJqs1m47XXXrP3mF511VWsXbvW4ZyYmBjq6+spKioiISGh1Ws5Q7sSVIvFwsGDB5k8ebLD8YyMDPbu3dumaxw6dIi9e/c2C25dXR3Tpk3DZrPRs2dPrrvuOofE1t0o/oGooA3xCyGEEKJ1Xt5aT6aLHrs9Ro8e3SyRO74D7fjezSbx8fEOSeb+/fuJi4uzJ6cAqampBAcHs3//fnuCeuL9QMuH2tM72VJ71q1bx1tvvcX+/fuprKzEarVSV1dHTU0Nfn5+LV4nMTHRYTg/KiqKkpISh3Oaem5ra2vb3L6OaleCWlFRgc1ma5a9BwcHYzKZTnrfe+65h4qKCqxWK9dcc429Bxa07utp06aRlJREbW0tS5Ys4cknn+Tll18mNja2xeuZzWaHfWMVRbFP+O2WHSya5p3WVMmOGR6qKW4SP88k8fN8EkPP19YYKori1GH2ruTn53fSDrKWErwTj6mq2uLP5MTjLV0rLCysXav1T1zslJOTwy233MLNN9/Mww8/TEhICL/88gsPPfSQQ950IoPBMSVUFAWbzeZwrCnXOzGpbk1nXtsdWiTV0gOeqhHPPfccdXV17Nu3jzlz5hATE8O5554LaO8qUlNT7eempaXxyCOP8P3333P77be3eL0FCxY4zONITk5m1qxZ9snHXa00IopqIEAHwa0k0cIzxMTEuLoJohMkfp5PYuj5ToxhbW1th1ayu5pOp0NRlJO2Xa/XO9ze0n3S09PJzc2lsLDQ3ou6d+9eKioqSE9Px2g0tvpYGRkZzJs3z+F4UxKq0+kcjiuK0qw9O3bswGKx8MILL9jntS5ZsgQAo9GI0Wi0J6MnXuv47/V6fbNjBw4cIC4uzmEKQmu8vLxa7WRsi3YlqEFBQeh0uma9peXl5c16VU/UtCIuKSmJ8vJyvvrqK3uCeiKdTkfv3r3Jz89v9XpTpkxxqATQlCAXFRVhsVja8nQ6xaboAajKz6MmL6/LH084n6IoxMTEkJ+fr60wFR5F4uf5JIaer7UYNjQ0nLS3zl3ZbDbq6+vJzc11OG4wGAgLCwPAarU6PDebzYaqqg7HRo8eTXp6Ovfccw/PPvusfZHUqFGj6NevH2azucX7AZx33nn87W9/o6ioyD61ICYmBkVR+P777xk/fjw+Pj74+/ujqmqz9iQkJGCxWHj//feZMGECv/zyC7NnzwaOjT435UnH3+/Etlit1mbH1q9fz5gxY9oU24aGBvJayI8MBkObOhPblaAaDAZ69epFZmYmw4cPtx/PzMxk2LBhbb6OqqonTSJVVSU7O9teoLYlTe8CWrt/l2sc4lerK+UPq4dTVVVi6MEkfp5PYuj5TqcYrl69mkGDBjkc6927Nz/++GObr6EoCv/+97954oknuPLKKx3KTJ1Keno6GRkZfPfdd/zf//0fALGxsTz00EO8+OKLPPjgg1x99dW8/vrrLd6/f//+PP3007z77ru8+OKLjBw5khkzZnD//fe3uf0tqaurY+nSpXz22Wdtvk9nficUtZ33Xr9+PW+99RZ33XUXqamprFixgpUrV/Lqq68SGRnJnDlzKC0t5b777gNg6dKlRERE2Lu49+zZw+zZs7nkkkvsC6W++uorUlJSiI2Ntc9B/emnn3j++efp06dPu55QUVFRt7xrU39aju2Tt1EGjkB37+Nd/njC+RRFITY2lry8vNPmD+uZROLn+SSGnq+1GFZUVBAUFOTClnm2lStX8vzzz7Nq1Sr7MH1XMhqNp8ydZs+ezbJly/j888/bdM3WfgeMRqPze1BB67aurKxk/vz5lJWVkZiYyIwZM+wPVlZWRnFxsf18VVX5/PPPKSwsRKfTERM
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqgAAAF0CAYAAADvrMuVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB5+ElEQVR4nO3dd3xV9eH/8de5udmbDDIZCRsE2cqWIaJURpFWbS2On34r1lprbRW3VIurVqutba3aIeIogqiACIqCggNk75WQQUL2zr33/P445EJMIDu5V9/PxyOP5J577jmfez+5N+981jFM0zQREREREfEQto4ugIiIiIjImRRQRURERMSjKKCKiIiIiEdRQBURERERj6KAKiIiIiIeRQFVRERERDyKAqqIiIiIeBQFVBERERHxKAqoIiIiIuJRFFBFRERExKPYO7oArS0/Px+Hw9Eu54qJiSEnJ6ddziVtQ3Xo3VR/3k916P1Uh96vPevQbrcTGRnZ8H7tUJZ25XA4qK6ubvPzGIbhPp9pmm1+Pml9qkPvpvrzfqpD76c69H6eWofq4hcRERERj6KAKiIiIiIeRQFVRERERDyKAqqIiIiIeBQFVBERERHxKE2exb9r1y6WL1/O4cOHyc/P54477mDEiBENPuaVV14hPT2dyMhILr/8ci6++OJa+3z++ecsWbKE7OxsOnfuzJVXXtngcUVERETku6fJLaiVlZV069aN6667rlH7nzhxgkcffZS+ffuyaNEiZs2axUsvvcTnn3/u3mffvn08/fTTjBs3jscff5xx48bxxz/+kf379ze1eCIiIiLi5Zrcgjp48GAGDx7c6P1Xr15NdHQ08+bNAyApKYmDBw/yzjvvcMEFFwDw7rvvMnDgQGbNmgXArFmz2LVrF++++y633XZbU4soIiIiIl6szRfq379/PwMHDqy17fzzz2fdunU4HA7sdjv79u3jsssuq7XPoEGDeO+998563Orq6loL8huGQWBgoPvntlZzjvY4l7QN1aF3U/15P9Wh91Mdej9PrcM2D6gFBQWEh4fX2hYeHo7T6aS4uJjIyEgKCgqIiIiotU9ERAQFBQVnPe7SpUt588033be7d+/OokWLiImJac3iNyguLq5dzyetT3Xo3VR/3k916DlMlwuzqhJbQGCTHvddqUPT5cIsLcFVUoSzuBBXcRG2kFD8evX3uADX2jytDtvlUqffrtSaS2mdq7JN0zzn/bNmzWL69Ol1zpGTk4PD4WhJcRvFMAzi4uLIysryqEuDSeOpDr2b6s/7qQ7bn1ldDfm5mCdPwMkc63teDpw8gXkyB/JzwOHAGDEO21U3YYSEnfN43laHZkU55sa1mJlpUFYCpcWYpdZ3SkusbfU8D6P/EGw/ugEjIbkDSt06TJfLep5FBZhFBVCUD4UFUFxIp6EXUNi9T7vUod1ub1RjYpsH1PpaQouKivDx8SEkJOSs+xQWFtZpeT2Tr68vvr6+9d7Xnm8S0zS94k0pZ6c69G6qP++nOmwc01ENGWmY6Ych7QiUFIHLBaYLXC7rNTzjtvXdtL5XlENerhVKGvFam5vX49y7HdvPfoFx3rCG9/fwOjRLSzDXrsD88B0rpDXEPwCCQyAoBLLSMXd+jfOBWzAuugzjBz/GCA5t+0KfwTx2CDPjmFWvLqf15Tz1s9Np1bHzjO1VlWcE0QIoLoDiQuvx9Sg3wOzW26PqsM0Das+ePfnqq69qbfvmm29ISUnBbrdO36tXL7Zv316rRXTbtm306tWrrYsnIiLiccziIkg/jJl2GNIOW6E0Mx2crdBD6OcHnWKhUwxGVAxExUJUDEYn6zuF+bhefgYy03A98xDG2Isx5l6HERDU8nO3M7MoH/OD5ZgfvWeFdIDYBIyhoyAkDIJDMYJDIDjUCqTBoRAUgnFGA5h5IgPXGy/B1k2YH76D+flHGDOuwhh3CYaPT9uWPzsD1//+BV9vbL2DhoRCaASERWCERUB4JAHnj6Cy9c7QKpocUCsqKsjKynLfPnHiBEeOHCEkJITo6GheffVV8vLyuOWWWwC4+OKLWbVqFa+88gqTJk1i3759rF27ll/+8pfuY1x66aXcf//9vP322wwfPpwvvviC7du389BDD7XCUxQREU9hmiZkHMPc8hnm1s1kOqpwpvSBPgMx+pyHERbZ0UXsEGZOFuZnazGPHIC0w1Bwsv4dA4MhuRtGcgpERoPNBoZhfbfZwLDVu83w9YNO0VYYDQk793jKqFhs9zyFufQ/mB8ux/xkNebub7BdextGr/5t8wK0MjMvB3PVUsxPVkN1lbUxsSvGpVdgDBuNYWt8sDRiE/CZvwBz9ze4lvwDjh/FfPUFzI/et7r9+53f+uUvLsR85zXM9SutllHDBj37Wf9c2HxOfdmsgGyzWbfP/NnXt3YIDYuA8AgICcew145+hmEQFB9PYWZmqz+PljDMJrbn7ty5kwcffLDO9vHjxzN//nyee+45cnJyeOCBB9z31SzUn5aWRmRkJDNmzKh3of7XXnuN7Oxs4uLi+PGPf8zIkSOb/IRycnJqze5vK4ZhEB8fT2Zmpkc1iUvjqQ69W0P1Z5omZB3H3LsNc882SDuC0bMfxqVzMGITOqDE3quhOQENPt7lgkN7Mbd8jrnlM8jJOvvOiV0x+gzE6DMQeg3ACApu9nk9nelywo6vcX30Puz4qm7Xe0wcJHfHSOqOkdwNklOsVs92nKxj7t2O66U/wckTYBgYU2ZizLzaCry0zeeo6XTCN5sxD+6BiE4Y0Z3h1JcReO5WXPNEBub7b2F+tu50a3P3XtguvQIGDsewtewCmqbTibl+Feby/0LJqaEC54/EdsW1rfK5YlZVYq5ZjrnyLSgvszaeNwzbD3+Gkdi1xcevT3v/LfT19W3UGNQmB1RPp4AqjaU69G7frj/TNCEnC3Pvdtiz3fpemFfPA20YI8ZaLSkJXdq/4F7CrK6GbV/g2vgh7NpqdYd2TsDonGB1kXaOh9gEiImv1R3qfryjGvZss0Lp1k3WOLgadl/odz62IRfSqWt3Tm78GHP3N5B+uPZBDBt062EF1r6DILUPhp9/6z1H07TGFx7cA2Wl4GO3WqF8fKwy+vhguLed8T04BOKTm929axYXYW74APPjlZCbffqOfoMxzh+BkdQdkro1GMbai1lehrnkH5gb1lgbErpgu+5XGF1TW/Vz1Cw4ibl+tdXqebYW5JBQiOpshdaYOCu0RncGf3/Mde9hfvGpNR4ToPd5VjDtO6jVQ71ZWoL5zmLMde9a4zp97BiTf4Bx2Y+aVW+my4n5+UeYb/8X8nOtjV1SsM251vrdb0MKqO1EAVUaS3Xo/WLtNrLXr8Hcs80KpHk5tXew+1qhps9AjIQuuD79ALZ/ad1nGDD4QmyXXYHRJbXVymRWV1t/XPNzMfOt7+Sd8XNZiRXqkrparYWJ3SA+qdnBy6yugpPWLGz8/KFravOPZZpw7CDmhg8xN69v3GQSw4BOMRAbj9E50WrpOnYQc/uXp1uAAAKDMM4bjjHkAug/GCMgqO4/GcWFsHc75u5tmHu+gRPf6nL0sUNiF4zk7pCcitElxerubuTYSLOqEo4cwDy4G/PAbji4p3HPsT5+ftClB0b3ntCtJ0b3XlZYOksQMk0TDu/D/OhUiHKc+jsVFIwxejLG+GlW+Pdg5jebcf3rz9Y/Gz4+GNN/hO3SuSQkJTX7c9R0uWDPNlwfvw9bN52exBMShjHkQigrxczNhtys0y2WDTlvGLZLr8Do0bfJ5WkqMzPN6vbfucXaEBIGPfpa74XOCRhxSdA5AULDz/67sWsLrjdePv0PWqcYjFk/wRgxvsUtvo2hgNpOFFClsVSH3sEsK4HsTMwTGZCdAScyME9kQvZxq9XrTD526N7rVBfxeZDS290V6T7esYO43n2j9qSDgcOxXTYXI6V348rkcsGJTMxjB60wlnUcagJocWHTn6Rhg9j4U4G1qxVeE7pCbJw1/uxkzTJA2ZB74tTPJ6yfv91K7OMDXVKt55LaFyO1N0anc/8xMAvzMTd9hLlxLRw/evqOiCiMCy/CGD4
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2024-10-03 15:53:33 +02:00
"source": [
"# Plot the change in the validation and training set error over training.\n",
"fig_1 = plt.figure(figsize=(8, 4))\n",
"ax_1 = fig_1.add_subplot(111)\n",
"ax_1.plot(train_loss, label='Error(train)')\n",
"ax_1.plot(valid_loss, label='Error(valid)')\n",
"ax_1.legend(loc=0)\n",
"ax_1.set_xlabel('Epoch number')\n",
"\n",
"# Plot the change in the validation and training set accuracy over training.\n",
"fig_2 = plt.figure(figsize=(8, 4))\n",
"ax_2 = fig_2.add_subplot(111)\n",
"ax_2.plot(train_acc, label='Accuracy(train)')\n",
"ax_2.plot(valid_acc, label='Accuracy(valid)')\n",
"ax_2.legend(loc=0)\n",
"ax_2.set_xlabel('Epoch number')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once our model is training and we are satisfied with the results, we can test its performance on unseen data. We can do this by propagating the test data through the model and calculating the accuracy. We can see that the test accuracy is similar to the validation accuracy.\n",
"\n",
"*Altought using a test set is not necessary for training, why is it important to have a one?*"
]
},
{
"cell_type": "code",
2024-10-10 15:52:23 +02:00
"execution_count": 36,
2024-10-03 15:53:33 +02:00
"metadata": {},
2024-10-10 15:52:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy(test): 0.971915\n"
]
}
],
2024-10-03 15:53:33 +02:00
"source": [
"# Testing\n",
"test_acc = []\n",
"model.eval()\n",
"for batch_idx, (x, t) in enumerate(test_loader):\n",
" x = x.to(device)\n",
" t = t.to(device)\n",
"\n",
" # Forward pass\n",
" y = model(x)\n",
" \n",
" # Calculate accuracy\n",
" _, argmax = torch.max(y, 1)\n",
" acc = (t == argmax.squeeze()).float().mean()\n",
" \n",
" test_acc.append(acc.item())\n",
"test_acc = np.mean(test_acc)\n",
"print('Accuracy(test): {:.6f}'.format(test_acc))"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}