diff --git a/data/ccpp_data.npz b/data/ccpp_data.npz new file mode 100644 index 0000000..a507ba2 Binary files /dev/null and b/data/ccpp_data.npz differ diff --git a/mlp/data_providers.py b/mlp/data_providers.py index 78816c3..bbd9632 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -163,7 +163,7 @@ class MetOfficeDataProvider(DataProvider): def __init__(self, window_size, batch_size=10, max_num_batches=-1, shuffle_order=True, rng=None): - """Create a new Met Offfice data provider object. + """Create a new Met Office data provider object. Args: window_size (int): Size of windows to split weather time series @@ -204,3 +204,50 @@ class MetOfficeDataProvider(DataProvider): targets = windowed[:, -1] super(MetOfficeDataProvider, self).__init__( inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + + +class CCPPDataProvider(DataProvider): + + def __init__(self, which_set='train', input_dims=None, batch_size=10, + max_num_batches=-1, shuffle_order=True, rng=None): + """Create a new Combined Cycle Power Plant data provider object. + + Args: + which_set: One of 'train' or 'valid'. Determines which portion of + data this object should provide. + input_dims: Which of the four input dimension to use. If `None` all + are used. If an iterable of integers are provided (consisting + of a subset of {0, 1, 2, 3}) then only the corresponding + input dimensions are included. + batch_size (int): Number of data points to include in each batch. + max_num_batches (int): Maximum number of batches to iterate over + in an epoch. If `max_num_batches * batch_size > num_data` then + only as many batches as the data can be split into will be + used. If set to -1 all of the data will be used. + shuffle_order (bool): Whether to randomly permute the order of + the data before each epoch. + rng (RandomState): A seeded random number generator. + """ + data_path = os.path.join( + os.environ['MLP_DATA_DIR'], 'ccpp_data.npz') + assert os.path.isfile(data_path), ( + 'Data file does not exist at expected path: ' + data_path + ) + # check a valid which_set was provided + assert which_set in ['train', 'valid'], ( + 'Expected which_set to be either train or valid ' + 'Got {0}'.format(which_set) + ) + # check input_dims are valid + if not input_dims is not None: + input_dims = set(input_dims) + assert input_dims.issubset({0, 1, 2, 3}), ( + 'input_dims should be a subset of {0, 1, 2, 3}' + ) + loaded = np.load(data_path) + inputs = loaded[which_set + '_inputs'] + if input_dims is not None: + inputs = inputs[:, input_dims] + targets = loaded[which_set + '_targets'] + super(CCPPDataProvider, self).__init__( + inputs, targets, batch_size, max_num_batches, shuffle_order, rng)