10.28.2014

Python: How to Parse DateTime with Local Timezone

Python: ローカルのタイムゾーン情報付きで日時表記文字列を読み込む

 

Python の datetime クラスは、"naive" と "aware" の 2種類の顔を持つ。

"aware" な datetime はタイムゾーンの情報を持ち、"naive" は持たない。

プログラム内部では "aware" な datetime を利用し、"naive" は入出力境界のみにとどめるのがセオリーだ。

 

タイムゾーンの自動認識には、dateutil.tz.tzlocal を使うのが最も適当な様子。

たとえば datetime.strptime を使って日時を認識するコードは以下のようになる。

>>> from datetime import datetime
>>> from dateutil.tz import tzlocal
>>> t = datetime.strptime('201410281234', '%Y%m%d%H%M').replace(tzinfo=tzlocal())
>>> t
datetime.datetime(2014, 10, 28, 12, 34, tzinfo=tzlocal())

datetime#astimezone を使ってタイムゾーンを変更する。

>>> import pytz
>>> t.astimezone(pytz.utc)
datetime.datetime(2014, 10, 28, 3, 34, tzinfo=<UTC>)

epoch time への変換には一手間かかる。calendar.timegm を使うのがベストプラクティスのようだ。

>>> import calendar
>>> calendar.timegm(t.timetuple())
1414499640
>>> calendar.timegm(t.astimezone(pytz.utc).timetuple())
1414467240

10.22.2014

Ansible: Playbook for AWS CloudWatch Monitoring Scripts

Ansible: CloudWatch 用 Linux 監視スクリプトをインストールする Playbook

 

前提

  • EC2 インスタンスの OS は Amazon Linux とする
  • 認証情報はファイルに保存する
    • インスタンス作成前であれば、IAM Role の設定により認証情報の保持が不要となる
    • CloudWatch API 専用の IAM ユーザの作成を推奨
  • 各種パラメータについては vars/main.yml を参照
  • 課金が発生する可能性があるので注意

 

コード

---
- name: install additional perl modules
  yum: name={{ item }} state=present
  with_items:
    - perl-Switch
    - perl-Sys-Syslog
    - perl-LWP-Protocol-https
  tags: aws-scripts-mon

- name: check if script is installed
  command: /usr/bin/test -e {{ path_to_script }}
  ignore_errors: True
  changed_when: False
  register: is_installed
  tags: aws-scripts-mon

- name: download scripts from AWS server
  get_url: url={{ download_url }} dest={{ path_to_download }}
  when: is_installed | failed
  tags: aws-scripts-mon

- name: unzip downloaded file
  unarchive: copy=no src={{ path_to_download }} dest={{ home_dir }}
  when: is_installed | failed
  tags: aws-scripts-mon

- name: create credential file
  template: src={{ item }}.j2 dest={{ script_dir }}/{{ item }} owner={{ user }} group={{ user }} mode="0600"
  with_items:
    - awscreds.conf
  tags: aws-scripts-mon

- name: set directory owner
  file: path={{ script_dir }} state=directory owner={{ user }} group={{ user }} recurse=yes
  tags: aws-scripts-mon

- name: remove downloaded file
  file: path={{ path_to_download }} state=absent
  tags: aws-scripts-mon

- name: set cron
  cron: user={{ user }}
        state=present
        name="CloudWatch monitoring script"
        minute="{{ cron.minute }}"
        hour="{{ cron.hour }}"
        job="{{ cron.job }}"
  tags: aws-scripts-mon
AWSAccessKeyId={{ access_key }}
AWSSecretKey={{ secret_key }}
---
user: ec2-user
version: 1.1.0
filename: CloudWatchMonitoringScripts-v{{ version }}.zip
download_url: http://ec2-downloads.s3.amazonaws.com/cloudwatch-samples/{{ filename }}

path_to_download: "/tmp/{{ filename }}"

home_dir: "/home/{{ user }}"
script_dir: "{{ home_dir }}/aws-scripts-mon"
path_to_script: "{{ script_dir }}/mon-put-instance-data.pl"

access_key: "{{ aws_cloudwatch_agent_access_key_id }}"
secret_key: "{{ aws_cloudwatch_agent_secret_access_key }}"

cron:
  hour: "*"
  minute: "*/5"
  job: "{{ path_to_script }} --mem-util --mem-used --mem-avail --swap-util --swap-used --disk-path=/ --disk-space-util --disk-space-used --disk-space-avail --aws-credential-file={{ script_dir }}/awscreds.conf --from-cron"

credential 情報 (aws_cloudwatch_agent_xxx) は、extra-vars などで渡す想定。

 

 

References

10.20.2014

Scala: How to Limit DynamoDB's Range Query

Scala: AWS DynamoDB のテーブルに対して件数制限付きのレンジクエリを実行する

 

DynamoDB の range クエリについて、理解が足りなかったので整理しておく。

今回は例として、食事の記録を以下の項目とともに DynamoDB に格納し、
ユーザごとの最新 n 件の食事を調べるクエリを投げるようなアプリケーションを考えてみる。

DynamoDB: FoodLog
  • UserId [ハッシュキー]: ユーザID (文字列)
  • Timestamp [レンジキー]: タイムスタンプ(epoch からの経過時間をミリ秒単位で格納) (整数値)
  • Food: 食事した内容 (文字列)
  • Calorie: 摂取カロリー (整数値)

 

テーブル作成

aws-cli で以下のコマンドを実行し、FoodLog テーブルを作成する。
(aws-cli および認証情報はセットアップ済みの前提)

$ aws dynamodb create-table \
--table-name FoodLog \
--attribute-definitions \
AttributeName=UserId,AttributeType=S \
AttributeName=Timestamp,AttributeType=N \
--key-schema AttributeName=UserId,KeyType=HASH AttributeName=Timestamp,KeyType=RANGE \
--provisioned-throughput ReadCapacityUnits=1,WriteCapacityUnits=1

 

ベース部分実装

Java のライブラリと O/R マッパーを使って FoodLog クラスを実装。
接続情報は環境変数(AWS_ACCESS_KEY, AWS_SECRET_KEY)より与えられる想定である。

