diff --git a/data_providers.py b/data_providers.py index c93cbde..345c478 100644 --- a/data_providers.py +++ b/data_providers.py @@ -15,7 +15,7 @@ class DataProvider(object): """Generic data provider.""" def __init__(self, inputs, targets, batch_size, max_num_batches=-1, - shuffle_order=True, rng=None): + random_sampling=True, rng=None): """Create a new data provider object. Args: @@ -28,23 +28,29 @@ class DataProvider(object): in an epoch. If `max_num_batches * batch_size > num_data` then only as many batches as the data can be split into will be used. If set to -1 all of the data will be used. - shuffle_order (bool): Whether to randomly permute the order of + random_sampling (bool): Whether to randomly permute the order of the data before each epoch. rng (RandomState): A seeded random number generator. """ self.inputs = inputs self.targets = targets + if batch_size < 1: raise ValueError('batch_size must be >= 1') + self._batch_size = batch_size + if max_num_batches == 0 or max_num_batches < -1: raise ValueError('max_num_batches must be -1 or > 0') + self._max_num_batches = max_num_batches self._update_num_batches() - self.shuffle_order = shuffle_order + self.random_sampling = random_sampling self._current_order = np.arange(inputs.shape[0]) + if rng is None: rng = np.random.RandomState(DEFAULT_SEED) + self.rng = rng self.new_epoch() @@ -96,8 +102,6 @@ class DataProvider(object): def new_epoch(self): """Starts a new epoch (pass through data), possibly shuffling first.""" self._curr_batch = 0 - if self.shuffle_order: - self.shuffle() def __next__(self): return self.next() @@ -110,13 +114,6 @@ class DataProvider(object): self.targets = self.targets[inv_perm] self.new_epoch() - def shuffle(self): - """Randomly shuffles order of data.""" - perm = self.rng.permutation(self.inputs.shape[0]) - self._current_order = self._current_order[perm] - self.inputs = self.inputs[perm] - self.targets = self.targets[perm] - def next(self): """Returns next data batch or raises `StopIteration` if at end.""" if self._curr_batch + 1 > self.num_batches: @@ -125,8 +122,13 @@ class DataProvider(object): self.new_epoch() raise StopIteration() # create an index slice corresponding to current batch number - batch_slice = slice(self._curr_batch * self.batch_size, - (self._curr_batch + 1) * self.batch_size) + if self.random_sampling: + batch_slice = self.rng.choice(self.inputs.shape[0], size=self.batch_size, replace=False) + + else: + batch_slice = slice(self._curr_batch * self.batch_size, + (self._curr_batch + 1) * self.batch_size) + inputs_batch = self.inputs[batch_slice] targets_batch = self.targets[batch_slice] self._curr_batch += 1 @@ -136,7 +138,7 @@ class MNISTDataProvider(DataProvider): """Data provider for MNIST handwritten digit images.""" def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, - shuffle_order=True, rng=None): + random_sampling=True, rng=None): """Create a new MNIST data provider object. Args: @@ -147,7 +149,7 @@ class MNISTDataProvider(DataProvider): in an epoch. If `max_num_batches * batch_size > num_data` then only as many batches as the data can be split into will be used. If set to -1 all of the data will be used. - shuffle_order (bool): Whether to randomly permute the order of + random_sampling (bool): Whether to randomly permute the order of the data before each epoch. rng (RandomState): A seeded random number generator. """ @@ -172,7 +174,7 @@ class MNISTDataProvider(DataProvider): inputs = inputs.astype(np.float32) # pass the loaded data to the parent class __init__ super(MNISTDataProvider, self).__init__( - inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + inputs, targets, batch_size, max_num_batches, random_sampling, rng) def next(self): """Returns next data batch or raises `StopIteration` if at end.""" @@ -202,7 +204,7 @@ class EMNISTDataProvider(DataProvider): """Data provider for EMNIST handwritten digit images.""" def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, - shuffle_order=True, rng=None, flatten=False, one_hot=False): + random_sampling=True, rng=None, flatten=False, one_hot=False): """Create a new EMNIST data provider object. Args: @@ -213,7 +215,7 @@ class EMNISTDataProvider(DataProvider): in an epoch. If `max_num_batches * batch_size > num_data` then only as many batches as the data can be split into will be used. If set to -1 all of the data will be used. - shuffle_order (bool): Whether to randomly permute the order of + random_sampling (bool): Whether to randomly permute the order of the data before each epoch. rng (RandomState): A seeded random number generator. """ @@ -246,7 +248,7 @@ class EMNISTDataProvider(DataProvider): # pass the loaded data to the parent class __init__ super(EMNISTDataProvider, self).__init__( - inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + inputs, targets, batch_size, max_num_batches, random_sampling, rng) def next(self): """Returns next data batch or raises `StopIteration` if at end.""" @@ -279,7 +281,7 @@ class CIFAR10DataProvider(DataProvider): """Data provider for CIFAR-10 object images.""" def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, - shuffle_order=True, rng=None, flatten=False, one_hot=False): + random_sampling=True, rng=None, flatten=False, one_hot=False): """Create a new EMNIST data provider object. Args: @@ -290,7 +292,7 @@ class CIFAR10DataProvider(DataProvider): in an epoch. If `max_num_batches * batch_size > num_data` then only as many batches as the data can be split into will be used. If set to -1 all of the data will be used. - shuffle_order (bool): Whether to randomly permute the order of + random_sampling (bool): Whether to randomly permute the order of the data before each epoch. rng (RandomState): A seeded random number generator. """ @@ -327,7 +329,7 @@ class CIFAR10DataProvider(DataProvider): # pass the loaded data to the parent class __init__ super(CIFAR10DataProvider, self).__init__( - inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + inputs, targets, batch_size, max_num_batches, random_sampling, rng) def next(self): """Returns next data batch or raises `StopIteration` if at end.""" @@ -362,7 +364,7 @@ class CIFAR100DataProvider(DataProvider): """Data provider for CIFAR-100 object images.""" def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, - shuffle_order=True, rng=None, flatten=False, one_hot=False): + random_sampling=True, rng=None, flatten=False, one_hot=False): """Create a new EMNIST data provider object. Args: @@ -373,7 +375,7 @@ class CIFAR100DataProvider(DataProvider): in an epoch. If `max_num_batches * batch_size > num_data` then only as many batches as the data can be split into will be used. If set to -1 all of the data will be used. - shuffle_order (bool): Whether to randomly permute the order of + random_sampling (bool): Whether to randomly permute the order of the data before each epoch. rng (RandomState): A seeded random number generator. """ @@ -407,7 +409,7 @@ class CIFAR100DataProvider(DataProvider): # pass the loaded data to the parent class __init__ super(CIFAR100DataProvider, self).__init__( - inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + inputs, targets, batch_size, max_num_batches, random_sampling, rng) def next(self): """Returns next data batch or raises `StopIteration` if at end.""" @@ -441,7 +443,7 @@ class MSD10GenreDataProvider(DataProvider): """Data provider for Million Song Dataset 10-genre classification task.""" def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, - shuffle_order=True, rng=None, one_hot=False, flatten=True): + random_sampling=True, rng=None, one_hot=False, flatten=True): """Create a new EMNIST data provider object. Args: @@ -452,7 +454,7 @@ class MSD10GenreDataProvider(DataProvider): in an epoch. If `max_num_batches * batch_size > num_data` then only as many batches as the data can be split into will be used. If set to -1 all of the data will be used. - shuffle_order (bool): Whether to randomly permute the order of + random_sampling (bool): Whether to randomly permute the order of the data before each epoch. rng (RandomState): A seeded random number generator. """ @@ -500,7 +502,7 @@ class MSD10GenreDataProvider(DataProvider): # pass the loaded data to the parent class __init__ super(MSD10GenreDataProvider, self).__init__( - inputs, target, batch_size, max_num_batches, shuffle_order, rng) + inputs, target, batch_size, max_num_batches, random_sampling, rng) def next(self): """Returns next data batch or raises `StopIteration` if at end.""" @@ -533,7 +535,7 @@ class MSD25GenreDataProvider(DataProvider): """Data provider for Million Song Dataset 25-genre classification task.""" def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, - shuffle_order=True, rng=None, one_hot=False, flatten=True): + random_sampling=True, rng=None, one_hot=False, flatten=True): """Create a new EMNIST data provider object. Args: @@ -544,7 +546,7 @@ class MSD25GenreDataProvider(DataProvider): in an epoch. If `max_num_batches * batch_size > num_data` then only as many batches as the data can be split into will be used. If set to -1 all of the data will be used. - shuffle_order (bool): Whether to randomly permute the order of + random_sampling (bool): Whether to randomly permute the order of the data before each epoch. rng (RandomState): A seeded random number generator. """ @@ -575,7 +577,7 @@ class MSD25GenreDataProvider(DataProvider): #inputs, target # pass the loaded data to the parent class __init__ super(MSD25GenreDataProvider, self).__init__( - inputs, target, batch_size, max_num_batches, shuffle_order, rng) + inputs, target, batch_size, max_num_batches, random_sampling, rng) def next(self): """Returns next data batch or raises `StopIteration` if at end.""" @@ -610,7 +612,7 @@ class MetOfficeDataProvider(DataProvider): """South Scotland Met Office weather data provider.""" def __init__(self, window_size, batch_size=10, max_num_batches=-1, - shuffle_order=True, rng=None): + random_sampling=True, rng=None): """Create a new Met Office data provider object. Args: @@ -623,7 +625,7 @@ class MetOfficeDataProvider(DataProvider): in an epoch. If `max_num_batches * batch_size > num_data` then only as many batches as the data can be split into will be used. If set to -1 all of the data will be used. - shuffle_order (bool): Whether to randomly permute the order of + random_sampling (bool): Whether to randomly permute the order of the data before each epoch. rng (RandomState): A seeded random number generator. """ @@ -651,12 +653,12 @@ class MetOfficeDataProvider(DataProvider): # targets are last entry in windows targets = windowed[:, -1] super(MetOfficeDataProvider, self).__init__( - inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + inputs, targets, batch_size, max_num_batches, random_sampling, rng) class CCPPDataProvider(DataProvider): def __init__(self, which_set='train', input_dims=None, batch_size=10, - max_num_batches=-1, shuffle_order=True, rng=None): + max_num_batches=-1, random_sampling=True, rng=None): """Create a new Combined Cycle Power Plant data provider object. Args: @@ -671,7 +673,7 @@ class CCPPDataProvider(DataProvider): in an epoch. If `max_num_batches * batch_size > num_data` then only as many batches as the data can be split into will be used. If set to -1 all of the data will be used. - shuffle_order (bool): Whether to randomly permute the order of + random_sampling (bool): Whether to randomly permute the order of the data before each epoch. rng (RandomState): A seeded random number generator. """ @@ -697,14 +699,14 @@ class CCPPDataProvider(DataProvider): inputs = inputs[:, input_dims] targets = loaded[which_set + '_targets'] super(CCPPDataProvider, self).__init__( - inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + inputs, targets, batch_size, max_num_batches, random_sampling, rng) class AugmentedMNISTDataProvider(MNISTDataProvider): """Data provider for MNIST dataset which randomly transforms images.""" def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, - shuffle_order=True, rng=None, transformer=None): + random_sampling=True, rng=None, transformer=None): """Create a new augmented MNIST data provider object. Args: @@ -715,7 +717,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider): in an epoch. If `max_num_batches * batch_size > num_data` then only as many batches as the data can be split into will be used. If set to -1 all of the data will be used. - shuffle_order (bool): Whether to randomly permute the order of + random_sampling (bool): Whether to randomly permute the order of the data before each epoch. rng (RandomState): A seeded random number generator. transformer: Function which takes an `inputs` array of shape @@ -727,7 +729,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider): the data provider. """ super(AugmentedMNISTDataProvider, self).__init__( - which_set, batch_size, max_num_batches, shuffle_order, rng) + which_set, batch_size, max_num_batches, random_sampling, rng) self.transformer = transformer def next(self):