Python: Redis上でスレッドセーフなオート・インクリメントを実現する
目的
Redisで以下のような3種類のDBを使い、色々な名前の登録処理をしたい。
- 名前 -> ID の検索テーブル
- ID -> 名前 の検索テーブル
- 特定の key で ID の登録数 (=払い出したIDの最大値) を保持するテーブル
ID は 1 を起点とするオート・インクリメントなもので、名前ごとにユニークな数値を割り当てる。
インターフェース
1個の Redis インスタンスの中にある 3つの DBを使用する想定。それぞれのDB番号とカウンターとして使用するキーの名前を指定して初期化。
register メソッドは文字列として name を受け取り、その name に対応するユニークな ID を返す。name が DB に登録されていない場合のみ、登録処理を行う。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import redis class Registerer( object ): def __init__( self , counter_key, db_counter, db, db_invert = None , host = 'localhost' , port = 6379 ): self .counter_key = counter_key self .db_counter = db_counter self .db = db self .db_invert = db_invert self .host = host self .port = port self .redis = redis.Redis(host, port, db) self .redis_cnt = redis.Redis(host, port, db_counter) self .redis_inv = (redis.Redis(host, port, db_invert) if db_invert is not None else None ) def register( self , name): return ??? |
テスト
マルチスレッド・プログラミングをする時はテストがないと不安なので、先にテストを書く。
(実行すると、ローカルの Redis のデータは全て消える)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | #!/usr/bin/env python import unittest import threading from repos.registerer import Registerer class TestRegisterer(unittest.TestCase): def _clear( self ): self .r.redis.flushall() def setUp( self ): self .num_records = 10000 self .r = Registerer( 'count' , 15 , 14 , 13 ) self ._clear() def tearDown( self ): self ._clear() def test_register_serial( self ): r = self .r for i in range ( self .num_records): r.register( 'user-%04d' % i) for i in range ( 10 ): r.register( 'user-%04d' % i) self .assertEqual(r.redis.dbsize(), self .num_records) self .assertEqual(r.redis_inv.dbsize(), self .num_records) self .assertEqual( int (r.redis_cnt.get( 'count' )), self .num_records) for i in range ( 100 ): self .assertEqual( int (r.redis.get(r.redis_inv.get(i * 50 + 1 ))), i * 50 + 1 ) self .assertEqual(r.redis_inv.get(i * 50 + 1 ), 'user-%04d' % (i * 50 )) def test_register_parallel( self ): r = self .r threads = [threading.Thread(target = r.register, args = ( 'user-%04d' % i,)) for i in range ( 10000 )] for t in threads: t.start() for t in threads: t.join() self .assertEqual(r.redis.dbsize(), self .num_records) self .assertEqual(r.redis_inv.dbsize(), self .num_records) self .assertEqual( int (r.redis_cnt.get( 'count' )), self .num_records) for i in range ( 100 ): self .assertEqual( int (r.redis.get(r.redis_inv.get(i * 50 + 1 ))), i * 50 + 1 ) # this will fail # self.assertEqual(r.invert_redis.get(i * 50 + 1), 'user-%04d' % (i * 50)) def test_register_parallel_same_id( self ): r = self .r threads = [threading.Thread(target = r.register, args = ( 'user-0000' ,)) for _ in range ( self .num_records)] for t in threads: t.start() for t in threads: t.join() self .assertEqual(r.redis.dbsize(), 1 ) self .assertEqual(r.redis_inv.dbsize(), 1 ) self .assertEqual( int (r.redis_cnt.get( 'count' )), 1 ) self .assertEqual( int (r.redis.get( 'user-0000' )), 1 ) self .assertEqual(r.redis_inv.get( 1 ), 'user-0000' ) if __name__ = = '__main__' : unittest.main() |
ナイーブな実装 (問題あり)
名前 -> ID のテーブルを検索し、値が取得できなかったら Redis の INCR コマンドで値を更新し、その ID を正引き/逆引きのテーブルそれぞれ登録する。
念の為に SETNX コマンドを利用し、キーが既に存在していたら例外を送出するようにする。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | def register_naive( self , name): index = self .redis.get(name) if index is None : index = self .redis_cnt.incr( self .counter_key) if not self .redis.setnx(name, index): raise Exception( 'Failed to register: db=%d, key=%s, value=%s' % ( self .db, name, index)) if self .db_invert is not None : if not self .redis_inv.setnx(index, name): raise Exception( 'Failed to register: db=%d, key=%s, value=%s' % ( self .db_invert, index, name)) return int (index) |
Redis の INCR はアトミックな処理であるし、一見問題はなさそうだが、先程書いたテストが失敗する。
.Exception in thread Thread-10001: Traceback (most recent call last): File "/Users/xxxxxx/.pyenv/versions/2.7.9/lib/python2.7/threading.py", line 810, in __bootstrap_inner self.run() File "/Users/xxxxxx/.pyenv/versions/2.7.9/lib/python2.7/threading.py", line 763, in run self.__target(*self.__args, **self.__kwargs) File "/xxxxxx/registerer.py", line 36, in register_naive raise Exception('Failed to register: db=%d, key=%s, value=%s' % (self.db, name, index)) Exception: Failed to register: db=14, key=user-0000, value=1 F. ====================================================================== FAIL: test_register_parallel_same_id (__main__.TestRegisterer) ---------------------------------------------------------------------- Traceback (most recent call last): File "./tests/test_registerer.py", line 72, in test_register_parallel_same_id self.assertEqual(int(r.redis_cnt.get('count')), 1) AssertionError: 2 != 1 ---------------------------------------------------------------------- Ran 3 tests in 9.492s FAILED (failures=1) |
これは、同じ名前を同時に大量に登録した場合にのみ発生する。
名前 -> IDテーブルのチェックと更新の間のタイミングで競合状態が発生し、同じ名前の登録処理が複数実行されているためである。
トランザクションを利用した実装
redis-py には transaction というお誂え向きのヘルパーメソッドが用意されている。
パイプラインを引数に取る関数と、楽観的ロックをかけるキーを指定するだけでシンプルにトランザクションを記述できる。
最終的なコードは以下のようになった。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | import redis class Registerer( object ): def __init__( self , counter_key, db_counter, db, db_invert = None , host = 'localhost' , port = 6379 ): self .counter_key = counter_key self .db_counter = db_counter self .db = db self .db_invert = db_invert self .host = host self .port = port self .redis = redis.Redis(host, port, db) self .redis_cnt = redis.Redis(host, port, db_counter) self .redis_inv = (redis.Redis(host, port, db_invert) if db_invert is not None else None ) def register( self , name): def f(pipe): index = self .redis.get(name) if index is None : index = pipe.incr( self .counter_key) if not self .redis.setnx(name, index): raise Exception( 'Failed to register: db=%d, key=%s, value=%s' % ( self .db, name, index)) if self .db_invert is not None : if not self .redis_inv.setnx(index, name): raise Exception( 'Failed to register: db=%d, key=%s, value=%s' % ( self .db_invert, index, name)) return int (index) return self .redis_cnt.transaction( f, self .counter_key, value_from_callable = True ) |
無事、テストも通った。
... ---------------------------------------------------------------------- Ran 3 tests in 13.818s OK |