1# -*- coding: utf-8 -*-
2#
3# Copyright 2024 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17"""Mutual TLS for Google Compute Engine metadata server."""
18
19from dataclasses import dataclass, field
20import enum
21import logging
22import os
23from pathlib import Path
24import ssl
25from urllib.parse import urlparse, urlunparse
26
27import requests
28from requests.adapters import HTTPAdapter
29
30from google.auth import environment_vars, exceptions
31
32
33_LOGGER = logging.getLogger(__name__)
34
35_WINDOWS_OS_NAME = "nt"
36
37# MDS mTLS certificate paths based on OS.
38# Documentation to well known locations can be found at:
39# https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates
40_WINDOWS_MTLS_COMPONENTS_BASE_PATH = Path("C:/ProgramData/Google/ComputeEngine")
41_MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls")
42
43
44def _get_mds_root_crt_path():
45 if os.name == _WINDOWS_OS_NAME:
46 return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt"
47 else:
48 return _MTLS_COMPONENTS_BASE_PATH / "root.crt"
49
50
51def _get_mds_client_combined_cert_path():
52 if os.name == _WINDOWS_OS_NAME:
53 return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key"
54 else:
55 return _MTLS_COMPONENTS_BASE_PATH / "client.key"
56
57
58@dataclass
59class MdsMtlsConfig:
60 ca_cert_path: Path = field(
61 default_factory=_get_mds_root_crt_path
62 ) # path to CA certificate
63 client_combined_cert_path: Path = field(
64 default_factory=_get_mds_client_combined_cert_path
65 ) # path to file containing client certificate and key
66
67
68def _certs_exist(mds_mtls_config: MdsMtlsConfig):
69 """Checks if the mTLS certificates exist."""
70 return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists(
71 mds_mtls_config.client_combined_cert_path
72 )
73
74
75class MdsMtlsMode(enum.Enum):
76 """MDS mTLS mode. Used to configure connection behavior when connecting to MDS.
77
78 STRICT: Always use HTTPS/mTLS. If certificates are not found locally, an error will be returned.
79 NONE: Never use mTLS. Requests will use regular HTTP.
80 DEFAULT: Use mTLS if certificates are found locally, otherwise use regular HTTP.
81 """
82
83 STRICT = "strict"
84 NONE = "none"
85 DEFAULT = "default"
86
87
88def _parse_mds_mode():
89 """Parses the GCE_METADATA_MTLS_MODE environment variable."""
90 mode_str = os.environ.get(
91 environment_vars.GCE_METADATA_MTLS_MODE, "default"
92 ).lower()
93 try:
94 return MdsMtlsMode(mode_str)
95 except ValueError:
96 raise ValueError(
97 "Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'."
98 )
99
100
101def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()):
102 """Determines if mTLS should be used for the metadata server."""
103 mode = _parse_mds_mode()
104 if mode == MdsMtlsMode.STRICT:
105 if not _certs_exist(mds_mtls_config):
106 raise exceptions.MutualTLSChannelError(
107 "mTLS certificates not found in strict mode."
108 )
109 return True
110 elif mode == MdsMtlsMode.NONE:
111 return False
112 else: # Default mode
113 return _certs_exist(mds_mtls_config)
114
115
116class MdsMtlsAdapter(HTTPAdapter):
117 """An HTTP adapter that uses mTLS for the metadata server."""
118
119 def __init__(
120 self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs
121 ):
122 self.ssl_context = ssl.create_default_context()
123 self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path)
124 self.ssl_context.load_cert_chain(
125 certfile=mds_mtls_config.client_combined_cert_path
126 )
127 super(MdsMtlsAdapter, self).__init__(*args, **kwargs)
128
129 def init_poolmanager(self, *args, **kwargs):
130 kwargs["ssl_context"] = self.ssl_context
131 return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs)
132
133 def proxy_manager_for(self, *args, **kwargs):
134 kwargs["ssl_context"] = self.ssl_context
135 return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs)
136
137 def send(self, request, **kwargs):
138 # If we are in strict mode, always use mTLS (no HTTP fallback)
139 if _parse_mds_mode() == MdsMtlsMode.STRICT:
140 return super(MdsMtlsAdapter, self).send(request, **kwargs)
141
142 # In default mode, attempt mTLS first, then fallback to HTTP on failure
143 try:
144 response = super(MdsMtlsAdapter, self).send(request, **kwargs)
145 response.raise_for_status()
146 return response
147 except (
148 ssl.SSLError,
149 requests.exceptions.SSLError,
150 requests.exceptions.HTTPError,
151 ) as e:
152 _LOGGER.warning(
153 "mTLS connection to Compute Engine Metadata server failed. "
154 "Falling back to standard HTTP. Reason: %s",
155 e,
156 )
157 # Fallback to standard HTTP
158 parsed_original_url = urlparse(request.url)
159 http_fallback_url = urlunparse(parsed_original_url._replace(scheme="http"))
160 request.url = http_fallback_url
161
162 # Use a standard HTTPAdapter for the fallback
163 http_adapter = HTTPAdapter()
164 return http_adapter.send(request, **kwargs)