1"""
2Digest authentication middleware for aiohttp client.
3
4This middleware implements HTTP Digest Authentication according to RFC 7616,
5providing a more secure alternative to Basic Authentication. It supports all
6standard hash algorithms including MD5, SHA, SHA-256, SHA-512 and their session
7variants, as well as both 'auth' and 'auth-int' quality of protection (qop) options.
8"""
9
10import hashlib
11import os
12import re
13import sys
14import time
15from typing import (
16 Callable,
17 Dict,
18 Final,
19 FrozenSet,
20 List,
21 Literal,
22 Tuple,
23 TypedDict,
24 Union,
25)
26
27from yarl import URL
28
29from . import hdrs
30from .client_exceptions import ClientError
31from .client_middlewares import ClientHandlerType
32from .client_reqrep import ClientRequest, ClientResponse
33from .payload import Payload
34
35
36class DigestAuthChallenge(TypedDict, total=False):
37 realm: str
38 nonce: str
39 qop: str
40 algorithm: str
41 opaque: str
42 domain: str
43 stale: str
44
45
46DigestFunctions: Dict[str, Callable[[bytes], "hashlib._Hash"]] = {
47 "MD5": hashlib.md5,
48 "MD5-SESS": hashlib.md5,
49 "SHA": hashlib.sha1,
50 "SHA-SESS": hashlib.sha1,
51 "SHA256": hashlib.sha256,
52 "SHA256-SESS": hashlib.sha256,
53 "SHA-256": hashlib.sha256,
54 "SHA-256-SESS": hashlib.sha256,
55 "SHA512": hashlib.sha512,
56 "SHA512-SESS": hashlib.sha512,
57 "SHA-512": hashlib.sha512,
58 "SHA-512-SESS": hashlib.sha512,
59}
60
61
62# Compile the regex pattern once at module level for performance
63_HEADER_PAIRS_PATTERN = re.compile(
64 r'(?:^|\s|,\s*)(\w+)\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))'
65 if sys.version_info < (3, 11)
66 else r'(?:^|\s|,\s*)((?>\w+))\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))'
67 # +------------|--------|--|-|-|--|----|------|----|--||-----|-> Match valid start/sep
68 # +--------|--|-|-|--|----|------|----|--||-----|-> alphanumeric key (atomic
69 # | | | | | | | | || | group reduces backtracking)
70 # +--|-|-|--|----|------|----|--||-----|-> maybe whitespace
71 # | | | | | | | || |
72 # +-|-|--|----|------|----|--||-----|-> = (delimiter)
73 # +-|--|----|------|----|--||-----|-> maybe whitespace
74 # | | | | | || |
75 # +--|----|------|----|--||-----|-> group quoted or unquoted
76 # | | | | || |
77 # +----|------|----|--||-----|-> if quoted...
78 # +------|----|--||-----|-> anything but " or \
79 # +----|--||-----|-> escaped characters allowed
80 # +--||-----|-> or can be empty string
81 # || |
82 # +|-----|-> if unquoted...
83 # +-----|-> anything but , or <space>
84 # +-> at least one char req'd
85)
86
87
88# RFC 7616: Challenge parameters to extract
89CHALLENGE_FIELDS: Final[
90 Tuple[
91 Literal["realm", "nonce", "qop", "algorithm", "opaque", "domain", "stale"], ...
92 ]
93] = (
94 "realm",
95 "nonce",
96 "qop",
97 "algorithm",
98 "opaque",
99 "domain",
100 "stale",
101)
102
103# Supported digest authentication algorithms
104# Use a tuple of sorted keys for predictable documentation and error messages
105SUPPORTED_ALGORITHMS: Final[Tuple[str, ...]] = tuple(sorted(DigestFunctions.keys()))
106
107# RFC 7616: Fields that require quoting in the Digest auth header
108# These fields must be enclosed in double quotes in the Authorization header.
109# Algorithm, qop, and nc are never quoted per RFC specifications.
110# This frozen set is used by the template-based header construction to
111# automatically determine which fields need quotes.
112QUOTED_AUTH_FIELDS: Final[FrozenSet[str]] = frozenset(
113 {"username", "realm", "nonce", "uri", "response", "opaque", "cnonce"}
114)
115
116
117def escape_quotes(value: str) -> str:
118 """Escape double quotes for HTTP header values."""
119 return value.replace('"', '\\"')
120
121
122def unescape_quotes(value: str) -> str:
123 """Unescape double quotes in HTTP header values."""
124 return value.replace('\\"', '"')
125
126
127def parse_header_pairs(header: str) -> Dict[str, str]:
128 """
129 Parse key-value pairs from WWW-Authenticate or similar HTTP headers.
130
131 This function handles the complex format of WWW-Authenticate header values,
132 supporting both quoted and unquoted values, proper handling of commas in
133 quoted values, and whitespace variations per RFC 7616.
134
135 Examples of supported formats:
136 - key1="value1", key2=value2
137 - key1 = "value1" , key2="value, with, commas"
138 - key1=value1,key2="value2"
139 - realm="example.com", nonce="12345", qop="auth"
140
141 Args:
142 header: The header value string to parse
143
144 Returns:
145 Dictionary mapping parameter names to their values
146 """
147 return {
148 stripped_key: unescape_quotes(quoted_val) if quoted_val else unquoted_val
149 for key, quoted_val, unquoted_val in _HEADER_PAIRS_PATTERN.findall(header)
150 if (stripped_key := key.strip())
151 }
152
153
154class DigestAuthMiddleware:
155 """
156 HTTP digest authentication middleware for aiohttp client.
157
158 This middleware intercepts 401 Unauthorized responses containing a Digest
159 authentication challenge, calculates the appropriate digest credentials,
160 and automatically retries the request with the proper Authorization header.
161
162 Features:
163 - Handles all aspects of Digest authentication handshake automatically
164 - Supports all standard hash algorithms:
165 - MD5, MD5-SESS
166 - SHA, SHA-SESS
167 - SHA256, SHA256-SESS, SHA-256, SHA-256-SESS
168 - SHA512, SHA512-SESS, SHA-512, SHA-512-SESS
169 - Supports 'auth' and 'auth-int' quality of protection modes
170 - Properly handles quoted strings and parameter parsing
171 - Includes replay attack protection with client nonce count tracking
172 - Supports preemptive authentication per RFC 7616 Section 3.6
173
174 Standards compliance:
175 - RFC 7616: HTTP Digest Access Authentication (primary reference)
176 - RFC 2617: HTTP Authentication (deprecated by RFC 7616)
177 - RFC 1945: Section 11.1 (username restrictions)
178
179 Implementation notes:
180 The core digest calculation is inspired by the implementation in
181 https://github.com/requests/requests/blob/v2.18.4/requests/auth.py
182 with added support for modern digest auth features and error handling.
183 """
184
185 def __init__(
186 self,
187 login: str,
188 password: str,
189 preemptive: bool = True,
190 ) -> None:
191 if login is None:
192 raise ValueError("None is not allowed as login value")
193
194 if password is None:
195 raise ValueError("None is not allowed as password value")
196
197 if ":" in login:
198 raise ValueError('A ":" is not allowed in username (RFC 1945#section-11.1)')
199
200 self._login_str: Final[str] = login
201 self._login_bytes: Final[bytes] = login.encode("utf-8")
202 self._password_bytes: Final[bytes] = password.encode("utf-8")
203
204 self._last_nonce_bytes = b""
205 self._nonce_count = 0
206 self._challenge: DigestAuthChallenge = {}
207 self._preemptive: bool = preemptive
208 # Set of URLs defining the protection space
209 self._protection_space: List[str] = []
210
211 async def _encode(
212 self, method: str, url: URL, body: Union[Payload, Literal[b""]]
213 ) -> str:
214 """
215 Build digest authorization header for the current challenge.
216
217 Args:
218 method: The HTTP method (GET, POST, etc.)
219 url: The request URL
220 body: The request body (used for qop=auth-int)
221
222 Returns:
223 A fully formatted Digest authorization header string
224
225 Raises:
226 ClientError: If the challenge is missing required parameters or
227 contains unsupported values
228
229 """
230 challenge = self._challenge
231 if "realm" not in challenge:
232 raise ClientError(
233 "Malformed Digest auth challenge: Missing 'realm' parameter"
234 )
235
236 if "nonce" not in challenge:
237 raise ClientError(
238 "Malformed Digest auth challenge: Missing 'nonce' parameter"
239 )
240
241 # Empty realm values are allowed per RFC 7616 (SHOULD, not MUST, contain host name)
242 realm = challenge["realm"]
243 nonce = challenge["nonce"]
244
245 # Empty nonce values are not allowed as they are security-critical for replay protection
246 if not nonce:
247 raise ClientError(
248 "Security issue: Digest auth challenge contains empty 'nonce' value"
249 )
250
251 qop_raw = challenge.get("qop", "")
252 # Preserve original algorithm case for response while using uppercase for processing
253 algorithm_original = challenge.get("algorithm", "MD5")
254 algorithm = algorithm_original.upper()
255 opaque = challenge.get("opaque", "")
256
257 # Convert string values to bytes once
258 nonce_bytes = nonce.encode("utf-8")
259 realm_bytes = realm.encode("utf-8")
260 path = URL(url).path_qs
261
262 # Process QoP
263 qop = ""
264 qop_bytes = b""
265 if qop_raw:
266 valid_qops = {"auth", "auth-int"}.intersection(
267 {q.strip() for q in qop_raw.split(",") if q.strip()}
268 )
269 if not valid_qops:
270 raise ClientError(
271 f"Digest auth error: Unsupported Quality of Protection (qop) value(s): {qop_raw}"
272 )
273
274 qop = "auth-int" if "auth-int" in valid_qops else "auth"
275 qop_bytes = qop.encode("utf-8")
276
277 if algorithm not in DigestFunctions:
278 raise ClientError(
279 f"Digest auth error: Unsupported hash algorithm: {algorithm}. "
280 f"Supported algorithms: {', '.join(SUPPORTED_ALGORITHMS)}"
281 )
282 hash_fn: Final = DigestFunctions[algorithm]
283
284 def H(x: bytes) -> bytes:
285 """RFC 7616 Section 3: Hash function H(data) = hex(hash(data))."""
286 return hash_fn(x).hexdigest().encode()
287
288 def KD(s: bytes, d: bytes) -> bytes:
289 """RFC 7616 Section 3: KD(secret, data) = H(concat(secret, ":", data))."""
290 return H(b":".join((s, d)))
291
292 # Calculate A1 and A2
293 A1 = b":".join((self._login_bytes, realm_bytes, self._password_bytes))
294 A2 = f"{method.upper()}:{path}".encode()
295 if qop == "auth-int":
296 if isinstance(body, Payload): # will always be empty bytes unless Payload
297 entity_bytes = await body.as_bytes() # Get bytes from Payload
298 else:
299 entity_bytes = body
300 entity_hash = H(entity_bytes)
301 A2 = b":".join((A2, entity_hash))
302
303 HA1 = H(A1)
304 HA2 = H(A2)
305
306 # Nonce count handling
307 if nonce_bytes == self._last_nonce_bytes:
308 self._nonce_count += 1
309 else:
310 self._nonce_count = 1
311
312 self._last_nonce_bytes = nonce_bytes
313 ncvalue = f"{self._nonce_count:08x}"
314 ncvalue_bytes = ncvalue.encode("utf-8")
315
316 # Generate client nonce
317 cnonce = hashlib.sha1(
318 b"".join(
319 [
320 str(self._nonce_count).encode("utf-8"),
321 nonce_bytes,
322 time.ctime().encode("utf-8"),
323 os.urandom(8),
324 ]
325 )
326 ).hexdigest()[:16]
327 cnonce_bytes = cnonce.encode("utf-8")
328
329 # Special handling for session-based algorithms
330 if algorithm.upper().endswith("-SESS"):
331 HA1 = H(b":".join((HA1, nonce_bytes, cnonce_bytes)))
332
333 # Calculate the response digest
334 if qop:
335 noncebit = b":".join(
336 (nonce_bytes, ncvalue_bytes, cnonce_bytes, qop_bytes, HA2)
337 )
338 response_digest = KD(HA1, noncebit)
339 else:
340 response_digest = KD(HA1, b":".join((nonce_bytes, HA2)))
341
342 # Define a dict mapping of header fields to their values
343 # Group fields into always-present, optional, and qop-dependent
344 header_fields = {
345 # Always present fields
346 "username": escape_quotes(self._login_str),
347 "realm": escape_quotes(realm),
348 "nonce": escape_quotes(nonce),
349 "uri": path,
350 "response": response_digest.decode(),
351 "algorithm": algorithm_original,
352 }
353
354 # Optional fields
355 if opaque:
356 header_fields["opaque"] = escape_quotes(opaque)
357
358 # QoP-dependent fields
359 if qop:
360 header_fields["qop"] = qop
361 header_fields["nc"] = ncvalue
362 header_fields["cnonce"] = cnonce
363
364 # Build header using templates for each field type
365 pairs: List[str] = []
366 for field, value in header_fields.items():
367 if field in QUOTED_AUTH_FIELDS:
368 pairs.append(f'{field}="{value}"')
369 else:
370 pairs.append(f"{field}={value}")
371
372 return f"Digest {', '.join(pairs)}"
373
374 def _in_protection_space(self, url: URL) -> bool:
375 """
376 Check if the given URL is within the current protection space.
377
378 According to RFC 7616, a URI is in the protection space if any URI
379 in the protection space is a prefix of it (after both have been made absolute).
380 """
381 request_str = str(url)
382 for space_str in self._protection_space:
383 # Check if request starts with space URL
384 if not request_str.startswith(space_str):
385 continue
386 # Exact match or space ends with / (proper directory prefix)
387 if len(request_str) == len(space_str) or space_str[-1] == "/":
388 return True
389 # Check next char is / to ensure proper path boundary
390 if request_str[len(space_str)] == "/":
391 return True
392 return False
393
394 def _authenticate(self, response: ClientResponse) -> bool:
395 """
396 Takes the given response and tries digest-auth, if needed.
397
398 Returns true if the original request must be resent.
399 """
400 if response.status != 401:
401 return False
402
403 auth_header = response.headers.get("www-authenticate", "")
404 if not auth_header:
405 return False # No authentication header present
406
407 method, sep, headers = auth_header.partition(" ")
408 if not sep:
409 # No space found in www-authenticate header
410 return False # Malformed auth header, missing scheme separator
411
412 if method.lower() != "digest":
413 # Not a digest auth challenge (could be Basic, Bearer, etc.)
414 return False
415
416 if not headers:
417 # We have a digest scheme but no parameters
418 return False # Malformed digest header, missing parameters
419
420 # We have a digest auth header with content
421 if not (header_pairs := parse_header_pairs(headers)):
422 # Failed to parse any key-value pairs
423 return False # Malformed digest header, no valid parameters
424
425 # Extract challenge parameters
426 self._challenge = {}
427 for field in CHALLENGE_FIELDS:
428 if value := header_pairs.get(field):
429 self._challenge[field] = value
430
431 # Update protection space based on domain parameter or default to origin
432 origin = response.url.origin()
433
434 if domain := self._challenge.get("domain"):
435 # Parse space-separated list of URIs
436 self._protection_space = []
437 for uri in domain.split():
438 # Remove quotes if present
439 uri = uri.strip('"')
440 if uri.startswith("/"):
441 # Path-absolute, relative to origin
442 self._protection_space.append(str(origin.join(URL(uri))))
443 else:
444 # Absolute URI
445 self._protection_space.append(str(URL(uri)))
446 else:
447 # No domain specified, protection space is entire origin
448 self._protection_space = [str(origin)]
449
450 # Return True only if we found at least one challenge parameter
451 return bool(self._challenge)
452
453 async def __call__(
454 self, request: ClientRequest, handler: ClientHandlerType
455 ) -> ClientResponse:
456 """Run the digest auth middleware."""
457 response = None
458 for retry_count in range(2):
459 # Apply authorization header if:
460 # 1. This is a retry after 401 (retry_count > 0), OR
461 # 2. Preemptive auth is enabled AND we have a challenge AND the URL is in protection space
462 if retry_count > 0 or (
463 self._preemptive
464 and self._challenge
465 and self._in_protection_space(request.url)
466 ):
467 request.headers[hdrs.AUTHORIZATION] = await self._encode(
468 request.method, request.url, request.body
469 )
470
471 # Send the request
472 response = await handler(request)
473
474 # Check if we need to authenticate
475 if not self._authenticate(response):
476 break
477
478 # At this point, response is guaranteed to be defined
479 assert response is not None
480 return response