Updating MSD data providers to reshape and upcast.

This commit is contained in:
Matt Graham 2017-01-22 12:44:24 +00:00
parent 8507224d90
commit 5556963490

View File

@ -342,6 +342,8 @@ class MSD10GenreDataProvider(OneOfKDataProvider):
# load data from compressed numpy file
loaded = np.load(data_path)
inputs, targets = loaded['inputs'], loaded['targets']
# flatten inputs to vectors and upcast to float32
inputs = inputs.reshape((inputs.shape[0], -1)).astype('float32')
# label map gives strings corresponding to integer label targets
self.label_map = loaded['label_map']
# pass the loaded data to the parent class __init__
@ -387,6 +389,8 @@ class MSD25GenreDataProvider(OneOfKDataProvider):
# load data from compressed numpy file
loaded = np.load(data_path)
inputs, targets = loaded['inputs'], loaded['targets']
# flatten inputs to vectors and upcast to float32
inputs = inputs.reshape((inputs.shape[0], -1)).astype('float32')
# label map gives strings corresponding to integer label targets
self.label_map = loaded['label_map']
# pass the loaded data to the parent class __init__