1"""Process URI templates per http://tools.ietf.org/html/rfc6570."""
2
3from __future__ import annotations
4
5import re
6from typing import TYPE_CHECKING
7
8from .expansions import (CommaExpansion, Expansion,
9 FormStyleQueryContinuation, FormStyleQueryExpansion,
10 FragmentExpansion, LabelExpansion, Literal,
11 PathExpansion, PathStyleExpansion,
12 ReservedCommaExpansion, ReservedExpansion, SimpleExpansion)
13
14if (TYPE_CHECKING):
15 from collections.abc import Iterable
16 from .variable import Variable
17
18
19class ExpansionReservedError(Exception):
20 """Exception thrown for reserved but unsupported expansions."""
21
22 expansion: str
23
24 def __init__(self, expansion: str) -> None:
25 self.expansion = expansion
26
27 def __str__(self) -> str:
28 """Convert to string."""
29 return 'Unsupported expansion: ' + self.expansion
30
31
32class ExpansionInvalidError(Exception):
33 """Exception thrown for unknown expansions."""
34
35 expansion: str
36
37 def __init__(self, expansion: str) -> None:
38 self.expansion = expansion
39
40 def __str__(self) -> str:
41 """Convert to string."""
42 return 'Bad expansion: ' + self.expansion
43
44
45class URITemplate:
46 """
47 URI Template object.
48
49 Constructor may raise ExpansionReservedError, ExpansionInvalidError, or VariableInvalidError.
50 """
51
52 expansions: list[Expansion]
53
54 def __init__(self, template: str) -> None:
55 self.expansions = []
56 parts = re.split(r'(\{[^\}]*\})', template)
57 for part in parts:
58 if (part):
59 if (('{' == part[0]) and ('}' == part[-1])):
60 expansion = part[1:-1]
61 if (re.match('^([a-zA-Z0-9_]|%[0-9a-fA-F][0-9a-fA-F]).*$', expansion)):
62 self.expansions.append(SimpleExpansion(expansion))
63 elif ('+' == part[1]):
64 self.expansions.append(ReservedExpansion(expansion))
65 elif ('#' == part[1]):
66 self.expansions.append(FragmentExpansion(expansion))
67 elif ('.' == part[1]):
68 self.expansions.append(LabelExpansion(expansion))
69 elif ('/' == part[1]):
70 self.expansions.append(PathExpansion(expansion))
71 elif (';' == part[1]):
72 self.expansions.append(PathStyleExpansion(expansion))
73 elif ('?' == part[1]):
74 self.expansions.append(FormStyleQueryExpansion(expansion))
75 elif ('&' == part[1]):
76 self.expansions.append(FormStyleQueryContinuation(expansion))
77 elif (',' == part[1]):
78 if ((1 < len(part)) and ('+' == part[2])):
79 self.expansions.append(ReservedCommaExpansion(expansion))
80 else:
81 self.expansions.append(CommaExpansion(expansion))
82 elif (part[1] in '=!@|'):
83 raise ExpansionReservedError(part)
84 else:
85 raise ExpansionInvalidError(part)
86 else:
87 if (('{' not in part) and ('}' not in part)):
88 self.expansions.append(Literal(part))
89 else:
90 raise ExpansionInvalidError(part)
91
92 @property
93 def variables(self) -> Iterable[Variable]:
94 """Get all variables in template."""
95 vars: dict[str, Variable] = {}
96 for expansion in self.expansions:
97 for var in expansion.variables:
98 vars[var.name] = var
99 return vars.values()
100
101 @property
102 def variable_names(self) -> Iterable[str]:
103 """Get names of all variables in template."""
104 vars: dict[str, Variable] = {}
105 for expansion in self.expansions:
106 for var in expansion.variables:
107 vars[var.name] = var
108 return [var.name for var in vars.values()]
109
110 def expand(self, **kwargs) -> str:
111 """
112 Expand the template.
113
114 May raise ExpansionFailed if a composite value is passed to a variable with a prefix modifier.
115 """
116 expanded = [expansion.expand(kwargs) for expansion in self.expansions]
117 return ''.join([expansion for expansion in expanded if (expansion is not None)])
118
119 def partial(self, **kwargs) -> URITemplate:
120 """
121 Expand the template, preserving expansions for missing variables.
122
123 May raise ExpansionFailed if a composite value is passed to a variable with a prefix modifier.
124 """
125 expanded = [expansion.partial(kwargs) for expansion in self.expansions]
126 return URITemplate(''.join(expanded))
127
128 @property
129 def expanded(self) -> bool:
130 """Determine if template is fully expanded."""
131 return (str(self) == self.expand())
132
133 def __str__(self) -> str:
134 """Convert to string, returns original template."""
135 return ''.join([str(expansion) for expansion in self.expansions])
136
137 def __repr__(self) -> str:
138 """Convert to string, returns original template."""
139 return str(self)