1# coding=utf-8
2# --------------------------------------------------------------------------
3# Copyright (c) Microsoft Corporation. All rights reserved.
4# Licensed under the MIT License. See License.txt in the project root for
5# license information.
6# --------------------------------------------------------------------------
7import base64
8from functools import partial
9from json import JSONEncoder
10from typing import Dict, List, Optional, Union, cast, Any, Type, Callable, Tuple
11from datetime import datetime, date, time, timedelta
12from datetime import timezone
13
14
15__all__ = [
16 "NULL",
17 "AzureJSONEncoder",
18 "is_generated_model",
19 "as_attribute_dict",
20 "attribute_list",
21 "TypeHandlerRegistry",
22]
23TZ_UTC = timezone.utc
24
25
26class _Null:
27 """To create a Falsy object"""
28
29 def __bool__(self) -> bool:
30 return False
31
32
33NULL = _Null()
34"""
35A falsy sentinel object which is supposed to be used to specify attributes
36with no data. This gets serialized to `null` on the wire.
37"""
38
39
40class TypeHandlerRegistry:
41 """A registry for custom serializers and deserializers for specific types or conditions."""
42
43 def __init__(self) -> None:
44 self._serializer_types: Dict[Type, Callable] = {}
45 self._deserializer_types: Dict[Type, Callable] = {}
46 self._serializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
47 self._deserializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
48
49 self._serializer_cache: Dict[Type, Optional[Callable]] = {}
50 self._deserializer_cache: Dict[Type, Optional[Callable]] = {}
51
52 def register_serializer(
53 self, condition: Union[Type, Callable[[Any], bool]]
54 ) -> Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]:
55 """Decorator to register a serializer.
56
57 The handler function is expected to take a single argument, the object to serialize,
58 and return a dictionary representation of that object.
59
60 Examples:
61
62 .. code-block:: python
63
64 @registry.register_serializer(CustomModel)
65 def serialize_single_type(value: CustomModel) -> dict:
66 return value.to_dict()
67
68 @registry.register_serializer(lambda x: isinstance(x, BaseModel))
69 def serialize_with_condition(value: BaseModel) -> dict:
70 return value.to_dict()
71
72 # Called manually for a specific type
73 def custom_serializer(value: CustomModel) -> Dict[str, Any]:
74 return {"custom": value.custom}
75
76 registry.register_serializer(CustomModel)(custom_serializer)
77
78 :param condition: A type or a callable predicate function that takes an object and returns a bool.
79 :type condition: Union[Type, Callable[[Any], bool]]
80 :return: A decorator that registers the handler function.
81 :rtype: Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]
82 :raises TypeError: If the condition is neither a type nor a callable.
83 """
84
85 def decorator(handler_func: Callable[[Any], Dict[str, Any]]) -> Callable[[Any], Dict[str, Any]]:
86 if isinstance(condition, type):
87 self._serializer_types[condition] = handler_func
88 elif callable(condition):
89 self._serializer_predicates.append((condition, handler_func))
90 else:
91 raise TypeError("Condition must be a type or a callable predicate function.")
92
93 self._serializer_cache.clear()
94 return handler_func
95
96 return decorator
97
98 def register_deserializer(
99 self, condition: Union[Type, Callable[[Any], bool]]
100 ) -> Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]:
101 """Decorator to register a deserializer.
102
103 The handler function is expected to take two arguments: the target type and the data dictionary,
104 and return an instance of the target type.
105
106 Examples:
107
108 .. code-block:: python
109
110 @registry.register_deserializer(CustomModel)
111 def deserialize_single_type(cls: Type[CustomModel], data: dict) -> CustomModel:
112 return cls(**data)
113
114 @registry.register_deserializer(lambda t: issubclass(t, BaseModel))
115 def deserialize_with_condition(cls: Type[BaseModel], data: dict) -> BaseModel:
116 return cls(**data)
117
118 # Called manually for a specific type
119 def custom_deserializer(cls: Type[CustomModel], data: Dict[str, Any]) -> CustomModel:
120 return cls(custom=data["custom"])
121
122 registry.register_deserializer(CustomModel)(custom_deserializer)
123
124 :param condition: A type or a callable predicate function that takes an object and returns a bool.
125 :type condition: Union[Type, Callable[[Any], bool]]
126 :return: A decorator that registers the handler function.
127 :rtype: Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]
128 :raises TypeError: If the condition is neither a type nor a callable.
129 """
130
131 def decorator(handler_func: Callable[[Type, Dict[str, Any]], Any]) -> Callable[[Type, Dict[str, Any]], Any]:
132 if isinstance(condition, type):
133 self._deserializer_types[condition] = handler_func
134 elif callable(condition):
135 self._deserializer_predicates.append((condition, handler_func))
136 else:
137 raise TypeError("Condition must be a type or a callable predicate function.")
138
139 self._deserializer_cache.clear()
140 return handler_func
141
142 return decorator
143
144 def get_serializer(self, obj: Any) -> Optional[Callable[[Any], Dict[str, Any]]]:
145 """Gets the appropriate serializer for an object.
146
147 It first checks the type dictionary for a direct type match.
148 If no match is found, it iterates through the predicate list to find a match.
149
150 Results of the lookup are cached for performance based on the object's type.
151
152 :param obj: The object to serialize.
153 :type obj: any
154 :return: The serializer function if found, otherwise None.
155 :rtype: Optional[Callable[[Any], Dict[str, Any]]]
156 """
157 obj_type = type(obj)
158 if obj_type in self._serializer_cache:
159 return self._serializer_cache[obj_type]
160
161 handler = self._serializer_types.get(type(obj))
162 if not handler:
163 for predicate, pred_handler in self._serializer_predicates:
164 if predicate(obj):
165 handler = pred_handler
166 break
167
168 self._serializer_cache[obj_type] = handler
169 return handler
170
171 def get_deserializer(self, cls: Type) -> Optional[Callable[[Dict[str, Any]], Any]]:
172 """Gets the appropriate deserializer for a class.
173
174 It first checks the type dictionary for a direct type match.
175 If no match is found, it iterates through the predicate list to find a match.
176
177 Results of the lookup are cached for performance based on the class.
178
179 :param cls: The class to deserialize.
180 :type cls: type
181 :return: A deserializer function bound to the specified class that takes a dictionary and returns
182 an instance of that class, or None if no deserializer is found.
183 :rtype: Optional[Callable[[Dict[str, Any]], Any]]
184 """
185 if cls in self._deserializer_cache:
186 return self._deserializer_cache[cls]
187
188 handler = self._deserializer_types.get(cls)
189 if not handler:
190 for predicate, pred_handler in self._deserializer_predicates:
191 if predicate(cls):
192 handler = pred_handler
193 break
194
195 self._deserializer_cache[cls] = partial(handler, cls) if handler else None
196 return self._deserializer_cache[cls]
197
198
199def _timedelta_as_isostr(td: timedelta) -> str:
200 """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
201
202 Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython
203
204 :param td: The timedelta object to convert
205 :type td: datetime.timedelta
206 :return: An ISO 8601 formatted string representing the timedelta object
207 :rtype: str
208 """
209
210 # Split seconds to larger units
211 seconds = td.total_seconds()
212 minutes, seconds = divmod(seconds, 60)
213 hours, minutes = divmod(minutes, 60)
214 days, hours = divmod(hours, 24)
215
216 days, hours, minutes = list(map(int, (days, hours, minutes)))
217 seconds = round(seconds, 6)
218
219 # Build date
220 date_str = ""
221 if days:
222 date_str = "%sD" % days
223
224 # Build time
225 time_str = "T"
226
227 # Hours
228 bigger_exists = date_str or hours
229 if bigger_exists:
230 time_str += "{:02}H".format(hours)
231
232 # Minutes
233 bigger_exists = bigger_exists or minutes
234 if bigger_exists:
235 time_str += "{:02}M".format(minutes)
236
237 # Seconds
238 try:
239 if seconds.is_integer():
240 seconds_string = "{:02}".format(int(seconds))
241 else:
242 # 9 chars long w/ leading 0, 6 digits after decimal
243 seconds_string = "%09.6f" % seconds
244 # Remove trailing zeros
245 seconds_string = seconds_string.rstrip("0")
246 except AttributeError: # int.is_integer() raises
247 seconds_string = "{:02}".format(seconds)
248
249 time_str += "{}S".format(seconds_string)
250
251 return "P" + date_str + time_str
252
253
254def _datetime_as_isostr(dt: Union[datetime, date, time, timedelta]) -> str:
255 """Converts a datetime.(datetime|date|time|timedelta) object into an ISO 8601 formatted string.
256
257 :param dt: The datetime object to convert
258 :type dt: datetime.datetime or datetime.date or datetime.time or datetime.timedelta
259 :return: An ISO 8601 formatted string representing the datetime object
260 :rtype: str
261 """
262 # First try datetime.datetime
263 if hasattr(dt, "year") and hasattr(dt, "hour"):
264 dt = cast(datetime, dt)
265 # astimezone() fails for naive times in Python 2.7, so make make sure dt is aware (tzinfo is set)
266 if not dt.tzinfo:
267 iso_formatted = dt.replace(tzinfo=TZ_UTC).isoformat()
268 else:
269 iso_formatted = dt.astimezone(TZ_UTC).isoformat()
270 # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt)
271 return iso_formatted.replace("+00:00", "Z")
272 # Next try datetime.date or datetime.time
273 try:
274 dt = cast(Union[date, time], dt)
275 return dt.isoformat()
276 # Last, try datetime.timedelta
277 except AttributeError:
278 dt = cast(timedelta, dt)
279 return _timedelta_as_isostr(dt)
280
281
282class AzureJSONEncoder(JSONEncoder):
283 """A JSON encoder that's capable of serializing datetime objects and bytes."""
284
285 def default(self, o: Any) -> Any:
286 """Override the default method to handle datetime and bytes serialization.
287 :param o: The object to serialize.
288 :type o: any
289 :return: A JSON-serializable representation of the object.
290 :rtype: any
291 """
292 if isinstance(o, (bytes, bytearray)):
293 return base64.b64encode(o).decode()
294 try:
295 return _datetime_as_isostr(o)
296 except AttributeError:
297 pass
298 return super(AzureJSONEncoder, self).default(o)
299
300
301def is_generated_model(obj: Any) -> bool:
302 """Check if the object is a generated SDK model.
303
304 :param obj: The object to check.
305 :type obj: any
306 :return: True if the object is a generated SDK model, False otherwise.
307 :rtype: bool
308 """
309 return bool(getattr(obj, "_is_model", False) or hasattr(obj, "_attribute_map"))
310
311
312def _is_readonly(p: Any) -> bool:
313 """Check if an attribute is readonly.
314
315 :param any p: The property to check.
316 :return: True if the property is readonly, False otherwise.
317 :rtype: bool
318 """
319 try:
320 return p._visibility == ["read"] # pylint: disable=protected-access
321 except AttributeError:
322 return False
323
324
325def _as_attribute_dict_value(v: Any, *, exclude_readonly: bool = False) -> Any:
326 if v is None or isinstance(v, _Null):
327 return None
328 if isinstance(v, (list, tuple, set)):
329 return type(v)(_as_attribute_dict_value(x, exclude_readonly=exclude_readonly) for x in v)
330 if isinstance(v, dict):
331 return {dk: _as_attribute_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()}
332 return as_attribute_dict(v, exclude_readonly=exclude_readonly) if is_generated_model(v) else v
333
334
335def _get_flattened_attribute(obj: Any) -> Optional[str]:
336 """Get the name of the flattened attribute in a generated TypeSpec model if one exists.
337
338 :param any obj: The object to check.
339 :return: The name of the flattened attribute if it exists, otherwise None.
340 :rtype: Optional[str]
341 """
342 flattened_items = None
343 try:
344 flattened_items = getattr(obj, next(a for a in dir(obj) if "__flattened_items" in a), None)
345 except StopIteration:
346 return None
347
348 if flattened_items is None:
349 return None
350
351 for k, v in obj._attr_to_rest_field.items(): # pylint: disable=protected-access
352 try:
353 if set(v._class_type._attr_to_rest_field.keys()).intersection( # pylint: disable=protected-access
354 set(flattened_items)
355 ):
356 return k
357 except AttributeError:
358 # if the attribute does not have _class_type, it is not a typespec generated model
359 continue
360 return None
361
362
363def attribute_list(obj: Any) -> List[str]:
364 """Get a list of attribute names for a generated SDK model.
365
366 :param obj: The object to get attributes from.
367 :type obj: any
368 :return: A list of attribute names.
369 :rtype: List[str]
370 """
371 if not is_generated_model(obj):
372 raise TypeError("Object is not a generated SDK model.")
373 if hasattr(obj, "_attribute_map"):
374 # msrest model
375 return list(obj._attribute_map.keys()) # pylint: disable=protected-access
376 flattened_attribute = _get_flattened_attribute(obj)
377 retval: List[str] = []
378 for attr_name, rest_field in obj._attr_to_rest_field.items(): # pylint: disable=protected-access
379 if flattened_attribute == attr_name:
380 retval.extend(attribute_list(rest_field._class_type)) # pylint: disable=protected-access
381 else:
382 retval.append(attr_name)
383 return retval
384
385
386def as_attribute_dict(obj: Any, *, exclude_readonly: bool = False) -> Dict[str, Any]:
387 """Convert an object to a dictionary of its attributes.
388
389 Made solely for backcompatibility with the legacy `.as_dict()` on msrest models.
390
391 .. deprecated::1.35.0
392 This function is added for backcompat purposes only.
393
394 :param any obj: The object to convert to a dictionary
395 :keyword bool exclude_readonly: Whether to exclude readonly properties
396 :return: A dictionary containing the object's attributes
397 :rtype: dict[str, any]
398 :raises TypeError: If the object is not a generated model instance
399 """
400 if not is_generated_model(obj):
401 raise TypeError("Object must be a generated model instance.")
402 if hasattr(obj, "_attribute_map"):
403 # msrest generated model
404 return obj.as_dict(keep_readonly=not exclude_readonly)
405 try:
406 # now we're a typespec generated model
407 result = {}
408 readonly_props = set()
409
410 # create a reverse mapping from rest field name to attribute name
411 rest_to_attr = {}
412 flattened_attribute = _get_flattened_attribute(obj)
413 for attr_name, rest_field in obj._attr_to_rest_field.items(): # pylint: disable=protected-access
414
415 if exclude_readonly and _is_readonly(rest_field):
416 # if we're excluding readonly properties, we need to track them
417 readonly_props.add(rest_field._rest_name) # pylint: disable=protected-access
418 if flattened_attribute == attr_name:
419 for fk, fv in rest_field._class_type._attr_to_rest_field.items(): # pylint: disable=protected-access
420 rest_to_attr[fv._rest_name] = fk # pylint: disable=protected-access
421 else:
422 rest_to_attr[rest_field._rest_name] = attr_name # pylint: disable=protected-access
423 for k, v in obj.items():
424 if exclude_readonly and k in readonly_props: # pyright: ignore
425 continue
426 if k == flattened_attribute:
427 for fk, fv in v.items():
428 result[rest_to_attr.get(fk, fk)] = _as_attribute_dict_value(fv, exclude_readonly=exclude_readonly)
429 else:
430 is_multipart_file_input = False
431 try:
432 is_multipart_file_input = next( # pylint: disable=protected-access
433 rf
434 for rf in obj._attr_to_rest_field.values() # pylint: disable=protected-access
435 if rf._rest_name == k # pylint: disable=protected-access
436 )._is_multipart_file_input
437 except StopIteration:
438 pass
439
440 result[rest_to_attr.get(k, k)] = (
441 v if is_multipart_file_input else _as_attribute_dict_value(v, exclude_readonly=exclude_readonly)
442 )
443 return result
444 except AttributeError as exc:
445 # not a typespec generated model
446 raise TypeError("Object must be a generated model instance.") from exc