Scala: 末尾再帰の最適化について
関数型プログラミングの中でも重要なテーマの一つ、末尾再帰の最適化について実際に試してみる。
ここでは例として、階乗の結果を定数 1,000,000,007 で割った剰余を求める関数を作る。(何の役に立つか謎だが)
1. 末尾再帰でない場合
まずは、ごく自然な再帰で階乗を求めてみる。
・コード
object Blog20120705_1 {
val modulo = 1000000007
def factorial(n: Int): Int = n match {
case x if x < 0 => throw new IllegalArgumentException
case 0 => 1
case x => (x.toLong * factorial(x - 1) % modulo).toInt
}
def main(args: Array[String]): Unit = {
println(factorial(0))
println(factorial(6))
println(factorial(1000))
println(factorial(10000))
}
}
・実行結果
1
720
641419708
Exception in thread "main" java.lang.StackOverflowError
at blog.Blog20120705_1$.factorial(Blog20120705.scala:8)
at blog.Blog20120705_1$.factorial(Blog20120705.scala:8)
at blog.Blog20120705_1$.factorial(Blog20120705.scala:8)
(以下略)
factorial(1000) までは問題ない。
しかし factorial(10000) を計算したとき、スタックオーバーフローを引き起こしてしまった。
2. 末尾再帰
前の例では再帰処理の結果に対して計算が行われるため、再帰の呼び出しにあたってスタックを使用する必要があった。
次の例のようにアキュームレータを使うことで、関数の処理の最後に再帰呼び出しを行うようにすれば、コンパイラが最適化してスタックを使わないマシンコードを生成してくれる。
・コード
object Blog20120705_2 {
val modulo = 1000000007
def factorial(n: Int): Int = {
def factorialLocal(n: Int, sofar: Int): Int = n match {
case 0 => sofar
case x => factorialLocal(x - 1, (x.toLong * sofar % modulo).toInt)
}
if (n < 0) throw new IllegalArgumentException
factorialLocal(n, 1)
}
def main(args: Array[String]): Unit = {
println(factorial(0))
println(factorial(6))
println(factorial(1000))
println(factorial(10000))
}
}
・実行結果
1 720 641419708 531950728
晴れて、factorial(10000) を計算できた。
3. アノテーションの活用
Scala 2.8 以降では、末尾再帰の最適化を保証してくれる "@tailrec" アノテーションが登場した。
これを利用すれば、最適化されない場合にコンパイルエラーとしてくれる。
末尾再帰の最適化を期待する関数には、安全のため必ず付けるようにしよう。
(末尾再帰をしていても、いくつかの制約により最適化されない場合もあるとのこと)
アノテーションを使用するためには、「import scala.annotation.tailrec」でインポートしておく。
・最適化されないコードは、次のエラーメッセージが出てコンパイルできない
import scala.annotation.tailrec
object Blog20120705_3 {
val modulo = 1000000007
@tailrec
def factorial(n: Int): Int = n match {
case x if x < 0 => throw new IllegalArgumentException
case 0 => 1
case x => (x.toLong * factorial(x - 1) % modulo).toInt
}
def main(args: Array[String]): Unit = {
println(factorial(0))
println(factorial(6))
println(factorial(1000))
println(factorial(10000))
}
}
・エラーメッセージ
could not optimize @tailrec annotated method factorial: it contains a recursive call not in tail position
・最適化されるコードは、正常にコンパイルされる
import scala.annotation.tailrec
object Blog20120705_4 {
val modulo = 1000000007
def factorial(n: Int): Int = {
@tailrec
def factorialLocal(n: Int, sofar: Int): Int = n match {
case 0 => sofar
case x => factorialLocal(x - 1, (x.toLong * sofar % modulo).toInt)
}
if (n < 0) throw new IllegalArgumentException
factorialLocal(n, 1)
}
def main(args: Array[String]): Unit = {
println(factorial(0))
println(factorial(6))
println(factorial(1000))
println(factorial(10000))
}
}
0 件のコメント:
コメントを投稿