Scala: 末尾から続く0の個数を数える
与えられた64ビット整数値に対して、それを2進数にしたときに末尾から0が何個連続するかを求めたい。
たとえば十進数 88 の場合、2進数にすると 1011000 となるので、答えは 3 となる。
これは、ntz (number of trailing zeros) として知られる典型的なビット操作の問題である。
今回はそれを Scala で解くコードをいくつか準備し、その性能を比較してみた。
環境
- CPU: 1.8 GHz Intel Core i5
- OS: Mac OS X 10.9.4
- Scala: 2.11.2
- sbt: 0.13.5
- sbt-jmh: 0.1.6
Scala での実装
1ビットずつシフトして数える
いかにもコストが重そうだが、find を使って最初にビットの立った位置を探す方法。
x は負数をとることもあるので、符号なしの右シフト(>>>)が必要な点に注意。
1 2 | def ntzNaive(x : Long) : Int = ( 0 until 64 ).find(i = > ((x >>> i) & 1 L) ! = 0 L).getOrElse( 64 ) |
var を使って工夫 (2種類)
1 2 3 4 5 6 7 8 9 10 11 12 13 | def ntzLoop 1 (x : Long) : Int = { var y = ~x & (x - 1 ) var n = 0 while (y ! = 0 ) {n + = 1 ; y >>> = 1 } n } def ntzLoop 2 (x : Long) : Int = { var y = x var n = 64 while (y ! = 0 ) {n - = 1 ; y << = 1 } n } |
二分探索木を用いた方法 (3種類)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | def ntzBinarySearch 1 (x : Long) : Int = { if (x == 0 L) { 64 } else { var n = 1 var y = x if ((y & 0x00000000ffffffff L) == 0 L) {n + = 32 ; y >>> = 32 } if ((y & 0x000000000000ffff L) == 0 L) {n + = 16 ; y >>> = 16 } if ((y & 0x00000000000000ff L) == 0 L) {n + = 8 ; y >>> = 8 } if ((y & 0x000000000000000f L) == 0 L) {n + = 4 ; y >>> = 4 } if ((y & 0x0000000000000003 L) == 0 L) {n + = 2 ; y >>> = 2 } n - (y & 1 L).toInt } } def ntzBinarySearch 2 (x : Long) : Int = { if (x == 0 L) { 64 } else { var n = 63 var y = 0 L var z = x << 32 ; if (z ! = 0 ) {n - = 32 ; y = z} z = y << 16 ; if (z ! = 0 ) {n - = 16 ; y = z} z = y << 8 ; if (z ! = 0 ) {n - = 8 ; y = z} z = y << 4 ; if (z ! = 0 ) {n - = 4 ; y = z} z = y << 2 ; if (z ! = 0 ) {n - = 2 ; y = z} z = y << 1 ; if (z ! = 0 ) n - = 1 n } } def ntzBinarySearch 3 (x : Long) : Int = { if (x == 0 L) 64 else if ((x & 0xffffffff L) ! = 0 L) if ((x & 0xffff L) ! = 0 L) if ((x & 0xff L) ! = 0 L) if ((x & 0xf L) ! = 0 L) if ((x & 3 L) ! = 0 L) if ((x & 1 L) ! = 0 L) 0 else 1 else if ((x & 4 L) ! = 0 L) 2 else 3 else if ((x & 0x30 L) ! = 0 L) if ((x & 0x10 L) ! = 0 L) 4 else 5 else if ((x & 0x40 L) ! = 0 L) 6 else 7 else if ((x & 0xf00 L) ! = 0 L) if ((x & 0x300 L) ! = 0 L) if ((x & 0x100 L) ! = 0 L) 8 else 9 else if ((x & 0x400 L) ! = 0 L) 10 else 11 else if ((x & 0x3000 L) ! = 0 L) if ((x & 0x1000 L) ! = 0 L) 12 else 13 else if ((x & 0x4000 L) ! = 0 L) 14 else 15 else if ((x & 0xff0000 L) ! = 0 L) if ((x & 0xf0000 L) ! = 0 L) if ((x & 30000 L) ! = 0 L) if ((x & 10000 L) ! = 0 L) 16 else 17 else if ((x & 40000 L) ! = 0 L) 18 else 19 else if ((x & 0x300000 L) ! = 0 L) if ((x & 0x100000 L) ! = 0 L) 20 else 21 else if ((x & 0x400000 L) ! = 0 L) 22 else 23 else if ((x & 0xf000000 L) ! = 0 L) if ((x & 0x3000000 L) ! = 0 L) if ((x & 0x1000000 L) ! = 0 L) 24 else 25 else if ((x & 0x4000000 L) ! = 0 L) 26 else 27 else if ((x & 0x30000000 L) ! = 0 L) if ((x & 0x10000000 L) ! = 0 L) 28 else 29 else if ((x & 0x40000000 L) ! = 0 L) 30 else 31 else if ((x & 0xffff00000000 L) ! = 0 L) if ((x & 0xff00000000 L) ! = 0 L) if ((x & 0xf00000000 L) ! = 0 L) if ((x & 0x300000000 L) ! = 0 L) if ((x & 0x100000000 L) ! = 0 L) 32 else 33 else if ((x & 0x400000000 L) ! = 0 L) 34 else 35 else if ((x & 0x3000000000 L) ! = 0 L) if ((x & 0x1000000000 L) ! = 0 L) 36 else 37 else if ((x & 0x4000000000 L) ! = 0 L) 38 else 39 else if ((x & 0xf0000000000 L) ! = 0 L) if ((x & 0x30000000000 L) ! = 0 L) if ((x & 0x10000000000 L) ! = 0 L) 40 else 41 else if ((x & 0x40000000000 L) ! = 0 L) 42 else 43 else if ((x & 0x300000000000 L) ! = 0 L) if ((x & 0x100000000000 L) ! = 0 L) 44 else 45 else if ((x & 0x400000000000 L) ! = 0 L) 46 else 47 else if ((x & 0xff000000000000 L) ! = 0 L) if ((x & 0xf000000000000 L) ! = 0 L) if ((x & 3000000000000 L) ! = 0 L) if ((x & 1000000000000 L) ! = 0 L) 48 else 49 else if ((x & 4000000000000 L) ! = 0 L) 50 else 51 else if ((x & 0x30000000000000 L) ! = 0 L) if ((x & 0x10000000000000 L) ! = 0 L) 52 else 53 else if ((x & 0x40000000000000 L) ! = 0 L) 54 else 55 else if ((x & 0xf00000000000000 L) ! = 0 L) if ((x & 0x300000000000000 L) ! = 0 L) if ((x & 0x100000000000000 L) ! = 0 L) 56 else 57 else if ((x & 0x400000000000000 L) ! = 0 L) 58 else 59 else if ((x & 0x3000000000000000 L) ! = 0 L) if ((x & 0x1000000000000000 L) ! = 0 L) 60 else 61 else if ((x & 0x4000000000000000 L) ! = 0 L) 62 else 63 } |
対数計算を利用
1 2 3 4 5 6 7 8 9 | private [ this ] val ln 2 = math.log( 2 ) def ntzLogarithm(x : Long) : Int = if (x == 0 L) 64 else if (x == 0x8000000000000000 L) 63 else (math.log(x & -x) / ln 2 ).toInt |
ハッシュを利用 (黒魔術)
1 2 3 4 5 6 | private [ this ] val ntzHash = 0x03F566ED27179461 L private [ this ] lazy val ntzTable : Array[Int] = ( 0 until 64 ).map(i = > (ntzHash << i >>> 58 ).toInt -> i).sorted.map( _ . _ 2 ).toArray def ntzHash(x : Long) : Int = if (x ! = 0 L) ntzTable((((x & -x) * ntzHash) >>> 58 ).toInt) else 64 |
ネイティブでの実装
C++ のインラインアセンブリを使って x86-64の命令を直接呼び出す。
それを JNI/JNA でラップして Scala から呼び出した。
JNI (Java Native Interface)
- C++ コード
1 2 3 4 5 6 7 8 9 10 11 12 | #include <jni.h> extern "C" JNIEXPORT jint JNICALL Java_com_github_mogproject_util_BitOperation_ntzJni (JNIEnv *env, jobject obj, jlong x) { if (x == 0) return 64; int pos = 0; __asm__( "bsf %1,%q0;" : "=r" (pos) : "rm" (x)); return pos; } |
- ビルド
$ g++ -dynamiclib -O3 \ -I/usr/include -I$JAVA_HOME/include -I$JAVA_HOME/include/darwin \ bitops_jni.cpp -o libbitops_jni.dylib |
- Scala コード
1 2 3 4 5 6 7 8 9 10 11 12 13 | package com.github.mogproject.util class BitOperation { @ native def ntzJni(x : Long) : Int } object BitOperation { System.load( "/path/to/libbitops_jni.dylib" ) private [ this ] val bitop = new BitOperation def ntzJni(x : = > Long) = bitop.ntzJni(x) } |
JNA (Java Native Access)
- C++ コード
1 2 3 4 5 6 7 8 9 | extern "C" int ntz(unsigned long long x) { if (x == 0) return 64; int pos = 0; __asm__( "bsf %1,%q0;" : "=r" (pos) : "rm" (x)); return pos; } |
- ビルド
$ g++ -dynamiclib -O3 bitops_jna.cpp -o libbitops_jna.dylib |
- Scala コード
1 2 3 4 5 6 | private [ this ] val lib = com.sun.jna.NativeLibrary.getInstance( "/path/to/libbitops_jna.dylib" ) private [ this ] val func = lib.getFunction( "ntz" ) def ntzJna(x : Long) : Int = func.invokeInt(Array(x.asInstanceOf[Object])) |
テスト
とりあえず REPL でランダムなLong型整数を与え、全てのメソッドの結果が一致することを確認。
scala> import com.github.mogproject.util.BitOperation. _ import com.github.mogproject.util.BitOperation. _ scala> def f(x : Long) = Seq(ntzNaive(x), ntzLoop 1 (x), ntzLoop 2 (x), | ntzBinarySearch 1 (x), ntzBinarySearch 2 (x), ntzBinarySearch 3 (x), ntzLogarithm(x), | ntzHash(x), ntzJni(x), ntzJna(x)) f : (x : Long)Seq[Int] scala> f( 0 ) res 0 : Seq[Int] = List( 64 , 64 , 64 , 64 , 64 , 64 , 64 , 64 , 64 , 64 ) scala> f(- 1 ) res 1 : Seq[Int] = List( 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ) scala> f( 0x8000000000000000 L) res 2 : Seq[Int] = List( 63 , 63 , 63 , 63 , 63 , 63 , 63 , 63 , 63 , 63 ) scala> Seq.fill( 1000 )(scala.util.Random.nextLong) forall (f( _ ).toSet.size == 1 ) res 3 : Boolean = true |
ベンチマーク
sbt-jmh でベンチマークを取る。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | import com.github.mogproject.util.BitOperation import org.openjdk.jmh.annotations.{Benchmark, Scope, State = > JmhState} import scala.util.Random @ JmhState(Scope.Thread) class BitOperationBench { val random = new Random( 12345 ) @ Benchmark def ntzNaive() : Unit = BitOperation.ntzNaive(random.nextLong()) @ Benchmark def ntzLoop 1 () : Unit = BitOperation.ntzLoop 1 (random.nextLong()) @ Benchmark def ntzLoop 2 () : Unit = BitOperation.ntzLoop 2 (random.nextLong()) @ Benchmark def ntzBinarySearch 1 () : Unit = BitOperation.ntzBinarySearch 1 (random.nextLong()) @ Benchmark def ntzBinarySearch 2 () : Unit = BitOperation.ntzBinarySearch 2 (random.nextLong()) @ Benchmark def ntzBinarySearch 3 () : Unit = BitOperation.ntzBinarySearch 3 (random.nextLong()) @ Benchmark def ntzLogarithm() : Unit = BitOperation.ntzLogarithm(random.nextLong()) @ Benchmark def ntzHash() : Unit = BitOperation.ntzHash(random.nextLong()) @ Benchmark def ntzJni() : Unit = BitOperation.ntzJni(random.nextLong()) @ Benchmark def ntzJna() : Unit = BitOperation.ntzJna(random.nextLong()) } |
測定結果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | [info] Running org.openjdk.jmh.Main -i 10 -wi 10 -f1 -t1 .*BitOperation.* (snip) [info] Benchmark Mode Samples Score Score error Units [info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch1 thrpt 10 35184555.294 147974.730 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch2 thrpt 10 35130401.937 234705.330 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch3 thrpt 10 29846819.149 352975.842 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzHash thrpt 10 35127105.013 127143.455 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzJna thrpt 10 1396115.477 16139.308 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzJni thrpt 10 25644475.628 231921.707 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzLogarithm thrpt 10 35144781.544 92927.496 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzLoop1 thrpt 10 30099013.164 308642.730 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzLoop2 thrpt 10 15183889.536 142090.869 ops/s [info] c.g.m.m.u.b.BitOperationBench.ntzNaive thrpt 10 14309678.759 317520.811 ops/s |
飛び抜けて良いものがある、という訳ではなく
BinarySearch1, BinarySearch2, Logarithm, Hash が先頭グループとしてほぼ同等の成績。
JNI はやはり関数コールのペナルティが大きいのか、実用には厳しいレベル。
JNA は一つだけ桁違いに性能が悪かった。
やはり計測あるのみ。
"Trust no one, bench everything." を実感した。
References
- 404 Blog Not Found:C - でも一番右端の立っているビット位置を求めてみた
- chessprogramming - BitScan
- A Simple Java Native Interface (JNI) example in Java and Scala
0 件のコメント:
コメントを投稿