Making sure input_dims set converted to list.

This commit is contained in:
Matt Graham 2016-10-18 13:44:30 +01:00
parent 5d4c979ef3
commit 237cff12db

View File

@ -284,6 +284,8 @@ class CCPPDataProvider(DataProvider):
assert input_dims.issubset({0, 1, 2, 3}), ( assert input_dims.issubset({0, 1, 2, 3}), (
'input_dims should be a subset of {0, 1, 2, 3}' 'input_dims should be a subset of {0, 1, 2, 3}'
) )
# convert to list as cannot index ndarray by set
input_dims = list(input_dims)
loaded = np.load(data_path) loaded = np.load(data_path)
inputs = loaded[which_set + '_inputs'] inputs = loaded[which_set + '_inputs']
if input_dims is not None: if input_dims is not None: