1import asyncio
2import logging
3import threading
4import uuid
5from types import SimpleNamespace
6from typing import TYPE_CHECKING, Awaitable, Optional, Union
7
8from redis.exceptions import LockError, LockNotOwnedError
9from redis.typing import Number
10
11if TYPE_CHECKING:
12 from redis.asyncio import Redis, RedisCluster
13
14logger = logging.getLogger(__name__)
15
16
17class Lock:
18 """
19 A shared, distributed Lock. Using Redis for locking allows the Lock
20 to be shared across processes and/or machines.
21
22 It's left to the user to resolve deadlock issues and make sure
23 multiple clients play nicely together.
24 """
25
26 lua_release = None
27 lua_extend = None
28 lua_reacquire = None
29
30 # KEYS[1] - lock name
31 # ARGV[1] - token
32 # return 1 if the lock was released, otherwise 0
33 LUA_RELEASE_SCRIPT = """
34 local token = redis.call('get', KEYS[1])
35 if not token or token ~= ARGV[1] then
36 return 0
37 end
38 redis.call('del', KEYS[1])
39 return 1
40 """
41
42 # KEYS[1] - lock name
43 # ARGV[1] - token
44 # ARGV[2] - additional milliseconds
45 # ARGV[3] - "0" if the additional time should be added to the lock's
46 # existing ttl or "1" if the existing ttl should be replaced
47 # return 1 if the locks time was extended, otherwise 0
48 LUA_EXTEND_SCRIPT = """
49 local token = redis.call('get', KEYS[1])
50 if not token or token ~= ARGV[1] then
51 return 0
52 end
53 local expiration = redis.call('pttl', KEYS[1])
54 if not expiration then
55 expiration = 0
56 end
57 if expiration < 0 then
58 return 0
59 end
60
61 local newttl = ARGV[2]
62 if ARGV[3] == "0" then
63 newttl = ARGV[2] + expiration
64 end
65 redis.call('pexpire', KEYS[1], newttl)
66 return 1
67 """
68
69 # KEYS[1] - lock name
70 # ARGV[1] - token
71 # ARGV[2] - milliseconds
72 # return 1 if the locks time was reacquired, otherwise 0
73 LUA_REACQUIRE_SCRIPT = """
74 local token = redis.call('get', KEYS[1])
75 if not token or token ~= ARGV[1] then
76 return 0
77 end
78 redis.call('pexpire', KEYS[1], ARGV[2])
79 return 1
80 """
81
82 def __init__(
83 self,
84 redis: Union["Redis", "RedisCluster"],
85 name: Union[str, bytes, memoryview],
86 timeout: Optional[float] = None,
87 sleep: float = 0.1,
88 blocking: bool = True,
89 blocking_timeout: Optional[Number] = None,
90 thread_local: bool = True,
91 raise_on_release_error: bool = True,
92 ):
93 """
94 Create a new Lock instance named ``name`` using the Redis client
95 supplied by ``redis``.
96
97 ``timeout`` indicates a maximum life for the lock in seconds.
98 By default, it will remain locked until release() is called.
99 ``timeout`` can be specified as a float or integer, both representing
100 the number of seconds to wait.
101
102 ``sleep`` indicates the amount of time to sleep in seconds per loop
103 iteration when the lock is in blocking mode and another client is
104 currently holding the lock.
105
106 ``blocking`` indicates whether calling ``acquire`` should block until
107 the lock has been acquired or to fail immediately, causing ``acquire``
108 to return False and the lock not being acquired. Defaults to True.
109 Note this value can be overridden by passing a ``blocking``
110 argument to ``acquire``.
111
112 ``blocking_timeout`` indicates the maximum amount of time in seconds to
113 spend trying to acquire the lock. A value of ``None`` indicates
114 continue trying forever. ``blocking_timeout`` can be specified as a
115 float or integer, both representing the number of seconds to wait.
116
117 ``thread_local`` indicates whether the lock token is placed in
118 thread-local storage. By default, the token is placed in thread local
119 storage so that a thread only sees its token, not a token set by
120 another thread. Consider the following timeline:
121
122 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
123 thread-1 sets the token to "abc"
124 time: 1, thread-2 blocks trying to acquire `my-lock` using the
125 Lock instance.
126 time: 5, thread-1 has not yet completed. redis expires the lock
127 key.
128 time: 5, thread-2 acquired `my-lock` now that it's available.
129 thread-2 sets the token to "xyz"
130 time: 6, thread-1 finishes its work and calls release(). if the
131 token is *not* stored in thread local storage, then
132 thread-1 would see the token value as "xyz" and would be
133 able to successfully release the thread-2's lock.
134
135 ``raise_on_release_error`` indicates whether to raise an exception when
136 the lock is no longer owned when exiting the context manager. By default,
137 this is True, meaning an exception will be raised. If False, the warning
138 will be logged and the exception will be suppressed.
139
140 In some use cases it's necessary to disable thread local storage. For
141 example, if you have code where one thread acquires a lock and passes
142 that lock instance to a worker thread to release later. If thread
143 local storage isn't disabled in this case, the worker thread won't see
144 the token set by the thread that acquired the lock. Our assumption
145 is that these cases aren't common and as such default to using
146 thread local storage.
147 """
148 self.redis = redis
149 self.name = name
150 self.timeout = timeout
151 self.sleep = sleep
152 self.blocking = blocking
153 self.blocking_timeout = blocking_timeout
154 self.thread_local = bool(thread_local)
155 self.local = threading.local() if self.thread_local else SimpleNamespace()
156 self.raise_on_release_error = raise_on_release_error
157 self.local.token = None
158 self.register_scripts()
159
160 def register_scripts(self):
161 cls = self.__class__
162 client = self.redis
163 if cls.lua_release is None:
164 cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
165 if cls.lua_extend is None:
166 cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
167 if cls.lua_reacquire is None:
168 cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
169
170 async def __aenter__(self):
171 if await self.acquire():
172 return self
173 raise LockError("Unable to acquire lock within the time specified")
174
175 async def __aexit__(self, exc_type, exc_value, traceback):
176 try:
177 await self.release()
178 except LockError:
179 if self.raise_on_release_error:
180 raise
181 logger.warning(
182 "Lock was unlocked or no longer owned when exiting context manager."
183 )
184
185 async def acquire(
186 self,
187 blocking: Optional[bool] = None,
188 blocking_timeout: Optional[Number] = None,
189 token: Optional[Union[str, bytes]] = None,
190 ):
191 """
192 Use Redis to hold a shared, distributed lock named ``name``.
193 Returns True once the lock is acquired.
194
195 If ``blocking`` is False, always return immediately. If the lock
196 was acquired, return True, otherwise return False.
197
198 ``blocking_timeout`` specifies the maximum number of seconds to
199 wait trying to acquire the lock.
200
201 ``token`` specifies the token value to be used. If provided, token
202 must be a bytes object or a string that can be encoded to a bytes
203 object with the default encoding. If a token isn't specified, a UUID
204 will be generated.
205 """
206 sleep = self.sleep
207 if token is None:
208 token = uuid.uuid1().hex.encode()
209 else:
210 try:
211 encoder = self.redis.connection_pool.get_encoder()
212 except AttributeError:
213 # Cluster
214 encoder = self.redis.get_encoder()
215 token = encoder.encode(token)
216 if blocking is None:
217 blocking = self.blocking
218 if blocking_timeout is None:
219 blocking_timeout = self.blocking_timeout
220 stop_trying_at = None
221 if blocking_timeout is not None:
222 stop_trying_at = asyncio.get_running_loop().time() + blocking_timeout
223 while True:
224 if await self.do_acquire(token):
225 self.local.token = token
226 return True
227 if not blocking:
228 return False
229 next_try_at = asyncio.get_running_loop().time() + sleep
230 if stop_trying_at is not None and next_try_at > stop_trying_at:
231 return False
232 await asyncio.sleep(sleep)
233
234 async def do_acquire(self, token: Union[str, bytes]) -> bool:
235 if self.timeout:
236 # convert to milliseconds
237 timeout = int(self.timeout * 1000)
238 else:
239 timeout = None
240 if await self.redis.set(self.name, token, nx=True, px=timeout):
241 return True
242 return False
243
244 async def locked(self) -> bool:
245 """
246 Returns True if this key is locked by any process, otherwise False.
247 """
248 return await self.redis.get(self.name) is not None
249
250 async def owned(self) -> bool:
251 """
252 Returns True if this key is locked by this lock, otherwise False.
253 """
254 stored_token = await self.redis.get(self.name)
255 # need to always compare bytes to bytes
256 # TODO: this can be simplified when the context manager is finished
257 if stored_token and not isinstance(stored_token, bytes):
258 try:
259 encoder = self.redis.connection_pool.get_encoder()
260 except AttributeError:
261 # Cluster
262 encoder = self.redis.get_encoder()
263 stored_token = encoder.encode(stored_token)
264 return self.local.token is not None and stored_token == self.local.token
265
266 def release(self) -> Awaitable[None]:
267 """Releases the already acquired lock"""
268 expected_token = self.local.token
269 if expected_token is None:
270 raise LockError(
271 "Cannot release a lock that's not owned or is already unlocked.",
272 lock_name=self.name,
273 )
274 self.local.token = None
275 return self.do_release(expected_token)
276
277 async def do_release(self, expected_token: bytes) -> None:
278 if not bool(
279 await self.lua_release(
280 keys=[self.name], args=[expected_token], client=self.redis
281 )
282 ):
283 raise LockNotOwnedError("Cannot release a lock that's no longer owned")
284
285 def extend(
286 self, additional_time: Number, replace_ttl: bool = False
287 ) -> Awaitable[bool]:
288 """
289 Adds more time to an already acquired lock.
290
291 ``additional_time`` can be specified as an integer or a float, both
292 representing the number of seconds to add.
293
294 ``replace_ttl`` if False (the default), add `additional_time` to
295 the lock's existing ttl. If True, replace the lock's ttl with
296 `additional_time`.
297 """
298 if self.local.token is None:
299 raise LockError("Cannot extend an unlocked lock")
300 if self.timeout is None:
301 raise LockError("Cannot extend a lock with no timeout")
302 return self.do_extend(additional_time, replace_ttl)
303
304 async def do_extend(self, additional_time, replace_ttl) -> bool:
305 additional_time = int(additional_time * 1000)
306 if not bool(
307 await self.lua_extend(
308 keys=[self.name],
309 args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
310 client=self.redis,
311 )
312 ):
313 raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
314 return True
315
316 def reacquire(self) -> Awaitable[bool]:
317 """
318 Resets a TTL of an already acquired lock back to a timeout value.
319 """
320 if self.local.token is None:
321 raise LockError("Cannot reacquire an unlocked lock")
322 if self.timeout is None:
323 raise LockError("Cannot reacquire a lock with no timeout")
324 return self.do_reacquire()
325
326 async def do_reacquire(self) -> bool:
327 timeout = int(self.timeout * 1000)
328 if not bool(
329 await self.lua_reacquire(
330 keys=[self.name], args=[self.local.token, timeout], client=self.redis
331 )
332 ):
333 raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
334 return True