Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/google/auth/compute_engine/_mtls.py: 60%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

75 statements  

1# -*- coding: utf-8 -*- 

2# 

3# Copyright 2024 Google LLC 

4# 

5# Licensed under the Apache License, Version 2.0 (the "License"); 

6# you may not use this file except in compliance with the License. 

7# You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, software 

12# distributed under the License is distributed on an "AS IS" BASIS, 

13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

14# See the License for the specific language governing permissions and 

15# limitations under the License. 

16# 

17"""Mutual TLS for Google Compute Engine metadata server.""" 

18 

19from dataclasses import dataclass, field 

20import enum 

21import logging 

22import os 

23from pathlib import Path 

24import ssl 

25from urllib.parse import urlparse, urlunparse 

26 

27import requests 

28from requests.adapters import HTTPAdapter 

29 

30from google.auth import environment_vars, exceptions 

31 

32 

33_LOGGER = logging.getLogger(__name__) 

34 

35_WINDOWS_OS_NAME = "nt" 

36 

37# MDS mTLS certificate paths based on OS. 

38# Documentation to well known locations can be found at: 

39# https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates 

40_WINDOWS_MTLS_COMPONENTS_BASE_PATH = Path("C:/ProgramData/Google/ComputeEngine") 

41_MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls") 

42 

43 

44def _get_mds_root_crt_path(): 

45 if os.name == _WINDOWS_OS_NAME: 

46 return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt" 

47 else: 

48 return _MTLS_COMPONENTS_BASE_PATH / "root.crt" 

49 

50 

51def _get_mds_client_combined_cert_path(): 

52 if os.name == _WINDOWS_OS_NAME: 

53 return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key" 

54 else: 

55 return _MTLS_COMPONENTS_BASE_PATH / "client.key" 

56 

57 

58@dataclass 

59class MdsMtlsConfig: 

60 ca_cert_path: Path = field( 

61 default_factory=_get_mds_root_crt_path 

62 ) # path to CA certificate 

63 client_combined_cert_path: Path = field( 

64 default_factory=_get_mds_client_combined_cert_path 

65 ) # path to file containing client certificate and key 

66 

67 

68def _certs_exist(mds_mtls_config: MdsMtlsConfig): 

69 """Checks if the mTLS certificates exist.""" 

70 return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( 

71 mds_mtls_config.client_combined_cert_path 

72 ) 

73 

74 

75class MdsMtlsMode(enum.Enum): 

76 """MDS mTLS mode. Used to configure connection behavior when connecting to MDS. 

77 

78 STRICT: Always use HTTPS/mTLS. If certificates are not found locally, an error will be returned. 

79 NONE: Never use mTLS. Requests will use regular HTTP. 

80 DEFAULT: Use mTLS if certificates are found locally, otherwise use regular HTTP. 

81 """ 

82 

83 STRICT = "strict" 

84 NONE = "none" 

85 DEFAULT = "default" 

86 

87 

88def _parse_mds_mode(): 

89 """Parses the GCE_METADATA_MTLS_MODE environment variable.""" 

90 mode_str = os.environ.get( 

91 environment_vars.GCE_METADATA_MTLS_MODE, "default" 

92 ).lower() 

93 try: 

94 return MdsMtlsMode(mode_str) 

95 except ValueError: 

96 raise ValueError( 

97 "Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'." 

98 ) 

99 

100 

101def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): 

102 """Determines if mTLS should be used for the metadata server.""" 

103 mode = _parse_mds_mode() 

104 if mode == MdsMtlsMode.STRICT: 

105 if not _certs_exist(mds_mtls_config): 

106 raise exceptions.MutualTLSChannelError( 

107 "mTLS certificates not found in strict mode." 

108 ) 

109 return True 

110 elif mode == MdsMtlsMode.NONE: 

111 return False 

112 else: # Default mode 

113 return _certs_exist(mds_mtls_config) 

114 

115 

116class MdsMtlsAdapter(HTTPAdapter): 

117 """An HTTP adapter that uses mTLS for the metadata server.""" 

118 

119 def __init__( 

120 self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs 

121 ): 

122 self.ssl_context = ssl.create_default_context() 

123 self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path) 

124 self.ssl_context.load_cert_chain( 

125 certfile=mds_mtls_config.client_combined_cert_path 

126 ) 

127 super(MdsMtlsAdapter, self).__init__(*args, **kwargs) 

128 

129 def init_poolmanager(self, *args, **kwargs): 

130 kwargs["ssl_context"] = self.ssl_context 

131 return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs) 

132 

133 def proxy_manager_for(self, *args, **kwargs): 

134 kwargs["ssl_context"] = self.ssl_context 

135 return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs) 

136 

137 def send(self, request, **kwargs): 

138 # If we are in strict mode, always use mTLS (no HTTP fallback) 

139 if _parse_mds_mode() == MdsMtlsMode.STRICT: 

140 return super(MdsMtlsAdapter, self).send(request, **kwargs) 

141 

142 # In default mode, attempt mTLS first, then fallback to HTTP on failure 

143 try: 

144 response = super(MdsMtlsAdapter, self).send(request, **kwargs) 

145 response.raise_for_status() 

146 return response 

147 except ( 

148 ssl.SSLError, 

149 requests.exceptions.SSLError, 

150 requests.exceptions.HTTPError, 

151 ) as e: 

152 _LOGGER.warning( 

153 "mTLS connection to Compute Engine Metadata server failed. " 

154 "Falling back to standard HTTP. Reason: %s", 

155 e, 

156 ) 

157 # Fallback to standard HTTP 

158 parsed_original_url = urlparse(request.url) 

159 http_fallback_url = urlunparse(parsed_original_url._replace(scheme="http")) 

160 request.url = http_fallback_url 

161 

162 # Use a standard HTTPAdapter for the fallback 

163 http_adapter = HTTPAdapter() 

164 return http_adapter.send(request, **kwargs)