Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/asyncio/client.py: 23%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

662 statements  

1import asyncio 

2import copy 

3import inspect 

4import re 

5import warnings 

6from typing import ( 

7 TYPE_CHECKING, 

8 Any, 

9 AsyncIterator, 

10 Awaitable, 

11 Callable, 

12 Dict, 

13 Iterable, 

14 List, 

15 Mapping, 

16 MutableMapping, 

17 Optional, 

18 Protocol, 

19 Set, 

20 Tuple, 

21 Type, 

22 TypedDict, 

23 TypeVar, 

24 Union, 

25 cast, 

26) 

27 

28from redis._parsers.helpers import ( 

29 _RedisCallbacks, 

30 _RedisCallbacksRESP2, 

31 _RedisCallbacksRESP3, 

32 bool_ok, 

33) 

34from redis.asyncio.connection import ( 

35 Connection, 

36 ConnectionPool, 

37 SSLConnection, 

38 UnixDomainSocketConnection, 

39) 

40from redis.asyncio.lock import Lock 

41from redis.asyncio.retry import Retry 

42from redis.backoff import ExponentialWithJitterBackoff 

43from redis.client import ( 

44 EMPTY_RESPONSE, 

45 NEVER_DECODE, 

46 AbstractRedis, 

47 CaseInsensitiveDict, 

48) 

49from redis.commands import ( 

50 AsyncCoreCommands, 

51 AsyncRedisModuleCommands, 

52 AsyncSentinelCommands, 

53 list_or_args, 

54) 

55from redis.credentials import CredentialProvider 

56from redis.driver_info import DriverInfo, resolve_driver_info 

57from redis.event import ( 

58 AfterPooledConnectionsInstantiationEvent, 

59 AfterPubSubConnectionInstantiationEvent, 

60 AfterSingleConnectionInstantiationEvent, 

61 ClientType, 

62 EventDispatcher, 

63) 

64from redis.exceptions import ( 

65 ConnectionError, 

66 ExecAbortError, 

67 PubSubError, 

68 RedisError, 

69 ResponseError, 

70 WatchError, 

71) 

72from redis.typing import ChannelT, EncodableT, KeyT 

73from redis.utils import ( 

74 SSL_AVAILABLE, 

75 _set_info_logger, 

76 deprecated_args, 

77 deprecated_function, 

78 safe_str, 

79 str_if_bytes, 

80 truncate_text, 

81) 

82 

83if TYPE_CHECKING and SSL_AVAILABLE: 

84 from ssl import TLSVersion, VerifyFlags, VerifyMode 

85else: 

86 TLSVersion = None 

87 VerifyMode = None 

88 VerifyFlags = None 

89 

90PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] 

91_KeyT = TypeVar("_KeyT", bound=KeyT) 

92_ArgT = TypeVar("_ArgT", KeyT, EncodableT) 

93_RedisT = TypeVar("_RedisT", bound="Redis") 

94_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) 

95if TYPE_CHECKING: 

96 from redis.commands.core import Script 

97 

98 

99class ResponseCallbackProtocol(Protocol): 

100 def __call__(self, response: Any, **kwargs): ... 

101 

102 

103class AsyncResponseCallbackProtocol(Protocol): 

104 async def __call__(self, response: Any, **kwargs): ... 

105 

106 

107ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] 

108 

109 

110class Redis( 

111 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands 

112): 

113 """ 

114 Implementation of the Redis protocol. 

115 

116 This abstract class provides a Python interface to all Redis commands 

117 and an implementation of the Redis protocol. 

118 

119 Pipelines derive from this, implementing how 

120 the commands are sent and received to the Redis server. Based on 

121 configuration, an instance will either use a ConnectionPool, or 

122 Connection object to talk to redis. 

123 """ 

124 

125 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] 

126 

127 @classmethod 

128 def from_url( 

129 cls: Type["Redis"], 

130 url: str, 

131 single_connection_client: bool = False, 

132 auto_close_connection_pool: Optional[bool] = None, 

133 **kwargs, 

134 ) -> "Redis": 

135 """ 

136 Return a Redis client object configured from the given URL 

137 

138 For example:: 

139 

140 redis://[[username]:[password]]@localhost:6379/0 

141 rediss://[[username]:[password]]@localhost:6379/0 

142 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

143 

144 Three URL schemes are supported: 

145 

146 - `redis://` creates a TCP socket connection. See more at: 

147 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

148 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

149 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

150 - ``unix://``: creates a Unix Domain Socket connection. 

151 

152 The username, password, hostname, path and all querystring values 

153 are passed through urllib.parse.unquote in order to replace any 

154 percent-encoded values with their corresponding characters. 

155 

156 There are several ways to specify a database number. The first value 

157 found will be used: 

158 

159 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

160 

161 2. If using the redis:// or rediss:// schemes, the path argument 

162 of the url, e.g. redis://localhost/0 

163 

164 3. A ``db`` keyword argument to this function. 

165 

166 If none of these options are specified, the default db=0 is used. 

167 

168 All querystring options are cast to their appropriate Python types. 

169 Boolean arguments can be specified with string values "True"/"False" 

170 or "Yes"/"No". Values that cannot be properly cast cause a 

171 ``ValueError`` to be raised. Once parsed, the querystring arguments 

172 and keyword arguments are passed to the ``ConnectionPool``'s 

173 class initializer. In the case of conflicting arguments, querystring 

174 arguments always win. 

175 

176 """ 

177 connection_pool = ConnectionPool.from_url(url, **kwargs) 

178 client = cls( 

179 connection_pool=connection_pool, 

180 single_connection_client=single_connection_client, 

181 ) 

182 if auto_close_connection_pool is not None: 

183 warnings.warn( 

184 DeprecationWarning( 

185 '"auto_close_connection_pool" is deprecated ' 

186 "since version 5.0.1. " 

187 "Please create a ConnectionPool explicitly and " 

188 "provide to the Redis() constructor instead." 

189 ) 

190 ) 

191 else: 

192 auto_close_connection_pool = True 

193 client.auto_close_connection_pool = auto_close_connection_pool 

194 return client 

195 

196 @classmethod 

197 def from_pool( 

198 cls: Type["Redis"], 

199 connection_pool: ConnectionPool, 

200 ) -> "Redis": 

201 """ 

202 Return a Redis client from the given connection pool. 

203 The Redis client will take ownership of the connection pool and 

204 close it when the Redis client is closed. 

205 """ 

206 client = cls( 

207 connection_pool=connection_pool, 

208 ) 

209 client.auto_close_connection_pool = True 

210 return client 

211 

212 @deprecated_args( 

213 args_to_warn=["retry_on_timeout"], 

214 reason="TimeoutError is included by default.", 

215 version="6.0.0", 

216 ) 

217 @deprecated_args( 

218 args_to_warn=["lib_name", "lib_version"], 

219 reason="Use 'driver_info' parameter instead. " 

220 "lib_name and lib_version will be removed in a future version.", 

221 ) 

