Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/event.py: 46%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

230 statements  

1import asyncio 

2import threading 

3from abc import ABC, abstractmethod 

4from enum import Enum 

5from typing import Dict, List, Optional, Type, Union 

6 

7from redis.auth.token import TokenInterface 

8from redis.credentials import CredentialProvider, StreamingCredentialProvider 

9 

10 

11class EventListenerInterface(ABC): 

12 """ 

13 Represents a listener for given event object. 

14 """ 

15 

16 @abstractmethod 

17 def listen(self, event: object): 

18 pass 

19 

20 

21class AsyncEventListenerInterface(ABC): 

22 """ 

23 Represents an async listener for given event object. 

24 """ 

25 

26 @abstractmethod 

27 async def listen(self, event: object): 

28 pass 

29 

30 

31class EventDispatcherInterface(ABC): 

32 """ 

33 Represents a dispatcher that dispatches events to listeners 

34 associated with given event. 

35 """ 

36 

37 @abstractmethod 

38 def dispatch(self, event: object): 

39 pass 

40 

41 @abstractmethod 

42 async def dispatch_async(self, event: object): 

43 pass 

44 

45 @abstractmethod 

46 def register_listeners( 

47 self, 

48 mappings: Dict[ 

49 Type[object], 

50 List[Union[EventListenerInterface, AsyncEventListenerInterface]], 

51 ], 

52 ): 

53 """Register additional listeners.""" 

54 pass 

55 

56 

57class EventException(Exception): 

58 """ 

59 Exception wrapper that adds an event object into exception context. 

60 """ 

61 

62 def __init__(self, exception: Exception, event: object): 

63 self.exception = exception 

64 self.event = event 

65 super().__init__(exception) 

66 

67 

68class EventDispatcher(EventDispatcherInterface): 

69 # TODO: Make dispatcher to accept external mappings. 

70 def __init__( 

71 self, 

72 event_listeners: Optional[ 

73 Dict[Type[object], List[EventListenerInterface]] 

74 ] = None, 

75 ): 

76 """ 

77 Dispatcher that dispatches events to listeners associated with given event. 

78 """ 

79 self._event_listeners_mapping: Dict[ 

80 Type[object], List[EventListenerInterface] 

81 ] = { 

82 AfterConnectionReleasedEvent: [ 

83 ReAuthConnectionListener(), 

84 ], 

85 AfterPooledConnectionsInstantiationEvent: [ 

86 RegisterReAuthForPooledConnections() 

87 ], 

88 AfterSingleConnectionInstantiationEvent: [ 

89 RegisterReAuthForSingleConnection() 

90 ], 

91 AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()], 

92 AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()], 

93 AsyncAfterConnectionReleasedEvent: [ 

94 AsyncReAuthConnectionListener(), 

95 ], 

96 } 

97 

98 self._lock = threading.Lock() 

99 self._async_lock = None 

100 

101 if event_listeners: 

102 self.register_listeners(event_listeners) 

103 

104 def dispatch(self, event: object): 

105 with self._lock: 

106 listeners = self._event_listeners_mapping.get(type(event), []) 

107 

108 for listener in listeners: 

109 listener.listen(event) 

110 

111 async def dispatch_async(self, event: object): 

112 if self._async_lock is None: 

113 self._async_lock = asyncio.Lock() 

114 

115 async with self._async_lock: 

116 listeners = self._event_listeners_mapping.get(type(event), []) 

117 

118 for listener in listeners: 

119 await listener.listen(event) 

120 

121 def register_listeners( 

122 self, 

123 mappings: Dict[ 

124 Type[object], 

125 List[Union[EventListenerInterface, AsyncEventListenerInterface]], 

126 ], 

127 ): 

128 with self._lock: 

129 for event_type in mappings: 

130 if event_type in self._event_listeners_mapping: 

131 self._event_listeners_mapping[event_type] = list( 

132 set( 

133 self._event_listeners_mapping[event_type] 

134 + mappings[event_type] 

135 ) 

136 ) 

137 else: 

138 self._event_listeners_mapping[event_type] = mappings[event_type] 

139 

140 

141class AfterConnectionReleasedEvent: 

142 """ 

143 Event that will be fired before each command execution. 

144 """ 

145 

146 def __init__(self, connection): 

147 self._connection = connection 

148 

149 @property 

150 def connection(self): 

151 return self._connection 

152 

153 

154class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent): 

155 pass 

156 

157 

158class ClientType(Enum): 

159 SYNC = ("sync",) 

160 ASYNC = ("async",) 

161 

162 

163class AfterPooledConnectionsInstantiationEvent: 

164 """ 

165 Event that will be fired after pooled connection instances was created. 

166 """ 

167 

