Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/redis/asyncio/client.py: 23%

660 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-23 06:16 +0000

1import asyncio 

2import copy 

3import inspect 

4import re 

5import ssl 

6import warnings 

7from typing import ( 

8 TYPE_CHECKING, 

9 Any, 

10 AsyncIterator, 

11 Awaitable, 

12 Callable, 

13 Dict, 

14 Iterable, 

15 List, 

16 Mapping, 

17 MutableMapping, 

18 Optional, 

19 Protocol, 

20 Set, 

21 Tuple, 

22 Type, 

23 TypedDict, 

24 TypeVar, 

25 Union, 

26 cast, 

27) 

28 

29from redis._cache import ( 

30 DEFAULT_BLACKLIST, 

31 DEFAULT_EVICTION_POLICY, 

32 DEFAULT_WHITELIST, 

33 AbstractCache, 

34) 

35from redis._parsers.helpers import ( 

36 _RedisCallbacks, 

37 _RedisCallbacksRESP2, 

38 _RedisCallbacksRESP3, 

39 bool_ok, 

40) 

41from redis.asyncio.connection import ( 

42 Connection, 

43 ConnectionPool, 

44 SSLConnection, 

45 UnixDomainSocketConnection, 

46) 

47from redis.asyncio.lock import Lock 

48from redis.asyncio.retry import Retry 

49from redis.client import ( 

50 EMPTY_RESPONSE, 

51 NEVER_DECODE, 

52 AbstractRedis, 

53 CaseInsensitiveDict, 

54) 

55from redis.commands import ( 

56 AsyncCoreCommands, 

57 AsyncRedisModuleCommands, 

58 AsyncSentinelCommands, 

59 list_or_args, 

60) 

61from redis.credentials import CredentialProvider 

62from redis.exceptions import ( 

63 ConnectionError, 

64 ExecAbortError, 

65 PubSubError, 

66 RedisError, 

67 ResponseError, 

68 TimeoutError, 

69 WatchError, 

70) 

71from redis.typing import ChannelT, EncodableT, KeyT 

72from redis.utils import ( 

73 HIREDIS_AVAILABLE, 

74 _set_info_logger, 

75 deprecated_function, 

76 get_lib_version, 

77 safe_str, 

78 str_if_bytes, 

79) 

80 

81PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] 

82_KeyT = TypeVar("_KeyT", bound=KeyT) 

83_ArgT = TypeVar("_ArgT", KeyT, EncodableT) 

84_RedisT = TypeVar("_RedisT", bound="Redis") 

85_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) 

86if TYPE_CHECKING: 

87 from redis.commands.core import Script 

88 

89 

90class ResponseCallbackProtocol(Protocol): 

91 def __call__(self, response: Any, **kwargs): ... 

92 

93 

94class AsyncResponseCallbackProtocol(Protocol): 

95 async def __call__(self, response: Any, **kwargs): ... 

96 

97 

98ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] 

99 

100 

101class Redis( 

102 AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands 

103): 

104 """ 

105 Implementation of the Redis protocol. 

106 

107 This abstract class provides a Python interface to all Redis commands 

108 and an implementation of the Redis protocol. 

109 

110 Pipelines derive from this, implementing how 

111 the commands are sent and received to the Redis server. Based on 

112 configuration, an instance will either use a ConnectionPool, or 

113 Connection object to talk to redis. 

114 """ 

115 

116 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] 

117 

118 @classmethod 

119 def from_url( 

120 cls, 

121 url: str, 

122 single_connection_client: bool = False, 

123 auto_close_connection_pool: Optional[bool] = None, 

124 **kwargs, 

125 ): 

126 """ 

127 Return a Redis client object configured from the given URL 

128 

129 For example:: 

130 

131 redis://[[username]:[password]]@localhost:6379/0 

132 rediss://[[username]:[password]]@localhost:6379/0 

133 unix://[username@]/path/to/socket.sock?db=0[&password=password] 

134 

135 Three URL schemes are supported: 

136 

137 - `redis://` creates a TCP socket connection. See more at: 

138 <https://www.iana.org/assignments/uri-schemes/prov/redis> 

139 - `rediss://` creates a SSL wrapped TCP socket connection. See more at: 

140 <https://www.iana.org/assignments/uri-schemes/prov/rediss> 

141 - ``unix://``: creates a Unix Domain Socket connection. 

142 

143 The username, password, hostname, path and all querystring values 

144 are passed through urllib.parse.unquote in order to replace any 

145 percent-encoded values with their corresponding characters. 

146 

147 There are several ways to specify a database number. The first value 

148 found will be used: 

149 

150 1. A ``db`` querystring option, e.g. redis://localhost?db=0 

151 

152 2. If using the redis:// or rediss:// schemes, the path argument 

153 of the url, e.g. redis://localhost/0 

154 

155 3. A ``db`` keyword argument to this function. 

156 

157 If none of these options are specified, the default db=0 is used. 

158 

159 All querystring options are cast to their appropriate Python types. 

160 Boolean arguments can be specified with string values "True"/"False" 

161 or "Yes"/"No". Values that cannot be properly cast cause a 

162 ``ValueError`` to be raised. Once parsed, the querystring arguments 

163 and keyword arguments are passed to the ``ConnectionPool``'s 

164 class initializer. In the case of conflicting arguments, querystring 

165 arguments always win. 

166 

167 """ 

168 connection_pool = ConnectionPool.from_url(url, **kwargs) 

169 client = cls( 

170 connection_pool=connection_pool, 

171 single_connection_client=single_connection_client, 

172 ) 

173 if auto_close_connection_pool is not None: 

174 warnings.warn( 

175 DeprecationWarning( 

176 '"auto_close_connection_pool" is deprecated ' 

177 "since version 5.0.1. " 

178 "Please create a ConnectionPool explicitly and " 

179 "provide to the Redis() constructor instead." 

180 ) 

181 ) 

182 else: 

183 auto_close_connection_pool = True 

184 client.auto_close_connection_pool = auto_close_connection_pool 

185 return client 

186 

187 @classmethod 

188 def from_pool( 

189 cls: Type["Redis"], 

190 connection_pool: ConnectionPool, 

191 ) -> "Redis": 

192 """ 

193 Return a Redis client from the given connection pool. 

194 The Redis client will take ownership of the connection pool and 

195 close it when the Redis client is closed. 

196 """ 

197 client = cls( 

198 connection_pool=connection_pool, 

199 ) 

200 client.auto_close_connection_pool = True 

201 return client 

202 

