Updated data provider with better sampling
This commit is contained in:
parent
ee4adeaa66
commit
4e22eefbce
@ -15,7 +15,7 @@ class DataProvider(object):
|
|||||||
"""Generic data provider."""
|
"""Generic data provider."""
|
||||||
|
|
||||||
def __init__(self, inputs, targets, batch_size, max_num_batches=-1,
|
def __init__(self, inputs, targets, batch_size, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None):
|
random_sampling=True, rng=None):
|
||||||
"""Create a new data provider object.
|
"""Create a new data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -28,23 +28,29 @@ class DataProvider(object):
|
|||||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
only as many batches as the data can be split into will be
|
only as many batches as the data can be split into will be
|
||||||
used. If set to -1 all of the data will be used.
|
used. If set to -1 all of the data will be used.
|
||||||
shuffle_order (bool): Whether to randomly permute the order of
|
random_sampling (bool): Whether to randomly permute the order of
|
||||||
the data before each epoch.
|
the data before each epoch.
|
||||||
rng (RandomState): A seeded random number generator.
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
|
||||||
if batch_size < 1:
|
if batch_size < 1:
|
||||||
raise ValueError('batch_size must be >= 1')
|
raise ValueError('batch_size must be >= 1')
|
||||||
|
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
|
|
||||||
if max_num_batches == 0 or max_num_batches < -1:
|
if max_num_batches == 0 or max_num_batches < -1:
|
||||||
raise ValueError('max_num_batches must be -1 or > 0')
|
raise ValueError('max_num_batches must be -1 or > 0')
|
||||||
|
|
||||||
self._max_num_batches = max_num_batches
|
self._max_num_batches = max_num_batches
|
||||||
self._update_num_batches()
|
self._update_num_batches()
|
||||||
self.shuffle_order = shuffle_order
|
self.random_sampling = random_sampling
|
||||||
self._current_order = np.arange(inputs.shape[0])
|
self._current_order = np.arange(inputs.shape[0])
|
||||||
|
|
||||||
if rng is None:
|
if rng is None:
|
||||||
rng = np.random.RandomState(DEFAULT_SEED)
|
rng = np.random.RandomState(DEFAULT_SEED)
|
||||||
|
|
||||||
self.rng = rng
|
self.rng = rng
|
||||||
self.new_epoch()
|
self.new_epoch()
|
||||||
|
|
||||||
@ -96,8 +102,6 @@ class DataProvider(object):
|
|||||||
def new_epoch(self):
|
def new_epoch(self):
|
||||||
"""Starts a new epoch (pass through data), possibly shuffling first."""
|
"""Starts a new epoch (pass through data), possibly shuffling first."""
|
||||||
self._curr_batch = 0
|
self._curr_batch = 0
|
||||||
if self.shuffle_order:
|
|
||||||
self.shuffle()
|
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
return self.next()
|
return self.next()
|
||||||
@ -110,13 +114,6 @@ class DataProvider(object):
|
|||||||
self.targets = self.targets[inv_perm]
|
self.targets = self.targets[inv_perm]
|
||||||
self.new_epoch()
|
self.new_epoch()
|
||||||
|
|
||||||
def shuffle(self):
|
|
||||||
"""Randomly shuffles order of data."""
|
|
||||||
perm = self.rng.permutation(self.inputs.shape[0])
|
|
||||||
self._current_order = self._current_order[perm]
|
|
||||||
self.inputs = self.inputs[perm]
|
|
||||||
self.targets = self.targets[perm]
|
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
if self._curr_batch + 1 > self.num_batches:
|
if self._curr_batch + 1 > self.num_batches:
|
||||||
@ -125,8 +122,13 @@ class DataProvider(object):
|
|||||||
self.new_epoch()
|
self.new_epoch()
|
||||||
raise StopIteration()
|
raise StopIteration()
|
||||||
# create an index slice corresponding to current batch number
|
# create an index slice corresponding to current batch number
|
||||||
batch_slice = slice(self._curr_batch * self.batch_size,
|
if self.random_sampling:
|
||||||
(self._curr_batch + 1) * self.batch_size)
|
batch_slice = self.rng.choice(self.inputs.shape[0], size=self.batch_size, replace=False)
|
||||||
|
|
||||||
|
else:
|
||||||
|
batch_slice = slice(self._curr_batch * self.batch_size,
|
||||||
|
(self._curr_batch + 1) * self.batch_size)
|
||||||
|
|
||||||
inputs_batch = self.inputs[batch_slice]
|
inputs_batch = self.inputs[batch_slice]
|
||||||
targets_batch = self.targets[batch_slice]
|
targets_batch = self.targets[batch_slice]
|
||||||
self._curr_batch += 1
|
self._curr_batch += 1
|
||||||
@ -136,7 +138,7 @@ class MNISTDataProvider(DataProvider):
|
|||||||
"""Data provider for MNIST handwritten digit images."""
|
"""Data provider for MNIST handwritten digit images."""
|
||||||
|
|
||||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None):
|
random_sampling=True, rng=None):
|
||||||
"""Create a new MNIST data provider object.
|
"""Create a new MNIST data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -147,7 +149,7 @@ class MNISTDataProvider(DataProvider):
|
|||||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
only as many batches as the data can be split into will be
|
only as many batches as the data can be split into will be
|
||||||
used. If set to -1 all of the data will be used.
|
used. If set to -1 all of the data will be used.
|
||||||
shuffle_order (bool): Whether to randomly permute the order of
|
random_sampling (bool): Whether to randomly permute the order of
|
||||||
the data before each epoch.
|
the data before each epoch.
|
||||||
rng (RandomState): A seeded random number generator.
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
@ -172,7 +174,7 @@ class MNISTDataProvider(DataProvider):
|
|||||||
inputs = inputs.astype(np.float32)
|
inputs = inputs.astype(np.float32)
|
||||||
# pass the loaded data to the parent class __init__
|
# pass the loaded data to the parent class __init__
|
||||||
super(MNISTDataProvider, self).__init__(
|
super(MNISTDataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
@ -202,7 +204,7 @@ class EMNISTDataProvider(DataProvider):
|
|||||||
"""Data provider for EMNIST handwritten digit images."""
|
"""Data provider for EMNIST handwritten digit images."""
|
||||||
|
|
||||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None, flatten=False, one_hot=False):
|
random_sampling=True, rng=None, flatten=False, one_hot=False):
|
||||||
"""Create a new EMNIST data provider object.
|
"""Create a new EMNIST data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -213,7 +215,7 @@ class EMNISTDataProvider(DataProvider):
|
|||||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
only as many batches as the data can be split into will be
|
only as many batches as the data can be split into will be
|
||||||
used. If set to -1 all of the data will be used.
|
used. If set to -1 all of the data will be used.
|
||||||
shuffle_order (bool): Whether to randomly permute the order of
|
random_sampling (bool): Whether to randomly permute the order of
|
||||||
the data before each epoch.
|
the data before each epoch.
|
||||||
rng (RandomState): A seeded random number generator.
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
@ -246,7 +248,7 @@ class EMNISTDataProvider(DataProvider):
|
|||||||
|
|
||||||
# pass the loaded data to the parent class __init__
|
# pass the loaded data to the parent class __init__
|
||||||
super(EMNISTDataProvider, self).__init__(
|
super(EMNISTDataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
@ -279,7 +281,7 @@ class CIFAR10DataProvider(DataProvider):
|
|||||||
"""Data provider for CIFAR-10 object images."""
|
"""Data provider for CIFAR-10 object images."""
|
||||||
|
|
||||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None, flatten=False, one_hot=False):
|
random_sampling=True, rng=None, flatten=False, one_hot=False):
|
||||||
"""Create a new EMNIST data provider object.
|
"""Create a new EMNIST data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -290,7 +292,7 @@ class CIFAR10DataProvider(DataProvider):
|
|||||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
only as many batches as the data can be split into will be
|
only as many batches as the data can be split into will be
|
||||||
used. If set to -1 all of the data will be used.
|
used. If set to -1 all of the data will be used.
|
||||||
shuffle_order (bool): Whether to randomly permute the order of
|
random_sampling (bool): Whether to randomly permute the order of
|
||||||
the data before each epoch.
|
the data before each epoch.
|
||||||
rng (RandomState): A seeded random number generator.
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
@ -327,7 +329,7 @@ class CIFAR10DataProvider(DataProvider):
|
|||||||
|
|
||||||
# pass the loaded data to the parent class __init__
|
# pass the loaded data to the parent class __init__
|
||||||
super(CIFAR10DataProvider, self).__init__(
|
super(CIFAR10DataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
@ -362,7 +364,7 @@ class CIFAR100DataProvider(DataProvider):
|
|||||||
"""Data provider for CIFAR-100 object images."""
|
"""Data provider for CIFAR-100 object images."""
|
||||||
|
|
||||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None, flatten=False, one_hot=False):
|
random_sampling=True, rng=None, flatten=False, one_hot=False):
|
||||||
"""Create a new EMNIST data provider object.
|
"""Create a new EMNIST data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -373,7 +375,7 @@ class CIFAR100DataProvider(DataProvider):
|
|||||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
only as many batches as the data can be split into will be
|
only as many batches as the data can be split into will be
|
||||||
used. If set to -1 all of the data will be used.
|
used. If set to -1 all of the data will be used.
|
||||||
shuffle_order (bool): Whether to randomly permute the order of
|
random_sampling (bool): Whether to randomly permute the order of
|
||||||
the data before each epoch.
|
the data before each epoch.
|
||||||
rng (RandomState): A seeded random number generator.
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
@ -407,7 +409,7 @@ class CIFAR100DataProvider(DataProvider):
|
|||||||
|
|
||||||
# pass the loaded data to the parent class __init__
|
# pass the loaded data to the parent class __init__
|
||||||
super(CIFAR100DataProvider, self).__init__(
|
super(CIFAR100DataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
@ -441,7 +443,7 @@ class MSD10GenreDataProvider(DataProvider):
|
|||||||
"""Data provider for Million Song Dataset 10-genre classification task."""
|
"""Data provider for Million Song Dataset 10-genre classification task."""
|
||||||
|
|
||||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None, one_hot=False, flatten=True):
|
random_sampling=True, rng=None, one_hot=False, flatten=True):
|
||||||
"""Create a new EMNIST data provider object.
|
"""Create a new EMNIST data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -452,7 +454,7 @@ class MSD10GenreDataProvider(DataProvider):
|
|||||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
only as many batches as the data can be split into will be
|
only as many batches as the data can be split into will be
|
||||||
used. If set to -1 all of the data will be used.
|
used. If set to -1 all of the data will be used.
|
||||||
shuffle_order (bool): Whether to randomly permute the order of
|
random_sampling (bool): Whether to randomly permute the order of
|
||||||
the data before each epoch.
|
the data before each epoch.
|
||||||
rng (RandomState): A seeded random number generator.
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
@ -500,7 +502,7 @@ class MSD10GenreDataProvider(DataProvider):
|
|||||||
|
|
||||||
# pass the loaded data to the parent class __init__
|
# pass the loaded data to the parent class __init__
|
||||||
super(MSD10GenreDataProvider, self).__init__(
|
super(MSD10GenreDataProvider, self).__init__(
|
||||||
inputs, target, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, target, batch_size, max_num_batches, random_sampling, rng)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
@ -533,7 +535,7 @@ class MSD25GenreDataProvider(DataProvider):
|
|||||||
"""Data provider for Million Song Dataset 25-genre classification task."""
|
"""Data provider for Million Song Dataset 25-genre classification task."""
|
||||||
|
|
||||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None, one_hot=False, flatten=True):
|
random_sampling=True, rng=None, one_hot=False, flatten=True):
|
||||||
"""Create a new EMNIST data provider object.
|
"""Create a new EMNIST data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -544,7 +546,7 @@ class MSD25GenreDataProvider(DataProvider):
|
|||||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
only as many batches as the data can be split into will be
|
only as many batches as the data can be split into will be
|
||||||
used. If set to -1 all of the data will be used.
|
used. If set to -1 all of the data will be used.
|
||||||
shuffle_order (bool): Whether to randomly permute the order of
|
random_sampling (bool): Whether to randomly permute the order of
|
||||||
the data before each epoch.
|
the data before each epoch.
|
||||||
rng (RandomState): A seeded random number generator.
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
@ -575,7 +577,7 @@ class MSD25GenreDataProvider(DataProvider):
|
|||||||
#inputs, target
|
#inputs, target
|
||||||
# pass the loaded data to the parent class __init__
|
# pass the loaded data to the parent class __init__
|
||||||
super(MSD25GenreDataProvider, self).__init__(
|
super(MSD25GenreDataProvider, self).__init__(
|
||||||
inputs, target, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, target, batch_size, max_num_batches, random_sampling, rng)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
@ -610,7 +612,7 @@ class MetOfficeDataProvider(DataProvider):
|
|||||||
"""South Scotland Met Office weather data provider."""
|
"""South Scotland Met Office weather data provider."""
|
||||||
|
|
||||||
def __init__(self, window_size, batch_size=10, max_num_batches=-1,
|
def __init__(self, window_size, batch_size=10, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None):
|
random_sampling=True, rng=None):
|
||||||
"""Create a new Met Office data provider object.
|
"""Create a new Met Office data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -623,7 +625,7 @@ class MetOfficeDataProvider(DataProvider):
|
|||||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
only as many batches as the data can be split into will be
|
only as many batches as the data can be split into will be
|
||||||
used. If set to -1 all of the data will be used.
|
used. If set to -1 all of the data will be used.
|
||||||
shuffle_order (bool): Whether to randomly permute the order of
|
random_sampling (bool): Whether to randomly permute the order of
|
||||||
the data before each epoch.
|
the data before each epoch.
|
||||||
rng (RandomState): A seeded random number generator.
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
@ -651,12 +653,12 @@ class MetOfficeDataProvider(DataProvider):
|
|||||||
# targets are last entry in windows
|
# targets are last entry in windows
|
||||||
targets = windowed[:, -1]
|
targets = windowed[:, -1]
|
||||||
super(MetOfficeDataProvider, self).__init__(
|
super(MetOfficeDataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||||
|
|
||||||
class CCPPDataProvider(DataProvider):
|
class CCPPDataProvider(DataProvider):
|
||||||
|
|
||||||
def __init__(self, which_set='train', input_dims=None, batch_size=10,
|
def __init__(self, which_set='train', input_dims=None, batch_size=10,
|
||||||
max_num_batches=-1, shuffle_order=True, rng=None):
|
max_num_batches=-1, random_sampling=True, rng=None):
|
||||||
"""Create a new Combined Cycle Power Plant data provider object.
|
"""Create a new Combined Cycle Power Plant data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -671,7 +673,7 @@ class CCPPDataProvider(DataProvider):
|
|||||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
only as many batches as the data can be split into will be
|
only as many batches as the data can be split into will be
|
||||||
used. If set to -1 all of the data will be used.
|
used. If set to -1 all of the data will be used.
|
||||||
shuffle_order (bool): Whether to randomly permute the order of
|
random_sampling (bool): Whether to randomly permute the order of
|
||||||
the data before each epoch.
|
the data before each epoch.
|
||||||
rng (RandomState): A seeded random number generator.
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
@ -697,14 +699,14 @@ class CCPPDataProvider(DataProvider):
|
|||||||
inputs = inputs[:, input_dims]
|
inputs = inputs[:, input_dims]
|
||||||
targets = loaded[which_set + '_targets']
|
targets = loaded[which_set + '_targets']
|
||||||
super(CCPPDataProvider, self).__init__(
|
super(CCPPDataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, random_sampling, rng)
|
||||||
|
|
||||||
|
|
||||||
class AugmentedMNISTDataProvider(MNISTDataProvider):
|
class AugmentedMNISTDataProvider(MNISTDataProvider):
|
||||||
"""Data provider for MNIST dataset which randomly transforms images."""
|
"""Data provider for MNIST dataset which randomly transforms images."""
|
||||||
|
|
||||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None, transformer=None):
|
random_sampling=True, rng=None, transformer=None):
|
||||||
"""Create a new augmented MNIST data provider object.
|
"""Create a new augmented MNIST data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -715,7 +717,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
|
|||||||
in an epoch. If `max_num_batches * batch_size > num_data` then
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
only as many batches as the data can be split into will be
|
only as many batches as the data can be split into will be
|
||||||
used. If set to -1 all of the data will be used.
|
used. If set to -1 all of the data will be used.
|
||||||
shuffle_order (bool): Whether to randomly permute the order of
|
random_sampling (bool): Whether to randomly permute the order of
|
||||||
the data before each epoch.
|
the data before each epoch.
|
||||||
rng (RandomState): A seeded random number generator.
|
rng (RandomState): A seeded random number generator.
|
||||||
transformer: Function which takes an `inputs` array of shape
|
transformer: Function which takes an `inputs` array of shape
|
||||||
@ -727,7 +729,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
|
|||||||
the data provider.
|
the data provider.
|
||||||
"""
|
"""
|
||||||
super(AugmentedMNISTDataProvider, self).__init__(
|
super(AugmentedMNISTDataProvider, self).__init__(
|
||||||
which_set, batch_size, max_num_batches, shuffle_order, rng)
|
which_set, batch_size, max_num_batches, random_sampling, rng)
|
||||||
self.transformer = transformer
|
self.transformer = transformer
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user