222 def __init__( 

223 self, 

224 *, 

225 host: str = "localhost", 

226 port: int = 6379, 

227 db: Union[str, int] = 0, 

228 password: Optional[str] = None, 

229 socket_timeout: Optional[float] = None, 

230 socket_connect_timeout: Optional[float] = None, 

231 socket_keepalive: Optional[bool] = None, 

232 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

233 connection_pool: Optional[ConnectionPool] = None, 

234 unix_socket_path: Optional[str] = None, 

235 encoding: str = "utf-8", 

236 encoding_errors: str = "strict", 

237 decode_responses: bool = False, 

238 retry_on_timeout: bool = False, 

239 retry: Retry = Retry( 

240 backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 

241 ), 

242 retry_on_error: Optional[list] = None, 

243 ssl: bool = False, 

244 ssl_keyfile: Optional[str] = None, 

245 ssl_certfile: Optional[str] = None, 

246 ssl_cert_reqs: Union[str, VerifyMode] = "required", 

247 ssl_include_verify_flags: Optional[List[VerifyFlags]] = None, 

248 ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None, 

249 ssl_ca_certs: Optional[str] = None, 

250 ssl_ca_data: Optional[str] = None, 

251 ssl_ca_path: Optional[str] = None, 

252 ssl_check_hostname: bool = True, 

253 ssl_min_version: Optional[TLSVersion] = None, 

254 ssl_ciphers: Optional[str] = None, 

255 ssl_password: Optional[str] = None, 

256 max_connections: Optional[int] = None, 

257 single_connection_client: bool = False, 

258 health_check_interval: int = 0, 

259 client_name: Optional[str] = None, 

260 lib_name: Optional[str] = None, 

261 lib_version: Optional[str] = None, 

262 driver_info: Optional["DriverInfo"] = None, 

263 username: Optional[str] = None, 

264 auto_close_connection_pool: Optional[bool] = None, 

265 redis_connect_func=None, 

266 credential_provider: Optional[CredentialProvider] = None, 

267 protocol: Optional[int] = 2, 

268 event_dispatcher: Optional[EventDispatcher] = None, 

269 ): 

270 """ 

271 Initialize a new Redis client. 

272 

273 To specify a retry policy for specific errors, you have two options: 

274 

275 1. Set the `retry_on_error` to a list of the error/s to retry on, and 

276 you can also set `retry` to a valid `Retry` object(in case the default 

277 one is not appropriate) - with this approach the retries will be triggered 

278 on the default errors specified in the Retry object enriched with the 

279 errors specified in `retry_on_error`. 

280 

281 2. Define a `Retry` object with configured 'supported_errors' and set 

282 it to the `retry` parameter - with this approach you completely redefine 

283 the errors on which retries will happen. 

284 

285 `retry_on_timeout` is deprecated - please include the TimeoutError 

286 either in the Retry object or in the `retry_on_error` list. 

287 

288 When 'connection_pool' is provided - the retry configuration of the 

289 provided pool will be used. 

290 """ 

291 kwargs: Dict[str, Any] 

292 if event_dispatcher is None: 

293 self._event_dispatcher = EventDispatcher() 

294 else: 

295 self._event_dispatcher = event_dispatcher 

296 # auto_close_connection_pool only has an effect if connection_pool is 

297 # None. It is assumed that if connection_pool is not None, the user 

298 # wants to manage the connection pool themselves. 

299 if auto_close_connection_pool is not None: 

300 warnings.warn( 

301 DeprecationWarning( 

302 '"auto_close_connection_pool" is deprecated ' 

303 "since version 5.0.1. " 

304 "Please create a ConnectionPool explicitly and " 

305 "provide to the Redis() constructor instead." 

306 ) 

307 ) 

308 else: 

309 auto_close_connection_pool = True 

310 

311 if not connection_pool: 

312 # Create internal connection pool, expected to be closed by Redis instance 

313 if not retry_on_error: 

314 retry_on_error = [] 

315 

316 # Handle driver_info: if provided, use it; otherwise create from lib_name/lib_version 

317 computed_driver_info = resolve_driver_info( 

318 driver_info, lib_name, lib_version 

319 ) 

320 

321 kwargs = { 

322 "db": db, 

323 "username": username, 

324 "password": password, 

325 "credential_provider": credential_provider, 

326 "socket_timeout": socket_timeout, 

327 "encoding": encoding, 

328 "encoding_errors": encoding_errors, 

329 "decode_responses": decode_responses, 

330 "retry_on_error": retry_on_error, 

331 "retry": copy.deepcopy(retry), 

332 "max_connections": max_connections, 

333 "health_check_interval": health_check_interval, 

334 "client_name": client_name, 

335 "driver_info": computed_driver_info, 

336 "redis_connect_func": redis_connect_func, 

337 "protocol": protocol, 

338 } 

339 # based on input, setup appropriate connection args 

340 if unix_socket_path is not None: 

341 kwargs.update( 

342 { 

343 "path": unix_socket_path, 

344 "connection_class": UnixDomainSocketConnection, 

345 } 

346 ) 

347 else: 

348 # TCP specific options 

349 kwargs.update( 

350 { 

351 "host": host, 

352 "port": port, 

353 "socket_connect_timeout": socket_connect_timeout, 

354 "socket_keepalive": socket_keepalive, 

355 "socket_keepalive_options": socket_keepalive_options, 

356 } 

357 ) 

358 

359 if ssl: 

360 kwargs.update( 

361 { 

362 "connection_class": SSLConnection, 

363 "ssl_keyfile": ssl_keyfile, 

364 "ssl_certfile": ssl_certfile, 

365 "ssl_cert_reqs": ssl_cert_reqs, 

366 "ssl_include_verify_flags": ssl_include_verify_flags, 

367 "ssl_exclude_verify_flags": ssl_exclude_verify_flags, 

368 "ssl_ca_certs": ssl_ca_certs, 

369 "ssl_ca_data": ssl_ca_data, 

370 "ssl_ca_path": ssl_ca_path, 

371 "ssl_check_hostname": ssl_check_hostname, 

372 "ssl_min_version": ssl_min_version, 

373 "ssl_ciphers": ssl_ciphers, 

374 "ssl_password": ssl_password, 

375 } 

376 ) 

377 # This arg only used if no pool is passed in 

378 self.auto_close_connection_pool = auto_close_connection_pool 

379 connection_pool = ConnectionPool(**kwargs) 

380 self._event_dispatcher.dispatch( 

381 AfterPooledConnectionsInstantiationEvent( 

382 [connection_pool], ClientType.ASYNC, credential_provider 

383 ) 

384 ) 

385 else: 

386 # If a pool is passed in, do not close it 

387 self.auto_close_connection_pool = False 

388 self._event_dispatcher.dispatch( 

389 AfterPooledConnectionsInstantiationEvent( 

390 [connection_pool], ClientType.ASYNC, credential_provider 

391 ) 

392 ) 

393 

394 self.connection_pool = connection_pool 

395 self.single_connection_client = single_connection_client 

396 self.connection: Optional[Connection] = None 

397 

398 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) 

399 

400 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: 

401 self.response_callbacks.update(_RedisCallbacksRESP3) 

402 else: 

403 self.response_callbacks.update(_RedisCallbacksRESP2) 

404 

405 # If using a single connection client, we need to lock creation-of and use-of 

406 # the client in order to avoid race conditions such as using asyncio.gather 

407 # on a set of redis commands 

408 self._single_conn_lock = asyncio.Lock() 

409 

410 # When used as an async context manager, we need to increment and decrement 

411 # a usage counter so that we can close the connection pool when no one is 

412 # using the client. 

413 self._usage_counter = 0 

414 self._usage_lock = asyncio.Lock() 

415 

416 def __repr__(self): 

417 return ( 

418 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

419 f"({self.connection_pool!r})>" 

420 ) 

421 

422 def __await__(self): 

423 return self.initialize().__await__() 

424 

425 async def initialize(self: _RedisT) -> _RedisT: 

426 if self.single_connection_client: 

427 async with self._single_conn_lock: 

428 if self.connection is None: 

429 self.connection = await self.connection_pool.get_connection() 

430 

431 self._event_dispatcher.dispatch( 

432 AfterSingleConnectionInstantiationEvent( 

433 self.connection, ClientType.ASYNC, self._single_conn_lock 

434 ) 

435 ) 

436 return self 

437 