203 def __init__( 

204 self, 

205 *, 

206 host: str = "localhost", 

207 port: int = 6379, 

208 db: Union[str, int] = 0, 

209 password: Optional[str] = None, 

210 socket_timeout: Optional[float] = None, 

211 socket_connect_timeout: Optional[float] = None, 

212 socket_keepalive: Optional[bool] = None, 

213 socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, 

214 connection_pool: Optional[ConnectionPool] = None, 

215 unix_socket_path: Optional[str] = None, 

216 encoding: str = "utf-8", 

217 encoding_errors: str = "strict", 

218 decode_responses: bool = False, 

219 retry_on_timeout: bool = False, 

220 retry_on_error: Optional[list] = None, 

221 ssl: bool = False, 

222 ssl_keyfile: Optional[str] = None, 

223 ssl_certfile: Optional[str] = None, 

224 ssl_cert_reqs: str = "required", 

225 ssl_ca_certs: Optional[str] = None, 

226 ssl_ca_data: Optional[str] = None, 

227 ssl_check_hostname: bool = False, 

228 ssl_min_version: Optional[ssl.TLSVersion] = None, 

229 max_connections: Optional[int] = None, 

230 single_connection_client: bool = False, 

231 health_check_interval: int = 0, 

232 client_name: Optional[str] = None, 

233 lib_name: Optional[str] = "redis-py", 

234 lib_version: Optional[str] = get_lib_version(), 

235 username: Optional[str] = None, 

236 retry: Optional[Retry] = None, 

237 auto_close_connection_pool: Optional[bool] = None, 

238 redis_connect_func=None, 

239 credential_provider: Optional[CredentialProvider] = None, 

240 protocol: Optional[int] = 2, 

241 cache_enabled: bool = False, 

242 client_cache: Optional[AbstractCache] = None, 

243 cache_max_size: int = 100, 

244 cache_ttl: int = 0, 

245 cache_policy: str = DEFAULT_EVICTION_POLICY, 

246 cache_blacklist: List[str] = DEFAULT_BLACKLIST, 

247 cache_whitelist: List[str] = DEFAULT_WHITELIST, 

248 ): 

249 """ 

250 Initialize a new Redis client. 

251 To specify a retry policy for specific errors, first set 

252 `retry_on_error` to a list of the error/s to retry on, then set 

253 `retry` to a valid `Retry` object. 

254 To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. 

255 """ 

256 kwargs: Dict[str, Any] 

257 # auto_close_connection_pool only has an effect if connection_pool is 

258 # None. It is assumed that if connection_pool is not None, the user 

259 # wants to manage the connection pool themselves. 

260 if auto_close_connection_pool is not None: 

261 warnings.warn( 

262 DeprecationWarning( 

263 '"auto_close_connection_pool" is deprecated ' 

264 "since version 5.0.1. " 

265 "Please create a ConnectionPool explicitly and " 

266 "provide to the Redis() constructor instead." 

267 ) 

268 ) 

269 else: 

270 auto_close_connection_pool = True 

271 

272 if not connection_pool: 

273 # Create internal connection pool, expected to be closed by Redis instance 

274 if not retry_on_error: 

275 retry_on_error = [] 

276 if retry_on_timeout is True: 

277 retry_on_error.append(TimeoutError) 

278 kwargs = { 

279 "db": db, 

280 "username": username, 

281 "password": password, 

282 "credential_provider": credential_provider, 

283 "socket_timeout": socket_timeout, 

284 "encoding": encoding, 

285 "encoding_errors": encoding_errors, 

286 "decode_responses": decode_responses, 

287 "retry_on_timeout": retry_on_timeout, 

288 "retry_on_error": retry_on_error, 

289 "retry": copy.deepcopy(retry), 

290 "max_connections": max_connections, 

291 "health_check_interval": health_check_interval, 

292 "client_name": client_name, 

293 "lib_name": lib_name, 

294 "lib_version": lib_version, 

295 "redis_connect_func": redis_connect_func, 

296 "protocol": protocol, 

297 "cache_enabled": cache_enabled, 

298 "client_cache": client_cache, 

299 "cache_max_size": cache_max_size, 

300 "cache_ttl": cache_ttl, 

301 "cache_policy": cache_policy, 

302 "cache_blacklist": cache_blacklist, 

303 "cache_whitelist": cache_whitelist, 

304 } 

305 # based on input, setup appropriate connection args 

306 if unix_socket_path is not None: 

307 kwargs.update( 

308 { 

309 "path": unix_socket_path, 

310 "connection_class": UnixDomainSocketConnection, 

311 } 

312 ) 

313 else: 

314 # TCP specific options 

315 kwargs.update( 

316 { 

317 "host": host, 

318 "port": port, 

319 "socket_connect_timeout": socket_connect_timeout, 

320 "socket_keepalive": socket_keepalive, 

321 "socket_keepalive_options": socket_keepalive_options, 

322 } 

323 ) 

324 

325 if ssl: 

326 kwargs.update( 

327 { 

328 "connection_class": SSLConnection, 

329 "ssl_keyfile": ssl_keyfile, 

330 "ssl_certfile": ssl_certfile, 

331 "ssl_cert_reqs": ssl_cert_reqs, 

332 "ssl_ca_certs": ssl_ca_certs, 

333 "ssl_ca_data": ssl_ca_data, 

334 "ssl_check_hostname": ssl_check_hostname, 

335 "ssl_min_version": ssl_min_version, 

336 } 

337 ) 

338 # This arg only used if no pool is passed in 

339 self.auto_close_connection_pool = auto_close_connection_pool 

340 connection_pool = ConnectionPool(**kwargs) 

341 else: 

342 # If a pool is passed in, do not close it 

343 self.auto_close_connection_pool = False 

344 

345 self.connection_pool = connection_pool 

346 self.single_connection_client = single_connection_client 

347 self.connection: Optional[Connection] = None 

348 

349 self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) 

350 

351 if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: 

352 self.response_callbacks.update(_RedisCallbacksRESP3) 

353 else: 

354 self.response_callbacks.update(_RedisCallbacksRESP2) 

355 

356 # If using a single connection client, we need to lock creation-of and use-of 

357 # the client in order to avoid race conditions such as using asyncio.gather 

358 # on a set of redis commands 

359 self._single_conn_lock = asyncio.Lock() 

360 

361 def __repr__(self): 

362 return ( 

363 f"<{self.__class__.__module__}.{self.__class__.__name__}" 

364 f"({self.connection_pool!r})>" 

365 ) 

366 

367 def __await__(self): 

368 return self.initialize().__await__() 

369 

370 async def initialize(self: _RedisT) -> _RedisT: 

371 if self.single_connection_client: 

372 async with self._single_conn_lock: 

373 if self.connection is None: 

374 self.connection = await self.connection_pool.get_connection("_") 

375 return self 

376 

377 def set_response_callback(self, command: str, callback: ResponseCallbackT): 

378 """Set a custom Response Callback""" 

379 self.response_callbacks[command] = callback 

380 

381 def get_encoder(self): 

382 """Get the connection pool's encoder""" 

383 return self.connection_pool.get_encoder() 

384 

385 def get_connection_kwargs(self): 

386 """Get the connection's key-word arguments""" 

387 return self.connection_pool.connection_kwargs 

388 

389 def get_retry(self) -> Optional["Retry"]: 

390 return self.get_connection_kwargs().get("retry") 

391 

392 def set_retry(self, retry: "Retry") -> None: 

393 self.get_connection_kwargs().update({"retry": retry}) 

394 self.connection_pool.set_retry(retry) 

395 

396 def load_external_module(self, funcname, func): 

397 """ 

398 This function can be used to add externally defined redis modules, 

399 and their namespaces to the redis client. 

400 

401 funcname - A string containing the name of the function to create 

402 func - The function, being added to this class. 

403 

404 ex: Assume that one has a custom redis module named foomod that 

405 creates command named 'foo.dothing' and 'foo.anotherthing' in redis. 

406 To load function functions into this namespace: 

407 

408 from redis import Redis 

409 from foomodule import F 

410 r = Redis() 

411 r.load_external_module("foo", F) 

412 r.foo().dothing('your', 'arguments') 

413 

414 For a concrete example see the reimport of the redisjson module in 

415 tests/test_connection.py::test_loading_external_modules 

416 """ 

