Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/lark/parsers/lalr_analysis.py: 86%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

230 statements  

1"""This module builds a LALR(1) transition-table for lalr_parser.py 

2 

3For now, shift/reduce conflicts are automatically resolved as shifts. 

4""" 

5 

6# Author: Erez Shinan (2017) 

7# Email : erezshin@gmail.com 

8 

9from typing import Dict, Set, Iterator, Tuple, List, TypeVar, Generic 

10from collections import defaultdict 

11 

12from ..utils import classify, classify_bool, bfs, fzset, Enumerator, logger 

13from ..exceptions import GrammarError 

14 

15from .grammar_analysis import GrammarAnalyzer, Terminal, LR0ItemSet, RulePtr, State 

16from ..grammar import Rule, Symbol 

17from ..common import ParserConf 

18 

19###{standalone 

20 

21class Action: 

22 def __init__(self, name): 

23 self.name = name 

24 def __str__(self): 

25 return self.name 

26 def __repr__(self): 

27 return str(self) 

28 

29Shift = Action('Shift') 

30Reduce = Action('Reduce') 

31 

32StateT = TypeVar("StateT") 

33 

34class ParseTableBase(Generic[StateT]): 

35 states: Dict[StateT, Dict[str, Tuple]] 

36 start_states: Dict[str, StateT] 

37 end_states: Dict[str, StateT] 

38 

39 def __init__(self, states, start_states, end_states): 

40 self.states = states 

41 self.start_states = start_states 

42 self.end_states = end_states 

43 

44 def serialize(self, memo): 

45 tokens = Enumerator() 

46 

47 states = { 

48 state: {tokens.get(token): ((1, arg.serialize(memo)) if action is Reduce else (0, arg)) 

49 for token, (action, arg) in actions.items()} 

50 for state, actions in self.states.items() 

51 } 

52 

53 return { 

54 'tokens': tokens.reversed(), 

55 'states': states, 

56 'start_states': self.start_states, 

57 'end_states': self.end_states, 

58 } 

59 

60 @classmethod 

61 def deserialize(cls, data, memo): 

62 tokens = data['tokens'] 

63 states = { 

64 state: {tokens[token]: ((Reduce, Rule.deserialize(arg, memo)) if action==1 else (Shift, arg)) 

65 for token, (action, arg) in actions.items()} 

66 for state, actions in data['states'].items() 

67 } 

68 return cls(states, data['start_states'], data['end_states']) 

69 

70class ParseTable(ParseTableBase['State']): 

71 """Parse-table whose key is State, i.e. set[RulePtr] 

72 

73 Slower than IntParseTable, but useful for debugging 

74 """ 

75 pass 

76 

77 

78class IntParseTable(ParseTableBase[int]): 

79 """Parse-table whose key is int. Best for performance.""" 

80 

81 @classmethod 

82 def from_ParseTable(cls, parse_table: ParseTable): 

83 enum = list(parse_table.states) 

84 state_to_idx: Dict['State', int] = {s:i for i,s in enumerate(enum)} 

85 int_states = {} 

86 

87 for s, la in parse_table.states.items(): 

88 la = {k:(v[0], state_to_idx[v[1]]) if v[0] is Shift else v 

89 for k,v in la.items()} 

90 int_states[ state_to_idx[s] ] = la 

91 

92 

93 start_states = {start:state_to_idx[s] for start, s in parse_table.start_states.items()} 

94 end_states = {start:state_to_idx[s] for start, s in parse_table.end_states.items()} 

95 return cls(int_states, start_states, end_states) 

96 

97###} 

98 

99 

100# digraph and traverse, see The Theory and Practice of Compiler Writing 

101 

102# computes F(x) = G(x) union (union { G(y) | x R y }) 

103# X: nodes 

104# R: relation (function mapping node -> list of nodes that satisfy the relation) 

105# G: set valued function 

106def digraph(X, R, G): 

107 F = {} 

108 S = [] 

109 N = dict.fromkeys(X, 0) 

110 for x in X: 

111 # this is always true for the first iteration, but N[x] may be updated in traverse below 

112 if N[x] == 0: 

113 traverse(x, S, N, X, R, G, F) 

114 return F 

115 

116# x: single node 

117# S: stack 

118# N: weights 

119# X: nodes 

120# R: relation (see above) 

121# G: set valued function 

122# F: set valued function we are computing (map of input -> output) 

123def traverse(x, S, N, X, R, G, F): 

124 S.append(x) 

