Adding more documentation to data providers module.
This commit is contained in:
parent
7aad4e02a0
commit
48b02abf39
@ -1,5 +1,9 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Data providers."""
|
"""Data providers.
|
||||||
|
|
||||||
|
This module provides classes for loading datasets and iterating over batches of
|
||||||
|
data points.
|
||||||
|
"""
|
||||||
|
|
||||||
import cPickle
|
import cPickle
|
||||||
import gzip
|
import gzip
|
||||||
@ -9,11 +13,25 @@ from mlp import DEFAULT_SEED
|
|||||||
|
|
||||||
|
|
||||||
class DataProvider(object):
|
class DataProvider(object):
|
||||||
"""Interface for generic data-independent readers."""
|
"""Generic data provider."""
|
||||||
|
|
||||||
def __init__(self, inputs, targets, batch_size, max_num_batches=-1,
|
def __init__(self, inputs, targets, batch_size, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None):
|
shuffle_order=True, rng=None):
|
||||||
"""
|
"""Create a new data provider object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (ndarray): Array of data input features of shape
|
||||||
|
(num_data, input_dim).
|
||||||
|
targets (ndarray): Array of data output targets of shape
|
||||||
|
(num_data, output_dim) or (num_data,) if output_dim == 1.
|
||||||
|
batch_size (int): Number of data points to include in each batch.
|
||||||
|
max_num_batches (int): Maximum number of batches to iterate over
|
||||||
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
|
only as many batches as the data can be split into will be
|
||||||
|
used. If set to -1 all of the data will be used.
|
||||||
|
shuffle_order (bool): Whether to randomly permute the order of
|
||||||
|
the data before each epoch.
|
||||||
|
rng (RandomState): A seeded random number generator.
|
||||||
"""
|
"""
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
@ -21,6 +39,9 @@ class DataProvider(object):
|
|||||||
assert max_num_batches != 0 and not max_num_batches < -1, (
|
assert max_num_batches != 0 and not max_num_batches < -1, (
|
||||||
'max_num_batches should be -1 or > 0')
|
'max_num_batches should be -1 or > 0')
|
||||||
self.max_num_batches = max_num_batches
|
self.max_num_batches = max_num_batches
|
||||||
|
# maximum possible number of batches is equal to number of whole times
|
||||||
|
# batch_size divides in to the number of data points which can be
|
||||||
|
# found using integer division
|
||||||
possible_num_batches = self.inputs.shape[0] // batch_size
|
possible_num_batches = self.inputs.shape[0] // batch_size
|
||||||
if self.max_num_batches == -1:
|
if self.max_num_batches == -1:
|
||||||
self.num_batches = possible_num_batches
|
self.num_batches = possible_num_batches
|
||||||
@ -33,6 +54,13 @@ class DataProvider(object):
|
|||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
"""Implements Python iterator interface.
|
||||||
|
|
||||||
|
This should return an object implementing a `next` method which steps
|
||||||
|
through a sequence returning one element at a time and raising
|
||||||
|
`StopIteration` when at the end of the sequence. Here the object
|
||||||
|
returned is the DataProvider itself.
|
||||||
|
"""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -42,15 +70,19 @@ class DataProvider(object):
|
|||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
def shuffle(self):
|
def shuffle(self):
|
||||||
"""Shuffles order of data."""
|
"""Randomly shuffles order of data."""
|
||||||
new_order = self.rng.permutation(self.inputs.shape[0])
|
new_order = self.rng.permutation(self.inputs.shape[0])
|
||||||
self.inputs = self.inputs[new_order]
|
self.inputs = self.inputs[new_order]
|
||||||
self.targets = self.targets[new_order]
|
self.targets = self.targets[new_order]
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
if self._curr_batch + 1 > self.num_batches:
|
if self._curr_batch + 1 > self.num_batches:
|
||||||
|
# no more batches in current iteration through data set so reset
|
||||||
|
# the dataset for another pass and indicate iteration is at end
|
||||||
self.reset()
|
self.reset()
|
||||||
raise StopIteration()
|
raise StopIteration()
|
||||||
|
# create an index slice corresponding to current batch number
|
||||||
batch_slice = slice(self._curr_batch * self.batch_size,
|
batch_slice = slice(self._curr_batch * self.batch_size,
|
||||||
(self._curr_batch + 1) * self.batch_size)
|
(self._curr_batch + 1) * self.batch_size)
|
||||||
inputs_batch = self.inputs[batch_slice]
|
inputs_batch = self.inputs[batch_slice]
|
||||||
@ -64,27 +96,63 @@ class MNISTDataProvider(DataProvider):
|
|||||||
|
|
||||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None):
|
shuffle_order=True, rng=None):
|
||||||
|
"""Create a new MNIST data provider object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
which_set: One of 'train', 'valid' or 'eval'. Determines which
|
||||||
|
portion of the MNIST data this object should provide.
|
||||||
|
batch_size (int): Number of data points to include in each batch.
|
||||||
|
max_num_batches (int): Maximum number of batches to iterate over
|
||||||
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
|
only as many batches as the data can be split into will be
|
||||||
|
used. If set to -1 all of the data will be used.
|
||||||
|
shuffle_order (bool): Whether to randomly permute the order of
|
||||||
|
the data before each epoch.
|
||||||
|
rng (RandomState): A seeded random number generator.
|
||||||
|
"""
|
||||||
|
# check a valid which_set was provided
|
||||||
assert which_set in ['train', 'valid', 'eval'], (
|
assert which_set in ['train', 'valid', 'eval'], (
|
||||||
'Expected which_set to be either train, valid or eval. '
|
'Expected which_set to be either train, valid or eval. '
|
||||||
'Got {0}'.format(which_set)
|
'Got {0}'.format(which_set)
|
||||||
)
|
)
|
||||||
self.which_set = which_set
|
self.which_set = which_set
|
||||||
self.num_classes = 10
|
self.num_classes = 10
|
||||||
|
# construct path to data using os.path.join to ensure the correct path
|
||||||
|
# separator for the current platform / OS is used
|
||||||
|
# MLP_DATA_DIR environment variable should point to the data directory
|
||||||
data_path = os.path.join(
|
data_path = os.path.join(
|
||||||
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set))
|
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set))
|
||||||
assert os.path.isfile(data_path), (
|
assert os.path.isfile(data_path), (
|
||||||
'Data file does not exist at expected path: ' + data_path
|
'Data file does not exist at expected path: ' + data_path
|
||||||
)
|
)
|
||||||
|
# use a context-manager to ensure the files are properly closed after
|
||||||
|
# we are finished with them
|
||||||
with gzip.open(data_path) as f:
|
with gzip.open(data_path) as f:
|
||||||
inputs, targets = cPickle.load(f)
|
inputs, targets = cPickle.load(f)
|
||||||
|
# pass the loaded data to the parent class __init__
|
||||||
super(MNISTDataProvider, self).__init__(
|
super(MNISTDataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
|
inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
|
||||||
return inputs_batch, self.to_one_of_k(targets_batch)
|
return inputs_batch, self.to_one_of_k(targets_batch)
|
||||||
|
|
||||||
def to_one_of_k(self, int_targets):
|
def to_one_of_k(self, int_targets):
|
||||||
|
"""Converts integer coded class target to 1 of K coded targets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
int_targets (ndarray): Array of integer coded class targets (i.e.
|
||||||
|
where an integer from 0 to `num_classes` - 1 is used to
|
||||||
|
indicate which is the correct class). This should be of shape
|
||||||
|
(num_data,).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of 1 of K coded targets i.e. an array of shape
|
||||||
|
(num_data, num_classes) where for each row all elements are equal
|
||||||
|
to zero except for the column corresponding to the correct class
|
||||||
|
which is equal to one.
|
||||||
|
"""
|
||||||
one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes))
|
one_of_k_targets = np.zeros((int_targets.shape[0], self.num_classes))
|
||||||
one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1
|
one_of_k_targets[range(int_targets.shape[0]), int_targets] = 1
|
||||||
return one_of_k_targets
|
return one_of_k_targets
|
||||||
@ -95,6 +163,22 @@ class MetOfficeDataProvider(DataProvider):
|
|||||||
|
|
||||||
def __init__(self, window_size, batch_size=10, max_num_batches=-1,
|
def __init__(self, window_size, batch_size=10, max_num_batches=-1,
|
||||||
shuffle_order=True, rng=None):
|
shuffle_order=True, rng=None):
|
||||||
|
"""Create a new Met Offfice data provider object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
window_size (int): Size of windows to split weather time series
|
||||||
|
data into. The constructed input features will be the first
|
||||||
|
`window_size - 1` entries in each window and the target outputs
|
||||||
|
the last entry in each window.
|
||||||
|
batch_size (int): Number of data points to include in each batch.
|
||||||
|
max_num_batches (int): Maximum number of batches to iterate over
|
||||||
|
in an epoch. If `max_num_batches * batch_size > num_data` then
|
||||||
|
only as many batches as the data can be split into will be
|
||||||
|
used. If set to -1 all of the data will be used.
|
||||||
|
shuffle_order (bool): Whether to randomly permute the order of
|
||||||
|
the data before each epoch.
|
||||||
|
rng (RandomState): A seeded random number generator.
|
||||||
|
"""
|
||||||
data_path = os.path.join(
|
data_path = os.path.join(
|
||||||
os.environ['MLP_DATA_DIR'], 'HadSSP_daily_qc.txt')
|
os.environ['MLP_DATA_DIR'], 'HadSSP_daily_qc.txt')
|
||||||
assert os.path.isfile(data_path), (
|
assert os.path.isfile(data_path), (
|
||||||
|
Loading…
Reference in New Issue
Block a user