417 setattr(self, funcname, func) 

418 

419 def pipeline( 

420 self, transaction: bool = True, shard_hint: Optional[str] = None 

421 ) -> "Pipeline": 

422 """ 

423 Return a new pipeline object that can queue multiple commands for 

424 later execution. ``transaction`` indicates whether all commands 

425 should be executed atomically. Apart from making a group of operations 

426 atomic, pipelines are useful for reducing the back-and-forth overhead 

427 between the client and server. 

428 """ 

429 return Pipeline( 

430 self.connection_pool, self.response_callbacks, transaction, shard_hint 

431 ) 

432 

433 async def transaction( 

434 self, 

435 func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], 

436 *watches: KeyT, 

437 shard_hint: Optional[str] = None, 

438 value_from_callable: bool = False, 

439 watch_delay: Optional[float] = None, 

440 ): 

441 """ 

442 Convenience method for executing the callable `func` as a transaction 

443 while watching all keys specified in `watches`. The 'func' callable 

444 should expect a single argument which is a Pipeline object. 

445 """ 

446 pipe: Pipeline 

447 async with self.pipeline(True, shard_hint) as pipe: 

448 while True: 

449 try: 

450 if watches: 

451 await pipe.watch(*watches) 

452 func_value = func(pipe) 

453 if inspect.isawaitable(func_value): 

454 func_value = await func_value 

455 exec_value = await pipe.execute() 

456 return func_value if value_from_callable else exec_value 

457 except WatchError: 

458 if watch_delay is not None and watch_delay > 0: 

459 await asyncio.sleep(watch_delay) 

460 continue 

461 

462 def lock( 

463 self, 

464 name: KeyT, 

465 timeout: Optional[float] = None, 

466 sleep: float = 0.1, 

467 blocking: bool = True, 

468 blocking_timeout: Optional[float] = None, 

469 lock_class: Optional[Type[Lock]] = None, 

470 thread_local: bool = True, 

471 ) -> Lock: 

472 """ 

473 Return a new Lock object using key ``name`` that mimics 

474 the behavior of threading.Lock. 

475 

476 If specified, ``timeout`` indicates a maximum life for the lock. 

477 By default, it will remain locked until release() is called. 

478 

479 ``sleep`` indicates the amount of time to sleep per loop iteration 

480 when the lock is in blocking mode and another client is currently 

481 holding the lock. 

482 

483 ``blocking`` indicates whether calling ``acquire`` should block until 

484 the lock has been acquired or to fail immediately, causing ``acquire`` 

485 to return False and the lock not being acquired. Defaults to True. 

486 Note this value can be overridden by passing a ``blocking`` 

487 argument to ``acquire``. 

488 

489 ``blocking_timeout`` indicates the maximum amount of time in seconds to 

490 spend trying to acquire the lock. A value of ``None`` indicates 

491 continue trying forever. ``blocking_timeout`` can be specified as a 

492 float or integer, both representing the number of seconds to wait. 

493 

494 ``lock_class`` forces the specified lock implementation. Note that as 

495 of redis-py 3.0, the only lock class we implement is ``Lock`` (which is 

496 a Lua-based lock). So, it's unlikely you'll need this parameter, unless 

497 you have created your own custom lock class. 

498 

499 ``thread_local`` indicates whether the lock token is placed in 

500 thread-local storage. By default, the token is placed in thread local 

501 storage so that a thread only sees its token, not a token set by 

502 another thread. Consider the following timeline: 

503 

504 time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. 

505 thread-1 sets the token to "abc" 

506 time: 1, thread-2 blocks trying to acquire `my-lock` using the 

507 Lock instance. 

508 time: 5, thread-1 has not yet completed. redis expires the lock 

509 key. 

510 time: 5, thread-2 acquired `my-lock` now that it's available. 

511 thread-2 sets the token to "xyz" 

512 time: 6, thread-1 finishes its work and calls release(). if the 

513 token is *not* stored in thread local storage, then 

514 thread-1 would see the token value as "xyz" and would be 

515 able to successfully release the thread-2's lock. 

516 

517 In some use cases it's necessary to disable thread local storage. For 

518 example, if you have code where one thread acquires a lock and passes 

519 that lock instance to a worker thread to release later. If thread 

520 local storage isn't disabled in this case, the worker thread won't see 

521 the token set by the thread that acquired the lock. Our assumption 

522 is that these cases aren't common and as such default to using 

523 thread local storage.""" 

524 if lock_class is None: 

525 lock_class = Lock 

526 return lock_class( 

527 self, 

528 name, 

529 timeout=timeout, 

530 sleep=sleep, 

531 blocking=blocking, 

532 blocking_timeout=blocking_timeout, 

533 thread_local=thread_local, 

534 ) 

535 

536 def pubsub(self, **kwargs) -> "PubSub": 

537 """ 

538 Return a Publish/Subscribe object. With this object, you can 

539 subscribe to channels and listen for messages that get published to 

540 them. 

541 """ 

542 return PubSub(self.connection_pool, **kwargs) 

543 

544 def monitor(self) -> "Monitor": 

545 return Monitor(self.connection_pool) 

546 

547 def client(self) -> "Redis": 

548 return self.__class__( 

549 connection_pool=self.connection_pool, single_connection_client=True 

550 ) 

551 

552 async def __aenter__(self: _RedisT) -> _RedisT: 

553 return await self.initialize() 

554 

555 async def __aexit__(self, exc_type, exc_value, traceback): 

556 await self.aclose() 

557 

558 _DEL_MESSAGE = "Unclosed Redis client" 

559 

560 # passing _warnings and _grl as argument default since they may be gone 

561 # by the time __del__ is called at shutdown 

562 def __del__( 

563 self, 

564 _warn: Any = warnings.warn, 

565 _grl: Any = asyncio.get_running_loop, 

566 ) -> None: 

567 if hasattr(self, "connection") and (self.connection is not None): 

568 _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self) 

569 try: 

570 context = {"client": self, "message": self._DEL_MESSAGE} 

571 _grl().call_exception_handler(context) 

572 except RuntimeError: 

573 pass 

574 self.connection._close() 

575 

576 async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: 

577 """ 

578 Closes Redis client connection 

579 

580 :param close_connection_pool: decides whether to close the connection pool used 

581 by this Redis client, overriding Redis.auto_close_connection_pool. By default, 

582 let Redis.auto_close_connection_pool decide whether to close the connection 

583 pool. 

584 """ 

585 conn = self.connection 

586 if conn: 

587 self.connection = None 

588 await self.connection_pool.release(conn) 

589 if close_connection_pool or ( 

590 close_connection_pool is None and self.auto_close_connection_pool 

591 ): 

592 await self.connection_pool.disconnect() 

593 

594 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") 

595 async def close(self, close_connection_pool: Optional[bool] = None) -> None: 

596 """ 

597 Alias for aclose(), for backwards compatibility 

598 """ 

599 await self.aclose(close_connection_pool) 

600 

601 async def _send_command_parse_response(self, conn, command_name, *args, **options): 

602 """ 

603 Send a command and parse the response 

604 """ 