438 def set_response_callback(self, command: str, callback: ResponseCallbackT): 

439 """Set a custom Response Callback""" 

440 self.response_callbacks[command] = callback 

441 

442 def get_encoder(self): 

443 """Get the connection pool's encoder""" 

444 return self.connection_pool.get_encoder() 

445 

446 def get_connection_kwargs(self): 

447 """Get the connection's key-word arguments""" 

448 return self.connection_pool.connection_kwargs 

449 

450 def get_retry(self) -> Optional[Retry]: 

451 return self.get_connection_kwargs().get("retry") 

452 

453 def set_retry(self, retry: Retry) -> None: 

454 self.get_connection_kwargs().update({"retry": retry}) 

455 self.connection_pool.set_retry(retry) 

456 

457 def load_external_module(self, funcname, func): 

458 """ 

459 This function can be used to add externally defined redis modules, 

460 and their namespaces to the redis client. 

461 

462 funcname - A string containing the name of the function to create 

463 func - The function, being added to this class. 

464 

465 ex: Assume that one has a custom redis module named foomod that 

466 creates command named 'foo.dothing' and 'foo.anotherthing' in redis. 

467 To load function functions into this namespace: 

468 

469 from redis import Redis 

470 from foomodule import F 

471 r = Redis() 

472 r.load_external_module("foo", F) 

473 r.foo().dothing('your', 'arguments') 

474 

475 For a concrete example see the reimport of the redisjson module in 

476 tests/test_connection.py::test_loading_external_modules 

477 """ 

478 setattr(self, funcname, func) 

479 

480 def pipeline( 

481 self, transaction: bool = True, shard_hint: Optional[str] = None 

482 ) -> "Pipeline": 

483 """ 

484 Return a new pipeline object that can queue multiple commands for 

485 later execution. ``transaction`` indicates whether all commands 

486 should be executed atomically. Apart from making a group of operations 

487 atomic, pipelines are useful for reducing the back-and-forth overhead 

488 between the client and server. 

489 """ 

490 return Pipeline( 

491 self.connection_pool, self.response_callbacks, transaction, shard_hint 

492 ) 

493 

494 async def transaction( 

495 self, 

496 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], 

497 *watches: KeyT, 

498 shard_hint: Optional[str] = None, 

499 value_from_callable: bool = False, 

500 watch_delay: Optional[float] = None, 

501 ): 

502 """ 

503 Convenience method for executing the callable `func` as a transaction 

504 while watching all keys specified in `watches`. The 'func' callable 

505 should expect a single argument which is a Pipeline object. 

506 """ 

507 pipe: Pipeline 

508 async with self.pipeline(True, shard_hint) as pipe: 

509 while True: 

510 try: 

511 if watches: 

512 await pipe.watch(*watches) 

513 func_value = func(pipe) 

514 if inspect.isawaitable(func_value): 

515 func_value = await func_value 

516 exec_value = await pipe.execute() 

517 return func_value if value_from_callable else exec_value 

518 except WatchError: 

519 if watch_delay is not None and watch_delay > 0: 

520 await asyncio.sleep(watch_delay) 

521 continue 

522 

523 def lock( 

524 self, 

525 name: KeyT, 

526 timeout: Optional[float] = None, 

527 sleep: float = 0.1, 

528 blocking: bool = True, 

529 blocking_timeout: Optional[float] = None, 

530 lock_class: Optional[Type[Lock]] = None, 

531 thread_local: bool = True, 

532 raise_on_release_error: bool = True, 

533 ) -> Lock: 

534 """ 

535 Return a new Lock object using key ``name`` that mimics 

536 the behavior of threading.Lock. 

537 

538 If specified, ``timeout`` indicates a maximum life for the lock. 

539 By default, it will remain locked until release() is called. 

540 

541 ``sleep`` indicates the amount of time to sleep per loop iteration 

542 when the lock is in blocking mode and another client is currently 

543 holding the lock. 

544 

545 ``blocking`` indicates whether calling ``acquire`` should block until 

546 the lock has been acquired or to fail immediately, causing ``acquire`` 

547 to return False and the lock not being acquired. Defaults to True. 

548 Note this value can be overridden by passing a ``blocking`` 

549 argument to ``acquire``. 

550 

551 ``blocking_timeout`` indicates the maximum amount of time in seconds to 

552 spend trying to acquire the lock. A value of ``None`` indicates 

553 continue trying forever. ``blocking_timeout`` can be specified as a 

554 float or integer, both representing the number of seconds to wait. 

555 

556 ``lock_class`` forces the specified lock implementation. Note that as 

557 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is 

558 a Lua-based lock). So, it's unlikely you'll need this parameter, unless 

559 you have created your own custom lock class. 

560 

561 ``thread_local`` indicates whether the lock token is placed in 

562 thread-local storage. By default, the token is placed in thread local 

563 storage so that a thread only sees its token, not a token set by 

564 another thread. Consider the following timeline: 

565 

566 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. 

567 thread-1 sets the token to "abc" 

568 time: 1, thread-2 blocks trying to acquire `my-lock` using the 

569 Lock instance. 

570 time: 5, thread-1 has not yet completed. redis expires the lock 

571 key. 

572 time: 5, thread-2 acquired `my-lock` now that it's available. 

573 thread-2 sets the token to "xyz" 

574 time: 6, thread-1 finishes its work and calls release(). if the 

575 token is *not* stored in thread local storage, then 

576 thread-1 would see the token value as "xyz" and would be 

577 able to successfully release the thread-2's lock. 

578 

579 ``raise_on_release_error`` indicates whether to raise an exception when 

580 the lock is no longer owned when exiting the context manager. By default, 

581 this is True, meaning an exception will be raised. If False, the warning 

582 will be logged and the exception will be suppressed. 

583 

584 In some use cases it's necessary to disable thread local storage. For 

585 example, if you have code where one thread acquires a lock and passes 

586 that lock instance to a worker thread to release later. If thread 

587 local storage isn't disabled in this case, the worker thread won't see 

588 the token set by the thread that acquired the lock. Our assumption 

589 is that these cases aren't common and as such default to using 

590 thread local storage.""" 

591 if lock_class is None: 

592 lock_class = Lock 

593 return lock_class( 

594 self, 

595 name, 

596 timeout=timeout, 

597 sleep=sleep, 

598 blocking=blocking, 

599 blocking_timeout=blocking_timeout, 

600 thread_local=thread_local, 

601 raise_on_release_error=raise_on_release_error, 

602 ) 

603 

604 def pubsub(self, **kwargs) -> "PubSub": 

605 """ 

606 Return a Publish/Subscribe object. With this object, you can 

607 subscribe to channels and listen for messages that get published to 

608 them. 

609 """ 

610 return PubSub( 

611 self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs 

612 ) 

613 

614 def monitor(self) -> "Monitor": 

615 return Monitor(self.connection_pool) 

616 

617 def client(self) -> "Redis": 

618 return self.__class__( 

619 connection_pool=self.connection_pool, single_connection_client=True 

620 ) 

621 

622 async def __aenter__(self: _RedisT) -> _RedisT: 

623 """ 

624 Async context manager entry. Increments a usage counter so that the 

625 connection pool is only closed (via aclose()) when no context is using 

626 the client. 

627 """ 

628 await self._increment_usage() 

629 try: 

630 # Initialize the client (i.e. establish connection, etc.) 

631 return await self.initialize() 

632 except Exception: 

633 # If initialization fails, decrement the counter to keep it in sync 

634 await self._decrement_usage() 

635 raise 

636 

637 async def _increment_usage(self) -> int: 

638 """ 

639 Helper coroutine to increment the usage counter while holding the lock. 

640 Returns the new value of the usage counter. 

641 """ 

