1# coding=utf-8
2# --------------------------------------------------------------------------
3# Copyright (c) Microsoft Corporation. All rights reserved.
4# Licensed under the MIT License. See License.txt in the project root for
5# license information.
6# --------------------------------------------------------------------------
7import datetime
8import sys
9from typing import (
10 Any,
11 AsyncContextManager,
12 Iterable,
13 Iterator,
14 Mapping,
15 MutableMapping,
16 Optional,
17 Tuple,
18 Union,
19 Dict,
20)
21from datetime import timezone
22
23TZ_UTC = timezone.utc
24
25
26class _FixedOffset(datetime.tzinfo):
27 """Fixed offset in minutes east from UTC.
28
29 Copy/pasted from Python doc
30
31 :param int offset: offset in minutes
32 """
33
34 def __init__(self, offset):
35 self.__offset = datetime.timedelta(minutes=offset)
36
37 def utcoffset(self, dt):
38 return self.__offset
39
40 def tzname(self, dt):
41 return str(self.__offset.total_seconds() / 3600)
42
43 def __repr__(self):
44 return "<FixedOffset {}>".format(self.tzname(None))
45
46 def dst(self, dt):
47 return datetime.timedelta(0)
48
49
50def _convert_to_isoformat(date_time):
51 """Deserialize a date in RFC 3339 format to datetime object.
52 Check https://tools.ietf.org/html/rfc3339#section-5.8 for examples.
53
54 :param str date_time: The date in RFC 3339 format.
55 """
56 if not date_time:
57 return None
58 if date_time[-1] == "Z":
59 delta = 0
60 timestamp = date_time[:-1]
61 else:
62 timestamp = date_time[:-6]
63 sign, offset = date_time[-6], date_time[-5:]
64 delta = int(sign + offset[:1]) * 60 + int(sign + offset[-2:])
65
66 check_decimal = timestamp.split(".")
67 if len(check_decimal) > 1:
68 decimal_str = ""
69 for digit in check_decimal[1]:
70 if digit.isdigit():
71 decimal_str += digit
72 else:
73 break
74 if len(decimal_str) > 6:
75 timestamp = timestamp.replace(decimal_str, decimal_str[0:6])
76
77 if delta == 0:
78 tzinfo = TZ_UTC
79 else:
80 tzinfo = timezone(datetime.timedelta(minutes=delta))
81
82 try:
83 deserialized = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f")
84 except ValueError:
85 deserialized = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S")
86
87 deserialized = deserialized.replace(tzinfo=tzinfo)
88 return deserialized
89
90
91def case_insensitive_dict(
92 *args: Optional[Union[Mapping[str, Any], Iterable[Tuple[str, Any]]]], **kwargs: Any
93) -> MutableMapping[str, Any]:
94 """Return a case-insensitive mutable mapping from an inputted mapping structure.
95
96 :param args: The positional arguments to pass to the dict.
97 :type args: Mapping[str, Any] or Iterable[Tuple[str, Any]
98 :return: A case-insensitive mutable mapping object.
99 :rtype: ~collections.abc.MutableMapping
100 """
101 return CaseInsensitiveDict(*args, **kwargs)
102
103
104class CaseInsensitiveDict(MutableMapping[str, Any]):
105 """
106 NOTE: This implementation is heavily inspired from the case insensitive dictionary from the requests library.
107 Thank you !!
108 Case insensitive dictionary implementation.
109 The keys are expected to be strings and will be stored in lower case.
110 case_insensitive_dict = CaseInsensitiveDict()
111 case_insensitive_dict['Key'] = 'some_value'
112 case_insensitive_dict['key'] == 'some_value' #True
113
114 :param data: Initial data to store in the dictionary.
115 :type data: Mapping[str, Any] or Iterable[Tuple[str, Any]]
116 """
117
118 def __init__(
119 self, data: Optional[Union[Mapping[str, Any], Iterable[Tuple[str, Any]]]] = None, **kwargs: Any
120 ) -> None:
121 self._store: Dict[str, Any] = {}
122 if data is None:
123 data = {}
124
125 self.update(data, **kwargs)
126
127 def copy(self) -> "CaseInsensitiveDict":
128 return CaseInsensitiveDict(self._store.values())
129
130 def __setitem__(self, key: str, value: Any) -> None:
131 """Set the `key` to `value`.
132
133 The original key will be stored with the value
134
135 :param str key: The key to set.
136 :param value: The value to set the key to.
137 :type value: any
138 """
139 self._store[key.lower()] = (key, value)
140
141 def __getitem__(self, key: str) -> Any:
142 return self._store[key.lower()][1]
143
144 def __delitem__(self, key: str) -> None:
145 del self._store[key.lower()]
146
147 def __iter__(self) -> Iterator[str]:
148 return (key for key, _ in self._store.values())
149
150 def __len__(self) -> int:
151 return len(self._store)
152
153 def lowerkey_items(self) -> Iterator[Tuple[str, Any]]:
154 return ((lower_case_key, pair[1]) for lower_case_key, pair in self._store.items())
155
156 def __eq__(self, other: Any) -> bool:
157 if isinstance(other, Mapping):
158 other = CaseInsensitiveDict(other)
159 else:
160 return False
161
162 return dict(self.lowerkey_items()) == dict(other.lowerkey_items())
163
164 def __repr__(self) -> str:
165 return str(dict(self.items()))
166
167
168def get_running_async_lock() -> AsyncContextManager:
169 """Get a lock instance from the async library that the current context is running under.
170
171 :return: An instance of the running async library's Lock class.
172 :rtype: AsyncContextManager
173 :raises RuntimeError: if the current context is not running under an async library.
174 """
175
176 try:
177 import asyncio # pylint: disable=do-not-import-asyncio
178
179 # Check if we are running in an asyncio event loop.
180 asyncio.get_running_loop()
181 return asyncio.Lock()
182 except RuntimeError as err:
183 # Otherwise, assume we are running in a trio event loop if it has already been imported.
184 if "trio" in sys.modules:
185 import trio # pylint: disable=networking-import-outside-azure-core-transport
186
187 return trio.Lock()
188 raise RuntimeError("An asyncio or trio event loop is required.") from err