You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

52 lines
2.6 KiB
Scala

import breeze.plot._
import breeze.plot.DomainFunction._
import breeze.linalg._
import breeze.stats.distributions.Gaussian
val nn = new RSNN(5000, 0.0000001)
val g = Gaussian(0, 0.3)
//val data = EqSeq(-math.Pi, math.Pi, 15) map (t => (t, math.sin(t)+ g.sample(1).last))
val (ws, evaluate) = nn.train(data, iter = 100000, lambda = (1.0/20) / 5 * (nn.n * 8) * 1)
val f = Figure()
val p = f.subplot(0)
val x = linspace(-5, 5)
val y = x.map(evaluate)
//print_data(nn, x, y, 3)
p += plot(x, y)
p += scatter(data.map(_._1), data.map(_._2), x => 0.1)
f.saveas("lines.png")
val x_i = data map {case (x,y) => x}
val y_i = data map {case (x,y) => y}
def print_data(nn: RSNN, x: DenseVector[Double], y: DenseVector[Double], tlambda: Double): Unit = {
val n = nn.n
reflect.io.File("C:/Users/tobia/Documents/Studium/Masterarbeit/Outputs/scala_out_d_1.csv").appendAll(s"x_n_$n"+s"_tl_$tlambda;" + x.toArray.mkString(";") + "\n")
reflect.io.File("C:/Users/tobia/Documents/Studium/Masterarbeit/Outputs/scala_out_d_1.csv").appendAll(s"y_n_$n"+s"_tl_$tlambda;" + y.toArray.mkString(";") + "\n")
}
reflect.io.File("C:/Users/tobia/Documents/Studium/Masterarbeit/Outputs/data_sin_d.csv").appendAll(x_i.mkString(";") + "\n")
reflect.io.File("C:/Users/tobia/Documents/Studium/Masterarbeit/Outputs/data_sin_d.csv").appendAll(y_i.mkString(";") + "\n")
reflect.io.File("C:/Users/tobia/Documents/Studium/Masterarbeit/Outputs/vals1.csv").appendAll(x.toArray.mkString(";") + "\n")
reflect.io.File("C:/Users/tobia/Documents/Studium/Masterarbeit/Outputs/vals1.csv").appendAll(y.toArray.mkString(";") + "\n")
for(j <- List(0.1, 1, 3)) {
for (i <- 3 until 4) {
val nn = new RSNN((5 * math.pow(10, i)).asInstanceOf[Int], 0.0000001)
val (ws, evaluate) = nn.train(data, iter = 100000, lambda = (1.0 / 20) / 5 * (nn.n * 8) * j)
val x = linspace(-5, 5)
val y = x.map(evaluate)
print_data(nn, x, y, j)
}
}
val x_i = Seq(-3.141592653589793,-2.722713633111154,-2.303834612632515,-1.8849555921538759,-1.4660765716752369,-1.0471975511965979,-0.6283185307179586,-0.2094395102393194,0.2094395102393194,0.6283185307179586,1.0471975511965974,1.4660765716752362,1.8849555921538759,2.3038346126325155,2.7227136331111543,3.1415926535897922)
val y_i = Seq(0.0802212608585366,-0.3759376368887911,-1.3264180339054117,-0.8971334213504949,-0.7724344034354425,-0.9501497164520739,-0.6224628757084738,-0.35622668982623207,-0.18377660088356823,0.7836770998126841,0.5874762732054489,1.0696991264956026,1.1297065441952743,0.7587275382323738,-0.030547103790458163,0.044327111895927106)
val data = x_i zip y_i