Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/event.py: 47%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

251 statements  

1import asyncio 

2import threading 

3from abc import ABC, abstractmethod 

4from enum import Enum 

5from typing import Dict, List, Optional, Type, Union 

6 

7from redis.auth.token import TokenInterface 

8from redis.credentials import CredentialProvider, StreamingCredentialProvider 

9from redis.observability.recorder import ( 

10 init_connection_count, 

11 register_pools_connection_count, 

12) 

13from redis.utils import check_protocol_version, deprecated_function 

14 

15 

16class EventListenerInterface(ABC): 

17 """ 

18 Represents a listener for given event object. 

19 """ 

20 

21 @abstractmethod 

22 def listen(self, event: object): 

23 pass 

24 

25 

26class AsyncEventListenerInterface(ABC): 

27 """ 

28 Represents an async listener for given event object. 

29 """ 

30 

31 @abstractmethod 

32 async def listen(self, event: object): 

33 pass 

34 

35 

36class EventDispatcherInterface(ABC): 

37 """ 

38 Represents a dispatcher that dispatches events to listeners 

39 associated with given event. 

40 """ 

41 

42 @abstractmethod 

43 def dispatch(self, event: object): 

44 pass 

45 

46 @abstractmethod 

47 async def dispatch_async(self, event: object): 

48 pass 

49 

50 @abstractmethod 

51 def register_listeners( 

52 self, 

53 mappings: Dict[ 

54 Type[object], 

55 List[Union[EventListenerInterface, AsyncEventListenerInterface]], 

56 ], 

57 ): 

58 """Register additional listeners.""" 

59 pass 

60 

61 @abstractmethod 

62 def unregister_listeners( 

63 self, 

64 mappings: Dict[ 

65 Type[object], 

66 List[Union[EventListenerInterface, AsyncEventListenerInterface]], 

67 ], 

68 ): 

69 """Remove previously registered listeners by identity.""" 

70 pass 

71 

72 

73class EventException(Exception): 

74 """ 

75 Exception wrapper that adds an event object into exception context. 

76 """ 

77 

78 def __init__(self, exception: Exception, event: object): 

79 self.exception = exception 

80 self.event = event 

81 super().__init__(exception) 

82 

83 

84class EventDispatcher(EventDispatcherInterface): 

85 # TODO: Make dispatcher to accept external mappings. 

86 def __init__( 

87 self, 

88 event_listeners: Optional[ 

89 Dict[Type[object], List[EventListenerInterface]] 

90 ] = None, 

91 ): 

92 """ 

93 Dispatcher that dispatches events to listeners associated with given event. 

94 """ 

95 self._event_listeners_mapping: Dict[ 

96 Type[object], List[EventListenerInterface] 

97 ] = { 

98 AfterConnectionReleasedEvent: [ 

99 ReAuthConnectionListener(), 

100 ], 

101 AfterPooledConnectionsInstantiationEvent: [ 

102 RegisterReAuthForPooledConnections(), 

103 ], 

104 AfterSingleConnectionInstantiationEvent: [ 

105 RegisterReAuthForSingleConnection() 

106 ], 

107 AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()], 

108 AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()], 

109 AsyncAfterConnectionReleasedEvent: [ 

110 AsyncReAuthConnectionListener(), 

111 ], 

112 } 

113 

114 # Reentrant so a finalizer/listener that runs on the same thread 

115 # while the lock is held (e.g. a weakref.finalize callback fired 

116 # from cyclic GC during an allocation inside register_listeners / 

117 # unregister_listeners) can re-enter without deadlocking. 

118 self._lock = threading.RLock() 

119 self._async_lock = None 

120 

121 if event_listeners: 

122 self.register_listeners(event_listeners) 

123 

124 def dispatch(self, event: object): 

125 # Snapshot listeners under the lock, then release it before invoking 

126 # them. Holding the lock across listener execution would turn any 

127 # listener that calls register_listeners / unregister_listeners / 

128 # dispatch back into the dispatcher into a deadlock. 

129 with self._lock: 

130 listeners = list(self._event_listeners_mapping.get(type(event), [])) 

131 for listener in listeners: 

132 listener.listen(event) 

133 

134 async def dispatch_async(self, event: object): 

135 if self._async_lock is None: 

136 self._async_lock = asyncio.Lock() 

137 

138 # Snapshot listeners under the lock, then release it before awaiting 

