Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/event.py: 47%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

236 statements  

1import asyncio 

2import threading 

3from abc import ABC, abstractmethod 

4from enum import Enum 

5from typing import Dict, List, Optional, Type, Union 

6 

7from redis.auth.token import TokenInterface 

8from redis.credentials import CredentialProvider, StreamingCredentialProvider 

9from redis.observability.recorder import ( 

10 init_connection_count, 

11 register_pools_connection_count, 

12) 

13from redis.utils import check_protocol_version 

14 

15 

16class EventListenerInterface(ABC): 

17 """ 

18 Represents a listener for given event object. 

19 """ 

20 

21 @abstractmethod 

22 def listen(self, event: object): 

23 pass 

24 

25 

26class AsyncEventListenerInterface(ABC): 

27 """ 

28 Represents an async listener for given event object. 

29 """ 

30 

31 @abstractmethod 

32 async def listen(self, event: object): 

33 pass 

34 

35 

36class EventDispatcherInterface(ABC): 

37 """ 

38 Represents a dispatcher that dispatches events to listeners 

39 associated with given event. 

40 """ 

41 

42 @abstractmethod 

43 def dispatch(self, event: object): 

44 pass 

45 

46 @abstractmethod 

47 async def dispatch_async(self, event: object): 

48 pass 

49 

50 @abstractmethod 

51 def register_listeners( 

52 self, 

53 mappings: Dict[ 

54 Type[object], 

55 List[Union[EventListenerInterface, AsyncEventListenerInterface]], 

56 ], 

57 ): 

58 """Register additional listeners.""" 

59 pass 

60 

61 

62class EventException(Exception): 

63 """ 

64 Exception wrapper that adds an event object into exception context. 

65 """ 

66 

67 def __init__(self, exception: Exception, event: object): 

68 self.exception = exception 

69 self.event = event 

70 super().__init__(exception) 

71 

72 

73class EventDispatcher(EventDispatcherInterface): 

74 # TODO: Make dispatcher to accept external mappings. 

75 def __init__( 

76 self, 

77 event_listeners: Optional[ 

78 Dict[Type[object], List[EventListenerInterface]] 

79 ] = None, 

80 ): 

81 """ 

82 Dispatcher that dispatches events to listeners associated with given event. 

83 """ 

84 self._event_listeners_mapping: Dict[ 

85 Type[object], List[EventListenerInterface] 

86 ] = { 

87 AfterConnectionReleasedEvent: [ 

88 ReAuthConnectionListener(), 

89 ], 

90 AfterPooledConnectionsInstantiationEvent: [ 

91 RegisterReAuthForPooledConnections(), 

92 InitializeConnectionCountObservability(), 

93 ], 

94 AfterSingleConnectionInstantiationEvent: [ 

95 RegisterReAuthForSingleConnection() 

96 ], 

97 AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()], 

98 AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()], 

99 AsyncAfterConnectionReleasedEvent: [ 

100 AsyncReAuthConnectionListener(), 

101 ], 

102 } 

103 

104 self._lock = threading.Lock() 

105 self._async_lock = None 

106 

107 if event_listeners: 

108 self.register_listeners(event_listeners) 

109 

110 def dispatch(self, event: object): 

111 with self._lock: 

112 listeners = self._event_listeners_mapping.get(type(event), []) 

113 

114 for listener in listeners: 

115 listener.listen(event) 

116 

117 async def dispatch_async(self, event: object): 

118 if self._async_lock is None: 

119 self._async_lock = asyncio.Lock() 

120 

121 async with self._async_lock: 

122 listeners = self._event_listeners_mapping.get(type(event), []) 

123 

124 for listener in listeners: 

125 await listener.listen(event) 

126 

127 def register_listeners( 

128 self, 

129 mappings: Dict[ 

130 Type[object], 

131 List[Union[EventListenerInterface, AsyncEventListenerInterface]], 

132 ], 

133 ): 

134 with self._lock: 

135 for event_type in mappings: 

136 if event_type in self._event_listeners_mapping: 

137 self._event_listeners_mapping[event_type] = list( 

138 set( 

139 self._event_listeners_mapping[event_type] 

140 + mappings[event_type] 

141 ) 

142 ) 

143 else: 

144 self._event_listeners_mapping[event_type] = mappings[event_type] 

145 

146 

147class AfterConnectionReleasedEvent: 

148 """ 

149 Event that will be fired before each command execution. 

150 """ 

151 

152 def __init__(self, connection): 

153 self._connection = connection 

154 

155 @property 

156 def connection(self): 

157 return self._connection 

158 

159 

160class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent): 

161 pass 

162 

163 

164class ClientType(Enum): 

165 SYNC = ("sync",) 

166 ASYNC = ("async",) 

167 

168 

169class AfterPooledConnectionsInstantiationEvent: 

170 """ 

171 Event that will be fired after pooled connection instances was created. 

172 """ 

173 