642 async with self._usage_lock: 

643 self._usage_counter += 1 

644 return self._usage_counter 

645 

646 async def _decrement_usage(self) -> int: 

647 """ 

648 Helper coroutine to decrement the usage counter while holding the lock. 

649 Returns the new value of the usage counter. 

650 """ 

651 async with self._usage_lock: 

652 self._usage_counter -= 1 

653 return self._usage_counter 

654 

655 async def __aexit__(self, exc_type, exc_value, traceback): 

656 """ 

657 Async context manager exit. Decrements a usage counter. If this is the 

658 last exit (counter becomes zero), the client closes its connection pool. 

659 """ 

660 current_usage = await asyncio.shield(self._decrement_usage()) 

661 if current_usage == 0: 

662 # This was the last active context, so disconnect the pool. 

663 await asyncio.shield(self.aclose()) 

664 

665 _DEL_MESSAGE = "Unclosed Redis client" 

666 

667 # passing _warnings and _grl as argument default since they may be gone 

668 # by the time __del__ is called at shutdown 

669 def __del__( 

670 self, 

671 _warn: Any = warnings.warn, 

672 _grl: Any = asyncio.get_running_loop, 

673 ) -> None: 

674 if hasattr(self, "connection") and (self.connection is not None): 

675 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self) 

676 try: 

677 context = {"client": self, "message": self._DEL_MESSAGE} 

678 _grl().call_exception_handler(context) 

679 except RuntimeError: 

680 pass 

681 self.connection._close() 

682 

683 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: 

684 """ 

685 Closes Redis client connection 

686 

687 Args: 

688 close_connection_pool: 

689 decides whether to close the connection pool used by this Redis client, 

690 overriding Redis.auto_close_connection_pool. 

691 By default, let Redis.auto_close_connection_pool decide 

692 whether to close the connection pool. 

693 """ 

694 conn = self.connection 

695 if conn: 

696 self.connection = None 

697 await self.connection_pool.release(conn) 

698 if close_connection_pool or ( 

699 close_connection_pool is None and self.auto_close_connection_pool 

700 ): 

701 await self.connection_pool.disconnect() 

702 

703 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") 

704 async def close(self, close_connection_pool: Optional[bool] = None) -> None: 

705 """ 

706 Alias for aclose(), for backwards compatibility 

707 """ 

708 await self.aclose(close_connection_pool) 

709 

710 async def _send_command_parse_response(self, conn, command_name, *args, **options): 

711 """ 

712 Send a command and parse the response 

713 """ 

714 await conn.send_command(*args) 

715 return await self.parse_response(conn, command_name, **options) 

716 

717 async def _close_connection(self, conn: Connection): 

718 """ 

719 Close the connection before retrying. 

720 

721 The supported exceptions are already checked in the 

722 retry object so we don't need to do it here. 

723 

724 After we disconnect the connection, it will try to reconnect and 

725 do a health check as part of the send_command logic(on connection level). 

726 """ 

727 await conn.disconnect() 

728 

729 # COMMAND EXECUTION AND PROTOCOL PARSING 

730 async def execute_command(self, *args, **options): 

731 """Execute a command and return a parsed response""" 

732 await self.initialize() 

733 pool = self.connection_pool 

734 command_name = args[0] 

735 conn = self.connection or await pool.get_connection() 

736 

737 if self.single_connection_client: 

738 await self._single_conn_lock.acquire() 

739 try: 

740 return await conn.retry.call_with_retry( 

741 lambda: self._send_command_parse_response( 

742 conn, command_name, *args, **options 

743 ), 

744 lambda _: self._close_connection(conn), 

745 ) 

746 finally: 

747 if self.single_connection_client: 

748 self._single_conn_lock.release() 

749 if not self.connection: 

750 await pool.release(conn) 

751 

752 async def parse_response( 

753 self, connection: Connection, command_name: Union[str, bytes], **options 

754 ): 

755 """Parses a response from the Redis server""" 

756 try: 

757 if NEVER_DECODE in options: 

758 response = await connection.read_response(disable_decoding=True) 

759 options.pop(NEVER_DECODE) 

760 else: 

761 response = await connection.read_response() 

762 except ResponseError: 

763 if EMPTY_RESPONSE in options: 

764 return options[EMPTY_RESPONSE] 

765 raise 

766 

767 if EMPTY_RESPONSE in options: 

768 options.pop(EMPTY_RESPONSE) 

769 

770 # Remove keys entry, it needs only for cache. 

771 options.pop("keys", None) 

772 

773 if command_name in self.response_callbacks: 

774 # Mypy bug: https://github.com/python/mypy/issues/10977 

775 command_name = cast(str, command_name) 

776 retval = self.response_callbacks[command_name](response, **options) 

777 return await retval if inspect.isawaitable(retval) else retval 

778 return response 

779 

780 

781StrictRedis = Redis 

782 

783 

784class MonitorCommandInfo(TypedDict): 

785 time: float 

786 db: int 

787 client_address: str 

788 client_port: str 

789 client_type: str 

790 command: str 

791 

792 

793class Monitor: 

794 """ 

795 Monitor is useful for handling the MONITOR command to the redis server. 

796 next_command() method returns one command from monitor 

797 listen() method yields commands from monitor. 

798 """ 

799 

800 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)") 

801 command_re = re.compile(r'"(.*?)(?<!\\)"') 

802 

803 def __init__(self, connection_pool: ConnectionPool): 

804 self.connection_pool = connection_pool 

805 self.connection: Optional[Connection] = None 

806 

807 async def connect(self): 

808 if self.connection is None: 

809 self.connection = await self.connection_pool.get_connection() 

810 

811 async def __aenter__(self): 

812 await self.connect() 

813 await self.connection.send_command("MONITOR") 

814 # check that monitor returns 'OK', but don't return it to user 

815 response = await self.connection.read_response() 

816 if not bool_ok(response): 

817 raise RedisError(f"MONITOR failed: {response}") 

818 return self 

819 

820 async def __aexit__(self, *args): 

821 await self.connection.disconnect() 

822 await self.connection_pool.release(self.connection) 

823 

824 async def next_command(self) -> MonitorCommandInfo: 

825 """Parse the response from a monitor command""" 

826 await self.connect() 

827 response = await self.connection.read_response() 

828 if isinstance(response, bytes): 

829 response = self.connection.encoder.decode(response, force=True) 

830 command_time, command_data = response.split(" ", 1) 

831 m = self.monitor_re.match(command_data) 

832 db_id, client_info, command = m.groups() 

833 command = " ".join(self.command_re.findall(command)) 

834 # Redis escapes double quotes because each piece of the command 

835 # string is surrounded by double quotes. We don't have that 

836 # requirement so remove the escaping and leave the quote. 

837 command = command.replace('\\"', '"') 

838 

839 if client_info == "lua": 

840 client_address = "lua" 

841 client_port = "" 

842 client_type = "lua" 

843 elif client_info.startswith("unix"): 

844 client_address = "unix" 

845 client_port = client_info[5:] 

846 client_type = "unix" 

847 else: 

848 # use rsplit as ipv6 addresses contain colons 

849 client_address, client_port = client_info.rsplit(":", 1) 

850 client_type = "tcp" 

851 return { 

852 "time": float(command_time), 

853 "db": int(db_id), 

854 "client_address": client_address, 

855 "client_port": client_port, 

856 "client_type": client_type, 

857 "command": command, 

858 } 

859 

860 async def listen(self) -> AsyncIterator[MonitorCommandInfo]: 

861 """Listen for commands coming to the server.""" 

862 while True: 

863 yield await self.next_command() 

864 

865 

