adding overlooked dataset changes
This commit is contained in:
parent
ed47b36873
commit
fc2eecf135
@ -83,7 +83,8 @@ class MNISTDataProvider(DataProvider):
|
|||||||
max_num_batches=-1,
|
max_num_batches=-1,
|
||||||
max_num_examples=-1,
|
max_num_examples=-1,
|
||||||
randomize=True,
|
randomize=True,
|
||||||
rng=None):
|
rng=None,
|
||||||
|
conv_reshape=False):
|
||||||
|
|
||||||
super(MNISTDataProvider, self).\
|
super(MNISTDataProvider, self).\
|
||||||
__init__(batch_size, randomize, rng)
|
__init__(batch_size, randomize, rng)
|
||||||
@ -119,6 +120,7 @@ class MNISTDataProvider(DataProvider):
|
|||||||
self.x = x
|
self.x = x
|
||||||
self.t = t
|
self.t = t
|
||||||
self.num_classes = 10
|
self.num_classes = 10
|
||||||
|
self.conv_reshape = conv_reshape
|
||||||
|
|
||||||
self._rand_idx = None
|
self._rand_idx = None
|
||||||
if self.randomize:
|
if self.randomize:
|
||||||
@ -162,6 +164,9 @@ class MNISTDataProvider(DataProvider):
|
|||||||
|
|
||||||
self._curr_idx += self.batch_size
|
self._curr_idx += self.batch_size
|
||||||
|
|
||||||
|
if self.conv_reshape:
|
||||||
|
rval_x = rval_x.reshape(self.batch_size, 1, 28, 28)
|
||||||
|
|
||||||
return rval_x, self.__to_one_of_k(rval_t)
|
return rval_x, self.__to_one_of_k(rval_t)
|
||||||
|
|
||||||
def num_examples(self):
|
def num_examples(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user