168 def __init__( 

169 self, 

170 connection_pools: List, 

171 client_type: ClientType, 

172 credential_provider: Optional[CredentialProvider] = None, 

173 ): 

174 self._connection_pools = connection_pools 

175 self._client_type = client_type 

176 self._credential_provider = credential_provider 

177 

178 @property 

179 def connection_pools(self): 

180 return self._connection_pools 

181 

182 @property 

183 def client_type(self) -> ClientType: 

184 return self._client_type 

185 

186 @property 

187 def credential_provider(self) -> Union[CredentialProvider, None]: 

188 return self._credential_provider 

189 

190 

191class AfterSingleConnectionInstantiationEvent: 

192 """ 

193 Event that will be fired after single connection instances was created. 

194 

195 :param connection_lock: For sync client thread-lock should be provided, 

196 for async asyncio.Lock 

197 """ 

198 

199 def __init__( 

200 self, 

201 connection, 

202 client_type: ClientType, 

203 connection_lock: Union[threading.RLock, asyncio.Lock], 

204 ): 

205 self._connection = connection 

206 self._client_type = client_type 

207 self._connection_lock = connection_lock 

208 

209 @property 

210 def connection(self): 

211 return self._connection 

212 

213 @property 

214 def client_type(self) -> ClientType: 

215 return self._client_type 

216 

217 @property 

218 def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: 

219 return self._connection_lock 

220 

221 

222class AfterPubSubConnectionInstantiationEvent: 

223 def __init__( 

224 self, 

225 pubsub_connection, 

226 connection_pool, 

227 client_type: ClientType, 

228 connection_lock: Union[threading.RLock, asyncio.Lock], 

229 ): 

230 self._pubsub_connection = pubsub_connection 

231 self._connection_pool = connection_pool 

232 self._client_type = client_type 

233 self._connection_lock = connection_lock 

234 

235 @property 

236 def pubsub_connection(self): 

237 return self._pubsub_connection 

238 

239 @property 

240 def connection_pool(self): 

241 return self._connection_pool 

242 

243 @property 

244 def client_type(self) -> ClientType: 

245 return self._client_type 

246 

247 @property 

248 def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: 

249 return self._connection_lock 

250 

251 

252class AfterAsyncClusterInstantiationEvent: 

253 """ 

254 Event that will be fired after async cluster instance was created. 

255 

256 Async cluster doesn't use connection pools, 

257 instead ClusterNode object manages connections. 

258 """ 

259 

260 def __init__( 

261 self, 

262 nodes: dict, 

263 credential_provider: Optional[CredentialProvider] = None, 

264 ): 

265 self._nodes = nodes 

266 self._credential_provider = credential_provider 

267 

268 @property 

269 def nodes(self) -> dict: 

270 return self._nodes 

271 

272 @property 

273 def credential_provider(self) -> Union[CredentialProvider, None]: 

274 return self._credential_provider 

275 

276 

277class OnCommandsFailEvent: 

278 """ 

279 Event fired whenever a command fails during the execution. 

280 """ 

281 

282 def __init__( 

283 self, 

284 commands: tuple, 

285 exception: Exception, 

286 ): 

287 self._commands = commands 

288 self._exception = exception 

289 

290 @property 

291 def commands(self) -> tuple: 

292 return self._commands 

293 

294 @property 

295 def exception(self) -> Exception: 

296 return self._exception 

297 

298 

299class AsyncOnCommandsFailEvent(OnCommandsFailEvent): 

300 pass 

301 

302 

303class ReAuthConnectionListener(EventListenerInterface): 

304 """ 

305 Listener that performs re-authentication of given connection. 

306 """ 

307 

308 def listen(self, event: AfterConnectionReleasedEvent): 

309 event.connection.re_auth() 

310 

311 

312class AsyncReAuthConnectionListener(AsyncEventListenerInterface): 

313 """ 

314 Async listener that performs re-authentication of given connection. 

315 """ 

316 

317 async def listen(self, event: AsyncAfterConnectionReleasedEvent): 

318 await event.connection.re_auth() 

319 

320 

321class RegisterReAuthForPooledConnections(EventListenerInterface): 

322 """ 

323 Listener that registers a re-authentication callback for pooled connections. 

324 Required by :class:`StreamingCredentialProvider`. 

325 """ 

326 

327 def __init__(self): 

328 self._event = None 

329 

330 def listen(self, event: AfterPooledConnectionsInstantiationEvent): 

331 if isinstance(event.credential_provider, StreamingCredentialProvider): 

332 self._event = event 

333 

334 if event.client_type == ClientType.SYNC: 

335 event.credential_provider.on_next(self._re_auth) 

336 event.credential_provider.on_error(self._raise_on_error) 

337 else: 

338 event.credential_provider.on_next(self._re_auth_async) 

339 event.credential_provider.on_error(self._raise_on_error_async) 

