67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Parameter initialisers.
|
|
|
|
This module defines classes to initialise the parameters in a layer.
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
DEFAULT_SEED = 123456 # Default random number generator seed if none provided.
|
|
|
|
|
|
class ConstantInit(object):
|
|
"""Constant parameter initialiser."""
|
|
|
|
def __init__(self, value):
|
|
"""Construct a constant parameter initialiser.
|
|
|
|
Args:
|
|
value: Value to initialise parameter to.
|
|
"""
|
|
self.value = value
|
|
|
|
def __call__(self, shape):
|
|
return np.ones(shape=shape) * self.value
|
|
|
|
|
|
class UniformInit(object):
|
|
"""Random uniform parameter initialiser."""
|
|
|
|
def __init__(self, low, high, rng=None):
|
|
"""Construct a random uniform parameter initialiser.
|
|
|
|
Args:
|
|
low: Lower bound of interval to sample from.
|
|
high: Upper bound of interval to sample from.
|
|
rng (RandomState): Seeded random number generator.
|
|
"""
|
|
self.low = low
|
|
self.high = high
|
|
if rng is None:
|
|
rng = np.random.RandomState(DEFAULT_SEED)
|
|
self.rng = rng
|
|
|
|
def __call__(self, shape):
|
|
return self.rng.uniform(low=self.low, high=self.high, size=shape)
|
|
|
|
|
|
class NormalInit(object):
|
|
"""Random normal parameter initialiser."""
|
|
|
|
def __init__(self, mean, std, rng=None):
|
|
"""Construct a random uniform parameter initialiser.
|
|
|
|
Args:
|
|
mean: Mean of distribution to sample from.
|
|
std: Standard deviation of distribution to sample from.
|
|
rng (RandomState): Seeded random number generator.
|
|
"""
|
|
self.mean = mean
|
|
self.std = std
|
|
if rng is None:
|
|
rng = np.random.RandomState(DEFAULT_SEED)
|
|
self.rng = rng
|
|
|
|
def __call__(self, shape):
|
|
return self.rng.normal(loc=self.mean, scale=self.std, size=shape)
|