diff --git a/mlp/data_providers.py b/mlp/data_providers.py index 3868a66..9166b17 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -284,6 +284,8 @@ class CCPPDataProvider(DataProvider): assert input_dims.issubset({0, 1, 2, 3}), ( 'input_dims should be a subset of {0, 1, 2, 3}' ) + # convert to list as cannot index ndarray by set + input_dims = list(input_dims) loaded = np.load(data_path) inputs = loaded[which_set + '_inputs'] if input_dims is not None: