1"""
2Digest authentication middleware for aiohttp client.
3
4This middleware implements HTTP Digest Authentication according to RFC 7616,
5providing a more secure alternative to Basic Authentication. It supports all
6standard hash algorithms including MD5, SHA, SHA-256, SHA-512 and their session
7variants, as well as both 'auth' and 'auth-int' quality of protection (qop) options.
8"""
9
10import hashlib
11import os
12import re
13import time
14from typing import (
15 Callable,
16 Dict,
17 Final,
18 FrozenSet,
19 List,
20 Literal,
21 Tuple,
22 TypedDict,
23 Union,
24)
25
26from yarl import URL
27
28from . import hdrs
29from .client_exceptions import ClientError
30from .client_middlewares import ClientHandlerType
31from .client_reqrep import ClientRequest, ClientResponse
32from .payload import Payload
33
34
35class DigestAuthChallenge(TypedDict, total=False):
36 realm: str
37 nonce: str
38 qop: str
39 algorithm: str
40 opaque: str
41 domain: str
42 stale: str
43
44
45DigestFunctions: Dict[str, Callable[[bytes], "hashlib._Hash"]] = {
46 "MD5": hashlib.md5,
47 "MD5-SESS": hashlib.md5,
48 "SHA": hashlib.sha1,
49 "SHA-SESS": hashlib.sha1,
50 "SHA256": hashlib.sha256,
51 "SHA256-SESS": hashlib.sha256,
52 "SHA-256": hashlib.sha256,
53 "SHA-256-SESS": hashlib.sha256,
54 "SHA512": hashlib.sha512,
55 "SHA512-SESS": hashlib.sha512,
56 "SHA-512": hashlib.sha512,
57 "SHA-512-SESS": hashlib.sha512,
58}
59
60
61# Compile the regex pattern once at module level for performance
62_HEADER_PAIRS_PATTERN = re.compile(
63 r'(\w+)\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))'
64 # | | | | | | | | | || |
65 # +----|--|-|-|--|----|------|----|--||-----|--> alphanumeric key
66 # +--|-|-|--|----|------|----|--||-----|--> maybe whitespace
67 # | | | | | | | || |
68 # +-|-|--|----|------|----|--||-----|--> = (delimiter)
69 # +-|--|----|------|----|--||-----|--> maybe whitespace
70 # | | | | | || |
71 # +--|----|------|----|--||-----|--> group quoted or unquoted
72 # | | | | || |
73 # +----|------|----|--||-----|--> if quoted...
74 # +------|----|--||-----|--> anything but " or \
75 # +----|--||-----|--> escaped characters allowed
76 # +--||-----|--> or can be empty string
77 # || |
78 # +|-----|--> if unquoted...
79 # +-----|--> anything but , or <space>
80 # +--> at least one char req'd
81)
82
83
84# RFC 7616: Challenge parameters to extract
85CHALLENGE_FIELDS: Final[
86 Tuple[
87 Literal["realm", "nonce", "qop", "algorithm", "opaque", "domain", "stale"], ...
88 ]
89] = (
90 "realm",
91 "nonce",
92 "qop",
93 "algorithm",
94 "opaque",
95 "domain",
96 "stale",
97)
98
99# Supported digest authentication algorithms
100# Use a tuple of sorted keys for predictable documentation and error messages
101SUPPORTED_ALGORITHMS: Final[Tuple[str, ...]] = tuple(sorted(DigestFunctions.keys()))
102
103# RFC 7616: Fields that require quoting in the Digest auth header
104# These fields must be enclosed in double quotes in the Authorization header.
105# Algorithm, qop, and nc are never quoted per RFC specifications.
106# This frozen set is used by the template-based header construction to
107# automatically determine which fields need quotes.
108QUOTED_AUTH_FIELDS: Final[FrozenSet[str]] = frozenset(
109 {"username", "realm", "nonce", "uri", "response", "opaque", "cnonce"}
110)
111
112
113def escape_quotes(value: str) -> str:
114 """Escape double quotes for HTTP header values."""
115 return value.replace('"', '\\"')
116
117
118def unescape_quotes(value: str) -> str:
119 """Unescape double quotes in HTTP header values."""
120 return value.replace('\\"', '"')
121
122
123def parse_header_pairs(header: str) -> Dict[str, str]:
124 """
125 Parse key-value pairs from WWW-Authenticate or similar HTTP headers.
126
127 This function handles the complex format of WWW-Authenticate header values,
128 supporting both quoted and unquoted values, proper handling of commas in
129 quoted values, and whitespace variations per RFC 7616.
130
131 Examples of supported formats:
132 - key1="value1", key2=value2
133 - key1 = "value1" , key2="value, with, commas"
134 - key1=value1,key2="value2"
135 - realm="example.com", nonce="12345", qop="auth"
136
137 Args:
138 header: The header value string to parse
139
140 Returns:
141 Dictionary mapping parameter names to their values
142 """
143 return {
144 stripped_key: unescape_quotes(quoted_val) if quoted_val else unquoted_val
145 for key, quoted_val, unquoted_val in _HEADER_PAIRS_PATTERN.findall(header)
146 if (stripped_key := key.strip())
147 }
148
149
150class DigestAuthMiddleware:
151 """
152 HTTP digest authentication middleware for aiohttp client.
153
154 This middleware intercepts 401 Unauthorized responses containing a Digest
155 authentication challenge, calculates the appropriate digest credentials,
156 and automatically retries the request with the proper Authorization header.
157
158 Features:
159 - Handles all aspects of Digest authentication handshake automatically
160 - Supports all standard hash algorithms:
161 - MD5, MD5-SESS
162 - SHA, SHA-SESS
163 - SHA256, SHA256-SESS, SHA-256, SHA-256-SESS
164 - SHA512, SHA512-SESS, SHA-512, SHA-512-SESS
165 - Supports 'auth' and 'auth-int' quality of protection modes
166 - Properly handles quoted strings and parameter parsing
167 - Includes replay attack protection with client nonce count tracking
168 - Supports preemptive authentication per RFC 7616 Section 3.6
169
170 Standards compliance:
171 - RFC 7616: HTTP Digest Access Authentication (primary reference)
172 - RFC 2617: HTTP Authentication (deprecated by RFC 7616)
173 - RFC 1945: Section 11.1 (username restrictions)
174
175 Implementation notes:
176 The core digest calculation is inspired by the implementation in
177 https://github.com/requests/requests/blob/v2.18.4/requests/auth.py
178 with added support for modern digest auth features and error handling.
179 """
180
181 def __init__(
182 self,
183 login: str,
184 password: str,
185 preemptive: bool = True,
186 ) -> None:
187 if login is None:
188 raise ValueError("None is not allowed as login value")
189
190 if password is None:
191 raise ValueError("None is not allowed as password value")
192
193 if ":" in login:
194 raise ValueError('A ":" is not allowed in username (RFC 1945#section-11.1)')
195
196 self._login_str: Final[str] = login
197 self._login_bytes: Final[bytes] = login.encode("utf-8")
198 self._password_bytes: Final[bytes] = password.encode("utf-8")
199
200 self._last_nonce_bytes = b""
201 self._nonce_count = 0
202 self._challenge: DigestAuthChallenge = {}
203 self._preemptive: bool = preemptive
204 # Set of URLs defining the protection space
205 self._protection_space: List[str] = []
206
207 async def _encode(
208 self, method: str, url: URL, body: Union[Payload, Literal[b""]]
209 ) -> str:
210 """
211 Build digest authorization header for the current challenge.
212
213 Args:
214 method: The HTTP method (GET, POST, etc.)
215 url: The request URL
216 body: The request body (used for qop=auth-int)
217
218 Returns:
219 A fully formatted Digest authorization header string
220
221 Raises:
222 ClientError: If the challenge is missing required parameters or
223 contains unsupported values
224
225 """
226 challenge = self._challenge
227 if "realm" not in challenge:
228 raise ClientError(
229 "Malformed Digest auth challenge: Missing 'realm' parameter"
230 )
231
232 if "nonce" not in challenge:
233 raise ClientError(
234 "Malformed Digest auth challenge: Missing 'nonce' parameter"
235 )
236
237 # Empty realm values are allowed per RFC 7616 (SHOULD, not MUST, contain host name)
238 realm = challenge["realm"]
239 nonce = challenge["nonce"]
240
241 # Empty nonce values are not allowed as they are security-critical for replay protection
242 if not nonce:
243 raise ClientError(
244 "Security issue: Digest auth challenge contains empty 'nonce' value"
245 )
246
247 qop_raw = challenge.get("qop", "")
248 algorithm = challenge.get("algorithm", "MD5").upper()
249 opaque = challenge.get("opaque", "")
250
251 # Convert string values to bytes once
252 nonce_bytes = nonce.encode("utf-8")
253 realm_bytes = realm.encode("utf-8")
254 path = URL(url).path_qs
255
256 # Process QoP
257 qop = ""
258 qop_bytes = b""
259 if qop_raw:
260 valid_qops = {"auth", "auth-int"}.intersection(
261 {q.strip() for q in qop_raw.split(",") if q.strip()}
262 )
263 if not valid_qops:
264 raise ClientError(
265 f"Digest auth error: Unsupported Quality of Protection (qop) value(s): {qop_raw}"
266 )
267
268 qop = "auth-int" if "auth-int" in valid_qops else "auth"
269 qop_bytes = qop.encode("utf-8")
270
271 if algorithm not in DigestFunctions:
272 raise ClientError(
273 f"Digest auth error: Unsupported hash algorithm: {algorithm}. "
274 f"Supported algorithms: {', '.join(SUPPORTED_ALGORITHMS)}"
275 )
276 hash_fn: Final = DigestFunctions[algorithm]
277
278 def H(x: bytes) -> bytes:
279 """RFC 7616 Section 3: Hash function H(data) = hex(hash(data))."""
280 return hash_fn(x).hexdigest().encode()
281
282 def KD(s: bytes, d: bytes) -> bytes:
283 """RFC 7616 Section 3: KD(secret, data) = H(concat(secret, ":", data))."""
284 return H(b":".join((s, d)))
285
286 # Calculate A1 and A2
287 A1 = b":".join((self._login_bytes, realm_bytes, self._password_bytes))
288 A2 = f"{method.upper()}:{path}".encode()
289 if qop == "auth-int":
290 if isinstance(body, Payload): # will always be empty bytes unless Payload
291 entity_bytes = await body.as_bytes() # Get bytes from Payload
292 else:
293 entity_bytes = body
294 entity_hash = H(entity_bytes)
295 A2 = b":".join((A2, entity_hash))
296
297 HA1 = H(A1)
298 HA2 = H(A2)
299
300 # Nonce count handling
301 if nonce_bytes == self._last_nonce_bytes:
302 self._nonce_count += 1
303 else:
304 self._nonce_count = 1
305
306 self._last_nonce_bytes = nonce_bytes
307 ncvalue = f"{self._nonce_count:08x}"
308 ncvalue_bytes = ncvalue.encode("utf-8")
309
310 # Generate client nonce
311 cnonce = hashlib.sha1(
312 b"".join(
313 [
314 str(self._nonce_count).encode("utf-8"),
315 nonce_bytes,
316 time.ctime().encode("utf-8"),
317 os.urandom(8),
318 ]
319 )
320 ).hexdigest()[:16]
321 cnonce_bytes = cnonce.encode("utf-8")
322
323 # Special handling for session-based algorithms
324 if algorithm.upper().endswith("-SESS"):
325 HA1 = H(b":".join((HA1, nonce_bytes, cnonce_bytes)))
326
327 # Calculate the response digest
328 if qop:
329 noncebit = b":".join(
330 (nonce_bytes, ncvalue_bytes, cnonce_bytes, qop_bytes, HA2)
331 )
332 response_digest = KD(HA1, noncebit)
333 else:
334 response_digest = KD(HA1, b":".join((nonce_bytes, HA2)))
335
336 # Define a dict mapping of header fields to their values
337 # Group fields into always-present, optional, and qop-dependent
338 header_fields = {
339 # Always present fields
340 "username": escape_quotes(self._login_str),
341 "realm": escape_quotes(realm),
342 "nonce": escape_quotes(nonce),
343 "uri": path,
344 "response": response_digest.decode(),
345 "algorithm": algorithm,
346 }
347
348 # Optional fields
349 if opaque:
350 header_fields["opaque"] = escape_quotes(opaque)
351
352 # QoP-dependent fields
353 if qop:
354 header_fields["qop"] = qop
355 header_fields["nc"] = ncvalue
356 header_fields["cnonce"] = cnonce
357
358 # Build header using templates for each field type
359 pairs: List[str] = []
360 for field, value in header_fields.items():
361 if field in QUOTED_AUTH_FIELDS:
362 pairs.append(f'{field}="{value}"')
363 else:
364 pairs.append(f"{field}={value}")
365
366 return f"Digest {', '.join(pairs)}"
367
368 def _in_protection_space(self, url: URL) -> bool:
369 """
370 Check if the given URL is within the current protection space.
371
372 According to RFC 7616, a URI is in the protection space if any URI
373 in the protection space is a prefix of it (after both have been made absolute).
374 """
375 request_str = str(url)
376 for space_str in self._protection_space:
377 # Check if request starts with space URL
378 if not request_str.startswith(space_str):
379 continue
380 # Exact match or space ends with / (proper directory prefix)
381 if len(request_str) == len(space_str) or space_str[-1] == "/":
382 return True
383 # Check next char is / to ensure proper path boundary
384 if request_str[len(space_str)] == "/":
385 return True
386 return False
387
388 def _authenticate(self, response: ClientResponse) -> bool:
389 """
390 Takes the given response and tries digest-auth, if needed.
391
392 Returns true if the original request must be resent.
393 """
394 if response.status != 401:
395 return False
396
397 auth_header = response.headers.get("www-authenticate", "")
398 if not auth_header:
399 return False # No authentication header present
400
401 method, sep, headers = auth_header.partition(" ")
402 if not sep:
403 # No space found in www-authenticate header
404 return False # Malformed auth header, missing scheme separator
405
406 if method.lower() != "digest":
407 # Not a digest auth challenge (could be Basic, Bearer, etc.)
408 return False
409
410 if not headers:
411 # We have a digest scheme but no parameters
412 return False # Malformed digest header, missing parameters
413
414 # We have a digest auth header with content
415 if not (header_pairs := parse_header_pairs(headers)):
416 # Failed to parse any key-value pairs
417 return False # Malformed digest header, no valid parameters
418
419 # Extract challenge parameters
420 self._challenge = {}
421 for field in CHALLENGE_FIELDS:
422 if value := header_pairs.get(field):
423 self._challenge[field] = value
424
425 # Update protection space based on domain parameter or default to origin
426 origin = response.url.origin()
427
428 if domain := self._challenge.get("domain"):
429 # Parse space-separated list of URIs
430 self._protection_space = []
431 for uri in domain.split():
432 # Remove quotes if present
433 uri = uri.strip('"')
434 if uri.startswith("/"):
435 # Path-absolute, relative to origin
436 self._protection_space.append(str(origin.join(URL(uri))))
437 else:
438 # Absolute URI
439 self._protection_space.append(str(URL(uri)))
440 else:
441 # No domain specified, protection space is entire origin
442 self._protection_space = [str(origin)]
443
444 # Return True only if we found at least one challenge parameter
445 return bool(self._challenge)
446
447 async def __call__(
448 self, request: ClientRequest, handler: ClientHandlerType
449 ) -> ClientResponse:
450 """Run the digest auth middleware."""
451 response = None
452 for retry_count in range(2):
453 # Apply authorization header if:
454 # 1. This is a retry after 401 (retry_count > 0), OR
455 # 2. Preemptive auth is enabled AND we have a challenge AND the URL is in protection space
456 if retry_count > 0 or (
457 self._preemptive
458 and self._challenge
459 and self._in_protection_space(request.url)
460 ):
461 request.headers[hdrs.AUTHORIZATION] = await self._encode(
462 request.method, request.url, request.body
463 )
464
465 # Send the request
466 response = await handler(request)
467
468 # Check if we need to authenticate
469 if not self._authenticate(response):
470 break
471
472 # At this point, response is guaranteed to be defined
473 assert response is not None
474 return response