Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/azure/core/serialization.py: 15%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

204 statements  

1# coding=utf-8 

2# -------------------------------------------------------------------------- 

3# Copyright (c) Microsoft Corporation. All rights reserved. 

4# Licensed under the MIT License. See License.txt in the project root for 

5# license information. 

6# -------------------------------------------------------------------------- 

7# pylint: disable=protected-access 

8import base64 

9from functools import partial 

10from json import JSONEncoder 

11from typing import Dict, List, Optional, Union, cast, Any, Type, Callable, Tuple 

12from datetime import datetime, date, time, timedelta 

13from datetime import timezone 

14 

15 

16__all__ = [ 

17 "NULL", 

18 "AzureJSONEncoder", 

19 "is_generated_model", 

20 "as_attribute_dict", 

21 "attribute_list", 

22 "TypeHandlerRegistry", 

23 "get_backcompat_attr_name", 

24] 

25TZ_UTC = timezone.utc 

26 

27 

28class _Null: 

29 """To create a Falsy object""" 

30 

31 def __bool__(self) -> bool: 

32 return False 

33 

34 

35NULL = _Null() 

36""" 

37A falsy sentinel object which is supposed to be used to specify attributes 

38with no data. This gets serialized to `null` on the wire. 

39""" 

40 

41 

42class TypeHandlerRegistry: 

43 """A registry for custom serializers and deserializers for specific types or conditions.""" 

44 

45 def __init__(self) -> None: 

46 self._serializer_types: Dict[Type, Callable] = {} 

47 self._deserializer_types: Dict[Type, Callable] = {} 

48 self._serializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = [] 

49 self._deserializer_predicates: List[Tuple[Callable[[Any], bool], Callable]] = [] 

50 

51 self._serializer_cache: Dict[Type, Optional[Callable]] = {} 

52 self._deserializer_cache: Dict[Type, Optional[Callable]] = {} 

53 

54 def register_serializer( 

55 self, condition: Union[Type, Callable[[Any], bool]] 

56 ) -> Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]]: 

57 """Decorator to register a serializer. 

58 

59 The handler function is expected to take a single argument, the object to serialize, 

60 and return a dictionary representation of that object. 

61 

62 Examples: 

63 

64 .. code-block:: python 

65 

66 @registry.register_serializer(CustomModel) 

67 def serialize_single_type(value: CustomModel) -> dict: 

68 return value.to_dict() 

69 

70 @registry.register_serializer(lambda x: isinstance(x, BaseModel)) 

71 def serialize_with_condition(value: BaseModel) -> dict: 

72 return value.to_dict() 

73 

74 # Called manually for a specific type 

75 def custom_serializer(value: CustomModel) -> Dict[str, Any]: 

76 return {"custom": value.custom} 

77 

78 registry.register_serializer(CustomModel)(custom_serializer) 

79 

80 :param condition: A type or a callable predicate function that takes an object and returns a bool. 

81 :type condition: Union[Type, Callable[[Any], bool]] 

82 :return: A decorator that registers the handler function. 

83 :rtype: Callable[[Callable[[Any], Dict[str, Any]]], Callable[[Any], Dict[str, Any]]] 

84 :raises TypeError: If the condition is neither a type nor a callable. 

85 """ 

86 

87 def decorator(handler_func: Callable[[Any], Dict[str, Any]]) -> Callable[[Any], Dict[str, Any]]: 

88 if isinstance(condition, type): 

89 self._serializer_types[condition] = handler_func 

90 elif callable(condition): 

91 self._serializer_predicates.append((condition, handler_func)) 

92 else: 

93 raise TypeError("Condition must be a type or a callable predicate function.") 

94 

95 self._serializer_cache.clear() 

96 return handler_func 

97 

98 return decorator 

99 

100 def register_deserializer( 

101 self, condition: Union[Type, Callable[[Any], bool]] 

102 ) -> Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]]: 

