Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/lark/parsers/cyk.py: 24%
200 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:30 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:30 +0000
1"""This module implements a CYK parser."""
3# Author: https://github.com/ehudt (2018)
4#
5# Adapted by Erez
8from collections import defaultdict
9import itertools
11from ..exceptions import ParseError
12from ..lexer import Token
13from ..tree import Tree
14from ..grammar import Terminal as T, NonTerminal as NT, Symbol
16def match(t, s):
17 assert isinstance(t, T)
18 return t.name == s.type
21class Rule:
22 """Context-free grammar rule."""
24 def __init__(self, lhs, rhs, weight, alias):
25 super(Rule, self).__init__()
26 assert isinstance(lhs, NT), lhs
27 assert all(isinstance(x, NT) or isinstance(x, T) for x in rhs), rhs
28 self.lhs = lhs
29 self.rhs = rhs
30 self.weight = weight
31 self.alias = alias
33 def __str__(self):
34 return '%s -> %s' % (str(self.lhs), ' '.join(str(x) for x in self.rhs))
36 def __repr__(self):
37 return str(self)
39 def __hash__(self):
40 return hash((self.lhs, tuple(self.rhs)))
42 def __eq__(self, other):
43 return self.lhs == other.lhs and self.rhs == other.rhs
45 def __ne__(self, other):
46 return not (self == other)
49class Grammar:
50 """Context-free grammar."""
52 def __init__(self, rules):
53 self.rules = frozenset(rules)
55 def __eq__(self, other):
56 return self.rules == other.rules
58 def __str__(self):
59 return '\n' + '\n'.join(sorted(repr(x) for x in self.rules)) + '\n'
61 def __repr__(self):
62 return str(self)
65# Parse tree data structures
66class RuleNode:
67 """A node in the parse tree, which also contains the full rhs rule."""
69 def __init__(self, rule, children, weight=0):
70 self.rule = rule
71 self.children = children
72 self.weight = weight
74 def __repr__(self):
75 return 'RuleNode(%s, [%s])' % (repr(self.rule.lhs), ', '.join(str(x) for x in self.children))
79class Parser:
80 """Parser wrapper."""
82 def __init__(self, rules):
83 super(Parser, self).__init__()
84 self.orig_rules = {rule: rule for rule in rules}
85 rules = [self._to_rule(rule) for rule in rules]
86 self.grammar = to_cnf(Grammar(rules))
88 def _to_rule(self, lark_rule):
89 """Converts a lark rule, (lhs, rhs, callback, options), to a Rule."""
90 assert isinstance(lark_rule.origin, NT)
91 assert all(isinstance(x, Symbol) for x in lark_rule.expansion)
92 return Rule(
93 lark_rule.origin, lark_rule.expansion,
94 weight=lark_rule.options.priority if lark_rule.options.priority else 0,
95 alias=lark_rule)
97 def parse(self, tokenized, start): # pylint: disable=invalid-name
98 """Parses input, which is a list of tokens."""
99 assert start
100 start = NT(start)
102 table, trees = _parse(tokenized, self.grammar)
103 # Check if the parse succeeded.
104 if all(r.lhs != start for r in table[(0, len(tokenized) - 1)]):
105 raise ParseError('Parsing failed.')
106 parse = trees[(0, len(tokenized) - 1)][start]
107 return self._to_tree(revert_cnf(parse))
109 def _to_tree(self, rule_node):
110 """Converts a RuleNode parse tree to a lark Tree."""
111 orig_rule = self.orig_rules[rule_node.rule.alias]
112 children = []
113 for child in rule_node.children:
114 if isinstance(child, RuleNode):
115 children.append(self._to_tree(child))
116 else:
117 assert isinstance(child.name, Token)
118 children.append(child.name)
119 t = Tree(orig_rule.origin, children)
120 t.rule=orig_rule
121 return t
124def print_parse(node, indent=0):
125 if isinstance(node, RuleNode):
126 print(' ' * (indent * 2) + str(node.rule.lhs))
127 for child in node.children:
128 print_parse(child, indent + 1)
129 else:
130 print(' ' * (indent * 2) + str(node.s))
133def _parse(s, g):
134 """Parses sentence 's' using CNF grammar 'g'."""
135 # The CYK table. Indexed with a 2-tuple: (start pos, end pos)
136 table = defaultdict(set)
137 # Top-level structure is similar to the CYK table. Each cell is a dict from
138 # rule name to the best (lightest) tree for that rule.
139 trees = defaultdict(dict)
140 # Populate base case with existing terminal production rules
141 for i, w in enumerate(s):
142 for terminal, rules in g.terminal_rules.items():
143 if match(terminal, w):
144 for rule in rules:
145 table[(i, i)].add(rule)
146 if (rule.lhs not in trees[(i, i)] or
147 rule.weight < trees[(i, i)][rule.lhs].weight):
148 trees[(i, i)][rule.lhs] = RuleNode(rule, [T(w)], weight=rule.weight)
150 # Iterate over lengths of sub-sentences
151 for l in range(2, len(s) + 1):
152 # Iterate over sub-sentences with the given length
153 for i in range(len(s) - l + 1):
154 # Choose partition of the sub-sentence in [1, l)
155 for p in range(i + 1, i + l):
156 span1 = (i, p - 1)
157 span2 = (p, i + l - 1)
158 for r1, r2 in itertools.product(table[span1], table[span2]):
159 for rule in g.nonterminal_rules.get((r1.lhs, r2.lhs), []):
160 table[(i, i + l - 1)].add(rule)
161 r1_tree = trees[span1][r1.lhs]
162 r2_tree = trees[span2][r2.lhs]
163 rule_total_weight = rule.weight + r1_tree.weight + r2_tree.weight
164 if (rule.lhs not in trees[(i, i + l - 1)]
165 or rule_total_weight < trees[(i, i + l - 1)][rule.lhs].weight):
166 trees[(i, i + l - 1)][rule.lhs] = RuleNode(rule, [r1_tree, r2_tree], weight=rule_total_weight)
167 return table, trees
170# This section implements context-free grammar converter to Chomsky normal form.
171# It also implements a conversion of parse trees from its CNF to the original
172# grammar.
173# Overview:
174# Applies the following operations in this order:
175# * TERM: Eliminates non-solitary terminals from all rules
176# * BIN: Eliminates rules with more than 2 symbols on their right-hand-side.
177# * UNIT: Eliminates non-terminal unit rules
178#
179# The following grammar characteristics aren't featured:
180# * Start symbol appears on RHS
181# * Empty rules (epsilon rules)
184class CnfWrapper:
185 """CNF wrapper for grammar.
187 Validates that the input grammar is CNF and provides helper data structures.
188 """
190 def __init__(self, grammar):
191 super(CnfWrapper, self).__init__()
192 self.grammar = grammar
193 self.rules = grammar.rules
194 self.terminal_rules = defaultdict(list)
195 self.nonterminal_rules = defaultdict(list)
196 for r in self.rules:
197 # Validate that the grammar is CNF and populate auxiliary data structures.
198 assert isinstance(r.lhs, NT), r
199 if len(r.rhs) not in [1, 2]:
200 raise ParseError("CYK doesn't support empty rules")
201 if len(r.rhs) == 1 and isinstance(r.rhs[0], T):
202 self.terminal_rules[r.rhs[0]].append(r)
203 elif len(r.rhs) == 2 and all(isinstance(x, NT) for x in r.rhs):
204 self.nonterminal_rules[tuple(r.rhs)].append(r)
205 else:
206 assert False, r
208 def __eq__(self, other):
209 return self.grammar == other.grammar
211 def __repr__(self):
212 return repr(self.grammar)
215class UnitSkipRule(Rule):
216 """A rule that records NTs that were skipped during transformation."""
218 def __init__(self, lhs, rhs, skipped_rules, weight, alias):
219 super(UnitSkipRule, self).__init__(lhs, rhs, weight, alias)
220 self.skipped_rules = skipped_rules
222 def __eq__(self, other):
223 return isinstance(other, type(self)) and self.skipped_rules == other.skipped_rules
225 __hash__ = Rule.__hash__
228def build_unit_skiprule(unit_rule, target_rule):
229 skipped_rules = []
230 if isinstance(unit_rule, UnitSkipRule):
231 skipped_rules += unit_rule.skipped_rules
232 skipped_rules.append(target_rule)
233 if isinstance(target_rule, UnitSkipRule):
234 skipped_rules += target_rule.skipped_rules
235 return UnitSkipRule(unit_rule.lhs, target_rule.rhs, skipped_rules,
236 weight=unit_rule.weight + target_rule.weight, alias=unit_rule.alias)
239def get_any_nt_unit_rule(g):
240 """Returns a non-terminal unit rule from 'g', or None if there is none."""
241 for rule in g.rules:
242 if len(rule.rhs) == 1 and isinstance(rule.rhs[0], NT):
243 return rule
244 return None
247def _remove_unit_rule(g, rule):
248 """Removes 'rule' from 'g' without changing the language produced by 'g'."""
249 new_rules = [x for x in g.rules if x != rule]
250 refs = [x for x in g.rules if x.lhs == rule.rhs[0]]
251 new_rules += [build_unit_skiprule(rule, ref) for ref in refs]
252 return Grammar(new_rules)
255def _split(rule):
256 """Splits a rule whose len(rhs) > 2 into shorter rules."""
257 rule_str = str(rule.lhs) + '__' + '_'.join(str(x) for x in rule.rhs)
258 rule_name = '__SP_%s' % (rule_str) + '_%d'
259 yield Rule(rule.lhs, [rule.rhs[0], NT(rule_name % 1)], weight=rule.weight, alias=rule.alias)
260 for i in range(1, len(rule.rhs) - 2):
261 yield Rule(NT(rule_name % i), [rule.rhs[i], NT(rule_name % (i + 1))], weight=0, alias='Split')
262 yield Rule(NT(rule_name % (len(rule.rhs) - 2)), rule.rhs[-2:], weight=0, alias='Split')
265def _term(g):
266 """Applies the TERM rule on 'g' (see top comment)."""
267 all_t = {x for rule in g.rules for x in rule.rhs if isinstance(x, T)}
268 t_rules = {t: Rule(NT('__T_%s' % str(t)), [t], weight=0, alias='Term') for t in all_t}
269 new_rules = []
270 for rule in g.rules:
271 if len(rule.rhs) > 1 and any(isinstance(x, T) for x in rule.rhs):
272 new_rhs = [t_rules[x].lhs if isinstance(x, T) else x for x in rule.rhs]
273 new_rules.append(Rule(rule.lhs, new_rhs, weight=rule.weight, alias=rule.alias))
274 new_rules.extend(v for k, v in t_rules.items() if k in rule.rhs)
275 else:
276 new_rules.append(rule)
277 return Grammar(new_rules)
280def _bin(g):
281 """Applies the BIN rule to 'g' (see top comment)."""
282 new_rules = []
283 for rule in g.rules:
284 if len(rule.rhs) > 2:
285 new_rules += _split(rule)
286 else:
287 new_rules.append(rule)
288 return Grammar(new_rules)
291def _unit(g):
292 """Applies the UNIT rule to 'g' (see top comment)."""
293 nt_unit_rule = get_any_nt_unit_rule(g)
294 while nt_unit_rule:
295 g = _remove_unit_rule(g, nt_unit_rule)
296 nt_unit_rule = get_any_nt_unit_rule(g)
297 return g
300def to_cnf(g):
301 """Creates a CNF grammar from a general context-free grammar 'g'."""
302 g = _unit(_bin(_term(g)))
303 return CnfWrapper(g)
306def unroll_unit_skiprule(lhs, orig_rhs, skipped_rules, children, weight, alias):
307 if not skipped_rules:
308 return RuleNode(Rule(lhs, orig_rhs, weight=weight, alias=alias), children, weight=weight)
309 else:
310 weight = weight - skipped_rules[0].weight
311 return RuleNode(
312 Rule(lhs, [skipped_rules[0].lhs], weight=weight, alias=alias), [
313 unroll_unit_skiprule(skipped_rules[0].lhs, orig_rhs,
314 skipped_rules[1:], children,
315 skipped_rules[0].weight, skipped_rules[0].alias)
316 ], weight=weight)
319def revert_cnf(node):
320 """Reverts a parse tree (RuleNode) to its original non-CNF form (Node)."""
321 if isinstance(node, T):
322 return node
323 # Reverts TERM rule.
324 if node.rule.lhs.name.startswith('__T_'):
325 return node.children[0]
326 else:
327 children = []
328 for child in map(revert_cnf, node.children):
329 # Reverts BIN rule.
330 if isinstance(child, RuleNode) and child.rule.lhs.name.startswith('__SP_'):
331 children += child.children
332 else:
333 children.append(child)
334 # Reverts UNIT rule.
335 if isinstance(node.rule, UnitSkipRule):
336 return unroll_unit_skiprule(node.rule.lhs, node.rule.rhs,
337 node.rule.skipped_rules, children,
338 node.rule.weight, node.rule.alias)
339 else:
340 return RuleNode(node.rule, children)