diff --git a/mlp/data_providers.py b/mlp/data_providers.py index 78816c3..a90e733 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -133,10 +133,10 @@ class MNISTDataProvider(DataProvider): super(MNISTDataProvider, self).__init__( inputs, targets, batch_size, max_num_batches, shuffle_order, rng) - def next(self): - """Returns next data batch or raises `StopIteration` if at end.""" - inputs_batch, targets_batch = super(MNISTDataProvider, self).next() - return inputs_batch, self.to_one_of_k(targets_batch) + #def next(self): + # """Returns next data batch or raises `StopIteration` if at end.""" + # inputs_batch, targets_batch = super(MNISTDataProvider, self).next() + # return inputs_batch, self.to_one_of_k(targets_batch) def to_one_of_k(self, int_targets): """Converts integer coded class target to 1 of K coded targets. @@ -153,9 +153,7 @@ class MNISTDataProvider(DataProvider): to zero except for the column corresponding to the correct class which is equal to one. """ - one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) - one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1 - return one_of_k_targets + raise NotImplementedError() class MetOfficeDataProvider(DataProvider): @@ -179,28 +177,25 @@ class MetOfficeDataProvider(DataProvider): the data before each epoch. rng (RandomState): A seeded random number generator. """ + self.window_size = window_size + assert window_size > 1, 'window_size must be at least 2.' data_path = os.path.join( os.environ['MLP_DATA_DIR'], 'HadSSP_daily_qc.txt') assert os.path.isfile(data_path), ( 'Data file does not exist at expected path: ' + data_path ) - raw = np.loadtxt(data_path, skiprows=3, usecols=range(2, 32)) - assert window_size > 1, 'window_size must be at least 2.' - self.window_size = window_size + # load raw data from text file + # ... # filter out all missing datapoints and flatten to a vector - filtered = raw[raw >= 0].flatten() + # ... # normalise data to zero mean, unit standard deviation - mean = np.mean(filtered) - std = np.std(filtered) - normalised = (filtered - mean) / std - # create a view on to array corresponding to a rolling window - shape = (normalised.shape[-1] - self.window_size + 1, self.window_size) - strides = normalised.strides + (normalised.strides[-1],) - windowed = np.lib.stride_tricks.as_strided( - normalised, shape=shape, strides=strides) + # ... + # convert from flat sequence to windowed data + # ... # inputs are first (window_size - 1) entries in windows - inputs = windowed[:, :-1] + # inputs = ... # targets are last entry in windows - targets = windowed[:, -1] - super(MetOfficeDataProvider, self).__init__( - inputs, targets, batch_size, max_num_batches, shuffle_order, rng) + # targets = ... + # initialise base class with inputs and targets arrays + # super(MetOfficeDataProvider, self).__init__( + # inputs, targets, batch_size, max_num_batches, shuffle_order, rng)