You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

76 lines
2.3 KiB
Scala

4 years ago
import breeze.stats.distributions.Uniform
import breeze.stats.distributions.Gaussian
import scala.language.postfixOps
object Activation {
def apply(x: Double): Double = math.max(0, x)
def d(x: Double): Double = if (x > 0) 1 else 0
}
class RSNN(val n: Int, val gamma: Double = 0.001) {
val g = Uniform(-10, 10)
val g_1 = Uniform(-5, 5)//scala.math.exp(1))
val g_3 = Gaussian(0, 5)
val xis = g.sample(n)
val vs = g_3.sample(n)
val bs = xis zip vs map {case(xi, v) => xi * v}
//val vs = g_1.sample(n)
//val bs = g.sample(n)
def computeL1(x: Double) = (bs zip vs) map { case (b, v) => Activation(b + v * x) }
def computeL2(l1: Seq[Double], ws: Seq[Double]): Double =
(l1 zip ws) map { case (l, w) => w * l } sum
def output(ws: Seq[Double])(x: Double): Double = computeL2(computeL1(x), ws)
def learn(data: Seq[(Double, Double)], ws: Seq[Double], lambda: Double, gamma: Double): Seq[Double] = {
// data: N \times 2
// ws: n \times 1
lazy val deltas = data.map {
case (x, y) =>
val l1 = computeL1(x) // n
val out = computeL2(l1, ws) // 1
(l1 zip ws) map {case (l1, w) => (l1 * 2 * (out - y) + lambda * 2 * w) * gamma * -1} // n
}
// deltas: N × n
deltas.foldRight(ws)(
(delta, ws) => // delta: n
ws zip (delta) map { case (w, d) => w + d } // n
)// map (w => w - lambda * gamma * 2 * w)
}
def train(data: Seq[(Double, Double)], iter: Int, lambda: Double, gamma: Double = gamma): (Seq[Double], Double => Double)= {
val ws = (1 to iter).foldRight((1 to n).map(_ => 0.0) :Seq[Double])((i, w) => {
println(s"Training iteration $i")
println(w.sum/w.length)
learn(data, w, lambda, gamma / 10)
})
(ws, output(ws))
}
}
object Main {
def main(args: Array[String]): Unit = {
val nn = new RSNN(10, gamma = 0.0001)
val data = (1 to 100) map (_ * 0.01) map (t => (t, math.sin(t)))
val (ws, evaluate) = nn.train(data, iter = 1000, lambda = 0.8)
val results = data.map(_._1).map(evaluate(_))
data zip results foreach {
println(_)
}
}
}
object EqSeq {
def apply(left: Double, right: Double, steps: Int): Seq[Double] =
(0 to steps) map (_ * (right - left) / steps + left)
}