1"""PEP 610"""
2
3from __future__ import annotations
4
5import json
6import re
7import urllib.parse
8from collections.abc import Iterable
9from dataclasses import dataclass
10from typing import Any, ClassVar, TypeVar, Union
11
12__all__ = [
13 "DirectUrl",
14 "DirectUrlValidationError",
15 "DirInfo",
16 "ArchiveInfo",
17 "VcsInfo",
18]
19
20T = TypeVar("T")
21
22DIRECT_URL_METADATA_NAME = "direct_url.json"
23ENV_VAR_RE = re.compile(r"^\$\{[A-Za-z0-9-_]+\}(:\$\{[A-Za-z0-9-_]+\})?$")
24
25
26class DirectUrlValidationError(Exception):
27 pass
28
29
30def _get(
31 d: dict[str, Any], expected_type: type[T], key: str, default: T | None = None
32) -> T | None:
33 """Get value from dictionary and verify expected type."""
34 if key not in d:
35 return default
36 value = d[key]
37 if not isinstance(value, expected_type):
38 raise DirectUrlValidationError(
39 f"{value!r} has unexpected type for {key} (expected {expected_type})"
40 )
41 return value
42
43
44def _get_required(
45 d: dict[str, Any], expected_type: type[T], key: str, default: T | None = None
46) -> T:
47 value = _get(d, expected_type, key, default)
48 if value is None:
49 raise DirectUrlValidationError(f"{key} must have a value")
50 return value
51
52
53def _exactly_one_of(infos: Iterable[InfoType | None]) -> InfoType:
54 infos = [info for info in infos if info is not None]
55 if not infos:
56 raise DirectUrlValidationError(
57 "missing one of archive_info, dir_info, vcs_info"
58 )
59 if len(infos) > 1:
60 raise DirectUrlValidationError(
61 "more than one of archive_info, dir_info, vcs_info"
62 )
63 assert infos[0] is not None
64 return infos[0]
65
66
67def _filter_none(**kwargs: Any) -> dict[str, Any]:
68 """Make dict excluding None values."""
69 return {k: v for k, v in kwargs.items() if v is not None}
70
71
72@dataclass
73class VcsInfo:
74 name: ClassVar = "vcs_info"
75
76 vcs: str
77 commit_id: str
78 requested_revision: str | None = None
79
80 @classmethod
81 def _from_dict(cls, d: dict[str, Any] | None) -> VcsInfo | None:
82 if d is None:
83 return None
84 return cls(
85 vcs=_get_required(d, str, "vcs"),
86 commit_id=_get_required(d, str, "commit_id"),
87 requested_revision=_get(d, str, "requested_revision"),
88 )
89
90 def _to_dict(self) -> dict[str, Any]:
91 return _filter_none(
92 vcs=self.vcs,
93 requested_revision=self.requested_revision,
94 commit_id=self.commit_id,
95 )
96
97
98class ArchiveInfo:
99 name = "archive_info"
100
101 def __init__(
102 self,
103 hash: str | None = None,
104 hashes: dict[str, str] | None = None,
105 ) -> None:
106 # set hashes before hash, since the hash setter will further populate hashes
107 self.hashes = hashes
108 self.hash = hash
109
110 @property
111 def hash(self) -> str | None:
112 return self._hash
113
114 @hash.setter
115 def hash(self, value: str | None) -> None:
116 if value is not None:
117 # Auto-populate the hashes key to upgrade to the new format automatically.
118 # We don't back-populate the legacy hash key from hashes.
119 try:
120 hash_name, hash_value = value.split("=", 1)
121 except ValueError:
122 raise DirectUrlValidationError(
123 f"invalid archive_info.hash format: {value!r}"
124 )
125 if self.hashes is None:
126 self.hashes = {hash_name: hash_value}
127 elif hash_name not in self.hashes:
128 self.hashes = self.hashes.copy()
129 self.hashes[hash_name] = hash_value
130 self._hash = value
131
132 @classmethod
133 def _from_dict(cls, d: dict[str, Any] | None) -> ArchiveInfo | None:
134 if d is None:
135 return None
136 return cls(hash=_get(d, str, "hash"), hashes=_get(d, dict, "hashes"))
137
138 def _to_dict(self) -> dict[str, Any]:
139 return _filter_none(hash=self.hash, hashes=self.hashes)
140
141
142@dataclass
143class DirInfo:
144 name: ClassVar = "dir_info"
145
146 editable: bool = False
147
148 @classmethod
149 def _from_dict(cls, d: dict[str, Any] | None) -> DirInfo | None:
150 if d is None:
151 return None
152 return cls(editable=_get_required(d, bool, "editable", default=False))
153
154 def _to_dict(self) -> dict[str, Any]:
155 return _filter_none(editable=self.editable or None)
156
157
158InfoType = Union[ArchiveInfo, DirInfo, VcsInfo]
159
160
161@dataclass
162class DirectUrl:
163 url: str
164 info: InfoType
165 subdirectory: str | None = None
166
167 def _remove_auth_from_netloc(self, netloc: str) -> str:
168 if "@" not in netloc:
169 return netloc
170 user_pass, netloc_no_user_pass = netloc.split("@", 1)
171 if (
172 isinstance(self.info, VcsInfo)
173 and self.info.vcs == "git"
174 and user_pass == "git"
175 ):
176 return netloc
177 if ENV_VAR_RE.match(user_pass):
178 return netloc
179 return netloc_no_user_pass
180
181 @property
182 def redacted_url(self) -> str:
183 """url with user:password part removed unless it is formed with
184 environment variables as specified in PEP 610, or it is ``git``
185 in the case of a git URL.
186 """
187 purl = urllib.parse.urlsplit(self.url)
188 netloc = self._remove_auth_from_netloc(purl.netloc)
189 surl = urllib.parse.urlunsplit(
190 (purl.scheme, netloc, purl.path, purl.query, purl.fragment)
191 )
192 return surl
193
194 def validate(self) -> None:
195 self.from_dict(self.to_dict())
196
197 @classmethod
198 def from_dict(cls, d: dict[str, Any]) -> DirectUrl:
199 return DirectUrl(
200 url=_get_required(d, str, "url"),
201 subdirectory=_get(d, str, "subdirectory"),
202 info=_exactly_one_of(
203 [
204 ArchiveInfo._from_dict(_get(d, dict, "archive_info")),
205 DirInfo._from_dict(_get(d, dict, "dir_info")),
206 VcsInfo._from_dict(_get(d, dict, "vcs_info")),
207 ]
208 ),
209 )
210
211 def to_dict(self) -> dict[str, Any]:
212 res = _filter_none(
213 url=self.redacted_url,
214 subdirectory=self.subdirectory,
215 )
216 res[self.info.name] = self.info._to_dict()
217 return res
218
219 @classmethod
220 def from_json(cls, s: str) -> DirectUrl:
221 return cls.from_dict(json.loads(s))
222
223 def to_json(self) -> str:
224 return json.dumps(self.to_dict(), sort_keys=True)
225
226 def is_local_editable(self) -> bool:
227 return isinstance(self.info, DirInfo) and self.info.editable