605 await conn.send_command(*args) 

606 return await self.parse_response(conn, command_name, **options) 

607 

608 async def _disconnect_raise(self, conn: Connection, error: Exception): 

609 """ 

610 Close the connection and raise an exception 

611 if retry_on_error is not set or the error 

612 is not one of the specified error types 

613 """ 

614 await conn.disconnect() 

615 if ( 

616 conn.retry_on_error is None 

617 or isinstance(error, tuple(conn.retry_on_error)) is False 

618 ): 

619 raise error 

620 

621 # COMMAND EXECUTION AND PROTOCOL PARSING 

622 async def execute_command(self, *args, **options): 

623 """Execute a command and return a parsed response""" 

624 await self.initialize() 

625 command_name = args[0] 

626 keys = options.pop("keys", None) # keys are used only for client side caching 

627 pool = self.connection_pool 

628 conn = self.connection or await pool.get_connection(command_name, **options) 

629 response_from_cache = await conn._get_from_local_cache(args) 

630 try: 

631 if response_from_cache is not None: 

632 return response_from_cache 

633 else: 

634 try: 

635 if self.single_connection_client: 

636 await self._single_conn_lock.acquire() 

637 response = await conn.retry.call_with_retry( 

638 lambda: self._send_command_parse_response( 

639 conn, command_name, *args, **options 

640 ), 

641 lambda error: self._disconnect_raise(conn, error), 

642 ) 

643 conn._add_to_local_cache(args, response, keys) 

644 return response 

645 finally: 

646 if self.single_connection_client: 

647 self._single_conn_lock.release() 

648 finally: 

649 if not self.connection: 

650 await pool.release(conn) 

651 

652 async def parse_response( 

653 self, connection: Connection, command_name: Union[str, bytes], **options 

654 ): 

655 """Parses a response from the Redis server""" 

656 try: 

657 if NEVER_DECODE in options: 

658 response = await connection.read_response(disable_decoding=True) 

659 options.pop(NEVER_DECODE) 

660 else: 

661 response = await connection.read_response() 

662 except ResponseError: 

663 if EMPTY_RESPONSE in options: 

664 return options[EMPTY_RESPONSE] 

665 raise 

666 

667 if EMPTY_RESPONSE in options: 

668 options.pop(EMPTY_RESPONSE) 

669 

670 if command_name in self.response_callbacks: 

671 # Mypy bug: https://github.com/python/mypy/issues/10977 

672 command_name = cast(str, command_name) 

673 retval = self.response_callbacks[command_name](response, **options) 

674 return await retval if inspect.isawaitable(retval) else retval 

675 return response 

676 

677 def flush_cache(self): 

678 try: 

679 if self.connection: 

680 self.connection.client_cache.flush() 

681 else: 

682 self.connection_pool.flush_cache() 

683 except AttributeError: 

684 pass 

685 

686 def delete_command_from_cache(self, command): 

687 try: 

688 if self.connection: 

689 self.connection.client_cache.delete_command(command) 

690 else: 

691 self.connection_pool.delete_command_from_cache(command) 

692 except AttributeError: 

693 pass 

694 

695 def invalidate_key_from_cache(self, key): 

696 try: 

697 if self.connection: 

698 self.connection.client_cache.invalidate_key(key) 

699 else: 

700 self.connection_pool.invalidate_key_from_cache(key) 

701 except AttributeError: 

702 pass 

703 

704 

705StrictRedis = Redis 

706 

707 

708class MonitorCommandInfo(TypedDict): 

709 time: float 

710 db: int 

711 client_address: str 

712 client_port: str 

713 client_type: str 

714 command: str 

715 

716 

717class Monitor: 

718 """ 

719 Monitor is useful for handling the MONITOR command to the redis server. 

720 next_command() method returns one command from monitor 

721 listen() method yields commands from monitor. 

722 """ 

723 

724 monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)") 

725 command_re = re.compile(r'"(.*?)(?<!\\)"') 

726 

727 def __init__(self, connection_pool: ConnectionPool): 

728 self.connection_pool = connection_pool 

729 self.connection: Optional[Connection] = None 

730 

731 async def connect(self): 

732 if self.connection is None: 

733 self.connection = await self.connection_pool.get_connection("MONITOR") 

734 

735 async def __aenter__(self): 

736 await self.connect() 

737 await self.connection.send_command("MONITOR") 

738 # check that monitor returns 'OK', but don't return it to user 

739 response = await self.connection.read_response() 

740 if not bool_ok(response): 

741 raise RedisError(f"MONITOR failed: {response}") 

742 return self 

743 

744 async def __aexit__(self, *args): 

745 await self.connection.disconnect() 

746 await self.connection_pool.release(self.connection) 

747 

748 async def next_command(self) -> MonitorCommandInfo: 

749 """Parse the response from a monitor command""" 

750 await self.connect() 

751 response = await self.connection.read_response() 

752 if isinstance(response, bytes): 

753 response = self.connection.encoder.decode(response, force=True) 

754 command_time, command_data = response.split(" ", 1) 

755 m = self.monitor_re.match(command_data) 

756 db_id, client_info, command = m.groups() 

757 command = " ".join(self.command_re.findall(command)) 

758 # Redis escapes double quotes because each piece of the command 

759 # string is surrounded by double quotes. We don't have that 

760 # requirement so remove the escaping and leave the quote. 

761 command = command.replace('\\"', '"') 

762 

763 if client_info == "lua": 

764 client_address = "lua" 

765 client_port = "" 

766 client_type = "lua" 

767 elif client_info.startswith("unix"): 

768 client_address = "unix" 

769 client_port = client_info[5:] 

770 client_type = "unix" 

771 else: 

772 # use rsplit as ipv6 addresses contain colons 

773 client_address, client_port = client_info.rsplit(":", 1) 

774 client_type = "tcp" 

775 return { 

776 "time": float(command_time), 

777 "db": int(db_id), 

778 "client_address": client_address, 

779 "client_port": client_port, 

780 "client_type": client_type, 

781 "command": command, 

782 } 

783 

784 async def listen(self) -> AsyncIterator[MonitorCommandInfo]: 

785 """Listen for commands coming to the server.""" 

786 while True: 

787 yield await self.next_command() 

788 

789 

790class PubSub: 

791 """ 

792 PubSub provides publish, subscribe and listen support to Redis channels. 

793 

794 After subscribing to one or more channels, the listen() method will block 

795 until a message arrives on one of the subscribed channels. That message 

796 will be returned and it's safe to start listening again. 

797 """ 

798 

799 PUBLISH_MESSAGE_TYPES = ("message", "pmessage") 

800 UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") 

801 HEALTH_CHECK_MESSAGE = "redis-py-health-check" 

802 

803 def __init__( 

804 self, 

805 connection_pool: ConnectionPool, 

806 shard_hint: Optional[str] = None, 

807 ignore_subscribe_messages: bool = False, 

808 encoder=None, 

809 push_handler_func: Optional[Callable] = None, 

810 ): 

811 self.connection_pool = connection_pool 

812 self.shard_hint = shard_hint 

813 self.ignore_subscribe_messages = ignore_subscribe_messages 

814 self.connection = None 

815 # we need to know the encoding options for this connection in order 

816 # to lookup channel and pattern names for callback handlers. 

817 self.encoder = encoder 

