10.07.2014

Counting Number of Trailing Zeros in Scala

Scala: 末尾から続く0の個数を数える

 

与えられた64ビット整数値に対して、それを2進数にしたときに末尾から0が何個連続するかを求めたい。

たとえば十進数 88 の場合、2進数にすると 1011000 となるので、答えは 3 となる。

これは、ntz (number of trailing zeros) として知られる典型的なビット操作の問題である。

 

今回はそれを Scala で解くコードをいくつか準備し、その性能を比較してみた。

 

環境

  • CPU: 1.8 GHz Intel Core i5
  • OS: Mac OS X 10.9.4
  • Scala: 2.11.2
  • sbt: 0.13.5
  • sbt-jmh: 0.1.6

 

Scala での実装

 

1ビットずつシフトして数える

いかにもコストが重そうだが、find を使って最初にビットの立った位置を探す方法。
x は負数をとることもあるので、符号なしの右シフト(>>>)が必要な点に注意。

1
2
def ntzNaive(x: Long): Int =
  (0 until 64).find(i => ((x >>> i) & 1L) != 0L).getOrElse(64)

var を使って工夫 (2種類)

1
2
3
4
5
6
7
8
9
10
11
12
13
def ntzLoop1(x: Long): Int = {
  var y = ~x & (x - 1)
  var n = 0
  while (y != 0) {n += 1; y >>>= 1}
  n
}
 
def ntzLoop2(x: Long): Int = {
  var y = x
  var n = 64
  while (y != 0) {n -= 1; y <<= 1}
  n
}
二分探索木を用いた方法 (3種類)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def ntzBinarySearch1(x: Long): Int = {
  if (x == 0L) {
    64
  } else {
    var n = 1
    var y = x
    if ((y & 0x00000000ffffffffL) == 0L) {n += 32; y >>>= 32}
    if ((y & 0x000000000000ffffL) == 0L) {n += 16; y >>>= 16}
    if ((y & 0x00000000000000ffL) == 0L) {n += 8; y >>>= 8}
    if ((y & 0x000000000000000fL) == 0L) {n += 4; y >>>= 4}
    if ((y & 0x0000000000000003L) == 0L) {n += 2; y >>>= 2}
    n - (y & 1L).toInt
  }
}
 
def ntzBinarySearch2(x: Long): Int = {
  if (x == 0L) {
    64
  } else {
    var n = 63
    var y = 0L
    var z = x << 32; if (z != 0) {n -= 32; y = z}
    z = y << 16; if (z != 0) {n -= 16; y = z}
    z = y << 8; if (z != 0) {n -= 8; y = z}
    z = y << 4; if (z != 0) {n -= 4; y = z}
    z = y << 2; if (z != 0) {n -= 2; y = z}
    z = y << 1; if (z != 0) n -= 1
    n
  }
}
 
