Changing data providers to skeleton code for first lab.

This commit is contained in:
Matt Graham 2016-09-22 15:41:46 +01:00
parent 04fe5a7279
commit d560f96a90

View File

@ -133,10 +133,10 @@ class MNISTDataProvider(DataProvider):
super(MNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
def next(self):
"""Returns next data batch or raises `StopIteration` if at end."""
inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
return inputs_batch, self.to_one_of_k(targets_batch)
#def next(self):
# """Returns next data batch or raises `StopIteration` if at end."""
# inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
# return inputs_batch, self.to_one_of_k(targets_batch)
def to_one_of_k(self, int_targets):
"""Converts integer coded class target to 1 of K coded targets.
@ -153,9 +153,7 @@ class MNISTDataProvider(DataProvider):
to zero except for the column corresponding to the correct class
which is equal to one.
"""
one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes))
one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1
return one_of_k_targets
raise NotImplementedError()
class MetOfficeDataProvider(DataProvider):
@ -179,28 +177,25 @@ class MetOfficeDataProvider(DataProvider):
the data before each epoch.
rng (RandomState): A seeded random number generator.
"""
self.window_size = window_size
assert window_size > 1, 'window_size must be at least 2.'
data_path = os.path.join(
os.environ['MLP_DATA_DIR'], 'HadSSP_daily_qc.txt')
assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path
)
raw = np.loadtxt(data_path, skiprows=3, usecols=range(2, 32))
assert window_size > 1, 'window_size must be at least 2.'
self.window_size = window_size
# load raw data from text file
# ...
# filter out all missing datapoints and flatten to a vector
filtered = raw[raw >= 0].flatten()
# ...
# normalise data to zero mean, unit standard deviation
mean = np.mean(filtered)
std = np.std(filtered)
normalised = (filtered - mean) / std
# create a view on to array corresponding to a rolling window
shape = (normalised.shape[-1] - self.window_size + 1, self.window_size)
strides = normalised.strides + (normalised.strides[-1],)
windowed = np.lib.stride_tricks.as_strided(
normalised, shape=shape, strides=strides)
# ...
# convert from flat sequence to windowed data
# ...
# inputs are first (window_size - 1) entries in windows
inputs = windowed[:, :-1]
# inputs = ...
# targets are last entry in windows
targets = windowed[:, -1]
super(MetOfficeDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
# targets = ...
# initialise base class with inputs and targets arrays
# super(MetOfficeDataProvider, self).__init__(
# inputs, targets, batch_size, max_num_batches, shuffle_order, rng)