diff --git a/mlp/dataset.py b/mlp/dataset.py index d081cc4..fdd8124 100644 --- a/mlp/dataset.py +++ b/mlp/dataset.py @@ -83,7 +83,8 @@ class MNISTDataProvider(DataProvider): max_num_batches=-1, max_num_examples=-1, randomize=True, - rng=None): + rng=None, + conv_reshape=False): super(MNISTDataProvider, self).\ __init__(batch_size, randomize, rng) @@ -119,6 +120,7 @@ class MNISTDataProvider(DataProvider): self.x = x self.t = t self.num_classes = 10 + self.conv_reshape = conv_reshape self._rand_idx = None if self.randomize: @@ -162,6 +164,9 @@ class MNISTDataProvider(DataProvider): self._curr_idx += self.batch_size + if self.conv_reshape: + rval_x = rval_x.reshape(self.batch_size, 1, 28, 28) + return rval_x, self.__to_one_of_k(rval_t) def num_examples(self):