1from __future__ import annotations
2
3import re
4import typing as t
5from dataclasses import dataclass
6from dataclasses import field
7
8from .converters import ValidationError
9from .exceptions import NoMatch
10from .exceptions import RequestAliasRedirect
11from .exceptions import RequestPath
12from .rules import Rule
13from .rules import RulePart
14
15
16class SlashRequired(Exception):
17 pass
18
19
20@dataclass
21class State:
22 """A representation of a rule state.
23
24 This includes the *rules* that correspond to the state and the
25 possible *static* and *dynamic* transitions to the next state.
26 """
27
28 dynamic: list[tuple[RulePart, State]] = field(default_factory=list)
29 rules: list[Rule] = field(default_factory=list)
30 static: dict[str, State] = field(default_factory=dict)
31
32
33class StateMachineMatcher:
34 def __init__(self, merge_slashes: bool) -> None:
35 self._root = State()
36 self.merge_slashes = merge_slashes
37
38 def add(self, rule: Rule) -> None:
39 state = self._root
40 for part in rule._parts:
41 if part.static:
42 state.static.setdefault(part.content, State())
43 state = state.static[part.content]
44 else:
45 for test_part, new_state in state.dynamic:
46 if test_part == part:
47 state = new_state
48 break
49 else:
50 new_state = State()
51 state.dynamic.append((part, new_state))
52 state = new_state
53 state.rules.append(rule)
54
55 def update(self) -> None:
56 # For every state the dynamic transitions should be sorted by
57 # the weight of the transition
58 state = self._root
59
60 def _update_state(state: State) -> None:
61 state.dynamic.sort(key=lambda entry: entry[0].weight)
62 for new_state in state.static.values():
63 _update_state(new_state)
64 for _, new_state in state.dynamic:
65 _update_state(new_state)
66
67 _update_state(state)
68
69 def match(
70 self, domain: str, path: str, method: str, websocket: bool
71 ) -> tuple[Rule, t.MutableMapping[str, t.Any]]:
72 # To match to a rule we need to start at the root state and
73 # try to follow the transitions until we find a match, or find
74 # there is no transition to follow.
75
76 have_match_for = set()
77 websocket_mismatch = False
78
79 def _match(
80 state: State, parts: list[str], values: list[str]
81 ) -> tuple[Rule, list[str]] | None:
82 # This function is meant to be called recursively, and will attempt
83 # to match the head part to the state's transitions.
84 nonlocal have_match_for, websocket_mismatch
85
86 # The base case is when all parts have been matched via
87 # transitions. Hence if there is a rule with methods &
88 # websocket that work return it and the dynamic values
89 # extracted.
90 if parts == []:
91 for rule in state.rules:
92 if rule.methods is not None and method not in rule.methods:
93 have_match_for.update(rule.methods)
94 elif rule.websocket != websocket:
95 websocket_mismatch = True
96 else:
97 return rule, values
98
99 # Test if there is a match with this path with a
100 # trailing slash, if so raise an exception to report
101 # that matching is possible with an additional slash
102 if "" in state.static:
103 for rule in state.static[""].rules:
104 if websocket == rule.websocket and (
105 rule.methods is None or method in rule.methods
106 ):
107 if rule.strict_slashes:
108 raise SlashRequired()
109 else:
110 return rule, values
111 return None
112
113 part = parts[0]
114 # To match this part try the static transitions first
115 if part in state.static:
116 rv = _match(state.static[part], parts[1:], values)
117 if rv is not None:
118 return rv
119 # No match via the static transitions, so try the dynamic
120 # ones.
121 for test_part, new_state in state.dynamic:
122 target = part
123 remaining = parts[1:]
124 # A final part indicates a transition that always
125 # consumes the remaining parts i.e. transitions to a
126 # final state.
127 if test_part.final:
128 target = "/".join(parts)
129 remaining = []
130 match = re.compile(test_part.content).match(target)
131 if match is not None:
132 if test_part.suffixed:
133 # If a part_isolating=False part has a slash suffix, remove the
134 # suffix from the match and check for the slash redirect next.
135 suffix = match.groups()[-1]
136 if suffix == "/":
137 remaining = [""]
138
139 converter_groups = sorted(
140 match.groupdict().items(), key=lambda entry: entry[0]
141 )
142 groups = [
143 value
144 for key, value in converter_groups
145 if key[:11] == "__werkzeug_"
146 ]
147 rv = _match(new_state, remaining, values + groups)
148 if rv is not None:
149 return rv
150
151 # If there is no match and the only part left is a
152 # trailing slash ("") consider rules that aren't
153 # strict-slashes as these should match if there is a final
154 # slash part.
155 if parts == [""]:
156 for rule in state.rules:
157 if rule.strict_slashes:
158 continue
159 if rule.methods is not None and method not in rule.methods:
160 have_match_for.update(rule.methods)
161 elif rule.websocket != websocket:
162 websocket_mismatch = True
163 else:
164 return rule, values
165
166 return None
167
168 try:
169 rv = _match(self._root, [domain, *path.split("/")], [])
170 except SlashRequired:
171 raise RequestPath(f"{path}/") from None
172
173 if self.merge_slashes and rv is None:
174 # Try to match again, but with slashes merged
175 path = re.sub("/{2,}?", "/", path)
176 try:
177 rv = _match(self._root, [domain, *path.split("/")], [])
178 except SlashRequired:
179 raise RequestPath(f"{path}/") from None
180 if rv is None or rv[0].merge_slashes is False:
181 raise NoMatch(have_match_for, websocket_mismatch)
182 else:
183 raise RequestPath(f"{path}")
184 elif rv is not None:
185 rule, values = rv
186
187 result = {}
188 for name, value in zip(rule._converters.keys(), values):
189 try:
190 value = rule._converters[name].to_python(value)
191 except ValidationError:
192 raise NoMatch(have_match_for, websocket_mismatch) from None
193 result[str(name)] = value
194 if rule.defaults:
195 result.update(rule.defaults)
196
197 if rule.alias and rule.map.redirect_defaults:
198 raise RequestAliasRedirect(result, rule.endpoint)
199
200 return rule, result
201
202 raise NoMatch(have_match_for, websocket_mismatch)