818 self.push_handler_func = push_handler_func 

819 if self.encoder is None: 

820 self.encoder = self.connection_pool.get_encoder() 

821 if self.encoder.decode_responses: 

822 self.health_check_response = [ 

823 ["pong", self.HEALTH_CHECK_MESSAGE], 

824 self.HEALTH_CHECK_MESSAGE, 

825 ] 

826 else: 

827 self.health_check_response = [ 

828 [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)], 

829 self.encoder.encode(self.HEALTH_CHECK_MESSAGE), 

830 ] 

831 if self.push_handler_func is None: 

832 _set_info_logger() 

833 self.channels = {} 

834 self.pending_unsubscribe_channels = set() 

835 self.patterns = {} 

836 self.pending_unsubscribe_patterns = set() 

837 self._lock = asyncio.Lock() 

838 

839 async def __aenter__(self): 

840 return self 

841 

842 async def __aexit__(self, exc_type, exc_value, traceback): 

843 await self.aclose() 

844 

845 def __del__(self): 

846 if self.connection: 

847 self.connection.deregister_connect_callback(self.on_connect) 

848 

849 async def aclose(self): 

850 # In case a connection property does not yet exist 

851 # (due to a crash earlier in the Redis() constructor), return 

852 # immediately as there is nothing to clean-up. 

853 if not hasattr(self, "connection"): 

854 return 

855 async with self._lock: 

856 if self.connection: 

857 await self.connection.disconnect() 

858 self.connection.deregister_connect_callback(self.on_connect) 

859 await self.connection_pool.release(self.connection) 

860 self.connection = None 

861 self.channels = {} 

862 self.pending_unsubscribe_channels = set() 

863 self.patterns = {} 

864 self.pending_unsubscribe_patterns = set() 

865 

866 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") 

867 async def close(self) -> None: 

868 """Alias for aclose(), for backwards compatibility""" 

869 await self.aclose() 

870 

871 @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset") 

872 async def reset(self) -> None: 

873 """Alias for aclose(), for backwards compatibility""" 

874 await self.aclose() 

875 

876 async def on_connect(self, connection: Connection): 

877 """Re-subscribe to any channels and patterns previously subscribed to""" 

878 # NOTE: for python3, we can't pass bytestrings as keyword arguments 

879 # so we need to decode channel/pattern names back to unicode strings 

880 # before passing them to [p]subscribe. 

881 self.pending_unsubscribe_channels.clear() 

882 self.pending_unsubscribe_patterns.clear() 

883 if self.channels: 

884 channels = {} 

885 for k, v in self.channels.items(): 

886 channels[self.encoder.decode(k, force=True)] = v 

887 await self.subscribe(**channels) 

888 if self.patterns: 

889 patterns = {} 

890 for k, v in self.patterns.items(): 

891 patterns[self.encoder.decode(k, force=True)] = v 

892 await self.psubscribe(**patterns) 

893 

894 @property 

895 def subscribed(self): 

896 """Indicates if there are subscriptions to any channels or patterns""" 

897 return bool(self.channels or self.patterns) 

898 

899 async def execute_command(self, *args: EncodableT): 

900 """Execute a publish/subscribe command""" 

901 

902 # NOTE: don't parse the response in this function -- it could pull a 

903 # legitimate message off the stack if the connection is already 

904 # subscribed to one or more channels 

905 

906 await self.connect() 

907 connection = self.connection 

908 kwargs = {"check_health": not self.subscribed} 

909 await self._execute(connection, connection.send_command, *args, **kwargs) 

910 

911 async def connect(self): 

912 """ 

913 Ensure that the PubSub is connected 

914 """ 

915 if self.connection is None: 

916 self.connection = await self.connection_pool.get_connection( 

917 "pubsub", self.shard_hint 

918 ) 

919 # register a callback that re-subscribes to any channels we 

920 # were listening to when we were disconnected 

921 self.connection.register_connect_callback(self.on_connect) 

922 else: 

923 await self.connection.connect() 

924 if self.push_handler_func is not None and not HIREDIS_AVAILABLE: 

925 self.connection._parser.set_pubsub_push_handler(self.push_handler_func) 

926 

927 async def _disconnect_raise_connect(self, conn, error): 

928 """ 

929 Close the connection and raise an exception 

930 if retry_on_error is not set or the error is not one 

931 of the specified error types. Otherwise, try to 

932 reconnect 

933 """ 

934 await conn.disconnect() 

935 if ( 

936 conn.retry_on_error is None 

937 or isinstance(error, tuple(conn.retry_on_error)) is False 

938 ): 

939 raise error 

940 await conn.connect() 

941 

942 async def _execute(self, conn, command, *args, **kwargs): 

943 """ 

944 Connect manually upon disconnection. If the Redis server is down, 

945 this will fail and raise a ConnectionError as desired. 

946 After reconnection, the ``on_connect`` callback should have been 

947 called by the # connection to resubscribe us to any channels and 

948 patterns we were previously listening to 

949 """ 

950 return await conn.retry.call_with_retry( 

951 lambda: command(*args, **kwargs), 

952 lambda error: self._disconnect_raise_connect(conn, error), 

953 ) 

954 

955 async def parse_response(self, block: bool = True, timeout: float = 0): 

956 """Parse the response from a publish/subscribe command""" 

957 conn = self.connection 

958 if conn is None: 

959 raise RuntimeError( 

960 "pubsub connection not set: " 

961 "did you forget to call subscribe() or psubscribe()?" 

962 ) 

963 

964 await self.check_health() 

965 

966 if not conn.is_connected: 

967 await conn.connect() 

968 

969 read_timeout = None if block else timeout 

970 response = await self._execute( 

971 conn, 

972 conn.read_response, 

973 timeout=read_timeout, 

974 disconnect_on_error=False, 

975 push_request=True, 

976 ) 

977 

978 if conn.health_check_interval and response in self.health_check_response: 

979 # ignore the health check message as user might not expect it 

980 return None 

981 return response 

982 

983 async def check_health(self): 

984 conn = self.connection 

985 if conn is None: 

986 raise RuntimeError( 

987 "pubsub connection not set: " 

988 "did you forget to call subscribe() or psubscribe()?" 

989 ) 

990 

991 if ( 

992 conn.health_check_interval 

993 and asyncio.get_running_loop().time() > conn.next_health_check 

994 ): 

995 await conn.send_command( 

996 "PING", self.HEALTH_CHECK_MESSAGE, check_health=False 

997 ) 

998 

999 def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: 

1000 """ 

1001 normalize channel/pattern names to be either bytes or strings 

1002 based on whether responses are automatically decoded. this saves us 

1003 from coercing the value for each message coming in. 

1004 """ 

1005 encode = self.encoder.encode 

1006 decode = self.encoder.decode 

1007 return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501 

1008 

1009 async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): 

1010 """ 

1011 Subscribe to channel patterns. Patterns supplied as keyword arguments 

1012 expect a pattern name as the key and a callable as the value. A 

1013 pattern's callable will be invoked automatically when a message is 

1014 received on that pattern rather than producing a message via 

1015 ``listen()``. 

1016 """ 

1017 parsed_args = list_or_args((args[0],), args[1:]) if args else args 

1018 new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args) 

1019 # Mypy bug: https://github.com/python/mypy/issues/10970 

1020 new_patterns.update(kwargs) # type: ignore[arg-type] 