package com.github.mogproject.example.dynamodb

import com.amazonaws.auth.BasicAWSCredentials
import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient
import com.amazonaws.services.dynamodbv2.datamodeling._

import scala.annotation.meta.beanGetter
import scala.beans.BeanProperty
import scala.collection.JavaConverters._

trait DynamoDBClient {
  private[this] val accessKeyId = sys.env("AWS_ACCESS_KEY")
  private[this] val secretAccessKey = sys.env("AWS_SECRET_KEY")
  private[this] val region = RegionUtils.getRegion("ap-northeast-1")
  private[this] val endpoint = region.getServiceEndpoint("dynamodb")

  private[this] val credentials = new BasicAWSCredentials(accessKeyId, secretAccessKey)
  private[this] val client = {
    val ret = new AmazonDynamoDBClient(credentials)
    ret.setRegion(region)
    ret.setEndpoint(endpoint)
    ret
  }
  protected val mapper = new DynamoDBMapper(client)

  def batchSave(xs: FoodLog*) = mapper.batchSave(xs.asJava)

  def batchDelete(xs: FoodLog*) = mapper.batchDelete(xs.asJava)

  def batchWrite(toWrite: Seq[FoodLog], toDelete: Seq[FoodLog]) = mapper.batchWrite(toWrite.asJava, toDelete.asJava)
}

@DynamoDBTable(tableName = "FoodLog")
case class FoodLog(
                    @(DynamoDBHashKey@beanGetter)(attributeName = "UserId") @BeanProperty var userId: String,
                    @(DynamoDBRangeKey@beanGetter)(attributeName = "Timestamp") @BeanProperty var timestamp: Long,
                    @DynamoDBAttribute(attributeName = "Food") @BeanProperty var food: String,
                    @DynamoDBAttribute(attributeName = "Calorie") @BeanProperty var calorie: Int
                    ) {
  def this() = this(null, 0, null, 0)
}

object FoodLog extends DynamoDBClient {
  def readRecent(userId: String, limit: Int): Seq[FoodLog] = ???
}

 

テストデータ投入

REPL を使い、2ユーザx10件ずつのランダムなテストデータを投入する。

$ export AWS_ACCESS_KEY="xxxxxx"
$ export AWS_SECRET_KEY="xxxxxx"
$ sbt console

scala> import com.github.mogproject.example.dynamodb.FoodLog
import com.github.mogproject.example.dynamodb.FoodLog

scala> import scala.util.Random
import scala.util.Random

scala> val item1 = (1 to 10).map(i => FoodLog("user-1", Random.nextInt(100000), s"food-$i", Random.nextInt(2000)))
item1: scala.collection.immutable.IndexedSeq[com.github.mogproject.example.dynamodb.FoodLog] = Vector(FoodLog(user-1,78548,food-1,1911), FoodLog(user-1,67632,food-2,974), FoodLog(user-1,34756,food-3,1639), FoodLog(user-1,15595,food-4,937), FoodLog(user-1,77366,food-5,158), FoodLog(user-1,9615,food-6,393), FoodLog(user-1,64601,food-7,429), FoodLog(user-1,6847,food-8,1834), FoodLog(user-1,55271,food-9,1434), FoodLog(user-1,74394,food-10,885))

scala> val item2 = (1 to 10).map(i => FoodLog("user-2", Random.nextInt(100000), s"food-$i", Random.nextInt(2000)))
item2: scala.collection.>immutable.IndexedSeq[com.github.mogproject.example.dynamodb.FoodLog] = Vector(FoodLog(user-2,15618,food-1,1356), FoodLog(user-2,27456,food-2,123), FoodLog(user-2,62137,food-3,1122), FoodLog(user-2,43501,food-4,673), FoodLog(user-2,80906,food-5,577), FoodLog(user-2,96682,food-6,1112), FoodLog(user-2,40193,food-7,1961), FoodLog(user-2,44857,food-8,1064), FoodLog(user-2,88767,food-9,1618), FoodLog(user-2,42126,food-10,761))

scala> FoodLog.batchSave(item1 ++ item2: _*)
res0: java.util.List[com.amazonaws.services.dynamodbv2.datamodeling.DynamoDBMapper.FailedBatch] = []

 

クエリ部分(仮)実装

件数制限付きのクエリを素直に書くと、以下のようになる。

  def readRecent(userId: String, limit: Int): Seq[FoodLog] = {
    val query = new DynamoDBQueryExpression[FoodLog]()
      .withHashKeyValues(FoodLog(userId, 0, null, 0))
      .withScanIndexForward(false)
      .withLimit(limit)
      .withConsistentRead(false)
    mapper.query(classOf[FoodLog], query).asScala
  }

そして limit=5 としてクエリを実行すると、結果は …… 10個ある!?

$ sbt console
scala> import com.github.mogproject.example.dynamodb.FoodLog
import com.github.mogproject.example.dynamodb.FoodLog

scala> FoodLog.readRecent("user-1", 5)
res0: Seq[com.github.mogproject.example.dynamodb.FoodLog] = Buffer(FoodLog(user-1,78548,food-1,1911), FoodLog(user-1,77366,food-5,158), FoodLog(user-1,74394,food-10,885), FoodLog(user-1,67632,food-2,974), FoodLog(user-1,64601,food-7,429), FoodLog(user-1,55271,food-9,1434), FoodLog(user-1,34756,food-3,1639), FoodLog(user-1,15595,food-4,937), FoodLog(user-1,9615,food-6,393), FoodLog(user-1,6847,food-8,1834))

scala> res0.size
res1: Int = 10

scala> res0 foreach println
FoodLog(user-1,78548,food-1,1911)
FoodLog(user-1,77366,food-5,158)
FoodLog(user-1,74394,food-10,885)
FoodLog(user-1,67632,food-2,974)
FoodLog(user-1,64601,food-7,429)
FoodLog(user-1,55271,food-9,1434)
FoodLog(user-1,34756,food-3,1639)
FoodLog(user-1,15595,food-4,937)
FoodLog(user-1,9615,food-6,393)
FoodLog(user-1,6847,food-8,1834)

 

