Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/lark/parsers/cyk.py: 24%

200 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:30 +0000

1"""This module implements a CYK parser.""" 

2 

3# Author: https://github.com/ehudt (2018) 

4# 

5# Adapted by Erez 

6 

7 

8from collections import defaultdict 

9import itertools 

10 

11from ..exceptions import ParseError 

12from ..lexer import Token 

13from ..tree import Tree 

14from ..grammar import Terminal as T, NonTerminal as NT, Symbol 

15 

16def match(t, s): 

17 assert isinstance(t, T) 

18 return t.name == s.type 

19 

20 

21class Rule: 

22 """Context-free grammar rule.""" 

23 

24 def __init__(self, lhs, rhs, weight, alias): 

25 super(Rule, self).__init__() 

26 assert isinstance(lhs, NT), lhs 

27 assert all(isinstance(x, NT) or isinstance(x, T) for x in rhs), rhs 

28 self.lhs = lhs 

29 self.rhs = rhs 

30 self.weight = weight 

31 self.alias = alias 

32 

33 def __str__(self): 

34 return '%s -> %s' % (str(self.lhs), ' '.join(str(x) for x in self.rhs)) 

35 

36 def __repr__(self): 

37 return str(self) 

38 

39 def __hash__(self): 

40 return hash((self.lhs, tuple(self.rhs))) 

41 

42 def __eq__(self, other): 

43 return self.lhs == other.lhs and self.rhs == other.rhs 

44 

45 def __ne__(self, other): 

46 return not (self == other) 

47 

48 

49class Grammar: 

50 """Context-free grammar.""" 

51 

52 def __init__(self, rules): 

53 self.rules = frozenset(rules) 

54 

55 def __eq__(self, other): 

56 return self.rules == other.rules 

57 

58 def __str__(self): 

59 return '\n' + '\n'.join(sorted(repr(x) for x in self.rules)) + '\n' 

60 

61 def __repr__(self): 

62 return str(self) 

63 

64 

65# Parse tree data structures 

66class RuleNode: 

67 """A node in the parse tree, which also contains the full rhs rule.""" 

68 

69 def __init__(self, rule, children, weight=0): 

70 self.rule = rule 

71 self.children = children 

72 self.weight = weight 

73 

74 def __repr__(self): 

75 return 'RuleNode(%s, [%s])' % (repr(self.rule.lhs), ', '.join(str(x) for x in self.children)) 

76 

77 

78 

79class Parser: 

80 """Parser wrapper.""" 

81 

82 def __init__(self, rules): 

83 super(Parser, self).__init__() 

84 self.orig_rules = {rule: rule for rule in rules} 

85 rules = [self._to_rule(rule) for rule in rules] 

86 self.grammar = to_cnf(Grammar(rules)) 

87 

88 def _to_rule(self, lark_rule): 

89 """Converts a lark rule, (lhs, rhs, callback, options), to a Rule.""" 

90 assert isinstance(lark_rule.origin, NT) 

91 assert all(isinstance(x, Symbol) for x in lark_rule.expansion) 

92 return Rule( 

93 lark_rule.origin, lark_rule.expansion, 

94 weight=lark_rule.options.priority if lark_rule.options.priority else 0, 

95 alias=lark_rule) 

96 

97 def parse(self, tokenized, start): # pylint: disable=invalid-name 

98 """Parses input, which is a list of tokens.""" 

99 assert start 

100 start = NT(start) 

101 

102 table, trees = _parse(tokenized, self.grammar) 

103 # Check if the parse succeeded. 

104 if all(r.lhs != start for r in table[(0, len(tokenized) - 1)]): 

105 raise ParseError('Parsing failed.') 

106 parse = trees[(0, len(tokenized) - 1)][start] 

107 return self._to_tree(revert_cnf(parse)) 

108 

109 def _to_tree(self, rule_node): 

110 """Converts a RuleNode parse tree to a lark Tree.""" 

111 orig_rule = self.orig_rules[rule_node.rule.alias] 

112 children = [] 

113 for child in rule_node.children: 

114 if isinstance(child, RuleNode): 

115 children.append(self._to_tree(child)) 

116 else: 

117 assert isinstance(child.name, Token) 

118 children.append(child.name) 

119 t = Tree(orig_rule.origin, children) 

120 t.rule=orig_rule 

121 return t 

122 

123 

124def print_parse(node, indent=0): 

125 if isinstance(node, RuleNode): 

126 print(' ' * (indent * 2) + str(node.rule.lhs)) 

127 for child in node.children: 

128 print_parse(child, indent + 1) 

129 else: 

130 print(' ' * (indent * 2) + str(node.s)) 

131 

132 

133def _parse(s, g): 

134 """Parses sentence 's' using CNF grammar 'g'.""" 

135 # The CYK table. Indexed with a 2-tuple: (start pos, end pos) 

136 table = defaultdict(set) 

137 # Top-level structure is similar to the CYK table. Each cell is a dict from 

138 # rule name to the best (lightest) tree for that rule. 

139 trees = defaultdict(dict) 

140 # Populate base case with existing terminal production rules 

141 for i, w in enumerate(s): 

142 for terminal, rules in g.terminal_rules.items(): 

143 if match(terminal, w): 

144 for rule in rules: 

145 table[(i, i)].add(rule) 

146 if (rule.lhs not in trees[(i, i)] or 

147 rule.weight < trees[(i, i)][rule.lhs].weight): 

148 trees[(i, i)][rule.lhs] = RuleNode(rule, [T(w)], weight=rule.weight) 

149 

150 # Iterate over lengths of sub-sentences 

151 for l in range(2, len(s) + 1): 

152 # Iterate over sub-sentences with the given length 

153 for i in range(len(s) - l + 1): 

154 # Choose partition of the sub-sentence in [1, l) 

155 for p in range(i + 1, i + l): 

156 span1 = (i, p - 1) 

157 span2 = (p, i + l - 1) 

158 for r1, r2 in itertools.product(table[span1], table[span2]): 

159 for rule in g.nonterminal_rules.get((r1.lhs, r2.lhs), []): 

160 table[(i, i + l - 1)].add(rule) 

161 r1_tree = trees[span1][r1.lhs] 

162 r2_tree = trees[span2][r2.lhs] 

163 rule_total_weight = rule.weight + r1_tree.weight + r2_tree.weight 

164 if (rule.lhs not in trees[(i, i + l - 1)] 

165 or rule_total_weight < trees[(i, i + l - 1)][rule.lhs].weight): 

166 trees[(i, i + l - 1)][rule.lhs] = RuleNode(rule, [r1_tree, r2_tree], weight=rule_total_weight) 

167 return table, trees 

168 

169 

170# This section implements context-free grammar converter to Chomsky normal form. 

171# It also implements a conversion of parse trees from its CNF to the original 

172# grammar. 

173# Overview: 

174# Applies the following operations in this order: 

175# * TERM: Eliminates non-solitary terminals from all rules 

176# * BIN: Eliminates rules with more than 2 symbols on their right-hand-side. 

177# * UNIT: Eliminates non-terminal unit rules 

178# 

179# The following grammar characteristics aren't featured: 

180# * Start symbol appears on RHS 

181# * Empty rules (epsilon rules) 

182 

183 

184class CnfWrapper: 

185 """CNF wrapper for grammar. 

186 

187 Validates that the input grammar is CNF and provides helper data structures. 

188 """ 

189 

190 def __init__(self, grammar): 

191 super(CnfWrapper, self).__init__() 

192 self.grammar = grammar 

193 self.rules = grammar.rules 

194 self.terminal_rules = defaultdict(list) 

195 self.nonterminal_rules = defaultdict(list) 

196 for r in self.rules: 

197 # Validate that the grammar is CNF and populate auxiliary data structures. 

198 assert isinstance(r.lhs, NT), r 

199 if len(r.rhs) not in [1, 2]: 

200 raise ParseError("CYK doesn't support empty rules") 

201 if len(r.rhs) == 1 and isinstance(r.rhs[0], T): 

202 self.terminal_rules[r.rhs[0]].append(r) 

203 elif len(r.rhs) == 2 and all(isinstance(x, NT) for x in r.rhs): 

204 self.nonterminal_rules[tuple(r.rhs)].append(r) 

205 else: 

206 assert False, r 

207 

208 def __eq__(self, other): 

209 return self.grammar == other.grammar 

210 

211 def __repr__(self): 

212 return repr(self.grammar) 

213 

214 

215class UnitSkipRule(Rule): 

216 """A rule that records NTs that were skipped during transformation.""" 

217 

218 def __init__(self, lhs, rhs, skipped_rules, weight, alias): 

219 super(UnitSkipRule, self).__init__(lhs, rhs, weight, alias) 

220 self.skipped_rules = skipped_rules 

221 

222 def __eq__(self, other): 

223 return isinstance(other, type(self)) and self.skipped_rules == other.skipped_rules 

224 

225 __hash__ = Rule.__hash__ 

226 

227 

228def build_unit_skiprule(unit_rule, target_rule): 

229 skipped_rules = [] 

230 if isinstance(unit_rule, UnitSkipRule): 

231 skipped_rules += unit_rule.skipped_rules 

232 skipped_rules.append(target_rule) 

233 if isinstance(target_rule, UnitSkipRule): 

234 skipped_rules += target_rule.skipped_rules 

235 return UnitSkipRule(unit_rule.lhs, target_rule.rhs, skipped_rules, 

236 weight=unit_rule.weight + target_rule.weight, alias=unit_rule.alias) 

237 

238 

239def get_any_nt_unit_rule(g): 

240 """Returns a non-terminal unit rule from 'g', or None if there is none.""" 

241 for rule in g.rules: 

242 if len(rule.rhs) == 1 and isinstance(rule.rhs[0], NT): 

243 return rule 

244 return None 

245 

246 

247def _remove_unit_rule(g, rule): 

248 """Removes 'rule' from 'g' without changing the language produced by 'g'.""" 

249 new_rules = [x for x in g.rules if x != rule] 

250 refs = [x for x in g.rules if x.lhs == rule.rhs[0]] 

251 new_rules += [build_unit_skiprule(rule, ref) for ref in refs] 

252 return Grammar(new_rules) 

253 

254 

255def _split(rule): 

256 """Splits a rule whose len(rhs) > 2 into shorter rules.""" 

257 rule_str = str(rule.lhs) + '__' + '_'.join(str(x) for x in rule.rhs) 

258 rule_name = '__SP_%s' % (rule_str) + '_%d' 

259 yield Rule(rule.lhs, [rule.rhs[0], NT(rule_name % 1)], weight=rule.weight, alias=rule.alias) 

260 for i in range(1, len(rule.rhs) - 2): 

261 yield Rule(NT(rule_name % i), [rule.rhs[i], NT(rule_name % (i + 1))], weight=0, alias='Split') 

262 yield Rule(NT(rule_name % (len(rule.rhs) - 2)), rule.rhs[-2:], weight=0, alias='Split') 

263 

264 

265def _term(g): 

266 """Applies the TERM rule on 'g' (see top comment).""" 

267 all_t = {x for rule in g.rules for x in rule.rhs if isinstance(x, T)} 

268 t_rules = {t: Rule(NT('__T_%s' % str(t)), [t], weight=0, alias='Term') for t in all_t} 

269 new_rules = [] 

270 for rule in g.rules: 

271 if len(rule.rhs) > 1 and any(isinstance(x, T) for x in rule.rhs): 

272 new_rhs = [t_rules[x].lhs if isinstance(x, T) else x for x in rule.rhs] 

273 new_rules.append(Rule(rule.lhs, new_rhs, weight=rule.weight, alias=rule.alias)) 

274 new_rules.extend(v for k, v in t_rules.items() if k in rule.rhs) 

275 else: 

276 new_rules.append(rule) 

277 return Grammar(new_rules) 

278 

279 

280def _bin(g): 

281 """Applies the BIN rule to 'g' (see top comment).""" 

282 new_rules = [] 

283 for rule in g.rules: 

284 if len(rule.rhs) > 2: 

285 new_rules += _split(rule) 

286 else: 

287 new_rules.append(rule) 

288 return Grammar(new_rules) 

289 

290 

291def _unit(g): 

292 """Applies the UNIT rule to 'g' (see top comment).""" 

293 nt_unit_rule = get_any_nt_unit_rule(g) 

294 while nt_unit_rule: 

295 g = _remove_unit_rule(g, nt_unit_rule) 

296 nt_unit_rule = get_any_nt_unit_rule(g) 

297 return g 

298 

299 

300def to_cnf(g): 

301 """Creates a CNF grammar from a general context-free grammar 'g'.""" 

302 g = _unit(_bin(_term(g))) 

303 return CnfWrapper(g) 

304 

305 

306def unroll_unit_skiprule(lhs, orig_rhs, skipped_rules, children, weight, alias): 

307 if not skipped_rules: 

308 return RuleNode(Rule(lhs, orig_rhs, weight=weight, alias=alias), children, weight=weight) 

309 else: 

310 weight = weight - skipped_rules[0].weight 

311 return RuleNode( 

312 Rule(lhs, [skipped_rules[0].lhs], weight=weight, alias=alias), [ 

313 unroll_unit_skiprule(skipped_rules[0].lhs, orig_rhs, 

314 skipped_rules[1:], children, 

315 skipped_rules[0].weight, skipped_rules[0].alias) 

316 ], weight=weight) 

317 

318 

319def revert_cnf(node): 

320 """Reverts a parse tree (RuleNode) to its original non-CNF form (Node).""" 

321 if isinstance(node, T): 

322 return node 

323 # Reverts TERM rule. 

324 if node.rule.lhs.name.startswith('__T_'): 

325 return node.children[0] 

326 else: 

327 children = [] 

328 for child in map(revert_cnf, node.children): 

329 # Reverts BIN rule. 

330 if isinstance(child, RuleNode) and child.rule.lhs.name.startswith('__SP_'): 

331 children += child.children 

332 else: 

333 children.append(child) 

334 # Reverts UNIT rule. 

335 if isinstance(node.rule, UnitSkipRule): 

336 return unroll_unit_skiprule(node.rule.lhs, node.rule.rhs, 

337 node.rule.skipped_rules, children, 

338 node.rule.weight, node.rule.alias) 

339 else: 

340 return RuleNode(node.rule, children)