1from enum import Enum
2from typing import TYPE_CHECKING, Any, Awaitable, Dict, Optional, Tuple, Union
3
4from redis.exceptions import IncorrectPolicyType, RedisError, ResponseError
5from redis.utils import str_if_bytes
6
7if TYPE_CHECKING:
8 from redis.asyncio.cluster import ClusterNode
9
10
11class RequestPolicy(Enum):
12 ALL_NODES = "all_nodes"
13 ALL_SHARDS = "all_shards"
14 ALL_REPLICAS = "all_replicas"
15 MULTI_SHARD = "multi_shard"
16 SPECIAL = "special"
17 DEFAULT_KEYLESS = "default_keyless"
18 DEFAULT_KEYED = "default_keyed"
19 DEFAULT_NODE = "default_node"
20
21
22class ResponsePolicy(Enum):
23 ONE_SUCCEEDED = "one_succeeded"
24 ALL_SUCCEEDED = "all_succeeded"
25 AGG_LOGICAL_AND = "agg_logical_and"
26 AGG_LOGICAL_OR = "agg_logical_or"
27 AGG_MIN = "agg_min"
28 AGG_MAX = "agg_max"
29 AGG_SUM = "agg_sum"
30 SPECIAL = "special"
31 DEFAULT_KEYLESS = "default_keyless"
32 DEFAULT_KEYED = "default_keyed"
33
34
35class CommandPolicies:
36 def __init__(
37 self,
38 request_policy: RequestPolicy = RequestPolicy.DEFAULT_KEYLESS,
39 response_policy: ResponsePolicy = ResponsePolicy.DEFAULT_KEYLESS,
40 ):
41 self.request_policy = request_policy
42 self.response_policy = response_policy
43
44
45PolicyRecords = dict[str, dict[str, CommandPolicies]]
46
47
48class AbstractCommandsParser:
49 def _get_pubsub_keys(self, *args):
50 """
51 Get the keys from pubsub command.
52 Although PubSub commands have predetermined key locations, they are not
53 supported in the 'COMMAND's output, so the key positions are hardcoded
54 in this method
55 """
56 if len(args) < 2:
57 # The command has no keys in it
58 return None
59 args = [str_if_bytes(arg) for arg in args]
60 command = args[0].upper()
61 keys = None
62 if command == "PUBSUB":
63 # the second argument is a part of the command name, e.g.
64 # ['PUBSUB', 'NUMSUB', 'foo'].
65 pubsub_type = args[1].upper()
66 if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]:
67 keys = args[2:]
68 elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]:
69 # format example:
70 # SUBSCRIBE channel [channel ...]
71 keys = list(args[1:])
72 elif command in ["PUBLISH", "SPUBLISH"]:
73 # format example:
74 # PUBLISH channel message
75 keys = [args[1]]
76 return keys
77
78 def parse_subcommand(self, command, **options):
79 cmd_dict = {}
80 cmd_name = str_if_bytes(command[0])
81 cmd_dict["name"] = cmd_name
82 cmd_dict["arity"] = int(command[1])
83 cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]}
84 cmd_dict["first_key_pos"] = command[3]
85 cmd_dict["last_key_pos"] = command[4]
86 cmd_dict["step_count"] = command[5]
87 if len(command) > 6:
88 cmd_dict["acl_categories"] = {str_if_bytes(c) for c in command[6]}
89 if len(command) > 7:
90 cmd_dict["tips"] = command[7]
91 cmd_dict["key_specifications"] = command[8]
92 cmd_dict["subcommands"] = command[9]
93 return cmd_dict
94
95
96class CommandsParser(AbstractCommandsParser):
97 """
98 Parses Redis commands to get command keys.
99 COMMAND output is used to determine key locations.
100 Commands that do not have a predefined key location are flagged with
101 'movablekeys', and these commands' keys are determined by the command
102 'COMMAND GETKEYS'.
103 """
104
105 def __init__(self, redis_connection):
106 self.commands = {}
107 self.redis_connection = redis_connection
108 self.initialize(self.redis_connection)
109
110 def initialize(self, r):
111 commands = r.command()
112 uppercase_commands = []
113 for cmd in commands:
114 if any(x.isupper() for x in cmd):
115 uppercase_commands.append(cmd)
116 for cmd in uppercase_commands:
117 commands[cmd.lower()] = commands.pop(cmd)
118 self.commands = commands
119
120 # As soon as this PR is merged into Redis, we should reimplement
121 # our logic to use COMMAND INFO changes to determine the key positions
122 # https://github.com/redis/redis/pull/8324
123 def get_keys(self, redis_conn, *args):
124 """
125 Get the keys from the passed command.
126
127 NOTE: Due to a bug in redis<7.0, this function does not work properly
128 for EVAL or EVALSHA when the `numkeys` arg is 0.
129 - issue: https://github.com/redis/redis/issues/9493
130 - fix: https://github.com/redis/redis/pull/9733
131
132 So, don't use this function with EVAL or EVALSHA.
133 """
134 if len(args) < 2:
135 # The command has no keys in it
136 return None
137
138 cmd_name = args[0].lower()
139 if cmd_name not in self.commands:
140 # try to split the command name and to take only the main command,
141 # e.g. 'memory' for 'memory usage'
142 cmd_name_split = cmd_name.split()
143 cmd_name = cmd_name_split[0]
144 if cmd_name in self.commands:
145 # save the splitted command to args
146 args = cmd_name_split + list(args[1:])
147 else:
148 # We'll try to reinitialize the commands cache, if the engine
149 # version has changed, the commands may not be current
150 self.initialize(redis_conn)
151 if cmd_name not in self.commands:
152 raise RedisError(
153 f"{cmd_name.upper()} command doesn't exist in Redis commands"
154 )
155
156 command = self.commands.get(cmd_name)
157 if "movablekeys" in command["flags"]:
158 keys = self._get_moveable_keys(redis_conn, *args)
159 elif "pubsub" in command["flags"] or command["name"] == "pubsub":
160 keys = self._get_pubsub_keys(*args)
161 else:
162 if (
163 command["step_count"] == 0
164 and command["first_key_pos"] == 0
165 and command["last_key_pos"] == 0
166 ):
167 is_subcmd = False
168 if "subcommands" in command:
169 subcmd_name = f"{cmd_name}|{args[1].lower()}"
170 for subcmd in command["subcommands"]:
171 if str_if_bytes(subcmd[0]) == subcmd_name:
172 command = self.parse_subcommand(subcmd)
173
174 if command["first_key_pos"] > 0:
175 is_subcmd = True
176
177 # The command doesn't have keys in it
178 if not is_subcmd:
179 return None
180 last_key_pos = command["last_key_pos"]
181 if last_key_pos < 0:
182 last_key_pos = len(args) - abs(last_key_pos)
183 keys_pos = list(
184 range(command["first_key_pos"], last_key_pos + 1, command["step_count"])
185 )
186 keys = [args[pos] for pos in keys_pos]
187
188 return keys
189
190 def _get_moveable_keys(self, redis_conn, *args):
191 """
192 NOTE: Due to a bug in redis<7.0, this function does not work properly
193 for EVAL or EVALSHA when the `numkeys` arg is 0.
194 - issue: https://github.com/redis/redis/issues/9493
195 - fix: https://github.com/redis/redis/pull/9733
196
197 So, don't use this function with EVAL or EVALSHA.
198 """
199 # The command name should be splitted into separate arguments,
200 # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE']
201 pieces = args[0].split() + list(args[1:])
202 try:
203 keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces)
204 except ResponseError as e:
205 message = e.__str__()
206 if (
207 "Invalid arguments" in message
208 or "The command has no key arguments" in message
209 ):
210 return None
211 else:
212 raise e
213 return keys
214
215 def _is_keyless_command(
216 self, command_name: str, subcommand_name: Optional[str] = None
217 ) -> bool:
218 """
219 Determines whether a given command or subcommand is considered "keyless".
220
221 A keyless command does not operate on specific keys, which is determined based
222 on the first key position in the command or subcommand details. If the command
223 or subcommand's first key position is zero or negative, it is treated as keyless.
224
225 Parameters:
226 command_name: str
227 The name of the command to check.
228 subcommand_name: Optional[str], default=None
229 The name of the subcommand to check, if applicable. If not provided,
230 the check is performed only on the command.
231
232 Returns:
233 bool
234 True if the specified command or subcommand is considered keyless,
235 False otherwise.
236
237 Raises:
238 ValueError
239 If the specified subcommand is not found within the command or the
240 specified command does not exist in the available commands.
241 """
242 if subcommand_name:
243 for subcommand in self.commands.get(command_name)["subcommands"]:
244 if str_if_bytes(subcommand[0]) == subcommand_name:
245 parsed_subcmd = self.parse_subcommand(subcommand)
246 return parsed_subcmd["first_key_pos"] <= 0
247 raise ValueError(
248 f"Subcommand {subcommand_name} not found in command {command_name}"
249 )
250 else:
251 command_details = self.commands.get(command_name, None)
252 if command_details is not None:
253 return command_details["first_key_pos"] <= 0
254
255 raise ValueError(f"Command {command_name} not found in commands")
256
257 def get_command_policies(self) -> PolicyRecords:
258 """
259 Retrieve and process the command policies for all commands and subcommands.
260
261 This method traverses through commands and subcommands, extracting policy details
262 from associated data structures and constructing a dictionary of commands with their
263 associated policies. It supports nested data structures and handles both main commands
264 and their subcommands.
265
266 Returns:
267 PolicyRecords: A collection of commands and subcommands associated with their
268 respective policies.
269
270 Raises:
271 IncorrectPolicyType: If an invalid policy type is encountered during policy extraction.
272 """
273 command_with_policies = {}
274
275 def extract_policies(data, module_name, command_name):
276 """
277 Recursively extract policies from nested data structures.
278
279 Args:
280 data: The data structure to search (can be list, dict, str, bytes, etc.)
281 command_name: The command name to associate with found policies
282 """
283 if isinstance(data, (str, bytes)):
284 # Decode bytes to string if needed
285 policy = str_if_bytes(data.decode())
286
287 # Check if this is a policy string
288 if policy.startswith("request_policy") or policy.startswith(
289 "response_policy"
290 ):
291 if policy.startswith("request_policy"):
292 policy_type = policy.split(":")[1]
293
294 try:
295 command_with_policies[module_name][
296 command_name
297 ].request_policy = RequestPolicy(policy_type)
298 except ValueError:
299 raise IncorrectPolicyType(
300 f"Incorrect request policy type: {policy_type}"
301 )
302
303 if policy.startswith("response_policy"):
304 policy_type = policy.split(":")[1]
305
306 try:
307 command_with_policies[module_name][
308 command_name
309 ].response_policy = ResponsePolicy(policy_type)
310 except ValueError:
311 raise IncorrectPolicyType(
312 f"Incorrect response policy type: {policy_type}"
313 )
314
315 elif isinstance(data, list):
316 # For lists, recursively process each element
317 for item in data:
318 extract_policies(item, module_name, command_name)
319
320 elif isinstance(data, dict):
321 # For dictionaries, recursively process each value
322 for value in data.values():
323 extract_policies(value, module_name, command_name)
324
325 for command, details in self.commands.items():
326 # Check whether the command has keys
327 is_keyless = self._is_keyless_command(command)
328
329 if is_keyless:
330 default_request_policy = RequestPolicy.DEFAULT_KEYLESS
331 default_response_policy = ResponsePolicy.DEFAULT_KEYLESS
332 else:
333 default_request_policy = RequestPolicy.DEFAULT_KEYED
334 default_response_policy = ResponsePolicy.DEFAULT_KEYED
335
336 # Check if it's a core or module command
337 split_name = command.split(".")
338
339 if len(split_name) > 1:
340 module_name = split_name[0]
341 command_name = split_name[1]
342 else:
343 module_name = "core"
344 command_name = split_name[0]
345
346 # Create a CommandPolicies object with default policies on the new command.
347 if command_with_policies.get(module_name, None) is None:
348 command_with_policies[module_name] = {
349 command_name: CommandPolicies(
350 request_policy=default_request_policy,
351 response_policy=default_response_policy,
352 )
353 }
354 else:
355 command_with_policies[module_name][command_name] = CommandPolicies(
356 request_policy=default_request_policy,
357 response_policy=default_response_policy,
358 )
359
360 tips = details.get("tips")
361 subcommands = details.get("subcommands")
362
363 # Process tips for the main command
364 if tips:
365 extract_policies(tips, module_name, command_name)
366
367 # Process subcommands
368 if subcommands:
369 for subcommand_details in subcommands:
370 # Get the subcommand name (first element)
371 subcmd_name = subcommand_details[0]
372 if isinstance(subcmd_name, bytes):
373 subcmd_name = subcmd_name.decode()
374
375 # Check whether the subcommand has keys
376 is_keyless = self._is_keyless_command(command, subcmd_name)
377
378 if is_keyless:
379 default_request_policy = RequestPolicy.DEFAULT_KEYLESS
380 default_response_policy = ResponsePolicy.DEFAULT_KEYLESS
381 else:
382 default_request_policy = RequestPolicy.DEFAULT_KEYED
383 default_response_policy = ResponsePolicy.DEFAULT_KEYED
384
385 subcmd_name = subcmd_name.replace("|", " ")
386
387 # Create a CommandPolicies object with default policies on the new command.
388 command_with_policies[module_name][subcmd_name] = CommandPolicies(
389 request_policy=default_request_policy,
390 response_policy=default_response_policy,
391 )
392
393 # Recursively extract policies from the rest of the subcommand details
394 for subcommand_detail in subcommand_details[1:]:
395 extract_policies(subcommand_detail, module_name, subcmd_name)
396
397 return command_with_policies
398
399
400class AsyncCommandsParser(AbstractCommandsParser):
401 """
402 Parses Redis commands to get command keys.
403
404 COMMAND output is used to determine key locations.
405 Commands that do not have a predefined key location are flagged with 'movablekeys',
406 and these commands' keys are determined by the command 'COMMAND GETKEYS'.
407
408 NOTE: Due to a bug in redis<7.0, this does not work properly
409 for EVAL or EVALSHA when the `numkeys` arg is 0.
410 - issue: https://github.com/redis/redis/issues/9493
411 - fix: https://github.com/redis/redis/pull/9733
412
413 So, don't use this with EVAL or EVALSHA.
414 """
415
416 __slots__ = ("commands", "node")
417
418 def __init__(self) -> None:
419 self.commands: Dict[str, Union[int, Dict[str, Any]]] = {}
420
421 async def initialize(self, node: Optional["ClusterNode"] = None) -> None:
422 if node:
423 self.node = node
424
425 commands = await self.node.execute_command("COMMAND")
426 self.commands = {cmd.lower(): command for cmd, command in commands.items()}
427
428 # As soon as this PR is merged into Redis, we should reimplement
429 # our logic to use COMMAND INFO changes to determine the key positions
430 # https://github.com/redis/redis/pull/8324
431 async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]:
432 """
433 Get the keys from the passed command.
434
435 NOTE: Due to a bug in redis<7.0, this function does not work properly
436 for EVAL or EVALSHA when the `numkeys` arg is 0.
437 - issue: https://github.com/redis/redis/issues/9493
438 - fix: https://github.com/redis/redis/pull/9733
439
440 So, don't use this function with EVAL or EVALSHA.
441 """
442 if len(args) < 2:
443 # The command has no keys in it
444 return None
445
446 cmd_name = args[0].lower()
447 if cmd_name not in self.commands:
448 # try to split the command name and to take only the main command,
449 # e.g. 'memory' for 'memory usage'
450 cmd_name_split = cmd_name.split()
451 cmd_name = cmd_name_split[0]
452 if cmd_name in self.commands:
453 # save the splitted command to args
454 args = cmd_name_split + list(args[1:])
455 else:
456 # We'll try to reinitialize the commands cache, if the engine
457 # version has changed, the commands may not be current
458 await self.initialize()
459 if cmd_name not in self.commands:
460 raise RedisError(
461 f"{cmd_name.upper()} command doesn't exist in Redis commands"
462 )
463
464 command = self.commands.get(cmd_name)
465 if "movablekeys" in command["flags"]:
466 keys = await self._get_moveable_keys(*args)
467 elif "pubsub" in command["flags"] or command["name"] == "pubsub":
468 keys = self._get_pubsub_keys(*args)
469 else:
470 if (
471 command["step_count"] == 0
472 and command["first_key_pos"] == 0
473 and command["last_key_pos"] == 0
474 ):
475 is_subcmd = False
476 if "subcommands" in command:
477 subcmd_name = f"{cmd_name}|{args[1].lower()}"
478 for subcmd in command["subcommands"]:
479 if str_if_bytes(subcmd[0]) == subcmd_name:
480 command = self.parse_subcommand(subcmd)
481
482 if command["first_key_pos"] > 0:
483 is_subcmd = True
484
485 # The command doesn't have keys in it
486 if not is_subcmd:
487 return None
488 last_key_pos = command["last_key_pos"]
489 if last_key_pos < 0:
490 last_key_pos = len(args) - abs(last_key_pos)
491 keys_pos = list(
492 range(command["first_key_pos"], last_key_pos + 1, command["step_count"])
493 )
494 keys = [args[pos] for pos in keys_pos]
495
496 return keys
497
498 async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]:
499 try:
500 keys = await self.node.execute_command("COMMAND GETKEYS", *args)
501 except ResponseError as e:
502 message = e.__str__()
503 if (
504 "Invalid arguments" in message
505 or "The command has no key arguments" in message
506 ):
507 return None
508 else:
509 raise e
510 return keys
511
512 async def _is_keyless_command(
513 self, command_name: str, subcommand_name: Optional[str] = None
514 ) -> bool:
515 """
516 Determines whether a given command or subcommand is considered "keyless".
517
518 A keyless command does not operate on specific keys, which is determined based
519 on the first key position in the command or subcommand details. If the command
520 or subcommand's first key position is zero or negative, it is treated as keyless.
521
522 Parameters:
523 command_name: str
524 The name of the command to check.
525 subcommand_name: Optional[str], default=None
526 The name of the subcommand to check, if applicable. If not provided,
527 the check is performed only on the command.
528
529 Returns:
530 bool
531 True if the specified command or subcommand is considered keyless,
532 False otherwise.
533
534 Raises:
535 ValueError
536 If the specified subcommand is not found within the command or the
537 specified command does not exist in the available commands.
538 """
539 if subcommand_name:
540 for subcommand in self.commands.get(command_name)["subcommands"]:
541 if str_if_bytes(subcommand[0]) == subcommand_name:
542 parsed_subcmd = self.parse_subcommand(subcommand)
543 return parsed_subcmd["first_key_pos"] <= 0
544 raise ValueError(
545 f"Subcommand {subcommand_name} not found in command {command_name}"
546 )
547 else:
548 command_details = self.commands.get(command_name, None)
549 if command_details is not None:
550 return command_details["first_key_pos"] <= 0
551
552 raise ValueError(f"Command {command_name} not found in commands")
553
554 async def get_command_policies(self) -> Awaitable[PolicyRecords]:
555 """
556 Retrieve and process the command policies for all commands and subcommands.
557
558 This method traverses through commands and subcommands, extracting policy details
559 from associated data structures and constructing a dictionary of commands with their
560 associated policies. It supports nested data structures and handles both main commands
561 and their subcommands.
562
563 Returns:
564 PolicyRecords: A collection of commands and subcommands associated with their
565 respective policies.
566
567 Raises:
568 IncorrectPolicyType: If an invalid policy type is encountered during policy extraction.
569 """
570 command_with_policies = {}
571
572 def extract_policies(data, module_name, command_name):
573 """
574 Recursively extract policies from nested data structures.
575
576 Args:
577 data: The data structure to search (can be list, dict, str, bytes, etc.)
578 command_name: The command name to associate with found policies
579 """
580 if isinstance(data, (str, bytes)):
581 # Decode bytes to string if needed
582 policy = str_if_bytes(data.decode())
583
584 # Check if this is a policy string
585 if policy.startswith("request_policy") or policy.startswith(
586 "response_policy"
587 ):
588 if policy.startswith("request_policy"):
589 policy_type = policy.split(":")[1]
590
591 try:
592 command_with_policies[module_name][
593 command_name
594 ].request_policy = RequestPolicy(policy_type)
595 except ValueError:
596 raise IncorrectPolicyType(
597 f"Incorrect request policy type: {policy_type}"
598 )
599
600 if policy.startswith("response_policy"):
601 policy_type = policy.split(":")[1]
602
603 try:
604 command_with_policies[module_name][
605 command_name
606 ].response_policy = ResponsePolicy(policy_type)
607 except ValueError:
608 raise IncorrectPolicyType(
609 f"Incorrect response policy type: {policy_type}"
610 )
611
612 elif isinstance(data, list):
613 # For lists, recursively process each element
614 for item in data:
615 extract_policies(item, module_name, command_name)
616
617 elif isinstance(data, dict):
618 # For dictionaries, recursively process each value
619 for value in data.values():
620 extract_policies(value, module_name, command_name)
621
622 for command, details in self.commands.items():
623 # Check whether the command has keys
624 is_keyless = await self._is_keyless_command(command)
625
626 if is_keyless:
627 default_request_policy = RequestPolicy.DEFAULT_KEYLESS
628 default_response_policy = ResponsePolicy.DEFAULT_KEYLESS
629 else:
630 default_request_policy = RequestPolicy.DEFAULT_KEYED
631 default_response_policy = ResponsePolicy.DEFAULT_KEYED
632
633 # Check if it's a core or module command
634 split_name = command.split(".")
635
636 if len(split_name) > 1:
637 module_name = split_name[0]
638 command_name = split_name[1]
639 else:
640 module_name = "core"
641 command_name = split_name[0]
642
643 # Create a CommandPolicies object with default policies on the new command.
644 if command_with_policies.get(module_name, None) is None:
645 command_with_policies[module_name] = {
646 command_name: CommandPolicies(
647 request_policy=default_request_policy,
648 response_policy=default_response_policy,
649 )
650 }
651 else:
652 command_with_policies[module_name][command_name] = CommandPolicies(
653 request_policy=default_request_policy,
654 response_policy=default_response_policy,
655 )
656
657 tips = details.get("tips")
658 subcommands = details.get("subcommands")
659
660 # Process tips for the main command
661 if tips:
662 extract_policies(tips, module_name, command_name)
663
664 # Process subcommands
665 if subcommands:
666 for subcommand_details in subcommands:
667 # Get the subcommand name (first element)
668 subcmd_name = subcommand_details[0]
669 if isinstance(subcmd_name, bytes):
670 subcmd_name = subcmd_name.decode()
671
672 # Check whether the subcommand has keys
673 is_keyless = await self._is_keyless_command(command, subcmd_name)
674
675 if is_keyless:
676 default_request_policy = RequestPolicy.DEFAULT_KEYLESS
677 default_response_policy = ResponsePolicy.DEFAULT_KEYLESS
678 else:
679 default_request_policy = RequestPolicy.DEFAULT_KEYED
680 default_response_policy = ResponsePolicy.DEFAULT_KEYED
681
682 subcmd_name = subcmd_name.replace("|", " ")
683
684 # Create a CommandPolicies object with default policies on the new command.
685 command_with_policies[module_name][subcmd_name] = CommandPolicies(
686 request_policy=default_request_policy,
687 response_policy=default_response_policy,
688 )
689
690 # Recursively extract policies from the rest of the subcommand details
691 for subcommand_detail in subcommand_details[1:]:
692 extract_policies(subcommand_detail, module_name, subcmd_name)
693
694 return command_with_policies