Fix bug with data provider flag of random_sampling being set to true as default instead of false as it should have been

This commit is contained in:
AntreasAntoniou 2018-03-02 22:51:35 +00:00
parent 8ac193d906
commit 84e22e9796

View File

@ -204,7 +204,7 @@ class EMNISTDataProvider(DataProvider):
"""Data provider for EMNIST handwritten digit images.""" """Data provider for EMNIST handwritten digit images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
random_sampling=True, rng=None, flatten=False, one_hot=False): random_sampling=False, rng=None, flatten=False, one_hot=False):
"""Create a new EMNIST data provider object. """Create a new EMNIST data provider object.
Args: Args:
@ -281,7 +281,7 @@ class CIFAR10DataProvider(DataProvider):
"""Data provider for CIFAR-10 object images.""" """Data provider for CIFAR-10 object images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
random_sampling=True, rng=None, flatten=False, one_hot=False): random_sampling=False, rng=None, flatten=False, one_hot=False):
"""Create a new EMNIST data provider object. """Create a new EMNIST data provider object.
Args: Args:
@ -364,7 +364,7 @@ class CIFAR100DataProvider(DataProvider):
"""Data provider for CIFAR-100 object images.""" """Data provider for CIFAR-100 object images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
random_sampling=True, rng=None, flatten=False, one_hot=False): random_sampling=False, rng=None, flatten=False, one_hot=False):
"""Create a new EMNIST data provider object. """Create a new EMNIST data provider object.
Args: Args:
@ -443,7 +443,7 @@ class MSD10GenreDataProvider(DataProvider):
"""Data provider for Million Song Dataset 10-genre classification task.""" """Data provider for Million Song Dataset 10-genre classification task."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
random_sampling=True, rng=None, one_hot=False, flatten=True): random_sampling=False, rng=None, one_hot=False, flatten=True):
"""Create a new EMNIST data provider object. """Create a new EMNIST data provider object.
Args: Args:
@ -535,7 +535,7 @@ class MSD25GenreDataProvider(DataProvider):
"""Data provider for Million Song Dataset 25-genre classification task.""" """Data provider for Million Song Dataset 25-genre classification task."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
random_sampling=True, rng=None, one_hot=False, flatten=True): random_sampling=False, rng=None, one_hot=False, flatten=True):
"""Create a new EMNIST data provider object. """Create a new EMNIST data provider object.
Args: Args:
@ -612,7 +612,7 @@ class MetOfficeDataProvider(DataProvider):
"""South Scotland Met Office weather data provider.""" """South Scotland Met Office weather data provider."""
def __init__(self, window_size, batch_size=10, max_num_batches=-1, def __init__(self, window_size, batch_size=10, max_num_batches=-1,
random_sampling=True, rng=None): random_sampling=False, rng=None):
"""Create a new Met Office data provider object. """Create a new Met Office data provider object.
Args: Args:
@ -658,7 +658,7 @@ class MetOfficeDataProvider(DataProvider):
class CCPPDataProvider(DataProvider): class CCPPDataProvider(DataProvider):
def __init__(self, which_set='train', input_dims=None, batch_size=10, def __init__(self, which_set='train', input_dims=None, batch_size=10,
max_num_batches=-1, random_sampling=True, rng=None): max_num_batches=-1, random_sampling=False, rng=None):
"""Create a new Combined Cycle Power Plant data provider object. """Create a new Combined Cycle Power Plant data provider object.
Args: Args:
@ -706,7 +706,7 @@ class AugmentedMNISTDataProvider(MNISTDataProvider):
"""Data provider for MNIST dataset which randomly transforms images.""" """Data provider for MNIST dataset which randomly transforms images."""
def __init__(self, which_set='train', batch_size=100, max_num_batches=-1, def __init__(self, which_set='train', batch_size=100, max_num_batches=-1,
random_sampling=True, rng=None, transformer=None): random_sampling=False, rng=None, transformer=None):
"""Create a new augmented MNIST data provider object. """Create a new augmented MNIST data provider object.
Args: Args: