1"""
2class Ruler
3
4Helper class, used by [[MarkdownIt#core]], [[MarkdownIt#block]] and
5[[MarkdownIt#inline]] to manage sequences of functions (rules):
6
7- keep rules in defined order
8- assign the name to each rule
9- enable/disable rules
10- add/replace rules
11- allow assign rules to additional named chains (in the same)
12- caching lists of active rules
13
14You will not need use this class directly until write plugins. For simple
15rules control use [[MarkdownIt.disable]], [[MarkdownIt.enable]] and
16[[MarkdownIt.use]].
17"""
18
19from __future__ import annotations
20
21from collections.abc import Iterable
22from dataclasses import dataclass, field
23from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar
24import warnings
25
26from markdown_it._compat import DATACLASS_KWARGS
27
28from .utils import EnvType
29
30if TYPE_CHECKING:
31 from markdown_it import MarkdownIt
32
33
34class StateBase:
35 def __init__(self, src: str, md: MarkdownIt, env: EnvType):
36 self.src = src
37 self.env = env
38 self.md = md
39
40 @property
41 def src(self) -> str:
42 return self._src
43
44 @src.setter
45 def src(self, value: str) -> None:
46 self._src = value
47 self._srcCharCode: tuple[int, ...] | None = None
48
49 @property
50 def srcCharCode(self) -> tuple[int, ...]:
51 warnings.warn(
52 "StateBase.srcCharCode is deprecated. Use StateBase.src instead.",
53 DeprecationWarning,
54 stacklevel=2,
55 )
56 if self._srcCharCode is None:
57 self._srcCharCode = tuple(ord(c) for c in self._src)
58 return self._srcCharCode
59
60
61class RuleOptionsType(TypedDict, total=False):
62 alt: list[str]
63
64
65RuleFuncTv = TypeVar("RuleFuncTv")
66"""A rule function, whose signature is dependent on the state type."""
67
68
69@dataclass(**DATACLASS_KWARGS)
70class Rule(Generic[RuleFuncTv]):
71 name: str
72 enabled: bool
73 fn: RuleFuncTv = field(repr=False)
74 alt: list[str]
75
76
77class Ruler(Generic[RuleFuncTv]):
78 def __init__(self) -> None:
79 # List of added rules.
80 self.__rules__: list[Rule[RuleFuncTv]] = []
81 # Cached rule chains.
82 # First level - chain name, '' for default.
83 # Second level - diginal anchor for fast filtering by charcodes.
84 self.__cache__: dict[str, list[RuleFuncTv]] | None = None
85
86 def __find__(self, name: str) -> int:
87 """Find rule index by name"""
88 for i, rule in enumerate(self.__rules__):
89 if rule.name == name:
90 return i
91 return -1
92
93 def __compile__(self) -> None:
94 """Build rules lookup cache"""
95 chains = {""}
96 # collect unique names
97 for rule in self.__rules__:
98 if not rule.enabled:
99 continue
100 for name in rule.alt:
101 chains.add(name)
102 self.__cache__ = {}
103 for chain in chains:
104 self.__cache__[chain] = []
105 for rule in self.__rules__:
106 if not rule.enabled:
107 continue
108 if chain and (chain not in rule.alt):
109 continue
110 self.__cache__[chain].append(rule.fn)
111
112 def at(
113 self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
114 ) -> None:
115 """Replace rule by name with new function & options.
116
117 :param ruleName: rule name to replace.
118 :param fn: new rule function.
119 :param options: new rule options (not mandatory).
120 :raises: KeyError if name not found
121 """
122 index = self.__find__(ruleName)
123 options = options or {}
124 if index == -1:
125 raise KeyError(f"Parser rule not found: {ruleName}")
126 self.__rules__[index].fn = fn
127 self.__rules__[index].alt = options.get("alt", [])
128 self.__cache__ = None
129
130 def before(
131 self,
132 beforeName: str,
133 ruleName: str,
134 fn: RuleFuncTv,
135 options: RuleOptionsType | None = None,
136 ) -> None:
137 """Add new rule to chain before one with given name.
138
139 :param beforeName: new rule will be added before this one.
140 :param ruleName: new rule will be added before this one.
141 :param fn: new rule function.
142 :param options: new rule options (not mandatory).
143 :raises: KeyError if name not found
144 """
145 index = self.__find__(beforeName)
146 options = options or {}
147 if index == -1:
148 raise KeyError(f"Parser rule not found: {beforeName}")
149 self.__rules__.insert(
150 index, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
151 )
152 self.__cache__ = None
153
154 def after(
155 self,
156 afterName: str,
157 ruleName: str,
158 fn: RuleFuncTv,
159 options: RuleOptionsType | None = None,
160 ) -> None:
161 """Add new rule to chain after one with given name.
162
163 :param afterName: new rule will be added after this one.
164 :param ruleName: new rule will be added after this one.
165 :param fn: new rule function.
166 :param options: new rule options (not mandatory).
167 :raises: KeyError if name not found
168 """
169 index = self.__find__(afterName)
170 options = options or {}
171 if index == -1:
172 raise KeyError(f"Parser rule not found: {afterName}")
173 self.__rules__.insert(
174 index + 1, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
175 )
176 self.__cache__ = None
177
178 def push(
179 self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
180 ) -> None:
181 """Push new rule to the end of chain.
182
183 :param ruleName: new rule will be added to the end of chain.
184 :param fn: new rule function.
185 :param options: new rule options (not mandatory).
186
187 """
188 self.__rules__.append(
189 Rule[RuleFuncTv](ruleName, True, fn, (options or {}).get("alt", []))
190 )
191 self.__cache__ = None
192
193 def enable(
194 self, names: str | Iterable[str], ignoreInvalid: bool = False
195 ) -> list[str]:
196 """Enable rules with given names.
197
198 :param names: name or list of rule names to enable.
199 :param ignoreInvalid: ignore errors when rule not found
200 :raises: KeyError if name not found and not ignoreInvalid
201 :return: list of found rule names
202 """
203 if isinstance(names, str):
204 names = [names]
205 result: list[str] = []
206 for name in names:
207 idx = self.__find__(name)
208 if (idx < 0) and ignoreInvalid:
209 continue
210 if (idx < 0) and not ignoreInvalid:
211 raise KeyError(f"Rules manager: invalid rule name {name}")
212 self.__rules__[idx].enabled = True
213 result.append(name)
214 self.__cache__ = None
215 return result
216
217 def enableOnly(
218 self, names: str | Iterable[str], ignoreInvalid: bool = False
219 ) -> list[str]:
220 """Enable rules with given names, and disable everything else.
221
222 :param names: name or list of rule names to enable.
223 :param ignoreInvalid: ignore errors when rule not found
224 :raises: KeyError if name not found and not ignoreInvalid
225 :return: list of found rule names
226 """
227 if isinstance(names, str):
228 names = [names]
229 for rule in self.__rules__:
230 rule.enabled = False
231 return self.enable(names, ignoreInvalid)
232
233 def disable(
234 self, names: str | Iterable[str], ignoreInvalid: bool = False
235 ) -> list[str]:
236 """Disable rules with given names.
237
238 :param names: name or list of rule names to enable.
239 :param ignoreInvalid: ignore errors when rule not found
240 :raises: KeyError if name not found and not ignoreInvalid
241 :return: list of found rule names
242 """
243 if isinstance(names, str):
244 names = [names]
245 result = []
246 for name in names:
247 idx = self.__find__(name)
248 if (idx < 0) and ignoreInvalid:
249 continue
250 if (idx < 0) and not ignoreInvalid:
251 raise KeyError(f"Rules manager: invalid rule name {name}")
252 self.__rules__[idx].enabled = False
253 result.append(name)
254 self.__cache__ = None
255 return result
256
257 def getRules(self, chainName: str = "") -> list[RuleFuncTv]:
258 """Return array of active functions (rules) for given chain name.
259 It analyzes rules configuration, compiles caches if not exists and returns result.
260
261 Default chain name is `''` (empty string). It can't be skipped.
262 That's done intentionally, to keep signature monomorphic for high speed.
263
264 """
265 if self.__cache__ is None:
266 self.__compile__()
267 assert self.__cache__ is not None
268 # Chain can be empty, if rules disabled. But we still have to return Array.
269 return self.__cache__.get(chainName, []) or []
270
271 def get_all_rules(self) -> list[str]:
272 """Return all available rule names."""
273 return [r.name for r in self.__rules__]
274
275 def get_active_rules(self) -> list[str]:
276 """Return the active rule names."""
277 return [r.name for r in self.__rules__ if r.enabled]