これはどういうことなのか

sbt run で実行可能なプログラムを作成し、build.sbt に以下の記述を行って http-wire ログを出力してみる。

javaOptions in run ++= Seq(
  "-Dorg.apache.commons.logging.Log=org.apache.commons.logging.impl.SimpleLog",
  "-Dorg.apache.commons.logging.simplelog.showdatetime=true",
  "-Dorg.apache.commons.logging.simplelog.log.org.apache.http.wire=DEBUG"
)

fork in run := true

すると、以下のように DynamoDB との通信が 2回発生していることがわかる。

[error] 2014/10/20 01:09:19:440 JST [DEBUG] wire - >> "POST / HTTP/1.1[\r][\n]"
[error] 2014/10/20 01:09:19:441 JST [DEBUG] wire - >> "Host: dynamodb.ap-northeast-1.amazonaws.com[\r][\n]"
(snip)
[error] 2014/10/20 01:09:19:442 JST [DEBUG] wire - >> "{"TableName":"FoodLog","Limit":5,"ConsistentRead":false,"KeyConditions":{"UserId":{"AttributeValueList":[{"S":"user-1"}],"ComparisonOperator":"EQ"}},"ScanIndexForward":false}"
[error] 2014/10/20 01:09:19:468 JST [DEBUG] wire - << "HTTP/1.1 200 OK[\r][\n]"
(snip)
[error] 2014/10/20 01:09:19:478 JST [DEBUG] wire - << "{"Count":5,"Items":[{"UserId":{"S":"user-1"},"Timestamp":{"N":"78548"},"food":{"S":"food-1"},"calorie":{"N":"1911"}},{"UserId":{"S":"user-1"},"Timestamp":{"N":"77366"},"food":{"S":"food-5"},"calorie":{"N":"158"}},{"UserId":{"S":"user-1"},"Timestamp":{"N":"74394"},"food":{"S":"food-10"},"calorie":{"N":"885"}},{"UserId":{"S":"user-1"},"Timestamp":{"N":"67632"},"food":{"S":"food-2"},"calorie":{"N":"974"}},{"UserId":{"S":"user-1"},"Timestamp":{"N":"64601"},"food":{"S":"food-7"},"calorie":{"N":"429"}}],"LastEvaluatedKey":{"Timestamp":{"N":"64601"},"UserId":{"S":"user-1"}},"ScannedCount":5}"
[info] FoodLog(user-1,78548,food-1,1911)
[info] FoodLog(user-1,77366,food-5,158)
[info] FoodLog(user-1,74394,food-10,885)
[info] FoodLog(user-1,67632,food-2,974)
[info] FoodLog(user-1,64601,food-7,429)
[error] 2014/10/20 01:09:19:504 JST [DEBUG] wire - >> "POST / HTTP/1.1[\r][\n]"
[error] 2014/10/20 01:09:19:504 JST [DEBUG] wire - >> "Host: dynamodb.ap-northeast-1.amazonaws.com[\r][\n]"
(snip)
[error] 2014/10/20 01:09:19:505 JST [DEBUG] wire - >> "{"TableName":"FoodLog","Limit":5,"ConsistentRead":false,"KeyConditions":{"UserId":{"AttributeValueList":[{"S":"user-1"}],"ComparisonOperator":"EQ"}},"ScanIndexForward":false,"ExclusiveStartKey":{"UserId":{"S":"user-1"},"Timestamp":{"N":"64601"}}}"
[error] 2014/10/20 01:09:19:533 JST [DEBUG] wire - << "HTTP/1.1 200 OK[\r][\n]"
(snip)
[error] 2014/10/20 01:09:19:533 JST [DEBUG] wire - << "{"Count":5,"Items":[{"UserId":{"S":"user-1"},"Timestamp":{"N":"55271"},"food":{"S":"food-9"},"calorie":{"N":"1434"}},{"UserId":{"S":"user-1"},"Timestamp":{"N":"34756"},"food":{"S":"food-3"},"calorie":{"N":"1639"}},{"UserId":{"S":"user-1"},"Timestamp":{"N":"15595"},"food":{"S":"food-4"},"calorie":{"N":"937"}},{"UserId":{"S":"user-1"},"Timestamp":{"N":"9615"},"food":{"S":"food-6"},"calorie":{"N":"393"}},{"UserId":{"S":"user-1"},"Timestamp":{"N":"6847"},"food":{"S":"food-8"},"calorie":{"N":"1834"}}],"ScannedCount":5}"
[info] FoodLog(user-1,55271,food-9,1434)
[info] FoodLog(user-1,34756,food-3,1639)
[info] FoodLog(user-1,15595,food-4,937)
[info] FoodLog(user-1,9615,food-6,393)
[info] FoodLog(user-1,6847,food-8,1834)

改めてAPIマニュアル(DynamoDBQueryExpression (AWS SDK for Java - 1.9.1))を読む。

Sets the maximum number of items to retrieve in each service request to DynamoDB and returns a pointer to this object for method-chaining.

Note that when calling DynamoDBMapper.query, multiple requests are made to DynamoDB if needed to retrieve the entire result set. Setting this will limit the number of items retrieved by each request, NOT the total number of results that will be retrieved. Use DynamoDBMapper.queryPage to retrieve a single page of items from DynamoDB.

つまるところ、withLimit で指定しているのはクエリ全体の取得件数ではなく、
1回のリクエストで取得するサイズ (サービスリクエストにおける1ページのサイズ) なのである。

DynamoDBMapper.query メソッドは(ページ単位で遅延評価となる)全体の結果セットを返すため、
その結果に map や size などの横断的な処理を適用すると結果セット全体がスキャンされてしまう。

求める結果が最初のページだけでよければ、DynamoDBMapper.queryPage を利用するのが正解だ。

 

クエリ部分の正しい実装

ハイライト部分を修正。

  def readRecent(userId: String, limit: Int): Seq[FoodLog] = {
    val query = new DynamoDBQueryExpression[FoodLog]()
      .withHashKeyValues(FoodLog(userId, 0, null, 0))
      .withScanIndexForward(false)
      .withLimit(limit)
      .withConsistentRead(false)
    mapper.queryPage(classOf[FoodLog], query).getResults.asScala
  }