125 d = len(S) 

126 N[x] = d 

127 F[x] = G[x] 

128 for y in R[x]: 

129 if N[y] == 0: 

130 traverse(y, S, N, X, R, G, F) 

131 n_x = N[x] 

132 assert(n_x > 0) 

133 n_y = N[y] 

134 assert(n_y != 0) 

135 if (n_y > 0) and (n_y < n_x): 

136 N[x] = n_y 

137 F[x].update(F[y]) 

138 if N[x] == d: 

139 f_x = F[x] 

140 while True: 

141 z = S.pop() 

142 N[z] = -1 

143 F[z] = f_x 

144 if z == x: 

145 break 

146 

147 

148class LALR_Analyzer(GrammarAnalyzer): 

149 lr0_itemsets: Set[LR0ItemSet] 

150 nonterminal_transitions: List[Tuple[LR0ItemSet, Symbol]] 

151 lookback: Dict[Tuple[LR0ItemSet, Symbol], Set[Tuple[LR0ItemSet, Rule]]] 

152 includes: Dict[Tuple[LR0ItemSet, Symbol], Set[Tuple[LR0ItemSet, Symbol]]] 

153 reads: Dict[Tuple[LR0ItemSet, Symbol], Set[Tuple[LR0ItemSet, Symbol]]] 

154 directly_reads: Dict[Tuple[LR0ItemSet, Symbol], Set[Symbol]] 

155 

156 

157 def __init__(self, parser_conf: ParserConf, debug: bool=False, strict: bool=False): 

158 GrammarAnalyzer.__init__(self, parser_conf, debug, strict) 

159 self.nonterminal_transitions = [] 

160 self.directly_reads = defaultdict(set) 

161 self.reads = defaultdict(set) 

162 self.includes = defaultdict(set) 

163 self.lookback = defaultdict(set) 

164 

165 

166 def compute_lr0_states(self) -> None: 

167 self.lr0_itemsets = set() 

168 # map of kernels to LR0ItemSets 

169 cache: Dict['State', LR0ItemSet] = {} 

170 

171 def step(state: LR0ItemSet) -> Iterator[LR0ItemSet]: 

172 _, unsat = classify_bool(state.closure, lambda rp: rp.is_satisfied) 

173 

174 d = classify(unsat, lambda rp: rp.next) 

175 for sym, rps in d.items(): 

176 kernel = fzset({rp.advance(sym) for rp in rps}) 

177 new_state = cache.get(kernel, None) 

178 if new_state is None: 

179 closure = set(kernel) 

180 for rp in kernel: 

181 if not rp.is_satisfied and not rp.next.is_term: 

182 closure |= self.expand_rule(rp.next, self.lr0_rules_by_origin) 

183 new_state = LR0ItemSet(kernel, closure) 

184 cache[kernel] = new_state 

185 

186 state.transitions[sym] = new_state 

187 yield new_state 

188 

189 self.lr0_itemsets.add(state) 

190 

191 for _ in bfs(self.lr0_start_states.values(), step): 

192 pass 

193 

194 def compute_reads_relations(self): 

195 # handle start state 

196 for root in self.lr0_start_states.values(): 

197 assert(len(root.kernel) == 1) 

198 for rp in root.kernel: 

199 assert(rp.index == 0) 

200 self.directly_reads[(root, rp.next)] = set([ Terminal('$END') ]) 

201 

202 for state in self.lr0_itemsets: 

203 seen = set() 

204 for rp in state.closure: 

205 if rp.is_satisfied: 

206 continue 

207 s = rp.next 

208 # if s is a not a nonterminal 

209 if s not in self.lr0_rules_by_origin: 

210 continue 

211 if s in seen: 

212 continue 

213 seen.add(s) 

214 nt = (state, s) 

215 self.nonterminal_transitions.append(nt) 

216 dr = self.directly_reads[nt] 

217 r = self.reads[nt] 

218 next_state = state.transitions[s] 

219 for rp2 in next_state.closure: 

220 if rp2.is_satisfied: 

221 continue 

222 s2 = rp2.next 

223 # if s2 is a terminal 

224 if s2 not in self.lr0_rules_by_origin: 

225 dr.add(s2) 

226 if s2 in self.NULLABLE: 

227 r.add((next_state, s2)) 

228 

229 def compute_includes_lookback(self): 

230 for nt in self.nonterminal_transitions: 

231 state, nonterminal = nt 

232 includes = [] 

233 lookback = self.lookback[nt] 

