1# coding=utf-8
2# --------------------------------------------------------------------------
3# Copyright (c) Microsoft Corporation. All rights reserved.
4# Licensed under the MIT License. See License.txt in the project root for
5# license information.
6# --------------------------------------------------------------------------
7import base64
8from functools import partial
9from json import JSONEncoder
10from typing import Dict, List, Optional, Union, cast, Any, Type, Callable, Tuple
11from datetime import datetime, date, time, timedelta
12from datetime import timezone
13
14
15__all__ = [
16 "NULL",
17 "AzureJSONEncoder",
18 "is_generated_model",
19 "as_attribute_dict",
20 "attribute_list",
21 "TypeHandlerRegistry",
22]
23TZ_UTC = timezone.utc
24
25
26class _Null:
27 """To create a Falsy object"""
28
29 def __bool__(self) -> bool:
30 return False
31
32
33NULL = _Null()
34"""
35A falsy sentinel object which is supposed to be used to specify attributes
36with no data. This gets serialized to `null` on the wire.
37"""
38
39
40class TypeHandlerRegistry:
41 """A registry for custom serializers and deserializers for specific types or conditions."""
42
43 def __init__(self) -> None:
44 self._serializer_types: Dict[Type, Callable] = {}
45 self._deserializer_types: Dict[Type, Callable] = {}
46 self._serializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
47 self._deserializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
48
49 self._serializer_cache: Dict[Type, Optional[Callable]] = {}
50 self._deserializer_cache: Dict[Type, Optional[Callable]] = {}
51
52 def register_serializer(
53 self, condition: Union[Type, Callable[[Any], bool]]
54 ) -> Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]:
55 """Decorator to register a serializer.
56
57 The handler function is expected to take a single argument, the object to serialize,
58 and return a dictionary representation of that object.
59
60 Examples:
61
62 .. code-block:: python
63
64 @registry.register_serializer(CustomModel)
65 def serialize_single_type(value: CustomModel) -> dict:
66 return value.to_dict()
67
68 @registry.register_serializer(lambda x: isinstance(x, BaseModel))
69 def serialize_with_condition(value: BaseModel) -> dict:
70 return value.to_dict()
71
72 # Called manually for a specific type
73 def custom_serializer(value: CustomModel) -> Dict[str, Any]:
74 return {"custom": value.custom}
75
76 registry.register_serializer(CustomModel)(custom_serializer)
77
78 :param condition: A type or a callable predicate function that takes an object and returns a bool.
79 :type condition: Union[Type, Callable[[Any], bool]]
80 :return: A decorator that registers the handler function.
81 :rtype: Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]
82 :raises TypeError: If the condition is neither a type nor a callable.
83 """
84
85 def decorator(handler_func: Callable[[Any], Dict[str, Any]]) -> Callable[[Any], Dict[str, Any]]:
86 if isinstance(condition, type):
87 self._serializer_types[condition] = handler_func
88 elif callable(condition):
89 self._serializer_predicates.append((condition, handler_func))
90 else:
91 raise TypeError("Condition must be a type or a callable predicate function.")
92
93 self._serializer_cache.clear()
94 return handler_func
95
96 return decorator
97
98 def register_deserializer(
99 self, condition: Union[Type, Callable[[Any], bool]]
100 ) -> Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]:
101 """Decorator to register a deserializer.
102
103 The handler function is expected to take two arguments: the target type and the data dictionary,
104 and return an instance of the target type.
105
106 Examples:
107
108 .. code-block:: python
109
110 @registry.register_deserializer(CustomModel)
111 def deserialize_single_type(cls: Type[CustomModel], data: dict) -> CustomModel:
112 return cls(**data)
113
114 @registry.register_deserializer(lambda t: issubclass(t, BaseModel))
115 def deserialize_with_condition(cls: Type[BaseModel], data: dict) -> BaseModel:
116 return cls(**data)
117
118 # Called manually for a specific type
119 def custom_deserializer(cls: Type[CustomModel], data: Dict[str, Any]) -> CustomModel:
120 return cls(custom=data["custom"])
121
122 registry.register_deserializer(CustomModel)(custom_deserializer)
123
124 :param condition: A type or a callable predicate function that takes an object and returns a bool.
125 :type condition: Union[Type, Callable[[Any], bool]]
126 :return: A decorator that registers the handler function.
127 :rtype: Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]
128 :raises TypeError: If the condition is neither a type nor a callable.
129 """
130
131 def decorator(handler_func: Callable[[Type, Dict[str, Any]], Any]) -> Callable[[Type, Dict[str, Any]], Any]:
132 if isinstance(condition, type):
133 self._deserializer_types[condition] = handler_func
134 elif callable(condition):
135 self._deserializer_predicates.append((condition, handler_func))
136 else:
137 raise TypeError("Condition must be a type or a callable predicate function.")
138
139 self._deserializer_cache.clear()
140 return handler_func
141
142 return decorator
143
144 def get_serializer(self, obj: Any) -> Optional[Callable]:
145 """Gets the appropriate serializer for an object.
146
147 It first checks the type dictionary for a direct type match.
148 If no match is found, it iterates through the predicate list to find a match.
149
150 Results of the lookup are cached for performance based on the object's type.
151
152 :param obj: The object to serialize.
153 :type obj: any
154 :return: The serializer function if found, otherwise None.
155 :rtype: Optional[Callable]
156 """
157 obj_type = type(obj)
158 if obj_type in self._serializer_cache:
159 return self._serializer_cache[obj_type]
160
161 handler = self._serializer_types.get(type(obj))
162 if not handler:
163 for predicate, pred_handler in self._serializer_predicates:
164 if predicate(obj):
165 handler = pred_handler
166 break
167
168 self._serializer_cache[obj_type] = handler
169 return handler
170
171 def get_deserializer(self, cls: Type) -> Optional[Callable]:
172 """Gets the appropriate deserializer for a class.
173
174 It first checks the type dictionary for a direct type match.
175 If no match is found, it iterates through the predicate list to find a match.
176
177 Results of the lookup are cached for performance based on the class.
178
179 :param cls: The class to deserialize.
180 :type cls: type
181 :return: The deserializer function wrapped with the class if found, otherwise None.
182 :rtype: Optional[Callable]
183 """
184 if cls in self._deserializer_cache:
185 return self._deserializer_cache[cls]
186
187 handler = self._deserializer_types.get(cls)
188 if not handler:
189 for predicate, pred_handler in self._deserializer_predicates:
190 if predicate(cls):
191 handler = pred_handler
192 break
193
194 self._deserializer_cache[cls] = partial(handler, cls) if handler else None
195 return self._deserializer_cache[cls]
196
197
198def _timedelta_as_isostr(td: timedelta) -> str:
199 """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
200
201 Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython
202
203 :param td: The timedelta object to convert
204 :type td: datetime.timedelta
205 :return: An ISO 8601 formatted string representing the timedelta object
206 :rtype: str
207 """
208
209 # Split seconds to larger units
210 seconds = td.total_seconds()
211 minutes, seconds = divmod(seconds, 60)
212 hours, minutes = divmod(minutes, 60)
213 days, hours = divmod(hours, 24)
214
215 days, hours, minutes = list(map(int, (days, hours, minutes)))
216 seconds = round(seconds, 6)
217
218 # Build date
219 date_str = ""
220 if days:
221 date_str = "%sD" % days
222
223 # Build time
224 time_str = "T"
225
226 # Hours
227 bigger_exists = date_str or hours
228 if bigger_exists:
229 time_str += "{:02}H".format(hours)
230
231 # Minutes
232 bigger_exists = bigger_exists or minutes
233 if bigger_exists:
234 time_str += "{:02}M".format(minutes)
235
236 # Seconds
237 try:
238 if seconds.is_integer():
239 seconds_string = "{:02}".format(int(seconds))
240 else:
241 # 9 chars long w/ leading 0, 6 digits after decimal
242 seconds_string = "%09.6f" % seconds
243 # Remove trailing zeros
244 seconds_string = seconds_string.rstrip("0")
245 except AttributeError: # int.is_integer() raises
246 seconds_string = "{:02}".format(seconds)
247
248 time_str += "{}S".format(seconds_string)
249
250 return "P" + date_str + time_str
251
252
253def _datetime_as_isostr(dt: Union[datetime, date, time, timedelta]) -> str:
254 """Converts a datetime.(datetime|date|time|timedelta) object into an ISO 8601 formatted string.
255
256 :param dt: The datetime object to convert
257 :type dt: datetime.datetime or datetime.date or datetime.time or datetime.timedelta
258 :return: An ISO 8601 formatted string representing the datetime object
259 :rtype: str
260 """
261 # First try datetime.datetime
262 if hasattr(dt, "year") and hasattr(dt, "hour"):
263 dt = cast(datetime, dt)
264 # astimezone() fails for naive times in Python 2.7, so make make sure dt is aware (tzinfo is set)
265 if not dt.tzinfo:
266 iso_formatted = dt.replace(tzinfo=TZ_UTC).isoformat()
267 else:
268 iso_formatted = dt.astimezone(TZ_UTC).isoformat()
269 # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt)
270 return iso_formatted.replace("+00:00", "Z")
271 # Next try datetime.date or datetime.time
272 try:
273 dt = cast(Union[date, time], dt)
274 return dt.isoformat()
275 # Last, try datetime.timedelta
276 except AttributeError:
277 dt = cast(timedelta, dt)
278 return _timedelta_as_isostr(dt)
279
280
281class AzureJSONEncoder(JSONEncoder):
282 """A JSON encoder that's capable of serializing datetime objects and bytes."""
283
284 def default(self, o: Any) -> Any:
285 """Override the default method to handle datetime and bytes serialization.
286 :param o: The object to serialize.
287 :type o: any
288 :return: A JSON-serializable representation of the object.
289 :rtype: any
290 """
291 if isinstance(o, (bytes, bytearray)):
292 return base64.b64encode(o).decode()
293 try:
294 return _datetime_as_isostr(o)
295 except AttributeError:
296 pass
297 return super(AzureJSONEncoder, self).default(o)
298
299
300def is_generated_model(obj: Any) -> bool:
301 """Check if the object is a generated SDK model.
302
303 :param obj: The object to check.
304 :type obj: any
305 :return: True if the object is a generated SDK model, False otherwise.
306 :rtype: bool
307 """
308 return bool(getattr(obj, "_is_model", False) or hasattr(obj, "_attribute_map"))
309
310
311def _is_readonly(p: Any) -> bool:
312 """Check if an attribute is readonly.
313
314 :param any p: The property to check.
315 :return: True if the property is readonly, False otherwise.
316 :rtype: bool
317 """
318 try:
319 return p._visibility == ["read"] # pylint: disable=protected-access
320 except AttributeError:
321 return False
322
323
324def _as_attribute_dict_value(v: Any, *, exclude_readonly: bool = False) -> Any:
325 if v is None or isinstance(v, _Null):
326 return None
327 if isinstance(v, (list, tuple, set)):
328 return type(v)(_as_attribute_dict_value(x, exclude_readonly=exclude_readonly) for x in v)
329 if isinstance(v, dict):
330 return {dk: _as_attribute_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()}
331 return as_attribute_dict(v, exclude_readonly=exclude_readonly) if is_generated_model(v) else v
332
333
334def _get_flattened_attribute(obj: Any) -> Optional[str]:
335 """Get the name of the flattened attribute in a generated TypeSpec model if one exists.
336
337 :param any obj: The object to check.
338 :return: The name of the flattened attribute if it exists, otherwise None.
339 :rtype: Optional[str]
340 """
341 flattened_items = None
342 try:
343 flattened_items = getattr(obj, next(a for a in dir(obj) if "__flattened_items" in a), None)
344 except StopIteration:
345 return None
346
347 if flattened_items is None:
348 return None
349
350 for k, v in obj._attr_to_rest_field.items(): # pylint: disable=protected-access
351 try:
352 if set(v._class_type._attr_to_rest_field.keys()).intersection( # pylint: disable=protected-access
353 set(flattened_items)
354 ):
355 return k
356 except AttributeError:
357 # if the attribute does not have _class_type, it is not a typespec generated model
358 continue
359 return None
360
361
362def attribute_list(obj: Any) -> List[str]:
363 """Get a list of attribute names for a generated SDK model.
364
365 :param obj: The object to get attributes from.
366 :type obj: any
367 :return: A list of attribute names.
368 :rtype: List[str]
369 """
370 if not is_generated_model(obj):
371 raise TypeError("Object is not a generated SDK model.")
372 if hasattr(obj, "_attribute_map"):
373 # msrest model
374 return list(obj._attribute_map.keys()) # pylint: disable=protected-access
375 flattened_attribute = _get_flattened_attribute(obj)
376 retval: List[str] = []
377 for attr_name, rest_field in obj._attr_to_rest_field.items(): # pylint: disable=protected-access
378 if flattened_attribute == attr_name:
379 retval.extend(attribute_list(rest_field._class_type)) # pylint: disable=protected-access
380 else:
381 retval.append(attr_name)
382 return retval
383
384
385def as_attribute_dict(obj: Any, *, exclude_readonly: bool = False) -> Dict[str, Any]:
386 """Convert an object to a dictionary of its attributes.
387
388 Made solely for backcompatibility with the legacy `.as_dict()` on msrest models.
389
390 .. deprecated::1.35.0
391 This function is added for backcompat purposes only.
392
393 :param any obj: The object to convert to a dictionary
394 :keyword bool exclude_readonly: Whether to exclude readonly properties
395 :return: A dictionary containing the object's attributes
396 :rtype: dict[str, any]
397 :raises TypeError: If the object is not a generated model instance
398 """
399 if not is_generated_model(obj):
400 raise TypeError("Object must be a generated model instance.")
401 if hasattr(obj, "_attribute_map"):
402 # msrest generated model
403 return obj.as_dict(keep_readonly=not exclude_readonly)
404 try:
405 # now we're a typespec generated model
406 result = {}
407 readonly_props = set()
408
409 # create a reverse mapping from rest field name to attribute name
410 rest_to_attr = {}
411 flattened_attribute = _get_flattened_attribute(obj)
412 for attr_name, rest_field in obj._attr_to_rest_field.items(): # pylint: disable=protected-access
413
414 if exclude_readonly and _is_readonly(rest_field):
415 # if we're excluding readonly properties, we need to track them
416 readonly_props.add(rest_field._rest_name) # pylint: disable=protected-access
417 if flattened_attribute == attr_name:
418 for fk, fv in rest_field._class_type._attr_to_rest_field.items(): # pylint: disable=protected-access
419 rest_to_attr[fv._rest_name] = fk # pylint: disable=protected-access
420 else:
421 rest_to_attr[rest_field._rest_name] = attr_name # pylint: disable=protected-access
422 for k, v in obj.items():
423 if exclude_readonly and k in readonly_props: # pyright: ignore
424 continue
425 if k == flattened_attribute:
426 for fk, fv in v.items():
427 result[rest_to_attr.get(fk, fk)] = _as_attribute_dict_value(fv, exclude_readonly=exclude_readonly)
428 else:
429 is_multipart_file_input = False
430 try:
431 is_multipart_file_input = next( # pylint: disable=protected-access
432 rf
433 for rf in obj._attr_to_rest_field.values() # pylint: disable=protected-access
434 if rf._rest_name == k # pylint: disable=protected-access
435 )._is_multipart_file_input
436 except StopIteration:
437 pass
438
439 result[rest_to_attr.get(k, k)] = (
440 v if is_multipart_file_input else _as_attribute_dict_value(v, exclude_readonly=exclude_readonly)
441 )
442 return result
443 except AttributeError as exc:
444 # not a typespec generated model
445 raise TypeError("Object must be a generated model instance.") from exc