340 

341 def _re_auth(self, token): 

342 for pool in self._event.connection_pools: 

343 pool.re_auth_callback(token) 

344 

345 async def _re_auth_async(self, token): 

346 for pool in self._event.connection_pools: 

347 await pool.re_auth_callback(token) 

348 

349 def _raise_on_error(self, error: Exception): 

350 raise EventException(error, self._event) 

351 

352 async def _raise_on_error_async(self, error: Exception): 

353 raise EventException(error, self._event) 

354 

355 

356class RegisterReAuthForSingleConnection(EventListenerInterface): 

357 """ 

358 Listener that registers a re-authentication callback for single connection. 

359 Required by :class:`StreamingCredentialProvider`. 

360 """ 

361 

362 def __init__(self): 

363 self._event = None 

364 

365 def listen(self, event: AfterSingleConnectionInstantiationEvent): 

366 if isinstance( 

367 event.connection.credential_provider, StreamingCredentialProvider 

368 ): 

369 self._event = event 

370 

371 if event.client_type == ClientType.SYNC: 

372 event.connection.credential_provider.on_next(self._re_auth) 

373 event.connection.credential_provider.on_error(self._raise_on_error) 

374 else: 

375 event.connection.credential_provider.on_next(self._re_auth_async) 

376 event.connection.credential_provider.on_error( 

377 self._raise_on_error_async 

378 ) 

379 

380 def _re_auth(self, token): 

381 with self._event.connection_lock: 

382 self._event.connection.send_command( 

383 "AUTH", token.try_get("oid"), token.get_value() 

384 ) 

385 self._event.connection.read_response() 

386 

387 async def _re_auth_async(self, token): 

388 async with self._event.connection_lock: 

389 await self._event.connection.send_command( 

390 "AUTH", token.try_get("oid"), token.get_value() 

391 ) 

392 await self._event.connection.read_response() 

393 

394 def _raise_on_error(self, error: Exception): 

395 raise EventException(error, self._event) 

396 

397 async def _raise_on_error_async(self, error: Exception): 

398 raise EventException(error, self._event) 

399 

400 

401class RegisterReAuthForAsyncClusterNodes(EventListenerInterface): 

402 def __init__(self): 

403 self._event = None 

404 

405 def listen(self, event: AfterAsyncClusterInstantiationEvent): 

406 if isinstance(event.credential_provider, StreamingCredentialProvider): 

407 self._event = event 

408 event.credential_provider.on_next(self._re_auth) 

409 event.credential_provider.on_error(self._raise_on_error) 

410 

411 async def _re_auth(self, token: TokenInterface): 

412 for key in self._event.nodes: 

413 await self._event.nodes[key].re_auth_callback(token) 

414 

415 async def _raise_on_error(self, error: Exception): 

416 raise EventException(error, self._event) 

417 

418 

419class RegisterReAuthForPubSub(EventListenerInterface): 

420 def __init__(self): 

421 self._connection = None 

422 self._connection_pool = None 

423 self._client_type = None 

424 self._connection_lock = None 

425 self._event = None 

426 

427 def listen(self, event: AfterPubSubConnectionInstantiationEvent): 

428 if isinstance( 

429 event.pubsub_connection.credential_provider, StreamingCredentialProvider 

430 ) and event.pubsub_connection.get_protocol() in [3, "3"]: 

431 self._event = event 

432 self._connection = event.pubsub_connection 

433 self._connection_pool = event.connection_pool 

434 self._client_type = event.client_type 

435 self._connection_lock = event.connection_lock 

436 

437 if self._client_type == ClientType.SYNC: 

438 self._connection.credential_provider.on_next(self._re_auth) 

439 self._connection.credential_provider.on_error(self._raise_on_error) 

440 else: 

441 self._connection.credential_provider.on_next(self._re_auth_async) 

442 self._connection.credential_provider.on_error( 

443 self._raise_on_error_async 

444 ) 

445 

446 def _re_auth(self, token: TokenInterface): 

447 with self._connection_lock: 

448 self._connection.send_command( 

449 "AUTH", token.try_get("oid"), token.get_value() 

450 ) 

451 self._connection.read_response() 

452 

453 self._connection_pool.re_auth_callback(token) 

454 

455 async def _re_auth_async(self, token: TokenInterface): 

456 async with self._connection_lock: 

457 await self._connection.send_command( 

458 "AUTH", token.try_get("oid"), token.get_value() 

459 ) 

460 await self._connection.read_response() 

461 

462 await self._connection_pool.re_auth_callback(token) 

463 

464 def _raise_on_error(self, error: Exception): 

465 raise EventException(error, self._event) 

466 

467 async def _raise_on_error_async(self, error: Exception): 

468 raise EventException(error, self._event)