Updating data_providers module to be Python 3 compatible.

This commit is contained in:
Matt Graham 2017-01-24 16:57:40 +00:00
parent 7867a372e4
commit e89ffedf3f
7 changed files with 10 additions and 8 deletions

BIN
data/mnist-test.npz Normal file

Binary file not shown.

BIN
data/mnist-train.npz Normal file

Binary file not shown.

BIN
data/mnist-valid.npz Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -5,10 +5,8 @@ This module provides classes for loading datasets and iterating over batches of
data points.
"""
import cPickle
import gzip
import numpy as np
import os
import numpy as np
from mlp import DEFAULT_SEED
@ -130,6 +128,10 @@ class DataProvider(object):
self._curr_batch += 1
return inputs_batch, targets_batch
# Python 3.x compatibility
def __next__(self):
return self.next()
class OneOfKDataProvider(DataProvider):
"""1-of-K classification target data provider.
@ -194,14 +196,14 @@ class MNISTDataProvider(OneOfKDataProvider):
# separator for the current platform / OS is used
# MLP_DATA_DIR environment variable should point to the data directory
data_path = os.path.join(
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set))
os.environ['MLP_DATA_DIR'], 'mnist-{0}.npz'.format(which_set))
assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path
)
# use a context-manager to ensure the files are properly closed after
# we are finished with them
with gzip.open(data_path) as f:
inputs, targets = cPickle.load(f)
# load data from compressed numpy file
loaded = np.load(data_path)
inputs, targets = loaded['inputs'], loaded['targets']
inputs = inputs.astype(np.float32)
# pass the loaded data to the parent class __init__
super(MNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)