Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/redis/_parsers/base.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

276 statements  

1import logging 

2from abc import ABC, abstractmethod 

3from asyncio import IncompleteReadError, StreamReader 

4from typing import Awaitable, Callable, List, Optional, Protocol, Union 

5 

6from redis.maint_notifications import ( 

7 MaintenanceNotification, 

8 NodeFailedOverNotification, 

9 NodeFailingOverNotification, 

10 NodeMigratedNotification, 

11 NodeMigratingNotification, 

12 NodeMovingNotification, 

13 OSSNodeMigratedNotification, 

14 OSSNodeMigratingNotification, 

15) 

16from redis.utils import deprecated_function, safe_str 

17 

18from ..exceptions import ( 

19 AskError, 

20 AuthenticationError, 

21 AuthenticationWrongNumberOfArgsError, 

22 BusyLoadingError, 

23 ClusterCrossSlotError, 

24 ClusterDownError, 

25 ConnectionError, 

26 ExecAbortError, 

27 ExternalAuthProviderError, 

28 MasterDownError, 

29 ModuleError, 

30 MovedError, 

31 NoPermissionError, 

32 NoScriptError, 

33 OutOfMemoryError, 

34 ReadOnlyError, 

35 ResponseError, 

36 TryAgainError, 

37) 

38from ..typing import EncodableT 

39from .encoders import Encoder 

40from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer 

41 

42MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." 

43NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" 

44MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." 

45MODULE_EXPORTS_DATA_TYPES_ERROR = ( 

46 "Error unloading module: the module " 

47 "exports one or more module-side data " 

48 "types, can't unload" 

49) 

50# user send an AUTH cmd to a server without authorization configured 

51NO_AUTH_SET_ERROR = { 

52 # Redis >= 6.0 

53 "AUTH <password> called without any password " 

54 "configured for the default user. Are you sure " 

55 "your configuration is correct?": AuthenticationError, 

56 # Redis < 6.0 

57 "Client sent AUTH, but no password is set": AuthenticationError, 

58} 

59 

60EXTERNAL_AUTH_PROVIDER_ERROR = { 

61 "problem with LDAP service": ExternalAuthProviderError, 

62} 

63 

64logger = logging.getLogger(__name__) 

65 

66 

67class BaseParser(ABC): 

68 EXCEPTION_CLASSES = { 

69 "ERR": { 

70 "max number of clients reached": ConnectionError, 

71 "invalid password": AuthenticationError, 

72 # some Redis server versions report invalid command syntax 

73 # in lowercase 

74 "wrong number of arguments " 

75 "for 'auth' command": AuthenticationWrongNumberOfArgsError, 

76 # some Redis server versions report invalid command syntax 

77 # in uppercase 

78 "wrong number of arguments " 

79 "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, 

80 MODULE_LOAD_ERROR: ModuleError, 

81 MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, 

82 NO_SUCH_MODULE_ERROR: ModuleError, 

83 MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, 

84 **NO_AUTH_SET_ERROR, 

85 **EXTERNAL_AUTH_PROVIDER_ERROR, 

86 }, 

87 "OOM": OutOfMemoryError, 

88 "WRONGPASS": AuthenticationError, 

89 "EXECABORT": ExecAbortError, 

90 "LOADING": BusyLoadingError, 

91 "NOSCRIPT": NoScriptError, 

92 "READONLY": ReadOnlyError, 

93 "NOAUTH": AuthenticationError, 

94 "NOPERM": NoPermissionError, 

95 "ASK": AskError, 

96 "TRYAGAIN": TryAgainError, 

97 "MOVED": MovedError, 

98 "CLUSTERDOWN": ClusterDownError, 

99 "CROSSSLOT": ClusterCrossSlotError, 

100 "MASTERDOWN": MasterDownError, 

101 } 

102 

103 @classmethod 

104 def parse_error(cls, response): 

105 "Parse an error response" 

106 error_code = response.split(" ")[0] 

107 if error_code in cls.EXCEPTION_CLASSES: 

108 response = response[len(error_code) + 1 :] 

109 exception_class = cls.EXCEPTION_CLASSES[error_code] 

110 if isinstance(exception_class, dict): 

111 exception_class = exception_class.get(response, ResponseError) 

112 return exception_class(response, status_code=error_code) 

