Fix unpacking variables bug
This commit is contained in:
parent
c2995c34e7
commit
95a2c07b8d
@ -176,8 +176,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
@ -218,7 +220,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_model_and_plot_stats(\n",
|
||||
@ -234,7 +238,7 @@
|
||||
"\n",
|
||||
" # Run the optimiser for 5 epochs (full passes through the training set)\n",
|
||||
" # printing statistics every epoch.\n",
|
||||
" stats, keys = optimiser.train(num_epochs=num_epochs, stats_interval=stats_interval)\n",
|
||||
" stats, keys, _ = optimiser.train(num_epochs=num_epochs, stats_interval=stats_interval)\n",
|
||||
"\n",
|
||||
" # Plot the change in the validation and training set error over training.\n",
|
||||
" fig_1 = plt.figure(figsize=(8, 4))\n",
|
||||
@ -270,7 +274,22 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 5: 1.2s to complete\n",
|
||||
" error(train)=3.11e-01, acc(train)=9.13e-01, error(valid)=2.92e-01, acc(valid)=9.18e-01\n",
|
||||
"Epoch 10: 1.2s to complete\n",
|
||||
" error(train)=2.89e-01, acc(train)=9.20e-01, error(valid)=2.77e-01, acc(valid)=9.23e-01\n",
|
||||
"Epoch 15: 1.1s to complete\n",
|
||||
" error(train)=2.79e-01, acc(train)=9.22e-01, error(valid)=2.70e-01, acc(valid)=9.24e-01\n",
|
||||
"Epoch 20: 0.7s to complete\n",
|
||||
" error(train)=2.72e-01, acc(train)=9.24e-01, error(valid)=2.66e-01, acc(valid)=9.26e-01\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Set training run hyperparameters\n",
|
||||
"batch_size = 100 # number of data points in a batch\n",
|
||||
@ -356,7 +375,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set training run hyperparameters\n",
|
||||
@ -437,7 +458,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
@ -451,7 +474,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
@ -465,7 +490,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
@ -479,7 +506,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
@ -536,7 +565,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
@ -589,7 +620,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_inputs = np.array([[0.1, -0.2, 0.3], [-0.4, 0.5, -0.6]])\n",
|
||||
@ -624,7 +657,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_inputs = np.array([[0.1, -0.2, 0.3], [-0.4, 0.5, -0.6]])\n",
|
||||
@ -659,18 +694,6 @@
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
Loading…
Reference in New Issue
Block a user