1021 ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys()) 

1022 # update the patterns dict AFTER we send the command. we don't want to 

1023 # subscribe twice to these patterns, once for the command and again 

1024 # for the reconnection. 

1025 new_patterns = self._normalize_keys(new_patterns) 

1026 self.patterns.update(new_patterns) 

1027 self.pending_unsubscribe_patterns.difference_update(new_patterns) 

1028 return ret_val 

1029 

1030 def punsubscribe(self, *args: ChannelT) -> Awaitable: 

1031 """ 

1032 Unsubscribe from the supplied patterns. If empty, unsubscribe from 

1033 all patterns. 

1034 """ 

1035 patterns: Iterable[ChannelT] 

1036 if args: 

1037 parsed_args = list_or_args((args[0],), args[1:]) 

1038 patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys() 

1039 else: 

1040 parsed_args = [] 

1041 patterns = self.patterns 

1042 self.pending_unsubscribe_patterns.update(patterns) 

1043 return self.execute_command("PUNSUBSCRIBE", *parsed_args) 

1044 

1045 async def subscribe(self, *args: ChannelT, **kwargs: Callable): 

1046 """ 

1047 Subscribe to channels. Channels supplied as keyword arguments expect 

1048 a channel name as the key and a callable as the value. A channel's 

1049 callable will be invoked automatically when a message is received on 

1050 that channel rather than producing a message via ``listen()`` or 

1051 ``get_message()``. 

1052 """ 

1053 parsed_args = list_or_args((args[0],), args[1:]) if args else () 

1054 new_channels = dict.fromkeys(parsed_args) 

1055 # Mypy bug: https://github.com/python/mypy/issues/10970 

1056 new_channels.update(kwargs) # type: ignore[arg-type] 

1057 ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys()) 

1058 # update the channels dict AFTER we send the command. we don't want to 

1059 # subscribe twice to these channels, once for the command and again 

1060 # for the reconnection. 

1061 new_channels = self._normalize_keys(new_channels) 

1062 self.channels.update(new_channels) 

1063 self.pending_unsubscribe_channels.difference_update(new_channels) 

1064 return ret_val 

1065 

1066 def unsubscribe(self, *args) -> Awaitable: 

1067 """ 

1068 Unsubscribe from the supplied channels. If empty, unsubscribe from 

1069 all channels 

1070 """ 

1071 if args: 

1072 parsed_args = list_or_args(args[0], args[1:]) 

1073 channels = self._normalize_keys(dict.fromkeys(parsed_args)) 

1074 else: 

1075 parsed_args = [] 

1076 channels = self.channels 

1077 self.pending_unsubscribe_channels.update(channels) 

1078 return self.execute_command("UNSUBSCRIBE", *parsed_args) 

1079 

1080 async def listen(self) -> AsyncIterator: 

1081 """Listen for messages on channels this client has been subscribed to""" 

1082 while self.subscribed: 

1083 response = await self.handle_message(await self.parse_response(block=True)) 

1084 if response is not None: 

1085 yield response 

1086 

1087 async def get_message( 

1088 self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0 

1089 ): 

1090 """ 

1091 Get the next message if one is available, otherwise None. 

1092 

1093 If timeout is specified, the system will wait for `timeout` seconds 

1094 before returning. Timeout should be specified as a floating point 

1095 number or None to wait indefinitely. 

1096 """ 

1097 response = await self.parse_response(block=(timeout is None), timeout=timeout) 

1098 if response: 

1099 return await self.handle_message(response, ignore_subscribe_messages) 

1100 return None 

1101 

1102 def ping(self, message=None) -> Awaitable: 

1103 """ 

1104 Ping the Redis server 

1105 """ 

1106 args = ["PING", message] if message is not None else ["PING"] 

1107 return self.execute_command(*args) 

1108 

1109 async def handle_message(self, response, ignore_subscribe_messages=False): 

1110 """ 

1111 Parses a pub/sub message. If the channel or pattern was subscribed to 

1112 with a message handler, the handler is invoked instead of a parsed 

1113 message being returned. 

1114 """ 

1115 if response is None: 

1116 return None 

1117 if isinstance(response, bytes): 

1118 response = [b"pong", response] if response != b"PONG" else [b"pong", b""] 

1119 message_type = str_if_bytes(response[0]) 

1120 if message_type == "pmessage": 

1121 message = { 

1122 "type": message_type, 

1123 "pattern": response[1], 

1124 "channel": response[2], 

1125 "data": response[3], 

1126 } 

1127 elif message_type == "pong": 

1128 message = { 

1129 "type": message_type, 

1130 "pattern": None, 

1131 "channel": None, 

1132 "data": response[1], 

1133 } 

1134 else: 

1135 message = { 

1136 "type": message_type, 

1137 "pattern": None, 

1138 "channel": response[1], 

1139 "data": response[2], 

1140 } 

1141 

1142 # if this is an unsubscribe message, remove it from memory 

1143 if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: 

1144 if message_type == "punsubscribe": 

1145 pattern = response[1] 

1146 if pattern in self.pending_unsubscribe_patterns: 

1147 self.pending_unsubscribe_patterns.remove(pattern) 

1148 self.patterns.pop(pattern, None) 

1149 else: 

1150 channel = response[1] 

1151 if channel in self.pending_unsubscribe_channels: 

1152 self.pending_unsubscribe_channels.remove(channel) 

1153 self.channels.pop(channel, None) 

1154 

1155 if message_type in self.PUBLISH_MESSAGE_TYPES: 

1156 # if there's a message handler, invoke it 

1157 if message_type == "pmessage": 

1158 handler = self.patterns.get(message["pattern"], None) 

1159 else: 

1160 handler = self.channels.get(message["channel"], None) 

1161 if handler: 

1162 if inspect.iscoroutinefunction(handler): 

1163 await handler(message) 

1164 else: 

1165 handler(message) 

1166 return None 

1167 elif message_type != "pong": 

1168 # this is a subscribe/unsubscribe message. ignore if we don't 

1169 # want them 

1170 if ignore_subscribe_messages or self.ignore_subscribe_messages: 

1171 return None 

1172 

1173 return message 

1174 

1175 async def run( 

1176 self, 

1177 *, 

1178 exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, 

1179 poll_timeout: float = 1.0, 

1180 ) -> None: 

1181 """Process pub/sub messages using registered callbacks. 

1182 

1183 This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in 

1184 redis-py, but it is a coroutine. To launch it as a separate task, use 

1185 ``asyncio.create_task``: 

1186 

1187 >>> task = asyncio.create_task(pubsub.run()) 

1188 

1189 To shut it down, use asyncio cancellation: 

1190 

1191 >>> task.cancel() 

1192 >>> await task 

1193 """ 

1194 for channel, handler in self.channels.items(): 

1195 if handler is None: 

1196 raise PubSubError(f"Channel: '{channel}' has no handler registered") 

1197 for pattern, handler in self.patterns.items(): 

1198 if handler is None: 

1199 raise PubSubError(f"Pattern: '{pattern}' has no handler registered") 

1200 

1201 await self.connect() 

1202 while True: 

1203 try: 

1204 await self.get_message( 

1205 ignore_subscribe_messages=True, timeout=poll_timeout 

1206 ) 

1207 except asyncio.CancelledError: 

