10.07.2014

Counting Number of Trailing Zeros in Scala

Scala: 末尾から続く0の個数を数える

 

与えられた64ビット整数値に対して、それを2進数にしたときに末尾から0が何個連続するかを求めたい。

たとえば十進数 88 の場合、2進数にすると 1011000 となるので、答えは 3 となる。

これは、ntz (number of trailing zeros) として知られる典型的なビット操作の問題である。

 

今回はそれを Scala で解くコードをいくつか準備し、その性能を比較してみた。

 

環境

  • CPU: 1.8 GHz Intel Core i5
  • OS: Mac OS X 10.9.4
  • Scala: 2.11.2
  • sbt: 0.13.5
  • sbt-jmh: 0.1.6

 

Scala での実装

 

1ビットずつシフトして数える

いかにもコストが重そうだが、find を使って最初にビットの立った位置を探す方法。
x は負数をとることもあるので、符号なしの右シフト(>>>)が必要な点に注意。

  def ntzNaive(x: Long): Int =
    (0 until 64).find(i => ((x >>> i) & 1L) != 0L).getOrElse(64)

var を使って工夫 (2種類)

  def ntzLoop1(x: Long): Int = {
    var y = ~x & (x - 1)
    var n = 0
    while (y != 0) {n += 1; y >>>= 1}
    n
  }

  def ntzLoop2(x: Long): Int = {
    var y = x
    var n = 64
    while (y != 0) {n -= 1; y <<= 1}
    n
  }
二分探索木を用いた方法 (3種類)
  def ntzBinarySearch1(x: Long): Int = {
    if (x == 0L) {
      64
    } else {
      var n = 1
      var y = x
      if ((y & 0x00000000ffffffffL) == 0L) {n += 32; y >>>= 32}
      if ((y & 0x000000000000ffffL) == 0L) {n += 16; y >>>= 16}
      if ((y & 0x00000000000000ffL) == 0L) {n += 8; y >>>= 8}
      if ((y & 0x000000000000000fL) == 0L) {n += 4; y >>>= 4}
      if ((y & 0x0000000000000003L) == 0L) {n += 2; y >>>= 2}
      n - (y & 1L).toInt
    }
  }

  def ntzBinarySearch2(x: Long): Int = {
    if (x == 0L) {
      64
    } else {
      var n = 63
      var y = 0L
      var z = x << 32; if (z != 0) {n -= 32; y = z}
      z = y << 16; if (z != 0) {n -= 16; y = z}
      z = y << 8; if (z != 0) {n -= 8; y = z}
      z = y << 4; if (z != 0) {n -= 4; y = z}
      z = y << 2; if (z != 0) {n -= 2; y = z}
      z = y << 1; if (z != 0) n -= 1
      n
    }
  }

  def ntzBinarySearch3(x: Long): Int = {
    if (x == 0L)
      64
    else
      if ((x & 0xffffffffL) != 0L)
        if ((x & 0xffffL) != 0L)
          if ((x & 0xffL) != 0L)
            if ((x & 0xfL) != 0L)
              if ((x & 3L) != 0L)
                if ((x & 1L) != 0L) 0 else 1
              else
                if ((x & 4L) != 0L) 2 else 3
            else
              if ((x & 0x30L) != 0L)
                if ((x & 0x10L) != 0L) 4 else 5
              else
                if ((x & 0x40L) != 0L) 6 else 7
          else
            if ((x & 0xf00L) != 0L)
              if ((x & 0x300L) != 0L)
                if ((x & 0x100L) != 0L) 8 else 9
              else
                if ((x & 0x400L) != 0L) 10 else 11
            else
              if ((x & 0x3000L) != 0L)
                if ((x & 0x1000L) != 0L) 12 else 13
              else
                if ((x & 0x4000L) != 0L) 14 else 15
        else
          if ((x & 0xff0000L) != 0L)
            if ((x & 0xf0000L) != 0L)
              if ((x & 30000L) != 0L)
                if ((x & 10000L) != 0L) 16 else 17
              else
                if ((x & 40000L) != 0L) 18 else 19
            else
              if ((x & 0x300000L) != 0L)
                if ((x & 0x100000L) != 0L) 20 else 21
              else
                if ((x & 0x400000L) != 0L) 22 else 23
          else
            if ((x & 0xf000000L) != 0L)
              if ((x & 0x3000000L) != 0L)
                if ((x & 0x1000000L) != 0L) 24 else 25
              else
                if ((x & 0x4000000L) != 0L) 26 else 27
            else
              if ((x & 0x30000000L) != 0L)
                if ((x & 0x10000000L) != 0L) 28 else 29
              else
                if ((x & 0x40000000L) != 0L) 30 else 31
      else
        if ((x & 0xffff00000000L) != 0L)
          if ((x & 0xff00000000L) != 0L)
            if ((x & 0xf00000000L) != 0L)
              if ((x & 0x300000000L) != 0L)
                if ((x & 0x100000000L) != 0L) 32 else 33
              else
                if ((x & 0x400000000L) != 0L) 34 else 35
            else
              if ((x & 0x3000000000L) != 0L)
                if ((x & 0x1000000000L) != 0L) 36 else 37
              else
                if ((x & 0x4000000000L) != 0L) 38 else 39
          else
            if ((x & 0xf0000000000L) != 0L)
              if ((x & 0x30000000000L) != 0L)
                if ((x & 0x10000000000L) != 0L) 40 else 41
              else
                if ((x & 0x40000000000L) != 0L) 42 else 43
            else
              if ((x & 0x300000000000L) != 0L)
                if ((x & 0x100000000000L) != 0L) 44 else 45
              else
                if ((x & 0x400000000000L) != 0L) 46 else 47
        else
          if ((x & 0xff000000000000L) != 0L)
            if ((x & 0xf000000000000L) != 0L)
              if ((x & 3000000000000L) != 0L)
                if ((x & 1000000000000L) != 0L) 48 else 49
              else
                if ((x & 4000000000000L) != 0L) 50 else 51
            else
              if ((x & 0x30000000000000L) != 0L)
                if ((x & 0x10000000000000L) != 0L) 52 else 53
              else
                if ((x & 0x40000000000000L) != 0L) 54 else 55
          else
            if ((x & 0xf00000000000000L) != 0L)
              if ((x & 0x300000000000000L) != 0L)
                if ((x & 0x100000000000000L) != 0L) 56 else 57
              else
                if ((x & 0x400000000000000L) != 0L) 58 else 59
            else
              if ((x & 0x3000000000000000L) != 0L)
                if ((x & 0x1000000000000000L) != 0L) 60 else 61
              else
                if ((x & 0x4000000000000000L) != 0L) 62 else 63
  }
