1# This file is dual licensed under the terms of the Apache License, Version 
    2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 
    3# for complete details. 
    4 
    5from __future__ import annotations 
    6 
    7import binascii 
    8import re 
    9import sys 
    10import typing 
    11import warnings 
    12from collections.abc import Iterable, Iterator 
    13 
    14from cryptography import utils 
    15from cryptography.hazmat.bindings._rust import x509 as rust_x509 
    16from cryptography.x509.oid import NameOID, ObjectIdentifier 
    17 
    18 
    19class _ASN1Type(utils.Enum): 
    20    BitString = 3 
    21    OctetString = 4 
    22    UTF8String = 12 
    23    NumericString = 18 
    24    PrintableString = 19 
    25    T61String = 20 
    26    IA5String = 22 
    27    UTCTime = 23 
    28    GeneralizedTime = 24 
    29    VisibleString = 26 
    30    UniversalString = 28 
    31    BMPString = 30 
    32 
    33 
    34_ASN1_TYPE_TO_ENUM = {i.value: i for i in _ASN1Type} 
    35_NAMEOID_DEFAULT_TYPE: dict[ObjectIdentifier, _ASN1Type] = { 
    36    NameOID.COUNTRY_NAME: _ASN1Type.PrintableString, 
    37    NameOID.JURISDICTION_COUNTRY_NAME: _ASN1Type.PrintableString, 
    38    NameOID.SERIAL_NUMBER: _ASN1Type.PrintableString, 
    39    NameOID.DN_QUALIFIER: _ASN1Type.PrintableString, 
    40    NameOID.EMAIL_ADDRESS: _ASN1Type.IA5String, 
    41    NameOID.DOMAIN_COMPONENT: _ASN1Type.IA5String, 
    42} 
    43 
    44# Type alias 
    45_OidNameMap = typing.Mapping[ObjectIdentifier, str] 
    46_NameOidMap = typing.Mapping[str, ObjectIdentifier] 
    47 
    48#: Short attribute names from RFC 4514: 
    49#: https://tools.ietf.org/html/rfc4514#page-7 
    50_NAMEOID_TO_NAME: _OidNameMap = { 
    51    NameOID.COMMON_NAME: "CN", 
    52    NameOID.LOCALITY_NAME: "L", 
    53    NameOID.STATE_OR_PROVINCE_NAME: "ST", 
    54    NameOID.ORGANIZATION_NAME: "O", 
    55    NameOID.ORGANIZATIONAL_UNIT_NAME: "OU", 
    56    NameOID.COUNTRY_NAME: "C", 
    57    NameOID.STREET_ADDRESS: "STREET", 
    58    NameOID.DOMAIN_COMPONENT: "DC", 
    59    NameOID.USER_ID: "UID", 
    60} 
    61_NAME_TO_NAMEOID = {v: k for k, v in _NAMEOID_TO_NAME.items()} 
    62 
    63_NAMEOID_LENGTH_LIMIT = { 
    64    NameOID.COUNTRY_NAME: (2, 2), 
    65    NameOID.JURISDICTION_COUNTRY_NAME: (2, 2), 
    66    NameOID.COMMON_NAME: (1, 64), 
    67} 
    68 
    69 
    70def _escape_dn_value(val: str | bytes) -> str: 
    71    """Escape special characters in RFC4514 Distinguished Name value.""" 
    72 
    73    if not val: 
    74        return "" 
    75 
    76    # RFC 4514 Section 2.4 defines the value as being the # (U+0023) character 
    77    # followed by the hexadecimal encoding of the octets. 
    78    if isinstance(val, bytes): 
    79        return "#" + binascii.hexlify(val).decode("utf8") 
    80 
    81    # See https://tools.ietf.org/html/rfc4514#section-2.4 
    82    val = val.replace("\\", "\\\\") 
    83    val = val.replace('"', '\\"') 
    84    val = val.replace("+", "\\+") 
    85    val = val.replace(",", "\\,") 
    86    val = val.replace(";", "\\;") 
    87    val = val.replace("<", "\\<") 
    88    val = val.replace(">", "\\>") 
    89    val = val.replace("\0", "\\00") 
    90 
    91    if val[0] in ("#", " "): 
    92        val = "\\" + val 
    93    if val[-1] == " ": 
    94        val = val[:-1] + "\\ " 
    95 
    96    return val 
    97 
    98 
    99def _unescape_dn_value(val: str) -> str: 
    100    if not val: 
    101        return "" 
    102 
    103    # See https://tools.ietf.org/html/rfc4514#section-3 
    104 
    105    # special = escaped / SPACE / SHARP / EQUALS 
    106    # escaped = DQUOTE / PLUS / COMMA / SEMI / LANGLE / RANGLE 
    107    def sub(m): 
    108        val = m.group(1) 
    109        # Regular escape 
    110        if len(val) == 1: 
    111            return val 
    112        # Hex-value scape 
    113        return chr(int(val, 16)) 
    114 
    115    return _RFC4514NameParser._PAIR_RE.sub(sub, val) 
    116 
    117 
    118NameAttributeValueType = typing.TypeVar( 
    119    "NameAttributeValueType", 
    120    typing.Union[str, bytes], 
    121    str, 
    122    bytes, 
    123    covariant=True, 
    124) 
    125 
    126 
    127class NameAttribute(typing.Generic[NameAttributeValueType]): 
    128    def __init__( 
    129        self, 
    130        oid: ObjectIdentifier, 
    131        value: NameAttributeValueType, 
    132        _type: _ASN1Type | None = None, 
    133        *, 
    134        _validate: bool = True, 
    135    ) -> None: 
    136        if not isinstance(oid, ObjectIdentifier): 
    137            raise TypeError( 
    138                "oid argument must be an ObjectIdentifier instance." 
    139            ) 
    140        if _type == _ASN1Type.BitString: 
    141            if oid != NameOID.X500_UNIQUE_IDENTIFIER: 
    142                raise TypeError( 
    143                    "oid must be X500_UNIQUE_IDENTIFIER for BitString type." 
    144                ) 
    145            if not isinstance(value, bytes): 
    146                raise TypeError("value must be bytes for BitString") 
    147        elif not isinstance(value, str): 
    148            raise TypeError("value argument must be a str") 
    149 
    150        length_limits = _NAMEOID_LENGTH_LIMIT.get(oid) 
    151        if length_limits is not None: 
    152            min_length, max_length = length_limits 
    153            assert isinstance(value, str) 
    154            c_len = len(value.encode("utf8")) 
    155            if c_len < min_length or c_len > max_length: 
    156                msg = ( 
    157                    f"Attribute's length must be >= {min_length} and " 
    158                    f"<= {max_length}, but it was {c_len}" 
    159                ) 
    160                if _validate is True: 
    161                    raise ValueError(msg) 
    162                else: 
    163                    warnings.warn(msg, stacklevel=2) 
    164 
    165        # The appropriate ASN1 string type varies by OID and is defined across 
    166        # multiple RFCs including 2459, 3280, and 5280. In general UTF8String 
    167        # is preferred (2459), but 3280 and 5280 specify several OIDs with 
    168        # alternate types. This means when we see the sentinel value we need 
    169        # to look up whether the OID has a non-UTF8 type. If it does, set it 
    170        # to that. Otherwise, UTF8! 
    171        if _type is None: 
    172            _type = _NAMEOID_DEFAULT_TYPE.get(oid, _ASN1Type.UTF8String) 
    173 
    174        if not isinstance(_type, _ASN1Type): 
    175            raise TypeError("_type must be from the _ASN1Type enum") 
    176 
    177        self._oid = oid 
    178        self._value: NameAttributeValueType = value 
    179        self._type: _ASN1Type = _type 
    180 
    181    @property 
    182    def oid(self) -> ObjectIdentifier: 
    183        return self._oid 
    184 
    185    @property 
    186    def value(self) -> NameAttributeValueType: 
    187        return self._value 
    188 
    189    @property 
    190    def rfc4514_attribute_name(self) -> str: 
    191        """ 
    192        The short attribute name (for example "CN") if available, 
    193        otherwise the OID dotted string. 
    194        """ 
    195        return _NAMEOID_TO_NAME.get(self.oid, self.oid.dotted_string) 
    196 
    197    def rfc4514_string( 
    198        self, attr_name_overrides: _OidNameMap | None = None 
    199    ) -> str: 
    200        """ 
    201        Format as RFC4514 Distinguished Name string. 
    202 
    203        Use short attribute name if available, otherwise fall back to OID 
    204        dotted string. 
    205        """ 
    206        attr_name = ( 
    207            attr_name_overrides.get(self.oid) if attr_name_overrides else None 
    208        ) 
    209        if attr_name is None: 
    210            attr_name = self.rfc4514_attribute_name 
    211 
    212        return f"{attr_name}={_escape_dn_value(self.value)}" 
    213 
    214    def __eq__(self, other: object) -> bool: 
    215        if not isinstance(other, NameAttribute): 
    216            return NotImplemented 
    217 
    218        return self.oid == other.oid and self.value == other.value 
    219 
    220    def __hash__(self) -> int: 
    221        return hash((self.oid, self.value)) 
    222 
    223    def __repr__(self) -> str: 
    224        return f"<NameAttribute(oid={self.oid}, value={self.value!r})>" 
    225 
    226 
    227class RelativeDistinguishedName: 
    228    def __init__(self, attributes: Iterable[NameAttribute]): 
    229        attributes = list(attributes) 
    230        if not attributes: 
    231            raise ValueError("a relative distinguished name cannot be empty") 
    232        if not all(isinstance(x, NameAttribute) for x in attributes): 
    233            raise TypeError("attributes must be an iterable of NameAttribute") 
    234 
    235        # Keep list and frozenset to preserve attribute order where it matters 
    236        self._attributes = attributes 
    237        self._attribute_set = frozenset(attributes) 
    238 
    239        if len(self._attribute_set) != len(attributes): 
    240            raise ValueError("duplicate attributes are not allowed") 
    241 
    242    def get_attributes_for_oid( 
    243        self, 
    244        oid: ObjectIdentifier, 
    245    ) -> list[NameAttribute[str | bytes]]: 
    246        return [i for i in self if i.oid == oid] 
    247 
    248    def rfc4514_string( 
    249        self, attr_name_overrides: _OidNameMap | None = None 
    250    ) -> str: 
    251        """ 
    252        Format as RFC4514 Distinguished Name string. 
    253 
    254        Within each RDN, attributes are joined by '+', although that is rarely 
    255        used in certificates. 
    256        """ 
    257        return "+".join( 
    258            attr.rfc4514_string(attr_name_overrides) 
    259            for attr in self._attributes 
    260        ) 
    261 
    262    def __eq__(self, other: object) -> bool: 
    263        if not isinstance(other, RelativeDistinguishedName): 
    264            return NotImplemented 
    265 
    266        return self._attribute_set == other._attribute_set 
    267 
    268    def __hash__(self) -> int: 
    269        return hash(self._attribute_set) 
    270 
    271    def __iter__(self) -> Iterator[NameAttribute]: 
    272        return iter(self._attributes) 
    273 
    274    def __len__(self) -> int: 
    275        return len(self._attributes) 
    276 
    277    def __repr__(self) -> str: 
    278        return f"<RelativeDistinguishedName({self.rfc4514_string()})>" 
    279 
    280 
    281class Name: 
    282    @typing.overload 
    283    def __init__(self, attributes: Iterable[NameAttribute]) -> None: ... 
    284 
    285    @typing.overload 
    286    def __init__( 
    287        self, attributes: Iterable[RelativeDistinguishedName] 
    288    ) -> None: ... 
    289 
    290    def __init__( 
    291        self, 
    292        attributes: Iterable[NameAttribute | RelativeDistinguishedName], 
    293    ) -> None: 
    294        attributes = list(attributes) 
    295        if all(isinstance(x, NameAttribute) for x in attributes): 
    296            self._attributes = [ 
    297                RelativeDistinguishedName([typing.cast(NameAttribute, x)]) 
    298                for x in attributes 
    299            ] 
    300        elif all(isinstance(x, RelativeDistinguishedName) for x in attributes): 
    301            self._attributes = typing.cast( 
    302                typing.List[RelativeDistinguishedName], attributes 
    303            ) 
    304        else: 
    305            raise TypeError( 
    306                "attributes must be a list of NameAttribute" 
    307                " or a list RelativeDistinguishedName" 
    308            ) 
    309 
    310    @classmethod 
    311    def from_rfc4514_string( 
    312        cls, 
    313        data: str, 
    314        attr_name_overrides: _NameOidMap | None = None, 
    315    ) -> Name: 
    316        return _RFC4514NameParser(data, attr_name_overrides or {}).parse() 
    317 
    318    def rfc4514_string( 
    319        self, attr_name_overrides: _OidNameMap | None = None 
    320    ) -> str: 
    321        """ 
    322        Format as RFC4514 Distinguished Name string. 
    323        For example 'CN=foobar.com,O=Foo Corp,C=US' 
    324 
    325        An X.509 name is a two-level structure: a list of sets of attributes. 
    326        Each list element is separated by ',' and within each list element, set 
    327        elements are separated by '+'. The latter is almost never used in 
    328        real world certificates. According to RFC4514 section 2.1 the 
    329        RDNSequence must be reversed when converting to string representation. 
    330        """ 
    331        return ",".join( 
    332            attr.rfc4514_string(attr_name_overrides) 
    333            for attr in reversed(self._attributes) 
    334        ) 
    335 
    336    def get_attributes_for_oid( 
    337        self, 
    338        oid: ObjectIdentifier, 
    339    ) -> list[NameAttribute[str | bytes]]: 
    340        return [i for i in self if i.oid == oid] 
    341 
    342    @property 
    343    def rdns(self) -> list[RelativeDistinguishedName]: 
    344        return self._attributes 
    345 
    346    def public_bytes(self, backend: typing.Any = None) -> bytes: 
    347        return rust_x509.encode_name_bytes(self) 
    348 
    349    def __eq__(self, other: object) -> bool: 
    350        if not isinstance(other, Name): 
    351            return NotImplemented 
    352 
    353        return self._attributes == other._attributes 
    354 
    355    def __hash__(self) -> int: 
    356        # TODO: this is relatively expensive, if this looks like a bottleneck 
    357        # for you, consider optimizing! 
    358        return hash(tuple(self._attributes)) 
    359 
    360    def __iter__(self) -> Iterator[NameAttribute]: 
    361        for rdn in self._attributes: 
    362            yield from rdn 
    363 
    364    def __len__(self) -> int: 
    365        return sum(len(rdn) for rdn in self._attributes) 
    366 
    367    def __repr__(self) -> str: 
    368        rdns = ",".join(attr.rfc4514_string() for attr in self._attributes) 
    369        return f"<Name({rdns})>" 
    370 
    371 
    372class _RFC4514NameParser: 
    373    _OID_RE = re.compile(r"(0|([1-9]\d*))(\.(0|([1-9]\d*)))+") 
    374    _DESCR_RE = re.compile(r"[a-zA-Z][a-zA-Z\d-]*") 
    375 
    376    _PAIR = r"\\([\\ #=\"\+,;<>]|[\da-zA-Z]{2})" 
    377    _PAIR_RE = re.compile(_PAIR) 
    378    _LUTF1 = r"[\x01-\x1f\x21\x24-\x2A\x2D-\x3A\x3D\x3F-\x5B\x5D-\x7F]" 
    379    _SUTF1 = r"[\x01-\x21\x23-\x2A\x2D-\x3A\x3D\x3F-\x5B\x5D-\x7F]" 
    380    _TUTF1 = r"[\x01-\x1F\x21\x23-\x2A\x2D-\x3A\x3D\x3F-\x5B\x5D-\x7F]" 
    381    _UTFMB = rf"[\x80-{chr(sys.maxunicode)}]" 
    382    _LEADCHAR = rf"{_LUTF1}|{_UTFMB}" 
    383    _STRINGCHAR = rf"{_SUTF1}|{_UTFMB}" 
    384    _TRAILCHAR = rf"{_TUTF1}|{_UTFMB}" 
    385    _STRING_RE = re.compile( 
    386        rf""" 
    387        ( 
    388            ({_LEADCHAR}|{_PAIR}) 
    389            ( 
    390                ({_STRINGCHAR}|{_PAIR})* 
    391                ({_TRAILCHAR}|{_PAIR}) 
    392            )? 
    393        )? 
    394        """, 
    395        re.VERBOSE, 
    396    ) 
    397    _HEXSTRING_RE = re.compile(r"#([\da-zA-Z]{2})+") 
    398 
    399    def __init__(self, data: str, attr_name_overrides: _NameOidMap) -> None: 
    400        self._data = data 
    401        self._idx = 0 
    402 
    403        self._attr_name_overrides = attr_name_overrides 
    404 
    405    def _has_data(self) -> bool: 
    406        return self._idx < len(self._data) 
    407 
    408    def _peek(self) -> str | None: 
    409        if self._has_data(): 
    410            return self._data[self._idx] 
    411        return None 
    412 
    413    def _read_char(self, ch: str) -> None: 
    414        if self._peek() != ch: 
    415            raise ValueError 
    416        self._idx += 1 
    417 
    418    def _read_re(self, pat) -> str: 
    419        match = pat.match(self._data, pos=self._idx) 
    420        if match is None: 
    421            raise ValueError 
    422        val = match.group() 
    423        self._idx += len(val) 
    424        return val 
    425 
    426    def parse(self) -> Name: 
    427        """ 
    428        Parses the `data` string and converts it to a Name. 
    429 
    430        According to RFC4514 section 2.1 the RDNSequence must be 
    431        reversed when converting to string representation. So, when 
    432        we parse it, we need to reverse again to get the RDNs on the 
    433        correct order. 
    434        """ 
    435 
    436        if not self._has_data(): 
    437            return Name([]) 
    438 
    439        rdns = [self._parse_rdn()] 
    440 
    441        while self._has_data(): 
    442            self._read_char(",") 
    443            rdns.append(self._parse_rdn()) 
    444 
    445        return Name(reversed(rdns)) 
    446 
    447    def _parse_rdn(self) -> RelativeDistinguishedName: 
    448        nas = [self._parse_na()] 
    449        while self._peek() == "+": 
    450            self._read_char("+") 
    451            nas.append(self._parse_na()) 
    452 
    453        return RelativeDistinguishedName(nas) 
    454 
    455    def _parse_na(self) -> NameAttribute: 
    456        try: 
    457            oid_value = self._read_re(self._OID_RE) 
    458        except ValueError: 
    459            name = self._read_re(self._DESCR_RE) 
    460            oid = self._attr_name_overrides.get( 
    461                name, _NAME_TO_NAMEOID.get(name) 
    462            ) 
    463            if oid is None: 
    464                raise ValueError 
    465        else: 
    466            oid = ObjectIdentifier(oid_value) 
    467 
    468        self._read_char("=") 
    469        if self._peek() == "#": 
    470            value = self._read_re(self._HEXSTRING_RE) 
    471            value = binascii.unhexlify(value[1:]).decode() 
    472        else: 
    473            raw_value = self._read_re(self._STRING_RE) 
    474            value = _unescape_dn_value(raw_value) 
    475 
    476        return NameAttribute(oid, value)