174 def __init__( 

175 self, 

176 connection_pools: List, 

177 client_type: ClientType, 

178 credential_provider: Optional[CredentialProvider] = None, 

179 ): 

180 self._connection_pools = connection_pools 

181 self._client_type = client_type 

182 self._credential_provider = credential_provider 

183 

184 @property 

185 def connection_pools(self): 

186 return self._connection_pools 

187 

188 @property 

189 def client_type(self) -> ClientType: 

190 return self._client_type 

191 

192 @property 

193 def credential_provider(self) -> Union[CredentialProvider, None]: 

194 return self._credential_provider 

195 

196 

197class AfterSingleConnectionInstantiationEvent: 

198 """ 

199 Event that will be fired after single connection instances was created. 

200 

201 :param connection_lock: For sync client thread-lock should be provided, 

202 for async asyncio.Lock 

203 """ 

204 

205 def __init__( 

206 self, 

207 connection, 

208 client_type: ClientType, 

209 connection_lock: Union[threading.RLock, asyncio.Lock], 

210 ): 

211 self._connection = connection 

212 self._client_type = client_type 

213 self._connection_lock = connection_lock 

214 

215 @property 

216 def connection(self): 

217 return self._connection 

218 

219 @property 

220 def client_type(self) -> ClientType: 

221 return self._client_type 

222 

223 @property 

224 def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: 

225 return self._connection_lock 

226 

227 

228class AfterPubSubConnectionInstantiationEvent: 

229 def __init__( 

230 self, 

231 pubsub_connection, 

232 connection_pool, 

233 client_type: ClientType, 

234 connection_lock: Union[threading.RLock, asyncio.Lock], 

235 ): 

236 self._pubsub_connection = pubsub_connection 

237 self._connection_pool = connection_pool 

238 self._client_type = client_type 

239 self._connection_lock = connection_lock 

240 

241 @property 

242 def pubsub_connection(self): 

243 return self._pubsub_connection 

244 

245 @property 

246 def connection_pool(self): 

247 return self._connection_pool 

248 

249 @property 

250 def client_type(self) -> ClientType: 

251 return self._client_type 

252 

253 @property 

254 def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: 

255 return self._connection_lock 

256 

257 

258class AfterAsyncClusterInstantiationEvent: 

259 """ 

260 Event that will be fired after async cluster instance was created. 

261 

262 Async cluster doesn't use connection pools, 

263 instead ClusterNode object manages connections. 

264 """ 

265 

266 def __init__( 

267 self, 

268 nodes: dict, 

269 credential_provider: Optional[CredentialProvider] = None, 

270 ): 

271 self._nodes = nodes 

272 self._credential_provider = credential_provider 

273 

274 @property 

275 def nodes(self) -> dict: 

276 return self._nodes 

277 

278 @property 

279 def credential_provider(self) -> Union[CredentialProvider, None]: 

280 return self._credential_provider 

281 

282 

283class OnCommandsFailEvent: 

284 """ 

285 Event fired whenever a command fails during the execution. 

286 """ 

287 

288 def __init__( 

289 self, 

290 commands: tuple, 

291 exception: Exception, 

292 ): 

293 self._commands = commands 

294 self._exception = exception 

295 

296 @property 

297 def commands(self) -> tuple: 

298 return self._commands 

299 

300 @property 

301 def exception(self) -> Exception: 

302 return self._exception 

303 

304 

305class AsyncOnCommandsFailEvent(OnCommandsFailEvent): 

306 pass 

307 

308 

309class ReAuthConnectionListener(EventListenerInterface): 

310 """ 

311 Listener that performs re-authentication of given connection. 

312 """ 

313 

314 def listen(self, event: AfterConnectionReleasedEvent): 

315 event.connection.re_auth() 

316 

317 

318class AsyncReAuthConnectionListener(AsyncEventListenerInterface): 

319 """ 

320 Async listener that performs re-authentication of given connection. 

321 """ 

322 

323 async def listen(self, event: AsyncAfterConnectionReleasedEvent): 

324 await event.connection.re_auth() 

325 

326 

327class RegisterReAuthForPooledConnections(EventListenerInterface): 

328 """ 

329 Listener that registers a re-authentication callback for pooled connections. 

330 Required by :class:`StreamingCredentialProvider`. 

331 """ 

332 

333 def __init__(self): 

334 self._event = None 

335 

336 def listen(self, event: AfterPooledConnectionsInstantiationEvent): 

337 if isinstance(event.credential_provider, StreamingCredentialProvider): 

338 self._event = event 

339 

340 if event.client_type == ClientType.SYNC: 

341 event.credential_provider.on_next(self._re_auth) 

342 event.credential_provider.on_error(self._raise_on_error) 

343 else: 

344 event.credential_provider.on_next(self._re_auth_async) 

345 event.credential_provider.on_error(self._raise_on_error_async) 

346 

347 def _re_auth(self, token): 

348 for pool in self._event.connection_pools: 

349 pool.re_auth_callback(token) 

350 

351 async def _re_auth_async(self, token): 

