Updated data provider with better sampling

This commit is contained in:
AntreasAntoniou 2018-02-09 21:39:20 +00:00
parent ee4adeaa66
commit 4e22eefbce

View File

@ -15,7 +15,7 @@ class DataProvider(object):
"""Generic data provider."""
def __init__(self, inputs, targets, batch_size, max_num_batches=-1,
shuffle_order=True, rng=None):
random_sampling=True, rng=None):
"""Create a new data provider object.
Args:
@ -28,23 +28,29 @@ class DataProvider(object):
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
random_sampling (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
self.inputs = inputs
self.targets = targets
if batch_size < 1:
raise ValueError('batch_size must be >= 1')
self._batch_size = batch_size
if max_num_batches == 0 or max_num_batches < -1:
raise ValueError('max_num_batches must be -1 or > 0')
self._max_num_batches = max_num_batches
self._update_num_batches()
self.shuffle_order = shuffle_order
self.random_sampling = random_sampling
self._current_order = np.arange(inputs.shape[0])
if rng is None:
rng = np.random.RandomState(DEFAULT_SEED)
self.rng = rng
self.new_epoch()
@ -96,8 +102,6 @@ class DataProvider(object):
def new_epoch(self):
"""Starts a new epoch (pass through data), possibly shuffling first."""
self._curr_batch = 0
if self.shuffle_order:
self.shuffle()
def __next__(self):
return self.next()
@ -110,13 +114,6 @@ class DataProvider(object):
self.targets = self.targets[inv_perm]
self.new_epoch()
def shuffle(self):
"""Randomly shuffles order of data."""
perm = self.rng.permutation(self.inputs.shape[0])
self._current_order = self._current_order[perm]
self.inputs = self.inputs[perm]
self.targets = self.targets[perm]
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
if self._curr_batch + 1 > self.num_batches:
@ -125,8 +122,13 @@ class DataProvider(object):
self.new_epoch()
raise StopIteration()
# create an index slice corresponding to current batch number
if self.random_sampling:
batch_slice = self.rng.choice(self.inputs.shape[0], size=self.batch_size, replace=False)
else:
batch_slice = slice(self._curr_batch * self.batch_size,
(self._curr_batch + 1) * self.batch_size)
inputs_batch = self.inputs[batch_slice]
targets_batch = self.targets[batch_slice]
self._curr_batch += 1
@ -136,7 +138,7 @@ class MNISTDataProvider(DataProvider):
"""Data provider for MNIST handwritten digit images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None):
random_sampling=True, rng=None):
"""Create a new MNIST data provider object.
Args:
@ -147,7 +149,7 @@ class MNISTDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
random_sampling (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
@ -172,7 +174,7 @@ class MNISTDataProvider(DataProvider):
inputs = inputs.astype(np.float32)
# pass the loaded data to the parent class __init__
super(MNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
@ -202,7 +204,7 @@ class EMNISTDataProvider(DataProvider):
"""Data provider for EMNIST handwritten digit images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, flatten=False, one_hot=False):
random_sampling=True, rng=None, flatten=False, one_hot=False):
"""Create a new EMNIST data provider object.
Args:
@ -213,7 +215,7 @@ class EMNISTDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
random_sampling (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
@ -246,7 +248,7 @@ class EMNISTDataProvider(DataProvider):
# pass the loaded data to the parent class __init__
super(EMNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
@ -279,7 +281,7 @@ class CIFAR10DataProvider(DataProvider):
"""Data provider for CIFAR-10 object images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, flatten=False, one_hot=False):
random_sampling=True, rng=None, flatten=False, one_hot=False):
"""Create a new EMNIST data provider object.
Args:
@ -290,7 +292,7 @@ class CIFAR10DataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
random_sampling (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
@ -327,7 +329,7 @@ class CIFAR10DataProvider(DataProvider):
# pass the loaded data to the parent class __init__
super(CIFAR10DataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
@ -362,7 +364,7 @@ class CIFAR100DataProvider(DataProvider):
"""Data provider for CIFAR-100 object images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, flatten=False, one_hot=False):
random_sampling=True, rng=None, flatten=False, one_hot=False):
"""Create a new EMNIST data provider object.
Args:
@ -373,7 +375,7 @@ class CIFAR100DataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
random_sampling (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
@ -407,7 +409,7 @@ class CIFAR100DataProvider(DataProvider):
# pass the loaded data to the parent class __init__
super(CIFAR100DataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
@ -441,7 +443,7 @@ class MSD10GenreDataProvider(DataProvider):
"""Data provider for Million Song Dataset 10-genre classification task."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, one_hot=False, flatten=True):
random_sampling=True, rng=None, one_hot=False, flatten=True):
"""Create a new EMNIST data provider object.
Args:
@ -452,7 +454,7 @@ class MSD10GenreDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
random_sampling (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
@ -500,7 +502,7 @@ class MSD10GenreDataProvider(DataProvider):
# pass the loaded data to the parent class __init__
super(MSD10GenreDataProvider, self).__init__(
inputs, target, batch_size, max_num_batches, shuffle_order, rng)
inputs, target, batch_size, max_num_batches, random_sampling, rng)
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
@ -533,7 +535,7 @@ class MSD25GenreDataProvider(DataProvider):
"""Data provider for Million Song Dataset 25-genre classification task."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, one_hot=False, flatten=True):
random_sampling=True, rng=None, one_hot=False, flatten=True):
"""Create a new EMNIST data provider object.
Args:
@ -544,7 +546,7 @@ class MSD25GenreDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
random_sampling (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
@ -575,7 +577,7 @@ class MSD25GenreDataProvider(DataProvider):
#inputs, target
# pass the loaded data to the parent class __init__
super(MSD25GenreDataProvider, self).__init__(
inputs, target, batch_size, max_num_batches, shuffle_order, rng)
inputs, target, batch_size, max_num_batches, random_sampling, rng)
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
@ -610,7 +612,7 @@ class MetOfficeDataProvider(DataProvider):
"""South Scotland Met Office weather data provider."""
def __init__(self, window_size, batch_size=10, max_num_batches=-1,
shuffle_order=True, rng=None):
random_sampling=True, rng=None):
"""Create a new Met Office data provider object.
Args:
@ -623,7 +625,7 @@ class MetOfficeDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
random_sampling (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
@ -651,12 +653,12 @@ class MetOfficeDataProvider(DataProvider):
# targets are last entry in windows
targets = windowed[:, -1]
super(MetOfficeDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
class CCPPDataProvider(DataProvider):
def __init__(self, which_set='train', input_dims=None, batch_size=10,
max_num_batches=-1, shuffle_order=True, rng=None):
max_num_batches=-1, random_sampling=True, rng=None):
"""Create a new Combined Cycle Power Plant data provider object.
Args:
@ -671,7 +673,7 @@ class CCPPDataProvider(DataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
random_sampling (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
@ -697,14 +699,14 @@ class CCPPDataProvider(DataProvider):
inputs = inputs[:, input_dims]
targets = loaded[which_set + '_targets']
super(CCPPDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
class AugmentedMNISTDataProvider(MNISTDataProvider):
"""Data provider for MNIST dataset which randomly transforms images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None, transformer=None):
random_sampling=True, rng=None, transformer=None):
"""Create a new augmented MNIST data provider object.
Args:
@ -715,7 +717,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
random_sampling (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
transformer: Function which takes an `inputs` array of shape
@ -727,7 +729,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
the data provider.
"""
super(AugmentedMNISTDataProvider, self).__init__(
which_set, batch_size, max_num_batches, shuffle_order, rng)
which_set, batch_size, max_num_batches, random_sampling, rng)
self.transformer = transformer
def next(self):