Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/redis/asyncio/lock.py: 24%

115 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-23 06:16 +0000

1import asyncio 

2import threading 

3import uuid 

4from types import SimpleNamespace 

5from typing import TYPE_CHECKING, Awaitable, Optional, Union 

6 

7from redis.exceptions import LockError, LockNotOwnedError 

8 

9if TYPE_CHECKING: 

10 from redis.asyncio import Redis, RedisCluster 

11 

12 

13class Lock: 

14 """ 

15 A shared, distributed Lock. Using Redis for locking allows the Lock 

16 to be shared across processes and/or machines. 

17 

18 It's left to the user to resolve deadlock issues and make sure 

19 multiple clients play nicely together. 

20 """ 

21 

22 lua_release = None 

23 lua_extend = None 

24 lua_reacquire = None 

25 

26 # KEYS[1] - lock name 

27 # ARGV[1] - token 

28 # return 1 if the lock was released, otherwise 0 

29 LUA_RELEASE_SCRIPT = """ 

30 local token = redis.call('get', KEYS[1]) 

31 if not token or token ~= ARGV[1] then 

32 return 0 

33 end 

34 redis.call('del', KEYS[1]) 

35 return 1 

36 """ 

37 

38 # KEYS[1] - lock name 

39 # ARGV[1] - token 

40 # ARGV[2] - additional milliseconds 

41 # ARGV[3] - "0" if the additional time should be added to the lock's 

42 # existing ttl or "1" if the existing ttl should be replaced 

43 # return 1 if the locks time was extended, otherwise 0 

44 LUA_EXTEND_SCRIPT = """ 

45 local token = redis.call('get', KEYS[1]) 

46 if not token or token ~= ARGV[1] then 

47 return 0 

48 end 

49 local expiration = redis.call('pttl', KEYS[1]) 

50 if not expiration then 

51 expiration = 0 

52 end 

53 if expiration < 0 then 

54 return 0 

55 end 

56 

57 local newttl = ARGV[2] 

58 if ARGV[3] == "0" then 

59 newttl = ARGV[2] + expiration 

60 end 

61 redis.call('pexpire', KEYS[1], newttl) 

62 return 1 

63 """ 

64 

65 # KEYS[1] - lock name 

66 # ARGV[1] - token 

67 # ARGV[2] - milliseconds 

68 # return 1 if the locks time was reacquired, otherwise 0 

69 LUA_REACQUIRE_SCRIPT = """ 

70 local token = redis.call('get', KEYS[1]) 

71 if not token or token ~= ARGV[1] then 

72 return 0 

73 end 

74 redis.call('pexpire', KEYS[1], ARGV[2]) 

75 return 1 

76 """ 

77 

78 def __init__( 

79 self, 

80 redis: Union["Redis", "RedisCluster"], 

81 name: Union[str, bytes, memoryview], 

82 timeout: Optional[float] = None, 

83 sleep: float = 0.1, 

84 blocking: bool = True, 

85 blocking_timeout: Optional[float] = None, 

86 thread_local: bool = True, 

87 ): 

