Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/lark/parsers/lalr_analysis.py: 85%
212 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:30 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:30 +0000
1"""This module builds a LALR(1) transition-table for lalr_parser.py
3For now, shift/reduce conflicts are automatically resolved as shifts.
4"""
6# Author: Erez Shinan (2017)
7# Email : erezshin@gmail.com
9from collections import defaultdict
11from ..utils import classify, classify_bool, bfs, fzset, Enumerator, logger
12from ..exceptions import GrammarError
14from .grammar_analysis import GrammarAnalyzer, Terminal, LR0ItemSet
15from ..grammar import Rule
17###{standalone
19class Action:
20 def __init__(self, name):
21 self.name = name
22 def __str__(self):
23 return self.name
24 def __repr__(self):
25 return str(self)
27Shift = Action('Shift')
28Reduce = Action('Reduce')
31class ParseTable:
32 def __init__(self, states, start_states, end_states):
33 self.states = states
34 self.start_states = start_states
35 self.end_states = end_states
37 def serialize(self, memo):
38 tokens = Enumerator()
40 states = {
41 state: {tokens.get(token): ((1, arg.serialize(memo)) if action is Reduce else (0, arg))
42 for token, (action, arg) in actions.items()}
43 for state, actions in self.states.items()
44 }
46 return {
47 'tokens': tokens.reversed(),
48 'states': states,
49 'start_states': self.start_states,
50 'end_states': self.end_states,
51 }
53 @classmethod
54 def deserialize(cls, data, memo):
55 tokens = data['tokens']
56 states = {
57 state: {tokens[token]: ((Reduce, Rule.deserialize(arg, memo)) if action==1 else (Shift, arg))
58 for token, (action, arg) in actions.items()}
59 for state, actions in data['states'].items()
60 }
61 return cls(states, data['start_states'], data['end_states'])
64class IntParseTable(ParseTable):
66 @classmethod
67 def from_ParseTable(cls, parse_table):
68 enum = list(parse_table.states)
69 state_to_idx = {s:i for i,s in enumerate(enum)}
70 int_states = {}
72 for s, la in parse_table.states.items():
73 la = {k:(v[0], state_to_idx[v[1]]) if v[0] is Shift else v
74 for k,v in la.items()}
75 int_states[ state_to_idx[s] ] = la
78 start_states = {start:state_to_idx[s] for start, s in parse_table.start_states.items()}
79 end_states = {start:state_to_idx[s] for start, s in parse_table.end_states.items()}
80 return cls(int_states, start_states, end_states)
82###}
85# digraph and traverse, see The Theory and Practice of Compiler Writing
87# computes F(x) = G(x) union (union { G(y) | x R y })
88# X: nodes
89# R: relation (function mapping node -> list of nodes that satisfy the relation)
90# G: set valued function
91def digraph(X, R, G):
92 F = {}
93 S = []
94 N = dict.fromkeys(X, 0)
95 for x in X:
96 # this is always true for the first iteration, but N[x] may be updated in traverse below
97 if N[x] == 0:
98 traverse(x, S, N, X, R, G, F)
99 return F
101# x: single node
102# S: stack
103# N: weights
104# X: nodes
105# R: relation (see above)
106# G: set valued function
107# F: set valued function we are computing (map of input -> output)
108def traverse(x, S, N, X, R, G, F):
109 S.append(x)
110 d = len(S)
111 N[x] = d
112 F[x] = G[x]
113 for y in R[x]:
114 if N[y] == 0:
115 traverse(y, S, N, X, R, G, F)
116 n_x = N[x]
117 assert(n_x > 0)
118 n_y = N[y]
119 assert(n_y != 0)
120 if (n_y > 0) and (n_y < n_x):
121 N[x] = n_y
122 F[x].update(F[y])
123 if N[x] == d:
124 f_x = F[x]
125 while True:
126 z = S.pop()
127 N[z] = -1
128 F[z] = f_x
129 if z == x:
130 break
133class LALR_Analyzer(GrammarAnalyzer):
134 def __init__(self, parser_conf, debug=False, strict=False):
135 GrammarAnalyzer.__init__(self, parser_conf, debug, strict)
136 self.nonterminal_transitions = []
137 self.directly_reads = defaultdict(set)
138 self.reads = defaultdict(set)
139 self.includes = defaultdict(set)
140 self.lookback = defaultdict(set)
143 def compute_lr0_states(self):
144 self.lr0_states = set()
145 # map of kernels to LR0ItemSets
146 cache = {}
148 def step(state):
149 _, unsat = classify_bool(state.closure, lambda rp: rp.is_satisfied)
151 d = classify(unsat, lambda rp: rp.next)
152 for sym, rps in d.items():
153 kernel = fzset({rp.advance(sym) for rp in rps})
154 new_state = cache.get(kernel, None)
155 if new_state is None:
156 closure = set(kernel)
157 for rp in kernel:
158 if not rp.is_satisfied and not rp.next.is_term:
159 closure |= self.expand_rule(rp.next, self.lr0_rules_by_origin)
160 new_state = LR0ItemSet(kernel, closure)
161 cache[kernel] = new_state
163 state.transitions[sym] = new_state
164 yield new_state
166 self.lr0_states.add(state)
168 for _ in bfs(self.lr0_start_states.values(), step):
169 pass
171 def compute_reads_relations(self):
172 # handle start state
173 for root in self.lr0_start_states.values():
174 assert(len(root.kernel) == 1)
175 for rp in root.kernel:
176 assert(rp.index == 0)
177 self.directly_reads[(root, rp.next)] = set([ Terminal('$END') ])
179 for state in self.lr0_states:
180 seen = set()
181 for rp in state.closure:
182 if rp.is_satisfied:
183 continue
184 s = rp.next
185 # if s is a not a nonterminal
186 if s not in self.lr0_rules_by_origin:
187 continue
188 if s in seen:
189 continue
190 seen.add(s)
191 nt = (state, s)
192 self.nonterminal_transitions.append(nt)
193 dr = self.directly_reads[nt]
194 r = self.reads[nt]
195 next_state = state.transitions[s]
196 for rp2 in next_state.closure:
197 if rp2.is_satisfied:
198 continue
199 s2 = rp2.next
200 # if s2 is a terminal
201 if s2 not in self.lr0_rules_by_origin:
202 dr.add(s2)
203 if s2 in self.NULLABLE:
204 r.add((next_state, s2))
206 def compute_includes_lookback(self):
207 for nt in self.nonterminal_transitions:
208 state, nonterminal = nt
209 includes = []
210 lookback = self.lookback[nt]
211 for rp in state.closure:
212 if rp.rule.origin != nonterminal:
213 continue
214 # traverse the states for rp(.rule)
215 state2 = state
216 for i in range(rp.index, len(rp.rule.expansion)):
217 s = rp.rule.expansion[i]
218 nt2 = (state2, s)
219 state2 = state2.transitions[s]
220 if nt2 not in self.reads:
221 continue
222 for j in range(i + 1, len(rp.rule.expansion)):
223 if rp.rule.expansion[j] not in self.NULLABLE:
224 break
225 else:
226 includes.append(nt2)
227 # state2 is at the final state for rp.rule
228 if rp.index == 0:
229 for rp2 in state2.closure:
230 if (rp2.rule == rp.rule) and rp2.is_satisfied:
231 lookback.add((state2, rp2.rule))
232 for nt2 in includes:
233 self.includes[nt2].add(nt)
235 def compute_lookaheads(self):
236 read_sets = digraph(self.nonterminal_transitions, self.reads, self.directly_reads)
237 follow_sets = digraph(self.nonterminal_transitions, self.includes, read_sets)
239 for nt, lookbacks in self.lookback.items():
240 for state, rule in lookbacks:
241 for s in follow_sets[nt]:
242 state.lookaheads[s].add(rule)
244 def compute_lalr1_states(self):
245 m = {}
246 reduce_reduce = []
247 for state in self.lr0_states:
248 actions = {la: (Shift, next_state.closure) for la, next_state in state.transitions.items()}
249 for la, rules in state.lookaheads.items():
250 if len(rules) > 1:
251 # Try to resolve conflict based on priority
252 p = [(r.options.priority or 0, r) for r in rules]
253 p.sort(key=lambda r: r[0], reverse=True)
254 best, second_best = p[:2]
255 if best[0] > second_best[0]:
256 rules = [best[1]]
257 else:
258 reduce_reduce.append((state, la, rules))
259 continue
261 rule ,= rules
262 if la in actions:
263 if self.strict:
264 raise GrammarError(f"Shift/Reduce conflict for terminal {la.name}. [strict-mode]\n ")
265 elif self.debug:
266 logger.warning('Shift/Reduce conflict for terminal %s: (resolving as shift)', la.name)
267 logger.warning(' * %s', rule)
268 else:
269 logger.debug('Shift/Reduce conflict for terminal %s: (resolving as shift)', la.name)
270 logger.debug(' * %s', rule)
271 else:
272 actions[la] = (Reduce, list(rules)[0])
273 m[state] = { k.name: v for k, v in actions.items() }
275 if reduce_reduce:
276 msgs = []
277 for state, la, rules in reduce_reduce:
278 msg = 'Reduce/Reduce collision in %s between the following rules: %s' % (la, ''.join([ '\n\t- ' + str(r) for r in rules ]))
279 if self.debug:
280 msg += '\n collision occurred in state: {%s\n }' % ''.join(['\n\t' + str(x) for x in state.closure])
281 msgs.append(msg)
282 raise GrammarError('\n\n'.join(msgs))
284 states = { k.closure: v for k, v in m.items() }
286 # compute end states
287 end_states = {}
288 for state in states:
289 for rp in state:
290 for start in self.lr0_start_states:
291 if rp.rule.origin.name == ('$root_' + start) and rp.is_satisfied:
292 assert(start not in end_states)
293 end_states[start] = state
295 _parse_table = ParseTable(states, { start: state.closure for start, state in self.lr0_start_states.items() }, end_states)
297 if self.debug:
298 self.parse_table = _parse_table
299 else:
300 self.parse_table = IntParseTable.from_ParseTable(_parse_table)
302 def compute_lalr(self):
303 self.compute_lr0_states()
304 self.compute_reads_relations()
305 self.compute_includes_lookback()
306 self.compute_lookaheads()
307 self.compute_lalr1_states()