diff --git a/mlp/data_providers.py b/mlp/data_providers.py index 0d2eef9..159d0c1 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -332,3 +332,14 @@ class AugmentedMNISTDataProvider(MNISTDataProvider): AugmentedMNISTDataProvider, self).next() transformed_inputs_batch = self.transformer(inputs_batch, self.rng) return transformed_inputs_batch, targets_batch + + +class MNISTAutoencoderDataProvider(MNISTDataProvider): + """Simple wrapper data provider for training an autoencoder on MNIST.""" + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super( + MNISTAutoencoderDataProvider, self).next() + # return inputs as targets for autoencoder training + return inputs_batch, inputs_batch