From 8586a3194947e0a3f1af556e91267ca699e6e3bf Mon Sep 17 00:00:00 2001 From: AntreasAntoniou Date: Thu, 1 Feb 2018 10:57:17 +0000 Subject: [PATCH] Add mistakenly deleted data providers --- data_providers.py | 738 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 738 insertions(+) create mode 100644 data_providers.py diff --git a/data_providers.py b/data_providers.py new file mode 100644 index 0000000..c93cbde --- /dev/null +++ b/data_providers.py @@ -0,0 +1,738 @@ +# -*- coding: utf-8 -*- +"""Data providers. + +This module provides classes for loading datasets and iterating over batches of +data points. +""" + +import os + +import numpy as np +DEFAULT_SEED = 22012018 + + +class DataProvider(object): + """Generic data provider.""" + + def __init__(self, inputs, targets, batch_size, max_num_batches=-1, + shuffle_order=True, rng=None): + """Create a new data provider object. + + Args: + inputs (ndarray): Array of data input features of shape + (num_data, input_dim). + targets (ndarray): Array of data output targets of shape + (num_data, output_dim) or (num_data,) if output_dim == 1. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + self.inputs = inputs + self.targets = targets + if batch_size < 1: + raise ValueError('batch_size must be >= 1') + self._batch_size = batch_size + if max_num_batches == 0 or max_num_batches < -1: + raise ValueError('max_num_batches must be -1 or > 0') + self._max_num_batches = max_num_batches + self._update_num_batches() + self.shuffle_order = shuffle_order + self._current_order = np.arange(inputs.shape[0]) + if rng is None: + rng = np.random.RandomState(DEFAULT_SEED) + self.rng = rng + self.new_epoch() + + @property + def batch_size(self): + """Number of data points to include in each batch.""" + return self._batch_size + + @batch_size.setter + def batch_size(self, value): + if value < 1: + raise ValueError('batch_size must be >= 1') + self._batch_size = value + self._update_num_batches() + + @property + def max_num_batches(self): + """Maximum number of batches to iterate over in an epoch.""" + return self._max_num_batches + + @max_num_batches.setter + def max_num_batches(self, value): + if value == 0 or value < -1: + raise ValueError('max_num_batches must be -1 or > 0') + self._max_num_batches = value + self._update_num_batches() + + def _update_num_batches(self): + """Updates number of batches to iterate over.""" + # maximum possible number of batches is equal to number of whole times + # batch_size divides in to the number of data points which can be + # found using integer division + possible_num_batches = self.inputs.shape[0] // self.batch_size + if self.max_num_batches == -1: + self.num_batches = possible_num_batches + else: + self.num_batches = min(self.max_num_batches, possible_num_batches) + + def __iter__(self): + """Implements Python iterator interface. + + This should return an object implementing a `next` method which steps + through a sequence returning one element at a time and raising + `StopIteration` when at the end of the sequence. Here the object + returned is the DataProvider itself. + """ + return self + + def new_epoch(self): + """Starts a new epoch (pass through data), possibly shuffling first.""" + self._curr_batch = 0 + if self.shuffle_order: + self.shuffle() + + def __next__(self): + return self.next() + + def reset(self): + """Resets the provider to the initial state.""" + inv_perm = np.argsort(self._current_order) + self._current_order = self._current_order[inv_perm] + self.inputs = self.inputs[inv_perm] + self.targets = self.targets[inv_perm] + self.new_epoch() + + def shuffle(self): + """Randomly shuffles order of data.""" + perm = self.rng.permutation(self.inputs.shape[0]) + self._current_order = self._current_order[perm] + self.inputs = self.inputs[perm] + self.targets = self.targets[perm] + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + if self._curr_batch + 1 > self.num_batches: + # no more batches in current iteration through data set so start + # new epoch ready for another pass and indicate iteration is at end + self.new_epoch() + raise StopIteration() + # create an index slice corresponding to current batch number + batch_slice = slice(self._curr_batch * self.batch_size, + (self._curr_batch + 1) * self.batch_size) + inputs_batch = self.inputs[batch_slice] + targets_batch = self.targets[batch_slice] + self._curr_batch += 1 + return inputs_batch, targets_batch + +class MNISTDataProvider(DataProvider): + """Data provider for MNIST handwritten digit images.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None): + """Create a new MNIST data provider object. + + Args: + which_set: One of 'train', 'valid' or 'eval'. Determines which + portion of the MNIST data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + # check a valid which_set was provided + assert which_set in ['train', 'valid', 'test'], ( + 'Expected which_set to be either train, valid or eval. ' + 'Got {0}'.format(which_set) + ) + self.which_set = which_set + self.num_classes = 10 + # construct path to data using os.path.join to ensure the correct path + # separator for the current platform / OS is used + # MLP_DATA_DIR environment variable should point to the data directory + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'mnist-{0}.npz'.format(which_set)) + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # load data from compressed numpy file + loaded = np.load(data_path) + inputs, targets = loaded['inputs'], loaded['targets'] + inputs = inputs.astype(np.float32) + # pass the loaded data to the parent class __init__ + super(MNISTDataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super(MNISTDataProvider, self).next() + return inputs_batch, self.to_one_of_k(targets_batch) + + def to_one_of_k(self, int_targets): + """Converts integer coded class target to 1 of K coded targets. + + Args: + int_targets (ndarray): Array of integer coded class targets (i.e. + where an integer from 0 to `num_classes` - 1 is used to + indicate which is the correct class). This should be of shape + (num_data,). + + Returns: + Array of 1 of K coded targets i.e. an array of shape + (num_data, num_classes) where for each row all elements are equal + to zero except for the column corresponding to the correct class + which is equal to one. + """ + one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) + one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1 + return one_of_k_targets + +class EMNISTDataProvider(DataProvider): + """Data provider for EMNIST handwritten digit images.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None, flatten=False, one_hot=False): + """Create a new EMNIST data provider object. + + Args: + which_set: One of 'train', 'valid' or 'eval'. Determines which + portion of the EMNIST data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + # check a valid which_set was provided + assert which_set in ['train', 'valid', 'test'], ( + 'Expected which_set to be either train, valid or eval. ' + 'Got {0}'.format(which_set) + ) + self.one_hot = one_hot + self.which_set = which_set + self.num_classes = 47 + # construct path to data using os.path.join to ensure the correct path + # separator for the current platform / OS is used + # MLP_DATA_DIR environment variable should point to the data directory + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'emnist-{0}.npz'.format(which_set)) + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # load data from compressed numpy file + loaded = np.load(data_path) + + inputs, targets = loaded['inputs'], loaded['targets'] + inputs = inputs.astype(np.float32) + if flatten: + inputs = np.reshape(inputs, newshape=(-1, 28*28)) + else: + inputs = np.expand_dims(inputs, axis=3) + inputs = inputs / 255.0 + + # pass the loaded data to the parent class __init__ + super(EMNISTDataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super(EMNISTDataProvider, self).next() + if self.one_hot: + return inputs_batch, self.to_one_of_k(targets_batch) + else: + return inputs_batch, targets_batch + + def to_one_of_k(self, int_targets): + """Converts integer coded class target to 1 of K coded targets. + + Args: + int_targets (ndarray): Array of integer coded class targets (i.e. + where an integer from 0 to `num_classes` - 1 is used to + indicate which is the correct class). This should be of shape + (num_data,). + + Returns: + Array of 1 of K coded targets i.e. an array of shape + (num_data, num_classes) where for each row all elements are equal + to zero except for the column corresponding to the correct class + which is equal to one. + """ + one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) + one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1 + return one_of_k_targets + +class CIFAR10DataProvider(DataProvider): + """Data provider for CIFAR-10 object images.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None, flatten=False, one_hot=False): + """Create a new EMNIST data provider object. + + Args: + which_set: One of 'train', 'valid' or 'eval'. Determines which + portion of the EMNIST data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + # check a valid which_set was provided + assert which_set in ['train', 'valid', 'test'], ( + 'Expected which_set to be either train, valid or eval. ' + 'Got {0}'.format(which_set) + ) + self.one_hot = one_hot + self.which_set = which_set + self.num_classes = 10 + # construct path to data using os.path.join to ensure the correct path + # separator for the current platform / OS is used + # MLP_DATA_DIR environment variable should point to the data directory + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'cifar10-{0}.npz'.format(which_set)) + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # load data from compressed numpy file + loaded = np.load(data_path) + + inputs, targets = loaded['inputs'], loaded['targets'] + inputs = inputs.astype(np.float32) + if flatten: + inputs = np.reshape(inputs, newshape=(-1, 32*32*3)) + else: + inputs = np.reshape(inputs, newshape=(-1, 3, 32, 32)) + inputs = np.transpose(inputs, axes=(0, 2, 3, 1)) + + inputs = inputs / 255.0 + # label map gives strings corresponding to integer label targets + + + # pass the loaded data to the parent class __init__ + super(CIFAR10DataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super(CIFAR10DataProvider, self).next() + if self.one_hot: + return inputs_batch, self.to_one_of_k(targets_batch) + else: + return inputs_batch, targets_batch + + def to_one_of_k(self, int_targets): + """Converts integer coded class target to 1 of K coded targets. + + Args: + int_targets (ndarray): Array of integer coded class targets (i.e. + where an integer from 0 to `num_classes` - 1 is used to + indicate which is the correct class). This should be of shape + (num_data,). + + Returns: + Array of 1 of K coded targets i.e. an array of shape + (num_data, num_classes) where for each row all elements are equal + to zero except for the column corresponding to the correct class + which is equal to one. + """ + one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) + one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1 + return one_of_k_targets + + + +class CIFAR100DataProvider(DataProvider): + """Data provider for CIFAR-100 object images.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None, flatten=False, one_hot=False): + """Create a new EMNIST data provider object. + + Args: + which_set: One of 'train', 'valid' or 'eval'. Determines which + portion of the EMNIST data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + # check a valid which_set was provided + assert which_set in ['train', 'valid', 'test'], ( + 'Expected which_set to be either train, valid or eval. ' + 'Got {0}'.format(which_set) + ) + self.one_hot = one_hot + self.which_set = which_set + self.num_classes = 100 + # construct path to data using os.path.join to ensure the correct path + # separator for the current platform / OS is used + # MLP_DATA_DIR environment variable should point to the data directory + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'cifar100-{0}.npz'.format(which_set)) + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # load data from compressed numpy file + loaded = np.load(data_path) + + inputs, targets = loaded['inputs'], loaded['targets'] + inputs = inputs.astype(np.float32) + if flatten: + inputs = np.reshape(inputs, newshape=(-1, 32*32*3)) + else: + inputs = np.reshape(inputs, newshape=(-1, 3, 32, 32)) + inputs = np.transpose(inputs, axes=(0, 2, 3, 1)) + inputs = inputs / 255.0 + + # pass the loaded data to the parent class __init__ + super(CIFAR100DataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super(CIFAR100DataProvider, self).next() + if self.one_hot: + return inputs_batch, self.to_one_of_k(targets_batch) + else: + return inputs_batch, targets_batch + + def to_one_of_k(self, int_targets): + """Converts integer coded class target to 1 of K coded targets. + + Args: + int_targets (ndarray): Array of integer coded class targets (i.e. + where an integer from 0 to `num_classes` - 1 is used to + indicate which is the correct class). This should be of shape + (num_data,). + + Returns: + Array of 1 of K coded targets i.e. an array of shape + (num_data, num_classes) where for each row all elements are equal + to zero except for the column corresponding to the correct class + which is equal to one. + """ + one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) + one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1 + return one_of_k_targets + + +class MSD10GenreDataProvider(DataProvider): + """Data provider for Million Song Dataset 10-genre classification task.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None, one_hot=False, flatten=True): + """Create a new EMNIST data provider object. + + Args: + which_set: One of 'train', 'valid' or 'eval'. Determines which + portion of the EMNIST data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + # check a valid which_set was provided + assert which_set in ['train', 'valid', 'test'], ( + 'Expected which_set to be either train, valid or eval. ' + 'Got {0}'.format(which_set) + ) + self.one_hot = one_hot + self.which_set = which_set + self.num_classes = 10 + # construct path to data using os.path.join to ensure the correct path + # separator for the current platform / OS is used + # MLP_DATA_DIR environment variable should point to the data directory + if which_set is not "test": + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'msd10-{0}.npz'.format(which_set)) + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # load data from compressed numpy file + loaded = np.load(data_path) + + inputs, target = loaded['inputs'], loaded['targets'] + else: + input_data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'msd-10-genre-test-inputs.npz') + assert os.path.isfile(input_data_path), ( + 'Data file does not exist at expected path: ' + input_data_path + ) + target_data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'msd-10-genre-test-targets.npz') + assert os.path.isfile(input_data_path), ( + 'Data file does not exist at expected path: ' + input_data_path + ) + # load data from compressed numpy file + inputs = np.load(input_data_path)['inputs'] + target = np.load(target_data_path)['targets'] + if flatten: + inputs = inputs.reshape((-1, 120*25)) + #inputs, targets = loaded['inputs'], loaded['targets'] + + + # label map gives strings corresponding to integer label targets + + # pass the loaded data to the parent class __init__ + super(MSD10GenreDataProvider, self).__init__( + inputs, target, batch_size, max_num_batches, shuffle_order, rng) + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super(MSD10GenreDataProvider, self).next() + if self.one_hot: + return inputs_batch, self.to_one_of_k(targets_batch) + else: + return inputs_batch, targets_batch + + def to_one_of_k(self, int_targets): + """Converts integer coded class target to 1 of K coded targets. + + Args: + int_targets (ndarray): Array of integer coded class targets (i.e. + where an integer from 0 to `num_classes` - 1 is used to + indicate which is the correct class). This should be of shape + (num_data,). + + Returns: + Array of 1 of K coded targets i.e. an array of shape + (num_data, num_classes) where for each row all elements are equal + to zero except for the column corresponding to the correct class + which is equal to one. + """ + one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) + one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1 + return one_of_k_targets + +class MSD25GenreDataProvider(DataProvider): + """Data provider for Million Song Dataset 25-genre classification task.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None, one_hot=False, flatten=True): + """Create a new EMNIST data provider object. + + Args: + which_set: One of 'train', 'valid' or 'eval'. Determines which + portion of the EMNIST data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + # check a valid which_set was provided + assert which_set in ['train', 'valid', 'test'], ( + 'Expected which_set to be either train or valid. ' + 'Got {0}'.format(which_set) + ) + self.one_hot = one_hot + self.which_set = which_set + self.num_classes = 25 + # construct path to data using os.path.join to ensure the correct path + # separator for the current platform / OS is used + # MLP_DATA_DIR environment variable should point to the data directory + + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'msd10-{0}.npz'.format(which_set)) + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # load data from compressed numpy file + loaded = np.load(data_path) + + inputs, target = loaded['inputs'], loaded['targets'] + + if flatten: + inputs = inputs.reshape((-1, 120*25)) + #inputs, target + # pass the loaded data to the parent class __init__ + super(MSD25GenreDataProvider, self).__init__( + inputs, target, batch_size, max_num_batches, shuffle_order, rng) + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super(MSD25GenreDataProvider, self).next() + if self.one_hot: + return inputs_batch, self.to_one_of_k(targets_batch) + else: + return inputs_batch, targets_batch + + def to_one_of_k(self, int_targets): + """Converts integer coded class target to 1 of K coded targets. + + Args: + int_targets (ndarray): Array of integer coded class targets (i.e. + where an integer from 0 to `num_classes` - 1 is used to + indicate which is the correct class). This should be of shape + (num_data,). + + Returns: + Array of 1 of K coded targets i.e. an array of shape + (num_data, num_classes) where for each row all elements are equal + to zero except for the column corresponding to the correct class + which is equal to one. + """ + one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) + one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1 + return one_of_k_targets + + + +class MetOfficeDataProvider(DataProvider): + """South Scotland Met Office weather data provider.""" + + def __init__(self, window_size, batch_size=10, max_num_batches=-1, + shuffle_order=True, rng=None): + """Create a new Met Office data provider object. + + Args: + window_size (int): Size of windows to split weather time series + data into. The constructed input features will be the first + `window_size - 1` entries in each window and the target outputs + the last entry in each window. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'HadSSP_daily_qc.txt') + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + raw = np.loadtxt(data_path, skiprows=3, usecols=range(2, 32)) + assert window_size > 1, 'window_size must be at least 2.' + self.window_size = window_size + # filter out all missing datapoints and flatten to a vector + filtered = raw[raw >= 0].flatten() + # normalise data to zero mean, unit standard deviation + mean = np.mean(filtered) + std = np.std(filtered) + normalised = (filtered - mean) / std + # create a view on to array corresponding to a rolling window + shape = (normalised.shape[-1] - self.window_size + 1, self.window_size) + strides = normalised.strides + (normalised.strides[-1],) + windowed = np.lib.stride_tricks.as_strided( + normalised, shape=shape, strides=strides) + # inputs are first (window_size - 1) entries in windows + inputs = windowed[:, :-1] + # targets are last entry in windows + targets = windowed[:, -1] + super(MetOfficeDataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + +class CCPPDataProvider(DataProvider): + + def __init__(self, which_set='train', input_dims=None, batch_size=10, + max_num_batches=-1, shuffle_order=True, rng=None): + """Create a new Combined Cycle Power Plant data provider object. + + Args: + which_set: One of 'train' or 'valid'. Determines which portion of + data this object should provide. + input_dims: Which of the four input dimension to use. If `None` all + are used. If an iterable of integers are provided (consisting + of a subset of {0, 1, 2, 3}) then only the corresponding + input dimensions are included. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'ccpp_data.npz') + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # check a valid which_set was provided + assert which_set in ['train', 'valid'], ( + 'Expected which_set to be either train or valid ' + 'Got {0}'.format(which_set) + ) + # check input_dims are valid + if not input_dims is not None: + input_dims = set(input_dims) + assert input_dims.issubset({0, 1, 2, 3}), ( + 'input_dims should be a subset of {0, 1, 2, 3}' + ) + loaded = np.load(data_path) + inputs = loaded[which_set + '_inputs'] + if input_dims is not None: + inputs = inputs[:, input_dims] + targets = loaded[which_set + '_targets'] + super(CCPPDataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + + +class AugmentedMNISTDataProvider(MNISTDataProvider): + """Data provider for MNIST dataset which randomly transforms images.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None, transformer=None): + """Create a new augmented MNIST data provider object. + + Args: + which_set: One of 'train', 'valid' or 'test'. Determines which + portion of the MNIST data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + transformer: Function which takes an `inputs` array of shape + (batch_size, input_dim) corresponding to a batch of input + images and a `rng` random number generator object (i.e. a + call signature `transformer(inputs, rng)`) and applies a + potentiall random set of transformations to some / all of the + input images as each new batch is returned when iterating over + the data provider. + """ + super(AugmentedMNISTDataProvider, self).__init__( + which_set, batch_size, max_num_batches, shuffle_order, rng) + self.transformer = transformer + + def next(self): + """Returns next data batch or raises `StopIteration` if at end.""" + inputs_batch, targets_batch = super( + AugmentedMNISTDataProvider, self).next() + transformed_inputs_batch = self.transformer(inputs_batch, self.rng) + return transformed_inputs_batch, targets_batch