1""" 
    2TLS with SNI_-support for Python 2. Follow these instructions if you would 
    3like to verify TLS certificates in Python 2. Note, the default libraries do 
    4*not* do certificate checking; you need to do additional work to validate 
    5certificates yourself. 
    6 
    7This needs the following packages installed: 
    8 
    9* `pyOpenSSL`_ (tested with 16.0.0) 
    10* `cryptography`_ (minimum 1.3.4, from pyopenssl) 
    11* `idna`_ (minimum 2.0, from cryptography) 
    12 
    13However, pyopenssl depends on cryptography, which depends on idna, so while we 
    14use all three directly here we end up having relatively few packages required. 
    15 
    16You can install them with the following command: 
    17 
    18.. code-block:: bash 
    19 
    20    $ python -m pip install pyopenssl cryptography idna 
    21 
    22To activate certificate checking, call 
    23:func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code 
    24before you begin making HTTP requests. This can be done in a ``sitecustomize`` 
    25module, or at any other time before your application begins using ``urllib3``, 
    26like this: 
    27 
    28.. code-block:: python 
    29 
    30    try: 
    31        import urllib3.contrib.pyopenssl 
    32        urllib3.contrib.pyopenssl.inject_into_urllib3() 
    33    except ImportError: 
    34        pass 
    35 
    36Now you can use :mod:`urllib3` as you normally would, and it will support SNI 
    37when the required modules are installed. 
    38 
    39Activating this module also has the positive side effect of disabling SSL/TLS 
    40compression in Python 2 (see `CRIME attack`_). 
    41 
    42.. _sni: https://en.wikipedia.org/wiki/Server_Name_Indication 
    43.. _crime attack: https://en.wikipedia.org/wiki/CRIME_(security_exploit) 
    44.. _pyopenssl: https://www.pyopenssl.org 
    45.. _cryptography: https://cryptography.io 
    46.. _idna: https://github.com/kjd/idna 
    47""" 
    48from __future__ import absolute_import 
    49 
    50import OpenSSL.crypto 
    51import OpenSSL.SSL 
    52from cryptography import x509 
    53from cryptography.hazmat.backends.openssl import backend as openssl_backend 
    54 
    55try: 
    56    from cryptography.x509 import UnsupportedExtension 
    57except ImportError: 
    58    # UnsupportedExtension is gone in cryptography >= 2.1.0 
    59    class UnsupportedExtension(Exception): 
    60        pass 
    61 
    62 
    63from io import BytesIO 
    64from socket import error as SocketError 
    65from socket import timeout 
    66 
    67try:  # Platform-specific: Python 2 
    68    from socket import _fileobject 
    69except ImportError:  # Platform-specific: Python 3 
    70    _fileobject = None 
    71    from ..packages.backports.makefile import backport_makefile 
    72 
    73import logging 
    74import ssl 
    75import sys 
    76import warnings 
    77 
    78from .. import util 
    79from ..packages import six 
    80from ..util.ssl_ import PROTOCOL_TLS_CLIENT 
    81 
    82warnings.warn( 
    83    "'urllib3.contrib.pyopenssl' module is deprecated and will be removed " 
    84    "in a future release of urllib3 2.x. Read more in this issue: " 
    85    "https://github.com/urllib3/urllib3/issues/2680", 
    86    category=DeprecationWarning, 
    87    stacklevel=2, 
    88) 
    89 
    90__all__ = ["inject_into_urllib3", "extract_from_urllib3"] 
    91 
    92# SNI always works. 
    93HAS_SNI = True 
    94 
    95# Map from urllib3 to PyOpenSSL compatible parameter-values. 
    96_openssl_versions = { 
    97    util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, 
    98    PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, 
    99    ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, 
    100} 
    101 
    102if hasattr(ssl, "PROTOCOL_SSLv3") and hasattr(OpenSSL.SSL, "SSLv3_METHOD"): 
    103    _openssl_versions[ssl.PROTOCOL_SSLv3] = OpenSSL.SSL.SSLv3_METHOD 
    104 
    105if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"): 
    106    _openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD 
    107 
    108if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"): 
    109    _openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD 
    110 
    111 
    112_stdlib_to_openssl_verify = { 
    113    ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE, 
    114    ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, 
    115    ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER 
    116    + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, 
    117} 
    118_openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items()) 
    119 
    120# OpenSSL will only write 16K at a time 
    121SSL_WRITE_BLOCKSIZE = 16384 
    122 
    123orig_util_HAS_SNI = util.HAS_SNI 
    124orig_util_SSLContext = util.ssl_.SSLContext 
    125 
    126 
    127log = logging.getLogger(__name__) 
    128 
    129 
    130def inject_into_urllib3(): 
    131    "Monkey-patch urllib3 with PyOpenSSL-backed SSL-support." 
    132 
    133    _validate_dependencies_met() 
    134 
    135    util.SSLContext = PyOpenSSLContext 
    136    util.ssl_.SSLContext = PyOpenSSLContext 
    137    util.HAS_SNI = HAS_SNI 
    138    util.ssl_.HAS_SNI = HAS_SNI 
    139    util.IS_PYOPENSSL = True 
    140    util.ssl_.IS_PYOPENSSL = True 
    141 
    142 
    143def extract_from_urllib3(): 
    144    "Undo monkey-patching by :func:`inject_into_urllib3`." 
    145 
    146    util.SSLContext = orig_util_SSLContext 
    147    util.ssl_.SSLContext = orig_util_SSLContext 
    148    util.HAS_SNI = orig_util_HAS_SNI 
    149    util.ssl_.HAS_SNI = orig_util_HAS_SNI 
    150    util.IS_PYOPENSSL = False 
    151    util.ssl_.IS_PYOPENSSL = False 
    152 
    153 
    154def _validate_dependencies_met(): 
    155    """ 
    156    Verifies that PyOpenSSL's package-level dependencies have been met. 
    157    Throws `ImportError` if they are not met. 
    158    """ 
    159    # Method added in `cryptography==1.1`; not available in older versions 
    160    from cryptography.x509.extensions import Extensions 
    161 
    162    if getattr(Extensions, "get_extension_for_class", None) is None: 
    163        raise ImportError( 
    164            "'cryptography' module missing required functionality.  " 
    165            "Try upgrading to v1.3.4 or newer." 
    166        ) 
    167 
    168    # pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509 
    169    # attribute is only present on those versions. 
    170    from OpenSSL.crypto import X509 
    171 
    172    x509 = X509() 
    173    if getattr(x509, "_x509", None) is None: 
    174        raise ImportError( 
    175            "'pyOpenSSL' module missing required functionality. " 
    176            "Try upgrading to v0.14 or newer." 
    177        ) 
    178 
    179 
    180def _dnsname_to_stdlib(name): 
    181    """ 
    182    Converts a dNSName SubjectAlternativeName field to the form used by the 
    183    standard library on the given Python version. 
    184 
    185    Cryptography produces a dNSName as a unicode string that was idna-decoded 
    186    from ASCII bytes. We need to idna-encode that string to get it back, and 
    187    then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib 
    188    uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8). 
    189 
    190    If the name cannot be idna-encoded then we return None signalling that 
    191    the name given should be skipped. 
    192    """ 
    193 
    194    def idna_encode(name): 
    195        """ 
    196        Borrowed wholesale from the Python Cryptography Project. It turns out 
    197        that we can't just safely call `idna.encode`: it can explode for 
    198        wildcard names. This avoids that problem. 
    199        """ 
    200        import idna 
    201 
    202        try: 
    203            for prefix in [u"*.", u"."]: 
    204                if name.startswith(prefix): 
    205                    name = name[len(prefix) :] 
    206                    return prefix.encode("ascii") + idna.encode(name) 
    207            return idna.encode(name) 
    208        except idna.core.IDNAError: 
    209            return None 
    210 
    211    # Don't send IPv6 addresses through the IDNA encoder. 
    212    if ":" in name: 
    213        return name 
    214 
    215    name = idna_encode(name) 
    216    if name is None: 
    217        return None 
    218    elif sys.version_info >= (3, 0): 
    219        name = name.decode("utf-8") 
    220    return name 
    221 
    222 
    223def get_subj_alt_name(peer_cert): 
    224    """ 
    225    Given an PyOpenSSL certificate, provides all the subject alternative names. 
    226    """ 
    227    # Pass the cert to cryptography, which has much better APIs for this. 
    228    if hasattr(peer_cert, "to_cryptography"): 
    229        cert = peer_cert.to_cryptography() 
    230    else: 
    231        der = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, peer_cert) 
    232        cert = x509.load_der_x509_certificate(der, openssl_backend) 
    233 
    234    # We want to find the SAN extension. Ask Cryptography to locate it (it's 
    235    # faster than looping in Python) 
    236    try: 
    237        ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value 
    238    except x509.ExtensionNotFound: 
    239        # No such extension, return the empty list. 
    240        return [] 
    241    except ( 
    242        x509.DuplicateExtension, 
    243        UnsupportedExtension, 
    244        x509.UnsupportedGeneralNameType, 
    245        UnicodeError, 
    246    ) as e: 
    247        # A problem has been found with the quality of the certificate. Assume 
    248        # no SAN field is present. 
    249        log.warning( 
    250            "A problem was encountered with the certificate that prevented " 
    251            "urllib3 from finding the SubjectAlternativeName field. This can " 
    252            "affect certificate validation. The error was %s", 
    253            e, 
    254        ) 
    255        return [] 
    256 
    257    # We want to return dNSName and iPAddress fields. We need to cast the IPs 
    258    # back to strings because the match_hostname function wants them as 
    259    # strings. 
    260    # Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8 
    261    # decoded. This is pretty frustrating, but that's what the standard library 
    262    # does with certificates, and so we need to attempt to do the same. 
    263    # We also want to skip over names which cannot be idna encoded. 
    264    names = [ 
    265        ("DNS", name) 
    266        for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName)) 
    267        if name is not None 
    268    ] 
    269    names.extend( 
    270        ("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress) 
    271    ) 
    272 
    273    return names 
    274 
    275 
    276class WrappedSocket(object): 
    277    """API-compatibility wrapper for Python OpenSSL's Connection-class. 
    278 
    279    Note: _makefile_refs, _drop() and _reuse() are needed for the garbage 
    280    collector of pypy. 
    281    """ 
    282 
    283    def __init__(self, connection, socket, suppress_ragged_eofs=True): 
    284        self.connection = connection 
    285        self.socket = socket 
    286        self.suppress_ragged_eofs = suppress_ragged_eofs 
    287        self._makefile_refs = 0 
    288        self._closed = False 
    289 
    290    def fileno(self): 
    291        return self.socket.fileno() 
    292 
    293    # Copy-pasted from Python 3.5 source code 
    294    def _decref_socketios(self): 
    295        if self._makefile_refs > 0: 
    296            self._makefile_refs -= 1 
    297        if self._closed: 
    298            self.close() 
    299 
    300    def recv(self, *args, **kwargs): 
    301        try: 
    302            data = self.connection.recv(*args, **kwargs) 
    303        except OpenSSL.SSL.SysCallError as e: 
    304            if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): 
    305                return b"" 
    306            else: 
    307                raise SocketError(str(e)) 
    308        except OpenSSL.SSL.ZeroReturnError: 
    309            if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: 
    310                return b"" 
    311            else: 
    312                raise 
    313        except OpenSSL.SSL.WantReadError: 
    314            if not util.wait_for_read(self.socket, self.socket.gettimeout()): 
    315                raise timeout("The read operation timed out") 
    316            else: 
    317                return self.recv(*args, **kwargs) 
    318 
    319        # TLS 1.3 post-handshake authentication 
    320        except OpenSSL.SSL.Error as e: 
    321            raise ssl.SSLError("read error: %r" % e) 
    322        else: 
    323            return data 
    324 
    325    def recv_into(self, *args, **kwargs): 
    326        try: 
    327            return self.connection.recv_into(*args, **kwargs) 
    328        except OpenSSL.SSL.SysCallError as e: 
    329            if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): 
    330                return 0 
    331            else: 
    332                raise SocketError(str(e)) 
    333        except OpenSSL.SSL.ZeroReturnError: 
    334            if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: 
    335                return 0 
    336            else: 
    337                raise 
    338        except OpenSSL.SSL.WantReadError: 
    339            if not util.wait_for_read(self.socket, self.socket.gettimeout()): 
    340                raise timeout("The read operation timed out") 
    341            else: 
    342                return self.recv_into(*args, **kwargs) 
    343 
    344        # TLS 1.3 post-handshake authentication 
    345        except OpenSSL.SSL.Error as e: 
    346            raise ssl.SSLError("read error: %r" % e) 
    347 
    348    def settimeout(self, timeout): 
    349        return self.socket.settimeout(timeout) 
    350 
    351    def _send_until_done(self, data): 
    352        while True: 
    353            try: 
    354                return self.connection.send(data) 
    355            except OpenSSL.SSL.WantWriteError: 
    356                if not util.wait_for_write(self.socket, self.socket.gettimeout()): 
    357                    raise timeout() 
    358                continue 
    359            except OpenSSL.SSL.SysCallError as e: 
    360                raise SocketError(str(e)) 
    361 
    362    def sendall(self, data): 
    363        total_sent = 0 
    364        while total_sent < len(data): 
    365            sent = self._send_until_done( 
    366                data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE] 
    367            ) 
    368            total_sent += sent 
    369 
    370    def shutdown(self): 
    371        # FIXME rethrow compatible exceptions should we ever use this 
    372        self.connection.shutdown() 
    373 
    374    def close(self): 
    375        if self._makefile_refs < 1: 
    376            try: 
    377                self._closed = True 
    378                return self.connection.close() 
    379            except OpenSSL.SSL.Error: 
    380                return 
    381        else: 
    382            self._makefile_refs -= 1 
    383 
    384    def getpeercert(self, binary_form=False): 
    385        x509 = self.connection.get_peer_certificate() 
    386 
    387        if not x509: 
    388            return x509 
    389 
    390        if binary_form: 
    391            return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) 
    392 
    393        return { 
    394            "subject": ((("commonName", x509.get_subject().CN),),), 
    395            "subjectAltName": get_subj_alt_name(x509), 
    396        } 
    397 
    398    def version(self): 
    399        return self.connection.get_protocol_version_name() 
    400 
    401    def _reuse(self): 
    402        self._makefile_refs += 1 
    403 
    404    def _drop(self): 
    405        if self._makefile_refs < 1: 
    406            self.close() 
    407        else: 
    408            self._makefile_refs -= 1 
    409 
    410 
    411if _fileobject:  # Platform-specific: Python 2 
    412 
    413    def makefile(self, mode, bufsize=-1): 
    414        self._makefile_refs += 1 
    415        return _fileobject(self, mode, bufsize, close=True) 
    416 
    417else:  # Platform-specific: Python 3 
    418    makefile = backport_makefile 
    419 
    420WrappedSocket.makefile = makefile 
    421 
    422 
    423class PyOpenSSLContext(object): 
    424    """ 
    425    I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible 
    426    for translating the interface of the standard library ``SSLContext`` object 
    427    to calls into PyOpenSSL. 
    428    """ 
    429 
    430    def __init__(self, protocol): 
    431        self.protocol = _openssl_versions[protocol] 
    432        self._ctx = OpenSSL.SSL.Context(self.protocol) 
    433        self._options = 0 
    434        self.check_hostname = False 
    435 
    436    @property 
    437    def options(self): 
    438        return self._options 
    439 
    440    @options.setter 
    441    def options(self, value): 
    442        self._options = value 
    443        self._ctx.set_options(value) 
    444 
    445    @property 
    446    def verify_mode(self): 
    447        return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()] 
    448 
    449    @verify_mode.setter 
    450    def verify_mode(self, value): 
    451        self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback) 
    452 
    453    def set_default_verify_paths(self): 
    454        self._ctx.set_default_verify_paths() 
    455 
    456    def set_ciphers(self, ciphers): 
    457        if isinstance(ciphers, six.text_type): 
    458            ciphers = ciphers.encode("utf-8") 
    459        self._ctx.set_cipher_list(ciphers) 
    460 
    461    def load_verify_locations(self, cafile=None, capath=None, cadata=None): 
    462        if cafile is not None: 
    463            cafile = cafile.encode("utf-8") 
    464        if capath is not None: 
    465            capath = capath.encode("utf-8") 
    466        try: 
    467            self._ctx.load_verify_locations(cafile, capath) 
    468            if cadata is not None: 
    469                self._ctx.load_verify_locations(BytesIO(cadata)) 
    470        except OpenSSL.SSL.Error as e: 
    471            raise ssl.SSLError("unable to load trusted certificates: %r" % e) 
    472 
    473    def load_cert_chain(self, certfile, keyfile=None, password=None): 
    474        self._ctx.use_certificate_chain_file(certfile) 
    475        if password is not None: 
    476            if not isinstance(password, six.binary_type): 
    477                password = password.encode("utf-8") 
    478            self._ctx.set_passwd_cb(lambda *_: password) 
    479        self._ctx.use_privatekey_file(keyfile or certfile) 
    480 
    481    def set_alpn_protocols(self, protocols): 
    482        protocols = [six.ensure_binary(p) for p in protocols] 
    483        return self._ctx.set_alpn_protos(protocols) 
    484 
    485    def wrap_socket( 
    486        self, 
    487        sock, 
    488        server_side=False, 
    489        do_handshake_on_connect=True, 
    490        suppress_ragged_eofs=True, 
    491        server_hostname=None, 
    492    ): 
    493        cnx = OpenSSL.SSL.Connection(self._ctx, sock) 
    494 
    495        if isinstance(server_hostname, six.text_type):  # Platform-specific: Python 3 
    496            server_hostname = server_hostname.encode("utf-8") 
    497 
    498        if server_hostname is not None: 
    499            cnx.set_tlsext_host_name(server_hostname) 
    500 
    501        cnx.set_connect_state() 
    502 
    503        while True: 
    504            try: 
    505                cnx.do_handshake() 
    506            except OpenSSL.SSL.WantReadError: 
    507                if not util.wait_for_read(sock, sock.gettimeout()): 
    508                    raise timeout("select timed out") 
    509                continue 
    510            except OpenSSL.SSL.Error as e: 
    511                raise ssl.SSLError("bad handshake: %r" % e) 
    512            break 
    513 
    514        return WrappedSocket(cnx, sock) 
    515 
    516 
    517def _verify_callback(cnx, x509, err_no, err_depth, return_code): 
    518    return err_no == 0