113 return ResponseError(response) 

114 

115 @abstractmethod 

116 def on_disconnect(self): 

117 pass 

118 

119 @abstractmethod 

120 def on_connect(self, connection): 

121 pass 

122 

123 

124class _RESPBase(BaseParser): 

125 """Base class for sync-based resp parsing""" 

126 

127 def __init__(self, socket_read_size): 

128 self.socket_read_size = socket_read_size 

129 self.encoder = None 

130 self._sock = None 

131 self._buffer = None 

132 

133 def __del__(self): 

134 try: 

135 self.on_disconnect() 

136 except Exception: 

137 pass 

138 

139 def on_connect(self, connection): 

140 "Called when the socket connects" 

141 self._sock = connection._sock 

142 self._buffer = SocketBuffer( 

143 self._sock, self.socket_read_size, connection.socket_timeout 

144 ) 

145 self.encoder = connection.encoder 

146 

147 def on_disconnect(self): 

148 "Called when the socket disconnects" 

149 self._sock = None 

150 if self._buffer is not None: 

151 self._buffer.close() 

152 self._buffer = None 

153 self.encoder = None 

154 

155 def can_read(self, timeout: float = 0) -> bool: 

156 # TODO: Rename this API; it detects pending data or dirty/closed 

157 # connection state, not only whether application data can be read. 

158 if self._buffer is None: 

159 return False 

160 return self._buffer.can_read(timeout) 

161 

162 

163class AsyncBaseParser(BaseParser): 

164 """Base parsing class for the python-backed async parser""" 

165 

166 __slots__ = "_stream", "_read_size" 

167 

168 def __init__(self, socket_read_size: int): 

169 self._stream: Optional[StreamReader] = None 

170 self._read_size = socket_read_size 

171 

172 @deprecated_function( 

173 version="8.0.0", reason="Use can_read() instead", name="can_read_destructive" 

174 ) 

175 @abstractmethod 

176 async def can_read_destructive(self) -> bool: 

177 pass 

178 

179 @abstractmethod 

180 async def can_read(self) -> bool: 

181 # TODO: Rename this API; it detects pending data or dirty/closed 

182 # connection state, not only whether application data can be read. 

183 pass 

184 

