Adding MSD data providers.

This commit is contained in:
Matt Graham 2017-01-21 19:34:01 +00:00
parent 510a1ca7be
commit 822b946033

View File

@ -135,6 +135,8 @@ class OneOfKDataProvider(DataProvider):
"""1-of-K classification target data provider.
Transforms integer target labels to binary 1-of-K encoded targets.
Derived classes must set self.num_classes appropriately.
"""
def next(self):
@ -302,6 +304,96 @@ class CIFAR100DataProvider(OneOfKDataProvider):
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
class MSD10GenreDataProvider(OneOfKDataProvider):
"""Data provider for Million Song Dataset 10-genre classification task."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None):
"""Create a new Million Song Dataset 10-genre data provider object.
Args:
which_set: One of 'train' or 'valid'. Determines which
portion of the MSD 10-genre data this object should provide.
batch_size (int): Number of data points to include in each batch.
max_num_batches (int): Maximum number of batches to iterate over
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
# check a valid which_set was provided
assert which_set in ['train', 'valid'], (
'Expected which_set to be either train or valid. '
'Got {0}'.format(which_set)
)
self.which_set = which_set
self.num_classes = 10
# construct path to data using os.path.join to ensure the correct path
# separator for the current platform / OS is used
# MLP_DATA_DIR environment variable should point to the data directory
data_path = os.path.join(
os.environ['MLP_DATA_DIR'],
'msd-10-genre-{0}.npz'.format(which_set))
assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path
)
# load data from compressed numpy file
loaded = np.load(data_path)
inputs, targets = loaded['inputs'], loaded['targets']
# label map gives strings corresponding to integer label targets
self.label_map = loaded['label_map']
# pass the loaded data to the parent class __init__
super(MSD10GenreDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
class MSD25GenreDataProvider(OneOfKDataProvider):
"""Data provider for Million Song Dataset 25-genre classification task."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None):
"""Create a new Million Song Dataset 25-genre data provider object.
Args:
which_set: One of 'train' or 'valid'. Determines which
portion of the MSD 25-genre data this object should provide.
batch_size (int): Number of data points to include in each batch.
max_num_batches (int): Maximum number of batches to iterate over
in an epoch. If `max_num_batches * batch_size > num_data` then
only as many batches as the data can be split into will be
used. If set to -1 all of the data will be used.
shuffle_order (bool): Whether to randomly permute the order of
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
# check a valid which_set was provided
assert which_set in ['train', 'valid'], (
'Expected which_set to be either train or valid. '
'Got {0}'.format(which_set)
)
self.which_set = which_set
self.num_classes = 25
# construct path to data using os.path.join to ensure the correct path
# separator for the current platform / OS is used
# MLP_DATA_DIR environment variable should point to the data directory
data_path = os.path.join(
os.environ['MLP_DATA_DIR'],
'msd-25-genre-{0}.npz'.format(which_set))
assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path
)
# load data from compressed numpy file
loaded = np.load(data_path)
inputs, targets = loaded['inputs'], loaded['targets']
# label map gives strings corresponding to integer label targets
self.label_map = loaded['label_map']
# pass the loaded data to the parent class __init__
super(MSD25GenreDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
class MetOfficeDataProvider(DataProvider):
"""South Scotland Met Office weather data provider."""