866class PubSub: 

867 """ 

868 PubSub provides publish, subscribe and listen support to Redis channels. 

869 

870 After subscribing to one or more channels, the listen() method will block 

871 until a message arrives on one of the subscribed channels. That message 

872 will be returned and it's safe to start listening again. 

873 """ 

874 

875 PUBLISH_MESSAGE_TYPES = ("message", "pmessage") 

876 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") 

877 HEALTH_CHECK_MESSAGE = "redis-py-health-check" 

878 

879 def __init__( 

880 self, 

881 connection_pool: ConnectionPool, 

882 shard_hint: Optional[str] = None, 

883 ignore_subscribe_messages: bool = False, 

884 encoder=None, 

885 push_handler_func: Optional[Callable] = None, 

886 event_dispatcher: Optional["EventDispatcher"] = None, 

887 ): 

888 if event_dispatcher is None: 

889 self._event_dispatcher = EventDispatcher() 

890 else: 

891 self._event_dispatcher = event_dispatcher 

892 self.connection_pool = connection_pool 

893 self.shard_hint = shard_hint 

894 self.ignore_subscribe_messages = ignore_subscribe_messages 

895 self.connection = None 

896 # we need to know the encoding options for this connection in order 

897 # to lookup channel and pattern names for callback handlers. 

898 self.encoder = encoder 

899 self.push_handler_func = push_handler_func 

900 if self.encoder is None: 

901 self.encoder = self.connection_pool.get_encoder() 

902 if self.encoder.decode_responses: 

903 self.health_check_response = [ 

904 ["pong", self.HEALTH_CHECK_MESSAGE], 

905 self.HEALTH_CHECK_MESSAGE, 

906 ] 

907 else: 

908 self.health_check_response = [ 

909 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)], 

910 self.encoder.encode(self.HEALTH_CHECK_MESSAGE), 

911 ] 

912 if self.push_handler_func is None: 

913 _set_info_logger() 

914 self.channels = {} 

915 self.pending_unsubscribe_channels = set() 

916 self.patterns = {} 

917 self.pending_unsubscribe_patterns = set() 

918 self._lock = asyncio.Lock() 

919 

920 async def __aenter__(self): 

921 return self 

922 

923 async def __aexit__(self, exc_type, exc_value, traceback): 

924 await self.aclose() 

925 

926 def __del__(self): 

927 if self.connection: 

928 self.connection.deregister_connect_callback(self.on_connect) 

929 

930 async def aclose(self): 

931 # In case a connection property does not yet exist 

932 # (due to a crash earlier in the Redis() constructor), return 

933 # immediately as there is nothing to clean-up. 

934 if not hasattr(self, "connection"): 

935 return 

936 async with self._lock: 

937 if self.connection: 

938 await self.connection.disconnect() 

939 self.connection.deregister_connect_callback(self.on_connect) 

940 await self.connection_pool.release(self.connection) 

941 self.connection = None 

942 self.channels = {} 

943 self.pending_unsubscribe_channels = set() 

944 self.patterns = {} 

945 self.pending_unsubscribe_patterns = set() 

946 

947 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") 

948 async def close(self) -> None: 

949 """Alias for aclose(), for backwards compatibility""" 

950 await self.aclose() 

951 

952 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset") 

953 async def reset(self) -> None: 

954 """Alias for aclose(), for backwards compatibility""" 

955 await self.aclose() 

956 

957 async def on_connect(self, connection: Connection): 

958 """Re-subscribe to any channels and patterns previously subscribed to""" 

959 # NOTE: for python3, we can't pass bytestrings as keyword arguments 

960 # so we need to decode channel/pattern names back to unicode strings 

961 # before passing them to [p]subscribe. 

962 self.pending_unsubscribe_channels.clear() 

963 self.pending_unsubscribe_patterns.clear() 

964 if self.channels: 

965 channels = {} 

966 for k, v in self.channels.items(): 

967 channels[self.encoder.decode(k, force=True)] = v 

968 await self.subscribe(**channels) 

969 if self.patterns: 

970 patterns = {} 

971 for k, v in self.patterns.items(): 

972 patterns[self.encoder.decode(k, force=True)] = v 

973 await self.psubscribe(**patterns) 

974 

975 @property 

976 def subscribed(self): 

977 """Indicates if there are subscriptions to any channels or patterns""" 

978 return bool(self.channels or self.patterns) 

979 

980 async def execute_command(self, *args: EncodableT): 

981 """Execute a publish/subscribe command""" 

982 

983 # NOTE: don't parse the response in this function -- it could pull a 

984 # legitimate message off the stack if the connection is already 

985 # subscribed to one or more channels 

986 

987 await self.connect() 

988 connection = self.connection 

989 kwargs = {"check_health": not self.subscribed} 

990 await self._execute(connection, connection.send_command, *args, **kwargs) 

991 

992 async def connect(self): 

993 """ 

994 Ensure that the PubSub is connected 

995 """ 

996 if self.connection is None: 

997 self.connection = await self.connection_pool.get_connection() 

998 # register a callback that re-subscribes to any channels we 

999 # were listening to when we were disconnected 

1000 self.connection.register_connect_callback(self.on_connect) 

1001 else: 

1002 await self.connection.connect() 

1003 if self.push_handler_func is not None: 

1004 self.connection._parser.set_pubsub_push_handler(self.push_handler_func) 

1005 

1006 self._event_dispatcher.dispatch( 

1007 AfterPubSubConnectionInstantiationEvent( 

1008 self.connection, self.connection_pool, ClientType.ASYNC, self._lock 

1009 ) 

1010 ) 

1011 

1012 async def _reconnect(self, conn): 

1013 """ 

1014 Try to reconnect 

1015 """ 

1016 await conn.disconnect() 

1017 await conn.connect() 

1018 

1019 async def _execute(self, conn, command, *args, **kwargs): 

1020 """ 

1021 Connect manually upon disconnection. If the Redis server is down, 

1022 this will fail and raise a ConnectionError as desired. 

1023 After reconnection, the ``on_connect`` callback should have been 

1024 called by the # connection to resubscribe us to any channels and 

1025 patterns we were previously listening to 

1026 """ 

1027 return await conn.retry.call_with_retry( 

1028 lambda: command(*args, **kwargs), 

1029 lambda _: self._reconnect(conn), 

1030 ) 

1031 

1032 async def parse_response(self, block: bool = True, timeout: float = 0): 

1033 """Parse the response from a publish/subscribe command""" 

1034 conn = self.connection 

1035 if conn is None: 

1036 raise RuntimeError( 

1037 "pubsub connection not set: " 

1038 "did you forget to call subscribe() or psubscribe()?" 

1039 ) 

1040 

1041 await self.check_health() 

1042 

1043 if not conn.is_connected: 

1044 await conn.connect() 

1045 

1046 read_timeout = None if block else timeout 

1047 response = await self._execute( 

1048 conn, 

1049 conn.read_response, 

1050 timeout=read_timeout, 

1051 disconnect_on_error=False, 

1052 push_request=True, 

1053 ) 

1054 

1055 if conn.health_check_interval and response in self.health_check_response: 

1056 # ignore the health check message as user might not expect it 

1057 return None 

1058 return response 

1059 

1060 async def check_health(self): 

1061 conn = self.connection 

1062 if conn is None: 

1063 raise RuntimeError( 

1064 "pubsub connection not set: " 

1065 "did you forget to call subscribe() or psubscribe()?" 

1066 ) 

1067 

1068 if ( 

1069 conn.health_check_interval 

1070 and asyncio.get_running_loop().time() > conn.next_health_check 

1071 ): 

1072 await conn.send_command( 

1073 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False 

1074 ) 

