commit
01140e3236
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,6 @@
|
||||
#dropbox stuff
|
||||
*.dropbox*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
325
00_Introduction.ipynb
Normal file
325
00_Introduction.ipynb
Normal file
@ -0,0 +1,325 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Introduction\n",
|
||||
"\n",
|
||||
"This notebook shows how to set-up a working python envirnoment for the Machine Learning Practical course.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Setting up the software\n",
|
||||
"\n",
|
||||
"Within this course we are going to work with python (using some auxiliary libraries like numpy and scipy). Depending on the infrastracture and working environment (e.g. DICE), root permission may not be not availabe so the packages cannot be installed in a default locations. A convenient python configuration, which allows to install and update third party libraries easily using package manager, are so called virtual environments. Those can be also used, to work (and test) the code with different versions of software.\n",
|
||||
"\n",
|
||||
"## Instructions for Windows\n",
|
||||
"\n",
|
||||
"The fastest way to get working setup on Windows is to install Anaconda (http://www.continuum.io) package. It's a python environment with precompiled most popular scientific libraries. It also works on MacOS, but numpy is not linked without a fee to a numerical library, hence for MacOS we recommend the following procedure.\n",
|
||||
"\n",
|
||||
"## Instructions for MacOS\n",
|
||||
"\n",
|
||||
"<ul>\n",
|
||||
"<li>Install macports following instructions at https://www.macports.org/install.php</li>\n",
|
||||
"<li>Install the relevant python packages in macports\n",
|
||||
"<ul>\n",
|
||||
"<li> sudo port install py27-scipy +openblas </li>\n",
|
||||
"<li> sudo port install py27-ipython +notebook </li>\n",
|
||||
"<li> sudo port install py27-notebook </li>\n",
|
||||
"<li> sudo port install py27-matplotlib </li>\n",
|
||||
"<li> sudo port select --set python python27 </li>\n",
|
||||
"<li> sudo port select --set ipython2 py27-ipython </li>\n",
|
||||
"<li> sudo port select --set ipython py27-ipython </li>\n",
|
||||
"</ul>\n",
|
||||
"</ul>\n",
|
||||
"\n",
|
||||
"Make sure that your $PATH has /opt/local/bin before /usr/bin so you pick up the version of python you just installed\n",
|
||||
"\n",
|
||||
"## Instructions for DICE:\n",
|
||||
"\n",
|
||||
"### Configuring virtual environment (the generic way)\n",
|
||||
"\n",
|
||||
"<ul>\n",
|
||||
"<li>git clone https://github.com/pypa/virtualenv</li>\n",
|
||||
"<li>Enter the cloned repository and type $\\texttt{virtualenv.py --python /usr/bin/python2.7 --no-site-packages --prefix=~/mlpractical}$ </li>\n",
|
||||
"<li>Activate the environment by typing $\\texttt{source ~/mlpractical/bin/activate}$ (to leave the virtual environment one may type $\\texttt{decativate}$). Environments need to be activated every time ones start the new session (unless you do this explicitly in the shell starting scripts, i.e. ~/.bashrc).\n",
|
||||
"</ul>\n",
|
||||
"\n",
|
||||
"### Configuring virtual environment (more comfy DICE wrapper)\n",
|
||||
"\n",
|
||||
"DICE comes with a handy virtual environment wrapper, called $\\texttt{mkvirtualenv}$, which allows to simplify a bit the above process, to use it:\n",
|
||||
"\n",
|
||||
"<ul>\n",
|
||||
"<li>$\\texttt{source /usr/bin/virtualenvwrapper.sh}$ (add this also to $\\texttt{~/.bashrc}$ script so its available automatically every time you ssh to the grid)</li>\n",
|
||||
"<li>Then type $\\texttt{mkvirtualenv mlpractical --python /usr/bin/python2.7}$ (this will create an environment under ~/.virtualenvs/mlpractical)</li>\n",
|
||||
"<li>To activate the environment you can use $\\texttt{workon}$ script that comes with the wrapper. Simply type: $\\texttt{workon mlpractical}$</li>\n",
|
||||
"</ul>\n",
|
||||
"\n",
|
||||
"Then, before you follow next, install/upgrade the following packages:\n",
|
||||
"\n",
|
||||
"pip install --upgrade pip <br/>\n",
|
||||
"pip install setuptools <br/>\n",
|
||||
"pip install setuptools --upgrade <br/>\n",
|
||||
"pip install ipython <br/>\n",
|
||||
"pip install notebook\n",
|
||||
"\n",
|
||||
"### Installing numpy\n",
|
||||
"\n",
|
||||
"Note, having virtual environment properly installed one may go and type `pip install numpy`, though this will most likely lead to the suboptimal configuration where numpy is linked to ATLAS numerical library, which on DICE is compiled in multi-threaded mode. This means whenever numpy use BLAS accelerated computations (using ATLAS), it will use <b>all</b> the available cores at the given machine. This happens because ATLAS can be compiled to either run computations in single *or* multi threaded modes. However, contrary to some other backends, the latter does not allow to use an arbitrary number of threads (specified by the user prior to computation). This is highly suboptimal, as the potential speed-up resulting from paralleism depends on many factors like the communication overhead between threads, the size of the problem, etc.. Using all cores for our exercises is not-necessary.\n",
|
||||
"\n",
|
||||
"For which reason, we are going to compile our own version of BLAS package, called *OpenBlas*. It allows to specify the number of threads manually by setting an environmental variable OMP_NUM_THREADS=N, where N is a desired number of parallel threads (please use 1 by default).\n",
|
||||
"\n",
|
||||
"#### OpenBlas\n",
|
||||
"\n",
|
||||
"To install OpenBlas library type:\n",
|
||||
"<ul>\n",
|
||||
"<li>$\\texttt{git clone git://github.com/xianyi/OpenBLAS }$</li>\n",
|
||||
"<li>$ \\texttt{cd OpenBLAS}$ </li>\n",
|
||||
"<li>$ \\texttt{make}$</li>\n",
|
||||
"<li>$ \\texttt{make PREFIX=/path/to/OpenBLAS install}$ </li>\n",
|
||||
"<li>Add $\\texttt{/path/to/OpenBLAS/lib}$ to LD_LIBRARY_PATH environmental variable (do it in ~/.bashrc by `export` LD_LIBRARY_PATH=\"\\$LD_LIBRARY_PATH:/path/to/OpenBLAS/lib\") </li>\n",
|
||||
"</ul>\n",
|
||||
"\n",
|
||||
"#### Numpy\n",
|
||||
"\n",
|
||||
"<code>\n",
|
||||
"wget http://downloads.sourceforge.net/project/numpy/NumPy/1.9.2/numpy-1.9.2.zip\n",
|
||||
"unzip numpy-1.9.2.zip\n",
|
||||
"cd numpy-1.9.2\n",
|
||||
"echo \"[openblas]\" >> site.cfg\n",
|
||||
"echo \"library_dirs = /path/to/OpenBlas/lib\" >> site.cfg\n",
|
||||
"echo \"include_dirs = /path/to/OpenBLAS/include\" >> site.cfg\n",
|
||||
"</code>\n",
|
||||
"\n",
|
||||
"python setup.py build --fcompiler=gnu95\n",
|
||||
"\n",
|
||||
"Assuming the virtual environment is activated, the below command will install numpy in a desired space (~/.virtualenvs/mlpractical/...):\n",
|
||||
"\n",
|
||||
"python setup.py install\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Installing remaining packages and running tests\n",
|
||||
"\n",
|
||||
"Use pip to install remaining packages: `scipy`, `matplotlib`, `argparse`, `nose`, and check if they pass the tests. An example for numpy is given below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%clear\n",
|
||||
"import numpy\n",
|
||||
"# show_config() prints the configuration of numpy numerical backend \n",
|
||||
"# you should be able to see linkage to OpenBlas or some other library\n",
|
||||
"# in case those are empty, it means something went wrong and \n",
|
||||
"# numpy will use a default (slow) pythonic implementation for algebra\n",
|
||||
"numpy.show_config()\n",
|
||||
"#numpy.test()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Also, below we check whether and how much speedup one may expect by using different number of cores:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import multiprocessing\n",
|
||||
"import timeit\n",
|
||||
"\n",
|
||||
"num_cores = multiprocessing.cpu_count()\n",
|
||||
"N = 1000\n",
|
||||
"x = numpy.random.random((N,N))\n",
|
||||
"\n",
|
||||
"for i in xrange(0, num_cores):\n",
|
||||
" # first, set the number of threads OpenBLAS\n",
|
||||
" # should use, the below line is equivalent\n",
|
||||
" # to typing export OMP_NUM_THREADS=i+1 in bash shell\n",
|
||||
" print 'Running matrix-matrix product on %i core(s)' % i\n",
|
||||
" os.environ['OMP_NUM_THREADS'] = str(i+1)\n",
|
||||
" %%timeit numpy.dot(x,x.T)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Test whether you can plot and display the figures using pyplot"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"x = numpy.linspace(0.0, 2*numpy.pi, 100)\n",
|
||||
"y1 = numpy.sin(x)\n",
|
||||
"y2 = numpy.cos(x)\n",
|
||||
"\n",
|
||||
"plt.plot(x, y1, lw=2, label=r'$\\sin(x)$')\n",
|
||||
"plt.plot(x, y2, lw=2, label=r'$\\cos(x)$')\n",
|
||||
"plt.xlabel('x')\n",
|
||||
"plt.ylabel('y')\n",
|
||||
"plt.legend()\n",
|
||||
"plt.xlim(0.0, 2*numpy.pi)\n",
|
||||
"plt.grid()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercises\n",
|
||||
"\n",
|
||||
"Today exercises are meant to get you familiar with ipython notebooks (if you haven't used them so far), data organisation and how to access it. Next week onwars, we will follow with the material covered in lectures.\n",
|
||||
"\n",
|
||||
"## Data providers\n",
|
||||
"\n",
|
||||
"Open (in the browser) `mlp.dataset` module (go to `Home` tab and navigate to mlp package, then click on the link `dataset.py`). Have a look thourgh the code and comments, then follow to exercises.\n",
|
||||
"\n",
|
||||
"<b>General note:</b> you can load the mlp code into your favourite python IDE but it is totally OK if you work (modify & save) the code directly in the browser by opening/modyfing the necessary modules in the tabs.\n",
|
||||
"\n",
|
||||
"### Exercise 1 \n",
|
||||
"\n",
|
||||
"Using MNISTDataProvider, write a code that iterates over the first 5 minibatches of size 100 data-points. Print MNIST digits in 10x10 images grid plot. Images are returned from the provider as tuples of numpy arrays `(features, targets)`. The `features` matrix has shape BxD while the `targets` vector is of size B, where B is the size of a mini-batch and D is dimensionality of the features. By deafult, each data-point (image) is stored in a 784 dimensional vector of pixel intensities normalised to [0,1] range from an inital integer values [0-255]. However, the original spatial domain is two dimensional, so before plotting you need to convert it into 2D matrix (MNIST images have the same number of pixels for height and width).\n",
|
||||
"\n",
|
||||
"Tip: Useful functions for this exercise are: imshow, subplot, gridspec"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.gridspec as gridspec\n",
|
||||
"import matplotlib.cm as cm\n",
|
||||
"from mlp.dataset import MNISTDataProvider\n",
|
||||
"\n",
|
||||
"def show_mnist_image(img):\n",
|
||||
" fig = plt.figure()\n",
|
||||
" gs = gridspec.GridSpec(1, 1)\n",
|
||||
" ax1 = fig.add_subplot(gs[0,0])\n",
|
||||
" ax1.imshow(img, cmap=cm.Greys_r)\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"def show_mnist_images(batch):\n",
|
||||
" raise NotImplementedError('Write me!')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# An example for a single MNIST image\n",
|
||||
"mnist_dp = MNISTDataProvider(dset='valid', batch_size=1, max_num_examples=2, randomize=False)\n",
|
||||
"\n",
|
||||
"for batch in mnist_dp:\n",
|
||||
" features, targets = batch\n",
|
||||
" show_mnist_image(features.reshape(28, 28))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#implement here Exercise 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Exercise 2\n",
|
||||
"\n",
|
||||
"`MNISTDataProvider` as `targets` currently returns a vector of integers, each element in this vector represents an id of the category `features` data-point represent. Later in the course we are going to need 1-of-K representation of targets, for instance, given the minibatch of size 3 and the corresponding targets vector $[2, 2, 0]$ (and assuming there are only 3 different classes to discriminate between), one needs to convert it into matrix $\\left[ \\begin{array}{ccc}\n",
|
||||
"0 & 0 & 1 \\\\\n",
|
||||
"0 & 0 & 1 \\\\\n",
|
||||
"1 & 0 & 0 \\end{array} \\right]$. \n",
|
||||
"\n",
|
||||
"Implement `__to_one_of_k` method of `MNISTDataProvider` class. Then modify (uncomment) an appropriate line in its `next` method, so the raw targets get converted to `1 of K` coding. Test the code in the cell below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"source": [
|
||||
"### Exercise 3\n",
|
||||
"\n",
|
||||
"Write your own data provider `MetOfficeDataProvider` that wraps the weather data for south Scotland (could be obtained from: http://www.metoffice.gov.uk/hadobs/hadukp/data/daily/HadSSP_daily_qc.txt). The file was also downloaded and stored in `data` directory for your convenience. The provider should return a tuple `(x,t)` of the estimates over an arbitrary time windows (i.e. last N-1 days) for `x` and the N-th day as the one which model should be able to predict, `t`. For now, skip missing data-points (denoted by -99.9) and simply use the next correct value. Make sure the provider works for arbitrary `batch_size` settings, including the case where single mini-batch is equal to all datapoints in the dataset. Test the dataset in the cell below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 2",
|
||||
"language": "python",
|
||||
"name": "python2"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
1023
data/HadSSP_daily_qc.txt
Normal file
1023
data/HadSSP_daily_qc.txt
Normal file
File diff suppressed because it is too large
Load Diff
BIN
data/mnist_eval.pkl.gz
Normal file
BIN
data/mnist_eval.pkl.gz
Normal file
Binary file not shown.
BIN
data/mnist_train.pkl.gz
Normal file
BIN
data/mnist_train.pkl.gz
Normal file
Binary file not shown.
BIN
data/mnist_valid.pkl.gz
Normal file
BIN
data/mnist_valid.pkl.gz
Normal file
Binary file not shown.
0
mlp/__init__.py
Normal file
0
mlp/__init__.py
Normal file
187
mlp/dataset.py
Normal file
187
mlp/dataset.py
Normal file
@ -0,0 +1,187 @@
|
||||
|
||||
# Machine Learning Practical (INFR11119),
|
||||
# Pawel Swietojanski, University of Edinburgh
|
||||
|
||||
import cPickle
|
||||
import gzip
|
||||
import numpy
|
||||
import os
|
||||
|
||||
|
||||
class DataProvider(object):
|
||||
"""
|
||||
Data provider defines an interface for our
|
||||
generic data-independent readers.
|
||||
"""
|
||||
def __init__(self, batch_size, randomize=True):
|
||||
"""
|
||||
:param batch_size: int, specifies the number
|
||||
of elements returned at each step
|
||||
:param randomize: bool, shuffles examples prior
|
||||
to iteration, so they are presented in random
|
||||
order for stochastic gradient descent training
|
||||
:return:
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.randomize = randomize
|
||||
self._curr_idx = 0
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the provider to the initial state to
|
||||
use in another epoch
|
||||
:return: None
|
||||
"""
|
||||
self._curr_idx = 0
|
||||
|
||||
def __randomize(self):
|
||||
"""
|
||||
Data-specific implementation of shuffling mechanism
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
"""
|
||||
Data-specific iteration mechanism.
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MNISTDataProvider(DataProvider):
|
||||
"""
|
||||
The class iterates over MNIST digits dataset, in possibly
|
||||
random order.
|
||||
"""
|
||||
def __init__(self, dset,
|
||||
batch_size=10,
|
||||
max_num_examples=-1,
|
||||
randomize=True):
|
||||
|
||||
super(MNISTDataProvider, self).\
|
||||
__init__(batch_size, randomize)
|
||||
|
||||
assert dset in ['train', 'valid', 'eval'], (
|
||||
"Expected dset to be either 'train', "
|
||||
"'valid' or 'eval' got %s" % dset
|
||||
)
|
||||
|
||||
dset_path = './data/mnist_%s.pkl.gz' % dset
|
||||
assert os.path.isfile(dset_path), (
|
||||
"File %s was expected to exist!." % dset_path
|
||||
)
|
||||
|
||||
with gzip.open(dset_path) as f:
|
||||
x, t = cPickle.load(f)
|
||||
|
||||
self._max_num_examples = max_num_examples
|
||||
self.x = x
|
||||
self.t = t
|
||||
self.num_classes = 10
|
||||
|
||||
self._rand_idx = None
|
||||
if self.randomize:
|
||||
self._rand_idx = self.__randomize()
|
||||
|
||||
def reset(self):
|
||||
super(MNISTDataProvider, self).reset()
|
||||
if self.randomize:
|
||||
self._rand_idx = self.__randomize()
|
||||
|
||||
def __randomize(self):
|
||||
assert isinstance(self.x, numpy.ndarray)
|
||||
return numpy.random.permute(numpy.arange(0, self.x.shape[0]))
|
||||
|
||||
def next(self):
|
||||
|
||||
has_enough = (self._curr_idx + self.batch_size) <= self.x.shape[0]
|
||||
presented_max = (self._max_num_examples > 0 and
|
||||
self._curr_idx + self.batch_size > self._max_num_examples)
|
||||
|
||||
if not has_enough or presented_max:
|
||||
raise StopIteration()
|
||||
|
||||
if self._rand_idx is not None:
|
||||
range_idx = \
|
||||
self._rand_idx[self._curr_idx:self._curr_idx + self.batch_size]
|
||||
else:
|
||||
range_idx = \
|
||||
numpy.arange(self._curr_idx, self._curr_idx + self.batch_size)
|
||||
|
||||
rval_x = self.x[range_idx]
|
||||
rval_t = self.t[range_idx]
|
||||
|
||||
self._curr_idx += self.batch_size
|
||||
|
||||
#return rval_x, self.__to_one_of_k(rval_y)
|
||||
return rval_x, rval_t
|
||||
|
||||
def __to_one_of_k(self, y):
|
||||
raise NotImplementedError('Write me!')
|
||||
|
||||
|
||||
class FuncDataProvider(DataProvider):
|
||||
"""
|
||||
Function gets as an argument a list of functions random samples
|
||||
drawn from normal distribution which means are defined by those
|
||||
functions.
|
||||
"""
|
||||
def __init__(self,
|
||||
fn_list=[lambda x: x ** 2, lambda x: numpy.sin(x)],
|
||||
std_list=[0.1, 0.1],
|
||||
x_from = 0.0,
|
||||
x_to = 1.0,
|
||||
points_per_fn=200,
|
||||
batch_size=10,
|
||||
randomize=True):
|
||||
|
||||
super(FuncDataProvider, self).__init__(batch_size, randomize)
|
||||
|
||||
def sample_points(y, std):
|
||||
ys = numpy.zeros_like(y)
|
||||
for i in xrange(y.shape[0]):
|
||||
ys[i] = numpy.random.normal(y[i], std)
|
||||
return ys
|
||||
|
||||
x = numpy.linspace(x_from, x_to, points_per_fn, dtype=numpy.float32)
|
||||
means = [fn(x) for fn in fn_list]
|
||||
y = [sample_points(mean, std) for mean, std in zip(means, std_list)]
|
||||
|
||||
self.x_orig = x
|
||||
self.y_class = y
|
||||
|
||||
self.x = numpy.concatenate([x for ys in y])
|
||||
self.y = numpy.concatenate([ys for ys in y])
|
||||
|
||||
if self.randomize:
|
||||
self._rand_idx = self.__randomize()
|
||||
else:
|
||||
self._rand_idx = None
|
||||
|
||||
def __randomize(self):
|
||||
assert isinstance(self.x, numpy.ndarray)
|
||||
return numpy.random.permute(numpy.arange(0, self.x.shape[0]))
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if (self._curr_idx + self.batch_size) >= self.x.shape[0]:
|
||||
raise StopIteration()
|
||||
|
||||
if self._rand_idx is not None:
|
||||
range_idx = self._rand_idx[self._curr_idx:self._curr_idx + self.batch_size]
|
||||
else:
|
||||
range_idx = numpy.arange(self._curr_idx, self._curr_idx + self.batch_size)
|
||||
|
||||
x = self.x[range_idx]
|
||||
y = self.y[range_idx]
|
||||
|
||||
self._curr_idx += self.batch_size
|
||||
|
||||
return x, y
|
||||
|
Loading…
Reference in New Issue
Block a user