Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/event.py: 47%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

237 statements  

1import asyncio 

2import threading 

3from abc import ABC, abstractmethod 

4from enum import Enum 

5from typing import Dict, List, Optional, Type, Union 

6 

7from redis.auth.token import TokenInterface 

8from redis.credentials import CredentialProvider, StreamingCredentialProvider 

9from redis.observability.recorder import ( 

10 init_connection_count, 

11 register_pools_connection_count, 

12) 

13from redis.utils import check_protocol_version, deprecated_function 

14 

15 

16class EventListenerInterface(ABC): 

17 """ 

18 Represents a listener for given event object. 

19 """ 

20 

21 @abstractmethod 

22 def listen(self, event: object): 

23 pass 

24 

25 

26class AsyncEventListenerInterface(ABC): 

27 """ 

28 Represents an async listener for given event object. 

29 """ 

30 

31 @abstractmethod 

32 async def listen(self, event: object): 

33 pass 

34 

35 

36class EventDispatcherInterface(ABC): 

37 """ 

38 Represents a dispatcher that dispatches events to listeners 

39 associated with given event. 

40 """ 

41 

42 @abstractmethod 

43 def dispatch(self, event: object): 

44 pass 

45 

46 @abstractmethod 

47 async def dispatch_async(self, event: object): 

48 pass 

49 

50 @abstractmethod 

51 def register_listeners( 

52 self, 

53 mappings: Dict[ 

54 Type[object], 

55 List[Union[EventListenerInterface, AsyncEventListenerInterface]], 

56 ], 

57 ): 

58 """Register additional listeners.""" 

59 pass 

60 

61 

62class EventException(Exception): 

63 """ 

64 Exception wrapper that adds an event object into exception context. 

65 """ 

66 

67 def __init__(self, exception: Exception, event: object): 

68 self.exception = exception 

69 self.event = event 

70 super().__init__(exception) 

71 

72 

73class EventDispatcher(EventDispatcherInterface): 

74 # TODO: Make dispatcher to accept external mappings. 

75 def __init__( 

76 self, 

77 event_listeners: Optional[ 

78 Dict[Type[object], List[EventListenerInterface]] 

79 ] = None, 

80 ): 

81 """ 

82 Dispatcher that dispatches events to listeners associated with given event. 

83 """ 

84 self._event_listeners_mapping: Dict[ 

85 Type[object], List[EventListenerInterface] 

86 ] = { 

87 AfterConnectionReleasedEvent: [ 

88 ReAuthConnectionListener(), 

89 ], 

90 AfterPooledConnectionsInstantiationEvent: [ 

91 RegisterReAuthForPooledConnections(), 

92 ], 

93 AfterSingleConnectionInstantiationEvent: [ 

94 RegisterReAuthForSingleConnection() 

95 ], 

96 AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()], 

97 AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()], 

98 AsyncAfterConnectionReleasedEvent: [ 

99 AsyncReAuthConnectionListener(), 

100 ], 

101 } 

102 

103 self._lock = threading.Lock() 

104 self._async_lock = None 

105 

106 if event_listeners: 

107 self.register_listeners(event_listeners) 

108 

109 def dispatch(self, event: object): 

110 with self._lock: 

111 listeners = self._event_listeners_mapping.get(type(event), []) 

112 

113 for listener in listeners: 

114 listener.listen(event) 

115 

116 async def dispatch_async(self, event: object): 

117 if self._async_lock is None: 

118 self._async_lock = asyncio.Lock() 

119 

120 async with self._async_lock: 

121 listeners = self._event_listeners_mapping.get(type(event), []) 

122 

123 for listener in listeners: 

124 await listener.listen(event) 

125 

126 def register_listeners( 

127 self, 

128 mappings: Dict[ 

129 Type[object], 

130 List[Union[EventListenerInterface, AsyncEventListenerInterface]], 

131 ], 

132 ): 

133 with self._lock: 

134 for event_type in mappings: 

135 if event_type in self._event_listeners_mapping: 

136 self._event_listeners_mapping[event_type] = list( 

137 set( 

138 self._event_listeners_mapping[event_type] 

139 + mappings[event_type] 

140 ) 

141 ) 

142 else: 

143 self._event_listeners_mapping[event_type] = mappings[event_type] 

144 

145 

146class AfterConnectionReleasedEvent: 

147 """ 

148 Event that will be fired before each command execution. 

149 """ 

150 

151 def __init__(self, connection): 

152 self._connection = connection 

153 

154 @property 

155 def connection(self): 

156 return self._connection 

157 

158 

159class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent): 

160 pass 

161 

162 

163class ClientType(Enum): 

164 SYNC = ("sync",) 

165 ASYNC = ("async",) 

166 

167 

168class AfterPooledConnectionsInstantiationEvent: 

169 """ 

170 Event that will be fired after pooled connection instances was created. 

171 """ 

172 

173 def __init__( 

174 self, 

175 connection_pools: List, 

176 client_type: ClientType, 

177 credential_provider: Optional[CredentialProvider] = None, 

178 ): 

179 self._connection_pools = connection_pools 

