Changing data providers to skeleton code for first lab.
This commit is contained in:
parent
5c946ef7da
commit
4fc931747c
@ -133,10 +133,10 @@ class MNISTDataProvider(DataProvider):
|
||||
super(MNISTDataProvider, self).__init__(
|
||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
|
||||
return inputs_batch, self.to_one_of_k(targets_batch)
|
||||
#def next(self):
|
||||
# """Returns next data batch or raises `StopIteration` if at end."""
|
||||
# inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
|
||||
# return inputs_batch, self.to_one_of_k(targets_batch)
|
||||
|
||||
def to_one_of_k(self, int_targets):
|
||||
"""Converts integer coded class target to 1 of K coded targets.
|
||||
@ -153,9 +153,7 @@ class MNISTDataProvider(DataProvider):
|
||||
to zero except for the column corresponding to the correct class
|
||||
which is equal to one.
|
||||
"""
|
||||
one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes))
|
||||
one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1
|
||||
return one_of_k_targets
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MetOfficeDataProvider(DataProvider):
|
||||
@ -179,28 +177,25 @@ class MetOfficeDataProvider(DataProvider):
|
||||
the data before each epoch.
|
||||
rng (RandomState): A seeded random number generator.
|
||||
"""
|
||||
self.window_size = window_size
|
||||
assert window_size > 1, 'window_size must be at least 2.'
|
||||
data_path = os.path.join(
|
||||
os.environ['MLP_DATA_DIR'], 'HadSSP_daily_qc.txt')
|
||||
assert os.path.isfile(data_path), (
|
||||
'Data file does not exist at expected path: ' + data_path
|
||||
)
|
||||
raw = np.loadtxt(data_path, skiprows=3, usecols=range(2, 32))
|
||||
assert window_size > 1, 'window_size must be at least 2.'
|
||||
self.window_size = window_size
|
||||
# load raw data from text file
|
||||
# ...
|
||||
# filter out all missing datapoints and flatten to a vector
|
||||
filtered = raw[raw >= 0].flatten()
|
||||
# ...
|
||||
# normalise data to zero mean, unit standard deviation
|
||||
mean = np.mean(filtered)
|
||||
std = np.std(filtered)
|
||||
normalised = (filtered - mean) / std
|
||||
# create a view on to array corresponding to a rolling window
|
||||
shape = (normalised.shape[-1] - self.window_size + 1, self.window_size)
|
||||
strides = normalised.strides + (normalised.strides[-1],)
|
||||
windowed = np.lib.stride_tricks.as_strided(
|
||||
normalised, shape=shape, strides=strides)
|
||||
# ...
|
||||
# convert from flat sequence to windowed data
|
||||
# ...
|
||||
# inputs are first (window_size - 1) entries in windows
|
||||
inputs = windowed[:, :-1]
|
||||
# inputs = ...
|
||||
# targets are last entry in windows
|
||||
targets = windowed[:, -1]
|
||||
super(MetOfficeDataProvider, self).__init__(
|
||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
# targets = ...
|
||||
# initialise base class with inputs and targets arrays
|
||||
# super(MetOfficeDataProvider, self).__init__(
|
||||
# inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
|
Loading…
Reference in New Issue
Block a user