From ba7b5387ad2a6bc41f7792a98172d7a1902867a9 Mon Sep 17 00:00:00 2001 From: AntreasAntoniou Date: Fri, 9 Feb 2018 21:53:13 +0000 Subject: [PATCH] Update data provider --- cifar100_network_trainer.py | 2 +- cifar10_network_trainer.py | 2 +- data_providers.py | 2 +- emnist_network_trainer.py | 2 +- msd10_network_trainer.py | 2 +- msd25_network_trainer.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cifar100_network_trainer.py b/cifar100_network_trainer.py index 2683530..e5cc4de 100644 --- a/cifar100_network_trainer.py +++ b/cifar100_network_trainer.py @@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr rng = np.random.RandomState(seed=seed) # set seed -train_data = CIFAR100DataProvider(which_set="train", batch_size=batch_size, rng=rng) +train_data = CIFAR100DataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True) val_data = CIFAR100DataProvider(which_set="valid", batch_size=batch_size, rng=rng) test_data = CIFAR100DataProvider(which_set="test", batch_size=batch_size, rng=rng) # setup our data providers diff --git a/cifar10_network_trainer.py b/cifar10_network_trainer.py index 0f1554f..d456c9c 100644 --- a/cifar10_network_trainer.py +++ b/cifar10_network_trainer.py @@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr rng = np.random.RandomState(seed=seed) # set seed -train_data = CIFAR10DataProvider(which_set="train", batch_size=batch_size, rng=rng) +train_data = CIFAR10DataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True) val_data = CIFAR10DataProvider(which_set="valid", batch_size=batch_size, rng=rng) test_data = CIFAR10DataProvider(which_set="test", batch_size=batch_size, rng=rng) # setup our data providers diff --git a/data_providers.py b/data_providers.py index 345c478..c6ae1b2 100644 --- a/data_providers.py +++ b/data_providers.py @@ -138,7 +138,7 @@ class MNISTDataProvider(DataProvider): """Data provider for MNIST handwritten digit images.""" def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, - random_sampling=True, rng=None): + random_sampling=False, rng=None): """Create a new MNIST data provider object. Args: diff --git a/emnist_network_trainer.py b/emnist_network_trainer.py index 1111408..f33a436 100644 --- a/emnist_network_trainer.py +++ b/emnist_network_trainer.py @@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr rng = np.random.RandomState(seed=seed) # set seed -train_data = EMNISTDataProvider(which_set="train", batch_size=batch_size, rng=rng) +train_data = EMNISTDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True) val_data = EMNISTDataProvider(which_set="valid", batch_size=batch_size, rng=rng) test_data = EMNISTDataProvider(which_set="test", batch_size=batch_size, rng=rng) # setup our data providers diff --git a/msd10_network_trainer.py b/msd10_network_trainer.py index 87b8eec..584347b 100644 --- a/msd10_network_trainer.py +++ b/msd10_network_trainer.py @@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr rng = np.random.RandomState(seed=seed) # set seed -train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng) +train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True) val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng) test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng) # setup our data providers diff --git a/msd25_network_trainer.py b/msd25_network_trainer.py index 94337b9..bc94997 100644 --- a/msd25_network_trainer.py +++ b/msd25_network_trainer.py @@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr rng = np.random.RandomState(seed=seed) # set seed -train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng) +train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True) val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng) test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng) # setup our data providers