Update data provider
This commit is contained in:
parent
4e22eefbce
commit
ba7b5387ad
@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
|
|||||||
|
|
||||||
rng = np.random.RandomState(seed=seed) # set seed
|
rng = np.random.RandomState(seed=seed) # set seed
|
||||||
|
|
||||||
train_data = CIFAR100DataProvider(which_set="train", batch_size=batch_size, rng=rng)
|
train_data = CIFAR100DataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
|
||||||
val_data = CIFAR100DataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
val_data = CIFAR100DataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
||||||
test_data = CIFAR100DataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
test_data = CIFAR100DataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
||||||
# setup our data providers
|
# setup our data providers
|
||||||
|
@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
|
|||||||
|
|
||||||
rng = np.random.RandomState(seed=seed) # set seed
|
rng = np.random.RandomState(seed=seed) # set seed
|
||||||
|
|
||||||
train_data = CIFAR10DataProvider(which_set="train", batch_size=batch_size, rng=rng)
|
train_data = CIFAR10DataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
|
||||||
val_data = CIFAR10DataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
val_data = CIFAR10DataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
||||||
test_data = CIFAR10DataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
test_data = CIFAR10DataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
||||||
# setup our data providers
|
# setup our data providers
|
||||||
|
@ -138,7 +138,7 @@ class MNISTDataProvider(DataProvider):
|
|||||||
"""Data provider for MNIST handwritten digit images."""
|
"""Data provider for MNIST handwritten digit images."""
|
||||||
|
|
||||||
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
|
||||||
random_sampling=True, rng=None):
|
random_sampling=False, rng=None):
|
||||||
"""Create a new MNIST data provider object.
|
"""Create a new MNIST data provider object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
|
|||||||
|
|
||||||
rng = np.random.RandomState(seed=seed) # set seed
|
rng = np.random.RandomState(seed=seed) # set seed
|
||||||
|
|
||||||
train_data = EMNISTDataProvider(which_set="train", batch_size=batch_size, rng=rng)
|
train_data = EMNISTDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
|
||||||
val_data = EMNISTDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
val_data = EMNISTDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
||||||
test_data = EMNISTDataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
test_data = EMNISTDataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
||||||
# setup our data providers
|
# setup our data providers
|
||||||
|
@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
|
|||||||
|
|
||||||
rng = np.random.RandomState(seed=seed) # set seed
|
rng = np.random.RandomState(seed=seed) # set seed
|
||||||
|
|
||||||
train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng)
|
train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
|
||||||
val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
||||||
test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
||||||
# setup our data providers
|
# setup our data providers
|
||||||
|
@ -23,7 +23,7 @@ experiment_name = "experiment_{}_batch_size_{}_bn_{}_mp_{}".format(experiment_pr
|
|||||||
|
|
||||||
rng = np.random.RandomState(seed=seed) # set seed
|
rng = np.random.RandomState(seed=seed) # set seed
|
||||||
|
|
||||||
train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng)
|
train_data = MSD10GenreDataProvider(which_set="train", batch_size=batch_size, rng=rng, random_sampling=True)
|
||||||
val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
val_data = MSD10GenreDataProvider(which_set="valid", batch_size=batch_size, rng=rng)
|
||||||
test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
test_data = MSD10GenreDataProvider(which_set="test", batch_size=batch_size, rng=rng)
|
||||||
# setup our data providers
|
# setup our data providers
|
||||||
|
Loading…
Reference in New Issue
Block a user