Updated data provider with better sampling
This commit is contained in:
parent
ee4adeaa66
commit
4e22eefbce
@ -15,7 +15,7 @@ class DataProvider(object):
|
||||
"""Generic data provider."""
|
||||
|
||||
def __init__(self, inputs, targets, batch_size, max_num_batches=-1,
|
||||
shuffle_order=True, rng=None):
|
||||
random_sampling=True, rng=None):
|
||||
"""Create a new data provider object.
|
||||
|
||||
Args:
|
||||
@ -28,23 +28,29 @@ class DataProvider(object):
|
||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||
only as many batches as the data can be split into will be
|
||||
used. If set to -1 all of the data will be used.
|
||||
shuffle_order (bool): Whether to randomly permute the order of
|
||||
random_sampling (bool): Whether to randomly permute the order of
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
"""
|
||||
self.inputs = inputs
|
||||
self.targets = targets
|
||||
|
||||
if batch_size < 1:
|
||||
raise ValueError('batch_size must be >= 1')
|
||||
|
||||
self._batch_size = batch_size
|
||||
|
||||
if max_num_batches == 0 or max_num_batches < -1:
|
||||
raise ValueError('max_num_batches must be -1 or > 0')
|
||||
|
||||
self._max_num_batches = max_num_batches
|
||||
self._update_num_batches()
|
||||
self.shuffle_order = shuffle_order
|
||||
self.random_sampling = random_sampling
|
||||
self._current_order = np.arange(inputs.shape[0])
|
||||
|
||||
if rng is None:
|
||||
rng = np.random.RandomState(DEFAULT_SEED)
|
||||
|
||||
self.rng = rng
|
||||
self.new_epoch()
|
||||
|
||||
@ -96,8 +102,6 @@ class DataProvider(object):
|
||||
def new_epoch(self):
|
||||
"""Starts a new epoch (pass through data), possibly shuffling first."""
|
||||
self._curr_batch = 0
|
||||
if self.shuffle_order:
|
||||
self.shuffle()
|
||||
|
||||
def __next__(self):
|
||||
return self.next()
|
||||
@ -110,13 +114,6 @@ class DataProvider(object):
|
||||
self.targets = self.targets[inv_perm]
|
||||
self.new_epoch()
|
||||
|
||||
def shuffle(self):
|
||||
"""Randomly shuffles order of data."""
|
||||
perm = self.rng.permutation(self.inputs.shape[0])
|
||||
self._current_order = self._current_order[perm]
|
||||
self.inputs = self.inputs[perm]
|
||||
self.targets = self.targets[perm]
|
||||
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
if self._curr_batch + 1 > self.num_batches:
|
||||
@ -125,8 +122,13 @@ class DataProvider(object):
|
||||
self.new_epoch()
|
||||
raise StopIteration()
|
||||
# create an index slice corresponding to current batch number
|
||||
if self.random_sampling:
|
||||
batch_slice = self.rng.choice(self.inputs.shape[0], size=self.batch_size, replace=False)
|
||||
|
||||
else:
|
||||
batch_slice = slice(self._curr_batch * self.batch_size,
|
||||
(self._curr_batch + 1) * self.batch_size)
|
||||
|
||||
inputs_batch = self.inputs[batch_slice]
|
||||
targets_batch = self.targets[batch_slice]
|
||||
self._curr_batch += 1
|
||||
@ -136,7 +138,7 @@ class MNISTDataProvider(DataProvider):
|
||||
"""Data provider for MNIST handwritten digit images."""
|
||||
|
||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||
shuffle_order=True, rng=None):
|
||||
random_sampling=True, rng=None):
|
||||
"""Create a new MNIST data provider object.
|
||||
|
||||
Args:
|
||||
@ -147,7 +149,7 @@ class MNISTDataProvider(DataProvider):
|
||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||
only as many batches as the data can be split into will be
|
||||
used. If set to -1 all of the data will be used.
|
||||
shuffle_order (bool): Whether to randomly permute the order of
|
||||
random_sampling (bool): Whether to randomly permute the order of
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
"""
|
||||
@ -172,7 +174,7 @@ class MNISTDataProvider(DataProvider):
|
||||
inputs = inputs.astype(np.float32)
|
||||
# pass the loaded data to the parent class __init__
|
||||
super(MNISTDataProvider, self).__init__(
|
||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
@ -202,7 +204,7 @@ class EMNISTDataProvider(DataProvider):
|
||||
"""Data provider for EMNIST handwritten digit images."""
|
||||
|
||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||
shuffle_order=True, rng=None, flatten=False, one_hot=False):
|
||||
random_sampling=True, rng=None, flatten=False, one_hot=False):
|
||||
"""Create a new EMNIST data provider object.
|
||||
|
||||
Args:
|
||||
@ -213,7 +215,7 @@ class EMNISTDataProvider(DataProvider):
|
||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||
only as many batches as the data can be split into will be
|
||||
used. If set to -1 all of the data will be used.
|
||||
shuffle_order (bool): Whether to randomly permute the order of
|
||||
random_sampling (bool): Whether to randomly permute the order of
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
"""
|
||||
@ -246,7 +248,7 @@ class EMNISTDataProvider(DataProvider):
|
||||
|
||||
# pass the loaded data to the parent class __init__
|
||||
super(EMNISTDataProvider, self).__init__(
|
||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
@ -279,7 +281,7 @@ class CIFAR10DataProvider(DataProvider):
|
||||
"""Data provider for CIFAR-10 object images."""
|
||||
|
||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||
shuffle_order=True, rng=None, flatten=False, one_hot=False):
|
||||
random_sampling=True, rng=None, flatten=False, one_hot=False):
|
||||
"""Create a new EMNIST data provider object.
|
||||
|
||||
Args:
|
||||
@ -290,7 +292,7 @@ class CIFAR10DataProvider(DataProvider):
|
||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||
only as many batches as the data can be split into will be
|
||||
used. If set to -1 all of the data will be used.
|
||||
shuffle_order (bool): Whether to randomly permute the order of
|
||||
random_sampling (bool): Whether to randomly permute the order of
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
"""
|
||||
@ -327,7 +329,7 @@ class CIFAR10DataProvider(DataProvider):
|
||||
|
||||
# pass the loaded data to the parent class __init__
|
||||
super(CIFAR10DataProvider, self).__init__(
|
||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
@ -362,7 +364,7 @@ class CIFAR100DataProvider(DataProvider):
|
||||
"""Data provider for CIFAR-100 object images."""
|
||||
|
||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||
shuffle_order=True, rng=None, flatten=False, one_hot=False):
|
||||
random_sampling=True, rng=None, flatten=False, one_hot=False):
|
||||
"""Create a new EMNIST data provider object.
|
||||
|
||||
Args:
|
||||
@ -373,7 +375,7 @@ class CIFAR100DataProvider(DataProvider):
|
||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||
only as many batches as the data can be split into will be
|
||||
used. If set to -1 all of the data will be used.
|
||||
shuffle_order (bool): Whether to randomly permute the order of
|
||||
random_sampling (bool): Whether to randomly permute the order of
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
"""
|
||||
@ -407,7 +409,7 @@ class CIFAR100DataProvider(DataProvider):
|
||||
|
||||
# pass the loaded data to the parent class __init__
|
||||
super(CIFAR100DataProvider, self).__init__(
|
||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
@ -441,7 +443,7 @@ class MSD10GenreDataProvider(DataProvider):
|
||||
"""Data provider for Million Song Dataset 10-genre classification task."""
|
||||
|
||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||
shuffle_order=True, rng=None, one_hot=False, flatten=True):
|
||||
random_sampling=True, rng=None, one_hot=False, flatten=True):
|
||||
"""Create a new EMNIST data provider object.
|
||||
|
||||
Args:
|
||||
@ -452,7 +454,7 @@ class MSD10GenreDataProvider(DataProvider):
|
||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||
only as many batches as the data can be split into will be
|
||||
used. If set to -1 all of the data will be used.
|
||||
shuffle_order (bool): Whether to randomly permute the order of
|
||||
random_sampling (bool): Whether to randomly permute the order of
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
"""
|
||||
@ -500,7 +502,7 @@ class MSD10GenreDataProvider(DataProvider):
|
||||
|
||||
# pass the loaded data to the parent class __init__
|
||||
super(MSD10GenreDataProvider, self).__init__(
|
||||
inputs, target, batch_size, max_num_batches, shuffle_order, rng)
|
||||
inputs, target, batch_size, max_num_batches, random_sampling, rng)
|
||||
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
@ -533,7 +535,7 @@ class MSD25GenreDataProvider(DataProvider):
|
||||
"""Data provider for Million Song Dataset 25-genre classification task."""
|
||||
|
||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||
shuffle_order=True, rng=None, one_hot=False, flatten=True):
|
||||
random_sampling=True, rng=None, one_hot=False, flatten=True):
|
||||
"""Create a new EMNIST data provider object.
|
||||
|
||||
Args:
|
||||
@ -544,7 +546,7 @@ class MSD25GenreDataProvider(DataProvider):
|
||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||
only as many batches as the data can be split into will be
|
||||
used. If set to -1 all of the data will be used.
|
||||
shuffle_order (bool): Whether to randomly permute the order of
|
||||
random_sampling (bool): Whether to randomly permute the order of
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
"""
|
||||
@ -575,7 +577,7 @@ class MSD25GenreDataProvider(DataProvider):
|
||||
#inputs, target
|
||||
# pass the loaded data to the parent class __init__
|
||||
super(MSD25GenreDataProvider, self).__init__(
|
||||
inputs, target, batch_size, max_num_batches, shuffle_order, rng)
|
||||
inputs, target, batch_size, max_num_batches, random_sampling, rng)
|
||||
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
@ -610,7 +612,7 @@ class MetOfficeDataProvider(DataProvider):
|
||||
"""South Scotland Met Office weather data provider."""
|
||||
|
||||
def __init__(self, window_size, batch_size=10, max_num_batches=-1,
|
||||
shuffle_order=True, rng=None):
|
||||
random_sampling=True, rng=None):
|
||||
"""Create a new Met Office data provider object.
|
||||
|
||||
Args:
|
||||
@ -623,7 +625,7 @@ class MetOfficeDataProvider(DataProvider):
|
||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||
only as many batches as the data can be split into will be
|
||||
used. If set to -1 all of the data will be used.
|
||||
shuffle_order (bool): Whether to randomly permute the order of
|
||||
random_sampling (bool): Whether to randomly permute the order of
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
"""
|
||||
@ -651,12 +653,12 @@ class MetOfficeDataProvider(DataProvider):
|
||||
# targets are last entry in windows
|
||||
targets = windowed[:, -1]
|
||||
super(MetOfficeDataProvider, self).__init__(
|
||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||
|
||||
class CCPPDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, which_set='train', input_dims=None, batch_size=10,
|
||||
max_num_batches=-1, shuffle_order=True, rng=None):
|
||||
max_num_batches=-1, random_sampling=True, rng=None):
|
||||
"""Create a new Combined Cycle Power Plant data provider object.
|
||||
|
||||
Args:
|
||||
@ -671,7 +673,7 @@ class CCPPDataProvider(DataProvider):
|
||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||
only as many batches as the data can be split into will be
|
||||
used. If set to -1 all of the data will be used.
|
||||
shuffle_order (bool): Whether to randomly permute the order of
|
||||
random_sampling (bool): Whether to randomly permute the order of
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
"""
|
||||
@ -697,14 +699,14 @@ class CCPPDataProvider(DataProvider):
|
||||
inputs = inputs[:, input_dims]
|
||||
targets = loaded[which_set + '_targets']
|
||||
super(CCPPDataProvider, self).__init__(
|
||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||
|
||||
|
||||
class AugmentedMNISTDataProvider(MNISTDataProvider):
|
||||
"""Data provider for MNIST dataset which randomly transforms images."""
|
||||
|
||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||
shuffle_order=True, rng=None, transformer=None):
|
||||
random_sampling=True, rng=None, transformer=None):
|
||||
"""Create a new augmented MNIST data provider object.
|
||||
|
||||
Args:
|
||||
@ -715,7 +717,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
|
||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||
only as many batches as the data can be split into will be
|
||||
used. If set to -1 all of the data will be used.
|
||||
shuffle_order (bool): Whether to randomly permute the order of
|
||||
random_sampling (bool): Whether to randomly permute the order of
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
transformer: Function which takes an `inputs` array of shape
|
||||
@ -727,7 +729,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
|
||||
the data provider.
|
||||
"""
|
||||
super(AugmentedMNISTDataProvider, self).__init__(
|
||||
which_set, batch_size, max_num_batches, shuffle_order, rng)
|
||||
which_set, batch_size, max_num_batches, random_sampling, rng)
|
||||
self.transformer = transformer
|
||||
|
||||
def next(self):
|
||||
|
Loading…
Reference in New Issue
Block a user