Python: Redis上でスレッドセーフなオート・インクリメントを実現する
目的
Redisで以下のような3種類のDBを使い、色々な名前の登録処理をしたい。
- 名前 -> ID の検索テーブル
- ID -> 名前 の検索テーブル
- 特定の key で ID の登録数 (=払い出したIDの最大値) を保持するテーブル
ID は 1 を起点とするオート・インクリメントなもので、名前ごとにユニークな数値を割り当てる。
インターフェース
1個の Redis インスタンスの中にある 3つの DBを使用する想定。それぞれのDB番号とカウンターとして使用するキーの名前を指定して初期化。
register メソッドは文字列として name を受け取り、その name に対応するユニークな ID を返す。name が DB に登録されていない場合のみ、登録処理を行う。
import redis
class Registerer(object):
def __init__(self, counter_key, db_counter, db, db_invert=None,
host='localhost', port=6379):
self.counter_key = counter_key
self.db_counter = db_counter
self.db = db
self.db_invert = db_invert
self.host = host
self.port = port
self.redis = redis.Redis(host, port, db)
self.redis_cnt = redis.Redis(host, port, db_counter)
self.redis_inv = (redis.Redis(host, port, db_invert)
if db_invert is not None else None)
def register(self, name):
return ???
テスト
マルチスレッド・プログラミングをする時はテストがないと不安なので、先にテストを書く。
(実行すると、ローカルの Redis のデータは全て消える)
#!/usr/bin/env python
import unittest
import threading
from repos.registerer import Registerer
class TestRegisterer(unittest.TestCase):
def _clear(self):
self.r.redis.flushall()
def setUp(self):
self.num_records = 10000
self.r = Registerer('count', 15, 14, 13)
self._clear()
def tearDown(self):
self._clear()
def test_register_serial(self):
r = self.r
for i in range(self.num_records):
r.register('user-%04d' % i)
for i in range(10):
r.register('user-%04d' % i)
self.assertEqual(r.redis.dbsize(), self.num_records)
self.assertEqual(r.redis_inv.dbsize(), self.num_records)
self.assertEqual(int(r.redis_cnt.get('count')), self.num_records)
for i in range(100):
self.assertEqual(int(r.redis.get(r.redis_inv.get(i * 50 + 1))), i * 50 + 1)
self.assertEqual(r.redis_inv.get(i * 50 + 1), 'user-%04d' % (i * 50))
def test_register_parallel(self):
r = self.r
threads = [threading.Thread(target=r.register, args=('user-%04d' % i,)) for i in range(10000)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(r.redis.dbsize(), self.num_records)
self.assertEqual(r.redis_inv.dbsize(), self.num_records)
self.assertEqual(int(r.redis_cnt.get('count')), self.num_records)
for i in range(100):
self.assertEqual(int(r.redis.get(r.redis_inv.get(i * 50 + 1))), i * 50 + 1)
# this will fail
# self.assertEqual(r.invert_redis.get(i * 50 + 1), 'user-%04d' % (i * 50))
def test_register_parallel_same_id(self):
r = self.r
threads = [threading.Thread(target=r.register, args=('user-0000',)) for _ in range(self.num_records)]
for t in threads:
t.start()
for t in threads:
t.join()
self.assertEqual(r.redis.dbsize(), 1)
self.assertEqual(r.redis_inv.dbsize(), 1)
self.assertEqual(int(r.redis_cnt.get('count')), 1)
self.assertEqual(int(r.redis.get('user-0000')), 1)
self.assertEqual(r.redis_inv.get(1), 'user-0000')
if __name__ == '__main__':
unittest.main()
ナイーブな実装 (問題あり)
名前 -> ID のテーブルを検索し、値が取得できなかったら Redis の INCR コマンドで値を更新し、その ID を正引き/逆引きのテーブルそれぞれ登録する。
念の為に SETNX コマンドを利用し、キーが既に存在していたら例外を送出するようにする。
def register_naive(self, name):
index = self.redis.get(name)
if index is None:
index = self.redis_cnt.incr(self.counter_key)
if not self.redis.setnx(name, index):
raise Exception(
'Failed to register: db=%d, key=%s, value=%s' %
(self.db, name, index))
if self.db_invert is not None:
if not self.redis_inv.setnx(index, name):
raise Exception(
'Failed to register: db=%d, key=%s, value=%s' %
(self.db_invert, index, name))
return int(index)
Redis の INCR はアトミックな処理であるし、一見問題はなさそうだが、先程書いたテストが失敗する。
.Exception in thread Thread-10001:
Traceback (most recent call last):
File "/Users/xxxxxx/.pyenv/versions/2.7.9/lib/python2.7/threading.py", line 810, in __bootstrap_inner
self.run()
File "/Users/xxxxxx/.pyenv/versions/2.7.9/lib/python2.7/threading.py", line 763, in run
self.__target(*self.__args, **self.__kwargs)
File "/xxxxxx/registerer.py", line 36, in register_naive
raise Exception('Failed to register: db=%d, key=%s, value=%s' % (self.db, name, index))
Exception: Failed to register: db=14, key=user-0000, value=1
F.
======================================================================
FAIL: test_register_parallel_same_id (__main__.TestRegisterer)
----------------------------------------------------------------------
Traceback (most recent call last):
File "./tests/test_registerer.py", line 72, in test_register_parallel_same_id
self.assertEqual(int(r.redis_cnt.get('count')), 1)
AssertionError: 2 != 1
----------------------------------------------------------------------
Ran 3 tests in 9.492s
FAILED (failures=1)
これは、同じ名前を同時に大量に登録した場合にのみ発生する。
名前 -> IDテーブルのチェックと更新の間のタイミングで競合状態が発生し、同じ名前の登録処理が複数実行されているためである。
トランザクションを利用した実装
redis-py には transaction というお誂え向きのヘルパーメソッドが用意されている。
パイプラインを引数に取る関数と、楽観的ロックをかけるキーを指定するだけでシンプルにトランザクションを記述できる。
最終的なコードは以下のようになった。
import redis
class Registerer(object):
def __init__(self, counter_key, db_counter, db, db_invert=None,
host='localhost', port=6379):
self.counter_key = counter_key
self.db_counter = db_counter
self.db = db
self.db_invert = db_invert
self.host = host
self.port = port
self.redis = redis.Redis(host, port, db)
self.redis_cnt = redis.Redis(host, port, db_counter)
self.redis_inv = (redis.Redis(host, port, db_invert)
if db_invert is not None else None)
def register(self, name):
def f(pipe):
index = self.redis.get(name)
if index is None:
index = pipe.incr(self.counter_key)
if not self.redis.setnx(name, index):
raise Exception(
'Failed to register: db=%d, key=%s, value=%s' %
(self.db, name, index))
if self.db_invert is not None:
if not self.redis_inv.setnx(index, name):
raise Exception(
'Failed to register: db=%d, key=%s, value=%s' %
(self.db_invert, index, name))
return int(index)
return self.redis_cnt.transaction(
f, self.counter_key, value_from_callable=True)
無事、テストも通った。
... ---------------------------------------------------------------------- Ran 3 tests in 13.818s OK