185 async def read_response( 

186 self, disable_decoding: bool = False 

187 ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: 

188 raise NotImplementedError() 

189 

190 

191class MaintenanceNotificationsParser: 

192 """Protocol defining maintenance push notification parsing functionality""" 

193 

194 @staticmethod 

195 def parse_oss_maintenance_start_msg(response): 

196 # Expected message format is: 

197 # SMIGRATING <seq_number> <slot, range1-range2,...> 

198 id = response[1] 

199 slots = safe_str(response[2]) 

200 return OSSNodeMigratingNotification(id, slots) 

201 

202 @staticmethod 

203 def parse_oss_maintenance_completed_msg(response): 

204 # Expected message format is: 

205 # SMIGRATED <seq_number> [[<src_host:port> <dest_host:port> <slot_range>], ...] 

206 id = response[1] 

207 nodes_to_slots_mapping_data = response[2] 

208 # Build the nodes_to_slots_mapping dict structure: 

209 # { 

210 # "src_host:port": [ 

211 # {"dest_host:port": "slot_range"}, 

212 # ... 

213 # ], 

214 # ... 

215 # } 

216 nodes_to_slots_mapping = {} 

217 for src_node, dest_node, slots in nodes_to_slots_mapping_data: 

218 src_node_str = safe_str(src_node) 

219 dest_node_str = safe_str(dest_node) 

220 slots_str = safe_str(slots) 

221 

222 if src_node_str not in nodes_to_slots_mapping: 

223 nodes_to_slots_mapping[src_node_str] = [] 

224 nodes_to_slots_mapping[src_node_str].append({dest_node_str: slots_str}) 

225 

226 return OSSNodeMigratedNotification(id, nodes_to_slots_mapping) 

227 

228 @staticmethod 

229 def parse_maintenance_start_msg(response, notification_type): 

230 # Expected message format is: <notification_type> <seq_number> <time> 

231 # Examples: 

232 # MIGRATING 1 10 

233 # FAILING_OVER 2 20 

234 id = response[1] 

235 ttl = response[2] 

236 return notification_type(id, ttl) 

237 

238 @staticmethod 

239 def parse_maintenance_completed_msg(response, notification_type): 

240 # Expected message format is: <notification_type> <seq_number> 

241 # Examples: 

242 # MIGRATED 1 

243 # FAILED_OVER 2 

244 id = response[1] 

245 return notification_type(id) 

246 

247 @staticmethod 

248 def parse_moving_msg(response): 

249 # Expected message format is: MOVING <seq_number> <time> <endpoint> 

250 id = response[1] 

251 ttl = response[2] 

252 if response[3] is None: 

253 host, port = None, None 

254 else: 

255 value = safe_str(response[3]) 

256 host, port = value.split(":") 

257 port = int(port) if port is not None else None 

258 

259 return NodeMovingNotification(id, host, port, ttl) 

260 

261 

262_INVALIDATION_MESSAGE = "invalidate" 

263_MOVING_MESSAGE = "MOVING" 

264_MIGRATING_MESSAGE = "MIGRATING" 

265_MIGRATED_MESSAGE = "MIGRATED" 

266_FAILING_OVER_MESSAGE = "FAILING_OVER" 

267_FAILED_OVER_MESSAGE = "FAILED_OVER" 

268_SMIGRATING_MESSAGE = "SMIGRATING" 

269_SMIGRATED_MESSAGE = "SMIGRATED" 

270 

271_MAINTENANCE_MESSAGES = ( 

272 _MIGRATING_MESSAGE, 

273 _MIGRATED_MESSAGE, 

274 _FAILING_OVER_MESSAGE, 

275 _FAILED_OVER_MESSAGE, 

276 _SMIGRATING_MESSAGE, 

277) 

278 

279MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING: dict[ 

280 str, tuple[type[MaintenanceNotification], Callable] 

281] = { 

282 _MIGRATING_MESSAGE: ( 

283 NodeMigratingNotification, 

284 MaintenanceNotificationsParser.parse_maintenance_start_msg, 

285 ), 

286 _MIGRATED_MESSAGE: ( 

287 NodeMigratedNotification, 

288 MaintenanceNotificationsParser.parse_maintenance_completed_msg, 

289 ), 

290 _FAILING_OVER_MESSAGE: ( 

291 NodeFailingOverNotification, 

292 MaintenanceNotificationsParser.parse_maintenance_start_msg, 

293 ), 

294 _FAILED_OVER_MESSAGE: ( 

295 NodeFailedOverNotification, 

296 MaintenanceNotificationsParser.parse_maintenance_completed_msg, 

297 ), 

298 _MOVING_MESSAGE: ( 

299 NodeMovingNotification, 

300 MaintenanceNotificationsParser.parse_moving_msg, 

301 ), 

302 _SMIGRATING_MESSAGE: ( 

303 OSSNodeMigratingNotification, 

304 MaintenanceNotificationsParser.parse_oss_maintenance_start_msg, 

305 ), 

306 _SMIGRATED_MESSAGE: ( 

307 OSSNodeMigratedNotification, 

308 MaintenanceNotificationsParser.parse_oss_maintenance_completed_msg, 

309 ), 

310} 

311 

312 

313class PushNotificationsParser(Protocol): 

314 """Protocol defining RESP3-specific parsing functionality""" 

315 

316 pubsub_push_handler_func: Callable 

317 invalidation_push_handler_func: Optional[Callable] = None 

318 node_moving_push_handler_func: Optional[Callable] = None 

319 maintenance_push_handler_func: Optional[Callable] = None 

320 oss_cluster_maint_push_handler_func: Optional[Callable] = None 

321 

322 def handle_pubsub_push_response(self, response): 

323 """Handle pubsub push responses""" 

324 raise NotImplementedError() 

325 

326 def handle_push_response(self, response, **kwargs): 

327 msg_type = response[0] 

328 if isinstance(msg_type, bytes): 

329 msg_type = msg_type.decode() 

330 

331 if msg_type not in ( 

332 _INVALIDATION_MESSAGE, 

333 *_MAINTENANCE_MESSAGES, 

334 _MOVING_MESSAGE, 

335 _SMIGRATED_MESSAGE, 

336 ): 

337 return self.pubsub_push_handler_func(response) 

338 

339 try: 

340 if ( 

341 msg_type == _INVALIDATION_MESSAGE 

342 and self.invalidation_push_handler_func 

343 ): 

344 return self.invalidation_push_handler_func(response) 

345 

346 if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func: 

347 parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[ 

348 msg_type 

349 ][1] 

350 

351 notification = parser_function(response) 

352 return self.node_moving_push_handler_func(notification) 

353 

354 if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: 

355 parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[ 

356 msg_type 

357 ][1] 

358 if msg_type == _SMIGRATING_MESSAGE: 

359 notification = parser_function(response) 

360 else: 

361 notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[ 

362 msg_type 

363 ][0] 

364 notification = parser_function(response, notification_type) 

365 

366 if notification is not None: 

367 return self.maintenance_push_handler_func(notification) 

368 if msg_type == _SMIGRATED_MESSAGE and ( 

369 self.oss_cluster_maint_push_handler_func 

370 or self.maintenance_push_handler_func 

371 ): 

372 parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[ 

373 msg_type 

374 ][1] 

375 notification = parser_function(response) 

376 

377 if notification is not None: 

378 if self.maintenance_push_handler_func: 

379 self.maintenance_push_handler_func(notification) 

380 if self.oss_cluster_maint_push_handler_func: 

381 self.oss_cluster_maint_push_handler_func(notification) 

382 except Exception as e: 

383 logger.error( 

384 "Error handling {} message ({}): {}".format(msg_type, response, e) 

385 ) 

386 

387 return None 

388 

389 def set_pubsub_push_handler(self, pubsub_push_handler_func): 

390 self.pubsub_push_handler_func = pubsub_push_handler_func 

391 

392 def set_invalidation_push_handler(self, invalidation_push_handler_func): 

393 self.invalidation_push_handler_func = invalidation_push_handler_func 

394 

395 def set_node_moving_push_handler(self, node_moving_push_handler_func): 

396 self.node_moving_push_handler_func = node_moving_push_handler_func 

397 

398 def set_maintenance_push_handler(self, maintenance_push_handler_func): 

399 self.maintenance_push_handler_func = maintenance_push_handler_func 

400 

401 def set_oss_cluster_maint_push_handler(self, oss_cluster_maint_push_handler_func): 

402 self.oss_cluster_maint_push_handler_func = oss_cluster_maint_push_handler_func 

403 

404 

405class AsyncPushNotificationsParser(Protocol): 

406 """Protocol defining async RESP3-specific parsing functionality""" 

407 

408 pubsub_push_handler_func: Callable 

409 invalidation_push_handler_func: Optional[Callable] = None 

410 node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None 

411 maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None 

412 oss_cluster_maint_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None 

413 

414 async def handle_pubsub_push_response(self, response): 

415 """Handle pubsub push responses asynchronously""" 

416 raise NotImplementedError() 

417 

418 async def handle_push_response(self, response, **kwargs): 

419 """Handle push responses asynchronously""" 

420 

421 msg_type = response[0] 

422 if isinstance(msg_type, bytes): 

423 msg_type = msg_type.decode() 

424 

425 if msg_type not in ( 

426 _INVALIDATION_MESSAGE, 

427 *_MAINTENANCE_MESSAGES, 

428 _MOVING_MESSAGE, 

429 _SMIGRATED_MESSAGE, 

430 ): 

431 return await self.pubsub_push_handler_func(response) 

432 

433 try: 

434 if ( 

435 msg_type == _INVALIDATION_MESSAGE 

436 and self.invalidation_push_handler_func 

437 ): 

438 return await self.invalidation_push_handler_func(response) 

439 

440 if isinstance(msg_type, bytes): 

441 msg_type = msg_type.decode() 

442 

443 if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func: 

444 parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[ 

445 msg_type 

446 ][1] 

447 notification = parser_function(response) 

448 return await self.node_moving_push_handler_func(notification) 

449 

450 if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: 

451 parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[ 

452 msg_type 

453 ][1] 

454 if msg_type == _SMIGRATING_MESSAGE: 

455 notification = parser_function(response) 

456 else: 

457 notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[ 

458 msg_type 

459 ][0] 

460 notification = parser_function(response, notification_type) 

461 

462 if notification is not None: 

463 return await self.maintenance_push_handler_func(notification) 

464 if ( 

465 msg_type == _SMIGRATED_MESSAGE 

466 and self.oss_cluster_maint_push_handler_func 

467 ): 

468 parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[ 

469 msg_type 

470 ][1] 

471 notification = parser_function(response) 

472 if notification is not None: 

473 return await self.oss_cluster_maint_push_handler_func(notification) 

474 except Exception as e: 

475 logger.error( 

476 "Error handling {} message ({}): {}".format(msg_type, response, e) 

477 ) 

478 

479 return None 

480 

481 def set_pubsub_push_handler(self, pubsub_push_handler_func): 

482 """Set the pubsub push handler function""" 

483 self.pubsub_push_handler_func = pubsub_push_handler_func 

484 

485 def set_invalidation_push_handler(self, invalidation_push_handler_func): 

486 """Set the invalidation push handler function""" 

487 self.invalidation_push_handler_func = invalidation_push_handler_func 

488 

489 def set_node_moving_push_handler(self, node_moving_push_handler_func): 

490 self.node_moving_push_handler_func = node_moving_push_handler_func 

491 

492 def set_maintenance_push_handler(self, maintenance_push_handler_func): 

493 self.maintenance_push_handler_func = maintenance_push_handler_func 

494 

495 def set_oss_cluster_maint_push_handler(self, oss_cluster_maint_push_handler_func): 

496 self.oss_cluster_maint_push_handler_func = oss_cluster_maint_push_handler_func 

497 

498 

499class _AsyncRESPBase(AsyncBaseParser): 

500 """Base class for async resp parsing""" 

501 

502 __slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") 

503 

504 def __init__(self, socket_read_size: int): 

505 super().__init__(socket_read_size) 

506 self.encoder: Optional[Encoder] = None 

507 self._buffer = b"" 

508 self._chunks = [] 

509 self._pos = 0 

510 

511 def _clear(self): 

512 self._buffer = b"" 

513 self._chunks.clear() 

514 

515 def on_connect(self, connection): 

516 """Called when the stream connects""" 

517 self._stream = connection._reader 

518 if self._stream is None: 

519 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 

520 self.encoder = connection.encoder 

521 self._clear() 

522 self._connected = True 

523 

524 def on_disconnect(self): 

525 """Called when the stream disconnects""" 

526 self._connected = False 

527 

528 @deprecated_function( 

529 version="8.0.0", 

530 reason="Use can_read() instead", 

531 name="can_read_destructive", 

532 ) 

533 async def can_read_destructive(self) -> bool: 

534 return await self.can_read() 

535 

536 async def can_read(self) -> bool: 

537 # TODO: Rename this API; it detects pending data or dirty/closed 

538 # connection state, not only whether application data can be read. 

539 if not self._connected: 

540 raise OSError("Buffer is closed.") 

541 if self._buffer: 

542 return True 

543 # asyncio.StreamReader has no public non-destructive API for checking 

544 # buffered bytes. Preserve dirty-connection detection for the Python 

545 # parser and fail loudly if the private buffer API changes. 

546 return bool(self._stream._buffer) or self._stream.at_eof() 

547 

548 async def _read(self, length: int) -> bytes: 

549 """ 

550 Read `length` bytes of data. These are assumed to be followed 

551 by a '\r\n' terminator which is subsequently discarded. 

552 """ 

553 want = length + 2 

554 end = self._pos + want 

555 if len(self._buffer) >= end: 

556 result = self._buffer[self._pos : end - 2] 

557 else: 

558 tail = self._buffer[self._pos :] 

559 try: 

560 data = await self._stream.readexactly(want - len(tail)) 

561 except IncompleteReadError as error: 

562 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error 

563 result = (tail + data)[:-2] 

564 self._chunks.append(data) 

565 self._pos += want 

566 return result 

567 

568 async def _readline(self) -> bytes: 

569 """ 

570 read an unknown number of bytes up to the next '\r\n' 

571 line separator, which is discarded. 

572 """ 

573 found = self._buffer.find(b"\r\n", self._pos) 

574 if found >= 0: 

575 result = self._buffer[self._pos : found] 

576 else: 

577 tail = self._buffer[self._pos :] 

578 data = await self._stream.readline() 

579 if not data.endswith(b"\r\n"): 

580 raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) 

581 result = (tail + data)[:-2] 

582 self._chunks.append(data) 

583 self._pos += len(result) + 2 

584 return result