Updating data_providers module to be Python 3 compatible.

This commit is contained in:
Matt Graham 2017-01-24 16:57:40 +00:00
parent 7867a372e4
commit e89ffedf3f
7 changed files with 10 additions and 8 deletions

BIN
data/mnist-test.npz Normal file

Binary file not shown.

BIN
data/mnist-train.npz Normal file

Binary file not shown.

BIN
data/mnist-valid.npz Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -5,10 +5,8 @@ This module provides classes for loading datasets and iterating over batches of
data points. data points.
""" """
import cPickle
import gzip
import numpy as np
import os import os
import numpy as np
from mlp import DEFAULT_SEED from mlp import DEFAULT_SEED
@ -130,6 +128,10 @@ class DataProvider(object):
self._curr_batch += 1 self._curr_batch += 1
return inputs_batch, targets_batch return inputs_batch, targets_batch
# Python 3.x compatibility
def __next__(self):
return self.next()
class OneOfKDataProvider(DataProvider): class OneOfKDataProvider(DataProvider):
"""1-of-K classification target data provider. """1-of-K classification target data provider.
@ -194,14 +196,14 @@ class MNISTDataProvider(OneOfKDataProvider):
# separator for the current platform / OS is used # separator for the current platform / OS is used
# MLP_DATA_DIR environment variable should point to the data directory # MLP_DATA_DIR environment variable should point to the data directory
data_path = os.path.join( data_path = os.path.join(
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set)) os.environ['MLP_DATA_DIR'], 'mnist-{0}.npz'.format(which_set))
assert os.path.isfile(data_path), ( assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path 'Data file does not exist at expected path: ' + data_path
) )
# use a context-manager to ensure the files are properly closed after # load data from compressed numpy file
# we are finished with them loaded = np.load(data_path)
with gzip.open(data_path) as f: inputs, targets = loaded['inputs'], loaded['targets']
inputs, targets = cPickle.load(f) inputs = inputs.astype(np.float32)
# pass the loaded data to the parent class __init__ # pass the loaded data to the parent class __init__
super(MNISTDataProvider, self).__init__( super(MNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) inputs, targets, batch_size, max_num_batches, shuffle_order, rng)