Scala: 行列累乗でフィボナッチ数列を解く
目的
- いくつかの方法でフィボナッチ数列を求めるプログラムをScalaで実装してみる
 - 次の操作にかかる時間を測定し、比較する
 - 0番目〜100番目のフィボナッチ数を列挙する
 - 100,000番目のフィボナッチ数を求める
 
時間計測用の処理
このようなメソッドを予め用意した。
 処理中に例外が発生した場合は、その概要が出力される。
import scala.util.control.Exception._
object Fibonacci {
  def timeit(description: String = "")(proc: => Unit) {
    val start = System.currentTimeMillis()
    allCatch.either(proc) match {
      case Left(e) => println(e)
      case _ =>
    }
    println(description + ": " + (System.currentTimeMillis() - start) + " msec")
  }
}
1. ワンライナー版
行列を使わずに、Streamで実装されたもの。
 Scalaらしく、非常にエレガントなコードだ。
object FibonacciOneLiner {
  val fibs: Stream[BigInt] = BigInt(0) #:: BigInt(1) #:: fibs.zip(fibs.tail).map { n => n._1 + n._2 }
}
呼び出しは以下のようにする。
object Fibonacci {
  def main(args: Array[String]) {
    val n = 100
    val m = 100000
    timeit("[1] Enum first 100 Fibs") {
      FibonacciOneLiner.fibs.take(n + 1).zipWithIndex.foreach { x => println(x._2 + " : " + x._1) }
    }
    timeit("[1] Find 100,000th Fib") {
      println(m + ": " + FibonacciOneLiner.fibs.take(m + 1)(m))
    }
  }
}
計測結果は以下の通り。
- 100番目までの列挙 => 390 ms
 - 100,000番目の表示 => java.lang.OutOfMemoryError 発生
 
2. 行列累乗版
行列の積によってフィボナッチ数列が求められることと、
 平方行列のn乗が Θ(lg n) の計算時間で算出できることを利用して解く方法。
数学的な解説はReferencesを参照。
import annotation.tailrec
object FibonacciMatrix {
  type Matrix = Array[Array[BigInt]]
  implicit def arrayArrayInt2Matrix(m: Array[Array[Int]]): Matrix = {
    m.map {
      a: Array[Int] => a.map {
        x: Int => BigInt(x)
      }
    }
  }
  def mul(m1: Matrix, m2: Matrix) = {
    val res = Array.fill(m1.length, m2(0).length)(BigInt(0))
    for (row <- (0 until m1.length);
         col <- (0 until m2(0).length);
         i <- 0 until m1(0).length) {
      res(row)(col) += m1(row)(i) * m2(i)(col)
    }
    res
  }
  def pow(a: Matrix, n: Int) = {
    @tailrec
    def powLocal(a: Matrix, b: Matrix, n: Int): Matrix = n match {
      case 0 => b
      case x if (x & 1) == 1 => powLocal(mul(a, a), mul(a, b), x >> 1)
      case x => powLocal(mul(a, a), b, x >> 1)
    }
    val unit = Array(Array(BigInt(1), BigInt(0)), Array(BigInt(0), BigInt(1))) // 単位行列
    if (n < 0) throw new IllegalArgumentException
    powLocal(a, unit, n)
  }
  def fib(n: Int) = {
    val m = Array(Array(1, 1), Array(1, 0))
    pow(m, n)(1)(0)
  }
}
行列(Matrix)をBigIntの二次元配列で表現。
Intの二次元配列からBigIntの二次元配列への暗黙の型変換メソッドを作成。
累乗を求める部分では、内部関数として末尾再帰の最適化が行われるよう工夫した。
呼び出し部分はこちら。
object Fibonacci {
  def main(args: Array[String]) {
    val n = 100
    val m = 100000
    timeit("[2] Enum first 100 Fibs") {
      (0 to n).foreach { i => println(i + ": " + FibonacciMatrix.fib(i)) }
    }
    timeit("[2] Find 100,000th Fib") {
      println(m + ": " + FibonacciMatrix.fib(m))
    }
  }
}
計測結果は以下の通り。
- 100番目までの列挙 => 108 ms
 - 100,000番目の表示 => 120 ms
 
OutOfMemoryが発生しない上に、処理時間そのものも大幅に短縮できた。
3. 行列累乗(並列処理)版
行列計算の mul メソッドを par を使って並列化したバージョン。
import annotation.tailrec
object FibonacciMatrix {
  type Matrix = Array[Array[BigInt]]
  implicit def arrayArrayInt2Matrix(m: Array[Array[Int]]): Matrix = {
    m.map {
      a: Array[Int] => a.map {
        x: Int => BigInt(x)
      }
    }
  }
  def mul(m1: Matrix, m2: Matrix) = {
    val res = Array.fill(m1.length, m2(0).length)(BigInt(0))
    for (row <- (0 until m1.length).par;
         col <- (0 until m2(0).length).par;
         i <- 0 until m1(0).length) {
      res(row)(col) += m1(row)(i) * m2(i)(col)
    }
    res
  }
  def pow(a: Matrix, n: Int) = {
    @tailrec
    def powLocal(a: Matrix, b: Matrix, n: Int): Matrix = n match {
      case 0 => b
      case x if (x & 1) == 1 => powLocal(mul(a, a), mul(a, b), x >> 1)
      case x => powLocal(mul(a, a), b, x >> 1)
    }
    val unit = Array(Array(BigInt(1), BigInt(0)), Array(BigInt(0), BigInt(1))) // 単位行列
    if (n < 0) throw new IllegalArgumentException
    powLocal(a, unit, n)
  }
  def fib(n: Int) = {
    val m = Array(Array(1, 1), Array(1, 0))
    pow(m, n)(1)(0)
  }
}
呼び出し方法。
object Fibonacci {
  def main(args: Array[String]) {
    val n = 100
    val m = 100000
    timeit("[3] Enum first 100 Fibs") {
      (0 to n).foreach { i => println(i + ": " + FibonacciMatrixParallel.fib(i)) }
    }
    timeit("[3] Find 100,000th Fib") {
      println(m + ": " + FibonacciMatrixParallel.fib(m))
    }
  }
}
計測結果は以下の通り。
- 100番目までの列挙 => 498 ms
 - 100,000番目の表示 => 106 ms
 
100,000番目の算出はごく僅かに速くなったが、100番目までの列挙は逆に遅くなってしまった。
References
MIT OCW - Lecture 3: Divide-and-Conquer: Strassen, Fibonacci, Polynomial Multiplication
http://ocw.mit.edu/courses/electrical-engineering-and-computer-science/6-046j-introduction-to-algorithms-sma-5503-fall-2005/video-lectures/lecture-3-divide-and-conquer-strassen-fibonacci-polynomial-multiplication/
ワンライナー・フィボナッチ
http://stackoverflow.com/questions/7388416/what-is-the-fastest-way-to-write-fibonacci-function-in-scala
行列の積の求め方
http://blog.scala4java.com/2011/12/matrix-multiplication-in-scala-single.html