103 """Decorator to register a deserializer. 

104 

105 The handler function is expected to take two arguments: the target type and the data dictionary, 

106 and return an instance of the target type. 

107 

108 Examples: 

109 

110 .. code-block:: python 

111 

112 @registry.register_deserializer(CustomModel) 

113 def deserialize_single_type(cls: Type[CustomModel], data: dict) -> CustomModel: 

114 return cls(**data) 

115 

116 @registry.register_deserializer(lambda t: issubclass(t, BaseModel)) 

117 def deserialize_with_condition(cls: Type[BaseModel], data: dict) -> BaseModel: 

118 return cls(**data) 

119 

120 # Called manually for a specific type 

121 def custom_deserializer(cls: Type[CustomModel], data: Dict[str, Any]) -> CustomModel: 

122 return cls(custom=data["custom"]) 

123 

124 registry.register_deserializer(CustomModel)(custom_deserializer) 

125 

126 :param condition: A type or a callable predicate function that takes an object and returns a bool. 

127 :type condition: Union[Type, Callable[[Any], bool]] 

128 :return: A decorator that registers the handler function. 

129 :rtype: Callable[[Callable[[Type, Dict[str, Any]], Any]], Callable[[Type, Dict[str, Any]], Any]] 

130 :raises TypeError: If the condition is neither a type nor a callable. 

131 """ 

132 

133 def decorator(handler_func: Callable[[Type, Dict[str, Any]], Any]) -> Callable[[Type, Dict[str, Any]], Any]: 

134 if isinstance(condition, type): 

135 self._deserializer_types[condition] = handler_func 

136 elif callable(condition): 

137 self._deserializer_predicates.append((condition, handler_func)) 

138 else: 

139 raise TypeError("Condition must be a type or a callable predicate function.") 

140 

141 self._deserializer_cache.clear() 

142 return handler_func 

143 

144 return decorator 

145 

146 def get_serializer(self, obj: Any) -> Optional[Callable[[Any], Dict[str, Any]]]: 

147 """Gets the appropriate serializer for an object. 

148 

149 It first checks the type dictionary for a direct type match. 

150 If no match is found, it iterates through the predicate list to find a match. 

151 

152 Results of the lookup are cached for performance based on the object's type. 

153 

154 :param obj: The object to serialize. 

155 :type obj: any 

156 :return: The serializer function if found, otherwise None. 

157 :rtype: Optional[Callable[[Any], Dict[str, Any]]] 

158 """ 

159 obj_type = type(obj) 

160 if obj_type in self._serializer_cache: 

161 return self._serializer_cache[obj_type] 

162 

163 handler = self._serializer_types.get(type(obj)) 

164 if not handler: 

165 for predicate, pred_handler in self._serializer_predicates: 

166 if predicate(obj): 

167 handler = pred_handler 

168 break 

169 

170 self._serializer_cache[obj_type] = handler 

171 return handler 

172 

173 def get_deserializer(self, cls: Type) -> Optional[Callable[[Dict[str, Any]], Any]]: 

174 """Gets the appropriate deserializer for a class. 

175 

176 It first checks the type dictionary for a direct type match. 

177 If no match is found, it iterates through the predicate list to find a match. 

178 

179 Results of the lookup are cached for performance based on the class. 

180 

181 :param cls: The class to deserialize. 

182 :type cls: type 

183 :return: A deserializer function bound to the specified class that takes a dictionary and returns 

184 an instance of that class, or None if no deserializer is found. 

185 :rtype: Optional[Callable[[Dict[str, Any]], Any]] 

186 """ 

187 if cls in self._deserializer_cache: 

188 return self._deserializer_cache[cls] 

189 

190 handler = self._deserializer_types.get(cls) 

191 if not handler: 

192 for predicate, pred_handler in self._deserializer_predicates: 

193 if predicate(cls): 

194 handler = pred_handler 

195 break 

196 

197 self._deserializer_cache[cls] = partial(handler, cls) if handler else None 

