diff --git a/mlp/data_providers.py b/mlp/data_providers.py index bbd9632..c49e8f3 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -35,23 +35,54 @@ class DataProvider(object): """ self.inputs = inputs self.targets = targets - self.batch_size = batch_size - assert max_num_batches != 0 and not max_num_batches < -1, ( - 'max_num_batches should be -1 or > 0') - self.max_num_batches = max_num_batches + if batch_size < 1: + raise ValueError('batch_size must be >= 1') + self._batch_size = batch_size + if max_num_batches == 0 or max_num_batches < -1: + raise ValueError('max_num_batches must be -1 or > 0') + self._max_num_batches = max_num_batches + self._update_num_batches() + self.shuffle_order = shuffle_order + self._current_order = np.arange(inputs.shape[0]) + if rng is None: + rng = np.random.RandomState(DEFAULT_SEED) + self.rng = rng + self.new_epoch() + + @property + def batch_size(self): + """Number of data points to include in each batch.""" + return self._batch_size + + @batch_size.setter + def batch_size(self, value): + if value < 1: + raise ValueError('batch_size must be >= 1') + self._batch_size = value + self._update_num_batches() + + @property + def max_num_batches(self): + """Maximum number of batches to iterate over in an epoch.""" + return self._max_num_batches + + @max_num_batches.setter + def max_num_batches(self, value): + if value == 0 or value < -1: + raise ValueError('max_num_batches must be -1 or > 0') + self._max_num_batches = value + self._update_num_batches() + + def _update_num_batches(self): + """Updates number of batches to iterate over.""" # maximum possible number of batches is equal to number of whole times # batch_size divides in to the number of data points which can be # found using integer division - possible_num_batches = self.inputs.shape[0] // batch_size + possible_num_batches = self.inputs.shape[0] // self.batch_size if self.max_num_batches == -1: self.num_batches = possible_num_batches else: self.num_batches = min(self.max_num_batches, possible_num_batches) - self.shuffle_order = shuffle_order - if rng is None: - rng = np.random.RandomState(DEFAULT_SEED) - self.rng = rng - self.reset() def __iter__(self): """Implements Python iterator interface. @@ -63,24 +94,33 @@ class DataProvider(object): """ return self - def reset(self): - """Resets the provider to the initial state to use in a new epoch.""" + def new_epoch(self): + """Starts a new epoch (pass through data), possibly shuffling first.""" self._curr_batch = 0 if self.shuffle_order: self.shuffle() + def reset(self): + """Resets the provider to the initial state.""" + inv_perm = np.argsort(self._current_order) + self._current_order = self._current_order[inv_perm] + self.inputs = self.inputs[inv_perm] + self.targets = self.targets[inv_perm] + self.new_epoch() + def shuffle(self): """Randomly shuffles order of data.""" - new_order = self.rng.permutation(self.inputs.shape[0]) - self.inputs = self.inputs[new_order] - self.targets = self.targets[new_order] + perm = self.rng.permutation(self.inputs.shape[0]) + self._current_order = self._current_order[perm] + self.inputs = self.inputs[perm] + self.targets = self.targets[perm] def next(self): """Returns next data batch or raises `StopIteration` if at end.""" if self._curr_batch + 1 > self.num_batches: - # no more batches in current iteration through data set so reset - # the dataset for another pass and indicate iteration is at end - self.reset() + # no more batches in current iteration through data set so start + # new epoch ready for another pass and indicate iteration is at end + self.new_epoch() raise StopIteration() # create an index slice corresponding to current batch number batch_slice = slice(self._curr_batch * self.batch_size,