From fc2eecf1357da39f0021671a7bce2431148fe297 Mon Sep 17 00:00:00 2001
From: pswietojanski
Date: Sun, 15 Nov 2015 16:41:21 +0000
Subject: [PATCH] adding overlooked dataset changes
---
mlp/dataset.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/mlp/dataset.py b/mlp/dataset.py
index d081cc4..fdd8124 100644
--- a/mlp/dataset.py
+++ b/mlp/dataset.py
@@ -83,7 +83,8 @@ class MNISTDataProvider(DataProvider):
max_num_batches=-1,
max_num_examples=-1,
randomize=True,
- rng=None):
+ rng=None,
+ conv_reshape=False):
super(MNISTDataProvider, self).\
__init__(batch_size, randomize, rng)
@@ -119,6 +120,7 @@ class MNISTDataProvider(DataProvider):
self.x = x
self.t = t
self.num_classes = 10
+ self.conv_reshape = conv_reshape
self._rand_idx = None
if self.randomize:
@@ -162,6 +164,9 @@ class MNISTDataProvider(DataProvider):
self._curr_idx += self.batch_size
+ if self.conv_reshape:
+ rval_x = rval_x.reshape(self.batch_size, 1, 28, 28)
+
return rval_x, self.__to_one_of_k(rval_t)
def num_examples(self):