1from __future__ import annotations
2
3import dataclasses
4import re
5import urllib.parse
6from collections.abc import Mapping
7from typing import TYPE_CHECKING, Any, Protocol, TypeVar
8
9if TYPE_CHECKING: # pragma: no cover
10 import sys
11 from collections.abc import Collection
12
13 if sys.version_info >= (3, 11):
14 from typing import Self
15 else:
16 from typing_extensions import Self
17
18__all__ = [
19 "ArchiveInfo",
20 "DirInfo",
21 "DirectUrl",
22 "DirectUrlValidationError",
23 "VcsInfo",
24]
25
26
27def __dir__() -> list[str]:
28 return __all__
29
30
31_T = TypeVar("_T")
32
33
34class _FromMappingProtocol(Protocol): # pragma: no cover
35 @classmethod
36 def _from_dict(cls, d: Mapping[str, Any]) -> Self: ...
37
38
39_FromMappingProtocolT = TypeVar("_FromMappingProtocolT", bound=_FromMappingProtocol)
40
41
42def _json_dict_factory(data: list[tuple[str, Any]]) -> dict[str, Any]:
43 return {key: value for key, value in data if value is not None}
44
45
46def _get(d: Mapping[str, Any], expected_type: type[_T], key: str) -> _T | None:
47 """Get a value from the dictionary and verify it's the expected type."""
48 if (value := d.get(key)) is None:
49 return None
50 if not isinstance(value, expected_type):
51 raise DirectUrlValidationError(
52 f"Unexpected type {type(value).__name__} "
53 f"(expected {expected_type.__name__})",
54 context=key,
55 )
56 return value
57
58
59def _get_required(d: Mapping[str, Any], expected_type: type[_T], key: str) -> _T:
60 """Get a required value from the dictionary and verify it's the expected type."""
61 if (value := _get(d, expected_type, key)) is None:
62 raise _DirectUrlRequiredKeyError(key)
63 return value
64
65
66def _get_object(
67 d: Mapping[str, Any], target_type: type[_FromMappingProtocolT], key: str
68) -> _FromMappingProtocolT | None:
69 """Get a dictionary value from the dictionary and convert it to a dataclass."""
70 if (value := _get(d, Mapping, key)) is None: # type: ignore[type-abstract]
71 return None
72 try:
73 return target_type._from_dict(value)
74 except Exception as e:
75 raise DirectUrlValidationError(e, context=key) from e
76
77
78_PEP610_USER_PASS_ENV_VARS_REGEX = re.compile(
79 r"^\$\{[A-Za-z0-9-_]+\}(:\$\{[A-Za-z0-9-_]+\})?$"
80)
81
82
83def _strip_auth_from_netloc(netloc: str, safe_user_passwords: Collection[str]) -> str:
84 if "@" not in netloc:
85 return netloc
86 user_pass, netloc_no_user_pass = netloc.split("@", 1)
87 if user_pass in safe_user_passwords:
88 return netloc
89 if _PEP610_USER_PASS_ENV_VARS_REGEX.match(user_pass):
90 return netloc
91 return netloc_no_user_pass
92
93
94def _strip_url(url: str, safe_user_passwords: Collection[str]) -> str:
95 """url with user:password part removed unless it is formed with
96 environment variables as specified in PEP 610, or it is a safe user:password
97 such as `git`.
98 """
99 parsed_url = urllib.parse.urlsplit(url)
100 netloc = _strip_auth_from_netloc(parsed_url.netloc, safe_user_passwords)
101 return urllib.parse.urlunsplit(
102 (
103 parsed_url.scheme,
104 netloc,
105 parsed_url.path,
106 parsed_url.query,
107 parsed_url.fragment,
108 )
109 )
110
111
112class DirectUrlValidationError(Exception):
113 """Raised when when input data is not spec-compliant."""
114
115 context: str | None = None
116 message: str
117
118 def __init__(
119 self,
120 cause: str | Exception,
121 *,
122 context: str | None = None,
123 ) -> None:
124 if isinstance(cause, DirectUrlValidationError):
125 if cause.context:
126 self.context = (
127 f"{context}.{cause.context}" if context else cause.context
128 )
129 else:
130 self.context = context # pragma: no cover
131 self.message = cause.message
132 else:
133 self.context = context
134 self.message = str(cause)
135
136 def __str__(self) -> str:
137 if self.context:
138 return f"{self.message} in {self.context!r}"
139 return self.message
140
141
142class _DirectUrlRequiredKeyError(DirectUrlValidationError):
143 def __init__(self, key: str) -> None:
144 super().__init__("Missing required value", context=key)
145
146
147@dataclasses.dataclass(frozen=True, init=False)
148class VcsInfo:
149 vcs: str
150 commit_id: str
151 requested_revision: str | None = None
152
153 def __init__(
154 self,
155 *,
156 vcs: str,
157 commit_id: str,
158 requested_revision: str | None = None,
159 ) -> None:
160 object.__setattr__(self, "vcs", vcs)
161 object.__setattr__(self, "commit_id", commit_id)
162 object.__setattr__(self, "requested_revision", requested_revision)
163
164 @classmethod
165 def _from_dict(cls, d: Mapping[str, Any]) -> Self:
166 # We can't validate vcs value because is not closed.
167 return cls(
168 vcs=_get_required(d, str, "vcs"),
169 requested_revision=_get(d, str, "requested_revision"),
170 commit_id=_get_required(d, str, "commit_id"),
171 )
172
173
174@dataclasses.dataclass(frozen=True, init=False)
175class ArchiveInfo:
176 hashes: Mapping[str, str] | None = None
177
178 def __init__(
179 self,
180 *,
181 hashes: Mapping[str, str] | None = None,
182 ) -> None:
183 object.__setattr__(self, "hashes", hashes)
184
185 @classmethod
186 def _from_dict(cls, d: Mapping[str, Any]) -> Self:
187 hashes = _get(d, Mapping, "hashes") # type: ignore[type-abstract]
188 if hashes is not None and not all(isinstance(h, str) for h in hashes.values()):
189 raise DirectUrlValidationError(
190 "Hash values must be strings", context="hashes"
191 )
192 legacy_hash = _get(d, str, "hash")
193 if legacy_hash is not None:
194 if "=" not in legacy_hash:
195 raise DirectUrlValidationError(
196 "Invalid hash format (expected '<algorithm>=<hash>')",
197 context="hash",
198 )
199 hash_algorithm, hash_value = legacy_hash.split("=", 1)
200 if hashes is None:
201 # if `hashes` are not present, we can derive it from the legacy `hash`
202 hashes = {hash_algorithm: hash_value}
203 else:
204 # if `hashes` are present, the legacy `hash` must match one of them
205 if hash_algorithm not in hashes:
206 raise DirectUrlValidationError(
207 f"Algorithm {hash_algorithm!r} used in hash field "
208 f"is not present in hashes field",
209 context="hashes",
210 )
211 if hashes[hash_algorithm] != hash_value:
212 raise DirectUrlValidationError(
213 f"Algorithm {hash_algorithm!r} used in hash field "
214 f"has different value in hashes field",
215 context="hash",
216 )
217 return cls(hashes=hashes)
218
219
220@dataclasses.dataclass(frozen=True, init=False)
221class DirInfo:
222 editable: bool | None = None
223
224 def __init__(
225 self,
226 *,
227 editable: bool | None = None,
228 ) -> None:
229 object.__setattr__(self, "editable", editable)
230
231 @classmethod
232 def _from_dict(cls, d: Mapping[str, Any]) -> Self:
233 return cls(
234 editable=_get(d, bool, "editable"),
235 )
236
237
238@dataclasses.dataclass(frozen=True, init=False)
239class DirectUrl:
240 """A class representing a direct URL."""
241
242 url: str
243 archive_info: ArchiveInfo | None = None
244 vcs_info: VcsInfo | None = None
245 dir_info: DirInfo | None = None
246 subdirectory: str | None = None # XXX Path or str?
247
248 def __init__(
249 self,
250 *,
251 url: str,
252 archive_info: ArchiveInfo | None = None,
253 vcs_info: VcsInfo | None = None,
254 dir_info: DirInfo | None = None,
255 subdirectory: str | None = None,
256 ) -> None:
257 object.__setattr__(self, "url", url)
258 object.__setattr__(self, "archive_info", archive_info)
259 object.__setattr__(self, "vcs_info", vcs_info)
260 object.__setattr__(self, "dir_info", dir_info)
261 object.__setattr__(self, "subdirectory", subdirectory)
262
263 @classmethod
264 def _from_dict(cls, d: Mapping[str, Any]) -> Self:
265 direct_url = cls(
266 url=_get_required(d, str, "url"),
267 archive_info=_get_object(d, ArchiveInfo, "archive_info"),
268 vcs_info=_get_object(d, VcsInfo, "vcs_info"),
269 dir_info=_get_object(d, DirInfo, "dir_info"),
270 subdirectory=_get(d, str, "subdirectory"),
271 )
272 if (
273 bool(direct_url.vcs_info)
274 + bool(direct_url.archive_info)
275 + bool(direct_url.dir_info)
276 ) != 1:
277 raise DirectUrlValidationError(
278 "Exactly one of vcs_info, archive_info, dir_info must be present"
279 )
280 if direct_url.dir_info is not None and not direct_url.url.startswith("file://"):
281 raise DirectUrlValidationError(
282 "URL scheme must be file:// when dir_info is present",
283 context="url",
284 )
285 # XXX subdirectory must be relative, can we, should we validate that here?
286 return direct_url
287
288 @classmethod
289 def from_dict(cls, d: Mapping[str, Any], /) -> Self:
290 """Create and validate a DirectUrl instance from a JSON dictionary."""
291 return cls._from_dict(d)
292
293 def to_dict(
294 self,
295 *,
296 generate_legacy_hash: bool = False,
297 strip_user_password: bool = True,
298 safe_user_passwords: Collection[str] = ("git",),
299 ) -> Mapping[str, Any]:
300 """Convert the DirectUrl instance to a JSON dictionary.
301
302 :param generate_legacy_hash: If True, include a legacy `hash` field in
303 `archive_info` for backward compatibility with tools that don't
304 support the `hashes` field.
305 :param strip_user_password: If True, strip user:password from the URL
306 unless it is formed with environment variables as specified in PEP
307 610, or it is a safe user:password such as `git`.
308 :param safe_user_passwords: A collection of user:password strings that
309 should not be stripped from the URL even if `strip_user_password` is
310 True.
311 """
312 res = dataclasses.asdict(self, dict_factory=_json_dict_factory)
313 if generate_legacy_hash and self.archive_info and self.archive_info.hashes:
314 hash_algorithm, hash_value = next(iter(self.archive_info.hashes.items()))
315 res["archive_info"]["hash"] = f"{hash_algorithm}={hash_value}"
316 if strip_user_password:
317 res["url"] = _strip_url(self.url, safe_user_passwords)
318 return res
319
320 def validate(self) -> None:
321 """Validate the DirectUrl instance against the specification.
322
323 Raises :class:`DirectUrlValidationError` if invalid.
324 """
325 self.from_dict(self.to_dict())