Adding autoencoder wrapper data provider.
This commit is contained in:
parent
2a4364d226
commit
057a25ed07
@ -332,3 +332,14 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
|
||||
AugmentedMNISTDataProvider, self).next()
|
||||
transformed_inputs_batch = self.transformer(inputs_batch, self.rng)
|
||||
return transformed_inputs_batch, targets_batch
|
||||
|
||||
|
||||
class MNISTAutoencoderDataProvider(MNISTDataProvider):
|
||||
"""Simple wrapper data provider for training an autoencoder on MNIST."""
|
||||
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
inputs_batch, targets_batch = super(
|
||||
MNISTAutoencoderDataProvider, self).next()
|
||||
# return inputs as targets for autoencoder training
|
||||
return inputs_batch, inputs_batch
|
||||
|
Loading…
Reference in New Issue
Block a user