1from __future__ import annotations
2
3import string
4from typing import TYPE_CHECKING, Callable, ClassVar
5
6from dissect.cstruct.exceptions import ExpressionParserError, ExpressionTokenizerError
7
8if TYPE_CHECKING:
9 from dissect.cstruct import cstruct
10
11
12HEXBIN_SUFFIX = {"x", "X", "b", "B"}
13
14
15class ExpressionTokenizer:
16 def __init__(self, expression: str):
17 self.expression = expression
18 self.pos = 0
19 self.tokens = []
20
21 def equal(self, token: str, expected: str | set[str]) -> bool:
22 if isinstance(expected, set):
23 return token in expected
24 return token == expected
25
26 def alnum(self, token: str) -> bool:
27 return token.isalnum()
28
29 def alpha(self, token: str) -> bool:
30 return token.isalpha()
31
32 def digit(self, token: str) -> bool:
33 return token.isdigit()
34
35 def hexdigit(self, token: str) -> bool:
36 return token in string.hexdigits
37
38 def operator(self, token: str) -> bool:
39 return token in {"*", "/", "+", "-", "%", "&", "^", "|", "(", ")", "~"}
40
41 def match(
42 self,
43 func: Callable[[str], bool] | None = None,
44 expected: str | None = None,
45 consume: bool = True,
46 append: bool = True,
47 ) -> bool:
48 if self.eol():
49 return False
50
51 token = self.get_token()
52
53 if expected and self.equal(token, expected):
54 if append:
55 self.tokens.append(token)
56 if consume:
57 self.consume()
58 return True
59
60 if func and func(token):
61 if append:
62 self.tokens.append(token)
63 if consume:
64 self.consume()
65 return True
66
67 return False
68
69 def consume(self) -> None:
70 self.pos += 1
71
72 def eol(self) -> bool:
73 return self.pos >= len(self.expression)
74
75 def get_token(self) -> str:
76 if self.eol():
77 raise ExpressionTokenizerError(f"Out of bounds index: {self.pos}, length: {len(self.expression)}")
78 return self.expression[self.pos]
79
80 def tokenize(self) -> list[str]:
81 token = ""
82
83 # Loop over expression runs in linear time
84 while not self.eol():
85 # If token is a single character operand add it to tokens
86 if self.match(self.operator):
87 continue
88
89 # If token is a single digit, keep looping over expression and build the number
90 if self.match(self.digit, consume=False, append=False):
91 token += self.get_token()
92 self.consume()
93
94 # Support for binary and hexadecimal notation
95 if self.match(expected=HEXBIN_SUFFIX, consume=False, append=False):
96 token += self.get_token()
97 self.consume()
98
99 while self.match(self.hexdigit, consume=False, append=False):
100 token += self.get_token()
101 self.consume()
102 if self.eol():
103 break
104
105 # Checks for suffixes in numbers
106 if self.match(expected={"u", "U"}, consume=False, append=False):
107 self.consume()
108 self.match(expected={"l", "L"}, append=False)
109 self.match(expected={"l", "L"}, append=False)
110
111 elif self.match(expected={"l", "L"}, append=False):
112 self.match(expected={"l", "L"}, append=False)
113 self.match(expected={"u", "U"}, append=False)
114 else:
115 pass
116
117 # Number cannot end on x or b in the case of binary or hexadecimal notation
118 if len(token) == 2 and token[-1] in HEXBIN_SUFFIX:
119 raise ExpressionTokenizerError("Invalid binary or hex notation")
120
121 if len(token) > 1 and token[0] == "0" and token[1] not in HEXBIN_SUFFIX:
122 token = token[:1] + "o" + token[1:]
123 self.tokens.append(token)
124 token = ""
125
126 # If token is alpha or underscore we need to build the identifier
127 elif self.match(self.alpha, consume=False, append=False) or self.match(
128 expected="_", consume=False, append=False
129 ):
130 while self.match(self.alnum, consume=False, append=False) or self.match(
131 expected="_", consume=False, append=False
132 ):
133 token += self.get_token()
134 self.consume()
135 if self.eol():
136 break
137 self.tokens.append(token)
138 token = ""
139 # If token is length 2 operand make sure next character is part of length 2 operand append to tokens
140 elif self.match(expected=">", append=False) and self.match(expected=">", append=False):
141 self.tokens.append(">>")
142 elif self.match(expected="<", append=False) and self.match(expected="<", append=False):
143 self.tokens.append("<<")
144 elif self.match(expected={" ", "\t"}, append=False):
145 continue
146 else:
147 raise ExpressionTokenizerError(
148 f"Tokenizer does not recognize following token '{self.expression[self.pos]}'"
149 )
150 return self.tokens
151
152
153class Expression:
154 """Expression parser for calculations in definitions."""
155
156 binary_operators: ClassVar[dict[str, Callable[[int, int], int]]] = {
157 "|": lambda a, b: a | b,
158 "^": lambda a, b: a ^ b,
159 "&": lambda a, b: a & b,
160 "<<": lambda a, b: a << b,
161 ">>": lambda a, b: a >> b,
162 "+": lambda a, b: a + b,
163 "-": lambda a, b: a - b,
164 "*": lambda a, b: a * b,
165 "/": lambda a, b: a // b,
166 "%": lambda a, b: a % b,
167 }
168
169 unary_operators: ClassVar[dict[str, Callable[[int], int]]] = {
170 "u": lambda a: -a,
171 "~": lambda a: ~a,
172 }
173
174 precedence_levels: ClassVar[dict[str, int]] = {
175 "|": 0,
176 "^": 1,
177 "&": 2,
178 "<<": 3,
179 ">>": 3,
180 "+": 4,
181 "-": 4,
182 "*": 5,
183 "/": 5,
184 "%": 5,
185 "u": 6,
186 "~": 6,
187 "sizeof": 6,
188 }
189
190 def __init__(self, cstruct: cstruct, expression: str):
191 self.cstruct = cstruct
192 self.expression = expression
193 self.tokens = ExpressionTokenizer(expression).tokenize()
194 self.stack = []
195 self.queue = []
196
197 def __repr__(self) -> str:
198 return self.expression
199
200 def precedence(self, o1: str, o2: str) -> bool:
201 return self.precedence_levels[o1] >= self.precedence_levels[o2]
202
203 def evaluate_exp(self) -> None:
204 operator = self.stack.pop(-1)
205 res = 0
206
207 if len(self.queue) < 1:
208 raise ExpressionParserError("Invalid expression: not enough operands")
209
210 right = self.queue.pop(-1)
211 if operator in self.unary_operators:
212 res = self.unary_operators[operator](right)
213 else:
214 if len(self.queue) < 1:
215 raise ExpressionParserError("Invalid expression: not enough operands")
216
217 left = self.queue.pop(-1)
218 res = self.binary_operators[operator](left, right)
219
220 self.queue.append(res)
221
222 def is_number(self, token: str) -> bool:
223 return token.isnumeric() or (len(token) > 2 and token[0] == "0" and token[1] in ("x", "X", "b", "B", "o", "O"))
224
225 def evaluate(self, context: dict[str, int] | None = None) -> int:
226 """Evaluates an expression using a Shunting-Yard implementation."""
227
228 self.stack = []
229 self.queue = []
230 operators = set(self.binary_operators.keys()) | set(self.unary_operators.keys())
231
232 context = context or {}
233 tmp_expression = self.tokens
234
235 # Unary minus tokens; we change the semantic of '-' depending on the previous token
236 for i in range(len(self.tokens)):
237 if self.tokens[i] == "-":
238 if i == 0:
239 self.tokens[i] = "u"
240 continue
241 if self.tokens[i - 1] in operators or self.tokens[i - 1] == "u" or self.tokens[i - 1] == "(":
242 self.tokens[i] = "u"
243 continue
244
245 i = 0
246 while i < len(tmp_expression):
247 current_token = tmp_expression[i]
248 if self.is_number(current_token):
249 self.queue.append(int(current_token, 0))
250 elif current_token in context:
251 self.queue.append(int(context[current_token]))
252 elif current_token in self.cstruct.consts:
253 self.queue.append(int(self.cstruct.consts[current_token]))
254 elif current_token in self.unary_operators:
255 self.stack.append(current_token)
256 elif current_token == "sizeof":
257 if len(tmp_expression) < i + 3 or (tmp_expression[i + 1] != "(" or tmp_expression[i + 3] != ")"):
258 raise ExpressionParserError("Invalid sizeof operation")
259 self.queue.append(len(self.cstruct.resolve(tmp_expression[i + 2])))
260 i += 3
261 elif current_token in operators:
262 while (
263 len(self.stack) != 0 and self.stack[-1] != "(" and (self.precedence(self.stack[-1], current_token))
264 ):
265 self.evaluate_exp()
266 self.stack.append(current_token)
267 elif current_token == "(":
268 if i > 0:
269 previous_token = tmp_expression[i - 1]
270 if self.is_number(previous_token):
271 raise ExpressionParserError(
272 f"Parser expected sizeof or an arethmethic operator instead got: '{previous_token}'"
273 )
274
275 self.stack.append(current_token)
276 elif current_token == ")":
277 if i > 0:
278 previous_token = tmp_expression[i - 1]
279 if previous_token == "(":
280 raise ExpressionParserError(
281 f"Parser expected an expression, instead received empty parenthesis. Index: {i}"
282 )
283
284 if len(self.stack) == 0:
285 raise ExpressionParserError("Invalid expression")
286
287 while self.stack[-1] != "(":
288 self.evaluate_exp()
289
290 self.stack.pop(-1)
291 else:
292 raise ExpressionParserError(f"Unmatched token: '{current_token}'")
293 i += 1
294
295 while len(self.stack) != 0:
296 if self.stack[-1] == "(":
297 raise ExpressionParserError("Invalid expression")
298
299 self.evaluate_exp()
300
301 if len(self.queue) != 1:
302 raise ExpressionParserError("Invalid expression")
303
304 return self.queue[0]