1# coding=utf-8
2# --------------------------------------------------------------------------
3# Copyright (c) Microsoft Corporation. All rights reserved.
4# Licensed under the MIT License. See License.txt in the project root for
5# license information.
6# --------------------------------------------------------------------------
7# pylint: disable=protected-access
8import base64
9from functools import partial
10from json import JSONEncoder
11from typing import Dict, List, Optional, Union, cast, Any, Type, Callable, Tuple
12from datetime import datetime, date, time, timedelta
13from datetime import timezone
14
15
16__all__ = [
17 "NULL",
18 "AzureJSONEncoder",
19 "is_generated_model",
20 "as_attribute_dict",
21 "attribute_list",
22 "TypeHandlerRegistry",
23 "get_backcompat_attr_name",
24]
25TZ_UTC = timezone.utc
26
27
28class _Null:
29 """To create a Falsy object"""
30
31 def __bool__(self) -> bool:
32 return False
33
34
35NULL = _Null()
36"""
37A falsy sentinel object which is supposed to be used to specify attributes
38with no data. This gets serialized to `null` on the wire.
39"""
40
41
42class TypeHandlerRegistry:
43 """A registry for custom serializers and deserializers for specific types or conditions."""
44
45 def __init__(self) -> None:
46 self._serializer_types: Dict[Type, Callable] = {}
47 self._deserializer_types: Dict[Type, Callable] = {}
48 self._serializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
49 self._deserializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = []
50
51 self._serializer_cache: Dict[Type, Optional[Callable]] = {}
52 self._deserializer_cache: Dict[Type, Optional[Callable]] = {}
53
54 def register_serializer(
55 self, condition: Union[Type, Callable[[Any], bool]]
56 ) -> Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]:
57 """Decorator to register a serializer.
58
59 The handler function is expected to take a single argument, the object to serialize,
60 and return a dictionary representation of that object.
61
62 Examples:
63
64 .. code-block:: python
65
66 @registry.register_serializer(CustomModel)
67 def serialize_single_type(value: CustomModel) -> dict:
68 return value.to_dict()
69
70 @registry.register_serializer(lambda x: isinstance(x, BaseModel))
71 def serialize_with_condition(value: BaseModel) -> dict:
72 return value.to_dict()
73
74 # Called manually for a specific type
75 def custom_serializer(value: CustomModel) -> Dict[str, Any]:
76 return {"custom": value.custom}
77
78 registry.register_serializer(CustomModel)(custom_serializer)
79
80 :param condition: A type or a callable predicate function that takes an object and returns a bool.
81 :type condition: Union[Type, Callable[[Any], bool]]
82 :return: A decorator that registers the handler function.
83 :rtype: Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]
84 :raises TypeError: If the condition is neither a type nor a callable.
85 """
86
87 def decorator(handler_func: Callable[[Any], Dict[str, Any]]) -> Callable[[Any], Dict[str, Any]]:
88 if isinstance(condition, type):
89 self._serializer_types[condition] = handler_func
90 elif callable(condition):
91 self._serializer_predicates.append((condition, handler_func))
92 else:
93 raise TypeError("Condition must be a type or a callable predicate function.")
94
95 self._serializer_cache.clear()
96 return handler_func
97
98 return decorator
99
100 def register_deserializer(
101 self, condition: Union[Type, Callable[[Any], bool]]
102 ) -> Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]:
103 """Decorator to register a deserializer.
104
105 The handler function is expected to take two arguments: the target type and the data dictionary,
106 and return an instance of the target type.
107
108 Examples:
109
110 .. code-block:: python
111
112 @registry.register_deserializer(CustomModel)
113 def deserialize_single_type(cls: Type[CustomModel], data: dict) -> CustomModel:
114 return cls(**data)
115
116 @registry.register_deserializer(lambda t: issubclass(t, BaseModel))
117 def deserialize_with_condition(cls: Type[BaseModel], data: dict) -> BaseModel:
118 return cls(**data)
119
120 # Called manually for a specific type
121 def custom_deserializer(cls: Type[CustomModel], data: Dict[str, Any]) -> CustomModel:
122 return cls(custom=data["custom"])
123
124 registry.register_deserializer(CustomModel)(custom_deserializer)
125
126 :param condition: A type or a callable predicate function that takes an object and returns a bool.
127 :type condition: Union[Type, Callable[[Any], bool]]
128 :return: A decorator that registers the handler function.
129 :rtype: Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]
130 :raises TypeError: If the condition is neither a type nor a callable.
131 """
132
133 def decorator(handler_func: Callable[[Type, Dict[str, Any]], Any]) -> Callable[[Type, Dict[str, Any]], Any]:
134 if isinstance(condition, type):
135 self._deserializer_types[condition] = handler_func
136 elif callable(condition):
137 self._deserializer_predicates.append((condition, handler_func))
138 else:
139 raise TypeError("Condition must be a type or a callable predicate function.")
140
141 self._deserializer_cache.clear()
142 return handler_func
143
144 return decorator
145
146 def get_serializer(self, obj: Any) -> Optional[Callable[[Any], Dict[str, Any]]]:
147 """Gets the appropriate serializer for an object.
148
149 It first checks the type dictionary for a direct type match.
150 If no match is found, it iterates through the predicate list to find a match.
151
152 Results of the lookup are cached for performance based on the object's type.
153
154 :param obj: The object to serialize.
155 :type obj: any
156 :return: The serializer function if found, otherwise None.
157 :rtype: Optional[Callable[[Any], Dict[str, Any]]]
158 """
159 obj_type = type(obj)
160 if obj_type in self._serializer_cache:
161 return self._serializer_cache[obj_type]
162
163 handler = self._serializer_types.get(type(obj))
164 if not handler:
165 for predicate, pred_handler in self._serializer_predicates:
166 if predicate(obj):
167 handler = pred_handler
168 break
169
170 self._serializer_cache[obj_type] = handler
171 return handler
172
173 def get_deserializer(self, cls: Type) -> Optional[Callable[[Dict[str, Any]], Any]]:
174 """Gets the appropriate deserializer for a class.
175
176 It first checks the type dictionary for a direct type match.
177 If no match is found, it iterates through the predicate list to find a match.
178
179 Results of the lookup are cached for performance based on the class.
180
181 :param cls: The class to deserialize.
182 :type cls: type
183 :return: A deserializer function bound to the specified class that takes a dictionary and returns
184 an instance of that class, or None if no deserializer is found.
185 :rtype: Optional[Callable[[Dict[str, Any]], Any]]
186 """
187 if cls in self._deserializer_cache:
188 return self._deserializer_cache[cls]
189
190 handler = self._deserializer_types.get(cls)
191 if not handler:
192 for predicate, pred_handler in self._deserializer_predicates:
193 if predicate(cls):
194 handler = pred_handler
195 break
196
197 self._deserializer_cache[cls] = partial(handler, cls) if handler else None
198 return self._deserializer_cache[cls]
199
200
201def _timedelta_as_isostr(td: timedelta) -> str:
202 """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
203
204 Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython
205
206 :param td: The timedelta object to convert
207 :type td: datetime.timedelta
208 :return: An ISO 8601 formatted string representing the timedelta object
209 :rtype: str
210 """
211
212 # Split seconds to larger units
213 seconds = td.total_seconds()
214 minutes, seconds = divmod(seconds, 60)
215 hours, minutes = divmod(minutes, 60)
216 days, hours = divmod(hours, 24)
217
218 days, hours, minutes = list(map(int, (days, hours, minutes)))
219 seconds = round(seconds, 6)
220
221 # Build date
222 date_str = ""
223 if days:
224 date_str = "%sD" % days
225
226 # Build time
227 time_str = "T"
228
229 # Hours
230 bigger_exists = date_str or hours
231 if bigger_exists:
232 time_str += "{:02}H".format(hours)
233
234 # Minutes
235 bigger_exists = bigger_exists or minutes
236 if bigger_exists:
237 time_str += "{:02}M".format(minutes)
238
239 # Seconds
240 try:
241 if seconds.is_integer():
242 seconds_string = "{:02}".format(int(seconds))
243 else:
244 # 9 chars long w/ leading 0, 6 digits after decimal
245 seconds_string = "%09.6f" % seconds
246 # Remove trailing zeros
247 seconds_string = seconds_string.rstrip("0")
248 except AttributeError: # int.is_integer() raises
249 seconds_string = "{:02}".format(seconds)
250
251 time_str += "{}S".format(seconds_string)
252
253 return "P" + date_str + time_str
254
255
256def _datetime_as_isostr(dt: Union[datetime, date, time, timedelta]) -> str:
257 """Converts a datetime.(datetime|date|time|timedelta) object into an ISO 8601 formatted string.
258
259 :param dt: The datetime object to convert
260 :type dt: datetime.datetime or datetime.date or datetime.time or datetime.timedelta
261 :return: An ISO 8601 formatted string representing the datetime object
262 :rtype: str
263 """
264 # First try datetime.datetime
265 if hasattr(dt, "year") and hasattr(dt, "hour"):
266 dt = cast(datetime, dt)
267 # astimezone() fails for naive times in Python 2.7, so make make sure dt is aware (tzinfo is set)
268 if not dt.tzinfo:
269 iso_formatted = dt.replace(tzinfo=TZ_UTC).isoformat()
270 else:
271 iso_formatted = dt.astimezone(TZ_UTC).isoformat()
272 # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt)
273 return iso_formatted.replace("+00:00", "Z")
274 # Next try datetime.date or datetime.time
275 try:
276 dt = cast(Union[date, time], dt)
277 return dt.isoformat()
278 # Last, try datetime.timedelta
279 except AttributeError:
280 dt = cast(timedelta, dt)
281 return _timedelta_as_isostr(dt)
282
283
284class AzureJSONEncoder(JSONEncoder):
285 """A JSON encoder that's capable of serializing datetime objects and bytes."""
286
287 def default(self, o: Any) -> Any:
288 """Override the default method to handle datetime and bytes serialization.
289 :param o: The object to serialize.
290 :type o: any
291 :return: A JSON-serializable representation of the object.
292 :rtype: any
293 """
294 if isinstance(o, (bytes, bytearray)):
295 return base64.b64encode(o).decode()
296 try:
297 return _datetime_as_isostr(o)
298 except AttributeError:
299 pass
300 return super(AzureJSONEncoder, self).default(o)
301
302
303def is_generated_model(obj: Any) -> bool:
304 """Check if the object is a generated SDK model.
305
306 :param obj: The object to check.
307 :type obj: any
308 :return: True if the object is a generated SDK model, False otherwise.
309 :rtype: bool
310 """
311 return bool(getattr(obj, "_is_model", False) or hasattr(obj, "_attribute_map"))
312
313
314def _is_readonly(p: Any) -> bool:
315 """Check if an attribute is readonly.
316
317 :param any p: The property to check.
318 :return: True if the property is readonly, False otherwise.
319 :rtype: bool
320 """
321 try:
322 return p._visibility == ["read"]
323 except AttributeError:
324 return False
325
326
327def _as_attribute_dict_value(v: Any, *, exclude_readonly: bool = False) -> Any:
328 if v is None or isinstance(v, _Null):
329 return None
330 if isinstance(v, (list, tuple, set)):
331 return type(v)(_as_attribute_dict_value(x, exclude_readonly=exclude_readonly) for x in v)
332 if isinstance(v, dict):
333 return {dk: _as_attribute_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()}
334 return as_attribute_dict(v, exclude_readonly=exclude_readonly) if is_generated_model(v) else v
335
336
337def _get_backcompat_name(rest_field: Any, default_attr_name: str) -> str:
338 """Get the backcompat name for an attribute.
339
340 :param any rest_field: The rest field to get the backcompat name from.
341 :param str default_attr_name: The default attribute name to use if no backcompat name
342 :return: The backcompat name.
343 :rtype: str
344 """
345 try:
346 return rest_field._original_tsp_name or default_attr_name
347 except AttributeError:
348 return default_attr_name
349
350
351def _get_flattened_attribute(obj: Any) -> Optional[str]:
352 """Get the name of the flattened attribute in a generated TypeSpec model if one exists.
353
354 :param any obj: The object to check.
355 :return: The name of the flattened attribute if it exists, otherwise None.
356 :rtype: Optional[str]
357 """
358 flattened_items = None
359 try:
360 flattened_items = getattr(obj, next(a for a in dir(obj) if "__flattened_items" in a), None)
361 except StopIteration:
362 return None
363
364 if flattened_items is None:
365 return None
366
367 for k, v in obj._attr_to_rest_field.items():
368 try:
369 if set(v._class_type._attr_to_rest_field.keys()).intersection(set(flattened_items)):
370 return k
371 except AttributeError:
372 # if the attribute does not have _class_type, it is not a typespec generated model
373 continue
374 return None
375
376
377def attribute_list(obj: Any) -> List[str]:
378 """Get a list of attribute names for a generated SDK model.
379
380 :param obj: The object to get attributes from.
381 :type obj: any
382 :return: A list of attribute names.
383 :rtype: List[str]
384 """
385 if not is_generated_model(obj):
386 raise TypeError("Object is not a generated SDK model.")
387 if hasattr(obj, "_attribute_map"):
388 # msrest model
389 return list(obj._attribute_map.keys())
390 flattened_attribute = _get_flattened_attribute(obj)
391 retval: List[str] = []
392 for attr_name, rest_field in obj._attr_to_rest_field.items():
393 if flattened_attribute == attr_name:
394 retval.extend(attribute_list(rest_field._class_type))
395 else:
396 retval.append(attr_name)
397 return retval
398
399
400def as_attribute_dict(obj: Any, *, exclude_readonly: bool = False) -> Dict[str, Any]:
401 """Convert an object to a dictionary of its attributes.
402
403 Made solely for backcompatibility with the legacy `.as_dict()` on msrest models.
404
405 .. deprecated::1.35.0
406 This function is added for backcompat purposes only.
407
408 :param any obj: The object to convert to a dictionary
409 :keyword bool exclude_readonly: Whether to exclude readonly properties
410 :return: A dictionary containing the object's attributes
411 :rtype: dict[str, any]
412 :raises TypeError: If the object is not a generated model instance
413 """
414 if not is_generated_model(obj):
415 raise TypeError("Object must be a generated model instance.")
416 if hasattr(obj, "_attribute_map"):
417 # msrest generated model
418 return obj.as_dict(keep_readonly=not exclude_readonly)
419 try:
420 # now we're a typespec generated model
421 result = {}
422 readonly_props = set()
423
424 # create a reverse mapping from rest field name to attribute name
425 rest_to_attr = {}
426 flattened_attribute = _get_flattened_attribute(obj)
427 for attr_name, rest_field in obj._attr_to_rest_field.items():
428
429 if exclude_readonly and _is_readonly(rest_field):
430 # if we're excluding readonly properties, we need to track them
431 readonly_props.add(rest_field._rest_name)
432 if flattened_attribute == attr_name:
433 for fk, fv in rest_field._class_type._attr_to_rest_field.items():
434 rest_to_attr[fv._rest_name] = fk
435 else:
436 rest_to_attr[rest_field._rest_name] = attr_name
437 for k, v in obj.items():
438 if exclude_readonly and k in readonly_props: # pyright: ignore
439 continue
440 if k == flattened_attribute:
441 for fk, fv in v.items():
442 result[rest_to_attr.get(fk, fk)] = _as_attribute_dict_value(fv, exclude_readonly=exclude_readonly)
443 else:
444 is_multipart_file_input = False
445 try:
446 is_multipart_file_input = next(
447 rf for rf in obj._attr_to_rest_field.values() if rf._rest_name == k
448 )._is_multipart_file_input
449 except StopIteration:
450 pass
451
452 result[rest_to_attr.get(k, k)] = (
453 v if is_multipart_file_input else _as_attribute_dict_value(v, exclude_readonly=exclude_readonly)
454 )
455 return result
456 except AttributeError as exc:
457 # not a typespec generated model
458 raise TypeError("Object must be a generated model instance.") from exc
459
460
461def get_backcompat_attr_name(model: Any, attr_name: str) -> str:
462 """Get the backcompat attribute name for a given attribute.
463
464 This function takes an attribute name and returns the backcompat name (original TSP name)
465 if one exists, otherwise returns the attribute name itself.
466
467 :param any model: The model instance.
468 :param str attr_name: The attribute name to get the backcompat name for.
469 :return: The backcompat attribute name (original TSP name) or the attribute name itself.
470 :rtype: str
471 """
472 if not is_generated_model(model):
473 raise TypeError("Object must be a generated model instance.")
474
475 # Check if attr_name exists in the model's attributes
476 flattened_attribute = _get_flattened_attribute(model)
477 for field_attr_name, rest_field in model._attr_to_rest_field.items():
478 # Check if this is the attribute we're looking for
479 if field_attr_name == attr_name:
480 # Return the original TSP name if it exists, otherwise the attribute name
481 return _get_backcompat_name(rest_field, attr_name)
482
483 # If this is a flattened attribute, check inside it
484 if flattened_attribute == field_attr_name:
485 for fk, fv in rest_field._class_type._attr_to_rest_field.items():
486 if fk == attr_name:
487 # Return the original TSP name for this flattened property
488 return _get_backcompat_name(fv, fk)
489
490 # If not found in the model, just return the attribute name as-is
491 return attr_name