diff --git a/data/mnist-test.npz b/data/mnist-test.npz new file mode 100644 index 0000000..0ef69fa Binary files /dev/null and b/data/mnist-test.npz differ diff --git a/data/mnist-train.npz b/data/mnist-train.npz new file mode 100644 index 0000000..b16e9ab Binary files /dev/null and b/data/mnist-train.npz differ diff --git a/data/mnist-valid.npz b/data/mnist-valid.npz new file mode 100644 index 0000000..d4fe806 Binary files /dev/null and b/data/mnist-valid.npz differ diff --git a/data/mnist_test.pkl.gz b/data/mnist_test.pkl.gz deleted file mode 100644 index 316f266..0000000 Binary files a/data/mnist_test.pkl.gz and /dev/null differ diff --git a/data/mnist_train.pkl.gz b/data/mnist_train.pkl.gz deleted file mode 100644 index a5e9aeb..0000000 Binary files a/data/mnist_train.pkl.gz and /dev/null differ diff --git a/data/mnist_valid.pkl.gz b/data/mnist_valid.pkl.gz deleted file mode 100644 index 69afd9f..0000000 Binary files a/data/mnist_valid.pkl.gz and /dev/null differ diff --git a/mlp/data_providers.py b/mlp/data_providers.py index d786c1c..279b178 100644 --- a/mlp/data_providers.py +++ b/mlp/data_providers.py @@ -5,10 +5,8 @@ This module provides classes for loading datasets and iterating over batches of data points. """ -import cPickle -import gzip -import numpy as np import os +import numpy as np from mlp import DEFAULT_SEED @@ -130,6 +128,10 @@ class DataProvider(object): self._curr_batch += 1 return inputs_batch, targets_batch + # Python 3.x compatibility + def __next__(self): + return self.next() + class OneOfKDataProvider(DataProvider): """1-of-K classification target data provider. @@ -194,14 +196,14 @@ class MNISTDataProvider(OneOfKDataProvider): # separator for the current platform / OS is used # MLP_DATA_DIR environment variable should point to the data directory data_path = os.path.join( - os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set)) + os.environ['MLP_DATA_DIR'], 'mnist-{0}.npz'.format(which_set)) assert os.path.isfile(data_path), ( 'Data file does not exist at expected path: ' + data_path ) - # use a context-manager to ensure the files are properly closed after - # we are finished with them - with gzip.open(data_path) as f: - inputs, targets = cPickle.load(f) + # load data from compressed numpy file + loaded = np.load(data_path) + inputs, targets = loaded['inputs'], loaded['targets'] + inputs = inputs.astype(np.float32) # pass the loaded data to the parent class __init__ super(MNISTDataProvider, self).__init__( inputs, targets, batch_size, max_num_batches, shuffle_order, rng)