Updating data_providers module to be Python 3 compatible.
This commit is contained in:
parent
7867a372e4
commit
e89ffedf3f
BIN
data/mnist-test.npz
Normal file
BIN
data/mnist-test.npz
Normal file
Binary file not shown.
BIN
data/mnist-train.npz
Normal file
BIN
data/mnist-train.npz
Normal file
Binary file not shown.
BIN
data/mnist-valid.npz
Normal file
BIN
data/mnist-valid.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -5,10 +5,8 @@ This module provides classes for loading datasets and iterating over batches of
|
|||||||
data points.
|
data points.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import cPickle
|
|
||||||
import gzip
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
from mlp import DEFAULT_SEED
|
from mlp import DEFAULT_SEED
|
||||||
|
|
||||||
|
|
||||||
@ -130,6 +128,10 @@ class DataProvider(object):
|
|||||||
self._curr_batch += 1
|
self._curr_batch += 1
|
||||||
return inputs_batch, targets_batch
|
return inputs_batch, targets_batch
|
||||||
|
|
||||||
|
# Python 3.x compatibility
|
||||||
|
def __next__(self):
|
||||||
|
return self.next()
|
||||||
|
|
||||||
|
|
||||||
class OneOfKDataProvider(DataProvider):
|
class OneOfKDataProvider(DataProvider):
|
||||||
"""1-of-K classification target data provider.
|
"""1-of-K classification target data provider.
|
||||||
@ -194,14 +196,14 @@ class MNISTDataProvider(OneOfKDataProvider):
|
|||||||
# separator for the current platform / OS is used
|
# separator for the current platform / OS is used
|
||||||
# MLP_DATA_DIR environment variable should point to the data directory
|
# MLP_DATA_DIR environment variable should point to the data directory
|
||||||
data_path = os.path.join(
|
data_path = os.path.join(
|
||||||
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set))
|
os.environ['MLP_DATA_DIR'], 'mnist-{0}.npz'.format(which_set))
|
||||||
assert os.path.isfile(data_path), (
|
assert os.path.isfile(data_path), (
|
||||||
'Data file does not exist at expected path: ' + data_path
|
'Data file does not exist at expected path: ' + data_path
|
||||||
)
|
)
|
||||||
# use a context-manager to ensure the files are properly closed after
|
# load data from compressed numpy file
|
||||||
# we are finished with them
|
loaded = np.load(data_path)
|
||||||
with gzip.open(data_path) as f:
|
inputs, targets = loaded['inputs'], loaded['targets']
|
||||||
inputs, targets = cPickle.load(f)
|
inputs = inputs.astype(np.float32)
|
||||||
# pass the loaded data to the parent class __init__
|
# pass the loaded data to the parent class __init__
|
||||||
super(MNISTDataProvider, self).__init__(
|
super(MNISTDataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||||
|
Loading…
Reference in New Issue
Block a user