139 # them. See the note in dispatch(); the same rationale applies here 

140 # for dispatch_async re-entry from within a listener. 

141 async with self._async_lock: 

142 listeners = list(self._event_listeners_mapping.get(type(event), [])) 

143 for listener in listeners: 

144 await listener.listen(event) 

145 

146 def register_listeners( 

147 self, 

148 mappings: Dict[ 

149 Type[object], 

150 List[Union[EventListenerInterface, AsyncEventListenerInterface]], 

151 ], 

152 ): 

153 with self._lock: 

154 for event_type in mappings: 

155 if event_type in self._event_listeners_mapping: 

156 self._event_listeners_mapping[event_type] = list( 

157 set( 

158 self._event_listeners_mapping[event_type] 

159 + mappings[event_type] 

160 ) 

161 ) 

162 else: 

163 self._event_listeners_mapping[event_type] = mappings[event_type] 

164 

165 def unregister_listeners( 

166 self, 

167 mappings: Dict[ 

168 Type[object], 

169 List[Union[EventListenerInterface, AsyncEventListenerInterface]], 

170 ], 

171 ): 

172 with self._lock: 

173 for event_type, to_remove in mappings.items(): 

174 current = self._event_listeners_mapping.get(event_type) 

175 if not current: 

176 continue 

177 # Remove by identity to match register semantics and to avoid 

178 # reliance on listener __eq__ implementations. 

179 self._event_listeners_mapping[event_type] = [ 

180 listener 

181 for listener in current 

182 if all(listener is not target for target in to_remove) 

183 ] 

184 

185 

186class AfterConnectionReleasedEvent: 

187 """ 

188 Event that will be fired before each command execution. 

189 """ 

190 

191 def __init__(self, connection): 

192 self._connection = connection 

193 

194 @property 

195 def connection(self): 

196 return self._connection 

197 

198 

199class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent): 

200 pass 

201 

202 

203class AfterSlotsCacheRefreshEvent: 

204 """ 

205 Event fired after NodesManager's slots cache is refreshed, either via a 

206 full re-initialization or a MOVED-driven slot re-mapping. Signal-only; 

207 carries no payload. Listeners typically reconcile per-node bookkeeping 

208 (e.g. ClusterPubSub shard subscriptions). 

209 """ 

210 

211 pass 

212 

213 

214class AsyncAfterSlotsCacheRefreshEvent(AfterSlotsCacheRefreshEvent): 

215 pass 

216 

217 

218class ClientType(Enum): 

219 SYNC = ("sync",) 

220 ASYNC = ("async",) 

221 

222 

223class AfterPooledConnectionsInstantiationEvent: 

224 """ 

225 Event that will be fired after pooled connection instances was created. 

226 """ 

227 

228 def __init__( 

229 self, 

230 connection_pools: List, 

231 client_type: ClientType, 

232 credential_provider: Optional[CredentialProvider] = None, 

233 ): 

234 self._connection_pools = connection_pools 

235 self._client_type = client_type 

236 self._credential_provider = credential_provider 

237 

238 @property 

239 def connection_pools(self): 

240 return self._connection_pools 

241 

242 @property 

243 def client_type(self) -> ClientType: 

244 return self._client_type 

245 

246 @property 

247 def credential_provider(self) -> Union[CredentialProvider, None]: 

248 return self._credential_provider 

249 

250 

251class AfterSingleConnectionInstantiationEvent: 

252 """ 

253 Event that will be fired after single connection instances was created. 

254 

255 :param connection_lock: For sync client thread-lock should be provided, 

256 for async asyncio.Lock 

257 """ 

258 

259 def __init__( 

260 self, 

261 connection, 

262 client_type: ClientType, 

263 connection_lock: Union[threading.RLock, asyncio.Lock], 

264 ): 

265 self._connection = connection 

266 self._client_type = client_type 

267 self._connection_lock = connection_lock 

268 

269 @property 

270 def connection(self): 

271 return self._connection 

272 

273 @property 

274 def client_type(self) -> ClientType: 

275 return self._client_type 

276 

277 @property 

278 def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: 

279 return self._connection_lock 

280 

281 

282class AfterPubSubConnectionInstantiationEvent: 

283 def __init__( 

284 self, 

285 pubsub_connection, 

286 connection_pool, 

287 client_type: ClientType, 

288 connection_lock: Union[threading.RLock, asyncio.Lock], 

289 ): 

