Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/markdown_it/ruler.py: 72%
129 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:15 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:15 +0000
1"""
2class Ruler
4Helper class, used by [[MarkdownIt#core]], [[MarkdownIt#block]] and
5[[MarkdownIt#inline]] to manage sequences of functions (rules):
7- keep rules in defined order
8- assign the name to each rule
9- enable/disable rules
10- add/replace rules
11- allow assign rules to additional named chains (in the same)
12- caching lists of active rules
14You will not need use this class directly until write plugins. For simple
15rules control use [[MarkdownIt.disable]], [[MarkdownIt.enable]] and
16[[MarkdownIt.use]].
17"""
18from __future__ import annotations
20from collections.abc import Iterable
21from dataclasses import dataclass, field
22from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar
23import warnings
25from markdown_it._compat import DATACLASS_KWARGS
27from .utils import EnvType
29if TYPE_CHECKING:
30 from markdown_it import MarkdownIt
33class StateBase:
34 def __init__(self, src: str, md: MarkdownIt, env: EnvType):
35 self.src = src
36 self.env = env
37 self.md = md
39 @property
40 def src(self) -> str:
41 return self._src
43 @src.setter
44 def src(self, value: str) -> None:
45 self._src = value
46 self._srcCharCode: tuple[int, ...] | None = None
48 @property
49 def srcCharCode(self) -> tuple[int, ...]:
50 warnings.warn(
51 "StateBase.srcCharCode is deprecated. Use StateBase.src instead.",
52 DeprecationWarning,
53 stacklevel=2,
54 )
55 if self._srcCharCode is None:
56 self._srcCharCode = tuple(ord(c) for c in self._src)
57 return self._srcCharCode
60class RuleOptionsType(TypedDict, total=False):
61 alt: list[str]
64RuleFuncTv = TypeVar("RuleFuncTv")
65"""A rule function, whose signature is dependent on the state type."""
68@dataclass(**DATACLASS_KWARGS)
69class Rule(Generic[RuleFuncTv]):
70 name: str
71 enabled: bool
72 fn: RuleFuncTv = field(repr=False)
73 alt: list[str]
76class Ruler(Generic[RuleFuncTv]):
77 def __init__(self) -> None:
78 # List of added rules.
79 self.__rules__: list[Rule[RuleFuncTv]] = []
80 # Cached rule chains.
81 # First level - chain name, '' for default.
82 # Second level - diginal anchor for fast filtering by charcodes.
83 self.__cache__: dict[str, list[RuleFuncTv]] | None = None
85 def __find__(self, name: str) -> int:
86 """Find rule index by name"""
87 for i, rule in enumerate(self.__rules__):
88 if rule.name == name:
89 return i
90 return -1
92 def __compile__(self) -> None:
93 """Build rules lookup cache"""
94 chains = {""}
95 # collect unique names
96 for rule in self.__rules__:
97 if not rule.enabled:
98 continue
99 for name in rule.alt:
100 chains.add(name)
101 self.__cache__ = {}
102 for chain in chains:
103 self.__cache__[chain] = []
104 for rule in self.__rules__:
105 if not rule.enabled:
106 continue
107 if chain and (chain not in rule.alt):
108 continue
109 self.__cache__[chain].append(rule.fn)
111 def at(
112 self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
113 ) -> None:
114 """Replace rule by name with new function & options.
116 :param ruleName: rule name to replace.
117 :param fn: new rule function.
118 :param options: new rule options (not mandatory).
119 :raises: KeyError if name not found
120 """
121 index = self.__find__(ruleName)
122 options = options or {}
123 if index == -1:
124 raise KeyError(f"Parser rule not found: {ruleName}")
125 self.__rules__[index].fn = fn
126 self.__rules__[index].alt = options.get("alt", [])
127 self.__cache__ = None
129 def before(
130 self,
131 beforeName: str,
132 ruleName: str,
133 fn: RuleFuncTv,
134 options: RuleOptionsType | None = None,
135 ) -> None:
136 """Add new rule to chain before one with given name.
138 :param beforeName: new rule will be added before this one.
139 :param ruleName: new rule will be added before this one.
140 :param fn: new rule function.
141 :param options: new rule options (not mandatory).
142 :raises: KeyError if name not found
143 """
144 index = self.__find__(beforeName)
145 options = options or {}
146 if index == -1:
147 raise KeyError(f"Parser rule not found: {beforeName}")
148 self.__rules__.insert(
149 index, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
150 )
151 self.__cache__ = None
153 def after(
154 self,
155 afterName: str,
156 ruleName: str,
157 fn: RuleFuncTv,
158 options: RuleOptionsType | None = None,
159 ) -> None:
160 """Add new rule to chain after one with given name.
162 :param afterName: new rule will be added after this one.
163 :param ruleName: new rule will be added after this one.
164 :param fn: new rule function.
165 :param options: new rule options (not mandatory).
166 :raises: KeyError if name not found
167 """
168 index = self.__find__(afterName)
169 options = options or {}
170 if index == -1:
171 raise KeyError(f"Parser rule not found: {afterName}")
172 self.__rules__.insert(
173 index + 1, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
174 )
175 self.__cache__ = None
177 def push(
178 self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
179 ) -> None:
180 """Push new rule to the end of chain.
182 :param ruleName: new rule will be added to the end of chain.
183 :param fn: new rule function.
184 :param options: new rule options (not mandatory).
186 """
187 self.__rules__.append(
188 Rule[RuleFuncTv](ruleName, True, fn, (options or {}).get("alt", []))
189 )
190 self.__cache__ = None
192 def enable(
193 self, names: str | Iterable[str], ignoreInvalid: bool = False
194 ) -> list[str]:
195 """Enable rules with given names.
197 :param names: name or list of rule names to enable.
198 :param ignoreInvalid: ignore errors when rule not found
199 :raises: KeyError if name not found and not ignoreInvalid
200 :return: list of found rule names
201 """
202 if isinstance(names, str):
203 names = [names]
204 result: list[str] = []
205 for name in names:
206 idx = self.__find__(name)
207 if (idx < 0) and ignoreInvalid:
208 continue
209 if (idx < 0) and not ignoreInvalid:
210 raise KeyError(f"Rules manager: invalid rule name {name}")
211 self.__rules__[idx].enabled = True
212 result.append(name)
213 self.__cache__ = None
214 return result
216 def enableOnly(
217 self, names: str | Iterable[str], ignoreInvalid: bool = False
218 ) -> list[str]:
219 """Enable rules with given names, and disable everything else.
221 :param names: name or list of rule names to enable.
222 :param ignoreInvalid: ignore errors when rule not found
223 :raises: KeyError if name not found and not ignoreInvalid
224 :return: list of found rule names
225 """
226 if isinstance(names, str):
227 names = [names]
228 for rule in self.__rules__:
229 rule.enabled = False
230 return self.enable(names, ignoreInvalid)
232 def disable(
233 self, names: str | Iterable[str], ignoreInvalid: bool = False
234 ) -> list[str]:
235 """Disable rules with given names.
237 :param names: name or list of rule names to enable.
238 :param ignoreInvalid: ignore errors when rule not found
239 :raises: KeyError if name not found and not ignoreInvalid
240 :return: list of found rule names
241 """
242 if isinstance(names, str):
243 names = [names]
244 result = []
245 for name in names:
246 idx = self.__find__(name)
247 if (idx < 0) and ignoreInvalid:
248 continue
249 if (idx < 0) and not ignoreInvalid:
250 raise KeyError(f"Rules manager: invalid rule name {name}")
251 self.__rules__[idx].enabled = False
252 result.append(name)
253 self.__cache__ = None
254 return result
256 def getRules(self, chainName: str = "") -> list[RuleFuncTv]:
257 """Return array of active functions (rules) for given chain name.
258 It analyzes rules configuration, compiles caches if not exists and returns result.
260 Default chain name is `''` (empty string). It can't be skipped.
261 That's done intentionally, to keep signature monomorphic for high speed.
263 """
264 if self.__cache__ is None:
265 self.__compile__()
266 assert self.__cache__ is not None
267 # Chain can be empty, if rules disabled. But we still have to return Array.
268 return self.__cache__.get(chainName, []) or []
270 def get_all_rules(self) -> list[str]:
271 """Return all available rule names."""
272 return [r.name for r in self.__rules__]
274 def get_active_rules(self) -> list[str]:
275 """Return the active rule names."""
276 return [r.name for r in self.__rules__ if r.enabled]