88 """ 

89 Create a new Lock instance named ``name`` using the Redis client 

90 supplied by ``redis``. 

91 

92 ``timeout`` indicates a maximum life for the lock in seconds. 

93 By default, it will remain locked until release() is called. 

94 ``timeout`` can be specified as a float or integer, both representing 

95 the number of seconds to wait. 

96 

97 ``sleep`` indicates the amount of time to sleep in seconds per loop 

98 iteration when the lock is in blocking mode and another client is 

99 currently holding the lock. 

100 

101 ``blocking`` indicates whether calling ``acquire`` should block until 

102 the lock has been acquired or to fail immediately, causing ``acquire`` 

103 to return False and the lock not being acquired. Defaults to True. 

104 Note this value can be overridden by passing a ``blocking`` 

105 argument to ``acquire``. 

106 

107 ``blocking_timeout`` indicates the maximum amount of time in seconds to 

108 spend trying to acquire the lock. A value of ``None`` indicates 

109 continue trying forever. ``blocking_timeout`` can be specified as a 

110 float or integer, both representing the number of seconds to wait. 

111 

112 ``thread_local`` indicates whether the lock token is placed in 

113 thread-local storage. By default, the token is placed in thread local 

114 storage so that a thread only sees its token, not a token set by 

115 another thread. Consider the following timeline: 

116 

117 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. 

118 thread-1 sets the token to "abc" 

119 time: 1, thread-2 blocks trying to acquire `my-lock` using the 

120 Lock instance. 

121 time: 5, thread-1 has not yet completed. redis expires the lock 

122 key. 

123 time: 5, thread-2 acquired `my-lock` now that it's available. 

124 thread-2 sets the token to "xyz" 

125 time: 6, thread-1 finishes its work and calls release(). if the 

126 token is *not* stored in thread local storage, then 

127 thread-1 would see the token value as "xyz" and would be 

128 able to successfully release the thread-2's lock. 

129 

130 In some use cases it's necessary to disable thread local storage. For 

131 example, if you have code where one thread acquires a lock and passes 

132 that lock instance to a worker thread to release later. If thread 

133 local storage isn't disabled in this case, the worker thread won't see 

134 the token set by the thread that acquired the lock. Our assumption 

135 is that these cases aren't common and as such default to using 

136 thread local storage. 

137 """ 

138 self.redis = redis 

139 self.name = name 

140 self.timeout = timeout 

141 self.sleep = sleep 

142 self.blocking = blocking 

143 self.blocking_timeout = blocking_timeout 

144 self.thread_local = bool(thread_local) 

145 self.local = threading.local() if self.thread_local else SimpleNamespace() 

146 self.local.token = None 

147 self.register_scripts() 

148 

149 def register_scripts(self): 

150 cls = self.__class__ 

151 client = self.redis 

152 if cls.lua_release is None: 

153 cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) 

154 if cls.lua_extend is None: 

155 cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) 

156 if cls.lua_reacquire is None: 

157 cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) 

158 

159 async def __aenter__(self): 

160 if await self.acquire(): 

161 return self 

162 raise LockError("Unable to acquire lock within the time specified") 

163 

164 async def __aexit__(self, exc_type, exc_value, traceback): 

165 await self.release() 

166 

167 async def acquire( 

168 self, 

169 blocking: Optional[bool] = None, 

170 blocking_timeout: Optional[float] = None, 

171 token: Optional[Union[str, bytes]] = None, 

172 ): 

173 """ 

174 Use Redis to hold a shared, distributed lock named ``name``. 

175 Returns True once the lock is acquired. 

176 

177 If ``blocking`` is False, always return immediately. If the lock 

178 was acquired, return True, otherwise return False. 

179 

180 ``blocking_timeout`` specifies the maximum number of seconds to 

181 wait trying to acquire the lock. 

182 

183 ``token`` specifies the token value to be used. If provided, token 

184 must be a bytes object or a string that can be encoded to a bytes 

185 object with the default encoding. If a token isn't specified, a UUID 

186 will be generated. 

187 """ 

188 sleep = self.sleep 

189 if token is None: 

190 token = uuid.uuid1().hex.encode() 

191 else: 

192 try: 

193 encoder = self.redis.connection_pool.get_encoder() 

194 except AttributeError: 

195 # Cluster 

196 encoder = self.redis.get_encoder() 

197 token = encoder.encode(token) 

198 if blocking is None: 

199 blocking = self.blocking 

200 if blocking_timeout is None: 

201 blocking_timeout = self.blocking_timeout 

202 stop_trying_at = None 

203 if blocking_timeout is not None: 

204 stop_trying_at = asyncio.get_running_loop().time() + blocking_timeout 

205 while True: 

206 if await self.do_acquire(token): 

207 self.local.token = token 

208 return True 