1075 

1076 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: 

1077 """ 

1078 normalize channel/pattern names to be either bytes or strings 

1079 based on whether responses are automatically decoded. this saves us 

1080 from coercing the value for each message coming in. 

1081 """ 

1082 encode = self.encoder.encode 

1083 decode = self.encoder.decode 

1084 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501 

1085 

1086 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): 

1087 """ 

1088 Subscribe to channel patterns. Patterns supplied as keyword arguments 

1089 expect a pattern name as the key and a callable as the value. A 

1090 pattern's callable will be invoked automatically when a message is 

1091 received on that pattern rather than producing a message via 

1092 ``listen()``. 

1093 """ 

1094 parsed_args = list_or_args((args[0],), args[1:]) if args else args 

1095 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args) 

1096 # Mypy bug: https://github.com/python/mypy/issues/10970 

1097 new_patterns.update(kwargs) # type: ignore[arg-type] 

1098 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys()) 

1099 # update the patterns dict AFTER we send the command. we don't want to 

1100 # subscribe twice to these patterns, once for the command and again 

1101 # for the reconnection. 

1102 new_patterns = self._normalize_keys(new_patterns) 

1103 self.patterns.update(new_patterns) 

1104 self.pending_unsubscribe_patterns.difference_update(new_patterns) 

1105 return ret_val 

1106 

1107 def punsubscribe(self, *args: ChannelT) -> Awaitable: 

1108 """ 

1109 Unsubscribe from the supplied patterns. If empty, unsubscribe from 

1110 all patterns. 

1111 """ 

1112 patterns: Iterable[ChannelT] 

1113 if args: 

1114 parsed_args = list_or_args((args[0],), args[1:]) 

1115 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys() 

1116 else: 

1117 parsed_args = [] 

1118 patterns = self.patterns 

1119 self.pending_unsubscribe_patterns.update(patterns) 

1120 return self.execute_command("PUNSUBSCRIBE", *parsed_args) 

1121 

1122 async def subscribe(self, *args: ChannelT, **kwargs: Callable): 

1123 """ 

1124 Subscribe to channels. Channels supplied as keyword arguments expect 

1125 a channel name as the key and a callable as the value. A channel's 

1126 callable will be invoked automatically when a message is received on 

1127 that channel rather than producing a message via ``listen()`` or 

1128 ``get_message()``. 

1129 """ 

1130 parsed_args = list_or_args((args[0],), args[1:]) if args else () 

1131 new_channels = dict.fromkeys(parsed_args) 

1132 # Mypy bug: https://github.com/python/mypy/issues/10970 

1133 new_channels.update(kwargs) # type: ignore[arg-type] 

1134 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys()) 

1135 # update the channels dict AFTER we send the command. we don't want to 

1136 # subscribe twice to these channels, once for the command and again 

1137 # for the reconnection. 

1138 new_channels = self._normalize_keys(new_channels) 

1139 self.channels.update(new_channels) 

1140 self.pending_unsubscribe_channels.difference_update(new_channels) 

1141 return ret_val 

1142 

1143 def unsubscribe(self, *args) -> Awaitable: 

1144 """ 

1145 Unsubscribe from the supplied channels. If empty, unsubscribe from 

1146 all channels 

1147 """ 

1148 if args: 

1149 parsed_args = list_or_args(args[0], args[1:]) 

1150 channels = self._normalize_keys(dict.fromkeys(parsed_args)) 

1151 else: 

1152 parsed_args = [] 

1153 channels = self.channels 

1154 self.pending_unsubscribe_channels.update(channels) 

1155 return self.execute_command("UNSUBSCRIBE", *parsed_args) 

1156 

1157 async def listen(self) -> AsyncIterator: 

1158 """Listen for messages on channels this client has been subscribed to""" 

1159 while self.subscribed: 

1160 response = await self.handle_message(await self.parse_response(block=True)) 

1161 if response is not None: 

1162 yield response 

1163 

1164 async def get_message( 

1165 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0 

1166 ): 

1167 """ 

1168 Get the next message if one is available, otherwise None. 

1169 

1170 If timeout is specified, the system will wait for `timeout` seconds 

1171 before returning. Timeout should be specified as a floating point 

1172 number or None to wait indefinitely. 

1173 """ 

1174 response = await self.parse_response(block=(timeout is None), timeout=timeout) 

1175 if response: 

1176 return await self.handle_message(response, ignore_subscribe_messages) 

1177 return None 

1178 

1179 def ping(self, message=None) -> Awaitable[bool]: 

1180 """ 

1181 Ping the Redis server to test connectivity. 

1182 

1183 Sends a PING command to the Redis server and returns True if the server 

1184 responds with "PONG". 

1185 """ 

1186 args = ["PING", message] if message is not None else ["PING"] 

1187 return self.execute_command(*args) 

1188 

1189 async def handle_message(self, response, ignore_subscribe_messages=False): 

1190 """ 

1191 Parses a pub/sub message. If the channel or pattern was subscribed to 

1192 with a message handler, the handler is invoked instead of a parsed 

1193 message being returned. 

1194 """ 

1195 if response is None: 

1196 return None 

1197 if isinstance(response, bytes): 

1198 response = [b"pong", response] if response != b"PONG" else [b"pong", b""] 

1199 message_type = str_if_bytes(response[0]) 

1200 if message_type == "pmessage": 

1201 message = { 

1202 "type": message_type, 

1203 "pattern": response[1], 

1204 "channel": response[2], 

1205 "data": response[3], 

1206 } 

1207 elif message_type == "pong": 

1208 message = { 

1209 "type": message_type, 

1210 "pattern": None, 

1211 "channel": None, 

1212 "data": response[1], 

1213 } 

1214 else: 

1215 message = { 

1216 "type": message_type, 

1217 "pattern": None, 

1218 "channel": response[1], 

1219 "data": response[2], 

1220 } 

1221 

1222 # if this is an unsubscribe message, remove it from memory 

1223 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: 

1224 if message_type == "punsubscribe": 

1225 pattern = response[1] 

1226 if pattern in self.pending_unsubscribe_patterns: 

1227 self.pending_unsubscribe_patterns.remove(pattern) 

1228 self.patterns.pop(pattern, None) 

1229 else: 

1230 channel = response[1] 

1231 if channel in self.pending_unsubscribe_channels: 

1232 self.pending_unsubscribe_channels.remove(channel) 

1233 self.channels.pop(channel, None) 

1234 

1235 if message_type in self.PUBLISH_MESSAGE_TYPES: 

1236 # if there's a message handler, invoke it 

1237 if message_type == "pmessage": 

1238 handler = self.patterns.get(message["pattern"], None) 

1239 else: 

1240 handler = self.channels.get(message["channel"], None) 

1241 if handler: 

1242 if inspect.iscoroutinefunction(handler): 

1243 await handler(message) 

1244 else: 

1245 handler(message) 

1246 return None 

1247 elif message_type != "pong": 

1248 # this is a subscribe/unsubscribe message. ignore if we don't 

1249 # want them 

1250 if ignore_subscribe_messages or self.ignore_subscribe_messages: 

1251 return None 

1252 

1253 return message 

1254 

1255 async def run( 

1256 self, 

1257 *, 

1258 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, 

1259 poll_timeout: float = 1.0, 

1260 pubsub=None, 

1261 ) -> None: 

1262 """Process pub/sub messages using registered callbacks. 

1263 

1264 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in 

1265 redis-py, but it is a coroutine. To launch it as a separate task, use 

1266 ``asyncio.create_task``: 

1267 

1268 >>> task = asyncio.create_task(pubsub.run()) 

1269 

1270 To shut it down, use asyncio cancellation: 

1271 

1272 >>> task.cancel() 

1273 >>> await task 

1274 """ 

