lab_1_pickle_bug_fix

This commit is contained in:
AntreasAntoniou 2017-09-22 20:57:33 +01:00
parent a74821dc31
commit e8a9fce41e
8 changed files with 55 additions and 213 deletions

BIN
data/mnist-test.npz Normal file

Binary file not shown.

BIN
data/mnist-train.npz Normal file

Binary file not shown.

BIN
data/mnist-valid.npz Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -5,7 +5,7 @@ This module provides classes for loading datasets and iterating over batches of
data points. data points.
""" """
import cPickle import pickle
import gzip import gzip
import numpy as np import numpy as np
import os import os
@ -121,22 +121,25 @@ class MNISTDataProvider(DataProvider):
# separator for the current platform / OS is used # separator for the current platform / OS is used
# MLP_DATA_DIR environment variable should point to the data directory # MLP_DATA_DIR environment variable should point to the data directory
data_path = os.path.join( data_path = os.path.join(
os.environ['MLP_DATA_DIR'], 'mnist_{0}.pkl.gz'.format(which_set)) os.environ['MLP_DATA_DIR'], 'mnist-{0}.npz'.format(which_set))
assert os.path.isfile(data_path), ( assert os.path.isfile(data_path), (
'Data file does not exist at expected path: ' + data_path 'Data file does not exist at expected path: ' + data_path
) )
# use a context-manager to ensure the files are properly closed after # load data from compressed numpy file
# we are finished with them loaded = np.load(data_path)
with gzip.open(data_path) as f: inputs, targets = loaded['inputs'], loaded['targets']
inputs, targets = cPickle.load(f) inputs = inputs.astype(np.float32)
# pass the loaded data to the parent class __init__ # pass the loaded data to the parent class __init__
super(MNISTDataProvider, self).__init__( super(MNISTDataProvider, self).__init__(
inputs, targets, batch_size, max_num_batches, shuffle_order, rng) inputs, targets, batch_size, max_num_batches, shuffle_order, rng)
#def next(self): def next(self):
# """Returns next data batch or raises `StopIteration` if at end.""" """Returns next data batch or raises `StopIteration` if at end."""
# inputs_batch, targets_batch = super(MNISTDataProvider, self).next() inputs_batch, targets_batch = super(MNISTDataProvider, self).next()
# return inputs_batch, self.to_one_of_k(targets_batch) return inputs_batch, self.to_one_of_k(targets_batch)
def __next__(self):
return self.next()
def to_one_of_k(self, int_targets): def to_one_of_k(self, int_targets):
"""Converts integer coded class target to 1 of K coded targets. """Converts integer coded class target to 1 of K coded targets.

File diff suppressed because one or more lines are too long