Update data provider

This commit is contained in:
AntreasAntoniou 2018-02-09 21:53:13 +00:00
parent 4e22eefbce
commit ba7b5387ad
6 changed files with 6 additions and 6 deletions

View File

@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
rng = np.random.RandomState(seed=seed) # set seed rng = np.random.RandomState(seed=seed) # set seed
train_data = CIFAR100DataProvider(which_set="train", batch_size=batch_size, rng=rng) train_data = CIFAR100DataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
val_data = CIFAR100DataProvider(which_set="valid", batch_size=batch_size, rng=rng) val_data = CIFAR100DataProvider(which_set="valid", batch_size=batch_size, rng=rng)
test_data = CIFAR100DataProvider(which_set="test", batch_size=batch_size, rng=rng) test_data = CIFAR100DataProvider(which_set="test", batch_size=batch_size, rng=rng)
# setup our data providers # setup our data providers

View File

@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
rng = np.random.RandomState(seed=seed) # set seed rng = np.random.RandomState(seed=seed) # set seed
train_data = CIFAR10DataProvider(which_set="train", batch_size=batch_size, rng=rng) train_data = CIFAR10DataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
val_data = CIFAR10DataProvider(which_set="valid", batch_size=batch_size, rng=rng) val_data = CIFAR10DataProvider(which_set="valid", batch_size=batch_size, rng=rng)
test_data = CIFAR10DataProvider(which_set="test", batch_size=batch_size, rng=rng) test_data = CIFAR10DataProvider(which_set="test", batch_size=batch_size, rng=rng)
# setup our data providers # setup our data providers

View File

@ -138,7 +138,7 @@ class MNISTDataProvider(DataProvider):
"""Data provider for MNIST handwritten digit images.""" """Data provider for MNIST handwritten digit images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
random_sampling=True, rng=None): random_sampling=False, rng=None):
"""Create a new MNIST data provider object. """Create a new MNIST data provider object.
Args: Args:

View File

@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
rng = np.random.RandomState(seed=seed) # set seed rng = np.random.RandomState(seed=seed) # set seed
train_data = EMNISTDataProvider(which_set="train", batch_size=batch_size, rng=rng) train_data = EMNISTDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
val_data = EMNISTDataProvider(which_set="valid", batch_size=batch_size, rng=rng) val_data = EMNISTDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
test_data = EMNISTDataProvider(which_set="test", batch_size=batch_size, rng=rng) test_data = EMNISTDataProvider(which_set="test", batch_size=batch_size, rng=rng)
# setup our data providers # setup our data providers

View File

@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
rng = np.random.RandomState(seed=seed) # set seed rng = np.random.RandomState(seed=seed) # set seed
train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng) train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng) val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng) test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng)
# setup our data providers # setup our data providers

View File

@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
rng = np.random.RandomState(seed=seed) # set seed rng = np.random.RandomState(seed=seed) # set seed
train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng) train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng) val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng) test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng)
# setup our data providers # setup our data providers