Adding augmented MNIST data provider.
This commit is contained in:
parent
a021ce585b
commit
a353526790
@ -247,6 +247,7 @@ class MetOfficeDataProvider(DataProvider):
|
|||||||
|
|
||||||
|
|
||||||
class CCPPDataProvider(DataProvider):
|
class CCPPDataProvider(DataProvider):
|
||||||
|
"""Combine Cycle Power Plant data provider."""
|
||||||
|
|
||||||
def __init__(self, which_set='train', input_dims=None, batch_size=10,
|
def __init__(self, which_set='train', input_dims=None, batch_size=10,
|
||||||
max_num_batches=-1, shuffle_order=True, rng=None):
|
max_num_batches=-1, shuffle_order=True, rng=None):
|
||||||
@ -293,3 +294,41 @@ class CCPPDataProvider(DataProvider):
|
|||||||
targets = loaded[which_set + '_targets']
|
targets = loaded[which_set + '_targets']
|
||||||
super(CCPPDataProvider, self).__init__(
|
super(CCPPDataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentedMNISTDataProvider(MNISTDataProvider):
|
||||||
|
"""Data provider for MNIST dataset which randomly transforms images."""
|
||||||
|
|
||||||
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
|
shuffle_order=True, rng=None, transformer=None):
|
||||||
|
"""Create a new augmented MNIST data provider object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
which_set: One of 'train', 'valid' or 'test'. Determines which
|
||||||
|
portion of the MNIST data this object should provide.
|
||||||
|
batch_size (int): Number of data points to include in each batch.
|
||||||
|
max_num_batches (int): Maximum number of batches to iterate over
|
||||||
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
|
only as many batches as the data can be split into will be
|
||||||
|
used. If set to -1 all of the data will be used.
|
||||||
|
shuffle_order (bool): Whether to randomly permute the order of
|
||||||
|
the data before each epoch.
|
||||||
|
rng (RandomState): A seeded random number generator.
|
||||||
|
transformer: Function which takes an `inputs` array of shape
|
||||||
|
(batch_size, input_dim) corresponding to a batch of input
|
||||||
|
images and a `rng` random number generator object (i.e. a
|
||||||
|
call signature `transformer(inputs, rng)`) and applies a
|
||||||
|
potentiall random set of transformations to some / all of the
|
||||||
|
input images as each new batch is returned when iterating over
|
||||||
|
the data provider.
|
||||||
|
"""
|
||||||
|
super(AugmentedMNISTDataProvider, self).__init__(
|
||||||
|
which_set, batch_size, max_num_batches, shuffle_order, rng)
|
||||||
|
self.transformer = transformer
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
|
inputs_batch, targets_batch = super(
|
||||||
|
AugmentedMNISTDataProvider, self).next()
|
||||||
|
transformed_inputs_batch = self.transformer(inputs_batch, self.rng)
|
||||||
|
return transformed_inputs_batch, targets_batch
|
||||||
|
Loading…
Reference in New Issue
Block a user