Adding autoencoder wrapper data provider.
This commit is contained in:
parent
2a4364d226
commit
057a25ed07
@ -332,3 +332,14 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
|
|||||||
AugmentedMNISTDataProvider, self).next()
|
AugmentedMNISTDataProvider, self).next()
|
||||||
transformed_inputs_batch = self.transformer(inputs_batch, self.rng)
|
transformed_inputs_batch = self.transformer(inputs_batch, self.rng)
|
||||||
return transformed_inputs_batch, targets_batch
|
return transformed_inputs_batch, targets_batch
|
||||||
|
|
||||||
|
|
||||||
|
class MNISTAutoencoderDataProvider(MNISTDataProvider):
|
||||||
|
"""Simple wrapper data provider for training an autoencoder on MNIST."""
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
|
inputs_batch, targets_batch = super(
|
||||||
|
MNISTAutoencoderDataProvider, self).next()
|
||||||
|
# return inputs as targets for autoencoder training
|
||||||
|
return inputs_batch, inputs_batch
|
||||||
|
Loading…
Reference in New Issue
Block a user