198 return self._deserializer_cache[cls] 

199 

200 

201def _timedelta_as_isostr(td: timedelta) -> str: 

202 """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S' 

203 

204 Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython 

205 

206 :param td: The timedelta object to convert 

207 :type td: datetime.timedelta 

208 :return: An ISO 8601 formatted string representing the timedelta object 

209 :rtype: str 

210 """ 

211 

212 # Split seconds to larger units 

213 seconds = td.total_seconds() 

214 minutes, seconds = divmod(seconds, 60) 

215 hours, minutes = divmod(minutes, 60) 

216 days, hours = divmod(hours, 24) 

217 

218 days, hours, minutes = list(map(int, (days, hours, minutes))) 

219 seconds = round(seconds, 6) 

220 

221 # Build date 

222 date_str = "" 

223 if days: 

224 date_str = "%sD" % days 

225 

226 # Build time 

227 time_str = "T" 

228 

229 # Hours 

230 bigger_exists = date_str or hours 

231 if bigger_exists: 

232 time_str += "{:02}H".format(hours) 

233 

234 # Minutes 

235 bigger_exists = bigger_exists or minutes 

236 if bigger_exists: 

237 time_str += "{:02}M".format(minutes) 

238 

239 # Seconds 

240 try: 

241 if seconds.is_integer(): 

242 seconds_string = "{:02}".format(int(seconds)) 

243 else: 

244 # 9 chars long w/ leading 0, 6 digits after decimal 

245 seconds_string = "%09.6f" % seconds 

246 # Remove trailing zeros 

247 seconds_string = seconds_string.rstrip("0") 

248 except AttributeError: # int.is_integer() raises 

249 seconds_string = "{:02}".format(seconds) 

250 

251 time_str += "{}S".format(seconds_string) 

252 

253 return "P" + date_str + time_str 

254 

255 

256def _datetime_as_isostr(dt: Union[datetime, date, time, timedelta]) -> str: 

257 """Converts a datetime.(datetime|date|time|timedelta) object into an ISO 8601 formatted string. 

258 

259 :param dt: The datetime object to convert 

260 :type dt: datetime.datetime or datetime.date or datetime.time or datetime.timedelta 

261 :return: An ISO 8601 formatted string representing the datetime object 

262 :rtype: str 

263 """ 

264 # First try datetime.datetime 

265 if hasattr(dt, "year") and hasattr(dt, "hour"): 

266 dt = cast(datetime, dt) 

267 # astimezone() fails for naive times in Python 2.7, so make make sure dt is aware (tzinfo is set) 

268 if not dt.tzinfo: 

269 iso_formatted = dt.replace(tzinfo=TZ_UTC).isoformat() 

270 else: 

271 iso_formatted = dt.astimezone(TZ_UTC).isoformat() 

272 # Replace the trailing "+00:00" UTC offset with "Z" (RFC 3339: https://www.ietf.org/rfc/rfc3339.txt) 

273 return iso_formatted.replace("+00:00", "Z") 

274 # Next try datetime.date or datetime.time 

275 try: 

276 dt = cast(Union[date, time], dt) 

277 return dt.isoformat() 

278 # Last, try datetime.timedelta 

279 except AttributeError: 

280 dt = cast(timedelta, dt) 

281 return _timedelta_as_isostr(dt) 

282 

283 

284class AzureJSONEncoder(JSONEncoder): 

285 """A JSON encoder that's capable of serializing datetime objects and bytes.""" 

286 

287 def default(self, o: Any) -> Any: 

288 """Override the default method to handle datetime and bytes serialization. 

289 :param o: The object to serialize. 

290 :type o: any 

291 :return: A JSON-serializable representation of the object. 

292 :rtype: any 

293 """ 

294 if isinstance(o, (bytes, bytearray)): 

295 return base64.b64encode(o).decode() 

296 try: 

297 return _datetime_as_isostr(o) 

298 except AttributeError: 

