Scala: 末尾から続く0の個数を数える
与えられた64ビット整数値に対して、それを2進数にしたときに末尾から0が何個連続するかを求めたい。
たとえば十進数 88 の場合、2進数にすると 1011000 となるので、答えは 3 となる。
これは、ntz (number of trailing zeros) として知られる典型的なビット操作の問題である。
今回はそれを Scala で解くコードをいくつか準備し、その性能を比較してみた。
環境
- CPU: 1.8 GHz Intel Core i5
- OS: Mac OS X 10.9.4
- Scala: 2.11.2
- sbt: 0.13.5
- sbt-jmh: 0.1.6
Scala での実装
1ビットずつシフトして数える
いかにもコストが重そうだが、find を使って最初にビットの立った位置を探す方法。
x は負数をとることもあるので、符号なしの右シフト(>>>)が必要な点に注意。
def ntzNaive(x: Long): Int = (0 until 64).find(i => ((x >>> i) & 1L) != 0L).getOrElse(64)
var を使って工夫 (2種類)
def ntzLoop1(x: Long): Int = { var y = ~x & (x - 1) var n = 0 while (y != 0) {n += 1; y >>>= 1} n } def ntzLoop2(x: Long): Int = { var y = x var n = 64 while (y != 0) {n -= 1; y <<= 1} n }
二分探索木を用いた方法 (3種類)
def ntzBinarySearch1(x: Long): Int = { if (x == 0L) { 64 } else { var n = 1 var y = x if ((y & 0x00000000ffffffffL) == 0L) {n += 32; y >>>= 32} if ((y & 0x000000000000ffffL) == 0L) {n += 16; y >>>= 16} if ((y & 0x00000000000000ffL) == 0L) {n += 8; y >>>= 8} if ((y & 0x000000000000000fL) == 0L) {n += 4; y >>>= 4} if ((y & 0x0000000000000003L) == 0L) {n += 2; y >>>= 2} n - (y & 1L).toInt } } def ntzBinarySearch2(x: Long): Int = { if (x == 0L) { 64 } else { var n = 63 var y = 0L var z = x << 32; if (z != 0) {n -= 32; y = z} z = y << 16; if (z != 0) {n -= 16; y = z} z = y << 8; if (z != 0) {n -= 8; y = z} z = y << 4; if (z != 0) {n -= 4; y = z} z = y << 2; if (z != 0) {n -= 2; y = z} z = y << 1; if (z != 0) n -= 1 n } } def ntzBinarySearch3(x: Long): Int = { if (x == 0L) 64 else if ((x & 0xffffffffL) != 0L) if ((x & 0xffffL) != 0L) if ((x & 0xffL) != 0L) if ((x & 0xfL) != 0L) if ((x & 3L) != 0L) if ((x & 1L) != 0L) 0 else 1 else if ((x & 4L) != 0L) 2 else 3 else if ((x & 0x30L) != 0L) if ((x & 0x10L) != 0L) 4 else 5 else if ((x & 0x40L) != 0L) 6 else 7 else if ((x & 0xf00L) != 0L) if ((x & 0x300L) != 0L) if ((x & 0x100L) != 0L) 8 else 9 else if ((x & 0x400L) != 0L) 10 else 11 else if ((x & 0x3000L) != 0L) if ((x & 0x1000L) != 0L) 12 else 13 else if ((x & 0x4000L) != 0L) 14 else 15 else if ((x & 0xff0000L) != 0L) if ((x & 0xf0000L) != 0L) if ((x & 30000L) != 0L) if ((x & 10000L) != 0L) 16 else 17 else if ((x & 40000L) != 0L) 18 else 19 else if ((x & 0x300000L) != 0L) if ((x & 0x100000L) != 0L) 20 else 21 else if ((x & 0x400000L) != 0L) 22 else 23 else if ((x & 0xf000000L) != 0L) if ((x & 0x3000000L) != 0L) if ((x & 0x1000000L) != 0L) 24 else 25 else if ((x & 0x4000000L) != 0L) 26 else 27 else if ((x & 0x30000000L) != 0L) if ((x & 0x10000000L) != 0L) 28 else 29 else if ((x & 0x40000000L) != 0L) 30 else 31 else if ((x & 0xffff00000000L) != 0L) if ((x & 0xff00000000L) != 0L) if ((x & 0xf00000000L) != 0L) if ((x & 0x300000000L) != 0L) if ((x & 0x100000000L) != 0L) 32 else 33 else if ((x & 0x400000000L) != 0L) 34 else 35 else if ((x & 0x3000000000L) != 0L) if ((x & 0x1000000000L) != 0L) 36 else 37 else if ((x & 0x4000000000L) != 0L) 38 else 39 else if ((x & 0xf0000000000L) != 0L) if ((x & 0x30000000000L) != 0L) if ((x & 0x10000000000L) != 0L) 40 else 41 else if ((x & 0x40000000000L) != 0L) 42 else 43 else if ((x & 0x300000000000L) != 0L) if ((x & 0x100000000000L) != 0L) 44 else 45 else if ((x & 0x400000000000L) != 0L) 46 else 47 else if ((x & 0xff000000000000L) != 0L) if ((x & 0xf000000000000L) != 0L) if ((x & 3000000000000L) != 0L) if ((x & 1000000000000L) != 0L) 48 else 49 else if ((x & 4000000000000L) != 0L) 50 else 51 else if ((x & 0x30000000000000L) != 0L) if ((x & 0x10000000000000L) != 0L) 52 else 53 else if ((x & 0x40000000000000L) != 0L) 54 else 55 else if ((x & 0xf00000000000000L) != 0L) if ((x & 0x300000000000000L) != 0L) if ((x & 0x100000000000000L) != 0L) 56 else 57 else if ((x & 0x400000000000000L) != 0L) 58 else 59 else if ((x & 0x3000000000000000L) != 0L) if ((x & 0x1000000000000000L) != 0L) 60 else 61 else if ((x & 0x4000000000000000L) != 0L) 62 else 63 }
対数計算を利用
private[this] val ln2 = math.log(2) def ntzLogarithm(x: Long): Int = if (x == 0L) 64 else if (x == 0x8000000000000000L) 63 else (math.log(x & -x) / ln2).toInt
ハッシュを利用 (黒魔術)
private[this] val ntzHash = 0x03F566ED27179461L private[this] lazy val ntzTable: Array[Int] = (0 until 64).map(i => (ntzHash << i >>> 58).toInt -> i).sorted.map(_._2).toArray def ntzHash(x: Long): Int = if (x != 0L) ntzTable((((x & -x) * ntzHash) >>> 58).toInt) else 64
ネイティブでの実装
C++ のインラインアセンブリを使って x86-64の命令を直接呼び出す。
それを JNI/JNA でラップして Scala から呼び出した。
JNI (Java Native Interface)
- C++ コード
#include <jni.h> extern "C" JNIEXPORT jint JNICALL Java_com_github_mogproject_util_BitOperation_ntzJni (JNIEnv *env, jobject obj, jlong x) { if (x == 0) return 64; int pos = 0; __asm__( "bsf %1,%q0;" :"=r"(pos) :"rm"(x)); return pos; }
- ビルド
$ g++ -dynamiclib -O3 \ -I/usr/include -I$JAVA_HOME/include -I$JAVA_HOME/include/darwin \ bitops_jni.cpp -o libbitops_jni.dylib
- Scala コード
package com.github.mogproject.util class BitOperation { @native def ntzJni(x: Long): Int } object BitOperation { System.load("/path/to/libbitops_jni.dylib") private[this] val bitop = new BitOperation def ntzJni(x: => Long) = bitop.ntzJni(x) }
JNA (Java Native Access)
- C++ コード
extern "C" int ntz(unsigned long long x) { if (x == 0) return 64; int pos = 0; __asm__( "bsf %1,%q0;" :"=r"(pos) :"rm"(x)); return pos; }
- ビルド
$ g++ -dynamiclib -O3 bitops_jna.cpp -o libbitops_jna.dylib
- Scala コード
private[this] val lib = com.sun.jna.NativeLibrary.getInstance("/path/to/libbitops_jna.dylib") private[this] val func = lib.getFunction("ntz") def ntzJna(x: Long): Int = func.invokeInt(Array(x.asInstanceOf[Object]))
テスト
とりあえず REPL でランダムなLong型整数を与え、全てのメソッドの結果が一致することを確認。
scala> import com.github.mogproject.util.BitOperation._ import com.github.mogproject.util.BitOperation._ scala> def f(x: Long) = Seq(ntzNaive(x), ntzLoop1(x), ntzLoop2(x), | ntzBinarySearch1(x), ntzBinarySearch2(x), ntzBinarySearch3(x), ntzLogarithm(x), | ntzHash(x), ntzJni(x), ntzJna(x)) f: (x: Long)Seq[Int] scala> f(0) res0: Seq[Int] = List(64, 64, 64, 64, 64, 64, 64, 64, 64, 64) scala> f(-1) res1: Seq[Int] = List(0, 0, 0, 0, 0, 0, 0, 0, 0, 0) scala> f(0x8000000000000000L) res2: Seq[Int] = List(63, 63, 63, 63, 63, 63, 63, 63, 63, 63) scala> Seq.fill(1000)(scala.util.Random.nextLong) forall (f(_).toSet.size == 1) res3: Boolean = true
ベンチマーク
sbt-jmh でベンチマークを取る。
import com.github.mogproject.util.BitOperation import org.openjdk.jmh.annotations.{Benchmark, Scope, State => JmhState} import scala.util.Random @JmhState(Scope.Thread) class BitOperationBench { val random = new Random(12345) @Benchmark def ntzNaive(): Unit = BitOperation.ntzNaive(random.nextLong()) @Benchmark def ntzLoop1(): Unit = BitOperation.ntzLoop1(random.nextLong()) @Benchmark def ntzLoop2(): Unit = BitOperation.ntzLoop2(random.nextLong()) @Benchmark def ntzBinarySearch1(): Unit = BitOperation.ntzBinarySearch1(random.nextLong()) @Benchmark def ntzBinarySearch2(): Unit = BitOperation.ntzBinarySearch2(random.nextLong()) @Benchmark def ntzBinarySearch3(): Unit = BitOperation.ntzBinarySearch3(random.nextLong()) @Benchmark def ntzLogarithm(): Unit = BitOperation.ntzLogarithm(random.nextLong()) @Benchmark def ntzHash(): Unit = BitOperation.ntzHash(random.nextLong()) @Benchmark def ntzJni(): Unit = BitOperation.ntzJni(random.nextLong()) @Benchmark def ntzJna(): Unit = BitOperation.ntzJna(random.nextLong()) }
測定結果
[info] Running org.openjdk.jmh.Main -i 10 -wi 10 -f1 -t1 .*BitOperation.* (snip) [info] Benchmark Mode Samples Score Score error Units [info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch1 thrpt 10 35184555.294 147974.730 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch2 thrpt 10 35130401.937 234705.330 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch3 thrpt 10 29846819.149 352975.842 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzHash thrpt 10 35127105.013 127143.455 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzJna thrpt 10 1396115.477 16139.308 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzJni thrpt 10 25644475.628 231921.707 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzLogarithm thrpt 10 35144781.544 92927.496 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzLoop1 thrpt 10 30099013.164 308642.730 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzLoop2 thrpt 10 15183889.536 142090.869 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzNaive thrpt 10 14309678.759 317520.811 ops/s
飛び抜けて良いものがある、という訳ではなく
BinarySearch1, BinarySearch2, Logarithm, Hash が先頭グループとしてほぼ同等の成績。
JNI はやはり関数コールのペナルティが大きいのか、実用には厳しいレベル。
JNA は一つだけ桁違いに性能が悪かった。
やはり計測あるのみ。
"Trust no one, bench everything." を実感した。
References
- 404 Blog Not Found:C - でも一番右端の立っているビット位置を求めてみた
- chessprogramming - BitScan
- A Simple Java Native Interface (JNI) example in Java and Scala
0 件のコメント:
コメントを投稿