7.06.2012

Tail Recursion Optimization in Scala

Scala: 末尾再帰の最適化について

関数型プログラミングの中でも重要なテーマの一つ、末尾再帰の最適化について実際に試してみる。

ここでは例として、階乗の結果を定数 1,000,000,007 で割った剰余を求める関数を作る。(何の役に立つか謎だが)

1. 末尾再帰でない場合

まずは、ごく自然な再帰で階乗を求めてみる。

・コード

object Blog20120705_1 {
  val modulo = 1000000007
  def factorial(n: Int): Int = n match {
    case x if x < 0 => throw new IllegalArgumentException
    case 0 => 1
    case x => (x.toLong * factorial(x - 1) % modulo).toInt
  }
  def main(args: Array[String]): Unit = {
    println(factorial(0))
    println(factorial(6))
    println(factorial(1000))
    println(factorial(10000))
  }
}

・実行結果

1
720
641419708
Exception in thread "main" java.lang.StackOverflowError
    at blog.Blog20120705_1$.factorial(Blog20120705.scala:8)
    at blog.Blog20120705_1$.factorial(Blog20120705.scala:8)
    at blog.Blog20120705_1$.factorial(Blog20120705.scala:8)
(以下略)

factorial(1000) までは問題ない。
しかし factorial(10000) を計算したとき、スタックオーバーフローを引き起こしてしまった。

2. 末尾再帰

前の例では再帰処理の結果に対して計算が行われるため、再帰の呼び出しにあたってスタックを使用する必要があった。
次の例のようにアキュームレータを使うことで、関数の処理の最後に再帰呼び出しを行うようにすれば、コンパイラが最適化してスタックを使わないマシンコードを生成してくれる。

・コード

object Blog20120705_2 {
  val modulo = 1000000007
  def factorial(n: Int): Int = {
    def factorialLocal(n: Int, sofar: Int): Int = n match {
      case 0 => sofar
      case x => factorialLocal(x - 1, (x.toLong * sofar % modulo).toInt)
    }
    if (n < 0) throw new IllegalArgumentException
    factorialLocal(n, 1)
  }
  def main(args: Array[String]): Unit = {
    println(factorial(0))
    println(factorial(6))
    println(factorial(1000))
    println(factorial(10000))
  }
}

・実行結果

1
720
641419708
531950728

晴れて、factorial(10000) を計算できた。

3. アノテーションの活用

Scala 2.8 以降では、末尾再帰の最適化を保証してくれる "@tailrec" アノテーションが登場した。
これを利用すれば、最適化されない場合にコンパイルエラーとしてくれる。
末尾再帰の最適化を期待する関数には、安全のため必ず付けるようにしよう。
(末尾再帰をしていても、いくつかの制約により最適化されない場合もあるとのこと)

アノテーションを使用するためには、「import scala.annotation.tailrec」でインポートしておく。

・最適化されないコードは、次のエラーメッセージが出てコンパイルできない

import scala.annotation.tailrec
object Blog20120705_3 {
  val modulo = 1000000007
  @tailrec
  def factorial(n: Int): Int = n match {
    case x if x < 0 => throw new IllegalArgumentException
    case 0 => 1
    case x => (x.toLong * factorial(x - 1) % modulo).toInt
  }
  def main(args: Array[String]): Unit = {
    println(factorial(0))
    println(factorial(6))
    println(factorial(1000))
    println(factorial(10000))
  }
}

・エラーメッセージ

could not optimize @tailrec annotated method factorial: it contains a recursive call not in tail position

・最適化されるコードは、正常にコンパイルされる

import scala.annotation.tailrec
object Blog20120705_4 {
  val modulo = 1000000007
  def factorial(n: Int): Int = {
    @tailrec
    def factorialLocal(n: Int, sofar: Int): Int = n match {
      case 0 => sofar
      case x => factorialLocal(x - 1, (x.toLong * sofar % modulo).toInt)
    }
    if (n < 0) throw new IllegalArgumentException
    factorialLocal(n, 1)
  }
  def main(args: Array[String]): Unit = {
    println(factorial(0))
    println(factorial(6))
    println(factorial(1000))
    println(factorial(10000))
  }
}

0 件のコメント:

コメントを投稿