299 pass 

300 return super(AzureJSONEncoder, self).default(o) 

301 

302 

303def is_generated_model(obj: Any) -> bool: 

304 """Check if the object is a generated SDK model. 

305 

306 :param obj: The object to check. 

307 :type obj: any 

308 :return: True if the object is a generated SDK model, False otherwise. 

309 :rtype: bool 

310 """ 

311 return bool(getattr(obj, "_is_model", False) or hasattr(obj, "_attribute_map")) 

312 

313 

314def _is_readonly(p: Any) -> bool: 

315 """Check if an attribute is readonly. 

316 

317 :param any p: The property to check. 

318 :return: True if the property is readonly, False otherwise. 

319 :rtype: bool 

320 """ 

321 try: 

322 return p._visibility == ["read"] 

323 except AttributeError: 

324 return False 

325 

326 

327def _as_attribute_dict_value(v: Any, *, exclude_readonly: bool = False) -> Any: 

328 if v is None or isinstance(v, _Null): 

329 return None 

330 if isinstance(v, (list, tuple, set)): 

331 return type(v)(_as_attribute_dict_value(x, exclude_readonly=exclude_readonly) for x in v) 

332 if isinstance(v, dict): 

333 return {dk: _as_attribute_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} 

334 return as_attribute_dict(v, exclude_readonly=exclude_readonly) if is_generated_model(v) else v 

335 

336 

337def _get_backcompat_name(rest_field: Any, default_attr_name: str) -> str: 

338 """Get the backcompat name for an attribute. 

339 

340 :param any rest_field: The rest field to get the backcompat name from. 

341 :param str default_attr_name: The default attribute name to use if no backcompat name 

342 :return: The backcompat name. 

343 :rtype: str 

344 """ 

345 try: 

346 return rest_field._original_tsp_name or default_attr_name 

347 except AttributeError: 

348 return default_attr_name 

349 

350 

351def _get_flattened_attribute(obj: Any) -> Optional[str]: 

352 """Get the name of the flattened attribute in a generated TypeSpec model if one exists. 

353 

354 :param any obj: The object to check. 

355 :return: The name of the flattened attribute if it exists, otherwise None. 

356 :rtype: Optional[str] 

357 """ 

358 flattened_items = None 

359 try: 

360 flattened_items = getattr(obj, next(a for a in dir(obj) if "__flattened_items" in a), None) 

361 except StopIteration: 

362 return None 

363 

364 if flattened_items is None: 

365 return None 

366 

367 for k, v in obj._attr_to_rest_field.items(): 

368 try: 

369 if set(v._class_type._attr_to_rest_field.keys()).intersection(set(flattened_items)): 

370 return k 

371 except AttributeError: 

372 # if the attribute does not have _class_type, it is not a typespec generated model 

373 continue 

374 return None 

375 

376 

377def attribute_list(obj: Any) -> List[str]: 

378 """Get a list of attribute names for a generated SDK model. 

379 

380 :param obj: The object to get attributes from. 

381 :type obj: any 

382 :return: A list of attribute names. 

383 :rtype: List[str] 

384 """ 

385 if not is_generated_model(obj): 

386 raise TypeError("Object is not a generated SDK model.") 

387 if hasattr(obj, "_attribute_map"): 

388 # msrest model 

389 return list(obj._attribute_map.keys()) 

390 flattened_attribute = _get_flattened_attribute(obj) 

391 retval: List[str] = [] 

392 for attr_name, rest_field in obj._attr_to_rest_field.items(): 

393 if flattened_attribute == attr_name: 

394 retval.extend(attribute_list(rest_field._class_type)) 

395 else: 

396 retval.append(attr_name) 

397 return retval 

398 

399 

400def as_attribute_dict(obj: Any, *, exclude_readonly: bool = False) -> Dict[str, Any]: 

