1from __future__ import annotations
2
3import string
4from typing import TYPE_CHECKING, ClassVar
5
6from dissect.cstruct.exceptions import ExpressionParserError, ExpressionTokenizerError
7
8if TYPE_CHECKING:
9 from collections.abc import Callable
10
11 from dissect.cstruct import cstruct
12
13
14HEXBIN_SUFFIX = {"x", "X", "b", "B"}
15
16
17class ExpressionTokenizer:
18 def __init__(self, expression: str):
19 self.expression = expression
20 self.pos = 0
21 self.tokens = []
22
23 def equal(self, token: str, expected: str | set[str]) -> bool:
24 if isinstance(expected, set):
25 return token in expected
26 return token == expected
27
28 def alnum(self, token: str) -> bool:
29 return token.isalnum()
30
31 def alpha(self, token: str) -> bool:
32 return token.isalpha()
33
34 def digit(self, token: str) -> bool:
35 return token.isdigit()
36
37 def hexdigit(self, token: str) -> bool:
38 return token in string.hexdigits
39
40 def operator(self, token: str) -> bool:
41 return token in {"*", "/", "+", "-", "%", "&", "^", "|", "(", ")", "~"}
42
43 def match(
44 self,
45 func: Callable[[str], bool] | None = None,
46 expected: str | None = None,
47 consume: bool = True,
48 append: bool = True,
49 ) -> bool:
50 if self.eol():
51 return False
52
53 token = self.get_token()
54
55 if expected and self.equal(token, expected):
56 if append:
57 self.tokens.append(token)
58 if consume:
59 self.consume()
60 return True
61
62 if func and func(token):
63 if append:
64 self.tokens.append(token)
65 if consume:
66 self.consume()
67 return True
68
69 return False
70
71 def consume(self) -> None:
72 self.pos += 1
73
74 def eol(self) -> bool:
75 return self.pos >= len(self.expression)
76
77 def get_token(self) -> str:
78 if self.eol():
79 raise ExpressionTokenizerError(f"Out of bounds index: {self.pos}, length: {len(self.expression)}")
80 return self.expression[self.pos]
81
82 def tokenize(self) -> list[str]:
83 token = ""
84
85 # Loop over expression runs in linear time
86 while not self.eol():
87 # If token is a single character operand add it to tokens
88 if self.match(self.operator):
89 continue
90
91 # If token is a single digit, keep looping over expression and build the number
92 if self.match(self.digit, consume=False, append=False):
93 token += self.get_token()
94 self.consume()
95
96 # Support for binary and hexadecimal notation
97 if self.match(expected=HEXBIN_SUFFIX, consume=False, append=False):
98 token += self.get_token()
99 self.consume()
100
101 while self.match(self.hexdigit, consume=False, append=False):
102 token += self.get_token()
103 self.consume()
104 if self.eol():
105 break
106
107 # Checks for suffixes in numbers
108 if self.match(expected={"u", "U"}, consume=False, append=False):
109 self.consume()
110 self.match(expected={"l", "L"}, append=False)
111 self.match(expected={"l", "L"}, append=False)
112
113 elif self.match(expected={"l", "L"}, append=False):
114 self.match(expected={"l", "L"}, append=False)
115 self.match(expected={"u", "U"}, append=False)
116 else:
117 pass
118
119 # Number cannot end on x or b in the case of binary or hexadecimal notation
120 if len(token) == 2 and token[-1] in HEXBIN_SUFFIX:
121 raise ExpressionTokenizerError("Invalid binary or hex notation")
122
123 if len(token) > 1 and token[0] == "0" and token[1] not in HEXBIN_SUFFIX:
124 token = token[:1] + "o" + token[1:]
125 self.tokens.append(token)
126 token = ""
127
128 # If token is alpha or underscore we need to build the identifier
129 elif self.match(self.alpha, consume=False, append=False) or self.match(
130 expected="_", consume=False, append=False
131 ):
132 while self.match(self.alnum, consume=False, append=False) or self.match(
133 expected="_", consume=False, append=False
134 ):
135 token += self.get_token()
136 self.consume()
137 if self.eol():
138 break
139 self.tokens.append(token)
140 token = ""
141 # If token is length 2 operand make sure next character is part of length 2 operand append to tokens
142 elif self.match(expected=">", append=False) and self.match(expected=">", append=False):
143 self.tokens.append(">>")
144 elif self.match(expected="<", append=False) and self.match(expected="<", append=False):
145 self.tokens.append("<<")
146 elif self.match(expected={" ", "\n", "\t"}, append=False):
147 continue
148 else:
149 raise ExpressionTokenizerError(
150 f"Tokenizer does not recognize following token '{self.expression[self.pos]}'"
151 )
152 return self.tokens
153
154
155class Expression:
156 """Expression parser for calculations in definitions."""
157
158 binary_operators: ClassVar[dict[str, Callable[[int, int], int]]] = {
159 "|": lambda a, b: a | b,
160 "^": lambda a, b: a ^ b,
161 "&": lambda a, b: a & b,
162 "<<": lambda a, b: a << b,
163 ">>": lambda a, b: a >> b,
164 "+": lambda a, b: a + b,
165 "-": lambda a, b: a - b,
166 "*": lambda a, b: a * b,
167 "/": lambda a, b: a // b,
168 "%": lambda a, b: a % b,
169 }
170
171 unary_operators: ClassVar[dict[str, Callable[[int], int]]] = {
172 "u": lambda a: -a,
173 "~": lambda a: ~a,
174 }
175
176 precedence_levels: ClassVar[dict[str, int]] = {
177 "|": 0,
178 "^": 1,
179 "&": 2,
180 "<<": 3,
181 ">>": 3,
182 "+": 4,
183 "-": 4,
184 "*": 5,
185 "/": 5,
186 "%": 5,
187 "u": 6,
188 "~": 6,
189 "sizeof": 6,
190 }
191
192 def __init__(self, expression: str):
193 self.expression = expression
194 self.tokens = ExpressionTokenizer(expression).tokenize()
195 self.stack = []
196 self.queue = []
197
198 def __repr__(self) -> str:
199 return self.expression
200
201 def precedence(self, o1: str, o2: str) -> bool:
202 return self.precedence_levels[o1] >= self.precedence_levels[o2]
203
204 def evaluate_exp(self) -> None:
205 operator = self.stack.pop(-1)
206 res = 0
207
208 if len(self.queue) < 1:
209 raise ExpressionParserError("Invalid expression: not enough operands")
210
211 right = self.queue.pop(-1)
212 if operator in self.unary_operators:
213 res = self.unary_operators[operator](right)
214 else:
215 if len(self.queue) < 1:
216 raise ExpressionParserError("Invalid expression: not enough operands")
217
218 left = self.queue.pop(-1)
219 res = self.binary_operators[operator](left, right)
220
221 self.queue.append(res)
222
223 def is_number(self, token: str) -> bool:
224 return token.isnumeric() or (len(token) > 2 and token[0] == "0" and token[1] in ("x", "X", "b", "B", "o", "O"))
225
226 def evaluate(self, cs: cstruct, context: dict[str, int] | None = None) -> int:
227 """Evaluates an expression using a Shunting-Yard implementation."""
228
229 self.stack = []
230 self.queue = []
231 operators = set(self.binary_operators.keys()) | set(self.unary_operators.keys())
232
233 context = context or {}
234 tmp_expression = self.tokens
235
236 # Unary minus tokens; we change the semantic of '-' depending on the previous token
237 for i in range(len(self.tokens)):
238 if self.tokens[i] == "-":
239 if i == 0:
240 self.tokens[i] = "u"
241 continue
242 if self.tokens[i - 1] in operators or self.tokens[i - 1] == "u" or self.tokens[i - 1] == "(":
243 self.tokens[i] = "u"
244 continue
245
246 i = 0
247 while i < len(tmp_expression):
248 current_token = tmp_expression[i]
249 if self.is_number(current_token):
250 self.queue.append(int(current_token, 0))
251 elif current_token in context:
252 self.queue.append(int(context[current_token]))
253 elif current_token in cs.consts:
254 self.queue.append(int(cs.consts[current_token]))
255 elif current_token in self.unary_operators:
256 self.stack.append(current_token)
257 elif current_token == "sizeof":
258 if len(tmp_expression) < i + 3 or (tmp_expression[i + 1] != "(" or tmp_expression[i + 3] != ")"):
259 raise ExpressionParserError("Invalid sizeof operation")
260 self.queue.append(len(cs.resolve(tmp_expression[i + 2])))
261 i += 3
262 elif current_token in operators:
263 while (
264 len(self.stack) != 0 and self.stack[-1] != "(" and (self.precedence(self.stack[-1], current_token))
265 ):
266 self.evaluate_exp()
267 self.stack.append(current_token)
268 elif current_token == "(":
269 if i > 0:
270 previous_token = tmp_expression[i - 1]
271 if self.is_number(previous_token):
272 raise ExpressionParserError(
273 f"Parser expected sizeof or an arethmethic operator instead got: '{previous_token}'"
274 )
275
276 self.stack.append(current_token)
277 elif current_token == ")":
278 if i > 0:
279 previous_token = tmp_expression[i - 1]
280 if previous_token == "(":
281 raise ExpressionParserError(
282 f"Parser expected an expression, instead received empty parenthesis. Index: {i}"
283 )
284
285 if len(self.stack) == 0:
286 raise ExpressionParserError("Invalid expression")
287
288 while self.stack[-1] != "(":
289 self.evaluate_exp()
290
291 self.stack.pop(-1)
292 else:
293 raise ExpressionParserError(f"Unmatched token: '{current_token}'")
294 i += 1
295
296 while len(self.stack) != 0:
297 if self.stack[-1] == "(":
298 raise ExpressionParserError("Invalid expression")
299
300 self.evaluate_exp()
301
302 if len(self.queue) != 1:
303 raise ExpressionParserError("Invalid expression")
304
305 return self.queue[0]