9.19.2015

Python: How to Implement Thread-Safe Auto-Increment in Redis

Python: Redis上でスレッドセーフなオート・インクリメントを実現する

 

目的

Redisで以下のような3種類のDBを使い、色々な名前の登録処理をしたい。

  • 名前 -> ID の検索テーブル
  • ID -> 名前 の検索テーブル
  • 特定の key で ID の登録数 (=払い出したIDの最大値) を保持するテーブル

ID は 1 を起点とするオート・インクリメントなもので、名前ごとにユニークな数値を割り当てる。

 

インターフェース

1個の Redis インスタンスの中にある 3つの DBを使用する想定。それぞれのDB番号とカウンターとして使用するキーの名前を指定して初期化。

register メソッドは文字列として name を受け取り、その name に対応するユニークな ID を返す。name が DB に登録されていない場合のみ、登録処理を行う。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import redis
 
 
class Registerer(object):
    def __init__(self, counter_key, db_counter, db, db_invert=None,
                 host='localhost', port=6379):
        self.counter_key = counter_key
        self.db_counter = db_counter
        self.db = db
        self.db_invert = db_invert
        self.host = host
        self.port = port
        self.redis = redis.Redis(host, port, db)
        self.redis_cnt = redis.Redis(host, port, db_counter)
        self.redis_inv = (redis.Redis(host, port, db_invert)
                          if db_invert is not None else None)
 
    def register(self, name):
        return ???

 

テスト

マルチスレッド・プログラミングをする時はテストがないと不安なので、先にテストを書く。
(実行すると、ローカルの Redis のデータは全て消える)

test_registerer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python
import unittest
import threading
 
from repos.registerer import Registerer
 
 
class TestRegisterer(unittest.TestCase):
    def _clear(self):
        self.r.redis.flushall()
 
    def setUp(self):
        self.num_records = 10000
        self.r = Registerer('count', 15, 14, 13)
        self._clear()
 
    def tearDown(self):
        self._clear()
 
    def test_register_serial(self):
        r = self.r
        for i in range(self.num_records):
            r.register('user-%04d' % i)
 
        for i in range(10):
            r.register('user-%04d' % i)
 
        self.assertEqual(r.redis.dbsize(), self.num_records)
        self.assertEqual(r.redis_inv.dbsize(), self.num_records)
        self.assertEqual(int(r.redis_cnt.get('count')), self.num_records)
 
        for i in range(100):
            self.assertEqual(int(r.redis.get(r.redis_inv.get(i * 50 + 1))), i * 50 + 1)
            self.assertEqual(r.redis_inv.get(i * 50 + 1), 'user-%04d' % (i * 50))
 
    def test_register_parallel(self):
        r = self.r
        threads = [threading.Thread(target=r.register, args=('user-%04d' % i,)) for i in range(10000)]
 
        for t in threads:
            t.start()
 
        for t in threads:
            t.join()
 
        self.assertEqual(r.redis.dbsize(), self.num_records)
        self.assertEqual(r.redis_inv.dbsize(), self.num_records)
        self.assertEqual(int(r.redis_cnt.get('count')), self.num_records)
 
        for i in range(100):
            self.assertEqual(int(r.redis.get(r.redis_inv.get(i * 50 + 1))), i * 50 + 1)
 
            # this will fail
            # self.assertEqual(r.invert_redis.get(i * 50 + 1), 'user-%04d' % (i * 50))
 
    def test_register_parallel_same_id(self):
        r = self.r
        threads = [threading.Thread(target=r.register, args=('user-0000',)) for _ in range(self.num_records)]
 
        for t in threads:
            t.start()
 
        for t in threads:
            t.join()
 
        self.assertEqual(r.redis.dbsize(), 1)
        self.assertEqual(r.redis_inv.dbsize(), 1)
        self.assertEqual(int(r.redis_cnt.get('count')), 1)
        self.assertEqual(int(r.redis.get('user-0000')), 1)
        self.assertEqual(r.redis_inv.get(1), 'user-0000')
 
 
