2017-09-29 18:54:05 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Single layer models\n",
"\n",
"In this lab we will implement a single-layer network model consisting of solely of an affine transformation of the inputs. The relevant material for this was covered in [the slides of the first lecture](http://www.inf.ed.ac.uk/teaching/courses/mlp/2016/mlp01-intro.pdf). \n",
"\n",
"We will first implement the forward propagation of inputs to the network to produce predicted outputs. We will then move on to considering how to use gradients of an error function evaluated on the outputs to compute the gradients with respect to the model parameters to allow us to perform an iterative gradient-descent training procedure. In the final exercise you will use an interactive visualisation to explore the role of some of the different hyperparameters of gradient-descent based training methods.\n",
"\n",
"#### A note on random number generators\n",
"\n",
"It is generally a good practice (for machine learning applications **not** for cryptography!) to seed a pseudo-random number generator once at the beginning of each experiment. This makes it easier to reproduce results as the same random draws will produced each time the experiment is run (e.g. the same random initialisations used for parameters). Therefore generally when we need to generate random values during this course, we will create a seeded random number generator object as we do in the cell below."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 1,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-09-29 18:54:05 +02:00
"outputs": [],
"source": [
"import numpy as np\n",
"seed = 27092016 \n",
"rng = np.random.RandomState(seed)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 1: linear and affine transforms\n",
"\n",
"Any *linear transform* (also called a linear map) of a finite-dimensional vector space can be parametrised by a matrix. So for example if we consider $\\boldsymbol{x} \\in \\mathbb{R}^{D}$ as the input space of a model with $D$ dimensional real-valued inputs, then a matrix $\\mathbf{W} \\in \\mathbb{R}^{K\\times D}$ can be used to define a prediction model consisting solely of a linear transform of the inputs\n",
"\n",
"\\begin{equation}\n",
" \\boldsymbol{y} = \\mathbf{W} \\boldsymbol{x}\n",
" \\qquad\n",
" \\Leftrightarrow\n",
" \\qquad\n",
" y_k = \\sum_{d=1}^D \\left( W_{kd} x_d \\right) \\quad \\forall k \\in \\left\\lbrace 1 \\dots K\\right\\rbrace\n",
"\\end{equation}\n",
"\n",
"with here $\\boldsymbol{y} \\in \\mathbb{R}^K$ the $K$-dimensional real-valued output of the model. Geometrically we can think of a linear transform doing some combination of rotation, scaling, reflection and shearing of the input.\n",
"\n",
"An *affine transform* consists of a linear transform plus an additional translation parameterised by a vector $\\boldsymbol{b} \\in \\mathbb{R}^K$. A model consisting of an affine transformation of the inputs can then be defined as\n",
"\n",
"\\begin{equation}\n",
" \\boldsymbol{y} = \\mathbf{W}\\boldsymbol{x} + \\boldsymbol{b}\n",
" \\qquad\n",
" \\Leftrightarrow\n",
" \\qquad\n",
" y_k = \\sum_{d=1}^D \\left( W_{kd} x_d \\right) + b_k \\quad \\forall k \\in \\left\\lbrace 1 \\dots K\\right\\rbrace\n",
"\\end{equation}\n",
"\n",
"In machine learning we will usually refer to the matrix $\\mathbf{W}$ as a *weight matrix* and the vector $\\boldsymbol{b}$ as a *bias vector*.\n",
"\n",
"Generally rather than working with a single data vector $\\boldsymbol{x}$ we will work with batches of datapoints $\\left\\lbrace \\boldsymbol{x}^{(b)}\\right\\rbrace_{b=1}^B$. We could calculate the outputs for each input in the batch sequentially\n",
"\n",
"\\begin{align}\n",
" \\boldsymbol{y}^{(1)} &= \\mathbf{W}\\boldsymbol{x}^{(1)} + \\boldsymbol{b}\\\\\n",
" \\boldsymbol{y}^{(2)} &= \\mathbf{W}\\boldsymbol{x}^{(2)} + \\boldsymbol{b}\\\\\n",
" \\dots &\\\\\n",
" \\boldsymbol{y}^{(B)} &= \\mathbf{W}\\boldsymbol{x}^{(B)} + \\boldsymbol{b}\\\\\n",
"\\end{align}\n",
"\n",
"by looping over each input in the batch and calculating the output. However in general loops in Python are slow (particularly compared to compiled and typed languages such as C). This is due at least in part to the large overhead in dynamically inferring variable types. In general therefore wherever possible we want to avoid having loops in which such overhead will become the dominant computational cost.\n",
"\n",
"For array based numerical operations, one way of overcoming this bottleneck is to *vectorise* operations. NumPy `ndarrays` are typed arrays for which operations such as basic elementwise arithmetic and linear algebra operations such as computing matrix-matrix or matrix-vector products are implemented by calls to highly-optimised compiled libraries. Therefore if you can implement code directly using NumPy operations on arrays rather than by looping over array elements it is often possible to make very substantial performance gains.\n",
"\n",
"As a simple example we can consider adding up two arrays `a` and `b` and writing the result to a third array `c`. First lets initialise `a` and `b` with arbitrary values by running the cell below."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 2,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-09-29 18:54:05 +02:00
"outputs": [],
"source": [
"size = 1000\n",
"a = np.arange(size)\n",
"b = np.ones(size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's time how long it takes to add up each pair of values in the two array and write the results to a third array using a loop-based implementation. We will use the `%%timeit` magic briefly mentioned in the previous lab notebook specifying the number of times to loop the code as 100 and to give the best of 3 repeats. Run the cell below to get a print out of the average time taken."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 3,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.05 ms ± 148 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)\n"
]
}
],
2017-09-29 18:54:05 +02:00
"source": [
"%%timeit -n 100 -r 3\n",
"c = np.empty(size)\n",
"for i in range(size):\n",
" c[i] = a[i] + b[i]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now we will perform the corresponding summation with the overloaded addition operator of NumPy arrays. Again run the cell below to get a print out of the average time taken."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 4,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.01 µs ± 1.53 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)\n"
]
}
],
2017-09-29 18:54:05 +02:00
"source": [
"%%timeit -n 100 -r 3\n",
"c = a + b"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first loop-based implementation should have taken on the order of milliseconds ($10^{-3}$s) while the vectorised implementation should have taken on the order of microseconds ($10^{-6}$s), i.e. a $\\sim1000\\times$ speedup. Hopefully this simple example should make it clear why we want to vectorise operations whenever possible!\n",
"\n",
"Getting back to our affine model, ideally rather than individually computing the output corresponding to each input we should compute the outputs for all inputs in a batch using a vectorised implementation. As you saw last week, data providers return batches of inputs as arrays of shape `(batch_size, input_dim)`. In the mathematical notation used earlier we can consider this as a matrix $\\mathbf{X}$ of dimensionality $B \\times D$, and in particular\n",
"\n",
"\\begin{equation}\n",
" \\mathbf{X} = \\left[ \\boldsymbol{x}^{(1)} ~ \\boldsymbol{x}^{(2)} ~ \\dots ~ \\boldsymbol{x}^{(B)} \\right]^\\mathrm{T}\n",
"\\end{equation}\n",
"\n",
"i.e. the $b^{\\textrm{th}}$ input vector $\\boldsymbol{x}^{(b)}$ corresponds to the $b^{\\textrm{th}}$ row of $\\mathbf{X}$. If we define the $B \\times K$ matrix of outputs $\\mathbf{Y}$ similarly as\n",
"\n",
"\\begin{equation}\n",
" \\mathbf{Y} = \\left[ \\boldsymbol{y}^{(1)} ~ \\boldsymbol{y}^{(2)} ~ \\dots ~ \\boldsymbol{y}^{(B)} \\right]^\\mathrm{T}\n",
"\\end{equation}\n",
"\n",
"then we can express the relationship between $\\mathbf{X}$ and $\\mathbf{Y}$ using [matrix multiplication](https://en.wikipedia.org/wiki/Matrix_multiplication) and addition as\n",
"\n",
"\\begin{equation}\n",
" \\mathbf{Y} = \\mathbf{X} \\mathbf{W}^\\mathrm{T} + \\mathbf{B}\n",
"\\end{equation}\n",
"\n",
"where $\\mathbf{B} = \\left[ \\boldsymbol{b} ~ \\boldsymbol{b} ~ \\dots ~ \\boldsymbol{b} \\right]^\\mathrm{T}$ i.e. a $B \\times K$ matrix with each row corresponding to the bias vector. The weight matrix needs to be transposed here as the inner dimensions of a matrix multiplication must match i.e. for $\\mathbf{C} = \\mathbf{A} \\mathbf{B}$ then if $\\mathbf{A}$ is of dimensionality $K \\times L$ and $\\mathbf{B}$ is of dimensionality $M \\times N$ then it must be the case that $L = M$ and $\\mathbf{C}$ will be of dimensionality $K \\times N$.\n",
"\n",
"The first exercise for this lab is to implement *forward propagation* for a single-layer model consisting of an affine transformation of the inputs in the `fprop` function given as skeleton code in the cell below. This should work for a batch of inputs of shape `(batch_size, input_dim)` producing a batch of outputs of shape `(batch_size, output_dim)`.\n",
" \n",
"You will probably want to use the NumPy `dot` function and [broadcasting features](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) to implement this efficiently. If you are not familiar with either / both of these you may wish to read the [hints](#Hints:-Using-the-dot-function-and-broadcasting) section below which gives some details on these before attempting the exercise."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 5,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-09-29 18:54:05 +02:00
"outputs": [],
"source": [
"def fprop(inputs, weights, biases):\n",
" \"\"\"Forward propagates activations through the layer transformation.\n",
"\n",
" For inputs `x`, outputs `y`, weights `W` and biases `b` the layer\n",
" corresponds to `y = W x + b`.\n",
"\n",
" Args:\n",
" inputs: Array of layer inputs of shape (batch_size, input_dim).\n",
" weights: Array of weight parameters of shape \n",
" (output_dim, input_dim).\n",
" biases: Array of bias parameters of shape (output_dim, ).\n",
"\n",
" Returns:\n",
" outputs: Array of layer outputs of shape (batch_size, output_dim).\n",
" \"\"\"\n",
2017-10-06 15:46:19 +02:00
" return inputs.dot(weights.T) + biases"
2017-09-29 18:54:05 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once you have implemented `fprop` in the cell above you can test your implementation by running the cell below."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 6,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All outputs correct!\n"
]
}
],
2017-09-29 18:54:05 +02:00
"source": [
"inputs = np.array([[0., -1., 2.], [-6., 3., 1.]])\n",
"weights = np.array([[2., -3., -1.], [-5., 7., 2.]])\n",
"biases = np.array([5., -3.])\n",
"true_outputs = np.array([[6., -6.], [-17., 50.]])\n",
"\n",
"if not np.allclose(fprop(inputs, weights, biases), true_outputs):\n",
" print('Wrong outputs computed.')\n",
"else:\n",
" print('All outputs correct!')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Hints: Using the `dot` function and broadcasting\n",
"\n",
"For those new to NumPy below are some details on the `dot` function and broadcasting feature of NumPy that you may want to use for implementing the first exercise. If you are already familiar with these and have already completed the first exercise you can move on straight to [second exercise](#Exercise-2:-visualising-random-models).\n",
"\n",
"#### `numpy.dot` function\n",
"\n",
2017-10-02 02:21:49 +02:00
"Matrix-matrix, matrix-vector and vector-vector (dot) products can all be computed in NumPy using the [`dot`](http://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html) function. For example if `A` and `B` are both two dimensional arrays, then `C = np.dot(A, B)` or equivalently `C = A.dot(B)` will both compute the matrix product of `A` and `B` assuming `A` and `B` have compatible dimensions. Similarly if `a` and `b` are one dimensional arrays then `c = np.dot(a, b)` (which is equivalent to `c = a.dot(b)`) will compute the [scalar / dot product](https://en.wikipedia.org/wiki/Dot_product) of the two arrays. If `A` is a two-dimensional array and `b` a one-dimensional array `np.dot(A, b)` (which is equivalent to `A.dot(b)`) will compute the matrix-vector product of `A` and `b`. Examples of all three of these product types are shown in the cell below:"
2017-09-29 18:54:05 +02:00
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 7,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 6. 6. 6.]\n",
" [ 24. 24. 24.]\n",
" [ 42. 42. 42.]]\n",
"[[ 18. 24. 30.]\n",
" [ 18. 24. 30.]\n",
" [ 18. 24. 30.]]\n",
"[ 0.8 2.6 4.4]\n",
"[ 2.4 3. 3.6]\n",
"0.2\n"
]
}
],
2017-09-29 18:54:05 +02:00
"source": [
"# Initiliase arrays with arbitrary values\n",
"A = np.arange(9).reshape((3, 3))\n",
"B = np.ones((3, 3)) * 2\n",
"a = np.array([-1., 0., 1.])\n",
"b = np.array([0.1, 0.2, 0.3])\n",
"print(A.dot(B)) # Matrix-matrix product\n",
"print(B.dot(A)) # Reversed product of above A.dot(B) != B.dot(A) in general\n",
"print(A.dot(b)) # Matrix-vector product\n",
"print(b.dot(A)) # Again A.dot(b) != b.dot(A) unless A is symmetric i.e. A == A.T\n",
"print(a.dot(b)) # Vector-vector scalar product"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Broadcasting\n",
"\n",
"Another NumPy feature it will be helpful to get familiar with is [broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). Broadcasting allows you to apply operations to arrays of different shapes, for example to add a one-dimensional array to a two-dimensional array or multiply a multidimensional array by a scalar. The complete set of rules for broadcasting as explained in the official documentation page just linked to can sound a bit complex: you might find the [visual explanation on this page](http://www.scipy-lectures.org/intro/numpy/operations.html#broadcasting) more intuitive. The cell below gives a few examples:"
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 8,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.1 1.2]\n",
" [ 2.1 3.2]\n",
" [ 4.1 5.2]]\n",
"[[-1. 0.]\n",
" [ 2. 3.]\n",
" [ 5. 6.]]\n",
"[[ 0. 0.2]\n",
" [ 0.2 0.6]\n",
" [ 0.4 1. ]]\n"
]
}
],
2017-09-29 18:54:05 +02:00
"source": [
"# Initiliase arrays with arbitrary values\n",
"A = np.arange(6).reshape((3, 2))\n",
"b = np.array([0.1, 0.2])\n",
"c = np.array([-1., 0., 1.])\n",
"print(A + b) # Add b elementwise to all rows of A\n",
"print((A.T + c).T) # Add b elementwise to all columns of A\n",
"print(A * b) # Multiply each row of A elementise by b "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 2: visualising random models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this exercise you will use your `fprop` implementation to visualise the outputs of a single-layer affine transform model with two-dimensional inputs and a one-dimensional output. In this simple case we can visualise the joint input-output space on a 3D axis.\n",
"\n",
"For this task and the learning experiments later in the notebook we will use a regression dataset from the [UCI machine learning repository](http://archive.ics.uci.edu/ml/index.html). In particular we will use a version of the [Combined Cycle Power Plant dataset](http://archive.ics.uci.edu/ml/datasets/Combined+Cycle+Power+Plant), where the task is to predict the energy output of a power plant given observations of the local ambient conditions (e.g. temperature, pressure and humidity).\n",
"\n",
"The original dataset has four input dimensions and a single target output dimension. We have preprocessed the dataset by [whitening](https://en.wikipedia.org/wiki/Whitening_transformation) it, a common preprocessing step. We will only use the first two dimensions of the whitened inputs (corresponding to the first two principal components of the inputs) so we can easily visualise the joint input-output space.\n",
"\n",
"The dataset has been wrapped in the `CCPPDataProvider` class in the `mlp.data_providers` module and the data included as a compressed file in the data directory as `ccpp_data.npz`. Running the cell below will initialise an instance of this class, get a single batch of inputs and outputs and import the necessary `matplotlib` objects."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 9,
2017-10-02 02:24:09 +02:00
"metadata": {},
2017-09-29 18:54:05 +02:00
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"from mlp.data_providers import CCPPDataProvider\n",
"%matplotlib notebook\n",
"\n",
"data_provider = CCPPDataProvider(\n",
" which_set='train',\n",
" input_dims=[0, 1],\n",
" batch_size=5000, \n",
" max_num_batches=1, \n",
" shuffle_order=False\n",
")\n",
"\n",
"input_dim, output_dim = 2, 1\n",
"\n",
"inputs, targets = data_provider.next()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we used the `%matplotlib notebook` magic command rather than the `%matplotlib inline` we used in the previous lab as this allows us to produce interactive 3D plots which you can rotate and zoom in/out by dragging with the mouse and scrolling the mouse-wheel respectively. Once you have finished interacting with a plot you can close it to produce a static inline plot using the <i class=\"fa fa-power-off\"></i> button in the top-right corner.\n",
"\n",
"Now run the cell below to plot the predicted outputs of a randomly initialised model across the two dimensional input space as well as the true target outputs. This sort of visualisation can be a useful method (in low dimensions) to assess how well the model is likely to be able to fit the data and to judge appropriate initialisation scales for the parameters. Each time you re-run the cell a new set of random parameters will be sampled\n",
"\n",
"Some questions to consider:\n",
"\n",
" * How do the weights and bias initialisation scale affect the sort of predicted input-output relationships?\n",
" * Does the linear form of the model seem appropriate for the data here?"
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 10,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" this.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overriden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" // select the cell after this one\n",
" var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
" IPython.notebook.select(index + 1);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAyAAAAMgCAYAAADbcAZoAAAgAElEQVR4Xu3XMQ0AAAzDsJU/6bHI5RGoZO3JzhEgQIAAAQIECBAgQCASWLRjhgABAgQIECBAgAABAidAPAEBAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQPwAAQIECBAgQIAAAQKZgADJqA0RIECAAAECBAgQICBA/AABAgQIECBAgAABApmAAMmoDREgQIAAAQIECBAgIED8AAECBAgQIECAAAECmYAAyagNESBAgAABAgQIECAgQP
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2017-09-29 18:54:05 +02:00
"source": [
"weights_init_range = 0.5\n",
"biases_init_range = 0.1\n",
"\n",
"# Randomly initialise weights matrix\n",
"weights = rng.uniform(\n",
" low=-weights_init_range, \n",
" high=weights_init_range, \n",
" size=(output_dim, input_dim)\n",
")\n",
"\n",
"# Randomly initialise biases vector\n",
"biases = rng.uniform(\n",
" low=-biases_init_range, \n",
" high=biases_init_range, \n",
" size=output_dim\n",
")\n",
"# Calculate predicted model outputs\n",
"outputs = fprop(inputs, weights, biases)\n",
"\n",
"# Plot target and predicted outputs against inputs on same axis\n",
"fig = plt.figure(figsize=(8, 8))\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"ax.plot(inputs[:, 0], inputs[:, 1], targets[:, 0], 'r.', ms=2)\n",
"ax.plot(inputs[:, 0], inputs[:, 1], outputs[:, 0], 'b.', ms=2)\n",
"ax.set_xlabel('Input dim 1')\n",
"ax.set_ylabel('Input dim 2')\n",
"ax.set_zlabel('Output')\n",
"ax.legend(['Targets', 'Predictions'], frameon=False)\n",
"fig.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 3: computing the error function and its gradient\n",
"\n",
"Here we will consider the task of regression as covered in the first lecture slides. The aim in a regression problem is given inputs $\\left\\lbrace \\boldsymbol{x}^{(n)}\\right\\rbrace_{n=1}^N$ to produce outputs $\\left\\lbrace \\boldsymbol{y}^{(n)}\\right\\rbrace_{n=1}^N$ that are as 'close' as possible to a set of target outputs $\\left\\lbrace \\boldsymbol{t}^{(n)}\\right\\rbrace_{n=1}^N$. The measure of 'closeness' or distance between target and predicted outputs is a design choice. \n",
"\n",
"A very common choice is the squared Euclidean distance between the predicted and target outputs. This can be computed as the sum of the squared differences between each element in the target and predicted outputs. A common convention is to multiply this value by $\\frac{1}{2}$ as this gives a slightly nicer expression for the error gradient. The error for the $n^{\\textrm{th}}$ training example is then\n",
"\n",
"\\begin{equation}\n",
" E^{(n)} = \\frac{1}{2} \\sum_{k=1}^K \\left\\lbrace \\left( y^{(n)}_k - t^{(n)}_k \\right)^2 \\right\\rbrace.\n",
"\\end{equation}\n",
"\n",
"The overall error is then the *average* of this value across all training examples\n",
"\n",
"\\begin{equation}\n",
" \\bar{E} = \\frac{1}{N} \\sum_{n=1}^N \\left\\lbrace E^{(n)} \\right\\rbrace. \n",
"\\end{equation}\n",
"\n",
"*Note here we are using a slightly different convention from the lectures. There the overall error was considered to be the sum of the individual error terms rather than the mean. To differentiate between the two we will use $\\bar{E}$ to represent the average error here as opposed to sum of errors $E$ as used in the slides with $\\bar{E} = \\frac{E}{N}$. Normalising by the number of training examples is helpful to do in practice as this means we can more easily compare errors across data sets / batches of different sizes, and more importantly it means the size of our gradient updates will be independent of the number of training examples summed over.*\n",
"\n",
"The regression problem is then to find parameters of the model which minimise $\\bar{E}$. For our simple single-layer affine model here that corresponds to finding weights $\\mathbf{W}$ and biases $\\boldsymbol{b}$ which minimise $\\bar{E}$. \n",
"\n",
"As mentioned in the lecture, for this simple case there is actually a closed form solution for the optimal weights and bias parameters. This is the linear least-squares solution those doing MLPR will have come across.\n",
"\n",
"However in general we will be interested in models where closed form solutions do not exist. We will therefore generally use iterative, gradient descent based training methods to find parameters which (locally) minimise the error function. A basic requirement of being able to do gradient-descent based training is (unsuprisingly) the ability to evaluate gradients of the error function.\n",
"\n",
"In the next exercise we will consider how to calculate gradients of the error function with respect to the model parameters $\\mathbf{W}$ and $\\boldsymbol{b}$, but as a first step here we will consider the gradient of the error function with respect to the model outputs $\\left\\lbrace \\boldsymbol{y}^{(n)}\\right\\rbrace_{n=1}^N$. This can be written\n",
"\n",
"\\begin{equation}\n",
" \\frac{\\partial \\bar{E}}{\\partial \\boldsymbol{y}^{(n)}} = \\frac{1}{N} \\left( \\boldsymbol{y}^{(n)} - \\boldsymbol{t}^{(n)} \\right)\n",
" \\qquad \\Leftrightarrow \\qquad\n",
" \\frac{\\partial \\bar{E}}{\\partial y^{(n)}_k} = \\frac{1}{N} \\left( y^{(n)}_k - t^{(n)}_k \\right) \\quad \\forall k \\in \\left\\lbrace 1 \\dots K\\right\\rbrace\n",
"\\end{equation}\n",
"\n",
"i.e. the gradient of the error function with respect to the $n^{\\textrm{th}}$ model output is just the difference between the $n^{\\textrm{th}}$ model and target outputs, corresponding to the $\\boldsymbol{\\delta}^{(n)}$ terms mentioned in the lecture slides.\n",
"\n",
"The third exercise is, using the equations given above, to implement functions computing the mean sum of squared differences error and its gradient with respect to the model outputs. You should implement the functions using the provided skeleton definitions in the cell below."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 11,
2017-10-02 02:24:09 +02:00
"metadata": {},
2017-09-29 18:54:05 +02:00
"outputs": [],
"source": [
"def error(outputs, targets):\n",
" \"\"\"Calculates error function given a batch of outputs and targets.\n",
"\n",
" Args:\n",
" outputs: Array of model outputs of shape (batch_size, output_dim).\n",
" targets: Array of target outputs of shape (batch_size, output_dim).\n",
"\n",
" Returns:\n",
" Scalar error function value.\n",
" \"\"\"\n",
2017-10-06 15:46:19 +02:00
" return 0.5 * ((outputs - targets)**2).sum() / outputs.shape[0]\n",
2017-09-29 18:54:05 +02:00
" \n",
"def error_grad(outputs, targets):\n",
" \"\"\"Calculates gradient of error function with respect to model outputs.\n",
"\n",
" Args:\n",
" outputs: Array of model outputs of shape (batch_size, output_dim).\n",
" targets: Array of target outputs of shape (batch_size, output_dim).\n",
"\n",
" Returns:\n",
" Gradient of error function with respect to outputs.\n",
" This will be an array of shape (batch_size, output_dim).\n",
" \"\"\"\n",
2017-10-06 15:46:19 +02:00
" return (outputs - targets) / outputs.shape[0]"
2017-09-29 18:54:05 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check your implementation by running the test cell below."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 12,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Error function and gradient computed correctly!\n"
]
}
],
2017-09-29 18:54:05 +02:00
"source": [
"outputs = np.array([[1., 2.], [-1., 0.], [6., -5.], [-1., 1.]])\n",
"targets = np.array([[0., 1.], [3., -2.], [7., -3.], [1., -2.]])\n",
"true_error = 5.\n",
"true_error_grad = np.array([[0.25, 0.25], [-1., 0.5], [-0.25, -0.5], [-0.5, 0.75]])\n",
"\n",
"if not error(outputs, targets) == true_error:\n",
" print('Error calculated incorrectly.')\n",
"elif not np.allclose(error_grad(outputs, targets), true_error_grad):\n",
" print('Error gradient calculated incorrectly.')\n",
"else:\n",
" print('Error function and gradient computed correctly!')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 4: computing gradients with respect to the parameters\n",
"\n",
"In the previous exercise you implemented a function computing the gradient of the error function with respect to the model outputs. For gradient-descent based training, we need to be able to evaluate the gradient of the error function with respect to the model parameters.\n",
"\n",
"Using the [chain rule for derivatives](https://en.wikipedia.org/wiki/Chain_rule#Higher_dimensions) we can write the partial deriviative of the error function with respect to single elements of the weight matrix and bias vector as\n",
"\n",
"\\begin{equation}\n",
2017-10-02 02:21:49 +02:00
" \\frac{\\partial E}{\\partial W_{kj}} = \\sum_{n=1}^N \\left\\lbrace \\frac{\\partial E}{\\partial y^{(n)}_k} \\frac{\\partial y^{(n)}_k}{\\partial W_{kj}} \\right\\rbrace\n",
2017-09-29 18:54:05 +02:00
" \\quad \\textrm{and} \\quad\n",
2017-10-02 02:21:49 +02:00
" \\frac{\\partial E}{\\partial b_k} = \\sum_{n=1}^N \\left\\lbrace \\frac{\\partial E}{\\partial y^{(n)}_k} \\frac{\\partial y^{(n)}_k}{\\partial b_k} \\right\\rbrace.\n",
2017-09-29 18:54:05 +02:00
"\\end{equation}\n",
"\n",
"From the definition of our model at the beginning we have \n",
"\n",
"\\begin{equation}\n",
" y^{(n)}_k = \\sum_{d=1}^D \\left\\lbrace W_{kd} x^{(n)}_d \\right\\rbrace + b_k\n",
" \\quad \\Rightarrow \\quad\n",
" \\frac{\\partial y^{(n)}_k}{\\partial W_{kj}} = x^{(n)}_j\n",
" \\quad \\textrm{and} \\quad\n",
" \\frac{\\partial y^{(n)}_k}{\\partial b_k} = 1.\n",
"\\end{equation}\n",
"\n",
"Putting this together we get that\n",
"\n",
"\\begin{equation}\n",
2017-10-02 02:21:49 +02:00
" \\frac{\\partial E}{\\partial W_{kj}} = \n",
" \\sum_{n=1}^N \\left\\lbrace \\frac{\\partial E}{\\partial y^{(n)}_k} x^{(n)}_j \\right\\rbrace\n",
2017-09-29 18:54:05 +02:00
" \\quad \\textrm{and} \\quad\n",
2017-10-02 02:21:49 +02:00
" \\frac{\\partial E}{\\partial b_{k}} = \n",
" \\sum_{n=1}^N \\left\\lbrace \\frac{\\partial E}{\\partial y^{(n)}_k} \\right\\rbrace.\n",
2017-09-29 18:54:05 +02:00
"\\end{equation}\n",
"\n",
"Although this may seem a bit of a roundabout way to get to these results, this method of decomposing the error gradient with respect to the parameters in terms of the gradient of the error function with respect to the model outputs and the derivatives of the model outputs with respect to the model parameters, will be key when calculating the parameter gradients of more complex models later in the course.\n",
"\n",
"Your task in this exercise is to implement a function calculating the gradient of the error function with respect to the weight and bias parameters of the model given the already computed gradient of the error function with respect to the model outputs. You should implement this in the `grads_wrt_params` function in the cell below."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 13,
2017-10-02 02:24:09 +02:00
"metadata": {},
2017-09-29 18:54:05 +02:00
"outputs": [],
"source": [
"def grads_wrt_params(inputs, grads_wrt_outputs):\n",
" \"\"\"Calculates gradients with respect to model parameters.\n",
"\n",
" Args:\n",
" inputs: array of inputs to model of shape (batch_size, input_dim)\n",
" grads_wrt_to_outputs: array of gradients of with respect to the model\n",
" outputs of shape (batch_size, output_dim).\n",
"\n",
" Returns:\n",
" list of arrays of gradients with respect to the model parameters\n",
" `[grads_wrt_weights, grads_wrt_biases]`.\n",
" \"\"\"\n",
2017-10-06 15:46:19 +02:00
" grads_wrt_weights = grads_wrt_outputs.T.dot(inputs)\n",
" grads_wrt_biases = grads_wrt_outputs.sum(0)\n",
" return [grads_wrt_weights, grads_wrt_biases]"
2017-09-29 18:54:05 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check your implementation by running the test cell below."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 14,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All parameter gradients calculated correctly!\n"
]
}
],
2017-09-29 18:54:05 +02:00
"source": [
"inputs = np.array([[1., 2., 3.], [-1., 4., -9.]])\n",
"grads_wrt_outputs = np.array([[-1., 1.], [2., -3.]])\n",
"true_grads_wrt_weights = np.array([[-3., 6., -21.], [4., -10., 30.]])\n",
"true_grads_wrt_biases = np.array([1., -2.])\n",
"\n",
"grads_wrt_weights, grads_wrt_biases = grads_wrt_params(\n",
" inputs, grads_wrt_outputs)\n",
"\n",
"if not np.allclose(true_grads_wrt_weights, grads_wrt_weights):\n",
" print('Gradients with respect to weights incorrect.')\n",
"elif not np.allclose(true_grads_wrt_biases, grads_wrt_biases):\n",
" print('Gradients with respect to biases incorrect.')\n",
"else:\n",
" print('All parameter gradients calculated correctly!')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 5: wrapping the functions into reusable components\n",
"\n",
"In exercises 1, 3 and 4 you implemented methods to compute the predicted outputs of our model, evaluate the error function and its gradient on the outputs and finally to calculate the gradients of the error with respect to the model parameters. Together they constitute all the basic ingredients we need to implement a gradient-descent based iterative learning procedure for the model.\n",
"\n",
"Although you could implement training code which directly uses the functions you defined, this would only be usable for this particular model architecture. In subsequent labs we will want to use the affine transform functions as the basis for more interesting multi-layer models. We will therefore wrap the implementations you just wrote in to reusable components that we can build more complex models with later in the course.\n",
"\n",
" * In the [`mlp.layers`](/edit/mlp/layers.py) module, use your implementations of `fprop` and `grad_wrt_params` above to implement the corresponding methods in the skeleton `AffineLayer` class provided.\n",
" * In the [`mlp.errors`](/edit/mlp/errors.py) module use your implementation of `error` and `error_grad` to implement the `__call__` and `grad` methods respectively of the skeleton `SumOfSquaredDiffsError` class provided. Note `__call__` is a special Python method that allows an object to be used with a function call syntax.\n",
"\n",
2017-10-02 02:21:49 +02:00
"Run the cell below to use your completed `AffineLayer` and `SumOfSquaredDiffsError` implementations to train a single-layer model using batch gradient descent on the CCPP dataset."
2017-09-29 18:54:05 +02:00
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 16,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 1: 0.0s to complete\n",
" error(train)=1.67e-01\n",
"Epoch 2: 0.0s to complete\n",
" error(train)=9.30e-02\n",
"Epoch 3: 0.0s to complete\n",
" error(train)=7.95e-02\n",
"Epoch 4: 0.0s to complete\n",
" error(train)=7.71e-02\n",
"Epoch 5: 0.0s to complete\n",
" error(train)=7.66e-02\n",
"Epoch 6: 0.0s to complete\n",
" error(train)=7.65e-02\n",
"Epoch 7: 0.0s to complete\n",
" error(train)=7.65e-02\n",
"Epoch 8: 0.0s to complete\n",
" error(train)=7.65e-02\n",
"Epoch 9: 0.0s to complete\n",
" error(train)=7.63e-02\n",
"Epoch 10: 0.0s to complete\n",
" error(train)=7.64e-02\n"
]
},
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" this.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overriden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" // select the cell after this one\n",
" var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
" IPython.notebook.select(index + 1);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAyAAAAGQCAYAAABWJQQ0AAAgAElEQVR4Xu3dB5gV1eH+8S8dQUCxK9h7oShiF0VBk5hYYo9JTDFqrNForMGW2GLUqDExxhrL30R/xhgLBsWuIAj2ikpRQVF6393/c3ZnZV132bt7zt699+73Po+PwM68M/M5s7DvnTlz2+BLAQUUUEABBRRQQAEFFMiTQJs8bcfNKKCAAgoooIACCiiggAJYQDwJFFBAAQUUUEABBRRQIG8CFpC8UbshBRRQQAEFFFBAAQUUsIB4DiiggAIKKKCAAgoooEDeBCwgeaN2QwoooIACCiiggAIKKGAB8RxQQAEFFFBAAQUUUECBvAlYQPJG7YYUUEABBRRQQAEFFFDAAuI5oIACCiiggAIKKKCAAnkTsIDkjdoNKaCAAgoooIACCiiggAXEc0ABBRRQQAEFFFBAAQXyJmAByRu1G1JAAQUUUEABBRRQQAELiOeAAgoooIACCiiggAIK5E3AApI3ajekgAIKKKCAAgoooIACFhDPAQUUUEABBRRQQAEFFMibgAUkb9RuSAEFFFBAAQUUUEABBSwgngMKKKCAAgoooIACCiiQNwELSN6o3ZACCiiggAIKKKCAAgpYQDwHFFBAAQUUUEABBRRQIG8CFpC8UbshBRRQQAEFFFBAAQUUsIB4DiiggAIKKKCAAgoooEDeBCwgeaN2QwoooIACCiiggAIKKGAB8RxQQAEFFFBAAQUUUECBvAlYQPJG7YYUUEABBRRQQAEFFFDAAuI5oIACCiiggAIKKKCAAnkTsIDkjdoNKaCAAgoooIACCiiggAXEc0ABBRRQQAEFFFBAAQXyJmAByRu1G1JAAQUUUEABBRRQQAELiOeAAgoooIACCiiggAIK5E3AApI3ajekgAIKKKCAAgoooIACFhDPAQUUUEABBRRQQAEFFMibgAUkb9RuSAEFFFBAAQUUUEABBSwgngMKKKCAAgoooIACCiiQNwELSN6o3ZACCiiggAIKKKCAAgpYQDwHFFBAAQUUUEABBRRQIG8CFpC8UbshBRRQQAEFFFBAAQUUsIB4DiiggAIKKKCAAgoooEDeBCwgeaN2QwoooIACCiiggAIKKGAB8RxQQAEFFFBAAQUUUECBvAlYQPJG7YYUUEABBRRQQAEFFFDAAuI5oIACCiiggAIKKKCAAnkTsIDkjdoNKaCAAgoooIACCiiggAXEc0ABBRRQQAEFFFBAAQXyJmAByRu1G1JAAQUUUEABBRRQQAELiOeAAgoooIACCiiggAIK5E3AApI3ajekgAIKKKCAAgoooIACFhDPAQUUUEABBRRQQAEFFMibgAUkb9Q5bSiMx5rA7JyWdiEFFFBAAQUUUECBlhDoBnwMVLTExot9mxaQwhrBtYDJhbVL7o0CCiiggAIKKKBAHQK9gCnKNF7AAtJ4s+Zcozswc9KkSXTvHn7pSwEFFFBAAQUUUKCQBGbNmkXv3r3DLvUAZhXSvhXLvlhACmukKgvIzJkzLSCFNS7ujQIKKKCAAgooUCkQCkiPHqF7WECaekpYQJoq1zzrWUCax9VUBRRQQAEFFFAgiYAFJJ7RAhJvmDLBApJS0ywFFFBAAQUUUCCxgAUkHtQCEm+YMsECklLTLAUUUEABBRRQILGABSQe1AISb5gywQKSUtMsBRRQQAEFFFAgsYAFJB7UAhJvmDLBApJS0ywFFFBAAQUUUCCxgAUkHtQCEm+YMsECklLTLAUUUEABBRRQILGABSQe1AISb5gywQKSUtMsBRRQQAEFFFAgsYAFJB7UAhJvmDLBApJS0ywFFFBAAQUUUCCxgAUkHtQCEm+YMsECklLTLAUUUEABBRRQILGABSQe1AISb5gyIa8FpKKigg8+n8t6K3elTRtPhZQDaZYCCiiggAIKlKaABSR+XP2pM94wZULeCsiSsnJ2vewJPp65gKdO2521V+qS8jjMUkABBRRQQAEFSlLAAhI/rBaQeMOUCXkrIGGnv3/9c4z56EsuO7APBw/onfI4zFJAAQUUUEABBUpSwAISP6wWkHjDlAl5LSCXPfIWfx75Pt/fuhdXHNw35XGYpYACCiiggAIKlKSABSR+WC0g8YYpE/JaQJ585zN+fNMoeq24HM/8ZnDK4zBLAQUUUEABBRQoSQELSPywWkDiDVMm5LWAzFm4hL7nD6esvIJnzxjMWissl/JYzFJAAQUUUEABBUpOwAISP6QWkHjDlAl5LSBhx/e97lnGT5rBHw/uywFb90p5LGYpoIACCiiggAIlJ2ABiR9SC0i8YcqEvBeQix96k78+NYFDBvTm0gP7pDwWsxRQQAEFFFBAgZITsIDED6kFJN4wZULeC8jjb03lp7e8xLordWHkabunPBazFFBAAQUUUECBkhOwgMQPqQUk3jBlQt4LyKwFi+l3/nDKK+CFM/dg9R6dUx6PWQoooIACCiigQEkJWEDih9MCEm+YMiHvBSTs/D7XPM1rU2Zx9aH92LffWimPxywFFFBAAQUUUKCkBCwg8cNpAYk3TJnQIgXkwgff4O/PfMBhA9fm4gO2Snk8ZimggAIKKKCAAiUlYAGJH04LSLxhyoQWKSDDX/+UX9w+hvVX6crjp+6W8njMUkABBRRQQAEFSkrAAhI/nBaQeMOUCS1SQGbMW0T/Cx+jogJGnb0Hq3ZzHkjKQTVLAQUUUEABBUpHwAISP5YWkHjDlAktUkDCAex91VO89elsrj28P/v0WTPlMZmlgAIKKKCAAgqUjIAFJH4oLSDxhikTWqyAnPfA69zy3If8cPt1uHC/LVMek1kKKKCAAgoooEDJCFhA4ofSAhJvmDKhxQrII699wjH/GMvGqy3P8F8NSnlMZimggAIKKKCAAiUjYAGJH0oLSLxhyoQWKyBfzF3E1hc+VnksY88dQs+uHVMel1kKKKCAAgoooEBJCFhA4ofRAhJvmDKhxQpIOIihVz7JO1Pn8JcjtmbvLddIeVxmKaCAAgoooIACJSFgAYkfRgtIvGHKhBYtIOfe/xq3v/ARR+64Lud9b4uUx2WWAgoooIACCihQEgIWkPhhtIDEG6ZMaNEC8uArH3P8nS+z2RrdefikXVIel1kKKKCAAgoooEBJCFhA4ofRAhJvmDKhRQvItNkLGPi7EbRpA+POHUqPLh1SHptZCiiggAIKKKBA0QtYQOKH0AISb5gyoUULSDiQwVeMZMJnc/nbjwYwZPPVUh6bWQoooIACCiigQNELWEDih9ACEm+YMqHFC8iZ973KXaMm8vOd1+OcfTZPeWxmKaCAAgoooIACRS9gAYkfQgtIvGHKhBYvIP8eN4WT7h7Hlmt158ETnAeScnDNUkABBRRQQIHiF7CAxI+hBSTeMGVCixeQT2cuYPuLR9A2zAMZNpTunZ0HknKAzVJAAQUUUECB4hawgMSPnwUk3jBlQosXkHAwgy5/go+mz+OmIwcweFPngaQcYLMUUEABBRRQoLgFLCDx42cBiTdMmVAQBeT0f43nnpcmc/Su63PmtzdLeXxmKaCAAgoooIACRS1gAYkfPgtIvGHKhIIoIPeOmcyp/xxP394r8O/jdkp5fGYpoIACCiiggAJFLWABiR8+C0i8YcqEgiggk7+cx86XPkG7tm0YP2woy3dqn/IYzVJAAQUUUEABBYpWwAISP3QWkHjDlAkFUUDCAe186eNM/nI+t/50IIM2XiXlMZqlgAIKKKCAAgoUrYAFJH7oLCDxhikTCqaAnHrPeO4dO5lf7rYBp++9acpjNEsBBRRQQAEFFChaAQtI/NBZQOINUyYUTAG556VJnP6vV9hmnRW599gdUx6jWQoooIACCiigQNEKWEDih84CEm+YMqFgCsjE6fPY9fIn6NCuDa8M24vlOrZLeZxmKaCAAgoooIACRSlgAYkfNgtIvGHKhIIpIBUVFex4yeN8MnMBd/x8O3bacOWUx2mWAgoooIACCihQlAIWkPhhs4DEG6ZMKJgCEg7q5Ltf5v5xH3Pi4A05ZegmKY/TLAUUUEABBRRQoCgFLCDxw2
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<matplotlib.text.Text at 0x7f2e8c192e80>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
2017-09-29 18:54:05 +02:00
"source": [
"from mlp.layers import AffineLayer\n",
"from mlp.errors import SumOfSquaredDiffsError\n",
"from mlp.models import SingleLayerModel\n",
"from mlp.initialisers import UniformInit, ConstantInit\n",
"from mlp.learning_rules import GradientDescentLearningRule\n",
"from mlp.optimisers import Optimiser\n",
"import logging\n",
"\n",
"# Seed a random number generator\n",
"seed = 27092016 \n",
"rng = np.random.RandomState(seed)\n",
"\n",
"# Set up a logger object to print info about the training run to stdout\n",
"logger = logging.getLogger()\n",
"logger.setLevel(logging.INFO)\n",
"logger.handlers = [logging.StreamHandler()]\n",
"\n",
"# Create data provider objects for the CCPP training set\n",
"train_data = CCPPDataProvider('train', [0, 1], batch_size=100, rng=rng)\n",
"input_dim, output_dim = 2, 1\n",
"\n",
"# Create a parameter initialiser which will sample random uniform values\n",
"# from [-0.1, 0.1]\n",
"param_init = UniformInit(-0.1, 0.1, rng=rng)\n",
"\n",
"# Create our single layer model\n",
"layer = AffineLayer(input_dim, output_dim, param_init, param_init)\n",
"model = SingleLayerModel(layer)\n",
"\n",
"# Initialise the error object\n",
"error = SumOfSquaredDiffsError()\n",
"\n",
"# Use a basic gradient descent learning rule with a small learning rate\n",
"learning_rule = GradientDescentLearningRule(learning_rate=1e-2)\n",
"\n",
"# Use the created objects to initialise a new Optimiser instance.\n",
"optimiser = Optimiser(model, error, learning_rule, train_data)\n",
"\n",
"# Run the optimiser for 5 epochs (full passes through the training set)\n",
"# printing statistics every epoch.\n",
2017-10-06 15:46:19 +02:00
"stats, keys, _ = optimiser.train(num_epochs=10, stats_interval=1)\n",
2017-09-29 18:54:05 +02:00
"\n",
"# Plot the change in the error over training.\n",
"fig = plt.figure(figsize=(8, 4))\n",
"ax = fig.add_subplot(111)\n",
"ax.plot(np.arange(1, stats.shape[0] + 1), stats[:, keys['error(train)']])\n",
"ax.set_xlabel('Epoch number')\n",
"ax.set_ylabel('Error')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using similar code to previously we can now visualise the joint input-output space for the trained model. If you implemented the required methods correctly you should now see a much improved fit between predicted and target outputs when running the cell below."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 17,
2017-10-02 02:21:49 +02:00
"metadata": {},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"data": {
"application/javascript": [
"/* Put everything inside the global mpl namespace */\n",
"window.mpl = {};\n",
"\n",
"\n",
"mpl.get_websocket_type = function() {\n",
" if (typeof(WebSocket) !== 'undefined') {\n",
" return WebSocket;\n",
" } else if (typeof(MozWebSocket) !== 'undefined') {\n",
" return MozWebSocket;\n",
" } else {\n",
" alert('Your browser does not have WebSocket support.' +\n",
" 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
" 'Firefox 4 and 5 are also supported but you ' +\n",
" 'have to enable WebSockets in about:config.');\n",
" };\n",
"}\n",
"\n",
"mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
" this.id = figure_id;\n",
"\n",
" this.ws = websocket;\n",
"\n",
" this.supports_binary = (this.ws.binaryType != undefined);\n",
"\n",
" if (!this.supports_binary) {\n",
" var warnings = document.getElementById(\"mpl-warnings\");\n",
" if (warnings) {\n",
" warnings.style.display = 'block';\n",
" warnings.textContent = (\n",
" \"This browser does not support binary websocket messages. \" +\n",
" \"Performance may be slow.\");\n",
" }\n",
" }\n",
"\n",
" this.imageObj = new Image();\n",
"\n",
" this.context = undefined;\n",
" this.message = undefined;\n",
" this.canvas = undefined;\n",
" this.rubberband_canvas = undefined;\n",
" this.rubberband_context = undefined;\n",
" this.format_dropdown = undefined;\n",
"\n",
" this.image_mode = 'full';\n",
"\n",
" this.root = $('<div/>');\n",
" this._root_extra_style(this.root)\n",
" this.root.attr('style', 'display: inline-block');\n",
"\n",
" $(parent_element).append(this.root);\n",
"\n",
" this._init_header(this);\n",
" this._init_canvas(this);\n",
" this._init_toolbar(this);\n",
"\n",
" var fig = this;\n",
"\n",
" this.waiting = false;\n",
"\n",
" this.ws.onopen = function () {\n",
" fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
" fig.send_message(\"send_image_mode\", {});\n",
" if (mpl.ratio != 1) {\n",
" fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
" }\n",
" fig.send_message(\"refresh\", {});\n",
" }\n",
"\n",
" this.imageObj.onload = function() {\n",
" if (fig.image_mode == 'full') {\n",
" // Full images could contain transparency (where diff images\n",
" // almost always do), so we need to clear the canvas so that\n",
" // there is no ghosting.\n",
" fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
" }\n",
" fig.context.drawImage(fig.imageObj, 0, 0);\n",
" };\n",
"\n",
" this.imageObj.onunload = function() {\n",
" this.ws.close();\n",
" }\n",
"\n",
" this.ws.onmessage = this._make_on_message_function(this);\n",
"\n",
" this.ondownload = ondownload;\n",
"}\n",
"\n",
"mpl.figure.prototype._init_header = function() {\n",
" var titlebar = $(\n",
" '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
" 'ui-helper-clearfix\"/>');\n",
" var titletext = $(\n",
" '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
" 'text-align: center; padding: 3px;\"/>');\n",
" titlebar.append(titletext)\n",
" this.root.append(titlebar);\n",
" this.header = titletext[0];\n",
"}\n",
"\n",
"\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._init_canvas = function() {\n",
" var fig = this;\n",
"\n",
" var canvas_div = $('<div/>');\n",
"\n",
" canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
"\n",
" function canvas_keyboard_event(event) {\n",
" return fig.key_event(event, event['data']);\n",
" }\n",
"\n",
" canvas_div.keydown('key_press', canvas_keyboard_event);\n",
" canvas_div.keyup('key_release', canvas_keyboard_event);\n",
" this.canvas_div = canvas_div\n",
" this._canvas_extra_style(canvas_div)\n",
" this.root.append(canvas_div);\n",
"\n",
" var canvas = $('<canvas/>');\n",
" canvas.addClass('mpl-canvas');\n",
" canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
"\n",
" this.canvas = canvas[0];\n",
" this.context = canvas[0].getContext(\"2d\");\n",
"\n",
" var backingStore = this.context.backingStorePixelRatio ||\n",
"\tthis.context.webkitBackingStorePixelRatio ||\n",
"\tthis.context.mozBackingStorePixelRatio ||\n",
"\tthis.context.msBackingStorePixelRatio ||\n",
"\tthis.context.oBackingStorePixelRatio ||\n",
"\tthis.context.backingStorePixelRatio || 1;\n",
"\n",
" mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
"\n",
" var rubberband = $('<canvas/>');\n",
" rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
"\n",
" var pass_mouse_events = true;\n",
"\n",
" canvas_div.resizable({\n",
" start: function(event, ui) {\n",
" pass_mouse_events = false;\n",
" },\n",
" resize: function(event, ui) {\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" stop: function(event, ui) {\n",
" pass_mouse_events = true;\n",
" fig.request_resize(ui.size.width, ui.size.height);\n",
" },\n",
" });\n",
"\n",
" function mouse_event_fn(event) {\n",
" if (pass_mouse_events)\n",
" return fig.mouse_event(event, event['data']);\n",
" }\n",
"\n",
" rubberband.mousedown('button_press', mouse_event_fn);\n",
" rubberband.mouseup('button_release', mouse_event_fn);\n",
" // Throttle sequential mouse events to 1 every 20ms.\n",
" rubberband.mousemove('motion_notify', mouse_event_fn);\n",
"\n",
" rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
" rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
"\n",
" canvas_div.on(\"wheel\", function (event) {\n",
" event = event.originalEvent;\n",
" event['data'] = 'scroll'\n",
" if (event.deltaY < 0) {\n",
" event.step = 1;\n",
" } else {\n",
" event.step = -1;\n",
" }\n",
" mouse_event_fn(event);\n",
" });\n",
"\n",
" canvas_div.append(canvas);\n",
" canvas_div.append(rubberband);\n",
"\n",
" this.rubberband = rubberband;\n",
" this.rubberband_canvas = rubberband[0];\n",
" this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
" this.rubberband_context.strokeStyle = \"#000000\";\n",
"\n",
" this._resize_canvas = function(width, height) {\n",
" // Keep the size of the canvas, canvas container, and rubber band\n",
" // canvas in synch.\n",
" canvas_div.css('width', width)\n",
" canvas_div.css('height', height)\n",
"\n",
" canvas.attr('width', width * mpl.ratio);\n",
" canvas.attr('height', height * mpl.ratio);\n",
" canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
"\n",
" rubberband.attr('width', width);\n",
" rubberband.attr('height', height);\n",
" }\n",
"\n",
" // Set the figure to an initial 600x600px, this will subsequently be updated\n",
" // upon first draw.\n",
" this._resize_canvas(600, 600);\n",
"\n",
" // Disable right mouse context menu.\n",
" $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
" return false;\n",
" });\n",
"\n",
" function set_focus () {\n",
" canvas.focus();\n",
" canvas_div.focus();\n",
" }\n",
"\n",
" window.setTimeout(set_focus, 100);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items) {\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) {\n",
" // put a spacer in here.\n",
" continue;\n",
" }\n",
" var button = $('<button/>');\n",
" button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
" 'ui-button-icon-only');\n",
" button.attr('role', 'button');\n",
" button.attr('aria-disabled', 'false');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
"\n",
" var icon_img = $('<span/>');\n",
" icon_img.addClass('ui-button-icon-primary ui-icon');\n",
" icon_img.addClass(image);\n",
" icon_img.addClass('ui-corner-all');\n",
"\n",
" var tooltip_span = $('<span/>');\n",
" tooltip_span.addClass('ui-button-text');\n",
" tooltip_span.html(tooltip);\n",
"\n",
" button.append(icon_img);\n",
" button.append(tooltip_span);\n",
"\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" var fmt_picker_span = $('<span/>');\n",
"\n",
" var fmt_picker = $('<select/>');\n",
" fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
" fmt_picker_span.append(fmt_picker);\n",
" nav_element.append(fmt_picker_span);\n",
" this.format_dropdown = fmt_picker[0];\n",
"\n",
" for (var ind in mpl.extensions) {\n",
" var fmt = mpl.extensions[ind];\n",
" var option = $(\n",
" '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
" fmt_picker.append(option)\n",
" }\n",
"\n",
" // Add hover states to the ui-buttons\n",
" $( \".ui-button\" ).hover(\n",
" function() { $(this).addClass(\"ui-state-hover\");},\n",
" function() { $(this).removeClass(\"ui-state-hover\");}\n",
" );\n",
"\n",
" var status_bar = $('<span class=\"mpl-message\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"}\n",
"\n",
"mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
" // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
" // which will in turn request a refresh of the image.\n",
" this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
"}\n",
"\n",
"mpl.figure.prototype.send_message = function(type, properties) {\n",
" properties['type'] = type;\n",
" properties['figure_id'] = this.id;\n",
" this.ws.send(JSON.stringify(properties));\n",
"}\n",
"\n",
"mpl.figure.prototype.send_draw_message = function() {\n",
" if (!this.waiting) {\n",
" this.waiting = true;\n",
" this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
" }\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" var format_dropdown = fig.format_dropdown;\n",
" var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
" fig.ondownload(fig, format);\n",
"}\n",
"\n",
"\n",
"mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
" var size = msg['size'];\n",
" if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
" fig._resize_canvas(size[0], size[1]);\n",
" fig.send_message(\"refresh\", {});\n",
" };\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
" var x0 = msg['x0'] / mpl.ratio;\n",
" var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
" var x1 = msg['x1'] / mpl.ratio;\n",
" var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
" x0 = Math.floor(x0) + 0.5;\n",
" y0 = Math.floor(y0) + 0.5;\n",
" x1 = Math.floor(x1) + 0.5;\n",
" y1 = Math.floor(y1) + 0.5;\n",
" var min_x = Math.min(x0, x1);\n",
" var min_y = Math.min(y0, y1);\n",
" var width = Math.abs(x1 - x0);\n",
" var height = Math.abs(y1 - y0);\n",
"\n",
" fig.rubberband_context.clearRect(\n",
" 0, 0, fig.canvas.width, fig.canvas.height);\n",
"\n",
" fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
" // Updates the figure title.\n",
" fig.header.textContent = msg['label'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
" var cursor = msg['cursor'];\n",
" switch(cursor)\n",
" {\n",
" case 0:\n",
" cursor = 'pointer';\n",
" break;\n",
" case 1:\n",
" cursor = 'default';\n",
" break;\n",
" case 2:\n",
" cursor = 'crosshair';\n",
" break;\n",
" case 3:\n",
" cursor = 'move';\n",
" break;\n",
" }\n",
" fig.rubberband_canvas.style.cursor = cursor;\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_message = function(fig, msg) {\n",
" fig.message.textContent = msg['message'];\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
" // Request the server to send over a new figure.\n",
" fig.send_draw_message();\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
" fig.image_mode = msg['mode'];\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Called whenever the canvas gets updated.\n",
" this.send_message(\"ack\", {});\n",
"}\n",
"\n",
"// A function to construct a web socket function for onmessage handling.\n",
"// Called in the figure constructor.\n",
"mpl.figure.prototype._make_on_message_function = function(fig) {\n",
" return function socket_on_message(evt) {\n",
" if (evt.data instanceof Blob) {\n",
" /* FIXME: We get \"Resource interpreted as Image but\n",
" * transferred with MIME type text/plain:\" errors on\n",
" * Chrome. But how to set the MIME type? It doesn't seem\n",
" * to be part of the websocket stream */\n",
" evt.data.type = \"image/png\";\n",
"\n",
" /* Free the memory for the previous frames */\n",
" if (fig.imageObj.src) {\n",
" (window.URL || window.webkitURL).revokeObjectURL(\n",
" fig.imageObj.src);\n",
" }\n",
"\n",
" fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
" evt.data);\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
" else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
" fig.imageObj.src = evt.data;\n",
" fig.updated_canvas_event();\n",
" fig.waiting = false;\n",
" return;\n",
" }\n",
"\n",
" var msg = JSON.parse(evt.data);\n",
" var msg_type = msg['type'];\n",
"\n",
" // Call the \"handle_{type}\" callback, which takes\n",
" // the figure and JSON message as its only arguments.\n",
" try {\n",
" var callback = fig[\"handle_\" + msg_type];\n",
" } catch (e) {\n",
" console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
" return;\n",
" }\n",
"\n",
" if (callback) {\n",
" try {\n",
" // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
" callback(fig, msg);\n",
" } catch (e) {\n",
" console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
" }\n",
" }\n",
" };\n",
"}\n",
"\n",
"// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
"mpl.findpos = function(e) {\n",
" //this section is from http://www.quirksmode.org/js/events_properties.html\n",
" var targ;\n",
" if (!e)\n",
" e = window.event;\n",
" if (e.target)\n",
" targ = e.target;\n",
" else if (e.srcElement)\n",
" targ = e.srcElement;\n",
" if (targ.nodeType == 3) // defeat Safari bug\n",
" targ = targ.parentNode;\n",
"\n",
" // jQuery normalizes the pageX and pageY\n",
" // pageX,Y are the mouse positions relative to the document\n",
" // offset() returns the position of the element relative to the document\n",
" var x = e.pageX - $(targ).offset().left;\n",
" var y = e.pageY - $(targ).offset().top;\n",
"\n",
" return {\"x\": x, \"y\": y};\n",
"};\n",
"\n",
"/*\n",
" * return a copy of an object with only non-object keys\n",
" * we need this to avoid circular references\n",
" * http://stackoverflow.com/a/24161582/3208463\n",
" */\n",
"function simpleKeys (original) {\n",
" return Object.keys(original).reduce(function (obj, key) {\n",
" if (typeof original[key] !== 'object')\n",
" obj[key] = original[key]\n",
" return obj;\n",
" }, {});\n",
"}\n",
"\n",
"mpl.figure.prototype.mouse_event = function(event, name) {\n",
" var canvas_pos = mpl.findpos(event)\n",
"\n",
" if (name === 'button_press')\n",
" {\n",
" this.canvas.focus();\n",
" this.canvas_div.focus();\n",
" }\n",
"\n",
" var x = canvas_pos.x * mpl.ratio;\n",
" var y = canvas_pos.y * mpl.ratio;\n",
"\n",
" this.send_message(name, {x: x, y: y, button: event.button,\n",
" step: event.step,\n",
" guiEvent: simpleKeys(event)});\n",
"\n",
" /* This prevents the web browser from automatically changing to\n",
" * the text insertion cursor when the button is pressed. We want\n",
" * to control all of the cursor setting manually through the\n",
" * 'cursor' event from matplotlib */\n",
" event.preventDefault();\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" // Handle any extra behaviour associated with a key event\n",
"}\n",
"\n",
"mpl.figure.prototype.key_event = function(event, name) {\n",
"\n",
" // Prevent repeat events\n",
" if (name == 'key_press')\n",
" {\n",
" if (event.which === this._key)\n",
" return;\n",
" else\n",
" this._key = event.which;\n",
" }\n",
" if (name == 'key_release')\n",
" this._key = null;\n",
"\n",
" var value = '';\n",
" if (event.ctrlKey && event.which != 17)\n",
" value += \"ctrl+\";\n",
" if (event.altKey && event.which != 18)\n",
" value += \"alt+\";\n",
" if (event.shiftKey && event.which != 16)\n",
" value += \"shift+\";\n",
"\n",
" value += 'k';\n",
" value += event.which.toString();\n",
"\n",
" this._key_event_extra(event, name);\n",
"\n",
" this.send_message(name, {key: value,\n",
" guiEvent: simpleKeys(event)});\n",
" return false;\n",
"}\n",
"\n",
"mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
" if (name == 'download') {\n",
" this.handle_save(this, null);\n",
" } else {\n",
" this.send_message(\"toolbar_button\", {name: name});\n",
" }\n",
"};\n",
"\n",
"mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
" this.message.textContent = tooltip;\n",
"};\n",
"mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
"\n",
"mpl.extensions = [\"eps\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\"];\n",
"\n",
"mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
" // Create a \"websocket\"-like object which calls the given IPython comm\n",
" // object with the appropriate methods. Currently this is a non binary\n",
" // socket, so there is still some room for performance tuning.\n",
" var ws = {};\n",
"\n",
" ws.close = function() {\n",
" comm.close()\n",
" };\n",
" ws.send = function(m) {\n",
" //console.log('sending', m);\n",
" comm.send(m);\n",
" };\n",
" // Register the callback with on_msg.\n",
" comm.on_msg(function(msg) {\n",
" //console.log('receiving', msg['content']['data'], msg);\n",
" // Pass the mpl event to the overriden (by mpl) onmessage function.\n",
" ws.onmessage(msg['content']['data'])\n",
" });\n",
" return ws;\n",
"}\n",
"\n",
"mpl.mpl_figure_comm = function(comm, msg) {\n",
" // This is the function which gets called when the mpl process\n",
" // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
"\n",
" var id = msg.content.data.id;\n",
" // Get hold of the div created by the display call when the Comm\n",
" // socket was opened in Python.\n",
" var element = $(\"#\" + id);\n",
" var ws_proxy = comm_websocket_adapter(comm)\n",
"\n",
" function ondownload(figure, format) {\n",
" window.open(figure.imageObj.src);\n",
" }\n",
"\n",
" var fig = new mpl.figure(id, ws_proxy,\n",
" ondownload,\n",
" element.get(0));\n",
"\n",
" // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
" // web socket which is closed, not our websocket->open comm proxy.\n",
" ws_proxy.onopen();\n",
"\n",
" fig.parent_element = element.get(0);\n",
" fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
" if (!fig.cell_info) {\n",
" console.error(\"Failed to find cell for figure\", id, fig);\n",
" return;\n",
" }\n",
"\n",
" var output_index = fig.cell_info[2]\n",
" var cell = fig.cell_info[0];\n",
"\n",
"};\n",
"\n",
"mpl.figure.prototype.handle_close = function(fig, msg) {\n",
" var width = fig.canvas.width/mpl.ratio\n",
" fig.root.unbind('remove')\n",
"\n",
" // Update the output cell to use the data from the current canvas.\n",
" fig.push_to_output();\n",
" var dataURL = fig.canvas.toDataURL();\n",
" // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
" // the notebook keyboard shortcuts fail.\n",
" IPython.keyboard_manager.enable()\n",
" $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
" fig.close_ws(fig, msg);\n",
"}\n",
"\n",
"mpl.figure.prototype.close_ws = function(fig, msg){\n",
" fig.send_message('closing', msg);\n",
" // fig.ws.close()\n",
"}\n",
"\n",
"mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
" // Turn the data on the canvas into data in the output cell.\n",
" var width = this.canvas.width/mpl.ratio\n",
" var dataURL = this.canvas.toDataURL();\n",
" this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
"}\n",
"\n",
"mpl.figure.prototype.updated_canvas_event = function() {\n",
" // Tell IPython that the notebook contents must change.\n",
" IPython.notebook.set_dirty(true);\n",
" this.send_message(\"ack\", {});\n",
" var fig = this;\n",
" // Wait a second, then push the new image to the DOM so\n",
" // that it is saved nicely (might be nice to debounce this).\n",
" setTimeout(function () { fig.push_to_output() }, 1000);\n",
"}\n",
"\n",
"mpl.figure.prototype._init_toolbar = function() {\n",
" var fig = this;\n",
"\n",
" var nav_element = $('<div/>')\n",
" nav_element.attr('style', 'width: 100%');\n",
" this.root.append(nav_element);\n",
"\n",
" // Define a callback function for later on.\n",
" function toolbar_event(event) {\n",
" return fig.toolbar_button_onclick(event['data']);\n",
" }\n",
" function toolbar_mouse_event(event) {\n",
" return fig.toolbar_button_onmouseover(event['data']);\n",
" }\n",
"\n",
" for(var toolbar_ind in mpl.toolbar_items){\n",
" var name = mpl.toolbar_items[toolbar_ind][0];\n",
" var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
" var image = mpl.toolbar_items[toolbar_ind][2];\n",
" var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
"\n",
" if (!name) { continue; };\n",
"\n",
" var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
" button.click(method_name, toolbar_event);\n",
" button.mouseover(tooltip, toolbar_mouse_event);\n",
" nav_element.append(button);\n",
" }\n",
"\n",
" // Add the status bar.\n",
" var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
" nav_element.append(status_bar);\n",
" this.message = status_bar[0];\n",
"\n",
" // Add the close button to the window.\n",
" var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
" var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
" button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
" button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
" buttongrp.append(button);\n",
" var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
" titlebar.prepend(buttongrp);\n",
"}\n",
"\n",
"mpl.figure.prototype._root_extra_style = function(el){\n",
" var fig = this\n",
" el.on(\"remove\", function(){\n",
"\tfig.close_ws(fig, {});\n",
" });\n",
"}\n",
"\n",
"mpl.figure.prototype._canvas_extra_style = function(el){\n",
" // this is important to make the div 'focusable\n",
" el.attr('tabindex', 0)\n",
" // reach out to IPython and tell the keyboard manager to turn it's self\n",
" // off when our div gets focus\n",
"\n",
" // location in version 3\n",
" if (IPython.notebook.keyboard_manager) {\n",
" IPython.notebook.keyboard_manager.register_events(el);\n",
" }\n",
" else {\n",
" // location in version 2\n",
" IPython.keyboard_manager.register_events(el);\n",
" }\n",
"\n",
"}\n",
"\n",
"mpl.figure.prototype._key_event_extra = function(event, name) {\n",
" var manager = IPython.notebook.keyboard_manager;\n",
" if (!manager)\n",
" manager = IPython.keyboard_manager;\n",
"\n",
" // Check for shift+enter\n",
" if (event.shiftKey && event.which == 13) {\n",
" this.canvas_div.blur();\n",
" // select the cell after this one\n",
" var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n",
" IPython.notebook.select(index + 1);\n",
" }\n",
"}\n",
"\n",
"mpl.figure.prototype.handle_save = function(fig, msg) {\n",
" fig.ondownload(fig, null);\n",
"}\n",
"\n",
"\n",
"mpl.find_output_cell = function(html_output) {\n",
" // Return the cell and output element which can be found *uniquely* in the notebook.\n",
" // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
" // IPython event is triggered only after the cells have been serialised, which for\n",
" // our purposes (turning an active figure into a static one), is too late.\n",
" var cells = IPython.notebook.get_cells();\n",
" var ncells = cells.length;\n",
" for (var i=0; i<ncells; i++) {\n",
" var cell = cells[i];\n",
" if (cell.cell_type === 'code'){\n",
" for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
" var data = cell.output_area.outputs[j];\n",
" if (data.data) {\n",
" // IPython >= 3 moved mimebundle to data attribute of output\n",
" data = data.data;\n",
" }\n",
" if (data['text/html'] == html_output) {\n",
" return [cell, data, j];\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n",
"// Register the function which deals with the matplotlib target/channel.\n",
"// The kernel may be null if the page has been refreshed.\n",
"if (IPython.notebook.kernel != null) {\n",
" IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
"}\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAyAAAAMgCAYAAADbcAZoAAAgAElEQVR4XuydB9QeRb3/h6v03gmdUBN6SUjASEkhisd6FWxgvWJBEWwIF0QBERWxIaD3ithFRY8FElOAEBIIUoTQIUCoQlBBqXr5n8/cO/lvNs/z7Mw+M7v7PPudc97zBt7dnZnvzM7+vr+6glETAkJACAgBISAEhIAQEAJCQAhUhMAKFfWjboSAEBACQkAICAEhIASEgBAQAkYERJtACAgBISAEhIAQEAJCQAgIgcoQEAGpDGp1JASEgBAQAkJACAgBISAEhIAIiPaAEBACQkAICAEhIASEgBAQApUhIAJSGdTqSAgIASEgBISAEBACQkAICAEREO0BISAEhIAQEAJCQAgIASEgBCpDQASkMqjVkRAQAkJACAgBISAEhIAQEAIiINoDQkAICAEhIASEgBAQAkJACFSGgAhIZVCrIyEgBISAEBACQkAICAEhIAREQLQHhIAQEAJCQAgIASEgBISAEKgMARGQyqBWR0JACAgBISAEhIAQEAJCQAiIgGgPCAEhIASEgBAQAkJACAgBIVAZAiIglUGtjoSAEBACQkAICAEhIASEgBAQAdEeEAJCQAgIASEgBISAEBACQqAyBERAKoNaHQkBISAEhIAQEAJCQAgIASEgAqI9IASEgBAQAkJACAgBISAEhEBlCIiAVAa1OhICQkAICAEhIASEgBAQAkJABER7QAgIASEgBISAEBACQkAICIHKEIhKQF588cUXKxu5OhICQkAICAEhIASEgBAQAkKgMgRWWGGFKNwhykPcrEVAKlt/dSQEhIAQEAJCQAgIASEgBCpFQASkUrjVmRAQAkJACAgBISAEhIAQaDcCIiDtXn/NXggIASEgBISAEBACQkAIVIqACEilcKszISAEhIAQEAJCQAgIASHQbgREQNq9/pq9EBACQkAICAEhIASEgBCoFAERkErhVmdCQAgIASEgBISAEBACQqDdCIiAtHv9NXshIASEgBAQAkJACAgBIVApAiIglcKtzoSAEBACQkAICAEhIASEQLsREAFp9/pr9kJACAgBISAEhIAQEAJCoFIEREAqhVudCQEhIASEgBAQAkJACAiBdiMgAtLu9dfshYAQEAJCQAgIASEgBIRApQiIgFQKtzoTAkJACAgBISAEhIAQEALtRkAEpN3rr9kLASEgBISAEBACQkAI/B8CK6ywQk8sTj75ZPOZz3ymVrw+9alPmcsuu8zMnz+/1nH007kISD/o6V4hIASEgBAQAkJACAiBoUHgkUceWTqXn/70p+akk04yt99++9L/t8Yaaxh+Qtvzzz9vVlpppdDbOl4vAvL/YelNFwPhfvHFF18MvEWXCwEhIASEgBAQAkJACAiBaAhccMEF5phjjjF//etfl3nmc889Z973vveZ2bNnmz//+c9mq622Mh/+8IfNBz7wgaXXHX744fbfO+64oznvvPPMuuuua2699VbzwAMPmHe/+93m8ssvN5tttpn5whe+YD70oQ9Zq8pRRx1l71myZIk57rjjzG9/+1vzwgsvmLFjx5qzzz7b7Lzzzubcc88173//+5cZz49//GNz2GGHmRNPPNFceOGFdkwbbrihYQxf+tKXouER80GygMREU88SAkJACAgBISAEhIAQGAoEuhGQf/zjH+bLX/6yOfTQQ816661nrrjiCksesJi8+tWvtnNH+P/d735nicGxxx5r/9/o0aPNhAkTDATmm9/8psHd66Mf/ahZsGCBJRiOgLz85S+3BOKEE06w1pavf/3r5qKLLjJ33HGHWXHFFc0nPvEJc9VVV9nn09ZZZx3z85//3JKgn/zkJ2annXYyDz/8sFm4cKF517ve1ci1EAFp5LJoUEJACAgBISAEhIAQEAJLETj3XGPOOMOYT33KmP+zFKRGpxsB6dTve97zHvPss8+aH/zgB0sJCCThnnvuMS996Uvt/7vhhhvMnnvuaW666Sazyy672P938803m1133dV861vfsgRkxowZlrxAICAbNByDttxyS3PaaaeZI444wnRywTr99NMNlhD6eMlLXpIamr6fLwLSN4R6gBAQAkJACAgBISAEhEBSBLbe2pj77jNmq62MuffepF25h/ciIFgscHe6//77LfEgxmPcuHHWGkKDRGAp+c1vfrN0rFhIsEjw/7NttdVWM2eddZYlIFhWsHCsuuqqy1zzzDPPWBerU045pSMBgehgXYHsTJ061Vpn+GkqGREBqWQLqxMhIASEgBAQAkJACAiB0gg0yAICMfngBz9ovvKVr5gxY8aYNddc05x66qnmtttuW5qZysWA4BLlmg8BgWBAbKZNm7YcVMSRrL/++h0JCBc//fTTZvr06daKgsvWqFGjzMyZMxtJQkRASr8JulEICAEhIASEgBAQAkJgWBHoZgF573vfax566KGlMRjM/2Uve5n55z//2ZOAOBcs3K4IKKcRp4E7lnPBwmLy7//+79Z1iyD1To3MXJdccomNHenWbrzxRrPHHnvY5xN70rQmAtK0FdF4hIAQEAJCQAgIASEgBGpHoBsBOfPMM80ZZ5xhrQxbbLGF+a//+i+bnQqLg6vN0ckCwoRwkyKz1Te+8Q07P7JdXX311earX/2qzaz1r3/9y+y33372N31su+225sEHH7SuXG9961vNbrvtZv77v//bBrbPmTPHjBgxwqy11lrWaoL7FRYZ3LfIvHXOOedYooSFpmlNBKRpK6LxCAEhIASEgBAQAkJACNSOQDcCQjwGVhBIAUI/xICA8blz5xYSkMWLF9s0vMSKYOGAzBAX8rWvfc0ceeSRds5/+9vfzPHHH28uvvhi88QTT5hNNtnEHHjggZaQQDiIIaFPihFyLcHnxHp88YtftG5gBK1DVD7/+c8bMmo1sYmANHFVNCYhIASEgBAQAkJACAiBoUfg7rvvNtttt5258sorzf777z/083UTFAFpzVJrokJACAgBISAEhIAQEAJ1IkCQOBmziAHBtepjH/uY+ctf/mJuueWWRgaLp8JKBCQVsnquEBACQkAICAEhIASEgBDIIIDb1ic/+Ulz77332tgNYkLIprX55pu3CicRkFYttyYrBISAEBACQkAICAEhIATqRUAEpF781bsQEAJCQAgIASEgBISAEGgVAiIgrVpuTVYICAEhIASEgBAQAkJACNSLgAhIvfirdyEgBISAEBACQkAICAEh0CoEREBatdyarBAQAkJACAgBISAEhIAQqBcBEZB68VfvQkAICAEhIASEgBAQAkKgVQiIgLRquTVZISAEhIAQEAJCQAgIASFQLwIiIPXir96FgBAQAkJACAgBISAEWooA9UC22WYbc/3115s99tjDXHbZZeaggw6yxQnXWWedUqjEeEapjgNuEgEJAEuXCgEhIASEgBAQAkJACAw/Au94xzvM9773PTvRFVdc0Wy55ZbmiCOOMJ/+9KfNS1/60mgA5AkIVdKfeOIJs/HGG5sVVlihsJ8DDzzQEpezzz576bWhzyjsJMEFIiAJQNUjhYAQEAJCQAgIASEgBAYXAQjIo48+ar773e+a5557zvz+9783H/zgB81pp51mjj/++GUm9q9//cuShX/7t38LnnCegIQ+oBMBCX1GHdeLgNSBuvoUAkJACAgBISAEhIAQaCwCEJC//vWv5le/+tXSMU6ZMsU89dRT5n3ve5855phjzIUXXmg+9alPmTvuuMPcddddZuuttzbf+c53zJe//GWzaNEi+98f/vCHzQc+8IGlz7jmmmvs/bfeeqvZZZddzAknnGBe//rX93TBmjt3rr2Oe1deeWUzduxY85Of/MR89KMfXWqlcR3QL6Qm78b1i1/8wpx00kl2nCNGjDBHH320Oe6445aOi7H+x3/8h/37RRddZNZdd11z4okn2v9Hw6py7LHHGp6DexgWmqOOOmo5Mua7oCIgvkjpOiEgBISAEBACQkAICIFWINCJgLzmNa8xDzzwgBXeEczHjBljvvjFL5r111/fbLHFFubiiy82H//4x803vv
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2017-09-29 18:54:05 +02:00
"source": [
"data_provider = CCPPDataProvider(\n",
" which_set='train',\n",
" input_dims=[0, 1],\n",
" batch_size=5000, \n",
" max_num_batches=1, \n",
" shuffle_order=False\n",
")\n",
"\n",
"inputs, targets = data_provider.next()\n",
"\n",
"# Calculate predicted model outputs\n",
"outputs = model.fprop(inputs)[-1]\n",
"\n",
"# Plot target and predicted outputs against inputs on same axis\n",
"fig = plt.figure(figsize=(8, 8))\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"ax.plot(inputs[:, 0], inputs[:, 1], targets[:, 0], 'r.', ms=2)\n",
"ax.plot(inputs[:, 0], inputs[:, 1], outputs[:, 0], 'b.', ms=2)\n",
"ax.set_xlabel('Input dim 1')\n",
"ax.set_ylabel('Input dim 2')\n",
"ax.set_zlabel('Output')\n",
"ax.legend(['Targets', 'Predictions'], frameon=False)\n",
"fig.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercise 6: visualising training trajectories in parameter space"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Running the cell below will display an interactive widget which plots the trajectories of gradient-based training of the single-layer affine model on the CCPP dataset in the three dimensional parameter space (two weights plus bias) from random initialisations. Also shown on the right is a plot of the evolution of the error function (evaluated on the current batch) over training. By moving the sliders you can alter the training hyperparameters to investigate the effect they have on how training procedes.\n",
"\n",
"Some questions to explore:\n",
"\n",
" * Are there multiple local minima in parameter space here? Why?\n",
" * What happens to learning for very small learning rates? And very large learning rates?\n",
" * How does the batch size affect learning?\n",
" \n",
"**Note:** You don't need to understand how the code below works. The idea of this exercise is to help you understand the role of the various hyperparameters involved in gradient-descent based training methods."
]
},
{
"cell_type": "code",
2017-10-06 15:46:19 +02:00
"execution_count": 18,
2017-09-29 18:54:05 +02:00
"metadata": {
"scrolled": false
},
2017-10-06 15:46:19 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3UAAAHFCAYAAABVW9B8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XecVOWh//HPMzPbgQVhacuyLCBVKSo2sPcoKnZFLBhN\nvUmUm5hifmquMdF7TaL3+ooxUeyaWFCxYNeooGJBpCkgvUrdPu08vz9mz7CdXdw9077v12tfMGfO\nnOc5swP7fPdpxlqLiIiIiIiIpCZfoisgIiIiIiIi+06hTkREREREJIUp1ImIiIiIiKQwhToRERER\nEZEUplAnIiIiIiKSwhTqREREREREUphCnYiIiIiISApTqBMREREREUlhCnUiIiIiIiIpLNDO822n\n1EJERKRzmURXQEREpLOop05ERERERCSFKdSJiIiIiIikMIU6ERERERGRFKZQJyIiIiIiksIU6kRE\nRERERFKYQp2IiIiIiEgKU6gTERERERFJYQp1IiIiIiIiKUyhTkREREREJIUp1ImIiIiIiKQwhToR\nEREREZEUplAnIiIiIiKSwhTqREREREREUphCnYiIiIiISApTqBMREREREUlhCnUiIiIiIiIpTKFO\nREREREQkhSnUiYiIiIiIpDCFOpE2evTRRzn55JMTXQ3Wrl1Lly5diEajia5K0nj33XcZPnx4h58r\nIiIikgqMtbY957frZEmsQYMGsWXLFvx+PwUFBZx22mn83//9H126dEl01Vp1xRVXMGDAAG655Zak\nvua+WL16NWVlZYTDYQKBQELrkixuuukmVqxYwSOPPJLoqkh6M4mugIiISGdRT12amz17NpWVlXz6\n6ad8/PHH+xRqIpFIJ9Ss8+xLfVPlHlOlni3Zl/pba3EcpxNqIyIiIpIeFOoyRHFxMaeddhqLFi0C\nYObMmYwcOZKuXbsyePBg/va3v8XPffvttxkwYAC33XYbffv25corr2Tnzp2cccYZFBUV0aNHD844\n4wzWr18ff82xxx7LDTfcwJFHHkmXLl2YPHky27dvZ+rUqXTr1o0JEyawevXq+PnLli3jpJNOYr/9\n9mP48OH861//AuDee+/l0Ucf5fbbb49fB2Djxo2ce+65FBUVUVZWxl133RW/1k033cR5553HpZde\nSrdu3XjggQca3HtL1xw0aBC33XYbY8aMoaCggEgkwh//+EeGDBlC165dGTVqFLNmzYpf54EHHmDS\npEl7vQeAmpoaZsyYQWlpKYWFhUyaNImamhqOPvpoALp3706XLl2YN28ejuNwyy23UFpaSu/evbns\nssvYvXs3EOvZM8Zw3333MXDgQI4//vj4MTcg7d69m6uuuop+/fpRXFzMDTfcEB+auWLFCo455hgK\nCwvp1asXF154YYufkeeff57Ro0fTvXt3jj32WJYuXQrAbbfdxnnnndfg3J/+9Kf85Cc/2Wv5Dzzw\nABMnTuTaa6+lZ8+e3HTTTQ2uM2fOHG699Vb++c9/0qVLF8aOHRv/PP3mN79h4sSJ5Ofn8/XXX7fp\nM+saNGgQ//M//8OYMWMoLCzkwgsvpLa2tt3nAtx+++3069eP/v37849//ANjDCtWrGjxfRQRERHx\nnLW2PV+SQkpLS+1rr71mrbV27dq1dtSoUfaGG26w1lr7wgsv2BUrVljHcezbb79t8/Ly7CeffGKt\ntfatt96yfr/f/uIXv7C1tbW2urrabtu2zT711FO2qqrKlpeX2/POO8+eddZZ8bKOOeYYO2TIELti\nxQq7a9cuO3LkSLv//vvb1157zYbDYTtt2jR7xRVXWGutraystAMGDLD333+/DYfD9tNPP7U9e/a0\nixcvttZae/nll9vf/OY38WtHo1F70EEH2ZtvvtkGg0G7cuVKW1ZWZufMmWOttfbGG2+0gUDAzpo1\ny0ajUVtdXd3kvWh8Tff9GTt2rF27dm38Nf/617/shg0bbDQatU888YTNz8+3GzdutNZaO3PmTDtx\n4sQ23cMPf/hDe8wxx9j169fbSCRi33//fVtbW2tXrVplARsOh+P1uO++++yQIUPsypUrbUVFhZ0y\nZYq99NJLrbU2fv60adNsZWWlra6ubnKNs88+215zzTW2srLSbtmyxU6YMMHec8891lprL7roInvL\nLbfYaDRqa2pq7LvvvtvsZ+XLL7+0+fn59tVXX7WhUMjedtttdsiQITYYDNrVq1fbvLw8W15ebq21\nNhKJ2L59+9p58+bttfyZM2dav99v77rrLhsOh5v93tx444126tSpDY4dc8wxtqSkxC5atMiGw2Eb\nCoX2+pktLi5u8L2dMGGC3bBhg92+fbsdMWKE/etf/9ruc19++WXbp08fu2jRIltVVWWnTp1qAbt8\n+fJm30dJau39eacvfelLX/rSV8p8tfcFkkJKS0ttQUGBLSwstAMHDrQ/+MEPmm1UW2vtWWedZf/y\nl79Ya2ON3qysLFtTU9PitT/77DPbvXv3+ONjjjnG3nLLLfHH1113nT311FPjj59//nk7duxYa621\nTzzxhJ00aVKD611zzTX2pptustY2DWAffPCBLSkpaXD+rbfeGg+JN954oz3qqKNafiOauaa1sffn\nvvvua/V1Y8eOtc8++6y1tmGoa+0eotGozc3NtQsWLGhyveZC3fHHH2/vvvvu+ONly5bZQCBgw+Fw\n/PyVK1c2e43Nmzfb7OzsBt/Xxx57zB577LHWWmunTZtmr776artu3bpW7/N3v/udPf/88+OPo9Go\n7d+/v33rrbestdZOnDjRPvjgg9Zaa1999VU7ePBga63da/kzZ85s8r1rrKVQ99vf/rbV1zX+zDYO\nag8//HD88c9//nP7ve99r93nXnnllfaXv/xl/Lnly5cr1KWuhP/A1Ze+9KUvfemrs760UkOae/bZ\nZznxxBObHH/55Ze5+eab+eqrr3Ach+rqag488MD480VFReTm5sYfV1dXc+211zJnzhx27twJQEVF\nBdFoFL/fD0CfPn3i5+fl5TV5XFlZCcCaNWv48MMP6d69e/z5SCTCtGnTmr2HNWvWsHHjxgbnR6NR\njjrqqPjjkpKStr0hjTR+3UMPPcSf/vSn+FDRyspKtm3b1mydWrqHbdu2UVtby5AhQ9pUh40bN1Ja\nWhp/XFpaSiQSYcuWLS3Ws349wuEw/fr1ix9zHCd+/u23385vf/tbDj30UHr06MGMGTOYPn36Xuvg\n8/koKSlhw4YNAFxyySU8/vjjXHbZZTz22GNccsklbSq/tbrvTePX7e0z21jfvn3jf8/Pz2fjxo3t\nPnfjxo0ccsghLdZJREREJBko1GWgYDDIueeey0MPPcRZZ51FVlYWZ599NtbuWdzUmIYLxd1xxx18\n+eWXfPjhh/Tt25cFCxYwfvz4Bq9pq5KSEo455hhee+21Zp9vXHZJSQllZWUsX768xWs2fk1bn69/\nfM2aNVx99dW88cYbHHHEEfj9fsaNG9fsPbZ2D47jkJuby8qVK+NzxFqrR//+/VmzZk388dq1awkE\nAvTp0yc+b7Gl+peUlJCTk8O2bduaXU2zb9++/P3vfwfgvffe48QTT+Too49m6NChTerwxRdfxB9b\na1m3bh3FxcUAnH/++cyYMYP169cza9Ys5s2b16byW6v73p6vf7wtn9nO0K9fvwZzR9etW9ep5YmI\niIjsCy2UkoFCoRDBYJCioiICgQAvv/wyr776aquvqaioIC8vj+7du7Njxw5uvvnmfS7/jDPO4Kuv\nvuLhhx8mHA4TDoeZP39+fGGOPn368PXXX8fPP/TQQ+natSu33XYbNTU1RKNRFi1axPz589tcZuNr\nNqeqqgpjDEVFRUBsMRl3YZn23IPP52P69Olcd911bNy4kWg0yrx58+Lvuc/na1CXiy++mD//+c+s\nWrWKyspKfv3rX3PhhRe2acuDfv36cfLJJzNjxgzKy8txHIeVK1fyzjvvAPDkk0/GQ0mPHj0wxuDz\nNf1nf8EFF/Diiy/yxht
"text/plain": [
"<matplotlib.figure.Figure at 0x7f2e8a086710>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2017-09-29 18:54:05 +02:00
"source": [
"from ipywidgets import interact\n",
"%matplotlib inline\n",
"\n",
"def setup_figure():\n",
" # create figure and axes\n",
" fig = plt.figure(figsize=(12, 6))\n",
" ax1 = fig.add_axes([0., 0., 0.5, 1.], projection='3d')\n",
" ax2 = fig.add_axes([0.6, 0.1, 0.4, 0.8])\n",
" # set axes properties\n",
" ax2.spines['right'].set_visible(False)\n",
" ax2.spines['top'].set_visible(False)\n",
" ax2.yaxis.set_ticks_position('left')\n",
" ax2.xaxis.set_ticks_position('bottom')\n",
" ax2.set_yscale('log')\n",
" ax1.set_xlim((-2, 2))\n",
" ax1.set_ylim((-2, 2))\n",
" ax1.set_zlim((-2, 2))\n",
" #set axes labels and title\n",
" ax1.set_title('Parameter trajectories over training')\n",
" ax1.set_xlabel('Weight 1')\n",
" ax1.set_ylabel('Weight 2')\n",
" ax1.set_zlabel('Bias')\n",
" ax2.set_title('Batch errors over training')\n",
" ax2.set_xlabel('Batch update number')\n",
" ax2.set_ylabel('Batch error')\n",
" return fig, ax1, ax2\n",
"\n",
"def visualise_training(n_epochs=1, batch_size=200, log_lr=-1., n_inits=5,\n",
" w_scale=1., b_scale=1., elev=30., azim=0.):\n",
" fig, ax1, ax2 = setup_figure()\n",
" # create seeded random number generator\n",
" rng = np.random.RandomState(1234)\n",
" # create data provider\n",
" data_provider = CCPPDataProvider(\n",
" input_dims=[0, 1],\n",
" batch_size=batch_size, \n",
" shuffle_order=False,\n",
" )\n",
" learning_rate = 10 ** log_lr\n",
" n_batches = data_provider.num_batches\n",
" weights_traj = np.empty((n_inits, n_epochs * n_batches + 1, 1, 2))\n",
" biases_traj = np.empty((n_inits, n_epochs * n_batches + 1, 1))\n",
" errors_traj = np.empty((n_inits, n_epochs * n_batches))\n",
" # randomly initialise parameters\n",
" weights = rng.uniform(-w_scale, w_scale, (n_inits, 1, 2))\n",
" biases = rng.uniform(-b_scale, b_scale, (n_inits, 1))\n",
" # store initial parameters\n",
" weights_traj[:, 0] = weights\n",
" biases_traj[:, 0] = biases\n",
" # iterate across different initialisations\n",
" for i in range(n_inits):\n",
" # iterate across epochs\n",
" for e in range(n_epochs):\n",
" # iterate across batches\n",
" for b, (inputs, targets) in enumerate(data_provider):\n",
" outputs = fprop(inputs, weights[i], biases[i])\n",
" errors_traj[i, e * n_batches + b] = error(outputs, targets)\n",
" grad_wrt_outputs = error_grad(outputs, targets)\n",
" weights_grad, biases_grad = grads_wrt_params(inputs, grad_wrt_outputs)\n",
" weights[i] -= learning_rate * weights_grad\n",
" biases[i] -= learning_rate * biases_grad\n",
" weights_traj[i, e * n_batches + b + 1] = weights[i]\n",
" biases_traj[i, e * n_batches + b + 1] = biases[i]\n",
" # choose a different color for each trajectory\n",
" colors = plt.cm.jet(np.linspace(0, 1, n_inits))\n",
" # plot all trajectories\n",
" for i in range(n_inits):\n",
" lines_1 = ax1.plot(\n",
" weights_traj[i, :, 0, 0], \n",
" weights_traj[i, :, 0, 1], \n",
" biases_traj[i, :, 0], \n",
" '-', c=colors[i], lw=2)\n",
" lines_2 = ax2.plot(\n",
" np.arange(n_batches * n_epochs),\n",
" errors_traj[i],\n",
" c=colors[i]\n",
" )\n",
" ax1.view_init(elev, azim)\n",
" plt.show()\n",
"\n",
"w = interact(\n",
" visualise_training,\n",
" elev=(-90, 90, 2),\n",
" azim=(-180, 180, 2), \n",
" n_epochs=(1, 5), \n",
" batch_size=(100, 1000, 100),\n",
" log_lr=(-3., 1.),\n",
" w_scale=(0., 2.),\n",
" b_scale=(0., 2.),\n",
" n_inits=(1, 10)\n",
")\n",
"\n",
"for child in w.widget.children:\n",
" child.layout.width = '100%'"
]
2017-10-06 15:46:19 +02:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2017-09-29 18:54:05 +02:00
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
2017-10-02 02:21:49 +02:00
"display_name": "Python 3",
2017-09-29 18:54:05 +02:00
"language": "python",
2017-10-02 02:21:49 +02:00
"name": "python3"
2017-09-29 18:54:05 +02:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
2017-10-02 02:21:49 +02:00
"version": 3
2017-09-29 18:54:05 +02:00
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
2017-10-02 02:21:49 +02:00
"pygments_lexer": "ipython3",
"version": "3.6.2"
2017-09-29 18:54:05 +02:00
}
},
"nbformat": 4,
2017-10-02 02:21:49 +02:00
"nbformat_minor": 1
2017-09-29 18:54:05 +02:00
}