結果は想定通り、5個のみとなった。

scala> import com.github.mogproject.example.dynamodb.FoodLog
import com.github.mogproject.example.dynamodb.FoodLog

scala> FoodLog.readRecent("user-1", 5)
res0: Seq[com.github.mogproject.example.dynamodb.FoodLog] = Buffer(FoodLog(user-1,78548,food-1,1911), FoodLog(user-1,77366,food-5,158), FoodLog(user-1,74394,food-10,885), FoodLog(user-1,67632,food-2,974), FoodLog(user-1,64601,food-7,429))

scala> res0.size
res1: Int = 5

scala> res0 foreach println
FoodLog(user-1,78548,food-1,1911)
FoodLog(user-1,77366,food-5,158)
FoodLog(user-1,74394,food-10,885)
FoodLog(user-1,67632,food-2,974)
FoodLog(user-1,64601,food-7,429)

DynamoDBのスキャンが想定外に繰り返されると、応答が遅くなるだけでなく
読み込みキャパシティの限界突破のリスクも非常に高くなる。
このような落とし穴にはよくよく注意が必要である。

 

さいごに、簡単なベンチマークを行った。

100件のデータを用意し、「limit=5 の全ページスキャン」「limit=5 の最初のページのみスキャン」「limit=100 の最初のページのみスキャン」の所要時間を測定したところ、結果は以下のようになった。

  • limit=5 の全ページスキャン                    : 340msec
  • limit=5 の最初のページのみスキャン      :  19msec
  • limit=100 の最初のページのみスキャン  :  28msec

やはり、DynamoDBに対してHTTP通信を繰り返す(上記の例では20回)のは非常にコストが高い。
全ページのスキャンを行うくらいなら、最初から limit を引き上げたほうが得策だろう。

 

 

Source code

 

References

10.19.2014

Scala: Property-Based Testing with ScalaTest and ScalaCheck

Scala: ScalaTest + ScalaCheck を使ってプロパティベースのユニットテストを行う

  • ScalaTest とは
    多くの Scala プロジェクトで採用されているユニットテスト・フレームワークのデファクトの一つ。
    双璧をなす Specs2 とどちらを使うかは好みの問題。
  • ScalaCheck とは
    Haskell の QuickCheck から派生したツール。
    実行時にランダムな入力を自動的に生成し、振る舞いをテストすることができる。
    テスト仕様とテストデータを分離できるようになるのが嬉しい。
    ScalaTest/Specs2 には ScalaCheck を透過的に扱うための仕組みが標準で組み込まれている。

 

sbt の設定

ScalaTest/ScalaCheck を使うにはsbt プロジェクトを作るのが一番簡便だろう。

build.sbt などに、以下のように依存ライブラリを追加する。

libraryDependencies ++= Seq(
  "org.scalatest" %% "scalatest" % "2.2.0" % "test",
  "org.scalacheck" %% "scalacheck" % "1.11.6" % "test"
)

 

REPL で ScalaCheck を試す

sbt test:console で REPL を立ち上げれば、すぐに ScalaCheck の動作に触れることができる。

はじめての ScalaCheck

任意のInt型整数に対して、2 を掛けるのと自分自身を足し合わせた結果が同じになることを確かめる。
ランダムな入力 100個に対するテストが行われ、結果「OK」が表示される。

scala> import org.scalacheck.Prop.forAll
import org.scalacheck.Prop.forAll

scala> val prop = forAll { x: Int => x * 2 == x + x }
prop: org.scalacheck.Prop = Prop

scala> prop.check
+ OK, passed 100 tests.

テストに失敗する例。x * 2 でオーバーフローする場合、その値を 2 で割っても x には戻らない。

scala> val prop = forAll { x: Int => x * 2 / 2 == x }

scala> prop.check
! Falsified after 4 passed tests.
> ARG_0: 1073741824
> ARG_0_ORIGINAL: 1461275699

Int 以外の組み込み型も、そのまま直感的に扱える。

scala> val prop =
     |   forAll { (x: String, y: String, n: Int) => (x + y).startsWith(x.take(n)) }
prop: org.scalacheck.Prop = Prop

scala> prop.check
+ OK, passed 100 tests.

scala> val prop =
     |   forAll { xss: List[List[Int]] => xss.map(_.length).sum == xss.flatten.length }
prop: org.scalacheck.Prop = Prop

scala> prop.check

OK, passed 100 tests.

ランダムに生成される String や List の長さ(サイズ)は、デフォルトでは 0以上 100以下となる。

 

入力値の制約

入力に対して制約を与えるには、以下2種類の方法がある。

1. ジェネレータ(Gen)を与える

scala> import org.scalacheck.Gen
import org.scalacheck.Gen

scala> val prop = forAll(Gen.choose(-10000, 10000)){ x: Int => x * 2 / 2 == x }
prop: org.scalacheck.Prop = Prop

scala> prop.check
+ OK, passed 100 tests.

2. テストの内部でフィルタリングする (マッチしない場合のテストを放棄)

scala> import org.scalacheck.Prop.BooleanOperators
import org.scalacheck.Prop.BooleanOperators

scala> val prop = forAll{ x: Int => (-10000 <= x && x <= 10000) ==> (x * 2 / 2 == x) }
prop: org.scalacheck.Prop = Prop

scala> prop.check
+ OK, passed 100 tests.

ただし後者の場合、制約が強すぎるとテストが全く行われない可能性があるので注意。

scala> val prop = forAll{ x: Int => (x == 12345) ==> (x * 2 / 2 == x) }
prop: org.scalacheck.Prop = Prop

scala> prop.check
! Gave up after only 0 passed tests. 101 tests were discarded.

 

ScalaTest の一部として ScalaCheck を使う

BDD スタイルの FlatSpec を使う場合の例。

単純なテストであれば、org.scalatest.prop.Checkers.check を使うだけでよい。

