Updating data_providers module to be Python 3 compatible.
This commit is contained in:
parent
7867a372e4
commit
e89ffedf3f
BIN
data/mnist-test.npz
Normal file
BIN
data/mnist-test.npz
Normal file
Binary file not shown.
BIN
data/mnist-train.npz
Normal file
BIN
data/mnist-train.npz
Normal file
Binary file not shown.
BIN
data/mnist-valid.npz
Normal file
BIN
data/mnist-valid.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -5,10 +5,8 @@ This module provides classes for loading datasets and iterating over batches of
|
||||
data points.
|
||||
"""
|
||||
|
||||
import cPickle
|
||||
import gzip
|
||||
import numpy as np
|
||||
import os
|
||||
import numpy as np
|
||||
from mlp import DEFAULT_SEED
|
||||
|
||||
|
||||
@ -130,6 +128,10 @@ class DataProvider(object):
|
||||
self._curr_batch += 1
|
||||
return inputs_batch, targets_batch
|
||||
|
||||
# Python 3.x compatibility
|
||||
def __next__(self):
|
||||
return self.next()
|
||||
|
||||
|
||||
class OneOfKDataProvider(DataProvider):
|
||||
"""1-of-K classification target data provider.
|
||||
@ -194,14 +196,14 @@ class MNISTDataProvider(OneOfKDataProvider):
|
||||
# separator for the current platform / OS is used
|
||||
# MLP_DATA_DIR environment variable should point to the data directory
|
||||
data_path = os.path.join(
|
||||
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set))
|
||||
os.environ['MLP_DATA_DIR'], 'mnist-{0}.npz'.format(which_set))
|
||||
assert os.path.isfile(data_path), (
|
||||
'Data file does not exist at expected path: ' + data_path
|
||||
)
|
||||
# use a context-manager to ensure the files are properly closed after
|
||||
# we are finished with them
|
||||
with gzip.open(data_path) as f:
|
||||
inputs, targets = cPickle.load(f)
|
||||
# load data from compressed numpy file
|
||||
loaded = np.load(data_path)
|
||||
inputs, targets = loaded['inputs'], loaded['targets']
|
||||
inputs = inputs.astype(np.float32)
|
||||
# pass the loaded data to the parent class __init__
|
||||
super(MNISTDataProvider, self).__init__(
|
||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
|
Loading…
Reference in New Issue
Block a user