対数計算を利用
  private[this] val ln2 = math.log(2)

  def ntzLogarithm(x: Long): Int =
    if (x == 0L)
      64
    else if (x == 0x8000000000000000L)
      63
    else
      (math.log(x & -x) / ln2).toInt
ハッシュを利用 (黒魔術)
  private[this] val ntzHash = 0x03F566ED27179461L
  private[this] lazy val ntzTable: Array[Int] =
    (0 until 64).map(i => (ntzHash << i >>> 58).toInt -> i).sorted.map(_._2).toArray

  def ntzHash(x: Long): Int =
    if (x != 0L) ntzTable((((x & -x) * ntzHash) >>> 58).toInt) else 64

 

ネイティブでの実装

C++ のインラインアセンブリを使って x86-64の命令を直接呼び出す。
それを JNI/JNA でラップして Scala から呼び出した。

JNI (Java Native Interface)
  • C++ コード
#include <jni.h>

extern "C" JNIEXPORT jint JNICALL Java_com_github_mogproject_util_BitOperation_ntzJni
  (JNIEnv *env, jobject obj, jlong x) {
  if (x == 0) return 64;
  int pos = 0;
  __asm__(
    "bsf %1,%q0;"
    :"=r"(pos)
    :"rm"(x));
  return pos;
}
  • ビルド
$ g++ -dynamiclib -O3 \
-I/usr/include -I$JAVA_HOME/include -I$JAVA_HOME/include/darwin \
bitops_jni.cpp -o libbitops_jni.dylib
  • Scala コード
package com.github.mogproject.util

class BitOperation {
  @native def ntzJni(x: Long): Int
}

object BitOperation {
  System.load("/path/to/libbitops_jni.dylib")

  private[this] val bitop = new BitOperation

  def ntzJni(x: => Long) = bitop.ntzJni(x)
}
JNA (Java Native Access)
  • C++ コード
