diff --git a/mlp/data_providers.py b/mlp/data_providers.py index 159d0c1..1e6af7e 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -131,7 +131,38 @@ class DataProvider(object): return inputs_batch, targets_batch -class MNISTDataProvider(DataProvider): +class OneOfKDataProvider(DataProvider): + """1-of-K classification target data provider. + + Transforms integer target labels to binary 1-of-K encoded targets. + """ + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super(OneOfKDataProvider, self).next() + return inputs_batch, self.to_one_of_k(targets_batch) + + def to_one_of_k(self, int_targets): + """Converts integer coded class target to 1-of-K coded targets. + + Args: + int_targets (ndarray): Array of integer coded class targets (i.e. + where an integer from 0 to `num_classes` - 1 is used to + indicate which is the correct class). This should be of shape + (num_data,). + + Returns: + Array of 1-of-K coded targets i.e. an array of shape + (num_data, num_classes) where for each row all elements are equal + to zero except for the column corresponding to the correct class + which is equal to one. + """ + one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) + one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1 + return one_of_k_targets + + +class MNISTDataProvider(OneOfKDataProvider): """Data provider for MNIST handwritten digit images.""" def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, @@ -173,29 +204,102 @@ class MNISTDataProvider(DataProvider): super(MNISTDataProvider, self).__init__( inputs, targets, batch_size, max_num_batches, shuffle_order, rng) - def next(self): - """Returns next data batch or raises `StopIteration` if at end.""" - inputs_batch, targets_batch = super(MNISTDataProvider, self).next() - return inputs_batch, self.to_one_of_k(targets_batch) - def to_one_of_k(self, int_targets): - """Converts integer coded class target to 1 of K coded targets. +class CIFAR10DataProvider(OneOfKDataProvider): + """Data provider for CIFAR-10 object images.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None): + """Create a new CIFAR-10 data provider object. Args: - int_targets (ndarray): Array of integer coded class targets (i.e. - where an integer from 0 to `num_classes` - 1 is used to - indicate which is the correct class). This should be of shape - (num_data,). - - Returns: - Array of 1 of K coded targets i.e. an array of shape - (num_data, num_classes) where for each row all elements are equal - to zero except for the column corresponding to the correct class - which is equal to one. + which_set: One of 'train' or 'valid'. Determines which + portion of the CIFAR-10 data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. """ - one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) - one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1 - return one_of_k_targets + # check a valid which_set was provided + assert which_set in ['train', 'valid'], ( + 'Expected which_set to be either train or valid. ' + 'Got {0}'.format(which_set) + ) + self.which_set = which_set + self.num_classes = 10 + # construct path to data using os.path.join to ensure the correct path + # separator for the current platform / OS is used + # MLP_DATA_DIR environment variable should point to the data directory + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'cifar-10-{0}.npz'.format(which_set)) + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # load data from compressed numpy file + loaded = np.load(data_path) + inputs, targets = loaded['inputs'], loaded['targets'] + # label map gives strings corresponding to integer label targets + self.label_map = loaded['label_map'] + # pass the loaded data to the parent class __init__ + super(CIFAR10DataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + + +class CIFAR100DataProvider(OneOfKDataProvider): + """Data provider for CIFAR-100 object images.""" + + def __init__(self, which_set='train', use_coarse_targets=False, + batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None): + """Create a new CIFAR-100 data provider object. + + Args: + which_set: One of 'train' or 'valid'. Determines which + portion of the CIFAR-100 data this object should provide. + use_coarse_targets: Whether to use coarse 'superclass' labels as + targets instead of standard class labels. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + # check a valid which_set was provided + assert which_set in ['train', 'valid'], ( + 'Expected which_set to be either train or valid. ' + 'Got {0}'.format(which_set) + ) + self.which_set = which_set + self.use_coarse_targets = use_coarse_targets + self.num_classes = 20 if use_coarse_targets else 100 + # construct path to data using os.path.join to ensure the correct path + # separator for the current platform / OS is used + # MLP_DATA_DIR environment variable should point to the data directory + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'cifar-100-{0}.npz'.format(which_set)) + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # load data from compressed numpy file + loaded = np.load(data_path) + inputs, targets = loaded['inputs'], loaded['targets'] + targets = targets[:, 1] if use_coarse_targets else targets[:, 0] + # label map gives strings corresponding to integer label targets + self.label_map = ( + loaded['coarse_label_map'] + if use_coarse_targets else + loaded['fine_label_map'] + ) + # pass the loaded data to the parent class __init__ + super(CIFAR100DataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng) class MetOfficeDataProvider(DataProvider):