Scala: 末尾再帰の最適化について
関数型プログラミングの中でも重要なテーマの一つ、末尾再帰の最適化について実際に試してみる。
ここでは例として、階乗の結果を定数 1,000,000,007 で割った剰余を求める関数を作る。(何の役に立つか謎だが)
1. 末尾再帰でない場合
まずは、ごく自然な再帰で階乗を求めてみる。
・コード
object Blog20120705_1 { val modulo = 1000000007 def factorial(n: Int): Int = n match { case x if x < 0 => throw new IllegalArgumentException case 0 => 1 case x => (x.toLong * factorial(x - 1) % modulo).toInt } def main(args: Array[String]): Unit = { println(factorial(0)) println(factorial(6)) println(factorial(1000)) println(factorial(10000)) } }
・実行結果
1 720 641419708 Exception in thread "main" java.lang.StackOverflowError at blog.Blog20120705_1$.factorial(Blog20120705.scala:8) at blog.Blog20120705_1$.factorial(Blog20120705.scala:8) at blog.Blog20120705_1$.factorial(Blog20120705.scala:8) (以下略)
factorial(1000) までは問題ない。
しかし factorial(10000) を計算したとき、スタックオーバーフローを引き起こしてしまった。
2. 末尾再帰
前の例では再帰処理の結果に対して計算が行われるため、再帰の呼び出しにあたってスタックを使用する必要があった。
次の例のようにアキュームレータを使うことで、関数の処理の最後に再帰呼び出しを行うようにすれば、コンパイラが最適化してスタックを使わないマシンコードを生成してくれる。
・コード
object Blog20120705_2 { val modulo = 1000000007 def factorial(n: Int): Int = { def factorialLocal(n: Int, sofar: Int): Int = n match { case 0 => sofar case x => factorialLocal(x - 1, (x.toLong * sofar % modulo).toInt) } if (n < 0) throw new IllegalArgumentException factorialLocal(n, 1) } def main(args: Array[String]): Unit = { println(factorial(0)) println(factorial(6)) println(factorial(1000)) println(factorial(10000)) } }
・実行結果
1 720 641419708 531950728
晴れて、factorial(10000) を計算できた。
3. アノテーションの活用
Scala 2.8 以降では、末尾再帰の最適化を保証してくれる "@tailrec" アノテーションが登場した。
これを利用すれば、最適化されない場合にコンパイルエラーとしてくれる。
末尾再帰の最適化を期待する関数には、安全のため必ず付けるようにしよう。
(末尾再帰をしていても、いくつかの制約により最適化されない場合もあるとのこと)
アノテーションを使用するためには、「import scala.annotation.tailrec」でインポートしておく。
・最適化されないコードは、次のエラーメッセージが出てコンパイルできない
import scala.annotation.tailrec object Blog20120705_3 { val modulo = 1000000007 @tailrec def factorial(n: Int): Int = n match { case x if x < 0 => throw new IllegalArgumentException case 0 => 1 case x => (x.toLong * factorial(x - 1) % modulo).toInt } def main(args: Array[String]): Unit = { println(factorial(0)) println(factorial(6)) println(factorial(1000)) println(factorial(10000)) } }
・エラーメッセージ
could not optimize @tailrec annotated method factorial: it contains a recursive call not in tail position
・最適化されるコードは、正常にコンパイルされる
import scala.annotation.tailrec object Blog20120705_4 { val modulo = 1000000007 def factorial(n: Int): Int = { @tailrec def factorialLocal(n: Int, sofar: Int): Int = n match { case 0 => sofar case x => factorialLocal(x - 1, (x.toLong * sofar % modulo).toInt) } if (n < 0) throw new IllegalArgumentException factorialLocal(n, 1) } def main(args: Array[String]): Unit = { println(factorial(0)) println(factorial(6)) println(factorial(1000)) println(factorial(10000)) } }
0 件のコメント:
コメントを投稿