352 for pool in self._event.connection_pools: 

353 await pool.re_auth_callback(token) 

354 

355 def _raise_on_error(self, error: Exception): 

356 raise EventException(error, self._event) 

357 

358 async def _raise_on_error_async(self, error: Exception): 

359 raise EventException(error, self._event) 

360 

361 

362class RegisterReAuthForSingleConnection(EventListenerInterface): 

363 """ 

364 Listener that registers a re-authentication callback for single connection. 

365 Required by :class:`StreamingCredentialProvider`. 

366 """ 

367 

368 def __init__(self): 

369 self._event = None 

370 

371 def listen(self, event: AfterSingleConnectionInstantiationEvent): 

372 if isinstance( 

373 event.connection.credential_provider, StreamingCredentialProvider 

374 ): 

375 self._event = event 

376 

377 if event.client_type == ClientType.SYNC: 

378 event.connection.credential_provider.on_next(self._re_auth) 

379 event.connection.credential_provider.on_error(self._raise_on_error) 

380 else: 

381 event.connection.credential_provider.on_next(self._re_auth_async) 

382 event.connection.credential_provider.on_error( 

383 self._raise_on_error_async 

384 ) 

385 

386 def _re_auth(self, token): 

387 with self._event.connection_lock: 

388 self._event.connection.send_command( 

389 "AUTH", token.try_get("oid"), token.get_value() 

390 ) 

391 self._event.connection.read_response() 

392 

393 async def _re_auth_async(self, token): 

394 async with self._event.connection_lock: 

395 await self._event.connection.send_command( 

396 "AUTH", token.try_get("oid"), token.get_value() 

397 ) 

398 await self._event.connection.read_response() 

399 

400 def _raise_on_error(self, error: Exception): 

401 raise EventException(error, self._event) 

402 

403 async def _raise_on_error_async(self, error: Exception): 

404 raise EventException(error, self._event) 

405 

406 

407class RegisterReAuthForAsyncClusterNodes(EventListenerInterface): 

408 def __init__(self): 

409 self._event = None 

410 

411 def listen(self, event: AfterAsyncClusterInstantiationEvent): 

412 if isinstance(event.credential_provider, StreamingCredentialProvider): 

413 self._event = event 

414 event.credential_provider.on_next(self._re_auth) 

415 event.credential_provider.on_error(self._raise_on_error) 

416 

417 async def _re_auth(self, token: TokenInterface): 

418 for key in self._event.nodes: 

419 await self._event.nodes[key].re_auth_callback(token) 

420 

421 async def _raise_on_error(self, error: Exception): 

422 raise EventException(error, self._event) 

423 

424 

425class RegisterReAuthForPubSub(EventListenerInterface): 

426 def __init__(self): 

427 self._connection = None 

428 self._connection_pool = None 

429 self._client_type = None 

430 self._connection_lock = None 

431 self._event = None 

432 

433 def listen(self, event: AfterPubSubConnectionInstantiationEvent): 

434 if isinstance( 

435 event.pubsub_connection.credential_provider, StreamingCredentialProvider 

436 ) and check_protocol_version(event.pubsub_connection.get_protocol(), 3): 

437 self._event = event 

438 self._connection = event.pubsub_connection 

439 self._connection_pool = event.connection_pool 

440 self._client_type = event.client_type 

441 self._connection_lock = event.connection_lock 

442 

443 if self._client_type == ClientType.SYNC: 

444 self._connection.credential_provider.on_next(self._re_auth) 

445 self._connection.credential_provider.on_error(self._raise_on_error) 

446 else: 

447 self._connection.credential_provider.on_next(self._re_auth_async) 

448 self._connection.credential_provider.on_error( 

449 self._raise_on_error_async 

450 ) 

451 

452 def _re_auth(self, token: TokenInterface): 

453 with self._connection_lock: 

454 self._connection.send_command( 

455 "AUTH", token.try_get("oid"), token.get_value() 

456 ) 

457 self._connection.read_response() 

458 

459 self._connection_pool.re_auth_callback(token) 

460 

461 async def _re_auth_async(self, token: TokenInterface): 

462 async with self._connection_lock: 

463 await self._connection.send_command( 

464 "AUTH", token.try_get("oid"), token.get_value() 

465 ) 

466 await self._connection.read_response() 

467 

468 await self._connection_pool.re_auth_callback(token) 

469 

470 def _raise_on_error(self, error: Exception): 

471 raise EventException(error, self._event) 

472 

473 async def _raise_on_error_async(self, error: Exception): 

474 raise EventException(error, self._event) 

475 

476 

477class InitializeConnectionCountObservability(EventListenerInterface): 

478 """ 

479 Listener that initializes connection count observability. 

480 """ 

481 

482 def listen(self, event: AfterPooledConnectionsInstantiationEvent): 

483 # Initialize gauge only once, subsequent calls won't have an affect. 

484 init_connection_count() 

485 

486 # Register pools for connection count observability. 

487 register_pools_connection_count(event.connection_pools)