1275 for channel, handler in self.channels.items(): 

1276 if handler is None: 

1277 raise PubSubError(f"Channel: '{channel}' has no handler registered") 

1278 for pattern, handler in self.patterns.items(): 

1279 if handler is None: 

1280 raise PubSubError(f"Pattern: '{pattern}' has no handler registered") 

1281 

1282 await self.connect() 

1283 while True: 

1284 try: 

1285 if pubsub is None: 

1286 await self.get_message( 

1287 ignore_subscribe_messages=True, timeout=poll_timeout 

1288 ) 

1289 else: 

1290 await pubsub.get_message( 

1291 ignore_subscribe_messages=True, timeout=poll_timeout 

1292 ) 

1293 except asyncio.CancelledError: 

1294 raise 

1295 except BaseException as e: 

1296 if exception_handler is None: 

1297 raise 

1298 res = exception_handler(e, self) 

1299 if inspect.isawaitable(res): 

1300 await res 

1301 # Ensure that other tasks on the event loop get a chance to run 

1302 # if we didn't have to block for I/O anywhere. 

1303 await asyncio.sleep(0) 

1304 

1305 

1306class PubsubWorkerExceptionHandler(Protocol): 

1307 def __call__(self, e: BaseException, pubsub: PubSub): ... 

1308 

1309 

1310class AsyncPubsubWorkerExceptionHandler(Protocol): 

1311 async def __call__(self, e: BaseException, pubsub: PubSub): ... 

1312 

1313 

1314PSWorkerThreadExcHandlerT = Union[ 

1315 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler 

1316] 

1317 

1318 

1319CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]] 

1320CommandStackT = List[CommandT] 

1321 

1322 

1323class Pipeline(Redis): # lgtm [py/init-calls-subclass] 

1324 """ 

1325 Pipelines provide a way to transmit multiple commands to the Redis server 

1326 in one transmission. This is convenient for batch processing, such as 

1327 saving all the values in a list to Redis. 

1328 

1329 All commands executed within a pipeline(when running in transactional mode, 

1330 which is the default behavior) are wrapped with MULTI and EXEC 

1331 calls. This guarantees all commands executed in the pipeline will be 

1332 executed atomically. 

1333 

1334 Any command raising an exception does *not* halt the execution of 

1335 subsequent commands in the pipeline. Instead, the exception is caught 

1336 and its instance is placed into the response list returned by execute(). 

1337 Code iterating over the response list should be able to deal with an 

1338 instance of an exception as a potential value. In general, these will be 

1339 ResponseError exceptions, such as those raised when issuing a command 

1340 on a key of a different datatype. 

1341 """ 

1342 

1343 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} 

1344 

1345 def __init__( 

1346 self, 

1347 connection_pool: ConnectionPool, 

1348 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT], 

1349 transaction: bool, 

1350 shard_hint: Optional[str], 

1351 ): 

1352 self.connection_pool = connection_pool 

1353 self.connection = None 

1354 self.response_callbacks = response_callbacks 

1355 self.is_transaction = transaction 

1356 self.shard_hint = shard_hint 

1357 self.watching = False 

1358 self.command_stack: CommandStackT = [] 

1359 self.scripts: Set[Script] = set() 

1360 self.explicit_transaction = False 

1361 

1362 async def __aenter__(self: _RedisT) -> _RedisT: 

1363 return self 

1364 

1365 async def __aexit__(self, exc_type, exc_value, traceback): 

1366 await self.reset() 

1367 

1368 def __await__(self): 

1369 return self._async_self().__await__() 

1370 

1371 _DEL_MESSAGE = "Unclosed Pipeline client" 

1372 

1373 def __len__(self): 

1374 return len(self.command_stack) 

1375 

1376 def __bool__(self): 

1377 """Pipeline instances should always evaluate to True""" 

1378 return True 

1379 

1380 async def _async_self(self): 

1381 return self 

1382 

1383 async def reset(self): 

1384 self.command_stack = [] 

1385 self.scripts = set() 

1386 # make sure to reset the connection state in the event that we were 

1387 # watching something 

1388 if self.watching and self.connection: 

1389 try: 

1390 # call this manually since our unwatch or 

1391 # immediate_execute_command methods can call reset() 

1392 await self.connection.send_command("UNWATCH") 

1393 await self.connection.read_response() 

1394 except ConnectionError: 

1395 # disconnect will also remove any previous WATCHes 

1396 if self.connection: 

1397 await self.connection.disconnect() 

1398 # clean up the other instance attributes 

1399 self.watching = False 

1400 self.explicit_transaction = False 

1401 # we can safely return the connection to the pool here since we're 

1402 # sure we're no longer WATCHing anything 

1403 if self.connection: 

1404 await self.connection_pool.release(self.connection) 

1405 self.connection = None 

1406 

1407 async def aclose(self) -> None: 

1408 """Alias for reset(), a standard method name for cleanup""" 

1409 await self.reset() 

1410 

1411 def multi(self): 

1412 """ 

1413 Start a transactional block of the pipeline after WATCH commands 

1414 are issued. End the transactional block with `execute`. 

1415 """ 

1416 if self.explicit_transaction: 

1417 raise RedisError("Cannot issue nested calls to MULTI") 

1418 if self.command_stack: 

1419 raise RedisError( 

1420 "Commands without an initial WATCH have already been issued" 

1421 ) 

1422 self.explicit_transaction = True 

1423 

1424 def execute_command( 

1425 self, *args, **kwargs 

1426 ) -> Union["Pipeline", Awaitable["Pipeline"]]: 

1427 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: 

1428 return self.immediate_execute_command(*args, **kwargs) 

1429 return self.pipeline_execute_command(*args, **kwargs) 

1430 

1431 async def _disconnect_reset_raise_on_watching( 

1432 self, 

1433 conn: Connection, 

1434 error: Exception, 

1435 ): 

1436 """ 

1437 Close the connection reset watching state and 

1438 raise an exception if we were watching. 

1439 

1440 The supported exceptions are already checked in the 

1441 retry object so we don't need to do it here. 

1442 

1443 After we disconnect the connection, it will try to reconnect and 

1444 do a health check as part of the send_command logic(on connection level). 

1445 """ 

1446 await conn.disconnect() 

1447 # if we were already watching a variable, the watch is no longer 

1448 # valid since this connection has died. raise a WatchError, which 

1449 # indicates the user should retry this transaction. 

1450 if self.watching: 

1451 await self.reset() 

1452 raise WatchError( 

1453 f"A {type(error).__name__} occurred while watching one or more keys" 

1454 ) 

1455 

1456 async def immediate_execute_command(self, *args, **options): 

1457 """ 

1458 Execute a command immediately, but don't auto-retry on the supported 

1459 errors for retry if we're already WATCHing a variable. 

1460 Used when issuing WATCH or subsequent commands retrieving their values but before 

1461 MULTI is called. 

1462 """ 

1463 command_name = args[0] 

1464 conn = self.connection 

1465 # if this is the first call, we need a connection 

1466 if not conn: 

1467 conn = await self.connection_pool.get_connection() 

1468 self.connection = conn 

1469 

1470 return await conn.retry.call_with_retry( 

1471 lambda: self._send_command_parse_response( 

1472 conn, command_name, *args, **options 

1473 ), 

1474 lambda error: self._disconnect_reset_raise_on_watching(conn, error), 

1475 ) 

1476 

1477 def pipeline_execute_command(self, *args, **options): 

