1"""
2
3uritemplate.variable
4====================
5
6This module contains the URIVariable class which powers the URITemplate class.
7
8What treasures await you:
9
10- URIVariable class
11
12You see a hammer in front of you.
13What do you do?
14>
15
16"""
17import collections.abc
18import typing as t
19import urllib.parse
20
21ScalarVariableValue = t.Union[int, float, complex, str]
22VariableValue = t.Union[
23 t.Sequence[ScalarVariableValue],
24 t.Mapping[str, ScalarVariableValue],
25 t.Tuple[str, ScalarVariableValue],
26 ScalarVariableValue,
27]
28VariableValueDict = t.Dict[str, VariableValue]
29
30
31class URIVariable:
32
33 """This object validates everything inside the URITemplate object.
34
35 It validates template expansions and will truncate length as decided by
36 the template.
37
38 Please note that just like the :class:`URITemplate <URITemplate>`, this
39 object's ``__str__`` and ``__repr__`` methods do not return the same
40 information. Calling ``str(var)`` will return the original variable.
41
42 This object does the majority of the heavy lifting. The ``URITemplate``
43 object finds the variables in the URI and then creates ``URIVariable``
44 objects. Expansions of the URI are handled by each ``URIVariable``
45 object. ``URIVariable.expand()`` returns a dictionary of the original
46 variable and the expanded value. Check that method's documentation for
47 more information.
48
49 """
50
51 operators = ("+", "#", ".", "/", ";", "?", "&", "|", "!", "@")
52 reserved = ":/?#[]@!$&'()*+,;="
53
54 def __init__(self, var: str):
55 #: The original string that comes through with the variable
56 self.original: str = var
57 #: The operator for the variable
58 self.operator: str = ""
59 #: List of safe characters when quoting the string
60 self.safe: str = ""
61 #: List of variables in this variable
62 self.variables: t.List[
63 t.Tuple[str, t.MutableMapping[str, t.Any]]
64 ] = []
65 #: List of variable names
66 self.variable_names: t.List[str] = []
67 #: List of defaults passed in
68 self.defaults: t.MutableMapping[str, ScalarVariableValue] = {}
69 # Parse the variable itself.
70 self.parse()
71 self.post_parse()
72
73 def __repr__(self) -> str:
74 return "URIVariable(%s)" % self
75
76 def __str__(self) -> str:
77 return self.original
78
79 def parse(self) -> None:
80 """Parse the variable.
81
82 This finds the:
83 - operator,
84 - set of safe characters,
85 - variables, and
86 - defaults.
87
88 """
89 var_list_str = self.original
90 if self.original[0] in URIVariable.operators:
91 self.operator = self.original[0]
92 var_list_str = self.original[1:]
93
94 if self.operator in URIVariable.operators[:2]:
95 self.safe = URIVariable.reserved
96
97 var_list = var_list_str.split(",")
98
99 for var in var_list:
100 default_val = None
101 name = var
102 if "=" in var:
103 name, default_val = tuple(var.split("=", 1))
104
105 explode = False
106 if name.endswith("*"):
107 explode = True
108 name = name[:-1]
109
110 prefix: t.Optional[int] = None
111 if ":" in name:
112 name, prefix_str = tuple(name.split(":", 1))
113 prefix = int(prefix_str)
114
115 if default_val:
116 self.defaults[name] = default_val
117
118 self.variables.append(
119 (name, {"explode": explode, "prefix": prefix})
120 )
121
122 self.variable_names = [varname for (varname, _) in self.variables]
123
124 def post_parse(self) -> None:
125 """Set ``start``, ``join_str`` and ``safe`` attributes.
126
127 After parsing the variable, we need to set up these attributes and it
128 only makes sense to do it in a more easily testable way.
129 """
130 self.safe = ""
131 self.start = self.join_str = self.operator
132 if self.operator == "+":
133 self.start = ""
134 if self.operator in ("+", "#", ""):
135 self.join_str = ","
136 if self.operator == "#":
137 self.start = "#"
138 if self.operator == "?":
139 self.start = "?"
140 self.join_str = "&"
141
142 if self.operator in ("+", "#"):
143 self.safe = URIVariable.reserved
144
145 def _query_expansion(
146 self,
147 name: str,
148 value: VariableValue,
149 explode: bool,
150 prefix: t.Optional[int],
151 ) -> t.Optional[str]:
152 """Expansion method for the '?' and '&' operators."""
153 if value is None:
154 return None
155
156 tuples, items = is_list_of_tuples(value)
157
158 safe = self.safe
159 if list_test(value) and not tuples:
160 if not value:
161 return None
162 value = t.cast(t.Sequence[ScalarVariableValue], value)
163 if explode:
164 return self.join_str.join(
165 f"{name}={quote(v, safe)}" for v in value
166 )
167 else:
168 value = ",".join(quote(v, safe) for v in value)
169 return f"{name}={value}"
170
171 if dict_test(value) or tuples:
172 if not value:
173 return None
174 value = t.cast(t.Mapping[str, ScalarVariableValue], value)
175 items = items or sorted(value.items())
176 if explode:
177 return self.join_str.join(
178 f"{quote(k, safe)}={quote(v, safe)}" for k, v in items
179 )
180 else:
181 value = ",".join(
182 f"{quote(k, safe)},{quote(v, safe)}" for k, v in items
183 )
184 return f"{name}={value}"
185
186 if value:
187 value = t.cast(t.Text, value)
188 value = value[:prefix] if prefix else value
189 return f"{name}={quote(value, safe)}"
190 return name + "="
191
192 def _label_path_expansion(
193 self,
194 name: str,
195 value: VariableValue,
196 explode: bool,
197 prefix: t.Optional[int],
198 ) -> t.Optional[str]:
199 """Label and path expansion method.
200
201 Expands for operators: '/', '.'
202
203 """
204 join_str = self.join_str
205 safe = self.safe
206
207 if value is None or (
208 not isinstance(value, (str, int, float, complex))
209 and len(value) == 0
210 ):
211 return None
212
213 tuples, items = is_list_of_tuples(value)
214
215 if list_test(value) and not tuples:
216 if not explode:
217 join_str = ","
218
219 value = t.cast(t.Sequence[ScalarVariableValue], value)
220 fragments = [quote(v, safe) for v in value if v is not None]
221 return join_str.join(fragments) if fragments else None
222
223 if dict_test(value) or tuples:
224 value = t.cast(t.Mapping[str, ScalarVariableValue], value)
225 items = items or sorted(value.items())
226 format_str = "%s=%s"
227 if not explode:
228 format_str = "%s,%s"
229 join_str = ","
230
231 expanded = join_str.join(
232 format_str % (quote(k, safe), quote(v, safe))
233 for k, v in items
234 if v is not None
235 )
236 return expanded if expanded else None
237
238 value = t.cast(t.Text, value)
239 value = value[:prefix] if prefix else value
240 return quote(value, safe)
241
242 def _semi_path_expansion(
243 self,
244 name: str,
245 value: VariableValue,
246 explode: bool,
247 prefix: t.Optional[int],
248 ) -> t.Optional[str]:
249 """Expansion method for ';' operator."""
250 join_str = self.join_str
251 safe = self.safe
252
253 if value is None:
254 return None
255
256 if self.operator == "?":
257 join_str = "&"
258
259 tuples, items = is_list_of_tuples(value)
260
261 if list_test(value) and not tuples:
262 value = t.cast(t.Sequence[ScalarVariableValue], value)
263 if explode:
264 expanded = join_str.join(
265 f"{name}={quote(v, safe)}" for v in value if v is not None
266 )
267 return expanded if expanded else None
268 else:
269 value = ",".join(quote(v, safe) for v in value)
270 return f"{name}={value}"
271
272 if dict_test(value) or tuples:
273 value = t.cast(t.Mapping[str, ScalarVariableValue], value)
274 items = items or sorted(value.items())
275
276 if explode:
277 return join_str.join(
278 f"{quote(k, safe)}={quote(v, safe)}"
279 for k, v in items
280 if v is not None
281 )
282 else:
283 expanded = ",".join(
284 f"{quote(k, safe)},{quote(v, safe)}"
285 for k, v in items
286 if v is not None
287 )
288 return f"{name}={expanded}"
289
290 value = t.cast(t.Text, value)
291 value = value[:prefix] if prefix else value
292 if value:
293 return f"{name}={quote(value, safe)}"
294
295 return name
296
297 def _string_expansion(
298 self,
299 name: str,
300 value: VariableValue,
301 explode: bool,
302 prefix: t.Optional[int],
303 ) -> t.Optional[str]:
304 if value is None:
305 return None
306
307 tuples, items = is_list_of_tuples(value)
308
309 if list_test(value) and not tuples:
310 value = t.cast(t.Sequence[ScalarVariableValue], value)
311 return ",".join(quote(v, self.safe) for v in value)
312
313 if dict_test(value) or tuples:
314 value = t.cast(t.Mapping[str, ScalarVariableValue], value)
315 items = items or sorted(value.items())
316 format_str = "%s=%s" if explode else "%s,%s"
317
318 return ",".join(
319 format_str % (quote(k, self.safe), quote(v, self.safe))
320 for k, v in items
321 )
322
323 value = t.cast(t.Text, value)
324 value = value[:prefix] if prefix else value
325 return quote(value, self.safe)
326
327 def expand(
328 self, var_dict: t.Optional[VariableValueDict] = None
329 ) -> t.Mapping[str, str]:
330 """Expand the variable in question.
331
332 Using ``var_dict`` and the previously parsed defaults, expand this
333 variable and subvariables.
334
335 :param dict var_dict: dictionary of key-value pairs to be used during
336 expansion
337 :returns: dict(variable=value)
338
339 Examples::
340
341 # (1)
342 v = URIVariable('/var')
343 expansion = v.expand({'var': 'value'})
344 print(expansion)
345 # => {'/var': '/value'}
346
347 # (2)
348 v = URIVariable('?var,hello,x,y')
349 expansion = v.expand({'var': 'value', 'hello': 'Hello World!',
350 'x': '1024', 'y': '768'})
351 print(expansion)
352 # => {'?var,hello,x,y':
353 # '?var=value&hello=Hello%20World%21&x=1024&y=768'}
354
355 """
356 return_values = []
357 if var_dict is None:
358 return {self.original: self.original}
359
360 for name, opts in self.variables:
361 value = var_dict.get(name, None)
362 if not value and value != "" and name in self.defaults:
363 value = self.defaults[name]
364
365 if value is None:
366 continue
367
368 expanded = None
369 if self.operator in ("/", "."):
370 expansion = self._label_path_expansion
371 elif self.operator in ("?", "&"):
372 expansion = self._query_expansion
373 elif self.operator == ";":
374 expansion = self._semi_path_expansion
375 else:
376 expansion = self._string_expansion
377
378 expanded = expansion(name, value, opts["explode"], opts["prefix"])
379
380 if expanded is not None:
381 return_values.append(expanded)
382
383 value = ""
384 if return_values:
385 value = self.start + self.join_str.join(return_values)
386 return {self.original: value}
387
388
389def is_list_of_tuples(
390 value: t.Any,
391) -> t.Tuple[bool, t.Optional[t.Sequence[t.Tuple[str, ScalarVariableValue]]]]:
392 if (
393 not value
394 or not isinstance(value, (list, tuple))
395 or not all(isinstance(t, tuple) and len(t) == 2 for t in value)
396 ):
397 return False, None
398
399 return True, value
400
401
402def list_test(value: t.Any) -> bool:
403 return isinstance(value, (list, tuple))
404
405
406def dict_test(value: t.Any) -> bool:
407 return isinstance(value, (dict, collections.abc.MutableMapping))
408
409
410def _encode(value: t.AnyStr, encoding: str = "utf-8") -> bytes:
411 if isinstance(value, str):
412 return value.encode(encoding)
413 return value
414
415
416def quote(value: t.Any, safe: str) -> str:
417 if not isinstance(value, (str, bytes)):
418 value = str(value)
419 return urllib.parse.quote(_encode(value), safe)