Scala: 末尾から続く0の個数を数える
与えられた64ビット整数値に対して、それを2進数にしたときに末尾から0が何個連続するかを求めたい。
たとえば十進数 88 の場合、2進数にすると 1011000 となるので、答えは 3 となる。
これは、ntz (number of trailing zeros) として知られる典型的なビット操作の問題である。
今回はそれを Scala で解くコードをいくつか準備し、その性能を比較してみた。
環境
- CPU: 1.8 GHz Intel Core i5
- OS: Mac OS X 10.9.4
- Scala: 2.11.2
- sbt: 0.13.5
- sbt-jmh: 0.1.6
Scala での実装
1ビットずつシフトして数える
いかにもコストが重そうだが、find を使って最初にビットの立った位置を探す方法。
x は負数をとることもあるので、符号なしの右シフト(>>>)が必要な点に注意。
def ntzNaive(x: Long): Int =
(0 until 64).find(i => ((x >>> i) & 1L) != 0L).getOrElse(64)
var を使って工夫 (2種類)
def ntzLoop1(x: Long): Int = {
var y = ~x & (x - 1)
var n = 0
while (y != 0) {n += 1; y >>>= 1}
n
}
def ntzLoop2(x: Long): Int = {
var y = x
var n = 64
while (y != 0) {n -= 1; y <<= 1}
n
}
二分探索木を用いた方法 (3種類)
def ntzBinarySearch1(x: Long): Int = {
if (x == 0L) {
64
} else {
var n = 1
var y = x
if ((y & 0x00000000ffffffffL) == 0L) {n += 32; y >>>= 32}
if ((y & 0x000000000000ffffL) == 0L) {n += 16; y >>>= 16}
if ((y & 0x00000000000000ffL) == 0L) {n += 8; y >>>= 8}
if ((y & 0x000000000000000fL) == 0L) {n += 4; y >>>= 4}
if ((y & 0x0000000000000003L) == 0L) {n += 2; y >>>= 2}
n - (y & 1L).toInt
}
}
def ntzBinarySearch2(x: Long): Int = {
if (x == 0L) {
64
} else {
var n = 63
var y = 0L
var z = x << 32; if (z != 0) {n -= 32; y = z}
z = y << 16; if (z != 0) {n -= 16; y = z}
z = y << 8; if (z != 0) {n -= 8; y = z}
z = y << 4; if (z != 0) {n -= 4; y = z}
z = y << 2; if (z != 0) {n -= 2; y = z}
z = y << 1; if (z != 0) n -= 1
n
}
}
def ntzBinarySearch3(x: Long): Int = {
if (x == 0L)
64
else
if ((x & 0xffffffffL) != 0L)
if ((x & 0xffffL) != 0L)
if ((x & 0xffL) != 0L)
if ((x & 0xfL) != 0L)
if ((x & 3L) != 0L)
if ((x & 1L) != 0L) 0 else 1
else
if ((x & 4L) != 0L) 2 else 3
else
if ((x & 0x30L) != 0L)
if ((x & 0x10L) != 0L) 4 else 5
else
if ((x & 0x40L) != 0L) 6 else 7
else
if ((x & 0xf00L) != 0L)
if ((x & 0x300L) != 0L)
if ((x & 0x100L) != 0L) 8 else 9
else
if ((x & 0x400L) != 0L) 10 else 11
else
if ((x & 0x3000L) != 0L)
if ((x & 0x1000L) != 0L) 12 else 13
else
if ((x & 0x4000L) != 0L) 14 else 15
else
if ((x & 0xff0000L) != 0L)
if ((x & 0xf0000L) != 0L)
if ((x & 30000L) != 0L)
if ((x & 10000L) != 0L) 16 else 17
else
if ((x & 40000L) != 0L) 18 else 19
else
if ((x & 0x300000L) != 0L)
if ((x & 0x100000L) != 0L) 20 else 21
else
if ((x & 0x400000L) != 0L) 22 else 23
else
if ((x & 0xf000000L) != 0L)
if ((x & 0x3000000L) != 0L)
if ((x & 0x1000000L) != 0L) 24 else 25
else
if ((x & 0x4000000L) != 0L) 26 else 27
else
if ((x & 0x30000000L) != 0L)
if ((x & 0x10000000L) != 0L) 28 else 29
else
if ((x & 0x40000000L) != 0L) 30 else 31
else
if ((x & 0xffff00000000L) != 0L)
if ((x & 0xff00000000L) != 0L)
if ((x & 0xf00000000L) != 0L)
if ((x & 0x300000000L) != 0L)
if ((x & 0x100000000L) != 0L) 32 else 33
else
if ((x & 0x400000000L) != 0L) 34 else 35
else
if ((x & 0x3000000000L) != 0L)
if ((x & 0x1000000000L) != 0L) 36 else 37
else
if ((x & 0x4000000000L) != 0L) 38 else 39
else
if ((x & 0xf0000000000L) != 0L)
if ((x & 0x30000000000L) != 0L)
if ((x & 0x10000000000L) != 0L) 40 else 41
else
if ((x & 0x40000000000L) != 0L) 42 else 43
else
if ((x & 0x300000000000L) != 0L)
if ((x & 0x100000000000L) != 0L) 44 else 45
else
if ((x & 0x400000000000L) != 0L) 46 else 47
else
if ((x & 0xff000000000000L) != 0L)
if ((x & 0xf000000000000L) != 0L)
if ((x & 3000000000000L) != 0L)
if ((x & 1000000000000L) != 0L) 48 else 49
else
if ((x & 4000000000000L) != 0L) 50 else 51
else
if ((x & 0x30000000000000L) != 0L)
if ((x & 0x10000000000000L) != 0L) 52 else 53
else
if ((x & 0x40000000000000L) != 0L) 54 else 55
else
if ((x & 0xf00000000000000L) != 0L)
if ((x & 0x300000000000000L) != 0L)
if ((x & 0x100000000000000L) != 0L) 56 else 57
else
if ((x & 0x400000000000000L) != 0L) 58 else 59
else
if ((x & 0x3000000000000000L) != 0L)
if ((x & 0x1000000000000000L) != 0L) 60 else 61
else
if ((x & 0x4000000000000000L) != 0L) 62 else 63
}
対数計算を利用
private[this] val ln2 = math.log(2)
def ntzLogarithm(x: Long): Int =
if (x == 0L)
64
else if (x == 0x8000000000000000L)
63
else
(math.log(x & -x) / ln2).toInt
ハッシュを利用 (黒魔術)
private[this] val ntzHash = 0x03F566ED27179461L
private[this] lazy val ntzTable: Array[Int] =
(0 until 64).map(i => (ntzHash << i >>> 58).toInt -> i).sorted.map(_._2).toArray
def ntzHash(x: Long): Int =
if (x != 0L) ntzTable((((x & -x) * ntzHash) >>> 58).toInt) else 64
ネイティブでの実装
C++ のインラインアセンブリを使って x86-64の命令を直接呼び出す。
それを JNI/JNA でラップして Scala から呼び出した。
JNI (Java Native Interface)
- C++ コード
#include <jni.h>
extern "C" JNIEXPORT jint JNICALL Java_com_github_mogproject_util_BitOperation_ntzJni
(JNIEnv *env, jobject obj, jlong x) {
if (x == 0) return 64;
int pos = 0;
__asm__(
"bsf %1,%q0;"
:"=r"(pos)
:"rm"(x));
return pos;
}
- ビルド
$ g++ -dynamiclib -O3 \ -I/usr/include -I$JAVA_HOME/include -I$JAVA_HOME/include/darwin \ bitops_jni.cpp -o libbitops_jni.dylib
- Scala コード
package com.github.mogproject.util
class BitOperation {
@native def ntzJni(x: Long): Int
}
object BitOperation {
System.load("/path/to/libbitops_jni.dylib")
private[this] val bitop = new BitOperation
def ntzJni(x: => Long) = bitop.ntzJni(x)
}
JNA (Java Native Access)
- C++ コード
extern "C" int ntz(unsigned long long x) {
if (x == 0) return 64;
int pos = 0;
__asm__(
"bsf %1,%q0;"
:"=r"(pos)
:"rm"(x));
return pos;
}
- ビルド
$ g++ -dynamiclib -O3 bitops_jna.cpp -o libbitops_jna.dylib
- Scala コード
private[this] val lib =
com.sun.jna.NativeLibrary.getInstance("/path/to/libbitops_jna.dylib")
private[this] val func = lib.getFunction("ntz")
def ntzJna(x: Long): Int = func.invokeInt(Array(x.asInstanceOf[Object]))
テスト
とりあえず REPL でランダムなLong型整数を与え、全てのメソッドの結果が一致することを確認。
scala> import com.github.mogproject.util.BitOperation._
import com.github.mogproject.util.BitOperation._
scala> def f(x: Long) = Seq(ntzNaive(x), ntzLoop1(x), ntzLoop2(x),
| ntzBinarySearch1(x), ntzBinarySearch2(x), ntzBinarySearch3(x), ntzLogarithm(x),
| ntzHash(x), ntzJni(x), ntzJna(x))
f: (x: Long)Seq[Int]
scala> f(0)
res0: Seq[Int] = List(64, 64, 64, 64, 64, 64, 64, 64, 64, 64)
scala> f(-1)
res1: Seq[Int] = List(0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
scala> f(0x8000000000000000L)
res2: Seq[Int] = List(63, 63, 63, 63, 63, 63, 63, 63, 63, 63)
scala> Seq.fill(1000)(scala.util.Random.nextLong) forall (f(_).toSet.size == 1)
res3: Boolean = true
ベンチマーク
sbt-jmh でベンチマークを取る。
import com.github.mogproject.util.BitOperation
import org.openjdk.jmh.annotations.{Benchmark, Scope, State => JmhState}
import scala.util.Random
@JmhState(Scope.Thread)
class BitOperationBench {
val random = new Random(12345)
@Benchmark def ntzNaive(): Unit = BitOperation.ntzNaive(random.nextLong())
@Benchmark def ntzLoop1(): Unit = BitOperation.ntzLoop1(random.nextLong())
@Benchmark def ntzLoop2(): Unit = BitOperation.ntzLoop2(random.nextLong())
@Benchmark def ntzBinarySearch1(): Unit = BitOperation.ntzBinarySearch1(random.nextLong())
@Benchmark def ntzBinarySearch2(): Unit = BitOperation.ntzBinarySearch2(random.nextLong())
@Benchmark def ntzBinarySearch3(): Unit = BitOperation.ntzBinarySearch3(random.nextLong())
@Benchmark def ntzLogarithm(): Unit = BitOperation.ntzLogarithm(random.nextLong())
@Benchmark def ntzHash(): Unit = BitOperation.ntzHash(random.nextLong())
@Benchmark def ntzJni(): Unit = BitOperation.ntzJni(random.nextLong())
@Benchmark def ntzJna(): Unit = BitOperation.ntzJna(random.nextLong())
}
測定結果
[info] Running org.openjdk.jmh.Main -i 10 -wi 10 -f1 -t1 .*BitOperation.* (snip) [info] Benchmark Mode Samples Score Score error Units [info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch1 thrpt 10 35184555.294 147974.730 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch2 thrpt 10 35130401.937 234705.330 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch3 thrpt 10 29846819.149 352975.842 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzHash thrpt 10 35127105.013 127143.455 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzJna thrpt 10 1396115.477 16139.308 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzJni thrpt 10 25644475.628 231921.707 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzLogarithm thrpt 10 35144781.544 92927.496 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzLoop1 thrpt 10 30099013.164 308642.730 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzLoop2 thrpt 10 15183889.536 142090.869 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzNaive thrpt 10 14309678.759 317520.811 ops/s
飛び抜けて良いものがある、という訳ではなく
BinarySearch1, BinarySearch2, Logarithm, Hash が先頭グループとしてほぼ同等の成績。
JNI はやはり関数コールのペナルティが大きいのか、実用には厳しいレベル。
JNA は一つだけ桁違いに性能が悪かった。
やはり計測あるのみ。
"Trust no one, bench everything." を実感した。
References
- 404 Blog Not Found:C - でも一番右端の立っているビット位置を求めてみた
- chessprogramming - BitScan
- A Simple Java Native Interface (JNI) example in Java and Scala
0 件のコメント:
コメントを投稿