Update data provider
This commit is contained in:
parent
4e22eefbce
commit
ba7b5387ad
@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
|
||||
|
||||
rng = np.random.RandomState(seed=seed) # set seed
|
||||
|
||||
train_data = CIFAR100DataProvider(which_set="train", batch_size=batch_size, rng=rng)
|
||||
train_data = CIFAR100DataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
|
||||
val_data = CIFAR100DataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
||||
test_data = CIFAR100DataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
||||
# setup our data providers
|
||||
|
@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
|
||||
|
||||
rng = np.random.RandomState(seed=seed) # set seed
|
||||
|
||||
train_data = CIFAR10DataProvider(which_set="train", batch_size=batch_size, rng=rng)
|
||||
train_data = CIFAR10DataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
|
||||
val_data = CIFAR10DataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
||||
test_data = CIFAR10DataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
||||
# setup our data providers
|
||||
|
@ -138,7 +138,7 @@ class MNISTDataProvider(DataProvider):
|
||||
"""Data provider for MNIST handwritten digit images."""
|
||||
|
||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||
random_sampling=True, rng=None):
|
||||
random_sampling=False, rng=None):
|
||||
"""Create a new MNIST data provider object.
|
||||
|
||||
Args:
|
||||
|
@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
|
||||
|
||||
rng = np.random.RandomState(seed=seed) # set seed
|
||||
|
||||
train_data = EMNISTDataProvider(which_set="train", batch_size=batch_size, rng=rng)
|
||||
train_data = EMNISTDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
|
||||
val_data = EMNISTDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
||||
test_data = EMNISTDataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
||||
# setup our data providers
|
||||
|
@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
|
||||
|
||||
rng = np.random.RandomState(seed=seed) # set seed
|
||||
|
||||
train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng)
|
||||
train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
|
||||
val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
||||
test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
||||
# setup our data providers
|
||||
|
@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
|
||||
|
||||
rng = np.random.RandomState(seed=seed) # set seed
|
||||
|
||||
train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng)
|
||||
train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
|
||||
val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
||||
test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
||||
# setup our data providers
|
||||
|
Loading…
Reference in New Issue
Block a user