From 55569634903fde2cc0452f956a908e8cd32ef35c Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Sun, 22 Jan 2017 12:44:24 +0000 Subject: [PATCH] Updating MSD data providers to reshape and upcast. --- mlp/data_providers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlp/data_providers.py b/mlp/data_providers.py index c99549d..d786c1c 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -342,6 +342,8 @@ class MSD10GenreDataProvider(OneOfKDataProvider): # load data from compressed numpy file loaded = np.load(data_path) inputs, targets = loaded['inputs'], loaded['targets'] + # flatten inputs to vectors and upcast to float32 + inputs = inputs.reshape((inputs.shape[0], -1)).astype('float32') # label map gives strings corresponding to integer label targets self.label_map = loaded['label_map'] # pass the loaded data to the parent class __init__ @@ -387,6 +389,8 @@ class MSD25GenreDataProvider(OneOfKDataProvider): # load data from compressed numpy file loaded = np.load(data_path) inputs, targets = loaded['inputs'], loaded['targets'] + # flatten inputs to vectors and upcast to float32 + inputs = inputs.reshape((inputs.shape[0], -1)).astype('float32') # label map gives strings corresponding to integer label targets self.label_map = loaded['label_map'] # pass the loaded data to the parent class __init__