Adding autoencoder wrapper data provider.

This commit is contained in:
Matt Graham 2016-11-03 18:28:41 +00:00
parent 2a4364d226
commit 057a25ed07

View File

@ -332,3 +332,14 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
AugmentedMNISTDataProvider, self).next()
transformed_inputs_batch = self.transformer(inputs_batch, self.rng)
return transformed_inputs_batch, targets_batch
class MNISTAutoencoderDataProvider(MNISTDataProvider):
"""Simple wrapper data provider for training an autoencoder on MNIST."""
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
inputs_batch, targets_batch = super(
MNISTAutoencoderDataProvider, self).next()
# return inputs as targets for autoencoder training
return inputs_batch, inputs_batch