1# 
    2# This file is part of pyasn1 software. 
    3# 
    4# Copyright (c) 2005-2020, Ilya Etingof <etingof@gmail.com> 
    5# License: https://pyasn1.readthedocs.io/en/latest/license.html 
    6# 
    7import warnings 
    8 
    9from pyasn1 import error 
    10from pyasn1.codec.ber import encoder 
    11from pyasn1.type import univ 
    12from pyasn1.type import useful 
    13 
    14__all__ = ['Encoder', 'encode'] 
    15 
    16 
    17class BooleanEncoder(encoder.IntegerEncoder): 
    18    def encodeValue(self, value, asn1Spec, encodeFun, **options): 
    19        if value == 0: 
    20            substrate = (0,) 
    21        else: 
    22            substrate = (255,) 
    23        return substrate, False, False 
    24 
    25 
    26class RealEncoder(encoder.RealEncoder): 
    27    def _chooseEncBase(self, value): 
    28        m, b, e = value 
    29        return self._dropFloatingPoint(m, b, e) 
    30 
    31 
    32# specialized GeneralStringEncoder here 
    33 
    34class TimeEncoderMixIn(object): 
    35    Z_CHAR = ord('Z') 
    36    PLUS_CHAR = ord('+') 
    37    MINUS_CHAR = ord('-') 
    38    COMMA_CHAR = ord(',') 
    39    DOT_CHAR = ord('.') 
    40    ZERO_CHAR = ord('0') 
    41 
    42    MIN_LENGTH = 12 
    43    MAX_LENGTH = 19 
    44 
    45    def encodeValue(self, value, asn1Spec, encodeFun, **options): 
    46        # CER encoding constraints: 
    47        # - minutes are mandatory, seconds are optional 
    48        # - sub-seconds must NOT be zero / no meaningless zeros 
    49        # - no hanging fraction dot 
    50        # - time in UTC (Z) 
    51        # - only dot is allowed for fractions 
    52 
    53        if asn1Spec is not None: 
    54            value = asn1Spec.clone(value) 
    55 
    56        numbers = value.asNumbers() 
    57 
    58        if self.PLUS_CHAR in numbers or self.MINUS_CHAR in numbers: 
    59            raise error.PyAsn1Error('Must be UTC time: %r' % value) 
    60 
    61        if numbers[-1] != self.Z_CHAR: 
    62            raise error.PyAsn1Error('Missing "Z" time zone specifier: %r' % value) 
    63 
    64        if self.COMMA_CHAR in numbers: 
    65            raise error.PyAsn1Error('Comma in fractions disallowed: %r' % value) 
    66 
    67        if self.DOT_CHAR in numbers: 
    68 
    69            isModified = False 
    70 
    71            numbers = list(numbers) 
    72 
    73            searchIndex = min(numbers.index(self.DOT_CHAR) + 4, len(numbers) - 1) 
    74 
    75            while numbers[searchIndex] != self.DOT_CHAR: 
    76                if numbers[searchIndex] == self.ZERO_CHAR: 
    77                    del numbers[searchIndex] 
    78                    isModified = True 
    79 
    80                searchIndex -= 1 
    81 
    82            searchIndex += 1 
    83 
    84            if searchIndex < len(numbers): 
    85                if numbers[searchIndex] == self.Z_CHAR: 
    86                    # drop hanging comma 
    87                    del numbers[searchIndex - 1] 
    88                    isModified = True 
    89 
    90            if isModified: 
    91                value = value.clone(numbers) 
    92 
    93        if not self.MIN_LENGTH < len(numbers) < self.MAX_LENGTH: 
    94            raise error.PyAsn1Error('Length constraint violated: %r' % value) 
    95 
    96        options.update(maxChunkSize=1000) 
    97 
    98        return encoder.OctetStringEncoder.encodeValue( 
    99            self, value, asn1Spec, encodeFun, **options 
    100        ) 
    101 
    102 
    103class GeneralizedTimeEncoder(TimeEncoderMixIn, encoder.OctetStringEncoder): 
    104    MIN_LENGTH = 12 
    105    MAX_LENGTH = 20 
    106 
    107 
    108class UTCTimeEncoder(TimeEncoderMixIn, encoder.OctetStringEncoder): 
    109    MIN_LENGTH = 10 
    110    MAX_LENGTH = 14 
    111 
    112 
    113class SetOfEncoder(encoder.SequenceOfEncoder): 
    114    def encodeValue(self, value, asn1Spec, encodeFun, **options): 
    115        chunks = self._encodeComponents( 
    116            value, asn1Spec, encodeFun, **options) 
    117 
    118        # sort by serialised and padded components 
    119        if len(chunks) > 1: 
    120            zero = b'\x00' 
    121            maxLen = max(map(len, chunks)) 
    122            paddedChunks = [ 
    123                (x.ljust(maxLen, zero), x) for x in chunks 
    124            ] 
    125            paddedChunks.sort(key=lambda x: x[0]) 
    126 
    127            chunks = [x[1] for x in paddedChunks] 
    128 
    129        return b''.join(chunks), True, True 
    130 
    131 
    132class SequenceOfEncoder(encoder.SequenceOfEncoder): 
    133    def encodeValue(self, value, asn1Spec, encodeFun, **options): 
    134 
    135        if options.get('ifNotEmpty', False) and not len(value): 
    136            return b'', True, True 
    137 
    138        chunks = self._encodeComponents( 
    139            value, asn1Spec, encodeFun, **options) 
    140 
    141        return b''.join(chunks), True, True 
    142 
    143 
    144class SetEncoder(encoder.SequenceEncoder): 
    145    @staticmethod 
    146    def _componentSortKey(componentAndType): 
    147        """Sort SET components by tag 
    148 
    149        Sort regardless of the Choice value (static sort) 
    150        """ 
    151        component, asn1Spec = componentAndType 
    152 
    153        if asn1Spec is None: 
    154            asn1Spec = component 
    155 
    156        if asn1Spec.typeId == univ.Choice.typeId and not asn1Spec.tagSet: 
    157            if asn1Spec.tagSet: 
    158                return asn1Spec.tagSet 
    159            else: 
    160                return asn1Spec.componentType.minTagSet 
    161        else: 
    162            return asn1Spec.tagSet 
    163 
    164    def encodeValue(self, value, asn1Spec, encodeFun, **options): 
    165 
    166        substrate = b'' 
    167 
    168        comps = [] 
    169        compsMap = {} 
    170 
    171        if asn1Spec is None: 
    172            # instance of ASN.1 schema 
    173            inconsistency = value.isInconsistent 
    174            if inconsistency: 
    175                raise error.PyAsn1Error( 
    176                    f"ASN.1 object {value.__class__.__name__} is inconsistent") 
    177 
    178            namedTypes = value.componentType 
    179 
    180            for idx, component in enumerate(value.values()): 
    181                if namedTypes: 
    182                    namedType = namedTypes[idx] 
    183 
    184                    if namedType.isOptional and not component.isValue: 
    185                            continue 
    186 
    187                    if namedType.isDefaulted and component == namedType.asn1Object: 
    188                            continue 
    189 
    190                    compsMap[id(component)] = namedType 
    191 
    192                else: 
    193                    compsMap[id(component)] = None 
    194 
    195                comps.append((component, asn1Spec)) 
    196 
    197        else: 
    198            # bare Python value + ASN.1 schema 
    199            for idx, namedType in enumerate(asn1Spec.componentType.namedTypes): 
    200 
    201                try: 
    202                    component = value[namedType.name] 
    203 
    204                except KeyError: 
    205                    raise error.PyAsn1Error('Component name "%s" not found in %r' % (namedType.name, value)) 
    206 
    207                if namedType.isOptional and namedType.name not in value: 
    208                    continue 
    209 
    210                if namedType.isDefaulted and component == namedType.asn1Object: 
    211                    continue 
    212 
    213                compsMap[id(component)] = namedType 
    214                comps.append((component, asn1Spec[idx])) 
    215 
    216        for comp, compType in sorted(comps, key=self._componentSortKey): 
    217            namedType = compsMap[id(comp)] 
    218 
    219            if namedType: 
    220                options.update(ifNotEmpty=namedType.isOptional) 
    221 
    222            chunk = encodeFun(comp, compType, **options) 
    223 
    224            # wrap open type blob if needed 
    225            if namedType and namedType.openType: 
    226                wrapType = namedType.asn1Object 
    227                if wrapType.tagSet and not wrapType.isSameTypeWith(comp): 
    228                    chunk = encodeFun(chunk, wrapType, **options) 
    229 
    230            substrate += chunk 
    231 
    232        return substrate, True, True 
    233 
    234 
    235class SequenceEncoder(encoder.SequenceEncoder): 
    236    omitEmptyOptionals = True 
    237 
    238 
    239TAG_MAP = encoder.TAG_MAP.copy() 
    240 
    241TAG_MAP.update({ 
    242    univ.Boolean.tagSet: BooleanEncoder(), 
    243    univ.Real.tagSet: RealEncoder(), 
    244    useful.GeneralizedTime.tagSet: GeneralizedTimeEncoder(), 
    245    useful.UTCTime.tagSet: UTCTimeEncoder(), 
    246    # Sequence & Set have same tags as SequenceOf & SetOf 
    247    univ.SetOf.tagSet: SetOfEncoder(), 
    248    univ.Sequence.typeId: SequenceEncoder() 
    249}) 
    250 
    251TYPE_MAP = encoder.TYPE_MAP.copy() 
    252 
    253TYPE_MAP.update({ 
    254    univ.Boolean.typeId: BooleanEncoder(), 
    255    univ.Real.typeId: RealEncoder(), 
    256    useful.GeneralizedTime.typeId: GeneralizedTimeEncoder(), 
    257    useful.UTCTime.typeId: UTCTimeEncoder(), 
    258    # Sequence & Set have same tags as SequenceOf & SetOf 
    259    univ.Set.typeId: SetEncoder(), 
    260    univ.SetOf.typeId: SetOfEncoder(), 
    261    univ.Sequence.typeId: SequenceEncoder(), 
    262    univ.SequenceOf.typeId: SequenceOfEncoder() 
    263}) 
    264 
    265 
    266class SingleItemEncoder(encoder.SingleItemEncoder): 
    267    fixedDefLengthMode = False 
    268    fixedChunkSize = 1000 
    269 
    270    TAG_MAP = TAG_MAP 
    271    TYPE_MAP = TYPE_MAP 
    272 
    273 
    274class Encoder(encoder.Encoder): 
    275    SINGLE_ITEM_ENCODER = SingleItemEncoder 
    276 
    277 
    278#: Turns ASN.1 object into CER octet stream. 
    279#: 
    280#: Takes any ASN.1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) 
    281#: walks all its components recursively and produces a CER octet stream. 
    282#: 
    283#: Parameters 
    284#: ---------- 
    285#: value: either a Python or pyasn1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative) 
    286#:     A Python or pyasn1 object to encode. If Python object is given, `asnSpec` 
    287#:     parameter is required to guide the encoding process. 
    288#: 
    289#: Keyword Args 
    290#: ------------ 
    291#: asn1Spec: 
    292#:     Optional ASN.1 schema or value object e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative 
    293#: 
    294#: Returns 
    295#: ------- 
    296#: : :py:class:`bytes` 
    297#:     Given ASN.1 object encoded into BER octet-stream 
    298#: 
    299#: Raises 
    300#: ------ 
    301#: ~pyasn1.error.PyAsn1Error 
    302#:     On encoding errors 
    303#: 
    304#: Examples 
    305#: -------- 
    306#: Encode Python value into CER with ASN.1 schema 
    307#: 
    308#: .. code-block:: pycon 
    309#: 
    310#:    >>> seq = SequenceOf(componentType=Integer()) 
    311#:    >>> encode([1, 2, 3], asn1Spec=seq) 
    312#:    b'0\x80\x02\x01\x01\x02\x01\x02\x02\x01\x03\x00\x00' 
    313#: 
    314#: Encode ASN.1 value object into CER 
    315#: 
    316#: .. code-block:: pycon 
    317#: 
    318#:    >>> seq = SequenceOf(componentType=Integer()) 
    319#:    >>> seq.extend([1, 2, 3]) 
    320#:    >>> encode(seq) 
    321#:    b'0\x80\x02\x01\x01\x02\x01\x02\x02\x01\x03\x00\x00' 
    322#: 
    323encode = Encoder() 
    324 
    325# EncoderFactory queries class instance and builds a map of tags -> encoders 
    326 
    327def __getattr__(attr: str): 
    328    if newAttr := {"tagMap": "TAG_MAP", "typeMap": "TYPE_MAP"}.get(attr): 
    329        warnings.warn(f"{attr} is deprecated. Please use {newAttr} instead.", DeprecationWarning) 
    330        return globals()[newAttr] 
    331    raise AttributeError(attr)