290 self._pubsub_connection = pubsub_connection 

291 self._connection_pool = connection_pool 

292 self._client_type = client_type 

293 self._connection_lock = connection_lock 

294 

295 @property 

296 def pubsub_connection(self): 

297 return self._pubsub_connection 

298 

299 @property 

300 def connection_pool(self): 

301 return self._connection_pool 

302 

303 @property 

304 def client_type(self) -> ClientType: 

305 return self._client_type 

306 

307 @property 

308 def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: 

309 return self._connection_lock 

310 

311 

312class AfterAsyncClusterInstantiationEvent: 

313 """ 

314 Event that will be fired after async cluster instance was created. 

315 

316 Async cluster doesn't use connection pools, 

317 instead ClusterNode object manages connections. 

318 """ 

319 

320 def __init__( 

321 self, 

322 nodes: dict, 

323 credential_provider: Optional[CredentialProvider] = None, 

324 ): 

325 self._nodes = nodes 

326 self._credential_provider = credential_provider 

327 

328 @property 

329 def nodes(self) -> dict: 

330 return self._nodes 

331 

332 @property 

333 def credential_provider(self) -> Union[CredentialProvider, None]: 

334 return self._credential_provider 

335 

336 

337class OnCommandsFailEvent: 

338 """ 

339 Event fired whenever a command fails during the execution. 

340 """ 

341 

342 def __init__( 

343 self, 

344 commands: tuple, 

345 exception: Exception, 

346 ): 

347 self._commands = commands 

348 self._exception = exception 

349 

350 @property 

351 def commands(self) -> tuple: 

352 return self._commands 

353 

354 @property 

355 def exception(self) -> Exception: 

356 return self._exception 

357 

358 

359class AsyncOnCommandsFailEvent(OnCommandsFailEvent): 

360 pass 

361 

362 

363class ReAuthConnectionListener(EventListenerInterface): 

364 """ 

365 Listener that performs re-authentication of given connection. 

366 """ 

367 

368 def listen(self, event: AfterConnectionReleasedEvent): 

369 event.connection.re_auth() 

370 

371 

372class AsyncReAuthConnectionListener(AsyncEventListenerInterface): 

373 """ 

374 Async listener that performs re-authentication of given connection. 

375 """ 

376 

377 async def listen(self, event: AsyncAfterConnectionReleasedEvent): 

378 await event.connection.re_auth() 

379 

380 

381class RegisterReAuthForPooledConnections(EventListenerInterface): 

382 """ 

383 Listener that registers a re-authentication callback for pooled connections. 

384 Required by :class:`StreamingCredentialProvider`. 

385 """ 

386 

387 def __init__(self): 

388 self._event = None 

389 

390 def listen(self, event: AfterPooledConnectionsInstantiationEvent): 

391 if isinstance(event.credential_provider, StreamingCredentialProvider): 

392 self._event = event 

393 

394 if event.client_type == ClientType.SYNC: 

395 event.credential_provider.on_next(self._re_auth) 

396 event.credential_provider.on_error(self._raise_on_error) 

397 else: 

398 event.credential_provider.on_next(self._re_auth_async) 

399 event.credential_provider.on_error(self._raise_on_error_async) 

400 

401 def _re_auth(self, token): 

402 for pool in self._event.connection_pools: 

403 pool.re_auth_callback(token) 

404 

405 async def _re_auth_async(self, token): 

406 for pool in self._event.connection_pools: 

407 await pool.re_auth_callback(token) 

408 

409 def _raise_on_error(self, error: Exception): 

410 raise EventException(error, self._event) 

411 

412 async def _raise_on_error_async(self, error: Exception): 

413 raise EventException(error, self._event) 

414 

415 

416class RegisterReAuthForSingleConnection(EventListenerInterface): 

417 """ 

418 Listener that registers a re-authentication callback for single connection. 

419 Required by :class:`StreamingCredentialProvider`. 

420 """ 

421 

422 def __init__(self): 

423 self._event = None 

424 

425 def listen(self, event: AfterSingleConnectionInstantiationEvent): 

426 if isinstance( 

427 event.connection.credential_provider, StreamingCredentialProvider 

428 ): 

429 self._event = event 

430 

431 if event.client_type == ClientType.SYNC: 

432 event.connection.credential_provider.on_next(self._re_auth) 

