Scala で BIT を実装する
概要
コード
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** Immutable 0-indexed Binary Index Tree (BIT) */ | |
case class BinaryIndexedTree[A](size: Int, cumulatives: Vector[A])(implicit num: Numeric[A]) { | |
require(cumulatives.size == size + 1) | |
/** Return the sum of the values indexed from 0 to index (time: O(log(size))) */ | |
def sum(until: Int): A = { | |
@annotation.tailrec | |
def f(n: Int, sofar: A): A = | |
if (n == 0) | |
sofar | |
else | |
f(n - (n & -n), num.plus(sofar, cumulatives(n))) | |
f(math.max(0, math.min(size, until)), num.zero) | |
} | |
def sum(from: Int, until: Int): A = num.minus(sum(until), sum(from)) | |
/** Return the updated BIT (time: O(log(size))) */ | |
def updated(index: Int, value: A): BinaryIndexedTree[A] = { | |
require((0 until size).contains(index)) | |
@annotation.tailrec | |
def f(n: Int, diff: A, sofar: Vector[A]): Vector[A] = | |
if (size < n) | |
sofar | |
else | |
f(n + (n & -n), diff, sofar.updated(n, num.plus(sofar(n), diff))) | |
copy(cumulatives = f(index + 1, num.minus(value, get(index)), cumulatives)) | |
} | |
def updated(elems: (Int, A)*): BinaryIndexedTree[A] = | |
elems.foldLeft(this) { case (b, (i, x)) => b.updated(i, x) } | |
/** Return the value at the specified index (time: O(log(size))) */ | |
def get(index: Int): A = sum(index + 1, index) | |
} | |
object BinaryIndexedTree { | |
def apply[A](size: Int)(implicit num: Numeric[A]): BinaryIndexedTree[A] = | |
new BinaryIndexedTree[A](size, Vector.fill(size + 1)(num.zero)) | |
} |
実行例
REPL でコードを貼り付ける場合は :paste モードでの実行が必要。
(コンパニオンオブジェクトを同時に定義するため)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | scala> : paste // Entering paste mode (ctrl-D to finish) ...snip... // Exiting paste mode, now interpreting. defined class BinaryIndexedTree defined object BinaryIndexedTree scala> val b = BinaryIndexedTree[Int]( 6 ).updated( 0 -> 1 , 2 -> 2 , 3 -> 1 , 4 -> 1 , 5 -> 3 ) b : BinaryIndexedTree[Int] = BinaryIndexedTree( 6 ,Vector( 0 , 1 , 1 , 2 , 4 , 1 , 4 )) scala> b.sum( 6 ) // 1 + 2 + 1 + 1 + 3 res 0 : Int = 8 scala> b.sum( 2 , 5 ) // 2 + 1 + 1 res 1 : Int = 4 |
0 件のコメント:
コメントを投稿