import org.scalatest.prop.Checkers.check
import org.scalatest.{MustMatchers, FlatSpec}

class ExampleSpec extends FlatSpec with MustMatchers {
  "multiplying integer with two" should "be same as adding itself" in check { x: Int =>
    x * 2 == x + x
  }
}

sbt test を流せば以下のような結果が表示される。

[info] ExampleSpec:
[info] multiplying integer with two
[info] - should be same as adding itself

GeneratorDrivenPropertyChecks トレイトをミックスインすれば、より複雑なテストケースを書けるようになる。

例えばこのような平面座標の距離を求める簡単な関数に対して、

case class Coord(x: Double, y: Double) {
  def distance(c: Coord) = math.sqrt(math.pow(c.x - x, 2) + math.pow(c.y - y, 2))
}

ScalaCheck を使ってみる。

import org.scalatest.prop.GeneratorDrivenPropertyChecks
import org.scalatest.{MustMatchers, FlatSpec}

class CoordSpec extends FlatSpec with MustMatchers with GeneratorDrivenPropertyChecks {
  "Coord#distance" should "be same as norm when one side is origin" in forAll { (x: Double, y: Double) =>
    val norm = math.sqrt(x * x + y * y)
    Coord(x, y).distance(Coord(0, 0)) mustBe norm
    Coord(0, 0).distance(Coord(x, y)) mustBe norm
  }
}

 

独自のクラスに対してテストを行う

組み込み型以外のクラスを任意に生成するには、独自のジェネレータを作ればよい。

ジェネレータの実装例

  def genCoord: Gen[Coord] =
    for {
      x <- Gen.choose(-100.0, 100.0)
      y <- Gen.choose(-100.0, 100.0)
    } yield Coord(x, y)

Gen.choose, Gen.oneOf, Gen.someOf がよく使われる。
Gen.frequency で出現頻度を調整したり、Gen.suchThat で制約を付けたり
org.scalacheck.Arbitrary.arbitrary で任意の値を選んだりもできる。(参考: User Guide · rickynils/scalacheck Wiki)

テストコードはこのような形になった。

import org.scalacheck.Gen
import org.scalacheck.Arbitrary.arbitrary
import org.scalatest.prop.GeneratorDrivenPropertyChecks
import org.scalatest.{MustMatchers, FlatSpec}

class CoordSpec extends FlatSpec with MustMatchers with GeneratorDrivenPropertyChecks {
  def genCoord: Gen[Coord] =
    for {
      x <- Gen.choose(-100.0, 100.0)
      y <- Gen.choose(-100.0, 100.0)
    } yield Coord(x, y)

  "Coord#distance" should "be norm when one side is origin" in forAll { (x: Double, y: Double) =>
    val norm = math.sqrt(x * x + y * y)
    Coord(x, y).distance(Coord(0, 0)) mustBe norm
    Coord(0, 0).distance(Coord(x, y)) mustBe norm
  }

  it should "be zero with same coordinates" in forAll(genCoord) { (a: Coord) =>
    a.distance(a) mustBe 0.0
  }
  it should "be positive or zero" in forAll(genCoord, genCoord) { (a: Coord, b: Coord) =>
    a.distance(b) must be >= 0.0
  }
  it should "be less than 300" in forAll(genCoord, genCoord) { (a: Coord, b: Coord) =>
    a.distance(b) must be < 300.0
  }
  it should "not change after swapping parameters" in forAll(genCoord, genCoord) { (a: Coord, b: Coord) =>
    a.distance(b) mustBe b.distance(a)
  }
  it should "not change after parallel shift" in forAll(
    genCoord, genCoord, Gen.choose(-100.0, 100.0), arbitrary[Int], minSuccessful(500), maxDiscarded(2000)) {

    (a: Coord, b: Coord, dx: Double, dy: Int) => whenever(-10000 <= dy && dy <= 10000) {
      a.distance(b) mustBe (Coord(a.x + dx, a.y + dy).distance(Coord(b.x + dx, b.y + dy)) +- 1.0E-8)
    }
  }
}

最後の例は、whenever で入力を制限したり、+- を使うことで浮動小数点数の誤差を吸収している。

 

 

Source code

 

References

10.07.2014

Counting Number of Trailing Zeros in Scala

Scala: 末尾から続く0の個数を数える

 

与えられた64ビット整数値に対して、それを2進数にしたときに末尾から0が何個連続するかを求めたい。

たとえば十進数 88 の場合、2進数にすると 1011000 となるので、答えは 3 となる。

これは、ntz (number of trailing zeros) として知られる典型的なビット操作の問題である。

 

今回はそれを Scala で解くコードをいくつか準備し、その性能を比較してみた。

 

環境

  • CPU: 1.8 GHz Intel Core i5
  • OS: Mac OS X 10.9.4
  • Scala: 2.11.2
  • sbt: 0.13.5
  • sbt-jmh: 0.1.6

 

Scala での実装

 

1ビットずつシフトして数える

いかにもコストが重そうだが、find を使って最初にビットの立った位置を探す方法。
x は負数をとることもあるので、符号なしの右シフト(>>>)が必要な点に注意。

  def ntzNaive(x: Long): Int =
    (0 until 64).find(i => ((x >>> i) & 1L) != 0L).getOrElse(64)

var を使って工夫 (2種類)

  def ntzLoop1(x: Long): Int = {
    var y = ~x & (x - 1)
    var n = 0
    while (y != 0) {n += 1; y >>>= 1}
    n
  }

  def ntzLoop2(x: Long): Int = {
    var y = x
    var n = 64
    while (y != 0) {n -= 1; y <<= 1}
    n
  }