180 self._client_type = client_type 

181 self._credential_provider = credential_provider 

182 

183 @property 

184 def connection_pools(self): 

185 return self._connection_pools 

186 

187 @property 

188 def client_type(self) -> ClientType: 

189 return self._client_type 

190 

191 @property 

192 def credential_provider(self) -> Union[CredentialProvider, None]: 

193 return self._credential_provider 

194 

195 

196class AfterSingleConnectionInstantiationEvent: 

197 """ 

198 Event that will be fired after single connection instances was created. 

199 

200 :param connection_lock: For sync client thread-lock should be provided, 

201 for async asyncio.Lock 

202 """ 

203 

204 def __init__( 

205 self, 

206 connection, 

207 client_type: ClientType, 

208 connection_lock: Union[threading.RLock, asyncio.Lock], 

209 ): 

210 self._connection = connection 

211 self._client_type = client_type 

212 self._connection_lock = connection_lock 

213 

214 @property 

215 def connection(self): 

216 return self._connection 

217 

218 @property 

219 def client_type(self) -> ClientType: 

220 return self._client_type 

221 

222 @property 

223 def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: 

224 return self._connection_lock 

225 

226 

227class AfterPubSubConnectionInstantiationEvent: 

228 def __init__( 

229 self, 

230 pubsub_connection, 

231 connection_pool, 

232 client_type: ClientType, 

233 connection_lock: Union[threading.RLock, asyncio.Lock], 

234 ): 

235 self._pubsub_connection = pubsub_connection 

236 self._connection_pool = connection_pool 

237 self._client_type = client_type 

238 self._connection_lock = connection_lock 

239 

240 @property 

241 def pubsub_connection(self): 

242 return self._pubsub_connection 

243 

244 @property 

245 def connection_pool(self): 

246 return self._connection_pool 

247 

248 @property 

249 def client_type(self) -> ClientType: 

250 return self._client_type 

251 

252 @property 

253 def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: 

254 return self._connection_lock 

255 

256 

257class AfterAsyncClusterInstantiationEvent: 

258 """ 

259 Event that will be fired after async cluster instance was created. 

260 

261 Async cluster doesn't use connection pools, 

262 instead ClusterNode object manages connections. 

263 """ 

264 

265 def __init__( 

266 self, 

267 nodes: dict, 

268 credential_provider: Optional[CredentialProvider] = None, 

269 ): 

270 self._nodes = nodes 

271 self._credential_provider = credential_provider 

272 

273 @property 

274 def nodes(self) -> dict: 

275 return self._nodes 

276 

277 @property 

278 def credential_provider(self) -> Union[CredentialProvider, None]: 

279 return self._credential_provider 

280 

281 

282class OnCommandsFailEvent: 

283 """ 

284 Event fired whenever a command fails during the execution. 

285 """ 

286 

287 def __init__( 

288 self, 

289 commands: tuple, 

290 exception: Exception, 

291 ): 

292 self._commands = commands 

293 self._exception = exception 

294 

295 @property 

296 def commands(self) -> tuple: 

297 return self._commands 

298 

299 @property 

300 def exception(self) -> Exception: 

301 return self._exception 

302 

303 

304class AsyncOnCommandsFailEvent(OnCommandsFailEvent): 

305 pass 

306 

307 

308class ReAuthConnectionListener(EventListenerInterface): 

309 """ 

310 Listener that performs re-authentication of given connection. 

311 """ 

312 

313 def listen(self, event: AfterConnectionReleasedEvent): 

314 event.connection.re_auth() 

315 

316 

317class AsyncReAuthConnectionListener(AsyncEventListenerInterface): 

318 """ 

319 Async listener that performs re-authentication of given connection. 

320 """ 

321 

322 async def listen(self, event: AsyncAfterConnectionReleasedEvent): 

323 await event.connection.re_auth() 

324 

325 

326class RegisterReAuthForPooledConnections(EventListenerInterface): 

327 """ 

328 Listener that registers a re-authentication callback for pooled connections. 

329 Required by :class:`StreamingCredentialProvider`. 

330 """ 

331 

332 def __init__(self): 

333 self._event = None 

334 

335 def listen(self, event: AfterPooledConnectionsInstantiationEvent): 

336 if isinstance(event.credential_provider, StreamingCredentialProvider): 

337 self._event = event 

338 

339 if event.client_type == ClientType.SYNC: 

340 event.credential_provider.on_next(self._re_auth) 

341 event.credential_provider.on_error(self._raise_on_error) 

342 else: 

343 event.credential_provider.on_next(self._re_auth_async) 

344 event.credential_provider.on_error(self._raise_on_error_async) 

345 

346 def _re_auth(self, token): 

347 for pool in self._event.connection_pools: 

348 pool.re_auth_callback(token) 

349 

350 async def _re_auth_async(self, token): 

351 for pool in self._event.connection_pools: 

352 await pool.re_auth_callback(token) 

353 

354 def _raise_on_error(self, error: Exception): 

355 raise EventException(error, self._event) 

356 