if __name__ == '__main__':
    unittest.main()

 

ナイーブな実装 (問題あり)

名前 -> ID のテーブルを検索し、値が取得できなかったら Redis の INCR コマンドで値を更新し、その ID を正引き/逆引きのテーブルそれぞれ登録する。

念の為に SETNX コマンドを利用し、キーが既に存在していたら例外を送出するようにする。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def register_naive(self, name):
    index = self.redis.get(name)
    if index is None:
        index = self.redis_cnt.incr(self.counter_key)
        if not self.redis.setnx(name, index):
            raise Exception(
                'Failed to register: db=%d, key=%s, value=%s' %
                (self.db, name, index))
        if self.db_invert is not None:
            if not self.redis_inv.setnx(index, name):
                raise Exception(
                    'Failed to register: db=%d, key=%s, value=%s' %
                    (self.db_invert, index, name))
    return int(index)

Redis の INCR はアトミックな処理であるし、一見問題はなさそうだが、先程書いたテストが失敗する。

.Exception in thread Thread-10001:
Traceback (most recent call last):
  File "/Users/xxxxxx/.pyenv/versions/2.7.9/lib/python2.7/threading.py", line 810, in __bootstrap_inner
    self.run()
  File "/Users/xxxxxx/.pyenv/versions/2.7.9/lib/python2.7/threading.py", line 763, in run
    self.__target(*self.__args, **self.__kwargs)
  File "/xxxxxx/registerer.py", line 36, in register_naive
    raise Exception('Failed to register: db=%d, key=%s, value=%s' % (self.db, name, index))
Exception: Failed to register: db=14, key=user-0000, value=1
 
F.
======================================================================
FAIL: test_register_parallel_same_id (__main__.TestRegisterer)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "./tests/test_registerer.py", line 72, in test_register_parallel_same_id
    self.assertEqual(int(r.redis_cnt.get('count')), 1)
AssertionError: 2 != 1
 
----------------------------------------------------------------------
Ran 3 tests in 9.492s
 
FAILED (failures=1)

これは、同じ名前を同時に大量に登録した場合にのみ発生する。
名前 -> IDテーブルのチェックと更新の間のタイミングで競合状態が発生し、同じ名前の登録処理が複数実行されているためである。

 

トランザクションを利用した実装

redis-py には transaction というお誂え向きのヘルパーメソッドが用意されている。

パイプラインを引数に取る関数と、楽観的ロックをかけるキーを指定するだけでシンプルにトランザクションを記述できる。

最終的なコードは以下のようになった。

registerer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import redis
 
 
class Registerer(object):
    def __init__(self, counter_key, db_counter, db, db_invert=None,
                 host='localhost', port=6379):
        self.counter_key = counter_key
        self.db_counter = db_counter
        self.db = db
        self.db_invert = db_invert
        self.host = host
        self.port = port
        self.redis = redis.Redis(host, port, db)
        self.redis_cnt = redis.Redis(host, port, db_counter)
        self.redis_inv = (redis.Redis(host, port, db_invert)
                          if db_invert is not None else None)
 
    def register(self, name):
        def f(pipe):
            index = self.redis.get(name)
            if index is None:
                index = pipe.incr(self.counter_key)
                if not self.redis.setnx(name, index):
                    raise Exception(
                        'Failed to register: db=%d, key=%s, value=%s' %
                        (self.db, name, index))
                if self.db_invert is not None:
                    if not self.redis_inv.setnx(index, name):
                        raise Exception(
                            'Failed to register: db=%d, key=%s, value=%s' %
                            (self.db_invert, index, name))
            return int(index)
 
        return self.redis_cnt.transaction(
            f, self.counter_key, value_from_callable=True)

 

無事、テストも通った。

...
----------------------------------------------------------------------
Ran 3 tests in 13.818s
 
OK

 

References

0 件のコメント:

コメントを投稿