209 if not blocking: 

210 return False 

211 next_try_at = asyncio.get_running_loop().time() + sleep 

212 if stop_trying_at is not None and next_try_at > stop_trying_at: 

213 return False 

214 await asyncio.sleep(sleep) 

215 

216 async def do_acquire(self, token: Union[str, bytes]) -> bool: 

217 if self.timeout: 

218 # convert to milliseconds 

219 timeout = int(self.timeout * 1000) 

220 else: 

221 timeout = None 

222 if await self.redis.set(self.name, token, nx=True, px=timeout): 

223 return True 

224 return False 

225 

226 async def locked(self) -> bool: 

227 """ 

228 Returns True if this key is locked by any process, otherwise False. 

229 """ 

230 return await self.redis.get(self.name) is not None 

231 

232 async def owned(self) -> bool: 

233 """ 

234 Returns True if this key is locked by this lock, otherwise False. 

235 """ 

236 stored_token = await self.redis.get(self.name) 

237 # need to always compare bytes to bytes 

238 # TODO: this can be simplified when the context manager is finished 

239 if stored_token and not isinstance(stored_token, bytes): 

240 try: 

241 encoder = self.redis.connection_pool.get_encoder() 

242 except AttributeError: 

243 # Cluster 

244 encoder = self.redis.get_encoder() 

245 stored_token = encoder.encode(stored_token) 

246 return self.local.token is not None and stored_token == self.local.token 

247 

248 def release(self) -> Awaitable[None]: 

249 """Releases the already acquired lock""" 

250 expected_token = self.local.token 

251 if expected_token is None: 

252 raise LockError("Cannot release an unlocked lock") 

253 self.local.token = None 

254 return self.do_release(expected_token) 

255 

256 async def do_release(self, expected_token: bytes) -> None: 

257 if not bool( 

258 await self.lua_release( 

259 keys=[self.name], args=[expected_token], client=self.redis 

260 ) 

261 ): 

262 raise LockNotOwnedError("Cannot release a lock that's no longer owned") 

263 

264 def extend( 

265 self, additional_time: float, replace_ttl: bool = False 

266 ) -> Awaitable[bool]: 

267 """ 

268 Adds more time to an already acquired lock. 

269 

270 ``additional_time`` can be specified as an integer or a float, both 

271 representing the number of seconds to add. 

272 

273 ``replace_ttl`` if False (the default), add `additional_time` to 

274 the lock's existing ttl. If True, replace the lock's ttl with 

275 `additional_time`. 

276 """ 

277 if self.local.token is None: 

278 raise LockError("Cannot extend an unlocked lock") 

279 if self.timeout is None: 

280 raise LockError("Cannot extend a lock with no timeout") 

281 return self.do_extend(additional_time, replace_ttl) 

282 

283 async def do_extend(self, additional_time, replace_ttl) -> bool: 

284 additional_time = int(additional_time * 1000) 

285 if not bool( 

286 await self.lua_extend( 

287 keys=[self.name], 

288 args=[self.local.token, additional_time, replace_ttl and "1" or "0"], 

289 client=self.redis, 

290 ) 

291 ): 

292 raise LockNotOwnedError("Cannot extend a lock that's no longer owned") 

293 return True 

294 

295 def reacquire(self) -> Awaitable[bool]: 

296 """ 

297 Resets a TTL of an already acquired lock back to a timeout value. 

298 """ 

299 if self.local.token is None: 

300 raise LockError("Cannot reacquire an unlocked lock") 

301 if self.timeout is None: 

302 raise LockError("Cannot reacquire a lock with no timeout") 

303 return self.do_reacquire() 

304 

305 async def do_reacquire(self) -> bool: 

306 timeout = int(self.timeout * 1000) 

307 if not bool( 

308 await self.lua_reacquire( 

309 keys=[self.name], args=[self.local.token, timeout], client=self.redis 

310 ) 

311 ): 

312 raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned") 

313 return True