def ntzBinarySearch3(x: Long): Int = {
  if (x == 0L)
    64
  else
    if ((x & 0xffffffffL) != 0L)
      if ((x & 0xffffL) != 0L)
        if ((x & 0xffL) != 0L)
          if ((x & 0xfL) != 0L)
            if ((x & 3L) != 0L)
              if ((x & 1L) != 0L) 0 else 1
            else
              if ((x & 4L) != 0L) 2 else 3
          else
            if ((x & 0x30L) != 0L)
              if ((x & 0x10L) != 0L) 4 else 5
            else
              if ((x & 0x40L) != 0L) 6 else 7
        else
          if ((x & 0xf00L) != 0L)
            if ((x & 0x300L) != 0L)
              if ((x & 0x100L) != 0L) 8 else 9
            else
              if ((x & 0x400L) != 0L) 10 else 11
          else
            if ((x & 0x3000L) != 0L)
              if ((x & 0x1000L) != 0L) 12 else 13
            else
              if ((x & 0x4000L) != 0L) 14 else 15
      else
        if ((x & 0xff0000L) != 0L)
          if ((x & 0xf0000L) != 0L)
            if ((x & 30000L) != 0L)
              if ((x & 10000L) != 0L) 16 else 17
            else
              if ((x & 40000L) != 0L) 18 else 19
          else
            if ((x & 0x300000L) != 0L)
              if ((x & 0x100000L) != 0L) 20 else 21
            else
              if ((x & 0x400000L) != 0L) 22 else 23
        else
          if ((x & 0xf000000L) != 0L)
            if ((x & 0x3000000L) != 0L)
              if ((x & 0x1000000L) != 0L) 24 else 25
            else
              if ((x & 0x4000000L) != 0L) 26 else 27
          else
            if ((x & 0x30000000L) != 0L)
              if ((x & 0x10000000L) != 0L) 28 else 29
            else
              if ((x & 0x40000000L) != 0L) 30 else 31
    else
      if ((x & 0xffff00000000L) != 0L)
        if ((x & 0xff00000000L) != 0L)
          if ((x & 0xf00000000L) != 0L)
            if ((x & 0x300000000L) != 0L)
              if ((x & 0x100000000L) != 0L) 32 else 33
            else
              if ((x & 0x400000000L) != 0L) 34 else 35
          else
            if ((x & 0x3000000000L) != 0L)
              if ((x & 0x1000000000L) != 0L) 36 else 37
            else
              if ((x & 0x4000000000L) != 0L) 38 else 39
        else
          if ((x & 0xf0000000000L) != 0L)
            if ((x & 0x30000000000L) != 0L)
              if ((x & 0x10000000000L) != 0L) 40 else 41
            else
              if ((x & 0x40000000000L) != 0L) 42 else 43
          else
            if ((x & 0x300000000000L) != 0L)
              if ((x & 0x100000000000L) != 0L) 44 else 45
            else
              if ((x & 0x400000000000L) != 0L) 46 else 47
      else
        if ((x & 0xff000000000000L) != 0L)
          if ((x & 0xf000000000000L) != 0L)
            if ((x & 3000000000000L) != 0L)
              if ((x & 1000000000000L) != 0L) 48 else 49
            else
              if ((x & 4000000000000L) != 0L) 50 else 51
          else
            if ((x & 0x30000000000000L) != 0L)
              if ((x & 0x10000000000000L) != 0L) 52 else 53
            else
              if ((x & 0x40000000000000L) != 0L) 54 else 55
        else
          if ((x & 0xf00000000000000L) != 0L)
            if ((x & 0x300000000000000L) != 0L)
              if ((x & 0x100000000000000L) != 0L) 56 else 57
            else
              if ((x & 0x400000000000000L) != 0L) 58 else 59
          else
            if ((x & 0x3000000000000000L) != 0L)
              if ((x & 0x1000000000000000L) != 0L) 60 else 61
            else
              if ((x & 0x4000000000000000L) != 0L) 62 else 63
}
対数計算を利用
1
2
3
4
5
6
7
8
9
private[this] val ln2 = math.log(2)
 
def ntzLogarithm(x: Long): Int =
  if (x == 0L)
    64
  else if (x == 0x8000000000000000L)
    63
  else
    (math.log(x & -x) / ln2).toInt
ハッシュを利用 (黒魔術)
1
2
3
4
5
6
private[this] val ntzHash = 0x03F566ED27179461L
private[this] lazy val ntzTable: Array[Int] =
  (0 until 64).map(i => (ntzHash << i >>> 58).toInt -> i).sorted.map(_._2).toArray
 
def ntzHash(x: Long): Int =
  if (x != 0L) ntzTable((((x & -x) * ntzHash) >>> 58).toInt) else 64

 

ネイティブでの実装

C++ のインラインアセンブリを使って x86-64の命令を直接呼び出す。
それを JNI/JNA でラップして Scala から呼び出した。

JNI (Java Native Interface)
  • C++ コード
1
2
3
4
5
6
7
8
9
10
11
12
#include <jni.h>
 
extern "C" JNIEXPORT jint JNICALL Java_com_github_mogproject_util_BitOperation_ntzJni
  (JNIEnv *env, jobject obj, jlong x) {
  if (x == 0) return 64;
  int pos = 0;
  __asm__(
    "bsf %1,%q0;"
    :"=r"(pos)
    :"rm"(x));
  return pos;
}
  • ビルド
$ g++ -dynamiclib -O3 \
-I/usr/include -I$JAVA_HOME/include -I$JAVA_HOME/include/darwin \
bitops_jni.cpp -o libbitops_jni.dylib
  • Scala コード
1
2
3
4
5
6
7
8
9
10
11
12
13
package com.github.mogproject.util
 
class BitOperation {
  @native def ntzJni(x: Long): Int
}
 
object BitOperation {
  System.load("/path/to/libbitops_jni.dylib")
 
  private[this] val bitop = new BitOperation
 
  def ntzJni(x: => Long) = bitop.ntzJni(x)
}
JNA (Java Native Access)
  • C++ コード
