diff --git a/mlp/data_providers.py b/mlp/data_providers.py index c99549d..d786c1c 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -342,6 +342,8 @@ class MSD10GenreDataProvider(OneOfKDataProvider): # load data from compressed numpy file loaded = np.load(data_path) inputs, targets = loaded['inputs'], loaded['targets'] + # flatten inputs to vectors and upcast to float32 + inputs = inputs.reshape((inputs.shape[0], -1)).astype('float32') # label map gives strings corresponding to integer label targets self.label_map = loaded['label_map'] # pass the loaded data to the parent class __init__ @@ -387,6 +389,8 @@ class MSD25GenreDataProvider(OneOfKDataProvider): # load data from compressed numpy file loaded = np.load(data_path) inputs, targets = loaded['inputs'], loaded['targets'] + # flatten inputs to vectors and upcast to float32 + inputs = inputs.reshape((inputs.shape[0], -1)).astype('float32') # label map gives strings corresponding to integer label targets self.label_map = loaded['label_map'] # pass the loaded data to the parent class __init__