401 """Convert an object to a dictionary of its attributes. 

402 

403 Made solely for backcompatibility with the legacy `.as_dict()` on msrest models. 

404 

405 .. deprecated::1.35.0 

406 This function is added for backcompat purposes only. 

407 

408 :param any obj: The object to convert to a dictionary 

409 :keyword bool exclude_readonly: Whether to exclude readonly properties 

410 :return: A dictionary containing the object's attributes 

411 :rtype: dict[str, any] 

412 :raises TypeError: If the object is not a generated model instance 

413 """ 

414 if not is_generated_model(obj): 

415 raise TypeError("Object must be a generated model instance.") 

416 if hasattr(obj, "_attribute_map"): 

417 # msrest generated model 

418 return obj.as_dict(keep_readonly=not exclude_readonly) 

419 try: 

420 # now we're a typespec generated model 

421 result = {} 

422 readonly_props = set() 

423 

424 # create a reverse mapping from rest field name to attribute name 

425 rest_to_attr = {} 

426 flattened_attribute = _get_flattened_attribute(obj) 

427 for attr_name, rest_field in obj._attr_to_rest_field.items(): 

428 

429 if exclude_readonly and _is_readonly(rest_field): 

430 # if we're excluding readonly properties, we need to track them 

431 readonly_props.add(rest_field._rest_name) 

432 if flattened_attribute == attr_name: 

433 for fk, fv in rest_field._class_type._attr_to_rest_field.items(): 

434 rest_to_attr[fv._rest_name] = fk 

435 else: 

436 rest_to_attr[rest_field._rest_name] = attr_name 

437 for k, v in obj.items(): 

438 if exclude_readonly and k in readonly_props: # pyright: ignore 

439 continue 

440 if k == flattened_attribute: 

441 for fk, fv in v.items(): 

442 result[rest_to_attr.get(fk, fk)] = _as_attribute_dict_value(fv, exclude_readonly=exclude_readonly) 

443 else: 

444 is_multipart_file_input = False 

445 try: 

446 is_multipart_file_input = next( 

447 rf for rf in obj._attr_to_rest_field.values() if rf._rest_name == k 

448 )._is_multipart_file_input 

449 except StopIteration: 

450 pass 

451 

452 result[rest_to_attr.get(k, k)] = ( 

453 v if is_multipart_file_input else _as_attribute_dict_value(v, exclude_readonly=exclude_readonly) 

454 ) 

455 return result 

456 except AttributeError as exc: 

457 # not a typespec generated model 

458 raise TypeError("Object must be a generated model instance.") from exc 

459 

460 

461def get_backcompat_attr_name(model: Any, attr_name: str) -> str: 

462 """Get the backcompat attribute name for a given attribute. 

463 

464 This function takes an attribute name and returns the backcompat name (original TSP name) 

465 if one exists, otherwise returns the attribute name itself. 

466 

467 :param any model: The model instance. 

468 :param str attr_name: The attribute name to get the backcompat name for. 

469 :return: The backcompat attribute name (original TSP name) or the attribute name itself. 

470 :rtype: str 

471 """ 

472 if not is_generated_model(model): 

473 raise TypeError("Object must be a generated model instance.") 

474 

475 # Check if attr_name exists in the model's attributes 

476 flattened_attribute = _get_flattened_attribute(model) 

477 for field_attr_name, rest_field in model._attr_to_rest_field.items(): 

478 # Check if this is the attribute we're looking for 

479 if field_attr_name == attr_name: 

480 # Return the original TSP name if it exists, otherwise the attribute name 

481 return _get_backcompat_name(rest_field, attr_name) 

482 

483 # If this is a flattened attribute, check inside it 

484 if flattened_attribute == field_attr_name: 

485 for fk, fv in rest_field._class_type._attr_to_rest_field.items(): 

486 if fk == attr_name: 

487 # Return the original TSP name for this flattened property 

488 return _get_backcompat_name(fv, fk) 

489 

490 # If not found in the model, just return the attribute name as-is 

491 return attr_name