1208 raise 

1209 except BaseException as e: 

1210 if exception_handler is None: 

1211 raise 

1212 res = exception_handler(e, self) 

1213 if inspect.isawaitable(res): 

1214 await res 

1215 # Ensure that other tasks on the event loop get a chance to run 

1216 # if we didn't have to block for I/O anywhere. 

1217 await asyncio.sleep(0) 

1218 

1219 

1220class PubsubWorkerExceptionHandler(Protocol): 

1221 def __call__(self, e: BaseException, pubsub: PubSub): ... 

1222 

1223 

1224class AsyncPubsubWorkerExceptionHandler(Protocol): 

1225 async def __call__(self, e: BaseException, pubsub: PubSub): ... 

1226 

1227 

1228PSWorkerThreadExcHandlerT = Union[ 

1229 PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler 

1230] 

1231 

1232 

1233CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]] 

1234CommandStackT = List[CommandT] 

1235 

1236 

1237class Pipeline(Redis): # lgtm [py/init-calls-subclass] 

1238 """ 

1239 Pipelines provide a way to transmit multiple commands to the Redis server 

1240 in one transmission. This is convenient for batch processing, such as 

1241 saving all the values in a list to Redis. 

1242 

1243 All commands executed within a pipeline are wrapped with MULTI and EXEC 

1244 calls. This guarantees all commands executed in the pipeline will be 

1245 executed atomically. 

1246 

1247 Any command raising an exception does *not* halt the execution of 

1248 subsequent commands in the pipeline. Instead, the exception is caught 

1249 and its instance is placed into the response list returned by execute(). 

1250 Code iterating over the response list should be able to deal with an 

1251 instance of an exception as a potential value. In general, these will be 

1252 ResponseError exceptions, such as those raised when issuing a command 

1253 on a key of a different datatype. 

1254 """ 

1255 

1256 UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} 

1257 

1258 def __init__( 

1259 self, 

1260 connection_pool: ConnectionPool, 

1261 response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT], 

1262 transaction: bool, 

1263 shard_hint: Optional[str], 

1264 ): 

1265 self.connection_pool = connection_pool 

1266 self.connection = None 

1267 self.response_callbacks = response_callbacks 

1268 self.is_transaction = transaction 

1269 self.shard_hint = shard_hint 

1270 self.watching = False 

1271 self.command_stack: CommandStackT = [] 

1272 self.scripts: Set["Script"] = set() 

1273 self.explicit_transaction = False 

1274 

1275 async def __aenter__(self: _RedisT) -> _RedisT: 

1276 return self 

1277 

1278 async def __aexit__(self, exc_type, exc_value, traceback): 

1279 await self.reset() 

1280 

1281 def __await__(self): 

1282 return self._async_self().__await__() 

1283 

1284 _DEL_MESSAGE = "Unclosed Pipeline client" 

1285 

1286 def __len__(self): 

1287 return len(self.command_stack) 

1288 

1289 def __bool__(self): 

1290 """Pipeline instances should always evaluate to True""" 

1291 return True 

1292 

1293 async def _async_self(self): 

1294 return self 

1295 

1296 async def reset(self): 

1297 self.command_stack = [] 

1298 self.scripts = set() 

1299 # make sure to reset the connection state in the event that we were 

1300 # watching something 

1301 if self.watching and self.connection: 

1302 try: 

1303 # call this manually since our unwatch or 

1304 # immediate_execute_command methods can call reset() 

1305 await self.connection.send_command("UNWATCH") 

1306 await self.connection.read_response() 

1307 except ConnectionError: 

1308 # disconnect will also remove any previous WATCHes 

1309 if self.connection: 

1310 await self.connection.disconnect() 

1311 # clean up the other instance attributes 

1312 self.watching = False 

1313 self.explicit_transaction = False 

1314 # we can safely return the connection to the pool here since we're 

1315 # sure we're no longer WATCHing anything 

1316 if self.connection: 

1317 await self.connection_pool.release(self.connection) 

1318 self.connection = None 

1319 

1320 async def aclose(self) -> None: 

1321 """Alias for reset(), a standard method name for cleanup""" 

1322 await self.reset() 

1323 

1324 def multi(self): 

1325 """ 

1326 Start a transactional block of the pipeline after WATCH commands 

1327 are issued. End the transactional block with `execute`. 

1328 """ 

1329 if self.explicit_transaction: 

1330 raise RedisError("Cannot issue nested calls to MULTI") 

1331 if self.command_stack: 

1332 raise RedisError( 

1333 "Commands without an initial WATCH have already been issued" 

1334 ) 

1335 self.explicit_transaction = True 

1336 

1337 def execute_command( 

1338 self, *args, **kwargs 

1339 ) -> Union["Pipeline", Awaitable["Pipeline"]]: 

1340 kwargs.pop("keys", None) # the keys are used only for client side caching 

1341 if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: 

1342 return self.immediate_execute_command(*args, **kwargs) 

1343 return self.pipeline_execute_command(*args, **kwargs) 

1344 

1345 async def _disconnect_reset_raise(self, conn, error): 

1346 """ 

1347 Close the connection, reset watching state and 

1348 raise an exception if we were watching, 

1349 if retry_on_error is not set or the error is not one 

1350 of the specified error types. 

1351 """ 

1352 await conn.disconnect() 

1353 # if we were already watching a variable, the watch is no longer 

1354 # valid since this connection has died. raise a WatchError, which 

1355 # indicates the user should retry this transaction. 

1356 if self.watching: 

1357 await self.aclose() 

1358 raise WatchError( 

1359 "A ConnectionError occurred on while watching one or more keys" 

1360 ) 

1361 # if retry_on_error is not set or the error is not one 

1362 # of the specified error types, raise it 

1363 if ( 

1364 conn.retry_on_error is None 

1365 or isinstance(error, tuple(conn.retry_on_error)) is False 

1366 ): 

1367 await self.aclose() 

1368 raise 

1369 

1370 async def immediate_execute_command(self, *args, **options): 

1371 """ 

1372 Execute a command immediately, but don't auto-retry on a 

1373 ConnectionError if we're already WATCHing a variable. Used when 

1374 issuing WATCH or subsequent commands retrieving their values but before 

1375 MULTI is called. 

1376 """ 

1377 command_name = args[0] 

1378 conn = self.connection 

1379 # if this is the first call, we need a connection 

1380 if not conn: 

1381 conn = await self.connection_pool.get_connection( 

1382 command_name, self.shard_hint 

1383 ) 

1384 self.connection = conn 

1385 

1386 return await conn.retry.call_with_retry( 

1387 lambda: self._send_command_parse_response( 

1388 conn, command_name, *args, **options 

1389 ), 

1390 lambda error: self._disconnect_reset_raise(conn, error), 

1391 ) 

1392 

1393 def pipeline_execute_command(self, *args, **options): 

1394 """ 

1395 Stage a command to be executed when execute() is next called 

1396 

1397 Returns the current Pipeline object back so commands can be 

1398 chained together, such as: 

1399 

1400 pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') 

1401 

1402 At some other point, you can then run: pipe.execute(), 

1403 which will execute all commands queued in the pipe. 

1404 """ 

1405 self.command_stack.append((args, options)) 

1406 return self 

1407 