二分探索木を用いた方法 (3種類)
  def ntzBinarySearch1(x: Long): Int = {
    if (x == 0L) {
      64
    } else {
      var n = 1
      var y = x
      if ((y & 0x00000000ffffffffL) == 0L) {n += 32; y >>>= 32}
      if ((y & 0x000000000000ffffL) == 0L) {n += 16; y >>>= 16}
      if ((y & 0x00000000000000ffL) == 0L) {n += 8; y >>>= 8}
      if ((y & 0x000000000000000fL) == 0L) {n += 4; y >>>= 4}
      if ((y & 0x0000000000000003L) == 0L) {n += 2; y >>>= 2}
      n - (y & 1L).toInt
    }
  }

  def ntzBinarySearch2(x: Long): Int = {
    if (x == 0L) {
      64
    } else {
      var n = 63
      var y = 0L
      var z = x << 32; if (z != 0) {n -= 32; y = z}
      z = y << 16; if (z != 0) {n -= 16; y = z}
      z = y << 8; if (z != 0) {n -= 8; y = z}
      z = y << 4; if (z != 0) {n -= 4; y = z}
      z = y << 2; if (z != 0) {n -= 2; y = z}
      z = y << 1; if (z != 0) n -= 1
      n
    }
  }

  def ntzBinarySearch3(x: Long): Int = {
    if (x == 0L)
      64
    else
      if ((x & 0xffffffffL) != 0L)
        if ((x & 0xffffL) != 0L)
          if ((x & 0xffL) != 0L)
            if ((x & 0xfL) != 0L)
              if ((x & 3L) != 0L)
                if ((x & 1L) != 0L) 0 else 1
              else
                if ((x & 4L) != 0L) 2 else 3
            else
              if ((x & 0x30L) != 0L)
                if ((x & 0x10L) != 0L) 4 else 5
              else
                if ((x & 0x40L) != 0L) 6 else 7
          else
            if ((x & 0xf00L) != 0L)
              if ((x & 0x300L) != 0L)
                if ((x & 0x100L) != 0L) 8 else 9
              else
                if ((x & 0x400L) != 0L) 10 else 11
            else
              if ((x & 0x3000L) != 0L)
                if ((x & 0x1000L) != 0L) 12 else 13
              else
                if ((x & 0x4000L) != 0L) 14 else 15
        else
          if ((x & 0xff0000L) != 0L)
            if ((x & 0xf0000L) != 0L)
              if ((x & 30000L) != 0L)
                if ((x & 10000L) != 0L) 16 else 17
              else
                if ((x & 40000L) != 0L) 18 else 19
            else
              if ((x & 0x300000L) != 0L)
                if ((x & 0x100000L) != 0L) 20 else 21
              else
                if ((x & 0x400000L) != 0L) 22 else 23
          else
            if ((x & 0xf000000L) != 0L)
              if ((x & 0x3000000L) != 0L)
                if ((x & 0x1000000L) != 0L) 24 else 25
              else
                if ((x & 0x4000000L) != 0L) 26 else 27
            else
              if ((x & 0x30000000L) != 0L)
                if ((x & 0x10000000L) != 0L) 28 else 29
              else
                if ((x & 0x40000000L) != 0L) 30 else 31
      else
        if ((x & 0xffff00000000L) != 0L)
          if ((x & 0xff00000000L) != 0L)
            if ((x & 0xf00000000L) != 0L)
              if ((x & 0x300000000L) != 0L)
                if ((x & 0x100000000L) != 0L) 32 else 33
              else
                if ((x & 0x400000000L) != 0L) 34 else 35
            else
              if ((x & 0x3000000000L) != 0L)
                if ((x & 0x1000000000L) != 0L) 36 else 37
              else
                if ((x & 0x4000000000L) != 0L) 38 else 39
          else
            if ((x & 0xf0000000000L) != 0L)
              if ((x & 0x30000000000L) != 0L)
                if ((x & 0x10000000000L) != 0L) 40 else 41
              else
                if ((x & 0x40000000000L) != 0L) 42 else 43
            else
              if ((x & 0x300000000000L) != 0L)
                if ((x & 0x100000000000L) != 0L) 44 else 45
              else
                if ((x & 0x400000000000L) != 0L) 46 else 47
        else
          if ((x & 0xff000000000000L) != 0L)
            if ((x & 0xf000000000000L) != 0L)
              if ((x & 3000000000000L) != 0L)
                if ((x & 1000000000000L) != 0L) 48 else 49
              else
                if ((x & 4000000000000L) != 0L) 50 else 51
            else
              if ((x & 0x30000000000000L) != 0L)
                if ((x & 0x10000000000000L) != 0L) 52 else 53
              else
                if ((x & 0x40000000000000L) != 0L) 54 else 55
          else
            if ((x & 0xf00000000000000L) != 0L)
              if ((x & 0x300000000000000L) != 0L)
                if ((x & 0x100000000000000L) != 0L) 56 else 57
              else
                if ((x & 0x400000000000000L) != 0L) 58 else 59
            else
              if ((x & 0x3000000000000000L) != 0L)
                if ((x & 0x1000000000000000L) != 0L) 60 else 61
              else
                if ((x & 0x4000000000000000L) != 0L) 62 else 63
  }
対数計算を利用
  private[this] val ln2 = math.log(2)

  def ntzLogarithm(x: Long): Int =
    if (x == 0L)
      64
    else if (x == 0x8000000000000000L)
      63
    else
      (math.log(x & -x) / ln2).toInt
ハッシュを利用 (黒魔術)
  private[this] val ntzHash = 0x03F566ED27179461L
  private[this] lazy val ntzTable: Array[Int] =
    (0 until 64).map(i => (ntzHash << i >>> 58).toInt -> i).sorted.map(_._2).toArray

  def ntzHash(x: Long): Int =
    if (x != 0L) ntzTable((((x & -x) * ntzHash) >>> 58).toInt) else 64

 

ネイティブでの実装

C++ のインラインアセンブリを使って x86-64の命令を直接呼び出す。
それを JNI/JNA でラップして Scala から呼び出した。

JNI (Java Native Interface)
  • C++ コード
#include <jni.h>

extern "C" JNIEXPORT jint JNICALL Java_com_github_mogproject_util_BitOperation_ntzJni
  (JNIEnv *env, jobject obj, jlong x) {
  if (x == 0) return 64;
  int pos = 0;
  __asm__(
    "bsf %1,%q0;"
    :"=r"(pos)
    :"rm"(x));
  return pos;
}
  • ビルド
