diff --git a/mlp/data_providers.py b/mlp/data_providers.py index 9166b17..0d2eef9 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -247,6 +247,7 @@ class MetOfficeDataProvider(DataProvider): class CCPPDataProvider(DataProvider): + """Combine Cycle Power Plant data provider.""" def __init__(self, which_set='train', input_dims=None, batch_size=10, max_num_batches=-1, shuffle_order=True, rng=None): @@ -293,3 +294,41 @@ class CCPPDataProvider(DataProvider): targets = loaded[which_set + '_targets'] super(CCPPDataProvider, self).__init__( inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + + +class AugmentedMNISTDataProvider(MNISTDataProvider): + """Data provider for MNIST dataset which randomly transforms images.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None, transformer=None): + """Create a new augmented MNIST data provider object. + + Args: + which_set: One of 'train', 'valid' or 'test'. Determines which + portion of the MNIST data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + transformer: Function which takes an `inputs` array of shape + (batch_size, input_dim) corresponding to a batch of input + images and a `rng` random number generator object (i.e. a + call signature `transformer(inputs, rng)`) and applies a + potentiall random set of transformations to some / all of the + input images as each new batch is returned when iterating over + the data provider. + """ + super(AugmentedMNISTDataProvider, self).__init__( + which_set, batch_size, max_num_batches, shuffle_order, rng) + self.transformer = transformer + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super( + AugmentedMNISTDataProvider, self).next() + transformed_inputs_batch = self.transformer(inputs_batch, self.rng) + return transformed_inputs_batch, targets_batch