433 event.connection.credential_provider.on_error(self._raise_on_error) 

434 else: 

435 event.connection.credential_provider.on_next(self._re_auth_async) 

436 event.connection.credential_provider.on_error( 

437 self._raise_on_error_async 

438 ) 

439 

440 def _re_auth(self, token): 

441 with self._event.connection_lock: 

442 self._event.connection.send_command( 

443 "AUTH", token.try_get("oid"), token.get_value() 

444 ) 

445 self._event.connection.read_response() 

446 

447 async def _re_auth_async(self, token): 

448 async with self._event.connection_lock: 

449 await self._event.connection.send_command( 

450 "AUTH", token.try_get("oid"), token.get_value() 

451 ) 

452 await self._event.connection.read_response() 

453 

454 def _raise_on_error(self, error: Exception): 

455 raise EventException(error, self._event) 

456 

457 async def _raise_on_error_async(self, error: Exception): 

458 raise EventException(error, self._event) 

459 

460 

461class RegisterReAuthForAsyncClusterNodes(EventListenerInterface): 

462 def __init__(self): 

463 self._event = None 

464 

465 def listen(self, event: AfterAsyncClusterInstantiationEvent): 

466 if isinstance(event.credential_provider, StreamingCredentialProvider): 

467 self._event = event 

468 event.credential_provider.on_next(self._re_auth) 

469 event.credential_provider.on_error(self._raise_on_error) 

470 

471 async def _re_auth(self, token: TokenInterface): 

472 for key in self._event.nodes: 

473 await self._event.nodes[key].re_auth_callback(token) 

474 

475 async def _raise_on_error(self, error: Exception): 

476 raise EventException(error, self._event) 

477 

478 

479class RegisterReAuthForPubSub(EventListenerInterface): 

480 def __init__(self): 

481 self._connection = None 

482 self._connection_pool = None 

483 self._client_type = None 

484 self._connection_lock = None 

485 self._event = None 

486 

487 def listen(self, event: AfterPubSubConnectionInstantiationEvent): 

488 if isinstance( 

489 event.pubsub_connection.credential_provider, StreamingCredentialProvider 

490 ) and check_protocol_version(event.pubsub_connection.get_protocol(), 3): 

491 self._event = event 

492 self._connection = event.pubsub_connection 

493 self._connection_pool = event.connection_pool 

494 self._client_type = event.client_type 

495 self._connection_lock = event.connection_lock 

496 

497 if self._client_type == ClientType.SYNC: 

498 self._connection.credential_provider.on_next(self._re_auth) 

499 self._connection.credential_provider.on_error(self._raise_on_error) 

500 else: 

501 self._connection.credential_provider.on_next(self._re_auth_async) 

502 self._connection.credential_provider.on_error( 

503 self._raise_on_error_async 

504 ) 

505 

506 def _re_auth(self, token: TokenInterface): 

507 with self._connection_lock: 

508 self._connection.send_command( 

509 "AUTH", token.try_get("oid"), token.get_value() 

510 ) 

511 self._connection.read_response() 

512 

513 self._connection_pool.re_auth_callback(token) 

514 

515 async def _re_auth_async(self, token: TokenInterface): 

516 async with self._connection_lock: 

517 await self._connection.send_command( 

518 "AUTH", token.try_get("oid"), token.get_value() 

519 ) 

520 await self._connection.read_response() 

521 

522 await self._connection_pool.re_auth_callback(token) 

523 

524 def _raise_on_error(self, error: Exception): 

525 raise EventException(error, self._event) 

526 

527 async def _raise_on_error_async(self, error: Exception): 

528 raise EventException(error, self._event) 

529 

530 

531class InitializeConnectionCountObservability(EventListenerInterface): 

532 """ 

533 Listener that initializes connection count observability. 

534 """ 

535 

536 @deprecated_function( 

537 reason="Connection count is now tracked via record_connection_count(). " 

538 "This functionality will be removed in the next major version", 

539 version="7.4.0", 

540 ) 

541 def listen(self, event: AfterPooledConnectionsInstantiationEvent): 

542 # Initialize gauge only once, subsequent calls won't have an affect. 

543 # Note: init_connection_count() and register_pools_connection_count() 

544 # are deprecated and will emit their own warnings. 

545 init_connection_count() 

546 

547 # Register pools for connection count observability. 

548 register_pools_connection_count(event.connection_pools)