11.25.2012

Scala: Find Fibonacci Numbers Using Matrix Multiplication

Scala: 行列累乗でフィボナッチ数列を解く

目的

  • いくつかの方法でフィボナッチ数列を求めるプログラムをScalaで実装してみる
  • 次の操作にかかる時間を測定し、比較する
    1. 0番目〜100番目のフィボナッチ数を列挙する
    2. 100,000番目のフィボナッチ数を求める

 

時間計測用の処理

このようなメソッドを予め用意した。
処理中に例外が発生した場合は、その概要が出力される。

import scala.util.control.Exception._

object Fibonacci {
  def timeit(description: String = "")(proc: => Unit) {
    val start = System.currentTimeMillis()
    allCatch.either(proc) match {
      case Left(e) => println(e)
      case _ =>
    }
    println(description + ": " + (System.currentTimeMillis() - start) + " msec")
  }
}

 

1. ワンライナー版

行列を使わずに、Streamで実装されたもの。
Scalaらしく、非常にエレガントなコードだ。

object FibonacciOneLiner {
  val fibs: Stream[BigInt] = BigInt(0) #:: BigInt(1) #:: fibs.zip(fibs.tail).map { n => n._1 + n._2 }
}

呼び出しは以下のようにする。

object Fibonacci {
  def main(args: Array[String]) {
    val n = 100
    val m = 100000
    timeit("[1] Enum first 100 Fibs") {
      FibonacciOneLiner.fibs.take(n + 1).zipWithIndex.foreach { x => println(x._2 + " : " + x._1) }
    }
    timeit("[1] Find 100,000th Fib") {
      println(m + ": " + FibonacciOneLiner.fibs.take(m + 1)(m))
    }
  }
}

計測結果は以下の通り。

  • 100番目までの列挙 => 390 ms
  • 100,000番目の表示 => java.lang.OutOfMemoryError 発生

 

2. 行列累乗版

行列の積によってフィボナッチ数列が求められることと、
平方行列のn乗が Θ(lg n) の計算時間で算出できることを利用して解く方法。

数学的な解説はReferencesを参照。

import annotation.tailrec

object FibonacciMatrix {
  type Matrix = Array[Array[BigInt]]

  implicit def arrayArrayInt2Matrix(m: Array[Array[Int]]): Matrix = {
    m.map {
      a: Array[Int] => a.map {
        x: Int => BigInt(x)
      }
    }
  }

  def mul(m1: Matrix, m2: Matrix) = {
    val res = Array.fill(m1.length, m2(0).length)(BigInt(0))
    for (row <- (0 until m1.length);
         col <- (0 until m2(0).length);
         i <- 0 until m1(0).length) {
      res(row)(col) += m1(row)(i) * m2(i)(col)
    }
    res
  }

  def pow(a: Matrix, n: Int) = {
    @tailrec
    def powLocal(a: Matrix, b: Matrix, n: Int): Matrix = n match {
      case 0 => b
      case x if (x & 1) == 1 => powLocal(mul(a, a), mul(a, b), x >> 1)
      case x => powLocal(mul(a, a), b, x >> 1)
    }
    val unit = Array(Array(BigInt(1), BigInt(0)), Array(BigInt(0), BigInt(1))) // 単位行列
    if (n < 0) throw new IllegalArgumentException
    powLocal(a, unit, n)
  }

  def fib(n: Int) = {
    val m = Array(Array(1, 1), Array(1, 0))
    pow(m, n)(1)(0)
  }
}

行列(Matrix)をBigIntの二次元配列で表現。
Intの二次元配列からBigIntの二次元配列への暗黙の型変換メソッドを作成。
累乗を求める部分では、内部関数として末尾再帰の最適化が行われるよう工夫した。

呼び出し部分はこちら。

object Fibonacci {
  def main(args: Array[String]) {
    val n = 100
    val m = 100000
    timeit("[2] Enum first 100 Fibs") {
      (0 to n).foreach { i => println(i + ": " + FibonacciMatrix.fib(i)) }
    }
    timeit("[2] Find 100,000th Fib") {
      println(m + ": " + FibonacciMatrix.fib(m))
    }
  }
}

計測結果は以下の通り。

  • 100番目までの列挙 => 108 ms
  • 100,000番目の表示 => 120 ms

OutOfMemoryが発生しない上に、処理時間そのものも大幅に短縮できた。

 

3. 行列累乗(並列処理)版

行列計算の mul メソッドを par を使って並列化したバージョン。

import annotation.tailrec

object FibonacciMatrix {
  type Matrix = Array[Array[BigInt]]

  implicit def arrayArrayInt2Matrix(m: Array[Array[Int]]): Matrix = {
    m.map {
      a: Array[Int] => a.map {
        x: Int => BigInt(x)
      }
    }
  }

  def mul(m1: Matrix, m2: Matrix) = {
    val res = Array.fill(m1.length, m2(0).length)(BigInt(0))
    for (row <- (0 until m1.length).par;
         col <- (0 until m2(0).length).par;
         i <- 0 until m1(0).length) {
      res(row)(col) += m1(row)(i) * m2(i)(col)
    }
    res
  }

  def pow(a: Matrix, n: Int) = {
    @tailrec
    def powLocal(a: Matrix, b: Matrix, n: Int): Matrix = n match {
      case 0 => b
      case x if (x & 1) == 1 => powLocal(mul(a, a), mul(a, b), x >> 1)
      case x => powLocal(mul(a, a), b, x >> 1)
    }
    val unit = Array(Array(BigInt(1), BigInt(0)), Array(BigInt(0), BigInt(1))) // 単位行列
    if (n < 0) throw new IllegalArgumentException
    powLocal(a, unit, n)
  }

  def fib(n: Int) = {
    val m = Array(Array(1, 1), Array(1, 0))
    pow(m, n)(1)(0)
  }
}

呼び出し方法。

object Fibonacci {
  def main(args: Array[String]) {
    val n = 100
    val m = 100000
    timeit("[3] Enum first 100 Fibs") {
      (0 to n).foreach { i => println(i + ": " + FibonacciMatrixParallel.fib(i)) }
    }
    timeit("[3] Find 100,000th Fib") {
      println(m + ": " + FibonacciMatrixParallel.fib(m))
    }
  }
}

計測結果は以下の通り。

  • 100番目までの列挙 => 498 ms
  • 100,000番目の表示 => 106 ms

100,000番目の算出はごく僅かに速くなったが、100番目までの列挙は逆に遅くなってしまった。

 

References

MIT OCW - Lecture 3: Divide-and-Conquer: Strassen, Fibonacci, Polynomial Multiplication
http://ocw.mit.edu/courses/electrical-engineering-and-computer-science/6-046j-introduction-to-algorithms-sma-5503-fall-2005/video-lectures/lecture-3-divide-and-conquer-strassen-fibonacci-polynomial-multiplication/

ワンライナー・フィボナッチ
http://stackoverflow.com/questions/7388416/what-is-the-fastest-way-to-write-fibonacci-function-in-scala

行列の積の求め方
http://blog.scala4java.com/2011/12/matrix-multiplication-in-scala-single.html 

0 件のコメント:

コメントを投稿