From 057a25ed07e99577ef444bcd59adb5b570f21805 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 3 Nov 2016 18:28:41 +0000 Subject: [PATCH] Adding autoencoder wrapper data provider. --- mlp/data_providers.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlp/data_providers.py b/mlp/data_providers.py index 0d2eef9..159d0c1 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -332,3 +332,14 @@ class AugmentedMNISTDataProvider(MNISTDataProvider): AugmentedMNISTDataProvider, self).next() transformed_inputs_batch = self.transformer(inputs_batch, self.rng) return transformed_inputs_batch, targets_batch + + +class MNISTAutoencoderDataProvider(MNISTDataProvider): + """Simple wrapper data provider for training an autoencoder on MNIST.""" + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super( + MNISTAutoencoderDataProvider, self).next() + # return inputs as targets for autoencoder training + return inputs_batch, inputs_batch