1"""This module builds a LALR(1) transition-table for lalr_parser.py
2
3For now, shift/reduce conflicts are automatically resolved as shifts.
4"""
5
6# Author: Erez Shinan (2017)
7# Email : erezshin@gmail.com
8
9from typing import Dict, Set, Iterator, Tuple, List, TypeVar, Generic
10from collections import defaultdict
11
12from ..utils import classify, classify_bool, bfs, fzset, Enumerator, logger
13from ..exceptions import GrammarError
14
15from .grammar_analysis import GrammarAnalyzer, Terminal, LR0ItemSet, RulePtr, State
16from ..grammar import Rule, Symbol
17from ..common import ParserConf
18
19###{standalone
20
21class Action:
22 def __init__(self, name):
23 self.name = name
24 def __str__(self):
25 return self.name
26 def __repr__(self):
27 return str(self)
28
29Shift = Action('Shift')
30Reduce = Action('Reduce')
31
32StateT = TypeVar("StateT")
33
34class ParseTableBase(Generic[StateT]):
35 states: Dict[StateT, Dict[str, Tuple]]
36 start_states: Dict[str, StateT]
37 end_states: Dict[str, StateT]
38
39 def __init__(self, states, start_states, end_states):
40 self.states = states
41 self.start_states = start_states
42 self.end_states = end_states
43
44 def serialize(self, memo):
45 tokens = Enumerator()
46
47 states = {
48 state: {tokens.get(token): ((1, arg.serialize(memo)) if action is Reduce else (0, arg))
49 for token, (action, arg) in actions.items()}
50 for state, actions in self.states.items()
51 }
52
53 return {
54 'tokens': tokens.reversed(),
55 'states': states,
56 'start_states': self.start_states,
57 'end_states': self.end_states,
58 }
59
60 @classmethod
61 def deserialize(cls, data, memo):
62 tokens = data['tokens']
63 states = {
64 state: {tokens[token]: ((Reduce, Rule.deserialize(arg, memo)) if action==1 else (Shift, arg))
65 for token, (action, arg) in actions.items()}
66 for state, actions in data['states'].items()
67 }
68 return cls(states, data['start_states'], data['end_states'])
69
70class ParseTable(ParseTableBase['State']):
71 """Parse-table whose key is State, i.e. set[RulePtr]
72
73 Slower than IntParseTable, but useful for debugging
74 """
75 pass
76
77
78class IntParseTable(ParseTableBase[int]):
79 """Parse-table whose key is int. Best for performance."""
80
81 @classmethod
82 def from_ParseTable(cls, parse_table: ParseTable):
83 enum = list(parse_table.states)
84 state_to_idx: Dict['State', int] = {s:i for i,s in enumerate(enum)}
85 int_states = {}
86
87 for s, la in parse_table.states.items():
88 la = {k:(v[0], state_to_idx[v[1]]) if v[0] is Shift else v
89 for k,v in la.items()}
90 int_states[ state_to_idx[s] ] = la
91
92
93 start_states = {start:state_to_idx[s] for start, s in parse_table.start_states.items()}
94 end_states = {start:state_to_idx[s] for start, s in parse_table.end_states.items()}
95 return cls(int_states, start_states, end_states)
96
97###}
98
99
100# digraph and traverse, see The Theory and Practice of Compiler Writing
101
102# computes F(x) = G(x) union (union { G(y) | x R y })
103# X: nodes
104# R: relation (function mapping node -> list of nodes that satisfy the relation)
105# G: set valued function
106def digraph(X, R, G):
107 F = {}
108 S = []
109 N = dict.fromkeys(X, 0)
110 for x in X:
111 # this is always true for the first iteration, but N[x] may be updated in traverse below
112 if N[x] == 0:
113 traverse(x, S, N, X, R, G, F)
114 return F
115
116# x: single node
117# S: stack
118# N: weights
119# X: nodes
120# R: relation (see above)
121# G: set valued function
122# F: set valued function we are computing (map of input -> output)
123def traverse(x, S, N, X, R, G, F):
124 S.append(x)
125 d = len(S)
126 N[x] = d
127 F[x] = G[x]
128 for y in R[x]:
129 if N[y] == 0:
130 traverse(y, S, N, X, R, G, F)
131 n_x = N[x]
132 assert(n_x > 0)
133 n_y = N[y]
134 assert(n_y != 0)
135 if (n_y > 0) and (n_y < n_x):
136 N[x] = n_y
137 F[x].update(F[y])
138 if N[x] == d:
139 f_x = F[x]
140 while True:
141 z = S.pop()
142 N[z] = -1
143 F[z] = f_x
144 if z == x:
145 break
146
147
148class LALR_Analyzer(GrammarAnalyzer):
149 lr0_itemsets: Set[LR0ItemSet]
150 nonterminal_transitions: List[Tuple[LR0ItemSet, Symbol]]
151 lookback: Dict[Tuple[LR0ItemSet, Symbol], Set[Tuple[LR0ItemSet, Rule]]]
152 includes: Dict[Tuple[LR0ItemSet, Symbol], Set[Tuple[LR0ItemSet, Symbol]]]
153 reads: Dict[Tuple[LR0ItemSet, Symbol], Set[Tuple[LR0ItemSet, Symbol]]]
154 directly_reads: Dict[Tuple[LR0ItemSet, Symbol], Set[Symbol]]
155
156
157 def __init__(self, parser_conf: ParserConf, debug: bool=False, strict: bool=False):
158 GrammarAnalyzer.__init__(self, parser_conf, debug, strict)
159 self.nonterminal_transitions = []
160 self.directly_reads = defaultdict(set)
161 self.reads = defaultdict(set)
162 self.includes = defaultdict(set)
163 self.lookback = defaultdict(set)
164
165
166 def compute_lr0_states(self) -> None:
167 self.lr0_itemsets = set()
168 # map of kernels to LR0ItemSets
169 cache: Dict['State', LR0ItemSet] = {}
170
171 def step(state: LR0ItemSet) -> Iterator[LR0ItemSet]:
172 _, unsat = classify_bool(state.closure, lambda rp: rp.is_satisfied)
173
174 d = classify(unsat, lambda rp: rp.next)
175 for sym, rps in d.items():
176 kernel = fzset({rp.advance(sym) for rp in rps})
177 new_state = cache.get(kernel, None)
178 if new_state is None:
179 closure = set(kernel)
180 for rp in kernel:
181 if not rp.is_satisfied and not rp.next.is_term:
182 closure |= self.expand_rule(rp.next, self.lr0_rules_by_origin)
183 new_state = LR0ItemSet(kernel, closure)
184 cache[kernel] = new_state
185
186 state.transitions[sym] = new_state
187 yield new_state
188
189 self.lr0_itemsets.add(state)
190
191 for _ in bfs(self.lr0_start_states.values(), step):
192 pass
193
194 def compute_reads_relations(self):
195 # handle start state
196 for root in self.lr0_start_states.values():
197 assert(len(root.kernel) == 1)
198 for rp in root.kernel:
199 assert(rp.index == 0)
200 self.directly_reads[(root, rp.next)] = set([ Terminal('$END') ])
201
202 for state in self.lr0_itemsets:
203 seen = set()
204 for rp in state.closure:
205 if rp.is_satisfied:
206 continue
207 s = rp.next
208 # if s is a not a nonterminal
209 if s not in self.lr0_rules_by_origin:
210 continue
211 if s in seen:
212 continue
213 seen.add(s)
214 nt = (state, s)
215 self.nonterminal_transitions.append(nt)
216 dr = self.directly_reads[nt]
217 r = self.reads[nt]
218 next_state = state.transitions[s]
219 for rp2 in next_state.closure:
220 if rp2.is_satisfied:
221 continue
222 s2 = rp2.next
223 # if s2 is a terminal
224 if s2 not in self.lr0_rules_by_origin:
225 dr.add(s2)
226 if s2 in self.NULLABLE:
227 r.add((next_state, s2))
228
229 def compute_includes_lookback(self):
230 for nt in self.nonterminal_transitions:
231 state, nonterminal = nt
232 includes = []
233 lookback = self.lookback[nt]
234 for rp in state.closure:
235 if rp.rule.origin != nonterminal:
236 continue
237 # traverse the states for rp(.rule)
238 state2 = state
239 for i in range(rp.index, len(rp.rule.expansion)):
240 s = rp.rule.expansion[i]
241 nt2 = (state2, s)
242 state2 = state2.transitions[s]
243 if nt2 not in self.reads:
244 continue
245 for j in range(i + 1, len(rp.rule.expansion)):
246 if rp.rule.expansion[j] not in self.NULLABLE:
247 break
248 else:
249 includes.append(nt2)
250 # state2 is at the final state for rp.rule
251 if rp.index == 0:
252 for rp2 in state2.closure:
253 if (rp2.rule == rp.rule) and rp2.is_satisfied:
254 lookback.add((state2, rp2.rule))
255 for nt2 in includes:
256 self.includes[nt2].add(nt)
257
258 def compute_lookaheads(self):
259 read_sets = digraph(self.nonterminal_transitions, self.reads, self.directly_reads)
260 follow_sets = digraph(self.nonterminal_transitions, self.includes, read_sets)
261
262 for nt, lookbacks in self.lookback.items():
263 for state, rule in lookbacks:
264 for s in follow_sets[nt]:
265 state.lookaheads[s].add(rule)
266
267 def compute_lalr1_states(self) -> None:
268 m: Dict[LR0ItemSet, Dict[str, Tuple]] = {}
269 reduce_reduce = []
270 for itemset in self.lr0_itemsets:
271 actions: Dict[Symbol, Tuple] = {la: (Shift, next_state.closure)
272 for la, next_state in itemset.transitions.items()}
273 for la, rules in itemset.lookaheads.items():
274 if len(rules) > 1:
275 # Try to resolve conflict based on priority
276 p = [(r.options.priority or 0, r) for r in rules]
277 p.sort(key=lambda r: r[0], reverse=True)
278 best, second_best = p[:2]
279 if best[0] > second_best[0]:
280 rules = {best[1]}
281 else:
282 reduce_reduce.append((itemset, la, rules))
283 continue
284
285 rule ,= rules
286 if la in actions:
287 if self.strict:
288 msg = f'Shift/Reduce conflict for terminal {la.name}. [strict-mode]\n' \
289 f' * {rule}\n'
290 raise GrammarError(msg)
291 elif self.debug:
292 logger.warning('Shift/Reduce conflict for terminal %s: (resolving as shift)', la.name)
293 logger.warning(' * %s', rule)
294 else:
295 logger.debug('Shift/Reduce conflict for terminal %s: (resolving as shift)', la.name)
296 logger.debug(' * %s', rule)
297 else:
298 actions[la] = (Reduce, rule)
299 m[itemset] = { k.name: v for k, v in actions.items() }
300
301 if reduce_reduce:
302 msgs = []
303 for itemset, la, rules in reduce_reduce:
304 msg = 'Reduce/Reduce collision in %s between the following rules: %s' % (la, ''.join([ '\n\t- ' + str(r) for r in rules ]))
305 if self.debug:
306 msg += '\n collision occurred in state: {%s\n }' % ''.join(['\n\t' + str(x) for x in itemset.closure])
307 msgs.append(msg)
308 raise GrammarError('\n\n'.join(msgs))
309
310 states = { k.closure: v for k, v in m.items() }
311
312 # compute end states
313 end_states: Dict[str, 'State'] = {}
314 for state in states:
315 for rp in state:
316 for start in self.lr0_start_states:
317 if rp.rule.origin.name == ('$root_' + start) and rp.is_satisfied:
318 assert start not in end_states
319 end_states[start] = state
320
321 start_states = { start: state.closure for start, state in self.lr0_start_states.items() }
322 _parse_table = ParseTable(states, start_states, end_states)
323
324 if self.debug:
325 self.parse_table = _parse_table
326 else:
327 self.parse_table = IntParseTable.from_ParseTable(_parse_table)
328
329 def compute_lalr(self):
330 self.compute_lr0_states()
331 self.compute_reads_relations()
332 self.compute_includes_lookback()
333 self.compute_lookaheads()
334 self.compute_lalr1_states()