Changing data providers to skeleton code for first lab.

This commit is contained in:
Matt Graham 2016-09-22 15:41:46 +01:00
parent 04fe5a7279
commit d560f96a90

View File

@ -133,10 +133,10 @@ class MNISTDataProvider(DataProvider):
super(MNISTDataProvider, self).__init__( super(MNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
def next(self): #def next(self):
"""Returns next data batch or raises `StopIteration` if at end.""" # """Returns next data batch or raises `StopIteration` if at end."""
inputs_batch, targets_batch = super(MNISTDataProvider, self).next() # inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
return inputs_batch, self.to_one_of_k(targets_batch) # return inputs_batch, self.to_one_of_k(targets_batch)
def to_one_of_k(self, int_targets): def to_one_of_k(self, int_targets):
"""Converts integer coded class target to 1 of K coded targets. """Converts integer coded class target to 1 of K coded targets.
@ -153,9 +153,7 @@ class MNISTDataProvider(DataProvider):
to zero except for the column corresponding to the correct class to zero except for the column corresponding to the correct class
which is equal to one. which is equal to one.
""" """
one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes)) raise NotImplementedError()
one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1
return one_of_k_targets
class MetOfficeDataProvider(DataProvider): class MetOfficeDataProvider(DataProvider):
@ -179,28 +177,25 @@ class MetOfficeDataProvider(DataProvider):
the data before each epoch. the data before each epoch.
rng (RandomState): A seeded random number generator. rng (RandomState): A seeded random number generator.
""" """
self.window_size = window_size
assert window_size > 1, 'window_size must be at least 2.'
data_path = os.path.join( data_path = os.path.join(
os.environ['MLP_DATA_DIR'], 'HadSSP_daily_qc.txt') os.environ['MLP_DATA_DIR'], 'HadSSP_daily_qc.txt')
assert os.path.isfile(data_path), ( assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path 'Data file does not exist at expected path: ' + data_path
) )
raw = np.loadtxt(data_path, skiprows=3, usecols=range(2, 32)) # load raw data from text file
assert window_size > 1, 'window_size must be at least 2.' # ...
self.window_size = window_size
# filter out all missing datapoints and flatten to a vector # filter out all missing datapoints and flatten to a vector
filtered = raw[raw >= 0].flatten() # ...
# normalise data to zero mean, unit standard deviation # normalise data to zero mean, unit standard deviation
mean = np.mean(filtered) # ...
std = np.std(filtered) # convert from flat sequence to windowed data
normalised = (filtered - mean) / std # ...
# create a view on to array corresponding to a rolling window
shape = (normalised.shape[-1] - self.window_size + 1, self.window_size)
strides = normalised.strides + (normalised.strides[-1],)
windowed = np.lib.stride_tricks.as_strided(
normalised, shape=shape, strides=strides)
# inputs are first (window_size - 1) entries in windows # inputs are first (window_size - 1) entries in windows
inputs = windowed[:, :-1] # inputs = ...
# targets are last entry in windows # targets are last entry in windows
targets = windowed[:, -1] # targets = ...
super(MetOfficeDataProvider, self).__init__( # initialise base class with inputs and targets arrays
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) # super(MetOfficeDataProvider, self).__init__(
# inputs, targets, batch_size, max_num_batches, shuffle_order, rng)