Adding ability to adjust batch size after initialisation and resetting to initial state (order).
This commit is contained in:
parent
f979aae573
commit
8d6f37668f
@ -35,23 +35,54 @@ class DataProvider(object):
|
||||
"""
|
||||
self.inputs = inputs
|
||||
self.targets = targets
|
||||
self.batch_size = batch_size
|
||||
assert max_num_batches != 0 and not max_num_batches < -1, (
|
||||
'max_num_batches should be -1 or > 0')
|
||||
self.max_num_batches = max_num_batches
|
||||
if batch_size < 1:
|
||||
raise ValueError('batch_size must be >= 1')
|
||||
self._batch_size = batch_size
|
||||
if max_num_batches == 0 or max_num_batches < -1:
|
||||
raise ValueError('max_num_batches must be -1 or > 0')
|
||||
self._max_num_batches = max_num_batches
|
||||
self._update_num_batches()
|
||||
self.shuffle_order = shuffle_order
|
||||
self._current_order = np.arange(inputs.shape[0])
|
||||
if rng is None:
|
||||
rng = np.random.RandomState(DEFAULT_SEED)
|
||||
self.rng = rng
|
||||
self.new_epoch()
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
"""Number of data points to include in each batch."""
|
||||
return self._batch_size
|
||||
|
||||
@batch_size.setter
|
||||
def batch_size(self, value):
|
||||
if value < 1:
|
||||
raise ValueError('batch_size must be >= 1')
|
||||
self._batch_size = value
|
||||
self._update_num_batches()
|
||||
|
||||
@property
|
||||
def max_num_batches(self):
|
||||
"""Maximum number of batches to iterate over in an epoch."""
|
||||
return self._max_num_batches
|
||||
|
||||
@max_num_batches.setter
|
||||
def max_num_batches(self, value):
|
||||
if value == 0 or value < -1:
|
||||
raise ValueError('max_num_batches must be -1 or > 0')
|
||||
self._max_num_batches = value
|
||||
self._update_num_batches()
|
||||
|
||||
def _update_num_batches(self):
|
||||
"""Updates number of batches to iterate over."""
|
||||
# maximum possible number of batches is equal to number of whole times
|
||||
# batch_size divides in to the number of data points which can be
|
||||
# found using integer division
|
||||
possible_num_batches = self.inputs.shape[0] // batch_size
|
||||
possible_num_batches = self.inputs.shape[0] // self.batch_size
|
||||
if self.max_num_batches == -1:
|
||||
self.num_batches = possible_num_batches
|
||||
else:
|
||||
self.num_batches = min(self.max_num_batches, possible_num_batches)
|
||||
self.shuffle_order = shuffle_order
|
||||
if rng is None:
|
||||
rng = np.random.RandomState(DEFAULT_SEED)
|
||||
self.rng = rng
|
||||
self.reset()
|
||||
|
||||
def __iter__(self):
|
||||
"""Implements Python iterator interface.
|
||||
@ -63,24 +94,33 @@ class DataProvider(object):
|
||||
"""
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
"""Resets the provider to the initial state to use in a new epoch."""
|
||||
def new_epoch(self):
|
||||
"""Starts a new epoch (pass through data), possibly shuffling first."""
|
||||
self._curr_batch = 0
|
||||
if self.shuffle_order:
|
||||
self.shuffle()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the provider to the initial state."""
|
||||
inv_perm = np.argsort(self._current_order)
|
||||
self._current_order = self._current_order[inv_perm]
|
||||
self.inputs = self.inputs[inv_perm]
|
||||
self.targets = self.targets[inv_perm]
|
||||
self.new_epoch()
|
||||
|
||||
def shuffle(self):
|
||||
"""Randomly shuffles order of data."""
|
||||
new_order = self.rng.permutation(self.inputs.shape[0])
|
||||
self.inputs = self.inputs[new_order]
|
||||
self.targets = self.targets[new_order]
|
||||
perm = self.rng.permutation(self.inputs.shape[0])
|
||||
self._current_order = self._current_order[perm]
|
||||
self.inputs = self.inputs[perm]
|
||||
self.targets = self.targets[perm]
|
||||
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
if self._curr_batch + 1 > self.num_batches:
|
||||
# no more batches in current iteration through data set so reset
|
||||
# the dataset for another pass and indicate iteration is at end
|
||||
self.reset()
|
||||
# no more batches in current iteration through data set so start
|
||||
# new epoch ready for another pass and indicate iteration is at end
|
||||
self.new_epoch()
|
||||
raise StopIteration()
|
||||
# create an index slice corresponding to current batch number
|
||||
batch_slice = slice(self._curr_batch * self.batch_size,
|
||||
|
Loading…
Reference in New Issue
Block a user