9.19.2015

Python: How to Implement Thread-Safe Auto-Increment in Redis

Python: Redis上でスレッドセーフなオート・インクリメントを実現する

 

目的

Redisで以下のような3種類のDBを使い、色々な名前の登録処理をしたい。

  • 名前 -> ID の検索テーブル
  • ID -> 名前 の検索テーブル
  • 特定の key で ID の登録数 (=払い出したIDの最大値) を保持するテーブル

ID は 1 を起点とするオート・インクリメントなもので、名前ごとにユニークな数値を割り当てる。

 

インターフェース

1個の Redis インスタンスの中にある 3つの DBを使用する想定。それぞれのDB番号とカウンターとして使用するキーの名前を指定して初期化。

register メソッドは文字列として name を受け取り、その name に対応するユニークな ID を返す。name が DB に登録されていない場合のみ、登録処理を行う。

import redis


class Registerer(object):
    def __init__(self, counter_key, db_counter, db, db_invert=None,
                 host='localhost', port=6379):
        self.counter_key = counter_key
        self.db_counter = db_counter
        self.db = db
        self.db_invert = db_invert
        self.host = host
        self.port = port
        self.redis = redis.Redis(host, port, db)
        self.redis_cnt = redis.Redis(host, port, db_counter)
        self.redis_inv = (redis.Redis(host, port, db_invert)
                          if db_invert is not None else None)

    def register(self, name):
        return ???

 

テスト

マルチスレッド・プログラミングをする時はテストがないと不安なので、先にテストを書く。
(実行すると、ローカルの Redis のデータは全て消える)

#!/usr/bin/env python
import unittest
import threading

from repos.registerer import Registerer


class TestRegisterer(unittest.TestCase):
    def _clear(self):
        self.r.redis.flushall()

    def setUp(self):
        self.num_records = 10000
        self.r = Registerer('count', 15, 14, 13)
        self._clear()

    def tearDown(self):
        self._clear()

    def test_register_serial(self):
        r = self.r
        for i in range(self.num_records):
            r.register('user-%04d' % i)

        for i in range(10):
            r.register('user-%04d' % i)

        self.assertEqual(r.redis.dbsize(), self.num_records)
        self.assertEqual(r.redis_inv.dbsize(), self.num_records)
        self.assertEqual(int(r.redis_cnt.get('count')), self.num_records)

        for i in range(100):
            self.assertEqual(int(r.redis.get(r.redis_inv.get(i * 50 + 1))), i * 50 + 1)
            self.assertEqual(r.redis_inv.get(i * 50 + 1), 'user-%04d' % (i * 50))

    def test_register_parallel(self):
        r = self.r
        threads = [threading.Thread(target=r.register, args=('user-%04d' % i,)) for i in range(10000)]

        for t in threads:
            t.start()

        for t in threads:
            t.join()

        self.assertEqual(r.redis.dbsize(), self.num_records)
        self.assertEqual(r.redis_inv.dbsize(), self.num_records)
        self.assertEqual(int(r.redis_cnt.get('count')), self.num_records)

        for i in range(100):
            self.assertEqual(int(r.redis.get(r.redis_inv.get(i * 50 + 1))), i * 50 + 1)

            # this will fail
            # self.assertEqual(r.invert_redis.get(i * 50 + 1), 'user-%04d' % (i * 50))

    def test_register_parallel_same_id(self):
        r = self.r
        threads = [threading.Thread(target=r.register, args=('user-0000',)) for _ in range(self.num_records)]

        for t in threads:
            t.start()

        for t in threads:
            t.join()

        self.assertEqual(r.redis.dbsize(), 1)
        self.assertEqual(r.redis_inv.dbsize(), 1)
        self.assertEqual(int(r.redis_cnt.get('count')), 1)
        self.assertEqual(int(r.redis.get('user-0000')), 1)
        self.assertEqual(r.redis_inv.get(1), 'user-0000')


if __name__ == '__main__':
    unittest.main()

 

ナイーブな実装 (問題あり)

名前 -> ID のテーブルを検索し、値が取得できなかったら Redis の INCR コマンドで値を更新し、その ID を正引き/逆引きのテーブルそれぞれ登録する。

念の為に SETNX コマンドを利用し、キーが既に存在していたら例外を送出するようにする。

    def register_naive(self, name):
        index = self.redis.get(name)
        if index is None:
            index = self.redis_cnt.incr(self.counter_key)
            if not self.redis.setnx(name, index):
                raise Exception(
                    'Failed to register: db=%d, key=%s, value=%s' %
                    (self.db, name, index))
            if self.db_invert is not None:
                if not self.redis_inv.setnx(index, name):
                    raise Exception(
                        'Failed to register: db=%d, key=%s, value=%s' %
                        (self.db_invert, index, name))
        return int(index)

Redis の INCR はアトミックな処理であるし、一見問題はなさそうだが、先程書いたテストが失敗する。

.Exception in thread Thread-10001:
Traceback (most recent call last):
  File "/Users/xxxxxx/.pyenv/versions/2.7.9/lib/python2.7/threading.py", line 810, in __bootstrap_inner
    self.run()
  File "/Users/xxxxxx/.pyenv/versions/2.7.9/lib/python2.7/threading.py", line 763, in run
    self.__target(*self.__args, **self.__kwargs)
  File "/xxxxxx/registerer.py", line 36, in register_naive
    raise Exception('Failed to register: db=%d, key=%s, value=%s' % (self.db, name, index))
Exception: Failed to register: db=14, key=user-0000, value=1

F.
======================================================================
FAIL: test_register_parallel_same_id (__main__.TestRegisterer)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "./tests/test_registerer.py", line 72, in test_register_parallel_same_id
    self.assertEqual(int(r.redis_cnt.get('count')), 1)
AssertionError: 2 != 1

----------------------------------------------------------------------
Ran 3 tests in 9.492s

FAILED (failures=1)

これは、同じ名前を同時に大量に登録した場合にのみ発生する。
名前 -> IDテーブルのチェックと更新の間のタイミングで競合状態が発生し、同じ名前の登録処理が複数実行されているためである。

 

トランザクションを利用した実装

redis-py には transaction というお誂え向きのヘルパーメソッドが用意されている。

パイプラインを引数に取る関数と、楽観的ロックをかけるキーを指定するだけでシンプルにトランザクションを記述できる。

最終的なコードは以下のようになった。

import redis


class Registerer(object):
    def __init__(self, counter_key, db_counter, db, db_invert=None,
                 host='localhost', port=6379):
        self.counter_key = counter_key
        self.db_counter = db_counter
        self.db = db
        self.db_invert = db_invert
        self.host = host
        self.port = port
        self.redis = redis.Redis(host, port, db)
        self.redis_cnt = redis.Redis(host, port, db_counter)
        self.redis_inv = (redis.Redis(host, port, db_invert)
                          if db_invert is not None else None)

    def register(self, name):
        def f(pipe):
            index = self.redis.get(name)
            if index is None:
                index = pipe.incr(self.counter_key)
                if not self.redis.setnx(name, index):
                    raise Exception(
                        'Failed to register: db=%d, key=%s, value=%s' %
                        (self.db, name, index))
                if self.db_invert is not None:
                    if not self.redis_inv.setnx(index, name):
                        raise Exception(
                            'Failed to register: db=%d, key=%s, value=%s' %
                            (self.db_invert, index, name))
            return int(index)

        return self.redis_cnt.transaction(
            f, self.counter_key, value_from_callable=True)

 

無事、テストも通った。

...
----------------------------------------------------------------------
Ran 3 tests in 13.818s

OK

 

References

0 件のコメント:

コメントを投稿