Adding ability to adjust batch size after initialisation and resetting to initial state (order).
This commit is contained in:
parent
f979aae573
commit
8d6f37668f
@ -35,23 +35,54 @@ class DataProvider(object):
|
|||||||
"""
|
"""
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.batch_size = batch_size
|
if batch_size < 1:
|
||||||
assert max_num_batches != 0 and not max_num_batches < -1, (
|
raise ValueError('batch_size must be >= 1')
|
||||||
'max_num_batches should be -1 or > 0')
|
self._batch_size = batch_size
|
||||||
self.max_num_batches = max_num_batches
|
if max_num_batches == 0 or max_num_batches < -1:
|
||||||
|
raise ValueError('max_num_batches must be -1 or > 0')
|
||||||
|
self._max_num_batches = max_num_batches
|
||||||
|
self._update_num_batches()
|
||||||
|
self.shuffle_order = shuffle_order
|
||||||
|
self._current_order = np.arange(inputs.shape[0])
|
||||||
|
if rng is None:
|
||||||
|
rng = np.random.RandomState(DEFAULT_SEED)
|
||||||
|
self.rng = rng
|
||||||
|
self.new_epoch()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size(self):
|
||||||
|
"""Number of data points to include in each batch."""
|
||||||
|
return self._batch_size
|
||||||
|
|
||||||
|
@batch_size.setter
|
||||||
|
def batch_size(self, value):
|
||||||
|
if value < 1:
|
||||||
|
raise ValueError('batch_size must be >= 1')
|
||||||
|
self._batch_size = value
|
||||||
|
self._update_num_batches()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_num_batches(self):
|
||||||
|
"""Maximum number of batches to iterate over in an epoch."""
|
||||||
|
return self._max_num_batches
|
||||||
|
|
||||||
|
@max_num_batches.setter
|
||||||
|
def max_num_batches(self, value):
|
||||||
|
if value == 0 or value < -1:
|
||||||
|
raise ValueError('max_num_batches must be -1 or > 0')
|
||||||
|
self._max_num_batches = value
|
||||||
|
self._update_num_batches()
|
||||||
|
|
||||||
|
def _update_num_batches(self):
|
||||||
|
"""Updates number of batches to iterate over."""
|
||||||
# maximum possible number of batches is equal to number of whole times
|
# maximum possible number of batches is equal to number of whole times
|
||||||
# batch_size divides in to the number of data points which can be
|
# batch_size divides in to the number of data points which can be
|
||||||
# found using integer division
|
# found using integer division
|
||||||
possible_num_batches = self.inputs.shape[0] // batch_size
|
possible_num_batches = self.inputs.shape[0] // self.batch_size
|
||||||
if self.max_num_batches == -1:
|
if self.max_num_batches == -1:
|
||||||
self.num_batches = possible_num_batches
|
self.num_batches = possible_num_batches
|
||||||
else:
|
else:
|
||||||
self.num_batches = min(self.max_num_batches, possible_num_batches)
|
self.num_batches = min(self.max_num_batches, possible_num_batches)
|
||||||
self.shuffle_order = shuffle_order
|
|
||||||
if rng is None:
|
|
||||||
rng = np.random.RandomState(DEFAULT_SEED)
|
|
||||||
self.rng = rng
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""Implements Python iterator interface.
|
"""Implements Python iterator interface.
|
||||||
@ -63,24 +94,33 @@ class DataProvider(object):
|
|||||||
"""
|
"""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def reset(self):
|
def new_epoch(self):
|
||||||
"""Resets the provider to the initial state to use in a new epoch."""
|
"""Starts a new epoch (pass through data), possibly shuffling first."""
|
||||||
self._curr_batch = 0
|
self._curr_batch = 0
|
||||||
if self.shuffle_order:
|
if self.shuffle_order:
|
||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Resets the provider to the initial state."""
|
||||||
|
inv_perm = np.argsort(self._current_order)
|
||||||
|
self._current_order = self._current_order[inv_perm]
|
||||||
|
self.inputs = self.inputs[inv_perm]
|
||||||
|
self.targets = self.targets[inv_perm]
|
||||||
|
self.new_epoch()
|
||||||
|
|
||||||
def shuffle(self):
|
def shuffle(self):
|
||||||
"""Randomly shuffles order of data."""
|
"""Randomly shuffles order of data."""
|
||||||
new_order = self.rng.permutation(self.inputs.shape[0])
|
perm = self.rng.permutation(self.inputs.shape[0])
|
||||||
self.inputs = self.inputs[new_order]
|
self._current_order = self._current_order[perm]
|
||||||
self.targets = self.targets[new_order]
|
self.inputs = self.inputs[perm]
|
||||||
|
self.targets = self.targets[perm]
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
if self._curr_batch + 1 > self.num_batches:
|
if self._curr_batch + 1 > self.num_batches:
|
||||||
# no more batches in current iteration through data set so reset
|
# no more batches in current iteration through data set so start
|
||||||
# the dataset for another pass and indicate iteration is at end
|
# new epoch ready for another pass and indicate iteration is at end
|
||||||
self.reset()
|
self.new_epoch()
|
||||||
raise StopIteration()
|
raise StopIteration()
|
||||||
# create an index slice corresponding to current batch number
|
# create an index slice corresponding to current batch number
|
||||||
batch_slice = slice(self._curr_batch * self.batch_size,
|
batch_slice = slice(self._curr_batch * self.batch_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user