adding overlooked dataset changes

This commit is contained in:
pswietojanski 2015-11-15 16:41:21 +00:00
parent ed47b36873
commit fc2eecf135

View File

@ -83,7 +83,8 @@ class MNISTDataProvider(DataProvider):
max_num_batches=-1, max_num_batches=-1,
max_num_examples=-1, max_num_examples=-1,
randomize=True, randomize=True,
rng=None): rng=None,
conv_reshape=False):
super(MNISTDataProvider, self).\ super(MNISTDataProvider, self).\
__init__(batch_size, randomize, rng) __init__(batch_size, randomize, rng)
@ -119,6 +120,7 @@ class MNISTDataProvider(DataProvider):
self.x = x self.x = x
self.t = t self.t = t
self.num_classes = 10 self.num_classes = 10
self.conv_reshape = conv_reshape
self._rand_idx = None self._rand_idx = None
if self.randomize: if self.randomize:
@ -162,6 +164,9 @@ class MNISTDataProvider(DataProvider):
self._curr_idx += self.batch_size self._curr_idx += self.batch_size
if self.conv_reshape:
rval_x = rval_x.reshape(self.batch_size, 1, 28, 28)
return rval_x, self.__to_one_of_k(rval_t) return rval_x, self.__to_one_of_k(rval_t)
def num_examples(self): def num_examples(self):