1from __future__ import annotations 
    2 
    3import re 
    4import typing as t 
    5from dataclasses import dataclass 
    6from dataclasses import field 
    7 
    8from .converters import ValidationError 
    9from .exceptions import NoMatch 
    10from .exceptions import RequestAliasRedirect 
    11from .exceptions import RequestPath 
    12from .rules import Rule 
    13from .rules import RulePart 
    14 
    15 
    16class SlashRequired(Exception): 
    17    pass 
    18 
    19 
    20@dataclass 
    21class State: 
    22    """A representation of a rule state. 
    23 
    24    This includes the *rules* that correspond to the state and the 
    25    possible *static* and *dynamic* transitions to the next state. 
    26    """ 
    27 
    28    dynamic: list[tuple[RulePart, State]] = field(default_factory=list) 
    29    rules: list[Rule] = field(default_factory=list) 
    30    static: dict[str, State] = field(default_factory=dict) 
    31 
    32 
    33class StateMachineMatcher: 
    34    def __init__(self, merge_slashes: bool) -> None: 
    35        self._root = State() 
    36        self.merge_slashes = merge_slashes 
    37 
    38    def add(self, rule: Rule) -> None: 
    39        state = self._root 
    40        for part in rule._parts: 
    41            if part.static: 
    42                state.static.setdefault(part.content, State()) 
    43                state = state.static[part.content] 
    44            else: 
    45                for test_part, new_state in state.dynamic: 
    46                    if test_part == part: 
    47                        state = new_state 
    48                        break 
    49                else: 
    50                    new_state = State() 
    51                    state.dynamic.append((part, new_state)) 
    52                    state = new_state 
    53        state.rules.append(rule) 
    54 
    55    def update(self) -> None: 
    56        # For every state the dynamic transitions should be sorted by 
    57        # the weight of the transition 
    58        state = self._root 
    59 
    60        def _update_state(state: State) -> None: 
    61            state.dynamic.sort(key=lambda entry: entry[0].weight) 
    62            for new_state in state.static.values(): 
    63                _update_state(new_state) 
    64            for _, new_state in state.dynamic: 
    65                _update_state(new_state) 
    66 
    67        _update_state(state) 
    68 
    69    def match( 
    70        self, domain: str, path: str, method: str, websocket: bool 
    71    ) -> tuple[Rule, t.MutableMapping[str, t.Any]]: 
    72        # To match to a rule we need to start at the root state and 
    73        # try to follow the transitions until we find a match, or find 
    74        # there is no transition to follow. 
    75 
    76        have_match_for = set() 
    77        websocket_mismatch = False 
    78 
    79        def _match( 
    80            state: State, parts: list[str], values: list[str] 
    81        ) -> tuple[Rule, list[str]] | None: 
    82            # This function is meant to be called recursively, and will attempt 
    83            # to match the head part to the state's transitions. 
    84            nonlocal have_match_for, websocket_mismatch 
    85 
    86            # The base case is when all parts have been matched via 
    87            # transitions. Hence if there is a rule with methods & 
    88            # websocket that work return it and the dynamic values 
    89            # extracted. 
    90            if parts == []: 
    91                for rule in state.rules: 
    92                    if rule.methods is not None and method not in rule.methods: 
    93                        have_match_for.update(rule.methods) 
    94                    elif rule.websocket != websocket: 
    95                        websocket_mismatch = True 
    96                    else: 
    97                        return rule, values 
    98 
    99                # Test if there is a match with this path with a 
    100                # trailing slash, if so raise an exception to report 
    101                # that matching is possible with an additional slash 
    102                if "" in state.static: 
    103                    for rule in state.static[""].rules: 
    104                        if websocket == rule.websocket and ( 
    105                            rule.methods is None or method in rule.methods 
    106                        ): 
    107                            if rule.strict_slashes: 
    108                                raise SlashRequired() 
    109                            else: 
    110                                return rule, values 
    111                return None 
    112 
    113            part = parts[0] 
    114            # To match this part try the static transitions first 
    115            if part in state.static: 
    116                rv = _match(state.static[part], parts[1:], values) 
    117                if rv is not None: 
    118                    return rv 
    119            # No match via the static transitions, so try the dynamic 
    120            # ones. 
    121            for test_part, new_state in state.dynamic: 
    122                target = part 
    123                remaining = parts[1:] 
    124                # A final part indicates a transition that always 
    125                # consumes the remaining parts i.e. transitions to a 
    126                # final state. 
    127                if test_part.final: 
    128                    target = "/".join(parts) 
    129                    remaining = [] 
    130                match = re.compile(test_part.content).match(target) 
    131                if match is not None: 
    132                    if test_part.suffixed: 
    133                        # If a part_isolating=False part has a slash suffix, remove the 
    134                        # suffix from the match and check for the slash redirect next. 
    135                        suffix = match.groups()[-1] 
    136                        if suffix == "/": 
    137                            remaining = [""] 
    138 
    139                    converter_groups = sorted( 
    140                        match.groupdict().items(), key=lambda entry: entry[0] 
    141                    ) 
    142                    groups = [ 
    143                        value 
    144                        for key, value in converter_groups 
    145                        if key[:11] == "__werkzeug_" 
    146                    ] 
    147                    rv = _match(new_state, remaining, values + groups) 
    148                    if rv is not None: 
    149                        return rv 
    150 
    151            # If there is no match and the only part left is a 
    152            # trailing slash ("") consider rules that aren't 
    153            # strict-slashes as these should match if there is a final 
    154            # slash part. 
    155            if parts == [""]: 
    156                for rule in state.rules: 
    157                    if rule.strict_slashes: 
    158                        continue 
    159                    if rule.methods is not None and method not in rule.methods: 
    160                        have_match_for.update(rule.methods) 
    161                    elif rule.websocket != websocket: 
    162                        websocket_mismatch = True 
    163                    else: 
    164                        return rule, values 
    165 
    166            return None 
    167 
    168        try: 
    169            rv = _match(self._root, [domain, *path.split("/")], []) 
    170        except SlashRequired: 
    171            raise RequestPath(f"{path}/") from None 
    172 
    173        if self.merge_slashes and rv is None: 
    174            # Try to match again, but with slashes merged 
    175            path = re.sub("/{2,}?", "/", path) 
    176            try: 
    177                rv = _match(self._root, [domain, *path.split("/")], []) 
    178            except SlashRequired: 
    179                raise RequestPath(f"{path}/") from None 
    180            if rv is None or rv[0].merge_slashes is False: 
    181                raise NoMatch(have_match_for, websocket_mismatch) 
    182            else: 
    183                raise RequestPath(f"{path}") 
    184        elif rv is not None: 
    185            rule, values = rv 
    186 
    187            result = {} 
    188            for name, value in zip(rule._converters.keys(), values): 
    189                try: 
    190                    value = rule._converters[name].to_python(value) 
    191                except ValidationError: 
    192                    raise NoMatch(have_match_for, websocket_mismatch) from None 
    193                result[str(name)] = value 
    194            if rule.defaults: 
    195                result.update(rule.defaults) 
    196 
    197            if rule.alias and rule.map.redirect_defaults: 
    198                raise RequestAliasRedirect(result, rule.endpoint) 
    199 
    200            return rule, result 
    201 
    202        raise NoMatch(have_match_for, websocket_mismatch)