Updating MSD data providers to reshape and upcast.
This commit is contained in:
parent
8507224d90
commit
5556963490
@ -342,6 +342,8 @@ class MSD10GenreDataProvider(OneOfKDataProvider):
|
||||
# load data from compressed numpy file
|
||||
loaded = np.load(data_path)
|
||||
inputs, targets = loaded['inputs'], loaded['targets']
|
||||
# flatten inputs to vectors and upcast to float32
|
||||
inputs = inputs.reshape((inputs.shape[0], -1)).astype('float32')
|
||||
# label map gives strings corresponding to integer label targets
|
||||
self.label_map = loaded['label_map']
|
||||
# pass the loaded data to the parent class __init__
|
||||
@ -387,6 +389,8 @@ class MSD25GenreDataProvider(OneOfKDataProvider):
|
||||
# load data from compressed numpy file
|
||||
loaded = np.load(data_path)
|
||||
inputs, targets = loaded['inputs'], loaded['targets']
|
||||
# flatten inputs to vectors and upcast to float32
|
||||
inputs = inputs.reshape((inputs.shape[0], -1)).astype('float32')
|
||||
# label map gives strings corresponding to integer label targets
|
||||
self.label_map = loaded['label_map']
|
||||
# pass the loaded data to the parent class __init__
|
||||
|
Loading…
Reference in New Issue
Block a user