From 822b946033572dfcaac9283d5caf01fcf8932afc Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Sat, 21 Jan 2017 19:34:01 +0000 Subject: [PATCH] Adding MSD data providers. --- mlp/data_providers.py | 92 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/mlp/data_providers.py b/mlp/data_providers.py index 1e6af7e..c99549d 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -135,6 +135,8 @@ class OneOfKDataProvider(DataProvider): """1-of-K classification target data provider. Transforms integer target labels to binary 1-of-K encoded targets. + + Derived classes must set self.num_classes appropriately. """ def next(self): @@ -302,6 +304,96 @@ class CIFAR100DataProvider(OneOfKDataProvider): inputs, targets, batch_size, max_num_batches, shuffle_order, rng) +class MSD10GenreDataProvider(OneOfKDataProvider): + """Data provider for Million Song Dataset 10-genre classification task.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None): + """Create a new Million Song Dataset 10-genre data provider object. + + Args: + which_set: One of 'train' or 'valid'. Determines which + portion of the MSD 10-genre data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + # check a valid which_set was provided + assert which_set in ['train', 'valid'], ( + 'Expected which_set to be either train or valid. ' + 'Got {0}'.format(which_set) + ) + self.which_set = which_set + self.num_classes = 10 + # construct path to data using os.path.join to ensure the correct path + # separator for the current platform / OS is used + # MLP_DATA_DIR environment variable should point to the data directory + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], + 'msd-10-genre-{0}.npz'.format(which_set)) + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # load data from compressed numpy file + loaded = np.load(data_path) + inputs, targets = loaded['inputs'], loaded['targets'] + # label map gives strings corresponding to integer label targets + self.label_map = loaded['label_map'] + # pass the loaded data to the parent class __init__ + super(MSD10GenreDataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + + +class MSD25GenreDataProvider(OneOfKDataProvider): + """Data provider for Million Song Dataset 25-genre classification task.""" + + def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, + shuffle_order=True, rng=None): + """Create a new Million Song Dataset 25-genre data provider object. + + Args: + which_set: One of 'train' or 'valid'. Determines which + portion of the MSD 25-genre data this object should provide. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + # check a valid which_set was provided + assert which_set in ['train', 'valid'], ( + 'Expected which_set to be either train or valid. ' + 'Got {0}'.format(which_set) + ) + self.which_set = which_set + self.num_classes = 25 + # construct path to data using os.path.join to ensure the correct path + # separator for the current platform / OS is used + # MLP_DATA_DIR environment variable should point to the data directory + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], + 'msd-25-genre-{0}.npz'.format(which_set)) + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # load data from compressed numpy file + loaded = np.load(data_path) + inputs, targets = loaded['inputs'], loaded['targets'] + # label map gives strings corresponding to integer label targets + self.label_map = loaded['label_map'] + # pass the loaded data to the parent class __init__ + super(MSD25GenreDataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + + class MetOfficeDataProvider(DataProvider): """South Scotland Met Office weather data provider."""