1# coding=utf-8
2# --------------------------------------------------------------------------
3# Copyright (c) Microsoft Corporation. All rights reserved.
4# Licensed under the MIT License. See License.txt in the project root for
5# license information.
6# --------------------------------------------------------------------------
7import datetime
8import sys
9from typing import (
10 Any,
11 AsyncContextManager,
12 Iterable,
13 Iterator,
14 Mapping,
15 MutableMapping,
16 Optional,
17 Tuple,
18 Union,
19 Dict,
20)
21from datetime import timezone
22
23TZ_UTC = timezone.utc
24
25
26def _convert_to_isoformat(date_time):
27 """Deserialize a date in RFC 3339 format to datetime object.
28 Check https://tools.ietf.org/html/rfc3339#section-5.8 for examples.
29
30 :param str date_time: The date in RFC 3339 format.
31 """
32 if not date_time:
33 return None
34 if date_time[-1] == "Z":
35 delta = 0
36 timestamp = date_time[:-1]
37 else:
38 timestamp = date_time[:-6]
39 sign, offset = date_time[-6], date_time[-5:]
40 delta = int(sign + offset[:1]) * 60 + int(sign + offset[-2:])
41
42 check_decimal = timestamp.split(".")
43 if len(check_decimal) > 1:
44 decimal_str = ""
45 for digit in check_decimal[1]:
46 if digit.isdigit():
47 decimal_str += digit
48 else:
49 break
50 if len(decimal_str) > 6:
51 timestamp = timestamp.replace(decimal_str, decimal_str[0:6])
52
53 if delta == 0:
54 tzinfo = TZ_UTC
55 else:
56 tzinfo = timezone(datetime.timedelta(minutes=delta))
57
58 try:
59 deserialized = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f")
60 except ValueError:
61 deserialized = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S")
62
63 deserialized = deserialized.replace(tzinfo=tzinfo)
64 return deserialized
65
66
67def case_insensitive_dict(
68 *args: Optional[Union[Mapping[str, Any], Iterable[Tuple[str, Any]]]], **kwargs: Any
69) -> MutableMapping[str, Any]:
70 """Return a case-insensitive mutable mapping from an inputted mapping structure.
71
72 :param args: The positional arguments to pass to the dict.
73 :type args: Mapping[str, Any] or Iterable[Tuple[str, Any]
74 :return: A case-insensitive mutable mapping object.
75 :rtype: ~collections.abc.MutableMapping
76 """
77 return CaseInsensitiveDict(*args, **kwargs)
78
79
80class CaseInsensitiveDict(MutableMapping[str, Any]):
81 """
82 NOTE: This implementation is heavily inspired from the case insensitive dictionary from the requests library.
83 Thank you !!
84 Case insensitive dictionary implementation.
85 The keys are expected to be strings and will be stored in lower case.
86 case_insensitive_dict = CaseInsensitiveDict()
87 case_insensitive_dict['Key'] = 'some_value'
88 case_insensitive_dict['key'] == 'some_value' #True
89
90 :param data: Initial data to store in the dictionary.
91 :type data: Mapping[str, Any] or Iterable[Tuple[str, Any]]
92 """
93
94 def __init__(
95 self, data: Optional[Union[Mapping[str, Any], Iterable[Tuple[str, Any]]]] = None, **kwargs: Any
96 ) -> None:
97 self._store: Dict[str, Any] = {}
98 if data is None:
99 data = {}
100
101 self.update(data, **kwargs)
102
103 def copy(self) -> "CaseInsensitiveDict":
104 return CaseInsensitiveDict(self._store.values())
105
106 def __setitem__(self, key: str, value: Any) -> None:
107 """Set the `key` to `value`.
108
109 The original key will be stored with the value
110
111 :param str key: The key to set.
112 :param value: The value to set the key to.
113 :type value: any
114 """
115 self._store[key.lower()] = (key, value)
116
117 def __getitem__(self, key: str) -> Any:
118 return self._store[key.lower()][1]
119
120 def __delitem__(self, key: str) -> None:
121 del self._store[key.lower()]
122
123 def __iter__(self) -> Iterator[str]:
124 return (key for key, _ in self._store.values())
125
126 def __len__(self) -> int:
127 return len(self._store)
128
129 def lowerkey_items(self) -> Iterator[Tuple[str, Any]]:
130 return ((lower_case_key, pair[1]) for lower_case_key, pair in self._store.items())
131
132 def __eq__(self, other: Any) -> bool:
133 if isinstance(other, Mapping):
134 other = CaseInsensitiveDict(other)
135 else:
136 return False
137
138 return dict(self.lowerkey_items()) == dict(other.lowerkey_items())
139
140 def __repr__(self) -> str:
141 return str(dict(self.items()))
142
143
144class CaseInsensitiveSet(set):
145 """A set that stores values in their original form but performs
146 case-insensitive lookups via a pre-computed lowercase cache.
147
148 The cache is rebuilt only when the set is mutated, not on every lookup.
149
150 :param data: Initial values for the set.
151 :type data: Iterable[str]
152 """
153
154 def __init__(self, data: Optional[Iterable[str]] = None) -> None:
155 self._lower_to_original: Dict[str, str] = {}
156 self.update(data or [])
157
158 def __contains__(self, item):
159 return item.lower() in self._lower_to_original
160
161 def add(self, item):
162 lower = item.lower()
163 if lower not in self._lower_to_original:
164 super().add(item)
165 self._lower_to_original[lower] = item
166
167 def discard(self, item):
168 lower = item.lower()
169 original = self._lower_to_original.pop(lower, None)
170 if original is not None:
171 super().discard(original)
172
173 def remove(self, item):
174 lower = item.lower()
175 original = self._lower_to_original.pop(lower, None)
176 if original is None:
177 raise KeyError(item)
178 super().remove(original)
179
180 def update(self, *others):
181 for other in others:
182 for item in other:
183 self.add(item)
184
185 def clear(self):
186 super().clear()
187 self._lower_to_original = {}
188
189 def pop(self):
190 result = super().pop()
191 self._lower_to_original.pop(result.lower(), None)
192 return result
193
194
195def get_running_async_lock() -> AsyncContextManager:
196 """Get a lock instance from the async library that the current context is running under.
197
198 :return: An instance of the running async library's Lock class.
199 :rtype: AsyncContextManager
200 :raises RuntimeError: if the current context is not running under an async library.
201 """
202
203 try:
204 import asyncio # pylint: disable=do-not-import-asyncio
205
206 # Check if we are running in an asyncio event loop.
207 asyncio.get_running_loop()
208 return asyncio.Lock()
209 except RuntimeError as err:
210 # Otherwise, assume we are running in a trio event loop if it has already been imported.
211 if "trio" in sys.modules:
212 import trio # pylint: disable=networking-import-outside-azure-core-transport
213
214 return trio.Lock()
215 raise RuntimeError("An asyncio or trio event loop is required.") from err