Updated data provider with better sampling

This commit is contained in:
AntreasAntoniou 2018-02-09 21:39:20 +00:00
parent ee4adeaa66
commit 4e22eefbce

View File

@ -15,7 +15,7 @@ class DataProvider(object):
"""Generic data provider.""" """Generic data provider."""
def __init__(self, inputs, targets, batch_size, max_num_batches=-1, def __init__(self, inputs, targets, batch_size, max_num_batches=-1,
shuffle_order=True, rng=None): random_sampling=True, rng=None):
"""Create a new data provider object. """Create a new data provider object.
Args: Args:
@ -28,23 +28,29 @@ class DataProvider(object):
in an epoch. If `max_num_batches * batch_size > num_data` then in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used. used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of random_sampling (bool): Whether to randomly permute the order of
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
""" """
self.inputs = inputs self.inputs = inputs
self.targets = targets self.targets = targets
if batch_size < 1: if batch_size < 1:
raise ValueError('batch_size must be >= 1') raise ValueError('batch_size must be >= 1')
self._batch_size = batch_size self._batch_size = batch_size
if max_num_batches == 0 or max_num_batches < -1: if max_num_batches == 0 or max_num_batches < -1:
raise ValueError('max_num_batches must be -1 or > 0') raise ValueError('max_num_batches must be -1 or > 0')
self._max_num_batches = max_num_batches self._max_num_batches = max_num_batches
self._update_num_batches() self._update_num_batches()
self.shuffle_order = shuffle_order self.random_sampling = random_sampling
self._current_order = np.arange(inputs.shape[0]) self._current_order = np.arange(inputs.shape[0])
if rng is None: if rng is None:
rng = np.random.RandomState(DEFAULT_SEED) rng = np.random.RandomState(DEFAULT_SEED)
self.rng = rng self.rng = rng
self.new_epoch() self.new_epoch()
@ -96,8 +102,6 @@ class DataProvider(object):
def new_epoch(self): def new_epoch(self):
"""Starts a new epoch (pass through data), possibly shuffling first.""" """Starts a new epoch (pass through data), possibly shuffling first."""
self._curr_batch = 0 self._curr_batch = 0
if self.shuffle_order:
self.shuffle()
def __next__(self): def __next__(self):
return self.next() return self.next()
@ -110,13 +114,6 @@ class DataProvider(object):
self.targets = self.targets[inv_perm] self.targets = self.targets[inv_perm]
self.new_epoch() self.new_epoch()
def shuffle(self):
"""Randomly shuffles order of data."""
perm = self.rng.permutation(self.inputs.shape[0])
self._current_order = self._current_order[perm]
self.inputs = self.inputs[perm]
self.targets = self.targets[perm]
def next(self): def next(self):
"""Returns next data batch or raises `StopIteration` if at end.""" """Returns next data batch or raises `StopIteration` if at end."""
if self._curr_batch + 1 > self.num_batches: if self._curr_batch + 1 > self.num_batches:
@ -125,8 +122,13 @@ class DataProvider(object):
self.new_epoch() self.new_epoch()
raise StopIteration() raise StopIteration()
# create an index slice corresponding to current batch number # create an index slice corresponding to current batch number
batch_slice = slice(self._curr_batch * self.batch_size, if self.random_sampling:
(self._curr_batch + 1) * self.batch_size) batch_slice = self.rng.choice(self.inputs.shape[0], size=self.batch_size, replace=False)
else:
batch_slice = slice(self._curr_batch * self.batch_size,
(self._curr_batch + 1) * self.batch_size)
inputs_batch = self.inputs[batch_slice] inputs_batch = self.inputs[batch_slice]
targets_batch = self.targets[batch_slice] targets_batch = self.targets[batch_slice]
self._curr_batch += 1 self._curr_batch += 1
@ -136,7 +138,7 @@ class MNISTDataProvider(DataProvider):
"""Data provider for MNIST handwritten digit images.""" """Data provider for MNIST handwritten digit images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None): random_sampling=True, rng=None):
"""Create a new MNIST data provider object. """Create a new MNIST data provider object.
Args: Args:
@ -147,7 +149,7 @@ class MNISTDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used. used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of random_sampling (bool): Whether to randomly permute the order of
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
""" """
@ -172,7 +174,7 @@ class MNISTDataProvider(DataProvider):
inputs = inputs.astype(np.float32) inputs = inputs.astype(np.float32)
# pass the loaded data to the parent class __init__ # pass the loaded data to the parent class __init__
super(MNISTDataProvider, self).__init__( super(MNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) inputs, targets, batch_size, max_num_batches, random_sampling, rng)
def next(self): def next(self):
"""Returns next data batch or raises `StopIteration` if at end.""" """Returns next data batch or raises `StopIteration` if at end."""
@ -202,7 +204,7 @@ class EMNISTDataProvider(DataProvider):
"""Data provider for EMNIST handwritten digit images.""" """Data provider for EMNIST handwritten digit images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, flatten=False, one_hot=False): random_sampling=True, rng=None, flatten=False, one_hot=False):
"""Create a new EMNIST data provider object. """Create a new EMNIST data provider object.
Args: Args:
@ -213,7 +215,7 @@ class EMNISTDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used. used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of random_sampling (bool): Whether to randomly permute the order of
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
""" """
@ -246,7 +248,7 @@ class EMNISTDataProvider(DataProvider):
# pass the loaded data to the parent class __init__ # pass the loaded data to the parent class __init__
super(EMNISTDataProvider, self).__init__( super(EMNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) inputs, targets, batch_size, max_num_batches, random_sampling, rng)
def next(self): def next(self):
"""Returns next data batch or raises `StopIteration` if at end.""" """Returns next data batch or raises `StopIteration` if at end."""
@ -279,7 +281,7 @@ class CIFAR10DataProvider(DataProvider):
"""Data provider for CIFAR-10 object images.""" """Data provider for CIFAR-10 object images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, flatten=False, one_hot=False): random_sampling=True, rng=None, flatten=False, one_hot=False):
"""Create a new EMNIST data provider object. """Create a new EMNIST data provider object.
Args: Args:
@ -290,7 +292,7 @@ class CIFAR10DataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used. used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of random_sampling (bool): Whether to randomly permute the order of
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
""" """
@ -327,7 +329,7 @@ class CIFAR10DataProvider(DataProvider):
# pass the loaded data to the parent class __init__ # pass the loaded data to the parent class __init__
super(CIFAR10DataProvider, self).__init__( super(CIFAR10DataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) inputs, targets, batch_size, max_num_batches, random_sampling, rng)
def next(self): def next(self):
"""Returns next data batch or raises `StopIteration` if at end.""" """Returns next data batch or raises `StopIteration` if at end."""
@ -362,7 +364,7 @@ class CIFAR100DataProvider(DataProvider):
"""Data provider for CIFAR-100 object images.""" """Data provider for CIFAR-100 object images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, flatten=False, one_hot=False): random_sampling=True, rng=None, flatten=False, one_hot=False):
"""Create a new EMNIST data provider object. """Create a new EMNIST data provider object.
Args: Args:
@ -373,7 +375,7 @@ class CIFAR100DataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used. used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of random_sampling (bool): Whether to randomly permute the order of
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
""" """
@ -407,7 +409,7 @@ class CIFAR100DataProvider(DataProvider):
# pass the loaded data to the parent class __init__ # pass the loaded data to the parent class __init__
super(CIFAR100DataProvider, self).__init__( super(CIFAR100DataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) inputs, targets, batch_size, max_num_batches, random_sampling, rng)
def next(self): def next(self):
"""Returns next data batch or raises `StopIteration` if at end.""" """Returns next data batch or raises `StopIteration` if at end."""
@ -441,7 +443,7 @@ class MSD10GenreDataProvider(DataProvider):
"""Data provider for Million Song Dataset 10-genre classification task.""" """Data provider for Million Song Dataset 10-genre classification task."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, one_hot=False, flatten=True): random_sampling=True, rng=None, one_hot=False, flatten=True):
"""Create a new EMNIST data provider object. """Create a new EMNIST data provider object.
Args: Args:
@ -452,7 +454,7 @@ class MSD10GenreDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used. used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of random_sampling (bool): Whether to randomly permute the order of
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
""" """
@ -500,7 +502,7 @@ class MSD10GenreDataProvider(DataProvider):
# pass the loaded data to the parent class __init__ # pass the loaded data to the parent class __init__
super(MSD10GenreDataProvider, self).__init__( super(MSD10GenreDataProvider, self).__init__(
inputs, target, batch_size, max_num_batches, shuffle_order, rng) inputs, target, batch_size, max_num_batches, random_sampling, rng)
def next(self): def next(self):
"""Returns next data batch or raises `StopIteration` if at end.""" """Returns next data batch or raises `StopIteration` if at end."""
@ -533,7 +535,7 @@ class MSD25GenreDataProvider(DataProvider):
"""Data provider for Million Song Dataset 25-genre classification task.""" """Data provider for Million Song Dataset 25-genre classification task."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, one_hot=False, flatten=True): random_sampling=True, rng=None, one_hot=False, flatten=True):
"""Create a new EMNIST data provider object. """Create a new EMNIST data provider object.
Args: Args:
@ -544,7 +546,7 @@ class MSD25GenreDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used. used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of random_sampling (bool): Whether to randomly permute the order of
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
""" """
@ -575,7 +577,7 @@ class MSD25GenreDataProvider(DataProvider):
#inputs, target #inputs, target
# pass the loaded data to the parent class __init__ # pass the loaded data to the parent class __init__
super(MSD25GenreDataProvider, self).__init__( super(MSD25GenreDataProvider, self).__init__(
inputs, target, batch_size, max_num_batches, shuffle_order, rng) inputs, target, batch_size, max_num_batches, random_sampling, rng)
def next(self): def next(self):
"""Returns next data batch or raises `StopIteration` if at end.""" """Returns next data batch or raises `StopIteration` if at end."""
@ -610,7 +612,7 @@ class MetOfficeDataProvider(DataProvider):
"""South Scotland Met Office weather data provider.""" """South Scotland Met Office weather data provider."""
def __init__(self, window_size, batch_size=10, max_num_batches=-1, def __init__(self, window_size, batch_size=10, max_num_batches=-1,
shuffle_order=True, rng=None): random_sampling=True, rng=None):
"""Create a new Met Office data provider object. """Create a new Met Office data provider object.
Args: Args:
@ -623,7 +625,7 @@ class MetOfficeDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used. used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of random_sampling (bool): Whether to randomly permute the order of
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
""" """
@ -651,12 +653,12 @@ class MetOfficeDataProvider(DataProvider):
# targets are last entry in windows # targets are last entry in windows
targets = windowed[:, -1] targets = windowed[:, -1]
super(MetOfficeDataProvider, self).__init__( super(MetOfficeDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) inputs, targets, batch_size, max_num_batches, random_sampling, rng)
class CCPPDataProvider(DataProvider): class CCPPDataProvider(DataProvider):
def __init__(self, which_set='train', input_dims=None, batch_size=10, def __init__(self, which_set='train', input_dims=None, batch_size=10,
max_num_batches=-1, shuffle_order=True, rng=None): max_num_batches=-1, random_sampling=True, rng=None):
"""Create a new Combined Cycle Power Plant data provider object. """Create a new Combined Cycle Power Plant data provider object.
Args: Args:
@ -671,7 +673,7 @@ class CCPPDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used. used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of random_sampling (bool): Whether to randomly permute the order of
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
""" """
@ -697,14 +699,14 @@ class CCPPDataProvider(DataProvider):
inputs = inputs[:, input_dims] inputs = inputs[:, input_dims]
targets = loaded[which_set + '_targets'] targets = loaded[which_set + '_targets']
super(CCPPDataProvider, self).__init__( super(CCPPDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) inputs, targets, batch_size, max_num_batches, random_sampling, rng)
class AugmentedMNISTDataProvider(MNISTDataProvider): class AugmentedMNISTDataProvider(MNISTDataProvider):
"""Data provider for MNIST dataset which randomly transforms images.""" """Data provider for MNIST dataset which randomly transforms images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, transformer=None): random_sampling=True, rng=None, transformer=None):
"""Create a new augmented MNIST data provider object. """Create a new augmented MNIST data provider object.
Args: Args:
@ -715,7 +717,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used. used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of random_sampling (bool): Whether to randomly permute the order of
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
transformer: Function which takes an `inputs` array of shape transformer: Function which takes an `inputs` array of shape
@ -727,7 +729,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
the data provider. the data provider.
""" """
super(AugmentedMNISTDataProvider, self).__init__( super(AugmentedMNISTDataProvider, self).__init__(
which_set, batch_size, max_num_batches, shuffle_order, rng) which_set, batch_size, max_num_batches, random_sampling, rng)
self.transformer = transformer self.transformer = transformer
def next(self): def next(self):