1408 async def _execute_transaction( # noqa: C901 

1409 self, connection: Connection, commands: CommandStackT, raise_on_error 

1410 ): 

1411 pre: CommandT = (("MULTI",), {}) 

1412 post: CommandT = (("EXEC",), {}) 

1413 cmds = (pre, *commands, post) 

1414 all_cmds = connection.pack_commands( 

1415 args for args, options in cmds if EMPTY_RESPONSE not in options 

1416 ) 

1417 await connection.send_packed_command(all_cmds) 

1418 errors = [] 

1419 

1420 # parse off the response for MULTI 

1421 # NOTE: we need to handle ResponseErrors here and continue 

1422 # so that we read all the additional command messages from 

1423 # the socket 

1424 try: 

1425 await self.parse_response(connection, "_") 

1426 except ResponseError as err: 

1427 errors.append((0, err)) 

1428 

1429 # and all the other commands 

1430 for i, command in enumerate(commands): 

1431 if EMPTY_RESPONSE in command[1]: 

1432 errors.append((i, command[1][EMPTY_RESPONSE])) 

1433 else: 

1434 try: 

1435 await self.parse_response(connection, "_") 

1436 except ResponseError as err: 

1437 self.annotate_exception(err, i + 1, command[0]) 

1438 errors.append((i, err)) 

1439 

1440 # parse the EXEC. 

1441 try: 

1442 response = await self.parse_response(connection, "_") 

1443 except ExecAbortError as err: 

1444 if errors: 

1445 raise errors[0][1] from err 

1446 raise 

1447 

1448 # EXEC clears any watched keys 

1449 self.watching = False 

1450 

1451 if response is None: 

1452 raise WatchError("Watched variable changed.") from None 

1453 

1454 # put any parse errors into the response 

1455 for i, e in errors: 

1456 response.insert(i, e) 

1457 

1458 if len(response) != len(commands): 

1459 if self.connection: 

1460 await self.connection.disconnect() 

1461 raise ResponseError( 

1462 "Wrong number of response items from pipeline execution" 

1463 ) from None 

1464 

1465 # find any errors in the response and raise if necessary 

1466 if raise_on_error: 

1467 self.raise_first_error(commands, response) 

1468 

1469 # We have to run response callbacks manually 

1470 data = [] 

1471 for r, cmd in zip(response, commands): 

1472 if not isinstance(r, Exception): 

1473 args, options = cmd 

1474 command_name = args[0] 

1475 if command_name in self.response_callbacks: 

1476 r = self.response_callbacks[command_name](r, **options) 

1477 if inspect.isawaitable(r): 

1478 r = await r 

1479 data.append(r) 

1480 return data 

1481 

1482 async def _execute_pipeline( 

1483 self, connection: Connection, commands: CommandStackT, raise_on_error: bool 

1484 ): 

1485 # build up all commands into a single request to increase network perf 

1486 all_cmds = connection.pack_commands([args for args, _ in commands]) 

1487 await connection.send_packed_command(all_cmds) 

1488 

1489 response = [] 

1490 for args, options in commands: 

1491 try: 

1492 response.append( 

1493 await self.parse_response(connection, args[0], **options) 

1494 ) 

1495 except ResponseError as e: 

1496 response.append(e) 

1497 

1498 if raise_on_error: 

1499 self.raise_first_error(commands, response) 

1500 return response 

1501 

1502 def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): 

1503 for i, r in enumerate(response): 

1504 if isinstance(r, ResponseError): 

1505 self.annotate_exception(r, i + 1, commands[i][0]) 

1506 raise r 

1507 

1508 def annotate_exception( 

1509 self, exception: Exception, number: int, command: Iterable[object] 

1510 ) -> None: 

1511 cmd = " ".join(map(safe_str, command)) 

1512 msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" 

1513 exception.args = (msg,) + exception.args[1:] 

1514 

1515 async def parse_response( 

1516 self, connection: Connection, command_name: Union[str, bytes], **options 

1517 ): 

1518 result = await super().parse_response(connection, command_name, **options) 

1519 if command_name in self.UNWATCH_COMMANDS: 

1520 self.watching = False 

1521 elif command_name == "WATCH": 

1522 self.watching = True 

1523 return result 

1524 

1525 async def load_scripts(self): 

1526 # make sure all scripts that are about to be run on this pipeline exist 

1527 scripts = list(self.scripts) 

1528 immediate = self.immediate_execute_command 

1529 shas = [s.sha for s in scripts] 

1530 # we can't use the normal script_* methods because they would just 

1531 # get buffered in the pipeline. 

1532 exists = await immediate("SCRIPT EXISTS", *shas) 

1533 if not all(exists): 

1534 for s, exist in zip(scripts, exists): 

1535 if not exist: 

1536 s.sha = await immediate("SCRIPT LOAD", s.script) 

1537 

1538 async def _disconnect_raise_reset(self, conn: Connection, error: Exception): 

1539 """ 

1540 Close the connection, raise an exception if we were watching, 

1541 and raise an exception if retry_on_error is not set or the 

1542 error is not one of the specified error types. 

1543 """ 

1544 await conn.disconnect() 

1545 # if we were watching a variable, the watch is no longer valid 

1546 # since this connection has died. raise a WatchError, which 

1547 # indicates the user should retry this transaction. 

1548 if self.watching: 

1549 raise WatchError( 

1550 "A ConnectionError occurred on while watching one or more keys" 

1551 ) 

1552 # if retry_on_error is not set or the error is not one 

1553 # of the specified error types, raise it 

1554 if ( 

1555 conn.retry_on_error is None 

1556 or isinstance(error, tuple(conn.retry_on_error)) is False 

1557 ): 

1558 await self.reset() 

1559 raise 

1560 

1561 async def execute(self, raise_on_error: bool = True): 

1562 """Execute all the commands in the current pipeline""" 

1563 stack = self.command_stack 

1564 if not stack and not self.watching: 

1565 return [] 

1566 if self.scripts: 

1567 await self.load_scripts() 

1568 if self.is_transaction or self.explicit_transaction: 

1569 execute = self._execute_transaction 

1570 else: 

1571 execute = self._execute_pipeline 

1572 

1573 conn = self.connection 

1574 if not conn: 

1575 conn = await self.connection_pool.get_connection("MULTI", self.shard_hint) 

1576 # assign to self.connection so reset() releases the connection 

1577 # back to the pool after we're done 

1578 self.connection = conn 

1579 conn = cast(Connection, conn) 

1580 

1581 try: 

1582 return await conn.retry.call_with_retry( 

1583 lambda: execute(conn, stack, raise_on_error), 

1584 lambda error: self._disconnect_raise_reset(conn, error), 

1585 ) 

1586 finally: 

1587 await self.reset() 

1588 

1589 async def discard(self): 

1590 """Flushes all previously queued commands 

1591 See: https://redis.io/commands/DISCARD 

1592 """ 

1593 await self.execute_command("DISCARD") 

1594 

1595 async def watch(self, *names: KeyT): 

1596 """Watches the values at keys ``names``""" 

1597 if self.explicit_transaction: 

1598 raise RedisError("Cannot issue a WATCH after a MULTI") 

1599 return await self.execute_command("WATCH", *names) 

1600 

1601 async def unwatch(self): 

1602 """Unwatches all previously specified keys""" 

1603 return self.watching and await self.execute_command("UNWATCH") or True