1"""
2class Ruler
3
4Helper class, used by [[MarkdownIt#core]], [[MarkdownIt#block]] and
5[[MarkdownIt#inline]] to manage sequences of functions (rules):
6
7- keep rules in defined order
8- assign the name to each rule
9- enable/disable rules
10- add/replace rules
11- allow assign rules to additional named chains (in the same)
12- caching lists of active rules
13
14You will not need use this class directly until write plugins. For simple
15rules control use [[MarkdownIt.disable]], [[MarkdownIt.enable]] and
16[[MarkdownIt.use]].
17"""
18
19from __future__ import annotations
20
21from collections.abc import Iterable
22from dataclasses import dataclass, field
23from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar
24import warnings
25
26from .utils import EnvType
27
28if TYPE_CHECKING:
29 from markdown_it import MarkdownIt
30
31
32class StateBase:
33 def __init__(self, src: str, md: MarkdownIt, env: EnvType):
34 self.src = src
35 self.env = env
36 self.md = md
37
38 @property
39 def src(self) -> str:
40 return self._src
41
42 @src.setter
43 def src(self, value: str) -> None:
44 self._src = value
45 self._srcCharCode: tuple[int, ...] | None = None
46
47 @property
48 def srcCharCode(self) -> tuple[int, ...]:
49 warnings.warn(
50 "StateBase.srcCharCode is deprecated. Use StateBase.src instead.",
51 DeprecationWarning,
52 stacklevel=2,
53 )
54 if self._srcCharCode is None:
55 self._srcCharCode = tuple(ord(c) for c in self._src)
56 return self._srcCharCode
57
58
59class RuleOptionsType(TypedDict, total=False):
60 alt: list[str]
61
62
63RuleFuncTv = TypeVar("RuleFuncTv")
64"""A rule function, whose signature is dependent on the state type."""
65
66
67@dataclass(slots=True)
68class Rule(Generic[RuleFuncTv]):
69 name: str
70 enabled: bool
71 fn: RuleFuncTv = field(repr=False)
72 alt: list[str]
73
74
75class Ruler(Generic[RuleFuncTv]):
76 def __init__(self) -> None:
77 # List of added rules.
78 self.__rules__: list[Rule[RuleFuncTv]] = []
79 # Cached rule chains.
80 # First level - chain name, '' for default.
81 # Second level - diginal anchor for fast filtering by charcodes.
82 self.__cache__: dict[str, list[RuleFuncTv]] | None = None
83
84 def __find__(self, name: str) -> int:
85 """Find rule index by name"""
86 for i, rule in enumerate(self.__rules__):
87 if rule.name == name:
88 return i
89 return -1
90
91 def __compile__(self) -> None:
92 """Build rules lookup cache"""
93 chains = {""}
94 # collect unique names
95 for rule in self.__rules__:
96 if not rule.enabled:
97 continue
98 for name in rule.alt:
99 chains.add(name)
100 self.__cache__ = {}
101 for chain in chains:
102 self.__cache__[chain] = []
103 for rule in self.__rules__:
104 if not rule.enabled:
105 continue
106 if chain and (chain not in rule.alt):
107 continue
108 self.__cache__[chain].append(rule.fn)
109
110 def at(
111 self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
112 ) -> None:
113 """Replace rule by name with new function & options.
114
115 :param ruleName: rule name to replace.
116 :param fn: new rule function.
117 :param options: new rule options (not mandatory).
118 :raises: KeyError if name not found
119 """
120 index = self.__find__(ruleName)
121 options = options or {}
122 if index == -1:
123 raise KeyError(f"Parser rule not found: {ruleName}")
124 self.__rules__[index].fn = fn
125 self.__rules__[index].alt = options.get("alt", [])
126 self.__cache__ = None
127
128 def before(
129 self,
130 beforeName: str,
131 ruleName: str,
132 fn: RuleFuncTv,
133 options: RuleOptionsType | None = None,
134 ) -> None:
135 """Add new rule to chain before one with given name.
136
137 :param beforeName: new rule will be added before this one.
138 :param ruleName: new rule will be added before this one.
139 :param fn: new rule function.
140 :param options: new rule options (not mandatory).
141 :raises: KeyError if name not found
142 """
143 index = self.__find__(beforeName)
144 options = options or {}
145 if index == -1:
146 raise KeyError(f"Parser rule not found: {beforeName}")
147 self.__rules__.insert(
148 index, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
149 )
150 self.__cache__ = None
151
152 def after(
153 self,
154 afterName: str,
155 ruleName: str,
156 fn: RuleFuncTv,
157 options: RuleOptionsType | None = None,
158 ) -> None:
159 """Add new rule to chain after one with given name.
160
161 :param afterName: new rule will be added after this one.
162 :param ruleName: new rule will be added after this one.
163 :param fn: new rule function.
164 :param options: new rule options (not mandatory).
165 :raises: KeyError if name not found
166 """
167 index = self.__find__(afterName)
168 options = options or {}
169 if index == -1:
170 raise KeyError(f"Parser rule not found: {afterName}")
171 self.__rules__.insert(
172 index + 1, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
173 )
174 self.__cache__ = None
175
176 def push(
177 self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
178 ) -> None:
179 """Push new rule to the end of chain.
180
181 :param ruleName: new rule will be added to the end of chain.
182 :param fn: new rule function.
183 :param options: new rule options (not mandatory).
184
185 """
186 self.__rules__.append(
187 Rule[RuleFuncTv](ruleName, True, fn, (options or {}).get("alt", []))
188 )
189 self.__cache__ = None
190
191 def enable(
192 self, names: str | Iterable[str], ignoreInvalid: bool = False
193 ) -> list[str]:
194 """Enable rules with given names.
195
196 :param names: name or list of rule names to enable.
197 :param ignoreInvalid: ignore errors when rule not found
198 :raises: KeyError if name not found and not ignoreInvalid
199 :return: list of found rule names
200 """
201 if isinstance(names, str):
202 names = [names]
203 result: list[str] = []
204 for name in names:
205 idx = self.__find__(name)
206 if (idx < 0) and ignoreInvalid:
207 continue
208 if (idx < 0) and not ignoreInvalid:
209 raise KeyError(f"Rules manager: invalid rule name {name}")
210 self.__rules__[idx].enabled = True
211 result.append(name)
212 self.__cache__ = None
213 return result
214
215 def enableOnly(
216 self, names: str | Iterable[str], ignoreInvalid: bool = False
217 ) -> list[str]:
218 """Enable rules with given names, and disable everything else.
219
220 :param names: name or list of rule names to enable.
221 :param ignoreInvalid: ignore errors when rule not found
222 :raises: KeyError if name not found and not ignoreInvalid
223 :return: list of found rule names
224 """
225 if isinstance(names, str):
226 names = [names]
227 for rule in self.__rules__:
228 rule.enabled = False
229 return self.enable(names, ignoreInvalid)
230
231 def disable(
232 self, names: str | Iterable[str], ignoreInvalid: bool = False
233 ) -> list[str]:
234 """Disable rules with given names.
235
236 :param names: name or list of rule names to enable.
237 :param ignoreInvalid: ignore errors when rule not found
238 :raises: KeyError if name not found and not ignoreInvalid
239 :return: list of found rule names
240 """
241 if isinstance(names, str):
242 names = [names]
243 result = []
244 for name in names:
245 idx = self.__find__(name)
246 if (idx < 0) and ignoreInvalid:
247 continue
248 if (idx < 0) and not ignoreInvalid:
249 raise KeyError(f"Rules manager: invalid rule name {name}")
250 self.__rules__[idx].enabled = False
251 result.append(name)
252 self.__cache__ = None
253 return result
254
255 def getRules(self, chainName: str = "") -> list[RuleFuncTv]:
256 """Return array of active functions (rules) for given chain name.
257 It analyzes rules configuration, compiles caches if not exists and returns result.
258
259 Default chain name is `''` (empty string). It can't be skipped.
260 That's done intentionally, to keep signature monomorphic for high speed.
261
262 """
263 if self.__cache__ is None:
264 self.__compile__()
265 assert self.__cache__ is not None
266 # Chain can be empty, if rules disabled. But we still have to return Array.
267 return self.__cache__.get(chainName, []) or []
268
269 def get_all_rules(self) -> list[str]:
270 """Return all available rule names."""
271 return [r.name for r in self.__rules__]
272
273 def get_active_rules(self) -> list[str]:
274 """Return the active rule names."""
275 return [r.name for r in self.__rules__ if r.enabled]