extern "C" int ntz(unsigned long long x) {
  if (x == 0) return 64;
  int pos = 0;
  __asm__(
    "bsf %1,%q0;"
    :"=r"(pos)
    :"rm"(x));
  return pos;
}
  • ビルド
$ g++ -dynamiclib -O3 bitops_jna.cpp -o libbitops_jna.dylib
  • Scala コード
  private[this] val lib =
    com.sun.jna.NativeLibrary.getInstance("/path/to/libbitops_jna.dylib")

  private[this] val func = lib.getFunction("ntz")

  def ntzJna(x: Long): Int = func.invokeInt(Array(x.asInstanceOf[Object]))

 

テスト

とりあえず REPL でランダムなLong型整数を与え、全てのメソッドの結果が一致することを確認。

scala> import com.github.mogproject.util.BitOperation._
import com.github.mogproject.util.BitOperation._

scala> def f(x: Long) = Seq(ntzNaive(x), ntzLoop1(x), ntzLoop2(x),
     | ntzBinarySearch1(x), ntzBinarySearch2(x), ntzBinarySearch3(x), ntzLogarithm(x),
     | ntzHash(x), ntzJni(x), ntzJna(x))
f: (x: Long)Seq[Int]

scala> f(0)
res0: Seq[Int] = List(64, 64, 64, 64, 64, 64, 64, 64, 64, 64)

scala> f(-1)
res1: Seq[Int] = List(0, 0, 0, 0, 0, 0, 0, 0, 0, 0)

scala> f(0x8000000000000000L)
res2: Seq[Int] = List(63, 63, 63, 63, 63, 63, 63, 63, 63, 63)

scala> Seq.fill(1000)(scala.util.Random.nextLong) forall (f(_).toSet.size == 1)
res3: Boolean = true

 

ベンチマーク

sbt-jmh でベンチマークを取る。

import com.github.mogproject.util.BitOperation
import org.openjdk.jmh.annotations.{Benchmark, Scope, State => JmhState}

import scala.util.Random


@JmhState(Scope.Thread)
class BitOperationBench {
  val random = new Random(12345)

  @Benchmark def ntzNaive(): Unit = BitOperation.ntzNaive(random.nextLong())

  @Benchmark def ntzLoop1(): Unit = BitOperation.ntzLoop1(random.nextLong())
  @Benchmark def ntzLoop2(): Unit = BitOperation.ntzLoop2(random.nextLong())

  @Benchmark def ntzBinarySearch1(): Unit = BitOperation.ntzBinarySearch1(random.nextLong())
  @Benchmark def ntzBinarySearch2(): Unit = BitOperation.ntzBinarySearch2(random.nextLong())
  @Benchmark def ntzBinarySearch3(): Unit = BitOperation.ntzBinarySearch3(random.nextLong())

  @Benchmark def ntzLogarithm(): Unit = BitOperation.ntzLogarithm(random.nextLong())
  @Benchmark def ntzHash(): Unit = BitOperation.ntzHash(random.nextLong())

  @Benchmark def ntzJni(): Unit = BitOperation.ntzJni(random.nextLong())
  @Benchmark def ntzJna(): Unit = BitOperation.ntzJna(random.nextLong())
}
測定結果
[info] Running org.openjdk.jmh.Main -i 10 -wi 10 -f1 -t1 .*BitOperation.*

(snip)

[info] Benchmark                                          Mode  Samples         Score  Score error  Units
[info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch1    thrpt       10  35184555.294   147974.730  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch2    thrpt       10  35130401.937   234705.330  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch3    thrpt       10  29846819.149   352975.842  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzHash             thrpt       10  35127105.013   127143.455  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzJna              thrpt       10   1396115.477    16139.308  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzJni              thrpt       10  25644475.628   231921.707  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzLogarithm        thrpt       10  35144781.544    92927.496  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzLoop1            thrpt       10  30099013.164   308642.730  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzLoop2            thrpt       10  15183889.536   142090.869  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzNaive            thrpt       10  14309678.759   317520.811  ops/s

飛び抜けて良いものがある、という訳ではなく
BinarySearch1, BinarySearch2, Logarithm, Hash が先頭グループとしてほぼ同等の成績。

JNI はやはり関数コールのペナルティが大きいのか、実用には厳しいレベル。
JNA は一つだけ桁違いに性能が悪かった。

やはり計測あるのみ。
"Trust no one, bench everything." を実感した。

 

References

 

 

Related Posts

0 件のコメント:

コメントを投稿