357 async def _raise_on_error_async(self, error: Exception): 

358 raise EventException(error, self._event) 

359 

360 

361class RegisterReAuthForSingleConnection(EventListenerInterface): 

362 """ 

363 Listener that registers a re-authentication callback for single connection. 

364 Required by :class:`StreamingCredentialProvider`. 

365 """ 

366 

367 def __init__(self): 

368 self._event = None 

369 

370 def listen(self, event: AfterSingleConnectionInstantiationEvent): 

371 if isinstance( 

372 event.connection.credential_provider, StreamingCredentialProvider 

373 ): 

374 self._event = event 

375 

376 if event.client_type == ClientType.SYNC: 

377 event.connection.credential_provider.on_next(self._re_auth) 

378 event.connection.credential_provider.on_error(self._raise_on_error) 

379 else: 

380 event.connection.credential_provider.on_next(self._re_auth_async) 

381 event.connection.credential_provider.on_error( 

382 self._raise_on_error_async 

383 ) 

384 

385 def _re_auth(self, token): 

386 with self._event.connection_lock: 

387 self._event.connection.send_command( 

388 "AUTH", token.try_get("oid"), token.get_value() 

389 ) 

390 self._event.connection.read_response() 

391 

392 async def _re_auth_async(self, token): 

393 async with self._event.connection_lock: 

394 await self._event.connection.send_command( 

395 "AUTH", token.try_get("oid"), token.get_value() 

396 ) 

397 await self._event.connection.read_response() 

398 

399 def _raise_on_error(self, error: Exception): 

400 raise EventException(error, self._event) 

401 

402 async def _raise_on_error_async(self, error: Exception): 

403 raise EventException(error, self._event) 

404 

405 

406class RegisterReAuthForAsyncClusterNodes(EventListenerInterface): 

407 def __init__(self): 

408 self._event = None 

409 

410 def listen(self, event: AfterAsyncClusterInstantiationEvent): 

411 if isinstance(event.credential_provider, StreamingCredentialProvider): 

412 self._event = event 

413 event.credential_provider.on_next(self._re_auth) 

414 event.credential_provider.on_error(self._raise_on_error) 

415 

416 async def _re_auth(self, token: TokenInterface): 

417 for key in self._event.nodes: 

418 await self._event.nodes[key].re_auth_callback(token) 

419 

420 async def _raise_on_error(self, error: Exception): 

421 raise EventException(error, self._event) 

422 

423 

424class RegisterReAuthForPubSub(EventListenerInterface): 

425 def __init__(self): 

426 self._connection = None 

427 self._connection_pool = None 

428 self._client_type = None 

429 self._connection_lock = None 

430 self._event = None 

431 

432 def listen(self, event: AfterPubSubConnectionInstantiationEvent): 

433 if isinstance( 

434 event.pubsub_connection.credential_provider, StreamingCredentialProvider 

435 ) and check_protocol_version(event.pubsub_connection.get_protocol(), 3): 

436 self._event = event 

437 self._connection = event.pubsub_connection 

438 self._connection_pool = event.connection_pool 

439 self._client_type = event.client_type 

440 self._connection_lock = event.connection_lock 

441 

442 if self._client_type == ClientType.SYNC: 

443 self._connection.credential_provider.on_next(self._re_auth) 

444 self._connection.credential_provider.on_error(self._raise_on_error) 

445 else: 

446 self._connection.credential_provider.on_next(self._re_auth_async) 

447 self._connection.credential_provider.on_error( 

448 self._raise_on_error_async 

449 ) 

450 

451 def _re_auth(self, token: TokenInterface): 

452 with self._connection_lock: 

453 self._connection.send_command( 

454 "AUTH", token.try_get("oid"), token.get_value() 

455 ) 

456 self._connection.read_response() 

457 

458 self._connection_pool.re_auth_callback(token) 

459 

460 async def _re_auth_async(self, token: TokenInterface): 

461 async with self._connection_lock: 

462 await self._connection.send_command( 

463 "AUTH", token.try_get("oid"), token.get_value() 

464 ) 

465 await self._connection.read_response() 

466 

467 await self._connection_pool.re_auth_callback(token) 

468 

469 def _raise_on_error(self, error: Exception): 

470 raise EventException(error, self._event) 

471 

472 async def _raise_on_error_async(self, error: Exception): 

473 raise EventException(error, self._event) 

474 

475 

476class InitializeConnectionCountObservability(EventListenerInterface): 

477 """ 

478 Listener that initializes connection count observability. 

479 """ 

480 

481 @deprecated_function( 

482 reason="Connection count is now tracked via record_connection_count(). " 

483 "This functionality will be removed in the next major version", 

484 version="7.4.0", 

485 ) 

486 def listen(self, event: AfterPooledConnectionsInstantiationEvent): 

487 # Initialize gauge only once, subsequent calls won't have an affect. 

488 # Note: init_connection_count() and register_pools_connection_count() 

489 # are deprecated and will emit their own warnings. 

490 init_connection_count() 

491 

492 # Register pools for connection count observability. 

493 register_pools_connection_count(event.connection_pools)