1
2
3
4
5
6
7
8
9
extern "C" int ntz(unsigned long long x) {
  if (x == 0) return 64;
  int pos = 0;
  __asm__(
    "bsf %1,%q0;"
    :"=r"(pos)
    :"rm"(x));
  return pos;
}
  • ビルド
$ g++ -dynamiclib -O3 bitops_jna.cpp -o libbitops_jna.dylib
  • Scala コード
1
2
3
4
5
6
private[this] val lib =
  com.sun.jna.NativeLibrary.getInstance("/path/to/libbitops_jna.dylib")
 
private[this] val func = lib.getFunction("ntz")
 
def ntzJna(x: Long): Int = func.invokeInt(Array(x.asInstanceOf[Object]))

 

テスト

とりあえず REPL でランダムなLong型整数を与え、全てのメソッドの結果が一致することを確認。

scala> import com.github.mogproject.util.BitOperation._
import com.github.mogproject.util.BitOperation._
 
scala> def f(x: Long) = Seq(ntzNaive(x), ntzLoop1(x), ntzLoop2(x),
     | ntzBinarySearch1(x), ntzBinarySearch2(x), ntzBinarySearch3(x), ntzLogarithm(x),
     | ntzHash(x), ntzJni(x), ntzJna(x))
f: (x: Long)Seq[Int]
 
scala> f(0)
res0: Seq[Int] = List(64, 64, 64, 64, 64, 64, 64, 64, 64, 64)
 
scala> f(-1)
res1: Seq[Int] = List(0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
 
scala> f(0x8000000000000000L)
res2: Seq[Int] = List(63, 63, 63, 63, 63, 63, 63, 63, 63, 63)
 
scala> Seq.fill(1000)(scala.util.Random.nextLong) forall (f(_).toSet.size == 1)
res3: Boolean = true

 

ベンチマーク

sbt-jmh でベンチマークを取る。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import com.github.mogproject.util.BitOperation
import org.openjdk.jmh.annotations.{Benchmark, Scope, State => JmhState}
 
import scala.util.Random
 
 
@JmhState(Scope.Thread)
class BitOperationBench {
  val random = new Random(12345)
 
  @Benchmark def ntzNaive(): Unit = BitOperation.ntzNaive(random.nextLong())
 
  @Benchmark def ntzLoop1(): Unit = BitOperation.ntzLoop1(random.nextLong())
  @Benchmark def ntzLoop2(): Unit = BitOperation.ntzLoop2(random.nextLong())
 
  @Benchmark def ntzBinarySearch1(): Unit = BitOperation.ntzBinarySearch1(random.nextLong())
  @Benchmark def ntzBinarySearch2(): Unit = BitOperation.ntzBinarySearch2(random.nextLong())
  @Benchmark def ntzBinarySearch3(): Unit = BitOperation.ntzBinarySearch3(random.nextLong())
 
  @Benchmark def ntzLogarithm(): Unit = BitOperation.ntzLogarithm(random.nextLong())
  @Benchmark def ntzHash(): Unit = BitOperation.ntzHash(random.nextLong())
 
  @Benchmark def ntzJni(): Unit = BitOperation.ntzJni(random.nextLong())
  @Benchmark def ntzJna(): Unit = BitOperation.ntzJna(random.nextLong())
}
測定結果
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
[info] Running org.openjdk.jmh.Main -i 10 -wi 10 -f1 -t1 .*BitOperation.*
 
(snip)
 
[info] Benchmark                                          Mode  Samples         Score  Score error  Units
[info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch1    thrpt       10  35184555.294   147974.730  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch2    thrpt       10  35130401.937   234705.330  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch3    thrpt       10  29846819.149   352975.842  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzHash             thrpt       10  35127105.013   127143.455  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzJna              thrpt       10   1396115.477    16139.308  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzJni              thrpt       10  25644475.628   231921.707  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzLogarithm        thrpt       10  35144781.544    92927.496  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzLoop1            thrpt       10  30099013.164   308642.730  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzLoop2            thrpt       10  15183889.536   142090.869  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzNaive            thrpt       10  14309678.759   317520.811  ops/s

飛び抜けて良いものがある、という訳ではなく
BinarySearch1, BinarySearch2, Logarithm, Hash が先頭グループとしてほぼ同等の成績。

JNI はやはり関数コールのペナルティが大きいのか、実用には厳しいレベル。
JNA は一つだけ桁違いに性能が悪かった。

やはり計測あるのみ。
"Trust no one, bench everything." を実感した。

 

References

 

 

Related Posts

0 件のコメント:

コメントを投稿