$ g++ -dynamiclib -O3 \
-I/usr/include -I$JAVA_HOME/include -I$JAVA_HOME/include/darwin \
bitops_jni.cpp -o libbitops_jni.dylib
  • Scala コード
package com.github.mogproject.util

class BitOperation {
  @native def ntzJni(x: Long): Int
}

object BitOperation {
  System.load("/path/to/libbitops_jni.dylib")

  private[this] val bitop = new BitOperation

  def ntzJni(x: => Long) = bitop.ntzJni(x)
}
JNA (Java Native Access)
  • C++ コード
extern "C" int ntz(unsigned long long x) {
  if (x == 0) return 64;
  int pos = 0;
  __asm__(
    "bsf %1,%q0;"
    :"=r"(pos)
    :"rm"(x));
  return pos;
}
  • ビルド
$ g++ -dynamiclib -O3 bitops_jna.cpp -o libbitops_jna.dylib
  • Scala コード
  private[this] val lib =
    com.sun.jna.NativeLibrary.getInstance("/path/to/libbitops_jna.dylib")

  private[this] val func = lib.getFunction("ntz")

  def ntzJna(x: Long): Int = func.invokeInt(Array(x.asInstanceOf[Object]))

 

テスト

とりあえず REPL でランダムなLong型整数を与え、全てのメソッドの結果が一致することを確認。

scala> import com.github.mogproject.util.BitOperation._
import com.github.mogproject.util.BitOperation._

scala> def f(x: Long) = Seq(ntzNaive(x), ntzLoop1(x), ntzLoop2(x),
     | ntzBinarySearch1(x), ntzBinarySearch2(x), ntzBinarySearch3(x), ntzLogarithm(x),
     | ntzHash(x), ntzJni(x), ntzJna(x))
f: (x: Long)Seq[Int]

scala> f(0)
res0: Seq[Int] = List(64, 64, 64, 64, 64, 64, 64, 64, 64, 64)

scala> f(-1)
res1: Seq[Int] = List(0, 0, 0, 0, 0, 0, 0, 0, 0, 0)

scala> f(0x8000000000000000L)
res2: Seq[Int] = List(63, 63, 63, 63, 63, 63, 63, 63, 63, 63)

scala> Seq.fill(1000)(scala.util.Random.nextLong) forall (f(_).toSet.size == 1)
res3: Boolean = true

 

ベンチマーク

sbt-jmh でベンチマークを取る。

import com.github.mogproject.util.BitOperation
import org.openjdk.jmh.annotations.{Benchmark, Scope, State => JmhState}

import scala.util.Random


@JmhState(Scope.Thread)
class BitOperationBench {
  val random = new Random(12345)

  @Benchmark def ntzNaive(): Unit = BitOperation.ntzNaive(random.nextLong())

  @Benchmark def ntzLoop1(): Unit = BitOperation.ntzLoop1(random.nextLong())
  @Benchmark def ntzLoop2(): Unit = BitOperation.ntzLoop2(random.nextLong())

  @Benchmark def ntzBinarySearch1(): Unit = BitOperation.ntzBinarySearch1(random.nextLong())
  @Benchmark def ntzBinarySearch2(): Unit = BitOperation.ntzBinarySearch2(random.nextLong())
  @Benchmark def ntzBinarySearch3(): Unit = BitOperation.ntzBinarySearch3(random.nextLong())

  @Benchmark def ntzLogarithm(): Unit = BitOperation.ntzLogarithm(random.nextLong())
  @Benchmark def ntzHash(): Unit = BitOperation.ntzHash(random.nextLong())

  @Benchmark def ntzJni(): Unit = BitOperation.ntzJni(random.nextLong())
  @Benchmark def ntzJna(): Unit = BitOperation.ntzJna(random.nextLong())
}
測定結果
[info] Running org.openjdk.jmh.Main -i 10 -wi 10 -f1 -t1 .*BitOperation.*

(snip)

