1from enum import Enum
2from typing import TYPE_CHECKING, Any, Awaitable, Dict, Optional, Tuple, Union
3
4from redis.exceptions import IncorrectPolicyType, RedisError, ResponseError
5from redis.utils import str_if_bytes
6
7if TYPE_CHECKING:
8 from redis.asyncio.cluster import ClusterNode
9
10
11class RequestPolicy(Enum):
12 ALL_NODES = "all_nodes"
13 ALL_SHARDS = "all_shards"
14 ALL_REPLICAS = "all_replicas"
15 MULTI_SHARD = "multi_shard"
16 SPECIAL = "special"
17 DEFAULT_KEYLESS = "default_keyless"
18 DEFAULT_KEYED = "default_keyed"
19 DEFAULT_NODE = "default_node"
20
21
22class ResponsePolicy(Enum):
23 ONE_SUCCEEDED = "one_succeeded"
24 ALL_SUCCEEDED = "all_succeeded"
25 AGG_LOGICAL_AND = "agg_logical_and"
26 AGG_LOGICAL_OR = "agg_logical_or"
27 AGG_MIN = "agg_min"
28 AGG_MAX = "agg_max"
29 AGG_SUM = "agg_sum"
30 SPECIAL = "special"
31 DEFAULT_KEYLESS = "default_keyless"
32 DEFAULT_KEYED = "default_keyed"
33
34
35class CommandPolicies:
36 def __init__(
37 self,
38 request_policy: RequestPolicy = RequestPolicy.DEFAULT_KEYLESS,
39 response_policy: ResponsePolicy = ResponsePolicy.DEFAULT_KEYLESS,
40 ):
41 self.request_policy = request_policy
42 self.response_policy = response_policy
43
44
45PolicyRecords = dict[str, dict[str, CommandPolicies]]
46
47
48class AbstractCommandsParser:
49 def _get_pubsub_keys(self, *args):
50 """
51 Get the keys from pubsub command.
52 Although PubSub commands have predetermined key locations, they are not
53 supported in the 'COMMAND's output, so the key positions are hardcoded
54 in this method
55 """
56 if len(args) < 2:
57 # The command has no keys in it
58 return None
59 args = [str_if_bytes(arg) for arg in args]
60 command = args[0].upper()
61 keys = None
62 if command == "PUBSUB":
63 # the second argument is a part of the command name, e.g.
64 # ['PUBSUB', 'NUMSUB', 'foo'].
65 pubsub_type = args[1].upper()
66 if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]:
67 keys = args[2:]
68 elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]:
69 # format example:
70 # SUBSCRIBE channel [channel ...]
71 keys = list(args[1:])
72 elif command in ["PUBLISH", "SPUBLISH"]:
73 # format example:
74 # PUBLISH channel message
75 keys = [args[1]]
76 return keys
77
78 def parse_subcommand(self, command, **options):
79 cmd_dict = {}
80 cmd_name = str_if_bytes(command[0])
81 cmd_dict["name"] = cmd_name
82 cmd_dict["arity"] = int(command[1])
83 cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]]
84 cmd_dict["first_key_pos"] = command[3]
85 cmd_dict["last_key_pos"] = command[4]
86 cmd_dict["step_count"] = command[5]
87 if len(command) > 7:
88 cmd_dict["tips"] = command[7]
89 cmd_dict["key_specifications"] = command[8]
90 cmd_dict["subcommands"] = command[9]
91 return cmd_dict
92
93
94class CommandsParser(AbstractCommandsParser):
95 """
96 Parses Redis commands to get command keys.
97 COMMAND output is used to determine key locations.
98 Commands that do not have a predefined key location are flagged with
99 'movablekeys', and these commands' keys are determined by the command
100 'COMMAND GETKEYS'.
101 """
102
103 def __init__(self, redis_connection):
104 self.commands = {}
105 self.redis_connection = redis_connection
106 self.initialize(self.redis_connection)
107
108 def initialize(self, r):
109 commands = r.command()
110 uppercase_commands = []
111 for cmd in commands:
112 if any(x.isupper() for x in cmd):
113 uppercase_commands.append(cmd)
114 for cmd in uppercase_commands:
115 commands[cmd.lower()] = commands.pop(cmd)
116 self.commands = commands
117
118 # As soon as this PR is merged into Redis, we should reimplement
119 # our logic to use COMMAND INFO changes to determine the key positions
120 # https://github.com/redis/redis/pull/8324
121 def get_keys(self, redis_conn, *args):
122 """
123 Get the keys from the passed command.
124
125 NOTE: Due to a bug in redis<7.0, this function does not work properly
126 for EVAL or EVALSHA when the `numkeys` arg is 0.
127 - issue: https://github.com/redis/redis/issues/9493
128 - fix: https://github.com/redis/redis/pull/9733
129
130 So, don't use this function with EVAL or EVALSHA.
131 """
132 if len(args) < 2:
133 # The command has no keys in it
134 return None
135
136 cmd_name = args[0].lower()
137 if cmd_name not in self.commands:
138 # try to split the command name and to take only the main command,
139 # e.g. 'memory' for 'memory usage'
140 cmd_name_split = cmd_name.split()
141 cmd_name = cmd_name_split[0]
142 if cmd_name in self.commands:
143 # save the splitted command to args
144 args = cmd_name_split + list(args[1:])
145 else:
146 # We'll try to reinitialize the commands cache, if the engine
147 # version has changed, the commands may not be current
148 self.initialize(redis_conn)
149 if cmd_name not in self.commands:
150 raise RedisError(
151 f"{cmd_name.upper()} command doesn't exist in Redis commands"
152 )
153
154 command = self.commands.get(cmd_name)
155 if "movablekeys" in command["flags"]:
156 keys = self._get_moveable_keys(redis_conn, *args)
157 elif "pubsub" in command["flags"] or command["name"] == "pubsub":
158 keys = self._get_pubsub_keys(*args)
159 else:
160 if (
161 command["step_count"] == 0
162 and command["first_key_pos"] == 0
163 and command["last_key_pos"] == 0
164 ):
165 is_subcmd = False
166 if "subcommands" in command:
167 subcmd_name = f"{cmd_name}|{args[1].lower()}"
168 for subcmd in command["subcommands"]:
169 if str_if_bytes(subcmd[0]) == subcmd_name:
170 command = self.parse_subcommand(subcmd)
171
172 if command["first_key_pos"] > 0:
173 is_subcmd = True
174
175 # The command doesn't have keys in it
176 if not is_subcmd:
177 return None
178 last_key_pos = command["last_key_pos"]
179 if last_key_pos < 0:
180 last_key_pos = len(args) - abs(last_key_pos)
181 keys_pos = list(
182 range(command["first_key_pos"], last_key_pos + 1, command["step_count"])
183 )
184 keys = [args[pos] for pos in keys_pos]
185
186 return keys
187
188 def _get_moveable_keys(self, redis_conn, *args):
189 """
190 NOTE: Due to a bug in redis<7.0, this function does not work properly
191 for EVAL or EVALSHA when the `numkeys` arg is 0.
192 - issue: https://github.com/redis/redis/issues/9493
193 - fix: https://github.com/redis/redis/pull/9733
194
195 So, don't use this function with EVAL or EVALSHA.
196 """
197 # The command name should be splitted into separate arguments,
198 # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE']
199 pieces = args[0].split() + list(args[1:])
200 try:
201 keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces)
202 except ResponseError as e:
203 message = e.__str__()
204 if (
205 "Invalid arguments" in message
206 or "The command has no key arguments" in message
207 ):
208 return None
209 else:
210 raise e
211 return keys
212
213 def _is_keyless_command(
214 self, command_name: str, subcommand_name: Optional[str] = None
215 ) -> bool:
216 """
217 Determines whether a given command or subcommand is considered "keyless".
218
219 A keyless command does not operate on specific keys, which is determined based
220 on the first key position in the command or subcommand details. If the command
221 or subcommand's first key position is zero or negative, it is treated as keyless.
222
223 Parameters:
224 command_name: str
225 The name of the command to check.
226 subcommand_name: Optional[str], default=None
227 The name of the subcommand to check, if applicable. If not provided,
228 the check is performed only on the command.
229
230 Returns:
231 bool
232 True if the specified command or subcommand is considered keyless,
233 False otherwise.
234
235 Raises:
236 ValueError
237 If the specified subcommand is not found within the command or the
238 specified command does not exist in the available commands.
239 """
240 if subcommand_name:
241 for subcommand in self.commands.get(command_name)["subcommands"]:
242 if str_if_bytes(subcommand[0]) == subcommand_name:
243 parsed_subcmd = self.parse_subcommand(subcommand)
244 return parsed_subcmd["first_key_pos"] <= 0
245 raise ValueError(
246 f"Subcommand {subcommand_name} not found in command {command_name}"
247 )
248 else:
249 command_details = self.commands.get(command_name, None)
250 if command_details is not None:
251 return command_details["first_key_pos"] <= 0
252
253 raise ValueError(f"Command {command_name} not found in commands")
254
255 def get_command_policies(self) -> PolicyRecords:
256 """
257 Retrieve and process the command policies for all commands and subcommands.
258
259 This method traverses through commands and subcommands, extracting policy details
260 from associated data structures and constructing a dictionary of commands with their
261 associated policies. It supports nested data structures and handles both main commands
262 and their subcommands.
263
264 Returns:
265 PolicyRecords: A collection of commands and subcommands associated with their
266 respective policies.
267
268 Raises:
269 IncorrectPolicyType: If an invalid policy type is encountered during policy extraction.
270 """
271 command_with_policies = {}
272
273 def extract_policies(data, module_name, command_name):
274 """
275 Recursively extract policies from nested data structures.
276
277 Args:
278 data: The data structure to search (can be list, dict, str, bytes, etc.)
279 command_name: The command name to associate with found policies
280 """
281 if isinstance(data, (str, bytes)):
282 # Decode bytes to string if needed
283 policy = str_if_bytes(data.decode())
284
285 # Check if this is a policy string
286 if policy.startswith("request_policy") or policy.startswith(
287 "response_policy"
288 ):
289 if policy.startswith("request_policy"):
290 policy_type = policy.split(":")[1]
291
292 try:
293 command_with_policies[module_name][
294 command_name
295 ].request_policy = RequestPolicy(policy_type)
296 except ValueError:
297 raise IncorrectPolicyType(
298 f"Incorrect request policy type: {policy_type}"
299 )
300
301 if policy.startswith("response_policy"):
302 policy_type = policy.split(":")[1]
303
304 try:
305 command_with_policies[module_name][
306 command_name
307 ].response_policy = ResponsePolicy(policy_type)
308 except ValueError:
309 raise IncorrectPolicyType(
310 f"Incorrect response policy type: {policy_type}"
311 )
312
313 elif isinstance(data, list):
314 # For lists, recursively process each element
315 for item in data:
316 extract_policies(item, module_name, command_name)
317
318 elif isinstance(data, dict):
319 # For dictionaries, recursively process each value
320 for value in data.values():
321 extract_policies(value, module_name, command_name)
322
323 for command, details in self.commands.items():
324 # Check whether the command has keys
325 is_keyless = self._is_keyless_command(command)
326
327 if is_keyless:
328 default_request_policy = RequestPolicy.DEFAULT_KEYLESS
329 default_response_policy = ResponsePolicy.DEFAULT_KEYLESS
330 else:
331 default_request_policy = RequestPolicy.DEFAULT_KEYED
332 default_response_policy = ResponsePolicy.DEFAULT_KEYED
333
334 # Check if it's a core or module command
335 split_name = command.split(".")
336
337 if len(split_name) > 1:
338 module_name = split_name[0]
339 command_name = split_name[1]
340 else:
341 module_name = "core"
342 command_name = split_name[0]
343
344 # Create a CommandPolicies object with default policies on the new command.
345 if command_with_policies.get(module_name, None) is None:
346 command_with_policies[module_name] = {
347 command_name: CommandPolicies(
348 request_policy=default_request_policy,
349 response_policy=default_response_policy,
350 )
351 }
352 else:
353 command_with_policies[module_name][command_name] = CommandPolicies(
354 request_policy=default_request_policy,
355 response_policy=default_response_policy,
356 )
357
358 tips = details.get("tips")
359 subcommands = details.get("subcommands")
360
361 # Process tips for the main command
362 if tips:
363 extract_policies(tips, module_name, command_name)
364
365 # Process subcommands
366 if subcommands:
367 for subcommand_details in subcommands:
368 # Get the subcommand name (first element)
369 subcmd_name = subcommand_details[0]
370 if isinstance(subcmd_name, bytes):
371 subcmd_name = subcmd_name.decode()
372
373 # Check whether the subcommand has keys
374 is_keyless = self._is_keyless_command(command, subcmd_name)
375
376 if is_keyless:
377 default_request_policy = RequestPolicy.DEFAULT_KEYLESS
378 default_response_policy = ResponsePolicy.DEFAULT_KEYLESS
379 else:
380 default_request_policy = RequestPolicy.DEFAULT_KEYED
381 default_response_policy = ResponsePolicy.DEFAULT_KEYED
382
383 subcmd_name = subcmd_name.replace("|", " ")
384
385 # Create a CommandPolicies object with default policies on the new command.
386 command_with_policies[module_name][subcmd_name] = CommandPolicies(
387 request_policy=default_request_policy,
388 response_policy=default_response_policy,
389 )
390
391 # Recursively extract policies from the rest of the subcommand details
392 for subcommand_detail in subcommand_details[1:]:
393 extract_policies(subcommand_detail, module_name, subcmd_name)
394
395 return command_with_policies
396
397
398class AsyncCommandsParser(AbstractCommandsParser):
399 """
400 Parses Redis commands to get command keys.
401
402 COMMAND output is used to determine key locations.
403 Commands that do not have a predefined key location are flagged with 'movablekeys',
404 and these commands' keys are determined by the command 'COMMAND GETKEYS'.
405
406 NOTE: Due to a bug in redis<7.0, this does not work properly
407 for EVAL or EVALSHA when the `numkeys` arg is 0.
408 - issue: https://github.com/redis/redis/issues/9493
409 - fix: https://github.com/redis/redis/pull/9733
410
411 So, don't use this with EVAL or EVALSHA.
412 """
413
414 __slots__ = ("commands", "node")
415
416 def __init__(self) -> None:
417 self.commands: Dict[str, Union[int, Dict[str, Any]]] = {}
418
419 async def initialize(self, node: Optional["ClusterNode"] = None) -> None:
420 if node:
421 self.node = node
422
423 commands = await self.node.execute_command("COMMAND")
424 self.commands = {cmd.lower(): command for cmd, command in commands.items()}
425
426 # As soon as this PR is merged into Redis, we should reimplement
427 # our logic to use COMMAND INFO changes to determine the key positions
428 # https://github.com/redis/redis/pull/8324
429 async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]:
430 """
431 Get the keys from the passed command.
432
433 NOTE: Due to a bug in redis<7.0, this function does not work properly
434 for EVAL or EVALSHA when the `numkeys` arg is 0.
435 - issue: https://github.com/redis/redis/issues/9493
436 - fix: https://github.com/redis/redis/pull/9733
437
438 So, don't use this function with EVAL or EVALSHA.
439 """
440 if len(args) < 2:
441 # The command has no keys in it
442 return None
443
444 cmd_name = args[0].lower()
445 if cmd_name not in self.commands:
446 # try to split the command name and to take only the main command,
447 # e.g. 'memory' for 'memory usage'
448 cmd_name_split = cmd_name.split()
449 cmd_name = cmd_name_split[0]
450 if cmd_name in self.commands:
451 # save the splitted command to args
452 args = cmd_name_split + list(args[1:])
453 else:
454 # We'll try to reinitialize the commands cache, if the engine
455 # version has changed, the commands may not be current
456 await self.initialize()
457 if cmd_name not in self.commands:
458 raise RedisError(
459 f"{cmd_name.upper()} command doesn't exist in Redis commands"
460 )
461
462 command = self.commands.get(cmd_name)
463 if "movablekeys" in command["flags"]:
464 keys = await self._get_moveable_keys(*args)
465 elif "pubsub" in command["flags"] or command["name"] == "pubsub":
466 keys = self._get_pubsub_keys(*args)
467 else:
468 if (
469 command["step_count"] == 0
470 and command["first_key_pos"] == 0
471 and command["last_key_pos"] == 0
472 ):
473 is_subcmd = False
474 if "subcommands" in command:
475 subcmd_name = f"{cmd_name}|{args[1].lower()}"
476 for subcmd in command["subcommands"]:
477 if str_if_bytes(subcmd[0]) == subcmd_name:
478 command = self.parse_subcommand(subcmd)
479
480 if command["first_key_pos"] > 0:
481 is_subcmd = True
482
483 # The command doesn't have keys in it
484 if not is_subcmd:
485 return None
486 last_key_pos = command["last_key_pos"]
487 if last_key_pos < 0:
488 last_key_pos = len(args) - abs(last_key_pos)
489 keys_pos = list(
490 range(command["first_key_pos"], last_key_pos + 1, command["step_count"])
491 )
492 keys = [args[pos] for pos in keys_pos]
493
494 return keys
495
496 async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]:
497 try:
498 keys = await self.node.execute_command("COMMAND GETKEYS", *args)
499 except ResponseError as e:
500 message = e.__str__()
501 if (
502 "Invalid arguments" in message
503 or "The command has no key arguments" in message
504 ):
505 return None
506 else:
507 raise e
508 return keys
509
510 async def _is_keyless_command(
511 self, command_name: str, subcommand_name: Optional[str] = None
512 ) -> bool:
513 """
514 Determines whether a given command or subcommand is considered "keyless".
515
516 A keyless command does not operate on specific keys, which is determined based
517 on the first key position in the command or subcommand details. If the command
518 or subcommand's first key position is zero or negative, it is treated as keyless.
519
520 Parameters:
521 command_name: str
522 The name of the command to check.
523 subcommand_name: Optional[str], default=None
524 The name of the subcommand to check, if applicable. If not provided,
525 the check is performed only on the command.
526
527 Returns:
528 bool
529 True if the specified command or subcommand is considered keyless,
530 False otherwise.
531
532 Raises:
533 ValueError
534 If the specified subcommand is not found within the command or the
535 specified command does not exist in the available commands.
536 """
537 if subcommand_name:
538 for subcommand in self.commands.get(command_name)["subcommands"]:
539 if str_if_bytes(subcommand[0]) == subcommand_name:
540 parsed_subcmd = self.parse_subcommand(subcommand)
541 return parsed_subcmd["first_key_pos"] <= 0
542 raise ValueError(
543 f"Subcommand {subcommand_name} not found in command {command_name}"
544 )
545 else:
546 command_details = self.commands.get(command_name, None)
547 if command_details is not None:
548 return command_details["first_key_pos"] <= 0
549
550 raise ValueError(f"Command {command_name} not found in commands")
551
552 async def get_command_policies(self) -> Awaitable[PolicyRecords]:
553 """
554 Retrieve and process the command policies for all commands and subcommands.
555
556 This method traverses through commands and subcommands, extracting policy details
557 from associated data structures and constructing a dictionary of commands with their
558 associated policies. It supports nested data structures and handles both main commands
559 and their subcommands.
560
561 Returns:
562 PolicyRecords: A collection of commands and subcommands associated with their
563 respective policies.
564
565 Raises:
566 IncorrectPolicyType: If an invalid policy type is encountered during policy extraction.
567 """
568 command_with_policies = {}
569
570 def extract_policies(data, module_name, command_name):
571 """
572 Recursively extract policies from nested data structures.
573
574 Args:
575 data: The data structure to search (can be list, dict, str, bytes, etc.)
576 command_name: The command name to associate with found policies
577 """
578 if isinstance(data, (str, bytes)):
579 # Decode bytes to string if needed
580 policy = str_if_bytes(data.decode())
581
582 # Check if this is a policy string
583 if policy.startswith("request_policy") or policy.startswith(
584 "response_policy"
585 ):
586 if policy.startswith("request_policy"):
587 policy_type = policy.split(":")[1]
588
589 try:
590 command_with_policies[module_name][
591 command_name
592 ].request_policy = RequestPolicy(policy_type)
593 except ValueError:
594 raise IncorrectPolicyType(
595 f"Incorrect request policy type: {policy_type}"
596 )
597
598 if policy.startswith("response_policy"):
599 policy_type = policy.split(":")[1]
600
601 try:
602 command_with_policies[module_name][
603 command_name
604 ].response_policy = ResponsePolicy(policy_type)
605 except ValueError:
606 raise IncorrectPolicyType(
607 f"Incorrect response policy type: {policy_type}"
608 )
609
610 elif isinstance(data, list):
611 # For lists, recursively process each element
612 for item in data:
613 extract_policies(item, module_name, command_name)
614
615 elif isinstance(data, dict):
616 # For dictionaries, recursively process each value
617 for value in data.values():
618 extract_policies(value, module_name, command_name)
619
620 for command, details in self.commands.items():
621 # Check whether the command has keys
622 is_keyless = await self._is_keyless_command(command)
623
624 if is_keyless:
625 default_request_policy = RequestPolicy.DEFAULT_KEYLESS
626 default_response_policy = ResponsePolicy.DEFAULT_KEYLESS
627 else:
628 default_request_policy = RequestPolicy.DEFAULT_KEYED
629 default_response_policy = ResponsePolicy.DEFAULT_KEYED
630
631 # Check if it's a core or module command
632 split_name = command.split(".")
633
634 if len(split_name) > 1:
635 module_name = split_name[0]
636 command_name = split_name[1]
637 else:
638 module_name = "core"
639 command_name = split_name[0]
640
641 # Create a CommandPolicies object with default policies on the new command.
642 if command_with_policies.get(module_name, None) is None:
643 command_with_policies[module_name] = {
644 command_name: CommandPolicies(
645 request_policy=default_request_policy,
646 response_policy=default_response_policy,
647 )
648 }
649 else:
650 command_with_policies[module_name][command_name] = CommandPolicies(
651 request_policy=default_request_policy,
652 response_policy=default_response_policy,
653 )
654
655 tips = details.get("tips")
656 subcommands = details.get("subcommands")
657
658 # Process tips for the main command
659 if tips:
660 extract_policies(tips, module_name, command_name)
661
662 # Process subcommands
663 if subcommands:
664 for subcommand_details in subcommands:
665 # Get the subcommand name (first element)
666 subcmd_name = subcommand_details[0]
667 if isinstance(subcmd_name, bytes):
668 subcmd_name = subcmd_name.decode()
669
670 # Check whether the subcommand has keys
671 is_keyless = await self._is_keyless_command(command, subcmd_name)
672
673 if is_keyless:
674 default_request_policy = RequestPolicy.DEFAULT_KEYLESS
675 default_response_policy = ResponsePolicy.DEFAULT_KEYLESS
676 else:
677 default_request_policy = RequestPolicy.DEFAULT_KEYED
678 default_response_policy = ResponsePolicy.DEFAULT_KEYED
679
680 subcmd_name = subcmd_name.replace("|", " ")
681
682 # Create a CommandPolicies object with default policies on the new command.
683 command_with_policies[module_name][subcmd_name] = CommandPolicies(
684 request_policy=default_request_policy,
685 response_policy=default_response_policy,
686 )
687
688 # Recursively extract policies from the rest of the subcommand details
689 for subcommand_detail in subcommand_details[1:]:
690 extract_policies(subcommand_detail, module_name, subcmd_name)
691
692 return command_with_policies