mlpractical/mlp/dataset.py
2016-09-19 11:16:53 +01:00

123 lines
4.6 KiB
Python

# -*- coding: utf-8 -*-
"""Data providers."""
import cPickle
import gzip
import numpy as np
import os
from mlp import DEFAULT_SEED
class DataProvider(object):
"""Interface for generic data-independent readers."""
def __init__(self, inputs, targets, batch_size, max_num_batches=-1,
shuffle_order=True, rng=None):
"""
"""
self.inputs = inputs
self.targets = targets
self.batch_size = batch_size
assert max_num_batches != 0 and not max_num_batches < -1, (
'max_num_batches should be -1 or > 0')
self.max_num_batches = max_num_batches
possible_num_batches = self.inputs.shape[0] // batch_size
if self.max_num_batches == -1:
self.num_batches = possible_num_batches
else:
self.num_batches = min(self.max_num_batches, possible_num_batches)
self.shuffle_order = shuffle_order
if rng is None:
rng = np.random.RandomState(DEFAULT_SEED)
self.rng = rng
self.reset()
def __iter__(self):
return self
def reset(self):
"""Resets the provider to the initial state to use in a new epoch."""
self._curr_batch = 0
if self.shuffle_order:
self.shuffle()
def shuffle(self):
"""Shuffles order of data."""
new_order = self.rng.permutation(self.inputs.shape[0])
self.inputs = self.inputs[new_order]
self.targets = self.targets[new_order]
def next(self):
if self._curr_batch + 1 > self.num_batches:
self.reset()
raise StopIteration()
batch_slice = slice(self._curr_batch * self.batch_size,
(self._curr_batch + 1) * self.batch_size)
inputs_batch = self.inputs[batch_slice]
targets_batch = self.targets[batch_slice]
self._curr_batch += 1
return inputs_batch, targets_batch
class MNISTDataProvider(DataProvider):
"""Data provider for MNIST handwritten digit images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
shuffle_order=True, rng=None):
assert which_set in ['train', 'valid', 'eval'], (
'Expected which_set to be either train, valid or eval. '
'Got {0}'.format(which_set)
)
self.which_set = which_set
self.num_classes = 10
data_path = os.path.join(
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set))
assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path
)
with gzip.open(data_path) as f:
inputs, targets = cPickle.load(f)
super(MNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
def next(self):
inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
return inputs_batch, self.to_one_of_k(targets_batch)
def to_one_of_k(self, int_targets):
one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes))
one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1
return one_of_k_targets
class MetOfficeDataProvider(DataProvider):
"""South Scotland Met Office weather data provider."""
def __init__(self, window_size, batch_size=10, max_num_batches=-1,
shuffle_order=True, rng=None):
data_path = os.path.join(
os.environ['MLP_DATA_DIR'], 'HadSSP_daily_qc.txt')
assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path
)
raw = np.loadtxt(data_path, skiprows=3, usecols=range(2, 32))
assert window_size > 1, 'window_size must be at least 2.'
self.window_size = window_size
#filter out all missing datapoints and flatten to a vector
filtered = raw[raw >= 0].flatten()
#normalise data to zero mean, unit standard deviation
mean = np.mean(filtered)
std = np.std(filtered)
normalised = (filtered - mean) / std
# create a view on to array corresponding to a rolling window
shape = (normalised.shape[-1] - self.window_size + 1, self.window_size)
strides = normalised.strides + (normalised.strides[-1],)
windowed = np.lib.stride_tricks.as_strided(
normalised, shape=shape, strides=strides)
# inputs are first (window_size - 1) entries in windows
inputs = windowed[:, :-1]
# targets are last entry in windows
targets = windowed[:, -1]
super(MetOfficeDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)