import breeze.stats.distributions.Uniform import breeze.stats.distributions.Gaussian import scala.language.postfixOps object Activation { def apply(x: Double): Double = math.max(0, x) def d(x: Double): Double = if (x > 0) 1 else 0 } class RSNN(val n: Int, val gamma: Double = 0.001) { val g = Uniform(-10, 10) val g_1 = Uniform(-5, 5)//scala.math.exp(1)) val g_3 = Gaussian(0, 5) val xis = g.sample(n) val vs = g_3.sample(n) val bs = xis zip vs map {case(xi, v) => xi * v} //val vs = g_1.sample(n) //val bs = g.sample(n) def computeL1(x: Double) = (bs zip vs) map { case (b, v) => Activation(b + v * x) } def computeL2(l1: Seq[Double], ws: Seq[Double]): Double = (l1 zip ws) map { case (l, w) => w * l } sum def output(ws: Seq[Double])(x: Double): Double = computeL2(computeL1(x), ws) def learn(data: Seq[(Double, Double)], ws: Seq[Double], lambda: Double, gamma: Double): Seq[Double] = { // data: N \times 2 // ws: n \times 1 lazy val deltas = data.map { case (x, y) => val l1 = computeL1(x) // n val out = computeL2(l1, ws) // 1 (l1 zip ws) map {case (l1, w) => (l1 * 2 * (out - y) + lambda * 2 * w) * gamma * -1} // n } // deltas: N × n deltas.foldRight(ws)( (delta, ws) => // delta: n ws zip (delta) map { case (w, d) => w + d } // n )// map (w => w - lambda * gamma * 2 * w) } def train(data: Seq[(Double, Double)], iter: Int, lambda: Double, gamma: Double = gamma): (Seq[Double], Double => Double)= { val ws = (1 to iter).foldRight((1 to n).map(_ => 0.0) :Seq[Double])((i, w) => { println(s"Training iteration $i") println(w.sum/w.length) learn(data, w, lambda, gamma / 10) }) (ws, output(ws)) } } object Main { def main(args: Array[String]): Unit = { val nn = new RSNN(10, gamma = 0.0001) val data = (1 to 100) map (_ * 0.01) map (t => (t, math.sin(t))) val (ws, evaluate) = nn.train(data, iter = 1000, lambda = 0.8) val results = data.map(_._1).map(evaluate(_)) data zip results foreach { println(_) } } } object EqSeq { def apply(left: Double, right: Double, steps: Int): Seq[Double] = (0 to steps) map (_ * (right - left) / steps + left) }