Adding CIFAR data providers.
This commit is contained in:
parent
c4a662465e
commit
510a1ca7be
@ -131,7 +131,38 @@ class DataProvider(object):
|
|||||||
return inputs_batch, targets_batch
|
return inputs_batch, targets_batch
|
||||||
|
|
||||||
|
|
||||||
class MNISTDataProvider(DataProvider):
|
class OneOfKDataProvider(DataProvider):
|
||||||
|
"""1-of-K classification target data provider.
|
||||||
|
|
||||||
|
Transforms integer target labels to binary 1-of-K encoded targets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
|
inputs_batch, targets_batch = super(OneOfKDataProvider, self).next()
|
||||||
|
return inputs_batch, self.to_one_of_k(targets_batch)
|
||||||
|
|
||||||
|
def to_one_of_k(self, int_targets):
|
||||||
|
"""Converts integer coded class target to 1-of-K coded targets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
int_targets (ndarray): Array of integer coded class targets (i.e.
|
||||||
|
where an integer from 0 to `num_classes` - 1 is used to
|
||||||
|
indicate which is the correct class). This should be of shape
|
||||||
|
(num_data,).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of 1-of-K coded targets i.e. an array of shape
|
||||||
|
(num_data, num_classes) where for each row all elements are equal
|
||||||
|
to zero except for the column corresponding to the correct class
|
||||||
|
which is equal to one.
|
||||||
|
"""
|
||||||
|
one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes))
|
||||||
|
one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1
|
||||||
|
return one_of_k_targets
|
||||||
|
|
||||||
|
|
||||||
|
class MNISTDataProvider(OneOfKDataProvider):
|
||||||
"""Data provider for MNIST handwritten digit images."""
|
"""Data provider for MNIST handwritten digit images."""
|
||||||
|
|
||||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
@ -173,29 +204,102 @@ class MNISTDataProvider(DataProvider):
|
|||||||
super(MNISTDataProvider, self).__init__(
|
super(MNISTDataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||||
|
|
||||||
def next(self):
|
|
||||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
|
||||||
inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
|
|
||||||
return inputs_batch, self.to_one_of_k(targets_batch)
|
|
||||||
|
|
||||||
def to_one_of_k(self, int_targets):
|
class CIFAR10DataProvider(OneOfKDataProvider):
|
||||||
"""Converts integer coded class target to 1 of K coded targets.
|
"""Data provider for CIFAR-10 object images."""
|
||||||
|
|
||||||
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
|
shuffle_order=True, rng=None):
|
||||||
|
"""Create a new CIFAR-10 data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
int_targets (ndarray): Array of integer coded class targets (i.e.
|
which_set: One of 'train' or 'valid'. Determines which
|
||||||
where an integer from 0 to `num_classes` - 1 is used to
|
portion of the CIFAR-10 data this object should provide.
|
||||||
indicate which is the correct class). This should be of shape
|
batch_size (int): Number of data points to include in each batch.
|
||||||
(num_data,).
|
max_num_batches (int): Maximum number of batches to iterate over
|
||||||
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
Returns:
|
only as many batches as the data can be split into will be
|
||||||
Array of 1 of K coded targets i.e. an array of shape
|
used. If set to -1 all of the data will be used.
|
||||||
(num_data, num_classes) where for each row all elements are equal
|
shuffle_order (bool): Whether to randomly permute the order of
|
||||||
to zero except for the column corresponding to the correct class
|
the data before each epoch.
|
||||||
which is equal to one.
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes))
|
# check a valid which_set was provided
|
||||||
one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1
|
assert which_set in ['train', 'valid'], (
|
||||||
return one_of_k_targets
|
'Expected which_set to be either train or valid. '
|
||||||
|
'Got {0}'.format(which_set)
|
||||||
|
)
|
||||||
|
self.which_set = which_set
|
||||||
|
self.num_classes = 10
|
||||||
|
# construct path to data using os.path.join to ensure the correct path
|
||||||
|
# separator for the current platform / OS is used
|
||||||
|
# MLP_DATA_DIR environment variable should point to the data directory
|
||||||
|
data_path = os.path.join(
|
||||||
|
os.environ['MLP_DATA_DIR'], 'cifar-10-{0}.npz'.format(which_set))
|
||||||
|
assert os.path.isfile(data_path), (
|
||||||
|
'Data file does not exist at expected path: ' + data_path
|
||||||
|
)
|
||||||
|
# load data from compressed numpy file
|
||||||
|
loaded = np.load(data_path)
|
||||||
|
inputs, targets = loaded['inputs'], loaded['targets']
|
||||||
|
# label map gives strings corresponding to integer label targets
|
||||||
|
self.label_map = loaded['label_map']
|
||||||
|
# pass the loaded data to the parent class __init__
|
||||||
|
super(CIFAR10DataProvider, self).__init__(
|
||||||
|
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||||
|
|
||||||
|
|
||||||
|
class CIFAR100DataProvider(OneOfKDataProvider):
|
||||||
|
"""Data provider for CIFAR-100 object images."""
|
||||||
|
|
||||||
|
def __init__(self, which_set='train', use_coarse_targets=False,
|
||||||
|
batch_size=100, max_num_batches=-1,
|
||||||
|
shuffle_order=True, rng=None):
|
||||||
|
"""Create a new CIFAR-100 data provider object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
which_set: One of 'train' or 'valid'. Determines which
|
||||||
|
portion of the CIFAR-100 data this object should provide.
|
||||||
|
use_coarse_targets: Whether to use coarse 'superclass' labels as
|
||||||
|
targets instead of standard class labels.
|
||||||
|
batch_size (int): Number of data points to include in each batch.
|
||||||
|
max_num_batches (int): Maximum number of batches to iterate over
|
||||||
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
|
only as many batches as the data can be split into will be
|
||||||
|
used. If set to -1 all of the data will be used.
|
||||||
|
shuffle_order (bool): Whether to randomly permute the order of
|
||||||
|
the data before each epoch.
|
||||||
|
rng (RandomState): A seeded random number generator.
|
||||||
|
"""
|
||||||
|
# check a valid which_set was provided
|
||||||
|
assert which_set in ['train', 'valid'], (
|
||||||
|
'Expected which_set to be either train or valid. '
|
||||||
|
'Got {0}'.format(which_set)
|
||||||
|
)
|
||||||
|
self.which_set = which_set
|
||||||
|
self.use_coarse_targets = use_coarse_targets
|
||||||
|
self.num_classes = 20 if use_coarse_targets else 100
|
||||||
|
# construct path to data using os.path.join to ensure the correct path
|
||||||
|
# separator for the current platform / OS is used
|
||||||
|
# MLP_DATA_DIR environment variable should point to the data directory
|
||||||
|
data_path = os.path.join(
|
||||||
|
os.environ['MLP_DATA_DIR'], 'cifar-100-{0}.npz'.format(which_set))
|
||||||
|
assert os.path.isfile(data_path), (
|
||||||
|
'Data file does not exist at expected path: ' + data_path
|
||||||
|
)
|
||||||
|
# load data from compressed numpy file
|
||||||
|
loaded = np.load(data_path)
|
||||||
|
inputs, targets = loaded['inputs'], loaded['targets']
|
||||||
|
targets = targets[:, 1] if use_coarse_targets else targets[:, 0]
|
||||||
|
# label map gives strings corresponding to integer label targets
|
||||||
|
self.label_map = (
|
||||||
|
loaded['coarse_label_map']
|
||||||
|
if use_coarse_targets else
|
||||||
|
loaded['fine_label_map']
|
||||||
|
)
|
||||||
|
# pass the loaded data to the parent class __init__
|
||||||
|
super(CIFAR100DataProvider, self).__init__(
|
||||||
|
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||||
|
|
||||||
|
|
||||||
class MetOfficeDataProvider(DataProvider):
|
class MetOfficeDataProvider(DataProvider):
|
||||||
|
Loading…
Reference in New Issue
Block a user