1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16from contextlib import asynccontextmanager
17import functools
18import time
19from typing import Mapping, Optional
20
21from google.auth import _exponential_backoff, exceptions
22from google.auth.aio import transport
23from google.auth.aio.credentials import Credentials
24from google.auth.exceptions import TimeoutError
25
26try:
27 from google.auth.aio.transport.aiohttp import Request as AiohttpRequest
28
29 AIOHTTP_INSTALLED = True
30except ImportError: # pragma: NO COVER
31 AIOHTTP_INSTALLED = False
32
33
34@asynccontextmanager
35async def timeout_guard(timeout):
36 """
37 timeout_guard is an asynchronous context manager to apply a timeout to an asynchronous block of code.
38
39 Args:
40 timeout (float): The time in seconds before the context manager times out.
41
42 Raises:
43 google.auth.exceptions.TimeoutError: If the code within the context exceeds the provided timeout.
44
45 Usage:
46 async with timeout_guard(10) as with_timeout:
47 await with_timeout(async_function())
48 """
49 start = time.monotonic()
50 total_timeout = timeout
51
52 def _remaining_time():
53 elapsed = time.monotonic() - start
54 remaining = total_timeout - elapsed
55 if remaining <= 0:
56 raise TimeoutError(
57 f"Context manager exceeded the configured timeout of {total_timeout}s."
58 )
59 return remaining
60
61 async def with_timeout(coro):
62 try:
63 remaining = _remaining_time()
64 response = await asyncio.wait_for(coro, remaining)
65 return response
66 except (asyncio.TimeoutError, TimeoutError) as e:
67 raise TimeoutError(
68 f"The operation {coro} exceeded the configured timeout of {total_timeout}s."
69 ) from e
70
71 try:
72 yield with_timeout
73
74 finally:
75 _remaining_time()
76
77
78class AsyncAuthorizedSession:
79 """This is an asynchronous implementation of :class:`google.auth.requests.AuthorizedSession` class.
80 We utilize an instance of a class that implements :class:`google.auth.aio.transport.Request` configured
81 by the caller or otherwise default to `google.auth.aio.transport.aiohttp.Request` if the external aiohttp
82 package is installed.
83
84 A Requests Session class with credentials.
85
86 This class is used to perform asynchronous requests to API endpoints that require
87 authorization::
88
89 import aiohttp
90 from google.auth.aio.transport import sessions
91
92 async with sessions.AsyncAuthorizedSession(credentials) as authed_session:
93 response = await authed_session.request(
94 'GET', 'https://www.googleapis.com/storage/v1/b')
95
96 The underlying :meth:`request` implementation handles adding the
97 credentials' headers to the request and refreshing credentials as needed.
98
99 Args:
100 credentials (google.auth.aio.credentials.Credentials):
101 The credentials to add to the request.
102 auth_request (Optional[google.auth.aio.transport.Request]):
103 An instance of a class that implements
104 :class:`~google.auth.aio.transport.Request` used to make requests
105 and refresh credentials. If not passed,
106 an instance of :class:`~google.auth.aio.transport.aiohttp.Request`
107 is created.
108
109 Raises:
110 - google.auth.exceptions.TransportError: If `auth_request` is `None`
111 and the external package `aiohttp` is not installed.
112 - google.auth.exceptions.InvalidType: If the provided credentials are
113 not of type `google.auth.aio.credentials.Credentials`.
114 """
115
116 def __init__(
117 self, credentials: Credentials, auth_request: Optional[transport.Request] = None
118 ):
119 if not isinstance(credentials, Credentials):
120 raise exceptions.InvalidType(
121 f"The configured credentials of type {type(credentials)} are invalid and must be of type `google.auth.aio.credentials.Credentials`"
122 )
123 self._credentials = credentials
124 _auth_request = auth_request
125 if not _auth_request and AIOHTTP_INSTALLED:
126 _auth_request = AiohttpRequest()
127 if _auth_request is None:
128 raise exceptions.TransportError(
129 "`auth_request` must either be configured or the external package `aiohttp` must be installed to use the default value."
130 )
131 self._auth_request = _auth_request
132
133 async def request(
134 self,
135 method: str,
136 url: str,
137 data: Optional[bytes] = None,
138 headers: Optional[Mapping[str, str]] = None,
139 max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
140 timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
141 **kwargs,
142 ) -> transport.Response:
143 """
144 Args:
145 method (str): The http method used to make the request.
146 url (str): The URI to be requested.
147 data (Optional[bytes]): The payload or body in HTTP request.
148 headers (Optional[Mapping[str, str]]): Request headers.
149 timeout (float):
150 The amount of time in seconds to wait for the server response
151 with each individual request.
152 max_allowed_time (float):
153 If the method runs longer than this, a ``Timeout`` exception is
154 automatically raised. Unlike the ``timeout`` parameter, this
155 value applies to the total method execution time, even if
156 multiple requests are made under the hood.
157
158 Mind that it is not guaranteed that the timeout error is raised
159 at ``max_allowed_time``. It might take longer, for example, if
160 an underlying request takes a lot of time, but the request
161 itself does not timeout, e.g. if a large file is being
162 transmitted. The timout error will be raised after such
163 request completes.
164
165 Returns:
166 google.auth.aio.transport.Response: The HTTP response.
167
168 Raises:
169 google.auth.exceptions.TimeoutError: If the method does not complete within
170 the configured `max_allowed_time` or the request exceeds the configured
171 `timeout`.
172 """
173
174 retries = _exponential_backoff.AsyncExponentialBackoff(
175 total_attempts=transport.DEFAULT_MAX_RETRY_ATTEMPTS
176 )
177 async with timeout_guard(max_allowed_time) as with_timeout:
178 await with_timeout(
179 # Note: before_request will attempt to refresh credentials if expired.
180 self._credentials.before_request(
181 self._auth_request, method, url, headers
182 )
183 )
184 # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch`
185 # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372
186 async for _ in retries: # pragma: no branch
187 response = await with_timeout(
188 self._auth_request(url, method, data, headers, timeout, **kwargs)
189 )
190 if response.status_code not in transport.DEFAULT_RETRYABLE_STATUS_CODES:
191 break
192 return response
193
194 @functools.wraps(request)
195 async def get(
196 self,
197 url: str,
198 data: Optional[bytes] = None,
199 headers: Optional[Mapping[str, str]] = None,
200 max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
201 timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
202 **kwargs,
203 ) -> transport.Response:
204 return await self.request(
205 "GET", url, data, headers, max_allowed_time, timeout, **kwargs
206 )
207
208 @functools.wraps(request)
209 async def post(
210 self,
211 url: str,
212 data: Optional[bytes] = None,
213 headers: Optional[Mapping[str, str]] = None,
214 max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
215 timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
216 **kwargs,
217 ) -> transport.Response:
218 return await self.request(
219 "POST", url, data, headers, max_allowed_time, timeout, **kwargs
220 )
221
222 @functools.wraps(request)
223 async def put(
224 self,
225 url: str,
226 data: Optional[bytes] = None,
227 headers: Optional[Mapping[str, str]] = None,
228 max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
229 timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
230 **kwargs,
231 ) -> transport.Response:
232 return await self.request(
233 "PUT", url, data, headers, max_allowed_time, timeout, **kwargs
234 )
235
236 @functools.wraps(request)
237 async def patch(
238 self,
239 url: str,
240 data: Optional[bytes] = None,
241 headers: Optional[Mapping[str, str]] = None,
242 max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
243 timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
244 **kwargs,
245 ) -> transport.Response:
246 return await self.request(
247 "PATCH", url, data, headers, max_allowed_time, timeout, **kwargs
248 )
249
250 @functools.wraps(request)
251 async def delete(
252 self,
253 url: str,
254 data: Optional[bytes] = None,
255 headers: Optional[Mapping[str, str]] = None,
256 max_allowed_time: float = transport._DEFAULT_TIMEOUT_SECONDS,
257 timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
258 **kwargs,
259 ) -> transport.Response:
260 return await self.request(
261 "DELETE", url, data, headers, max_allowed_time, timeout, **kwargs
262 )
263
264 async def close(self) -> None:
265 """
266 Close the underlying auth request session.
267 """
268 await self._auth_request.close()