1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import annotations
6
7import binascii
8import re
9import sys
10import typing
11import warnings
12from collections.abc import Iterable, Iterator
13
14from cryptography import utils
15from cryptography.hazmat.bindings._rust import x509 as rust_x509
16from cryptography.x509.oid import NameOID, ObjectIdentifier
17
18
19class _ASN1Type(utils.Enum):
20 BitString = 3
21 OctetString = 4
22 UTF8String = 12
23 NumericString = 18
24 PrintableString = 19
25 T61String = 20
26 IA5String = 22
27 UTCTime = 23
28 GeneralizedTime = 24
29 VisibleString = 26
30 UniversalString = 28
31 BMPString = 30
32
33
34_ASN1_TYPE_TO_ENUM = {i.value: i for i in _ASN1Type}
35_NAMEOID_DEFAULT_TYPE: dict[ObjectIdentifier, _ASN1Type] = {
36 NameOID.COUNTRY_NAME: _ASN1Type.PrintableString,
37 NameOID.JURISDICTION_COUNTRY_NAME: _ASN1Type.PrintableString,
38 NameOID.SERIAL_NUMBER: _ASN1Type.PrintableString,
39 NameOID.DN_QUALIFIER: _ASN1Type.PrintableString,
40 NameOID.EMAIL_ADDRESS: _ASN1Type.IA5String,
41 NameOID.DOMAIN_COMPONENT: _ASN1Type.IA5String,
42}
43
44# Type alias
45_OidNameMap = typing.Mapping[ObjectIdentifier, str]
46_NameOidMap = typing.Mapping[str, ObjectIdentifier]
47
48#: Short attribute names from RFC 4514:
49#: https://tools.ietf.org/html/rfc4514#page-7
50_NAMEOID_TO_NAME: _OidNameMap = {
51 NameOID.COMMON_NAME: "CN",
52 NameOID.LOCALITY_NAME: "L",
53 NameOID.STATE_OR_PROVINCE_NAME: "ST",
54 NameOID.ORGANIZATION_NAME: "O",
55 NameOID.ORGANIZATIONAL_UNIT_NAME: "OU",
56 NameOID.COUNTRY_NAME: "C",
57 NameOID.STREET_ADDRESS: "STREET",
58 NameOID.DOMAIN_COMPONENT: "DC",
59 NameOID.USER_ID: "UID",
60}
61_NAME_TO_NAMEOID = {v: k for k, v in _NAMEOID_TO_NAME.items()}
62
63_NAMEOID_LENGTH_LIMIT = {
64 NameOID.COUNTRY_NAME: (2, 2),
65 NameOID.JURISDICTION_COUNTRY_NAME: (2, 2),
66 NameOID.COMMON_NAME: (1, 64),
67}
68
69
70def _escape_dn_value(val: str | bytes) -> str:
71 """Escape special characters in RFC4514 Distinguished Name value."""
72
73 if not val:
74 return ""
75
76 # RFC 4514 Section 2.4 defines the value as being the # (U+0023) character
77 # followed by the hexadecimal encoding of the octets.
78 if isinstance(val, bytes):
79 return "#" + binascii.hexlify(val).decode("utf8")
80
81 # See https://tools.ietf.org/html/rfc4514#section-2.4
82 val = val.replace("\\", "\\\\")
83 val = val.replace('"', '\\"')
84 val = val.replace("+", "\\+")
85 val = val.replace(",", "\\,")
86 val = val.replace(";", "\\;")
87 val = val.replace("<", "\\<")
88 val = val.replace(">", "\\>")
89 val = val.replace("\0", "\\00")
90
91 if val[0] == "#" or (val[0] == " " and len(val) > 1):
92 val = "\\" + val
93 if val[-1] == " ":
94 val = val[:-1] + "\\ "
95
96 return val
97
98
99def _unescape_dn_value(val: str) -> str:
100 if not val:
101 return ""
102
103 # See https://tools.ietf.org/html/rfc4514#section-3
104
105 # special = escaped / SPACE / SHARP / EQUALS
106 # escaped = DQUOTE / PLUS / COMMA / SEMI / LANGLE / RANGLE
107 def sub(m):
108 val = m.group(0)
109 # Special character escape
110 if len(val) == 2:
111 return val[1:]
112
113 # Unicode string of hex
114 return binascii.unhexlify(val.replace("\\", "")).decode()
115
116 return _RFC4514NameParser._PAIR_MULTI_RE.sub(sub, val)
117
118
119NameAttributeValueType = typing.TypeVar(
120 "NameAttributeValueType",
121 typing.Union[str, bytes],
122 str,
123 bytes,
124 covariant=True,
125)
126
127
128class NameAttribute(typing.Generic[NameAttributeValueType]):
129 def __init__(
130 self,
131 oid: ObjectIdentifier,
132 value: NameAttributeValueType,
133 _type: _ASN1Type | None = None,
134 *,
135 _validate: bool = True,
136 ) -> None:
137 if not isinstance(oid, ObjectIdentifier):
138 raise TypeError(
139 "oid argument must be an ObjectIdentifier instance."
140 )
141 if _type == _ASN1Type.BitString:
142 if oid != NameOID.X500_UNIQUE_IDENTIFIER:
143 raise TypeError(
144 "oid must be X500_UNIQUE_IDENTIFIER for BitString type."
145 )
146 if not isinstance(value, bytes):
147 raise TypeError("value must be bytes for BitString")
148 elif not isinstance(value, str):
149 raise TypeError("value argument must be a str")
150
151 length_limits = _NAMEOID_LENGTH_LIMIT.get(oid)
152 if length_limits is not None:
153 min_length, max_length = length_limits
154 assert isinstance(value, str)
155 c_len = len(value.encode("utf8"))
156 if c_len < min_length or c_len > max_length:
157 msg = (
158 f"Attribute's length must be >= {min_length} and "
159 f"<= {max_length}, but it was {c_len}"
160 )
161 if _validate is True:
162 raise ValueError(msg)
163 else:
164 warnings.warn(msg, stacklevel=2)
165
166 # The appropriate ASN1 string type varies by OID and is defined across
167 # multiple RFCs including 2459, 3280, and 5280. In general UTF8String
168 # is preferred (2459), but 3280 and 5280 specify several OIDs with
169 # alternate types. This means when we see the sentinel value we need
170 # to look up whether the OID has a non-UTF8 type. If it does, set it
171 # to that. Otherwise, UTF8!
172 if _type is None:
173 _type = _NAMEOID_DEFAULT_TYPE.get(oid, _ASN1Type.UTF8String)
174
175 if not isinstance(_type, _ASN1Type):
176 raise TypeError("_type must be from the _ASN1Type enum")
177
178 self._oid = oid
179 self._value: NameAttributeValueType = value
180 self._type: _ASN1Type = _type
181
182 @property
183 def oid(self) -> ObjectIdentifier:
184 return self._oid
185
186 @property
187 def value(self) -> NameAttributeValueType:
188 return self._value
189
190 @property
191 def rfc4514_attribute_name(self) -> str:
192 """
193 The short attribute name (for example "CN") if available,
194 otherwise the OID dotted string.
195 """
196 return _NAMEOID_TO_NAME.get(self.oid, self.oid.dotted_string)
197
198 def rfc4514_string(
199 self, attr_name_overrides: _OidNameMap | None = None
200 ) -> str:
201 """
202 Format as RFC4514 Distinguished Name string.
203
204 Use short attribute name if available, otherwise fall back to OID
205 dotted string.
206 """
207 attr_name = (
208 attr_name_overrides.get(self.oid) if attr_name_overrides else None
209 )
210 if attr_name is None:
211 attr_name = self.rfc4514_attribute_name
212
213 return f"{attr_name}={_escape_dn_value(self.value)}"
214
215 def __eq__(self, other: object) -> bool:
216 if not isinstance(other, NameAttribute):
217 return NotImplemented
218
219 return self.oid == other.oid and self.value == other.value
220
221 def __hash__(self) -> int:
222 return hash((self.oid, self.value))
223
224 def __repr__(self) -> str:
225 return f"<NameAttribute(oid={self.oid}, value={self.value!r})>"
226
227
228class RelativeDistinguishedName:
229 def __init__(self, attributes: Iterable[NameAttribute[str | bytes]]):
230 attributes = list(attributes)
231 if not attributes:
232 raise ValueError("a relative distinguished name cannot be empty")
233 if not all(isinstance(x, NameAttribute) for x in attributes):
234 raise TypeError("attributes must be an iterable of NameAttribute")
235
236 # Keep list and frozenset to preserve attribute order where it matters
237 self._attributes = attributes
238 self._attribute_set = frozenset(attributes)
239
240 if len(self._attribute_set) != len(attributes):
241 raise ValueError("duplicate attributes are not allowed")
242
243 def get_attributes_for_oid(
244 self,
245 oid: ObjectIdentifier,
246 ) -> list[NameAttribute[str | bytes]]:
247 return [i for i in self if i.oid == oid]
248
249 def rfc4514_string(
250 self, attr_name_overrides: _OidNameMap | None = None
251 ) -> str:
252 """
253 Format as RFC4514 Distinguished Name string.
254
255 Within each RDN, attributes are joined by '+', although that is rarely
256 used in certificates.
257 """
258 return "+".join(
259 attr.rfc4514_string(attr_name_overrides)
260 for attr in self._attributes
261 )
262
263 def __eq__(self, other: object) -> bool:
264 if not isinstance(other, RelativeDistinguishedName):
265 return NotImplemented
266
267 return self._attribute_set == other._attribute_set
268
269 def __hash__(self) -> int:
270 return hash(self._attribute_set)
271
272 def __iter__(self) -> Iterator[NameAttribute[str | bytes]]:
273 return iter(self._attributes)
274
275 def __len__(self) -> int:
276 return len(self._attributes)
277
278 def __repr__(self) -> str:
279 return f"<RelativeDistinguishedName({self.rfc4514_string()})>"
280
281
282class Name:
283 @typing.overload
284 def __init__(
285 self, attributes: Iterable[NameAttribute[str | bytes]]
286 ) -> None: ...
287
288 @typing.overload
289 def __init__(
290 self, attributes: Iterable[RelativeDistinguishedName]
291 ) -> None: ...
292
293 def __init__(
294 self,
295 attributes: Iterable[
296 NameAttribute[str | bytes] | RelativeDistinguishedName
297 ],
298 ) -> None:
299 attributes = list(attributes)
300 if all(isinstance(x, NameAttribute) for x in attributes):
301 self._attributes = [
302 RelativeDistinguishedName([typing.cast(NameAttribute, x)])
303 for x in attributes
304 ]
305 elif all(isinstance(x, RelativeDistinguishedName) for x in attributes):
306 self._attributes = typing.cast(
307 typing.List[RelativeDistinguishedName], attributes
308 )
309 else:
310 raise TypeError(
311 "attributes must be a list of NameAttribute"
312 " or a list RelativeDistinguishedName"
313 )
314
315 @classmethod
316 def from_rfc4514_string(
317 cls,
318 data: str,
319 attr_name_overrides: _NameOidMap | None = None,
320 ) -> Name:
321 return _RFC4514NameParser(data, attr_name_overrides or {}).parse()
322
323 def rfc4514_string(
324 self, attr_name_overrides: _OidNameMap | None = None
325 ) -> str:
326 """
327 Format as RFC4514 Distinguished Name string.
328 For example 'CN=foobar.com,O=Foo Corp,C=US'
329
330 An X.509 name is a two-level structure: a list of sets of attributes.
331 Each list element is separated by ',' and within each list element, set
332 elements are separated by '+'. The latter is almost never used in
333 real world certificates. According to RFC4514 section 2.1 the
334 RDNSequence must be reversed when converting to string representation.
335 """
336 return ",".join(
337 attr.rfc4514_string(attr_name_overrides)
338 for attr in reversed(self._attributes)
339 )
340
341 def get_attributes_for_oid(
342 self,
343 oid: ObjectIdentifier,
344 ) -> list[NameAttribute[str | bytes]]:
345 return [i for i in self if i.oid == oid]
346
347 @property
348 def rdns(self) -> list[RelativeDistinguishedName]:
349 return self._attributes
350
351 def public_bytes(self, backend: typing.Any = None) -> bytes:
352 return rust_x509.encode_name_bytes(self)
353
354 def __eq__(self, other: object) -> bool:
355 if not isinstance(other, Name):
356 return NotImplemented
357
358 return self._attributes == other._attributes
359
360 def __hash__(self) -> int:
361 # TODO: this is relatively expensive, if this looks like a bottleneck
362 # for you, consider optimizing!
363 return hash(tuple(self._attributes))
364
365 def __iter__(self) -> Iterator[NameAttribute[str | bytes]]:
366 for rdn in self._attributes:
367 yield from rdn
368
369 def __len__(self) -> int:
370 return sum(len(rdn) for rdn in self._attributes)
371
372 def __repr__(self) -> str:
373 return f"<Name({self.rfc4514_string()})>"
374
375
376class _RFC4514NameParser:
377 _OID_RE = re.compile(r"(0|([1-9]\d*))(\.(0|([1-9]\d*)))+")
378 _DESCR_RE = re.compile(r"[a-zA-Z][a-zA-Z\d-]*")
379
380 _ESCAPE_SPECIAL = r"[\\ #=\"\+,;<>]"
381 _ESCAPE_HEX = r"[\da-zA-Z]{2}"
382 _PAIR = rf"\\({_ESCAPE_SPECIAL}|{_ESCAPE_HEX})"
383 _PAIR_MULTI_RE = re.compile(rf"(\\{_ESCAPE_SPECIAL})|((\\{_ESCAPE_HEX})+)")
384 _LUTF1 = r"[\x01-\x1f\x21\x24-\x2A\x2D-\x3A\x3D\x3F-\x5B\x5D-\x7F]"
385 _SUTF1 = r"[\x01-\x21\x23-\x2A\x2D-\x3A\x3D\x3F-\x5B\x5D-\x7F]"
386 _TUTF1 = r"[\x01-\x1F\x21\x23-\x2A\x2D-\x3A\x3D\x3F-\x5B\x5D-\x7F]"
387 _UTFMB = rf"[\x80-{chr(sys.maxunicode)}]"
388 _LEADCHAR = rf"{_LUTF1}|{_UTFMB}"
389 _STRINGCHAR = rf"{_SUTF1}|{_UTFMB}"
390 _TRAILCHAR = rf"{_TUTF1}|{_UTFMB}"
391 _STRING_RE = re.compile(
392 rf"""
393 (
394 ({_LEADCHAR}|{_PAIR})
395 (
396 ({_STRINGCHAR}|{_PAIR})*
397 ({_TRAILCHAR}|{_PAIR})
398 )?
399 )?
400 """,
401 re.VERBOSE,
402 )
403 _HEXSTRING_RE = re.compile(r"#([\da-zA-Z]{2})+")
404
405 def __init__(self, data: str, attr_name_overrides: _NameOidMap) -> None:
406 self._data = data
407 self._idx = 0
408
409 self._attr_name_overrides = attr_name_overrides
410
411 def _has_data(self) -> bool:
412 return self._idx < len(self._data)
413
414 def _peek(self) -> str | None:
415 if self._has_data():
416 return self._data[self._idx]
417 return None
418
419 def _read_char(self, ch: str) -> None:
420 if self._peek() != ch:
421 raise ValueError
422 self._idx += 1
423
424 def _read_re(self, pat) -> str:
425 match = pat.match(self._data, pos=self._idx)
426 if match is None:
427 raise ValueError
428 val = match.group()
429 self._idx += len(val)
430 return val
431
432 def parse(self) -> Name:
433 """
434 Parses the `data` string and converts it to a Name.
435
436 According to RFC4514 section 2.1 the RDNSequence must be
437 reversed when converting to string representation. So, when
438 we parse it, we need to reverse again to get the RDNs on the
439 correct order.
440 """
441
442 if not self._has_data():
443 return Name([])
444
445 rdns = [self._parse_rdn()]
446
447 while self._has_data():
448 self._read_char(",")
449 rdns.append(self._parse_rdn())
450
451 return Name(reversed(rdns))
452
453 def _parse_rdn(self) -> RelativeDistinguishedName:
454 nas = [self._parse_na()]
455 while self._peek() == "+":
456 self._read_char("+")
457 nas.append(self._parse_na())
458
459 return RelativeDistinguishedName(nas)
460
461 def _parse_na(self) -> NameAttribute[str]:
462 try:
463 oid_value = self._read_re(self._OID_RE)
464 except ValueError:
465 name = self._read_re(self._DESCR_RE)
466 oid = self._attr_name_overrides.get(
467 name, _NAME_TO_NAMEOID.get(name)
468 )
469 if oid is None:
470 raise ValueError
471 else:
472 oid = ObjectIdentifier(oid_value)
473
474 self._read_char("=")
475 if self._peek() == "#":
476 value = self._read_re(self._HEXSTRING_RE)
477 value = binascii.unhexlify(value[1:]).decode()
478 else:
479 raw_value = self._read_re(self._STRING_RE)
480 value = _unescape_dn_value(raw_value)
481
482 return NameAttribute(oid, value)