1478 """ 

1479 Stage a command to be executed when execute() is next called 

1480 

1481 Returns the current Pipeline object back so commands can be 

1482 chained together, such as: 

1483 

1484 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') 

1485 

1486 At some other point, you can then run: pipe.execute(), 

1487 which will execute all commands queued in the pipe. 

1488 """ 

1489 self.command_stack.append((args, options)) 

1490 return self 

1491 

1492 async def _execute_transaction( # noqa: C901 

1493 self, connection: Connection, commands: CommandStackT, raise_on_error 

1494 ): 

1495 pre: CommandT = (("MULTI",), {}) 

1496 post: CommandT = (("EXEC",), {}) 

1497 cmds = (pre, *commands, post) 

1498 all_cmds = connection.pack_commands( 

1499 args for args, options in cmds if EMPTY_RESPONSE not in options 

1500 ) 

1501 await connection.send_packed_command(all_cmds) 

1502 errors = [] 

1503 

1504 # parse off the response for MULTI 

1505 # NOTE: we need to handle ResponseErrors here and continue 

1506 # so that we read all the additional command messages from 

1507 # the socket 

1508 try: 

1509 await self.parse_response(connection, "_") 

1510 except ResponseError as err: 

1511 errors.append((0, err)) 

1512 

1513 # and all the other commands 

1514 for i, command in enumerate(commands): 

1515 if EMPTY_RESPONSE in command[1]: 

1516 errors.append((i, command[1][EMPTY_RESPONSE])) 

1517 else: 

1518 try: 

1519 await self.parse_response(connection, "_") 

1520 except ResponseError as err: 

1521 self.annotate_exception(err, i + 1, command[0]) 

1522 errors.append((i, err)) 

1523 

1524 # parse the EXEC. 

1525 try: 

1526 response = await self.parse_response(connection, "_") 

1527 except ExecAbortError as err: 

1528 if errors: 

1529 raise errors[0][1] from err 

1530 raise 

1531 

1532 # EXEC clears any watched keys 

1533 self.watching = False 

1534 

1535 if response is None: 

1536 raise WatchError("Watched variable changed.") from None 

1537 

1538 # put any parse errors into the response 

1539 for i, e in errors: 

1540 response.insert(i, e) 

1541 

1542 if len(response) != len(commands): 

1543 if self.connection: 

1544 await self.connection.disconnect() 

1545 raise ResponseError( 

1546 "Wrong number of response items from pipeline execution" 

1547 ) from None 

1548 

1549 # find any errors in the response and raise if necessary 

1550 if raise_on_error: 

1551 self.raise_first_error(commands, response) 

1552 

1553 # We have to run response callbacks manually 

1554 data = [] 

1555 for r, cmd in zip(response, commands): 

1556 if not isinstance(r, Exception): 

1557 args, options = cmd 

1558 command_name = args[0] 

1559 

1560 # Remove keys entry, it needs only for cache. 

1561 options.pop("keys", None) 

1562 

1563 if command_name in self.response_callbacks: 

1564 r = self.response_callbacks[command_name](r, **options) 

1565 if inspect.isawaitable(r): 

1566 r = await r 

1567 data.append(r) 

1568 return data 

1569 

1570 async def _execute_pipeline( 

1571 self, connection: Connection, commands: CommandStackT, raise_on_error: bool 

1572 ): 

1573 # build up all commands into a single request to increase network perf 

1574 all_cmds = connection.pack_commands([args for args, _ in commands]) 

1575 await connection.send_packed_command(all_cmds) 

1576 

1577 response = [] 

1578 for args, options in commands: 

1579 try: 

1580 response.append( 

1581 await self.parse_response(connection, args[0], **options) 

1582 ) 

1583 except ResponseError as e: 

1584 response.append(e) 

1585 

1586 if raise_on_error: 

1587 self.raise_first_error(commands, response) 

1588 return response 

1589 

1590 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): 

1591 for i, r in enumerate(response): 

1592 if isinstance(r, ResponseError): 

1593 self.annotate_exception(r, i + 1, commands[i][0]) 

1594 raise r 

1595 

1596 def annotate_exception( 

1597 self, exception: Exception, number: int, command: Iterable[object] 

1598 ) -> None: 

1599 cmd = " ".join(map(safe_str, command)) 

1600 msg = ( 

1601 f"Command # {number} ({truncate_text(cmd)}) " 

1602 f"of pipeline caused error: {exception.args}" 

1603 ) 

1604 exception.args = (msg,) + exception.args[1:] 

1605 

1606 async def parse_response( 

1607 self, connection: Connection, command_name: Union[str, bytes], **options 

1608 ): 

1609 result = await super().parse_response(connection, command_name, **options) 

1610 if command_name in self.UNWATCH_COMMANDS: 

1611 self.watching = False 

1612 elif command_name == "WATCH": 

1613 self.watching = True 

1614 return result 

1615 

1616 async def load_scripts(self): 

1617 # make sure all scripts that are about to be run on this pipeline exist 

1618 scripts = list(self.scripts) 

1619 immediate = self.immediate_execute_command 

1620 shas = [s.sha for s in scripts] 

1621 # we can't use the normal script_* methods because they would just 

1622 # get buffered in the pipeline. 

1623 exists = await immediate("SCRIPT EXISTS", *shas) 

1624 if not all(exists): 

1625 for s, exist in zip(scripts, exists): 

1626 if not exist: 

1627 s.sha = await immediate("SCRIPT LOAD", s.script) 

1628 

1629 async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception): 

1630 """ 

1631 Close the connection, raise an exception if we were watching. 

1632 

1633 The supported exceptions are already checked in the 

1634 retry object so we don't need to do it here. 

1635 

1636 After we disconnect the connection, it will try to reconnect and 

1637 do a health check as part of the send_command logic(on connection level). 

1638 """ 

1639 await conn.disconnect() 

1640 # if we were watching a variable, the watch is no longer valid 

1641 # since this connection has died. raise a WatchError, which 

1642 # indicates the user should retry this transaction. 

1643 if self.watching: 

1644 raise WatchError( 

1645 f"A {type(error).__name__} occurred while watching one or more keys" 

1646 ) 

1647 

1648 async def execute(self, raise_on_error: bool = True) -> List[Any]: 

1649 """Execute all the commands in the current pipeline""" 

1650 stack = self.command_stack 

1651 if not stack and not self.watching: 

1652 return [] 

1653 if self.scripts: 

1654 await self.load_scripts() 

1655 if self.is_transaction or self.explicit_transaction: 

1656 execute = self._execute_transaction 

1657 else: 

1658 execute = self._execute_pipeline 

1659 

1660 conn = self.connection 

1661 if not conn: 

1662 conn = await self.connection_pool.get_connection() 

1663 # assign to self.connection so reset() releases the connection 

1664 # back to the pool after we're done 

1665 self.connection = conn 

1666 conn = cast(Connection, conn) 

1667 

1668 try: 

1669 return await conn.retry.call_with_retry( 

1670 lambda: execute(conn, stack, raise_on_error), 

1671 lambda error: self._disconnect_raise_on_watching(conn, error), 

1672 ) 

1673 finally: 

1674 await self.reset() 

1675 

1676 async def discard(self): 

1677 """Flushes all previously queued commands 

1678 See: https://redis.io/commands/DISCARD 

1679 """ 

1680 await self.execute_command("DISCARD") 

1681 

1682 async def watch(self, *names: KeyT): 

1683 """Watches the values at keys ``names``""" 

1684 if self.explicit_transaction: 

1685 raise RedisError("Cannot issue a WATCH after a MULTI") 

1686 return await self.execute_command("WATCH", *names) 

1687 

1688 async def unwatch(self): 

1689 """Unwatches all previously specified keys""" 

1690 return self.watching and await self.execute_command("UNWATCH") or True