lab_1_pickle_bug_fix
This commit is contained in:
parent
a74821dc31
commit
e8a9fce41e
BIN
data/mnist-test.npz
Normal file
BIN
data/mnist-test.npz
Normal file
Binary file not shown.
BIN
data/mnist-train.npz
Normal file
BIN
data/mnist-train.npz
Normal file
Binary file not shown.
BIN
data/mnist-valid.npz
Normal file
BIN
data/mnist-valid.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -5,7 +5,7 @@ This module provides classes for loading datasets and iterating over batches of
|
||||
data points.
|
||||
"""
|
||||
|
||||
import cPickle
|
||||
import pickle
|
||||
import gzip
|
||||
import numpy as np
|
||||
import os
|
||||
@ -121,22 +121,25 @@ class MNISTDataProvider(DataProvider):
|
||||
# separator for the current platform / OS is used
|
||||
# MLP_DATA_DIR environment variable should point to the data directory
|
||||
data_path = os.path.join(
|
||||
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set))
|
||||
os.environ['MLP_DATA_DIR'], 'mnist-{0}.npz'.format(which_set))
|
||||
assert os.path.isfile(data_path), (
|
||||
'Data file does not exist at expected path: ' + data_path
|
||||
)
|
||||
# use a context-manager to ensure the files are properly closed after
|
||||
# we are finished with them
|
||||
with gzip.open(data_path) as f:
|
||||
inputs, targets = cPickle.load(f)
|
||||
# load data from compressed numpy file
|
||||
loaded = np.load(data_path)
|
||||
inputs, targets = loaded['inputs'], loaded['targets']
|
||||
inputs = inputs.astype(np.float32)
|
||||
# pass the loaded data to the parent class __init__
|
||||
super(MNISTDataProvider, self).__init__(
|
||||
inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
|
||||
|
||||
#def next(self):
|
||||
# """Returns next data batch or raises `StopIteration` if at end."""
|
||||
# inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
|
||||
# return inputs_batch, self.to_one_of_k(targets_batch)
|
||||
def next(self):
|
||||
"""Returns next data batch or raises `StopIteration` if at end."""
|
||||
inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
|
||||
return inputs_batch, self.to_one_of_k(targets_batch)
|
||||
|
||||
def __next__(self):
|
||||
return self.next()
|
||||
|
||||
def to_one_of_k(self, int_targets):
|
||||
"""Converts integer coded class target to 1 of K coded targets.
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user