234 for rp in state.closure: 

235 if rp.rule.origin != nonterminal: 

236 continue 

237 # traverse the states for rp(.rule) 

238 state2 = state 

239 for i in range(rp.index, len(rp.rule.expansion)): 

240 s = rp.rule.expansion[i] 

241 nt2 = (state2, s) 

242 state2 = state2.transitions[s] 

243 if nt2 not in self.reads: 

244 continue 

245 for j in range(i + 1, len(rp.rule.expansion)): 

246 if rp.rule.expansion[j] not in self.NULLABLE: 

247 break 

248 else: 

249 includes.append(nt2) 

250 # state2 is at the final state for rp.rule 

251 if rp.index == 0: 

252 for rp2 in state2.closure: 

253 if (rp2.rule == rp.rule) and rp2.is_satisfied: 

254 lookback.add((state2, rp2.rule)) 

255 for nt2 in includes: 

256 self.includes[nt2].add(nt) 

257 

258 def compute_lookaheads(self): 

259 read_sets = digraph(self.nonterminal_transitions, self.reads, self.directly_reads) 

260 follow_sets = digraph(self.nonterminal_transitions, self.includes, read_sets) 

261 

262 for nt, lookbacks in self.lookback.items(): 

263 for state, rule in lookbacks: 

264 for s in follow_sets[nt]: 

265 state.lookaheads[s].add(rule) 

266 

267 def compute_lalr1_states(self) -> None: 

268 m: Dict[LR0ItemSet, Dict[str, Tuple]] = {} 

269 reduce_reduce = [] 

270 for itemset in self.lr0_itemsets: 

271 actions: Dict[Symbol, Tuple] = {la: (Shift, next_state.closure) 

272 for la, next_state in itemset.transitions.items()} 

273 for la, rules in itemset.lookaheads.items(): 

274 if len(rules) > 1: 

275 # Try to resolve conflict based on priority 

276 p = [(r.options.priority or 0, r) for r in rules] 

277 p.sort(key=lambda r: r[0], reverse=True) 

278 best, second_best = p[:2] 

279 if best[0] > second_best[0]: 

280 rules = {best[1]} 

281 else: 

282 reduce_reduce.append((itemset, la, rules)) 

283 continue 

284 

285 rule ,= rules 

286 if la in actions: 

287 if self.strict: 

288 msg = f'Shift/Reduce conflict for terminal {la.name}. [strict-mode]\n' \ 

289 f' * {rule}\n' 

290 raise GrammarError(msg) 

291 elif self.debug: 

292 logger.warning('Shift/Reduce conflict for terminal %s: (resolving as shift)', la.name) 

293 logger.warning(' * %s', rule) 

294 else: 

295 logger.debug('Shift/Reduce conflict for terminal %s: (resolving as shift)', la.name) 

296 logger.debug(' * %s', rule) 

297 else: 

298 actions[la] = (Reduce, rule) 

299 m[itemset] = { k.name: v for k, v in actions.items() } 

300 

301 if reduce_reduce: 

302 msgs = [] 

303 for itemset, la, rules in reduce_reduce: 

304 msg = 'Reduce/Reduce collision in %s between the following rules: %s' % (la, ''.join([ '\n\t- ' + str(r) for r in rules ])) 

305 if self.debug: 

306 msg += '\n collision occurred in state: {%s\n }' % ''.join(['\n\t' + str(x) for x in itemset.closure]) 

307 msgs.append(msg) 

308 raise GrammarError('\n\n'.join(msgs)) 

309 

310 states = { k.closure: v for k, v in m.items() } 

311 

312 # compute end states 

313 end_states: Dict[str, 'State'] = {} 

314 for state in states: 

315 for rp in state: 

316 for start in self.lr0_start_states: 

317 if rp.rule.origin.name == ('$root_' + start) and rp.is_satisfied: 

318 assert start not in end_states 

319 end_states[start] = state 

320 

321 start_states = { start: state.closure for start, state in self.lr0_start_states.items() } 

322 _parse_table = ParseTable(states, start_states, end_states) 

323 

324 if self.debug: 

325 self.parse_table = _parse_table 

326 else: 

327 self.parse_table = IntParseTable.from_ParseTable(_parse_table) 

328 

329 def compute_lalr(self): 

330 self.compute_lr0_states() 

331 self.compute_reads_relations() 

332 self.compute_includes_lookback() 

333 self.compute_lookaheads() 

334 self.compute_lalr1_states()