lab_1_pickle_bug_fix
This commit is contained in:
parent
a74821dc31
commit
e8a9fce41e
BIN
data/mnist-test.npz
Normal file
BIN
data/mnist-test.npz
Normal file
Binary file not shown.
BIN
data/mnist-train.npz
Normal file
BIN
data/mnist-train.npz
Normal file
Binary file not shown.
BIN
data/mnist-valid.npz
Normal file
BIN
data/mnist-valid.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -5,7 +5,7 @@ This module provides classes for loading datasets and iterating over batches of
|
|||||||
data points.
|
data points.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import cPickle
|
import pickle
|
||||||
import gzip
|
import gzip
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
@ -121,22 +121,25 @@ class MNISTDataProvider(DataProvider):
|
|||||||
# separator for the current platform / OS is used
|
# separator for the current platform / OS is used
|
||||||
# MLP_DATA_DIR environment variable should point to the data directory
|
# MLP_DATA_DIR environment variable should point to the data directory
|
||||||
data_path = os.path.join(
|
data_path = os.path.join(
|
||||||
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set))
|
os.environ['MLP_DATA_DIR'], 'mnist-{0}.npz'.format(which_set))
|
||||||
assert os.path.isfile(data_path), (
|
assert os.path.isfile(data_path), (
|
||||||
'Data file does not exist at expected path: ' + data_path
|
'Data file does not exist at expected path: ' + data_path
|
||||||
)
|
)
|
||||||
# use a context-manager to ensure the files are properly closed after
|
# load data from compressed numpy file
|
||||||
# we are finished with them
|
loaded = np.load(data_path)
|
||||||
with gzip.open(data_path) as f:
|
inputs, targets = loaded['inputs'], loaded['targets']
|
||||||
inputs, targets = cPickle.load(f)
|
inputs = inputs.astype(np.float32)
|
||||||
# pass the loaded data to the parent class __init__
|
# pass the loaded data to the parent class __init__
|
||||||
super(MNISTDataProvider, self).__init__(
|
super(MNISTDataProvider, self).__init__(
|
||||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||||
|
|
||||||
#def next(self):
|
def next(self):
|
||||||
# """Returns next data batch or raises `StopIteration` if at end."""
|
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||||
# inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
|
inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
|
||||||
# return inputs_batch, self.to_one_of_k(targets_batch)
|
return inputs_batch, self.to_one_of_k(targets_batch)
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
return self.next()
|
||||||
|
|
||||||
def to_one_of_k(self, int_targets):
|
def to_one_of_k(self, int_targets):
|
||||||
"""Converts integer coded class target to 1 of K coded targets.
|
"""Converts integer coded class target to 1 of K coded targets.
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user