[info] Benchmark                                          Mode  Samples         Score  Score error  Units
[info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch1    thrpt       10  35184555.294   147974.730  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch2    thrpt       10  35130401.937   234705.330  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzBinarySearch3    thrpt       10  29846819.149   352975.842  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzHash             thrpt       10  35127105.013   127143.455  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzJna              thrpt       10   1396115.477    16139.308  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzJni              thrpt       10  25644475.628   231921.707  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzLogarithm        thrpt       10  35144781.544    92927.496  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzLoop1            thrpt       10  30099013.164   308642.730  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzLoop2            thrpt       10  15183889.536   142090.869  ops/s
[info] c.g.m.m.u.b.BitOperationBench.ntzNaive            thrpt       10  14309678.759   317520.811  ops/s

飛び抜けて良いものがある、という訳ではなく
BinarySearch1, BinarySearch2, Logarithm, Hash が先頭グループとしてほぼ同等の成績。

JNI はやはり関数コールのペナルティが大きいのか、実用には厳しいレベル。
JNA は一つだけ桁違いに性能が悪かった。

やはり計測あるのみ。
"Trust no one, bench everything." を実感した。

 

References

 

 

Related Posts

10.05.2014

Micro Benchmark in Scala - Using sbt-jmh

Scala でマイクロベンチマーク - sbt-jmh を使ってみる

 

sbt-jmh は、sbt コンソールの中で、Java のマイクロベンチマークツールである jmh を扱えるようにする
sbt プラグインである。

Scala のマイクロベンチマークにデファクトスタンダードが存在するかどうかは分からないが、

といった経緯もあり、今回触れてみることにした。

 

環境
  • OS: Mac OS X 10.9.4
  • Scala: 2.11.2
  • sbt: 0.13.5
  • sbt-jmh: 0.1.6

 

sbt 定義ファイルの設定

まずは sbt の各種設定ファイルにプラグインの追加設定を行う。

  • plugins.sbt
    addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.1.6")
  • build.sbt (利用時のみ)
    jmhSettings
  • Build.scala (利用時のみ/一例)
    import pl.project13.scala.sbt.SbtJmh._
    
    sbt.Project(...).settings(jmhSettings: _*)
    

 

ベンチマーク対象コードの記述

src/main/scala 配下の任意の .scala ファイルに適当なクラスを作って(objectではダメ)、
測定したいメソッドにアノテーション @org.openjdk.jmh.annotations.Benchmark を付けるだけでよい。

  • 例: 数値の Set/List を作成し、contains メソッドを実行するまでの所要時間を計測
    import org.openjdk.jmh.annotations.Benchmark
    
    class ContainsBench {
      @Benchmark
      def setContains(): Unit = (1 to 100000).toSet.contains(100001)
    
      @Benchmark
      def listContains(): Unit = (1 to 100000).toList.contains(100001)
    }
  • クラス内で状態を保持したり、パラメータ付きのメソッドを定義したい場合は
    @State など他のアノテーションの利用が必要
    import org.openjdk.jmh.annotations.{Benchmark, Scope, State}
    
    @State(Scope.Thread)
    class ContainsBench {
      val xs = (1 to 100000).toSet
      val ys = (1 to 100000).toList
    
      @Benchmark
      def setContains(): Unit = xs.contains(100001)
    
      @Benchmark
      def listContains(): Unit = ys.contains(100001)
    }
    

 

ベンチマークの実行

コードの準備が終わったら、sbt を起動。(マルチプロジェクトの場合は対象プロジェクトへ移動)

run -l でベンチマーク一覧、run -h でヘルプが表示される (オプションは非常に豊富)。

  • 全てのベンチマークの実行
    run -i 3 -wi 3 -f1 -t1

    -i でイテレーション回数、
    -wi でウォームアップイテレーション(測定前に実行される繰り返し)回数を指定。
    正確な測定を行うためには、それぞれ最低でも10〜20回を指定すべきとのこと。

    -f はフォークする数の指定。この回数だけウォームアップ+実測が繰り返される。
    -t はスレッド数。とりあえずは 1 を指定すれば良さそうだ。

    また、いくつかの測定モードが用意されているが、デフォルトではスループット計測モードとなる。
  • 特定のベンチマークの実行
    公式ドキュメントに載っていた例。ワイルドカードを使えるようだ。
    run -i 3 -wi 3 -f1 -t1 .*FalseSharing.*
  • 実行結果の例
    [info] Running org.openjdk.jmh.Main -i 3 -wi 3 -f1 -t1
    [info] # VM invoker: /Library/Java/JavaVirtualMachines/1.7.0.jdk/Contents/Home/jre/bin/java
    [info] # VM options: <none>
    [info] # Warmup: 3 iterations, 1 s each
    [info] # Measurement: 3 iterations, 1 s each
    [info] # Timeout: 10 min per iteration
    [info] # Threads: 1 thread, will synchronize iterations
    [info] # Benchmark mode: Throughput, ops/time
    [info] # Benchmark: com.github.mogproject.util.ContainsBench.listContains
    [info]
    [info] # Run progress: 0.00% complete, ETA 00:00:12
    [info] # Fork: 1 of 1
    [info] # Warmup Iteration   1: 35.322 ops/s
    [info] # Warmup Iteration   2: 40.904 ops/s
    [info] # Warmup Iteration   3: 46.665 ops/s
    [info] Iteration   1: 39.450 ops/s
    [info] Iteration   2: 42.116 ops/s
    [info] Iteration   3: 41.535 ops/s
    [info]
    [info]
    [info] Result: 41.033 ±(99.9%) 25.573 ops/s [Average]
    [info]   Statistics: (min, avg, max) = (39.450, 41.033, 42.116), stdev = 1.402
    [info]   Confidence interval (99.9%): [15.461, 66.606]
    [info]
    [info]
    [info] # VM invoker: /Library/Java/JavaVirtualMachines/1.7.0.jdk/Contents/Home/jre/bin/java
    [info] # VM options: <none>
    [info] # Warmup: 3 iterations, 1 s each
    [info] # Measurement: 3 iterations, 1 s each
    [info] # Timeout: 10 min per iteration
    [info] # Threads: 1 thread, will synchronize iterations
    [info] # Benchmark mode: Throughput, ops/time
    [info] # Benchmark: com.github.mogproject.util.ContainsBench.setContains
    [info]
    [info] # Run progress: 50.00% complete, ETA 00:00:07
    [info] # Fork: 1 of 1
    [info] # Warmup Iteration   1: 4.550 ops/s
    [info] # Warmup Iteration   2: 6.826 ops/s
    [info] # Warmup Iteration   3: 6.980 ops/s
    [info] Iteration   1: 6.730 ops/s
    [info] Iteration   2: 6.901 ops/s
    [info] Iteration   3: 6.798 ops/s
    [info]
    [info]
    [info] Result: 6.810 ±(99.9%) 1.569 ops/s [Average]
    [info]   Statistics: (min, avg, max) = (6.730, 6.810, 6.901), stdev = 0.086
    [info]   Confidence interval (99.9%): [5.241, 8.380]
    [info]
    [info]
    [info] # Run complete. Total time: 00:00:15
    [info]
    [info] Benchmark                              Mode  Samples   Score  Score error  Units
    [info] c.g.m.u.ContainsBench.listContains    thrpt        3  41.033       25.573  ops/s
    [info] c.g.m.u.ContainsBench.setContains     thrpt        3   6.810        1.569  ops/s
    
    contains メソッド自体は Set のほうが圧倒的に速いものの、toSet のコストが大きいため List バージョンのほうが良いスループットが出ることがわかった。

より高度な使い方は公式ドキュメントや以下を参照。

 

 

References

10.01.2014

Scala: How to define initial commands of REPL in SBT

Scala: SBTのREPL起動時に自動で実行されるコマンドを定義する

 

sbt console (マルチプロジェクトの場合は sbt project_name/console)を実行すると
プロジェクトの実行環境が完全にセットアップされた状態で REPL が立ち上がる。

このとき sbt の Key である initialCommands を定義すれば、
その内容がREPL初期化時に自動的に実行されるようになる。

毎回手で打つと面倒な import 文や、ちょっと試すためのデータなどを定義すると非常に捗る。

initialCommands in console := "import your.package.name._"

 

参考: initialCommands を駆使した例