Merge pull request #56 from CSTR-Edinburgh/mlp2017-8/lab3

Fix data provider bug
This commit is contained in:
AntreasAntoniou 2017-10-10 11:54:12 +01:00 committed by GitHub
commit ee36a2f3c1
2 changed files with 49 additions and 26 deletions

View File

@ -101,7 +101,7 @@ class DataProvider(object):
self.shuffle() self.shuffle()
def __next__(self): def __next__(self):
self.next() return self.next()
def reset(self): def reset(self):
"""Resets the provider to the initial state.""" """Resets the provider to the initial state."""

View File

@ -176,8 +176,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
@ -218,7 +220,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def train_model_and_plot_stats(\n", "def train_model_and_plot_stats(\n",
@ -234,7 +238,7 @@
"\n", "\n",
" # Run the optimiser for 5 epochs (full passes through the training set)\n", " # Run the optimiser for 5 epochs (full passes through the training set)\n",
" # printing statistics every epoch.\n", " # printing statistics every epoch.\n",
" stats, keys = optimiser.train(num_epochs=num_epochs, stats_interval=stats_interval)\n", " stats, keys, _ = optimiser.train(num_epochs=num_epochs, stats_interval=stats_interval)\n",
"\n", "\n",
" # Plot the change in the validation and training set error over training.\n", " # Plot the change in the validation and training set error over training.\n",
" fig_1 = plt.figure(figsize=(8, 4))\n", " fig_1 = plt.figure(figsize=(8, 4))\n",
@ -270,7 +274,22 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5: 1.2s to complete\n",
" error(train)=3.11e-01, acc(train)=9.13e-01, error(valid)=2.92e-01, acc(valid)=9.18e-01\n",
"Epoch 10: 1.2s to complete\n",
" error(train)=2.89e-01, acc(train)=9.20e-01, error(valid)=2.77e-01, acc(valid)=9.23e-01\n",
"Epoch 15: 1.1s to complete\n",
" error(train)=2.79e-01, acc(train)=9.22e-01, error(valid)=2.70e-01, acc(valid)=9.24e-01\n",
"Epoch 20: 0.7s to complete\n",
" error(train)=2.72e-01, acc(train)=9.24e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n"
]
}
],
"source": [ "source": [
"# Set training run hyperparameters\n", "# Set training run hyperparameters\n",
"batch_size = 100 # number of data points in a batch\n", "batch_size = 100 # number of data points in a batch\n",
@ -356,7 +375,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Set training run hyperparameters\n", "# Set training run hyperparameters\n",
@ -437,7 +458,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [] "source": []
}, },
@ -451,7 +474,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [] "source": []
}, },
@ -465,7 +490,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [] "source": []
}, },
@ -479,7 +506,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [] "source": []
}, },
@ -536,7 +565,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
@ -589,7 +620,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"test_inputs = np.array([[0.1, -0.2, 0.3], [-0.4, 0.5, -0.6]])\n", "test_inputs = np.array([[0.1, -0.2, 0.3], [-0.4, 0.5, -0.6]])\n",
@ -624,7 +657,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"test_inputs = np.array([[0.1, -0.2, 0.3], [-0.4, 0.5, -0.6]])\n", "test_inputs = np.array([[0.1, -0.2, 0.3], [-0.4, 0.5, -0.6]])\n",
@ -659,18 +694,6 @@
"display_name": "Python 3", "display_name": "Python 3",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
} }
}, },
"nbformat": 4, "nbformat": 4,