1import re
2import typing as t
3from dataclasses import dataclass
4from dataclasses import field
5
6from .converters import ValidationError
7from .exceptions import NoMatch
8from .exceptions import RequestAliasRedirect
9from .exceptions import RequestPath
10from .rules import Rule
11from .rules import RulePart
12
13
14class SlashRequired(Exception):
15 pass
16
17
18@dataclass
19class State:
20 """A representation of a rule state.
21
22 This includes the *rules* that correspond to the state and the
23 possible *static* and *dynamic* transitions to the next state.
24 """
25
26 dynamic: t.List[t.Tuple[RulePart, "State"]] = field(default_factory=list)
27 rules: t.List[Rule] = field(default_factory=list)
28 static: t.Dict[str, "State"] = field(default_factory=dict)
29
30
31class StateMachineMatcher:
32 def __init__(self, merge_slashes: bool) -> None:
33 self._root = State()
34 self.merge_slashes = merge_slashes
35
36 def add(self, rule: Rule) -> None:
37 state = self._root
38 for part in rule._parts:
39 if part.static:
40 state.static.setdefault(part.content, State())
41 state = state.static[part.content]
42 else:
43 for test_part, new_state in state.dynamic:
44 if test_part == part:
45 state = new_state
46 break
47 else:
48 new_state = State()
49 state.dynamic.append((part, new_state))
50 state = new_state
51 state.rules.append(rule)
52
53 def update(self) -> None:
54 # For every state the dynamic transitions should be sorted by
55 # the weight of the transition
56 state = self._root
57
58 def _update_state(state: State) -> None:
59 state.dynamic.sort(key=lambda entry: entry[0].weight)
60 for new_state in state.static.values():
61 _update_state(new_state)
62 for _, new_state in state.dynamic:
63 _update_state(new_state)
64
65 _update_state(state)
66
67 def match(
68 self, domain: str, path: str, method: str, websocket: bool
69 ) -> t.Tuple[Rule, t.MutableMapping[str, t.Any]]:
70 # To match to a rule we need to start at the root state and
71 # try to follow the transitions until we find a match, or find
72 # there is no transition to follow.
73
74 have_match_for = set()
75 websocket_mismatch = False
76
77 def _match(
78 state: State, parts: t.List[str], values: t.List[str]
79 ) -> t.Optional[t.Tuple[Rule, t.List[str]]]:
80 # This function is meant to be called recursively, and will attempt
81 # to match the head part to the state's transitions.
82 nonlocal have_match_for, websocket_mismatch
83
84 # The base case is when all parts have been matched via
85 # transitions. Hence if there is a rule with methods &
86 # websocket that work return it and the dynamic values
87 # extracted.
88 if parts == []:
89 for rule in state.rules:
90 if rule.methods is not None and method not in rule.methods:
91 have_match_for.update(rule.methods)
92 elif rule.websocket != websocket:
93 websocket_mismatch = True
94 else:
95 return rule, values
96
97 # Test if there is a match with this path with a
98 # trailing slash, if so raise an exception to report
99 # that matching is possible with an additional slash
100 if "" in state.static:
101 for rule in state.static[""].rules:
102 if websocket == rule.websocket and (
103 rule.methods is None or method in rule.methods
104 ):
105 if rule.strict_slashes:
106 raise SlashRequired()
107 else:
108 return rule, values
109 return None
110
111 part = parts[0]
112 # To match this part try the static transitions first
113 if part in state.static:
114 rv = _match(state.static[part], parts[1:], values)
115 if rv is not None:
116 return rv
117 # No match via the static transitions, so try the dynamic
118 # ones.
119 for test_part, new_state in state.dynamic:
120 target = part
121 remaining = parts[1:]
122 # A final part indicates a transition that always
123 # consumes the remaining parts i.e. transitions to a
124 # final state.
125 if test_part.final:
126 target = "/".join(parts)
127 remaining = []
128 match = re.compile(test_part.content).match(target)
129 if match is not None:
130 groups = list(match.groups())
131 if test_part.suffixed:
132 # If a part_isolating=False part has a slash suffix, remove the
133 # suffix from the match and check for the slash redirect next.
134 suffix = groups.pop()
135 if suffix == "/":
136 remaining = [""]
137 rv = _match(new_state, remaining, values + groups)
138 if rv is not None:
139 return rv
140
141 # If there is no match and the only part left is a
142 # trailing slash ("") consider rules that aren't
143 # strict-slashes as these should match if there is a final
144 # slash part.
145 if parts == [""]:
146 for rule in state.rules:
147 if rule.strict_slashes:
148 continue
149 if rule.methods is not None and method not in rule.methods:
150 have_match_for.update(rule.methods)
151 elif rule.websocket != websocket:
152 websocket_mismatch = True
153 else:
154 return rule, values
155
156 return None
157
158 try:
159 rv = _match(self._root, [domain, *path.split("/")], [])
160 except SlashRequired:
161 raise RequestPath(f"{path}/") from None
162
163 if self.merge_slashes and rv is None:
164 # Try to match again, but with slashes merged
165 path = re.sub("/{2,}?", "/", path)
166 try:
167 rv = _match(self._root, [domain, *path.split("/")], [])
168 except SlashRequired:
169 raise RequestPath(f"{path}/") from None
170 if rv is None:
171 raise NoMatch(have_match_for, websocket_mismatch)
172 else:
173 raise RequestPath(f"{path}")
174 elif rv is not None:
175 rule, values = rv
176
177 result = {}
178 for name, value in zip(rule._converters.keys(), values):
179 try:
180 value = rule._converters[name].to_python(value)
181 except ValidationError:
182 raise NoMatch(have_match_for, websocket_mismatch) from None
183 result[str(name)] = value
184 if rule.defaults:
185 result.update(rule.defaults)
186
187 if rule.alias and rule.map.redirect_defaults:
188 raise RequestAliasRedirect(result, rule.endpoint)
189
190 return rule, result
191
192 raise NoMatch(have_match_for, websocket_mismatch)