1from abc import ABC, abstractmethod
2from typing import Optional
3
4from redis._parsers.commands import (
5 CommandPolicies,
6 CommandsParser,
7 PolicyRecords,
8 RequestPolicy,
9 ResponsePolicy,
10)
11
12STATIC_POLICIES: PolicyRecords = {
13 "ft": {
14 "explaincli": CommandPolicies(
15 request_policy=RequestPolicy.DEFAULT_KEYLESS,
16 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
17 ),
18 "suglen": CommandPolicies(
19 request_policy=RequestPolicy.DEFAULT_KEYED,
20 response_policy=ResponsePolicy.DEFAULT_KEYED,
21 ),
22 "profile": CommandPolicies(
23 request_policy=RequestPolicy.DEFAULT_KEYLESS,
24 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
25 ),
26 "dropindex": CommandPolicies(
27 request_policy=RequestPolicy.DEFAULT_KEYLESS,
28 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
29 ),
30 "aliasupdate": CommandPolicies(
31 request_policy=RequestPolicy.DEFAULT_KEYLESS,
32 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
33 ),
34 "alter": CommandPolicies(
35 request_policy=RequestPolicy.DEFAULT_KEYLESS,
36 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
37 ),
38 "aggregate": CommandPolicies(
39 request_policy=RequestPolicy.DEFAULT_KEYLESS,
40 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
41 ),
42 "syndump": CommandPolicies(
43 request_policy=RequestPolicy.DEFAULT_KEYLESS,
44 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
45 ),
46 "create": CommandPolicies(
47 request_policy=RequestPolicy.DEFAULT_KEYLESS,
48 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
49 ),
50 "explain": CommandPolicies(
51 request_policy=RequestPolicy.DEFAULT_KEYLESS,
52 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
53 ),
54 "sugget": CommandPolicies(
55 request_policy=RequestPolicy.DEFAULT_KEYED,
56 response_policy=ResponsePolicy.DEFAULT_KEYED,
57 ),
58 "dictdel": CommandPolicies(
59 request_policy=RequestPolicy.DEFAULT_KEYLESS,
60 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
61 ),
62 "aliasadd": CommandPolicies(
63 request_policy=RequestPolicy.DEFAULT_KEYLESS,
64 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
65 ),
66 "dictadd": CommandPolicies(
67 request_policy=RequestPolicy.DEFAULT_KEYLESS,
68 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
69 ),
70 "synupdate": CommandPolicies(
71 request_policy=RequestPolicy.DEFAULT_KEYLESS,
72 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
73 ),
74 "drop": CommandPolicies(
75 request_policy=RequestPolicy.DEFAULT_KEYLESS,
76 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
77 ),
78 "info": CommandPolicies(
79 request_policy=RequestPolicy.DEFAULT_KEYLESS,
80 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
81 ),
82 "sugadd": CommandPolicies(
83 request_policy=RequestPolicy.DEFAULT_KEYED,
84 response_policy=ResponsePolicy.DEFAULT_KEYED,
85 ),
86 "dictdump": CommandPolicies(
87 request_policy=RequestPolicy.DEFAULT_KEYLESS,
88 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
89 ),
90 "cursor": CommandPolicies(
91 request_policy=RequestPolicy.SPECIAL,
92 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
93 ),
94 "search": CommandPolicies(
95 request_policy=RequestPolicy.DEFAULT_KEYLESS,
96 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
97 ),
98 "tagvals": CommandPolicies(
99 request_policy=RequestPolicy.DEFAULT_KEYLESS,
100 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
101 ),
102 "aliasdel": CommandPolicies(
103 request_policy=RequestPolicy.DEFAULT_KEYLESS,
104 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
105 ),
106 "sugdel": CommandPolicies(
107 request_policy=RequestPolicy.DEFAULT_KEYED,
108 response_policy=ResponsePolicy.DEFAULT_KEYED,
109 ),
110 "spellcheck": CommandPolicies(
111 request_policy=RequestPolicy.DEFAULT_KEYLESS,
112 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
113 ),
114 },
115 "core": {
116 "command": CommandPolicies(
117 request_policy=RequestPolicy.DEFAULT_KEYLESS,
118 response_policy=ResponsePolicy.DEFAULT_KEYLESS,
119 ),
120 },
121}
122
123
124class PolicyResolver(ABC):
125 @abstractmethod
126 def resolve(self, command_name: str) -> Optional[CommandPolicies]:
127 """
128 Resolves the command name and determines the associated command policies.
129
130 Args:
131 command_name: The name of the command to resolve.
132
133 Returns:
134 CommandPolicies: The policies associated with the specified command.
135 """
136 pass
137
138 @abstractmethod
139 def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver":
140 """
141 Factory method to instantiate a policy resolver with a fallback resolver.
142
143 Args:
144 fallback: Fallback resolver
145
146 Returns:
147 PolicyResolver: Returns a new policy resolver with the specified fallback resolver.
148 """
149 pass
150
151
152class AsyncPolicyResolver(ABC):
153 @abstractmethod
154 async def resolve(self, command_name: str) -> Optional[CommandPolicies]:
155 """
156 Resolves the command name and determines the associated command policies.
157
158 Args:
159 command_name: The name of the command to resolve.
160
161 Returns:
162 CommandPolicies: The policies associated with the specified command.
163 """
164 pass
165
166 @abstractmethod
167 def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver":
168 """
169 Factory method to instantiate an async policy resolver with a fallback resolver.
170
171 Args:
172 fallback: Fallback resolver
173
174 Returns:
175 AsyncPolicyResolver: Returns a new policy resolver with the specified fallback resolver.
176 """
177 pass
178
179
180class BasePolicyResolver(PolicyResolver):
181 """
182 Base class for policy resolvers.
183 """
184
185 def __init__(
186 self, policies: PolicyRecords, fallback: Optional[PolicyResolver] = None
187 ) -> None:
188 self._policies = policies
189 self._fallback = fallback
190
191 def resolve(self, command_name: str) -> Optional[CommandPolicies]:
192 parts = command_name.split(".")
193
194 if len(parts) > 2:
195 raise ValueError(f"Wrong command or module name: {command_name}")
196
197 module, command = parts if len(parts) == 2 else ("core", parts[0])
198
199 if self._policies.get(module, None) is None:
200 if self._fallback is not None:
201 return self._fallback.resolve(command_name)
202 else:
203 return None
204
205 if self._policies.get(module).get(command, None) is None:
206 if self._fallback is not None:
207 return self._fallback.resolve(command_name)
208 else:
209 return None
210
211 return self._policies.get(module).get(command)
212
213 @abstractmethod
214 def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver":
215 pass
216
217
218class AsyncBasePolicyResolver(AsyncPolicyResolver):
219 """
220 Async base class for policy resolvers.
221 """
222
223 def __init__(
224 self, policies: PolicyRecords, fallback: Optional[AsyncPolicyResolver] = None
225 ) -> None:
226 self._policies = policies
227 self._fallback = fallback
228
229 async def resolve(self, command_name: str) -> Optional[CommandPolicies]:
230 parts = command_name.split(".")
231
232 if len(parts) > 2:
233 raise ValueError(f"Wrong command or module name: {command_name}")
234
235 module, command = parts if len(parts) == 2 else ("core", parts[0])
236
237 if self._policies.get(module, None) is None:
238 if self._fallback is not None:
239 return await self._fallback.resolve(command_name)
240 else:
241 return None
242
243 if self._policies.get(module).get(command, None) is None:
244 if self._fallback is not None:
245 return await self._fallback.resolve(command_name)
246 else:
247 return None
248
249 return self._policies.get(module).get(command)
250
251 @abstractmethod
252 def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver":
253 pass
254
255
256class DynamicPolicyResolver(BasePolicyResolver):
257 """
258 Resolves policy dynamically based on the COMMAND output.
259 """
260
261 def __init__(
262 self, commands_parser: CommandsParser, fallback: Optional[PolicyResolver] = None
263 ) -> None:
264 """
265 Parameters:
266 commands_parser (CommandsParser): COMMAND output parser.
267 fallback (Optional[PolicyResolver]): An optional resolver to be used when the
268 primary policies cannot handle a specific request.
269 """
270 self._commands_parser = commands_parser
271 super().__init__(commands_parser.get_command_policies(), fallback)
272
273 def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver":
274 return DynamicPolicyResolver(self._commands_parser, fallback)
275
276
277class StaticPolicyResolver(BasePolicyResolver):
278 """
279 Resolves policy from a static list of policy records.
280 """
281
282 def __init__(self, fallback: Optional[PolicyResolver] = None) -> None:
283 """
284 Parameters:
285 fallback (Optional[PolicyResolver]): An optional fallback policy resolver
286 used for resolving policies if static policies are inadequate.
287 """
288 super().__init__(STATIC_POLICIES, fallback)
289
290 def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver":
291 return StaticPolicyResolver(fallback)
292
293
294class AsyncDynamicPolicyResolver(AsyncBasePolicyResolver):
295 """
296 Async version of DynamicPolicyResolver.
297 """
298
299 def __init__(
300 self,
301 policy_records: PolicyRecords,
302 fallback: Optional[AsyncPolicyResolver] = None,
303 ) -> None:
304 """
305 Parameters:
306 policy_records (PolicyRecords): Policy records.
307 fallback (Optional[AsyncPolicyResolver]): An optional resolver to be used when the
308 primary policies cannot handle a specific request.
309 """
310 super().__init__(policy_records, fallback)
311
312 def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver":
313 return AsyncDynamicPolicyResolver(self._policies, fallback)
314
315
316class AsyncStaticPolicyResolver(AsyncBasePolicyResolver):
317 """
318 Async version of StaticPolicyResolver.
319 """
320
321 def __init__(self, fallback: Optional[AsyncPolicyResolver] = None) -> None:
322 """
323 Parameters:
324 fallback (Optional[AsyncPolicyResolver]): An optional fallback policy resolver
325 used for resolving policies if static policies are inadequate.
326 """
327 super().__init__(STATIC_POLICIES, fallback)
328
329 def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver":
330 return AsyncStaticPolicyResolver(fallback)