{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "$\\newcommand{\\vct}[1]{\\boldsymbol{#1}}\n", "\\newcommand{\\mtx}[1]{\\mathbf{#1}}\n", "\\newcommand{\\tr}{^\\mathrm{T}}\n", "\\newcommand{\\reals}{\\mathbb{R}}\n", "\\newcommand{\\lpa}{\\left(}\n", "\\newcommand{\\rpa}{\\right)}\n", "\\newcommand{\\lsb}{\\left[}\n", "\\newcommand{\\rsb}{\\right]}\n", "\\newcommand{\\lbr}{\\left\\lbrace}\n", "\\newcommand{\\rbr}{\\right\\rbrace}\n", "\\newcommand{\\fset}[1]{\\lbr #1 \\rbr}\n", "\\newcommand{\\pd}[2]{\\frac{\\partial #1}{\\partial #2}}$\n", "\n", "# Single layer models\n", "\n", "In this lab we will implement a single-layer network model consisting of solely of an affine transformation of the inputs. The relevant material for this was covered in [the slides of the first lecture](http://www.inf.ed.ac.uk/teaching/courses/mlp/2016/mlp01-intro.pdf). \n", "\n", "We will first implement the forward propagation of inputs to the network to produce predicted outputs. We will then move on to considering how to use gradients of a cost function evaluated on the outputs to compute the gradients with respect to the model parameters to allow us to perform an iterative gradient-descent training procedure.\n", "\n", "## A general note on random number generators\n", "\n", "It is generally a good practice (for machine learning applications **not** for cryptography!) to seed a pseudo-random number generator once at the beginning of each experiment. This makes it easier to reproduce results as the same random draws will produced each time the experiment is run (e.g. the same random initialisations used for parameters). \n", "\n", "Therefore generally when we need to generate random values during this course, we will create a seeded random number generator object as follows:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n", "\n", "# Seed a random number generator to be used later\n", "seed = 27092016 \n", "rng = np.random.RandomState(seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exercise 1: linear and affine transforms\n", "\n", "Any *linear transform* (also called a linear map) of a finite-dimensional vector space can be parametrised by a matrix. So for example if we consider $\\vct{x} \\in \\reals^{M}$ as the input space of a model with $M$ dimensional real-valued inputs, then a matrix $\\mtx{W} \\in \\reals^{N\\times M}$ can be used to define a prediction model consisting solely of a linear transform of the inputs\n", "\n", "\\begin{equation}\n", " \\vct{y} = \\mtx{W} \\vct{x}\n", " \\qquad\n", " \\Leftrightarrow\n", " \\qquad\n", " y_i = \\sum_{j=1}^M \\lpa W_{ij} x_j \\rpa \\qquad \\forall i \\in \\fset{1 \\dots N}\n", "\\end{equation}\n", "\n", "with here $\\vct{y} \\in \\reals^N$ the $N$-dimensional real-valued output of the model. Geometrically we can think of a linear transform doing some combination of rotation, scaling, reflection and shearing of the input.\n", "\n", "An *affine transform* consists of a linear transform plus an additional translation parameterised by a vector $\\vct{b} \\in \\reals^N$. A model consisting of an affine transformation of the inputs can then be defined as\n", "\n", "\\begin{equation}\n", " \\vct{y} = \\mtx{W}\\vct{x} + \\vct{b}\n", " \\qquad\n", " \\Leftrightarrow\n", " \\qquad\n", " y_i = \\sum_{j=1}^M \\lpa W_{ij} x_j \\rpa + b_i \\qquad \\forall i \\in \\fset{1 \\dots N}\n", "\\end{equation}\n", "\n", "In machine learning we will usually refer to the matrix $\\mtx{W}$ as a *weight matrix* and the vector $\\vct{b}$ as a *bias vector*. \n", "\n", "Implement *forward propagation* for a single-layer model consisting of an affine transformation of the inputs. This should work for a batch of inputs of shape `(batch_size, input_dim)` producing a batch of outputs of shape `(batch_size, output_dim)`. \n", "\n", "You may find the NumPy [`dot`](http://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html) function and [broadcasting features](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) useful." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def fprop(inputs, weights, biases):\n", " \"\"\"Forward propagates activations through the layer transformation.\n", "\n", " For inputs `x`, outputs `y`, weights `W` and biases `b` the layer\n", " corresponds to `y = W x + b`.\n", "\n", " Args:\n", " inputs: Array of layer inputs of shape (batch_size, input_dim).\n", " weights: Array of weight parameters of shape \n", " (output_dim, input_dim).\n", " biases: Array of bias parameters of shape (output_dim, ).\n", "\n", " Returns:\n", " outputs: Array of layer outputs of shape (batch_size, output_dim).\n", " \"\"\"\n", " return weights.dot(inputs.T).T + biases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check your implementation by running the test cell below" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "All outputs correct!\n" ] } ], "source": [ "inputs = np.array([[0., -1., 2.], [-6., 3., 1.]])\n", "weights = np.array([[2., -3., -1.], [-5., 7., 2.]])\n", "biases = np.array([5., -3.])\n", "true_outputs = np.array([[6., -6.], [-17., 50.]])\n", "\n", "if not np.allclose(fprop(inputs, weights, biases), true_outputs):\n", " print('Wrong outputs computed.')\n", "else:\n", " print('All outputs correct!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exercise 2: visualising random models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this exercise you will use your `fprop` implementation to visualise the outputs of a single-layer affine transform model with two-dimensional inputs and a one-dimensional output. In this simple case we can visualise the joint input-output space on a 3D axis.\n", "\n", "For this task and the learning experiments later in the notebook we will use a regression dataset from the [UCI machine learning repository](http://archive.ics.uci.edu/ml/index.html). In particular we will use a version of the [Combined Cycle Power Plant dataset](http://archive.ics.uci.edu/ml/datasets/Combined+Cycle+Power+Plant), where the task is to predict the energy output of a power plant given observations of the local ambient conditions (e.g. temperature, pressure and humidity).\n", "\n", "The original dataset has four input dimensions and a single target output dimension. We have preprocessed the dataset by [whitening](https://en.wikipedia.org/wiki/Whitening_transformation) it, a common preprocessing step. We will only use the first two dimensions of the whitened inputs (corresponding to the first two principal components of the inputs) so we can easily visualise the joint input-output space.\n", "\n", "The dataset has been wrapped in the `CCPPDataProvider` class in the `mlp.data_providers` module. Running the cell below will initialise an instance of this class, get a single batch of inputs and outputs and import the necessary `matplotlib` objects." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "from mlp.data_providers import CCPPDataProvider\n", "%matplotlib notebook\n", "\n", "data_provider = CCPPDataProvider(\n", " which_set='train',\n", " input_dims=[0, 1],\n", " batch_size=5000, \n", " max_num_batches=1, \n", " shuffle_order=False\n", ")\n", "\n", "input_dim, output_dim = 2, 1\n", "\n", "inputs, targets = data_provider.next()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now run the cell below to plot the predicted outputs of a randomly initialised model across the two dimensional input space as well as the true target outputs. This sort of visualisation can be a useful method (in low dimensions) to assess how well the model is likely to be able to fit the data and to judge appropriate initialisation scales for the parameters. Each time you re-run the cell a new set of random parameters will be sampled\n", "\n", "Some questions to consider:\n", "\n", " * How do the weights and bias initialisation scale affect the sort of predicted input-output relationships?\n", " * Does the linear form of the model seem a good fit for the data here?" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " this.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '');\n", " var titletext = $(\n", " '');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width);\n", " canvas.attr('height', height);\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n", " 'ui-button-icon-only');\n", " button.attr('role', 'button');\n", " button.attr('aria-disabled', 'false');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", "\n", " var icon_img = $('');\n", " icon_img.addClass('ui-button-icon-primary ui-icon');\n", " icon_img.addClass(image);\n", " icon_img.addClass('ui-corner-all');\n", "\n", " var tooltip_span = $('');\n", " tooltip_span.addClass('ui-button-text');\n", " tooltip_span.html(tooltip);\n", "\n", " button.append(icon_img);\n", " button.append(tooltip_span);\n", "\n", " nav_element.append(button);\n", " }\n", "\n", " var fmt_picker_span = $('');\n", "\n", " var fmt_picker = $('');\n", " fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n", " fmt_picker_span.append(fmt_picker);\n", " nav_element.append(fmt_picker_span);\n", " this.format_dropdown = fmt_picker[0];\n", "\n", " for (var ind in mpl.extensions) {\n", " var fmt = mpl.extensions[ind];\n", " var option = $(\n", " '', {selected: fmt === mpl.default_extension}).html(fmt);\n", " fmt_picker.append(option)\n", " }\n", "\n", " // Add hover states to the ui-buttons\n", " $( \".ui-button\" ).hover(\n", " function() { $(this).addClass(\"ui-state-hover\");},\n", " function() { $(this).removeClass(\"ui-state-hover\");}\n", " );\n", "\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "}\n", "\n", "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n", " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n", " // which will in turn request a refresh of the image.\n", " this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n", "}\n", "\n", "mpl.figure.prototype.send_message = function(type, properties) {\n", " properties['type'] = type;\n", " properties['figure_id'] = this.id;\n", " this.ws.send(JSON.stringify(properties));\n", "}\n", "\n", "mpl.figure.prototype.send_draw_message = function() {\n", " if (!this.waiting) {\n", " this.waiting = true;\n", " this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n", " }\n", "}\n", "\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " var format_dropdown = fig.format_dropdown;\n", " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n", " fig.ondownload(fig, format);\n", "}\n", "\n", "\n", "mpl.figure.prototype.handle_resize = function(fig, msg) {\n", " var size = msg['size'];\n", " if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n", " fig._resize_canvas(size[0], size[1]);\n", " fig.send_message(\"refresh\", {});\n", " };\n", "}\n", "\n", "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n", " var x0 = msg['x0'];\n", " var y0 = fig.canvas.height - msg['y0'];\n", " var x1 = msg['x1'];\n", " var y1 = fig.canvas.height - msg['y1'];\n", " x0 = Math.floor(x0) + 0.5;\n", " y0 = Math.floor(y0) + 0.5;\n", " x1 = Math.floor(x1) + 0.5;\n", " y1 = Math.floor(y1) + 0.5;\n", " var min_x = Math.min(x0, x1);\n", " var min_y = Math.min(y0, y1);\n", " var width = Math.abs(x1 - x0);\n", " var height = Math.abs(y1 - y0);\n", "\n", " fig.rubberband_context.clearRect(\n", " 0, 0, fig.canvas.width, fig.canvas.height);\n", "\n", " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n", "}\n", "\n", "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n", " // Updates the figure title.\n", " fig.header.textContent = msg['label'];\n", "}\n", "\n", "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n", " var cursor = msg['cursor'];\n", " switch(cursor)\n", " {\n", " case 0:\n", " cursor = 'pointer';\n", " break;\n", " case 1:\n", " cursor = 'default';\n", " break;\n", " case 2:\n", " cursor = 'crosshair';\n", " break;\n", " case 3:\n", " cursor = 'move';\n", " break;\n", " }\n", " fig.rubberband_canvas.style.cursor = cursor;\n", "}\n", "\n", "mpl.figure.prototype.handle_message = function(fig, msg) {\n", " fig.message.textContent = msg['message'];\n", "}\n", "\n", "mpl.figure.prototype.handle_draw = function(fig, msg) {\n", " // Request the server to send over a new figure.\n", " fig.send_draw_message();\n", "}\n", "\n", "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n", " fig.image_mode = msg['mode'];\n", "}\n", "\n", "mpl.figure.prototype.updated_canvas_event = function() {\n", " // Called whenever the canvas gets updated.\n", " this.send_message(\"ack\", {});\n", "}\n", "\n", "// A function to construct a web socket function for onmessage handling.\n", "// Called in the figure constructor.\n", "mpl.figure.prototype._make_on_message_function = function(fig) {\n", " return function socket_on_message(evt) {\n", " if (evt.data instanceof Blob) {\n", " /* FIXME: We get \"Resource interpreted as Image but\n", " * transferred with MIME type text/plain:\" errors on\n", " * Chrome. But how to set the MIME type? It doesn't seem\n", " * to be part of the websocket stream */\n", " evt.data.type = \"image/png\";\n", "\n", " /* Free the memory for the previous frames */\n", " if (fig.imageObj.src) {\n", " (window.URL || window.webkitURL).revokeObjectURL(\n", " fig.imageObj.src);\n", " }\n", "\n", " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n", " evt.data);\n", " fig.updated_canvas_event();\n", " fig.waiting = false;\n", " return;\n", " }\n", " else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n", " fig.imageObj.src = evt.data;\n", " fig.updated_canvas_event();\n", " fig.waiting = false;\n", " return;\n", " }\n", "\n", " var msg = JSON.parse(evt.data);\n", " var msg_type = msg['type'];\n", "\n", " // Call the \"handle_{type}\" callback, which takes\n", " // the figure and JSON message as its only arguments.\n", " try {\n", " var callback = fig[\"handle_\" + msg_type];\n", " } catch (e) {\n", " console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n", " return;\n", " }\n", "\n", " if (callback) {\n", " try {\n", " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n", " callback(fig, msg);\n", " } catch (e) {\n", " console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n", " }\n", " }\n", " };\n", "}\n", "\n", "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n", "mpl.findpos = function(e) {\n", " //this section is from http://www.quirksmode.org/js/events_properties.html\n", " var targ;\n", " if (!e)\n", " e = window.event;\n", " if (e.target)\n", " targ = e.target;\n", " else if (e.srcElement)\n", " targ = e.srcElement;\n", " if (targ.nodeType == 3) // defeat Safari bug\n", " targ = targ.parentNode;\n", "\n", " // jQuery normalizes the pageX and pageY\n", " // pageX,Y are the mouse positions relative to the document\n", " // offset() returns the position of the element relative to the document\n", " var x = e.pageX - $(targ).offset().left;\n", " var y = e.pageY - $(targ).offset().top;\n", "\n", " return {\"x\": x, \"y\": y};\n", "};\n", "\n", "/*\n", " * return a copy of an object with only non-object keys\n", " * we need this to avoid circular references\n", " * http://stackoverflow.com/a/24161582/3208463\n", " */\n", "function simpleKeys (original) {\n", " return Object.keys(original).reduce(function (obj, key) {\n", " if (typeof original[key] !== 'object')\n", " obj[key] = original[key]\n", " return obj;\n", " }, {});\n", "}\n", "\n", "mpl.figure.prototype.mouse_event = function(event, name) {\n", " var canvas_pos = mpl.findpos(event)\n", "\n", " if (name === 'button_press')\n", " {\n", " this.canvas.focus();\n", " this.canvas_div.focus();\n", " }\n", "\n", " var x = canvas_pos.x;\n", " var y = canvas_pos.y;\n", "\n", " this.send_message(name, {x: x, y: y, button: event.button,\n", " step: event.step,\n", " guiEvent: simpleKeys(event)});\n", "\n", " /* This prevents the web browser from automatically changing to\n", " * the text insertion cursor when the button is pressed. We want\n", " * to control all of the cursor setting manually through the\n", " * 'cursor' event from matplotlib */\n", " event.preventDefault();\n", " return false;\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " // Handle any extra behaviour associated with a key event\n", "}\n", "\n", "mpl.figure.prototype.key_event = function(event, name) {\n", "\n", " // Prevent repeat events\n", " if (name == 'key_press')\n", " {\n", " if (event.which === this._key)\n", " return;\n", " else\n", " this._key = event.which;\n", " }\n", " if (name == 'key_release')\n", " this._key = null;\n", "\n", " var value = '';\n", " if (event.ctrlKey && event.which != 17)\n", " value += \"ctrl+\";\n", " if (event.altKey && event.which != 18)\n", " value += \"alt+\";\n", " if (event.shiftKey && event.which != 16)\n", " value += \"shift+\";\n", "\n", " value += 'k';\n", " value += event.which.toString();\n", "\n", " this._key_event_extra(event, name);\n", "\n", " this.send_message(name, {key: value,\n", " guiEvent: simpleKeys(event)});\n", " return false;\n", "}\n", "\n", "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n", " if (name == 'download') {\n", " this.handle_save(this, null);\n", " } else {\n", " this.send_message(\"toolbar_button\", {name: name});\n", " }\n", "};\n", "\n", "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n", " this.message.textContent = tooltip;\n", "};\n", "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n", "\n", "mpl.extensions = [\"eps\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\"];\n", "\n", "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n", " // Create a \"websocket\"-like object which calls the given IPython comm\n", " // object with the appropriate methods. Currently this is a non binary\n", " // socket, so there is still some room for performance tuning.\n", " var ws = {};\n", "\n", " ws.close = function() {\n", " comm.close()\n", " };\n", " ws.send = function(m) {\n", " //console.log('sending', m);\n", " comm.send(m);\n", " };\n", " // Register the callback with on_msg.\n", " comm.on_msg(function(msg) {\n", " //console.log('receiving', msg['content']['data'], msg);\n", " // Pass the mpl event to the overriden (by mpl) onmessage function.\n", " ws.onmessage(msg['content']['data'])\n", " });\n", " return ws;\n", "}\n", "\n", "mpl.mpl_figure_comm = function(comm, msg) {\n", " // This is the function which gets called when the mpl process\n", " // starts-up an IPython Comm through the \"matplotlib\" channel.\n", "\n", " var id = msg.content.data.id;\n", " // Get hold of the div created by the display call when the Comm\n", " // socket was opened in Python.\n", " var element = $(\"#\" + id);\n", " var ws_proxy = comm_websocket_adapter(comm)\n", "\n", " function ondownload(figure, format) {\n", " window.open(figure.imageObj.src);\n", " }\n", "\n", " var fig = new mpl.figure(id, ws_proxy,\n", " ondownload,\n", " element.get(0));\n", "\n", " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n", " // web socket which is closed, not our websocket->open comm proxy.\n", " ws_proxy.onopen();\n", "\n", " fig.parent_element = element.get(0);\n", " fig.cell_info = mpl.find_output_cell(\"\");\n", " if (!fig.cell_info) {\n", " console.error(\"Failed to find cell for figure\", id, fig);\n", " return;\n", " }\n", "\n", " var output_index = fig.cell_info[2]\n", " var cell = fig.cell_info[0];\n", "\n", "};\n", "\n", "mpl.figure.prototype.handle_close = function(fig, msg) {\n", " fig.root.unbind('remove')\n", "\n", " // Update the output cell to use the data from the current canvas.\n", " fig.push_to_output();\n", " var dataURL = fig.canvas.toDataURL();\n", " // Re-enable the keyboard manager in IPython - without this line, in FF,\n", " // the notebook keyboard shortcuts fail.\n", " IPython.keyboard_manager.enable()\n", " $(fig.parent_element).html('');\n", " fig.close_ws(fig, msg);\n", "}\n", "\n", "mpl.figure.prototype.close_ws = function(fig, msg){\n", " fig.send_message('closing', msg);\n", " // fig.ws.close()\n", "}\n", "\n", "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n", " // Turn the data on the canvas into data in the output cell.\n", " var dataURL = this.canvas.toDataURL();\n", " this.cell_info[1]['text/html'] = '';\n", "}\n", "\n", "mpl.figure.prototype.updated_canvas_event = function() {\n", " // Tell IPython that the notebook contents must change.\n", " IPython.notebook.set_dirty(true);\n", " this.send_message(\"ack\", {});\n", " var fig = this;\n", " // Wait a second, then push the new image to the DOM so\n", " // that it is saved nicely (might be nice to debounce this).\n", " setTimeout(function () { fig.push_to_output() }, 1000);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items){\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) { continue; };\n", "\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i