Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py: 1%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

167 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18""" 

19Example Airflow DAG that performs query in a Cloud SQL instance with SSL support. 

20""" 

21 

22from __future__ import annotations 

23 

24import base64 

25import json 

26import logging 

27import os 

28import random 

29import string 

30from copy import deepcopy 

31from datetime import datetime 

32from pathlib import Path 

33from typing import Any, Iterable 

34 

35from googleapiclient import discovery 

36 

37from airflow import settings 

38from airflow.decorators import task 

39from airflow.models.connection import Connection 

40from airflow.models.dag import DAG 

41from airflow.operators.bash import BashOperator 

42from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLHook 

43from airflow.providers.google.cloud.hooks.secret_manager import GoogleCloudSecretManagerHook 

44from airflow.providers.google.cloud.operators.cloud_sql import ( 

45 CloudSQLCreateInstanceDatabaseOperator, 

46 CloudSQLCreateInstanceOperator, 

47 CloudSQLDeleteInstanceOperator, 

48 CloudSQLExecuteQueryOperator, 

49) 

50from airflow.utils.trigger_rule import TriggerRule 

51 

52ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") 

53PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "Not found") 

54DAG_ID = "cloudsql-query-ssl" 

55REGION = "us-central1" 

56HOME_DIR = Path.home() 

57 

58COMPOSER_ENVIRONMENT = os.environ.get("COMPOSER_ENVIRONMENT", "") 

59if COMPOSER_ENVIRONMENT: 

60 # We assume that the test is launched in Cloud Composer environment because the reserved environment 

61 # variable is assigned (https://cloud.google.com/composer/docs/composer-2/set-environment-variables) 

62 GET_COMPOSER_NETWORK_COMMAND = """ 

63 gcloud composer environments describe $COMPOSER_ENVIRONMENT \ 

64 --location=$COMPOSER_LOCATION \ 

65 --project=$GCP_PROJECT \ 

66 --format="value(config.nodeConfig.network)" 

67 """ 

68else: 

69 # The test is launched locally 

70 GET_COMPOSER_NETWORK_COMMAND = "echo" 

71 

72 

73def run_in_composer(): 

74 return bool(COMPOSER_ENVIRONMENT) 

75 

76 

77CLOUD_SQL_INSTANCE_NAME_TEMPLATE = f"{ENV_ID}-{DAG_ID}".replace("_", "-") 

78CLOUD_SQL_INSTANCE_CREATE_BODY_TEMPLATE: dict[str, Any] = { 

79 "name": CLOUD_SQL_INSTANCE_NAME_TEMPLATE, 

80 "settings": { 

81 "tier": "db-custom-1-3840", 

82 "dataDiskSizeGb": 30, 

83 "pricingPlan": "PER_USE", 

84 "ipConfiguration": {}, 

85 }, 

86 # For using a different database version please check the link below. 

87 # https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/SqlDatabaseVersion 

88 "databaseVersion": "1.2.3", 

89 "region": REGION, 

90 "ipConfiguration": { 

91 "ipv4Enabled": True, 

92 "requireSsl": True, 

93 "authorizedNetworks": [ 

94 {"value": "0.0.0.0/0"}, 

95 ], 

96 }, 

97} 

98 

99DB_PROVIDERS: Iterable[dict[str, str]] = ( 

100 { 

101 "database_type": "postgres", 

102 "port": "5432", 

103 "database_version": "POSTGRES_15", 

104 "cloud_sql_instance_name": f"{CLOUD_SQL_INSTANCE_NAME_TEMPLATE}-postgres", 

105 }, 

106 { 

107 "database_type": "mysql", 

108 "port": "3306", 

109 "database_version": "MYSQL_8_0", 

110 "cloud_sql_instance_name": f"{CLOUD_SQL_INSTANCE_NAME_TEMPLATE}-mysql", 

111 }, 

112) 

113 

114 

115def ip_configuration() -> dict[str, Any]: 

116 """Generates an ip configuration for a CloudSQL instance creation body""" 

117 if run_in_composer(): 

118 # Use connection to Cloud SQL instance via Private IP within the Cloud Composer's network. 

119 return { 

120 "ipv4Enabled": True, 

121 "requireSsl": False, 

122 "sslMode": "ENCRYPTED_ONLY", 

123 "enablePrivatePathForGoogleCloudServices": True, 

124 "privateNetwork": """{{ task_instance.xcom_pull('get_composer_network')}}""", 

125 } 

126 else: 

127 # Use connection to Cloud SQL instance via Public IP from anywhere (mask 0.0.0.0/0). 

128 # Consider specifying your network mask 

129 # for allowing requests only from the trusted sources, not from anywhere. 

130 return { 

131 "ipv4Enabled": True, 

132 "requireSsl": False, 

133 "sslMode": "ENCRYPTED_ONLY", 

134 "authorizedNetworks": [ 

135 {"value": "0.0.0.0/0"}, 

136 ], 

137 } 

138 

139 

140def cloud_sql_instance_create_body(database_provider: dict[str, Any]) -> dict[str, Any]: 

141 """Generates a CloudSQL instance creation body""" 

142 create_body: dict[str, Any] = deepcopy(CLOUD_SQL_INSTANCE_CREATE_BODY_TEMPLATE) 

143 create_body["name"] = database_provider["cloud_sql_instance_name"] 

144 create_body["databaseVersion"] = database_provider["database_version"] 

145 create_body["settings"]["ipConfiguration"] = ip_configuration() 

146 return create_body 

147 

148 

149CLOUD_SQL_DATABASE_NAME = "test_db" 

150CLOUD_SQL_USER = "test_user" 

151CLOUD_SQL_PASSWORD = "JoxHlwrPzwch0gz9" 

152CLOUD_SQL_IP_ADDRESS = "127.0.0.1" 

153CLOUD_SQL_PUBLIC_PORT = 5432 

154 

155 

156def cloud_sql_database_create_body(instance: str) -> dict[str, Any]: 

157 """Generates a CloudSQL database creation body""" 

158 return { 

159 "instance": instance, 

160 "name": CLOUD_SQL_DATABASE_NAME, 

161 "project": PROJECT_ID, 

162 } 

163 

164 

165CLOUD_SQL_INSTANCE_NAME = "" 

166DATABASE_TYPE = "" # "postgres|mysql|mssql" 

167 

168# [START howto_operator_cloudsql_query_connections] 

169# Connect directly via TCP (SSL) 

170CONNECTION_PUBLIC_TCP_SSL_KWARGS = { 

171 "conn_type": "gcpcloudsql", 

172 "login": CLOUD_SQL_USER, 

173 "password": CLOUD_SQL_PASSWORD, 

174 "host": CLOUD_SQL_IP_ADDRESS, 

175 "port": CLOUD_SQL_PUBLIC_PORT, 

176 "schema": CLOUD_SQL_DATABASE_NAME, 

177 "extra": { 

178 "database_type": DATABASE_TYPE, 

179 "project_id": PROJECT_ID, 

180 "location": REGION, 

181 "instance": CLOUD_SQL_INSTANCE_NAME, 

182 "use_proxy": "False", 

183 "use_ssl": "True", 

184 }, 

185} 

186# [END howto_operator_cloudsql_query_connections] 

187 

188CONNECTION_PUBLIC_TCP_SSL_ID = f"{DAG_ID}_{ENV_ID}_tcp_ssl" 

189 

190SQL = [ 

191 "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)", 

192 "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)", 

193 "INSERT INTO TABLE_TEST VALUES (0)", 

194 "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)", 

195 "DROP TABLE TABLE_TEST", 

196 "DROP TABLE TABLE_TEST2", 

197] 

198 

199DELETE_CONNECTION_COMMAND = "airflow connections delete {}" 

200 

201SSL_PATH = f"/{DAG_ID}/{ENV_ID}" 

202SSL_LOCAL_PATH_PREFIX = "/tmp" 

203SSL_COMPOSER_PATH_PREFIX = "/home/airflow/gcs/data" 

204# [START howto_operator_cloudsql_query_connections_env] 

205 

206# The connections below are created using one of the standard approaches - via environment 

207# variables named AIRFLOW_CONN_* . The connections can also be created in the database 

208# of AIRFLOW (using command line or UI). 

209 

210postgres_kwargs = { 

211 "user": "user", 

212 "password": "password", 

213 "public_ip": "public_ip", 

214 "public_port": "public_port", 

215 "database": "database", 

216 "project_id": "project_id", 

217 "location": "location", 

218 "instance": "instance", 

219 "client_cert_file": "client_cert_file", 

220 "client_key_file": "client_key_file", 

221 "server_ca_file": "server_ca_file", 

222} 

223 

224# Postgres: connect directly via TCP (SSL) 

225os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL"] = ( 

226 "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" 

227 "database_type=postgres&" 

228 "project_id={project_id}&" 

229 "location={location}&" 

230 "instance={instance}&" 

231 "use_proxy=False&" 

232 "use_ssl=True&" 

233 "sslcert={client_cert_file}&" 

234 "sslkey={client_key_file}&" 

235 "sslrootcert={server_ca_file}".format(**postgres_kwargs) 

236) 

237 

238mysql_kwargs = { 

239 "user": "user", 

240 "password": "password", 

241 "public_ip": "public_ip", 

242 "public_port": "public_port", 

243 "database": "database", 

244 "project_id": "project_id", 

245 "location": "location", 

246 "instance": "instance", 

247 "client_cert_file": "client_cert_file", 

248 "client_key_file": "client_key_file", 

249 "server_ca_file": "server_ca_file", 

250} 

251 

252# MySQL: connect directly via TCP (SSL) 

253os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL"] = ( 

254 "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" 

255 "database_type=mysql&" 

256 "project_id={project_id}&" 

257 "location={location}&" 

258 "instance={instance}&" 

259 "use_proxy=False&" 

260 "use_ssl=True&" 

261 "sslcert={client_cert_file}&" 

262 "sslkey={client_key_file}&" 

263 "sslrootcert={server_ca_file}".format(**mysql_kwargs) 

264) 

265# [END howto_operator_cloudsql_query_connections_env] 

266 

267 

268log = logging.getLogger(__name__) 

269 

270with DAG( 

271 dag_id=DAG_ID, 

272 start_date=datetime(2021, 1, 1), 

273 catchup=False, 

274 tags=["example", "cloudsql", "postgres"], 

275) as dag: 

276 get_composer_network = BashOperator( 

277 task_id="get_composer_network", 

278 bash_command=GET_COMPOSER_NETWORK_COMMAND, 

279 do_xcom_push=True, 

280 ) 

281 

282 for db_provider in DB_PROVIDERS: 

283 database_type: str = db_provider["database_type"] 

284 cloud_sql_instance_name: str = db_provider["cloud_sql_instance_name"] 

285 

286 create_cloud_sql_instance = CloudSQLCreateInstanceOperator( 

287 task_id=f"create_cloud_sql_instance_{database_type}", 

288 project_id=PROJECT_ID, 

289 instance=cloud_sql_instance_name, 

290 body=cloud_sql_instance_create_body(database_provider=db_provider), 

291 ) 

292 

293 create_database = CloudSQLCreateInstanceDatabaseOperator( 

294 task_id=f"create_database_{database_type}", 

295 body=cloud_sql_database_create_body(instance=cloud_sql_instance_name), 

296 instance=cloud_sql_instance_name, 

297 ) 

298 

299 @task(task_id=f"create_user_{database_type}") 

300 def create_user(instance: str) -> None: 

301 with discovery.build("sqladmin", "v1beta4") as service: 

302 request = service.users().insert( 

303 project=PROJECT_ID, 

304 instance=instance, 

305 body={ 

306 "name": CLOUD_SQL_USER, 

307 "password": CLOUD_SQL_PASSWORD, 

308 }, 

309 ) 

310 request.execute() 

311 return None 

312 

313 create_user_task = create_user(instance=cloud_sql_instance_name) 

314 

315 @task(task_id=f"get_ip_address_{database_type}") 

316 def get_ip_address(instance: str) -> str | None: 

317 """Returns a Cloud SQL instance IP address. 

318 

319 If the test is running in Cloud Composer, the Private IP address is used, otherwise Public IP.""" 

320 with discovery.build("sqladmin", "v1beta4") as service: 

321 request = service.connect().get( 

322 project=PROJECT_ID, 

323 instance=instance, 

324 fields="ipAddresses", 

325 ) 

326 response = request.execute() 

327 for ip_item in response.get("ipAddresses", []): 

328 if run_in_composer(): 

329 if ip_item["type"] == "PRIVATE": 

330 return ip_item["ipAddress"] 

331 else: 

332 if ip_item["type"] == "PRIMARY": 

333 return ip_item["ipAddress"] 

334 return None 

335 

336 get_ip_address_task = get_ip_address(instance=cloud_sql_instance_name) 

337 

338 conn_id = f"{CONNECTION_PUBLIC_TCP_SSL_ID}_{database_type}" 

339 

340 @task(task_id=f"create_connection_{database_type}") 

341 def create_connection( 

342 connection_id: str, instance: str, db_type: str, ip_address: str, port: str 

343 ) -> str | None: 

344 session = settings.Session() 

345 if session.query(Connection).filter(Connection.conn_id == connection_id).first(): 

346 log.warning("Connection '%s' already exists", connection_id) 

347 return connection_id 

348 

349 connection: dict[str, Any] = deepcopy(CONNECTION_PUBLIC_TCP_SSL_KWARGS) 

350 connection["extra"]["instance"] = instance 

351 connection["host"] = ip_address 

352 connection["extra"]["database_type"] = db_type 

353 connection["port"] = port 

354 conn = Connection(conn_id=connection_id, **connection) 

355 session.add(conn) 

356 session.commit() 

357 log.info("Connection created: '%s'", connection_id) 

358 return connection_id 

359 

360 create_connection_task = create_connection( 

361 connection_id=conn_id, 

362 instance=cloud_sql_instance_name, 

363 db_type=database_type, 

364 ip_address=get_ip_address_task, 

365 port=db_provider["port"], 

366 ) 

367 

368 @task(task_id=f"create_ssl_certificates_{database_type}") 

369 def create_ssl_certificate(instance: str, connection_id: str) -> dict[str, Any]: 

370 hook = CloudSQLHook(api_version="v1", gcp_conn_id=connection_id) 

371 certificate_name = f"test_cert_{''.join(random.choice(string.ascii_letters) for _ in range(8))}" 

372 response = hook.create_ssl_certificate( 

373 instance=instance, 

374 body={"common_name": certificate_name}, 

375 project_id=PROJECT_ID, 

376 ) 

377 return response 

378 

379 create_ssl_certificate_task = create_ssl_certificate( 

380 instance=cloud_sql_instance_name, connection_id=create_connection_task 

381 ) 

382 

383 @task(task_id=f"save_ssl_cert_locally_{database_type}") 

384 def save_ssl_cert_locally(ssl_cert: dict[str, Any], db_type: str) -> dict[str, str]: 

385 folder = SSL_COMPOSER_PATH_PREFIX if run_in_composer() else SSL_LOCAL_PATH_PREFIX 

386 folder += f"/certs/{db_type}/{ssl_cert['operation']['name']}" 

387 if not os.path.exists(folder): 

388 os.makedirs(folder) 

389 _ssl_root_cert_path = f"{folder}/sslrootcert.pem" 

390 _ssl_cert_path = f"{folder}/sslcert.pem" 

391 _ssl_key_path = f"{folder}/sslkey.pem" 

392 with open(_ssl_root_cert_path, "w") as ssl_root_cert_file: 

393 ssl_root_cert_file.write(ssl_cert["serverCaCert"]["cert"]) 

394 with open(_ssl_cert_path, "w") as ssl_cert_file: 

395 ssl_cert_file.write(ssl_cert["clientCert"]["certInfo"]["cert"]) 

396 with open(_ssl_key_path, "w") as ssl_key_file: 

397 ssl_key_file.write(ssl_cert["clientCert"]["certPrivateKey"]) 

398 return { 

399 "sslrootcert": _ssl_root_cert_path, 

400 "sslcert": _ssl_cert_path, 

401 "sslkey": _ssl_key_path, 

402 } 

403 

404 save_ssl_cert_locally_task = save_ssl_cert_locally( 

405 ssl_cert=create_ssl_certificate_task, db_type=database_type 

406 ) 

407 

408 @task(task_id=f"save_ssl_cert_to_secret_manager_{database_type}") 

409 def save_ssl_cert_to_secret_manager(ssl_cert: dict[str, Any], db_type: str) -> str: 

410 hook = GoogleCloudSecretManagerHook() 

411 payload = { 

412 "sslrootcert": ssl_cert["serverCaCert"]["cert"], 

413 "sslcert": ssl_cert["clientCert"]["certInfo"]["cert"], 

414 "sslkey": ssl_cert["clientCert"]["certPrivateKey"], 

415 } 

416 _secret_id = f"secret_{DAG_ID}_{ENV_ID}_{db_type}" 

417 

418 if not hook.secret_exists(project_id=PROJECT_ID, secret_id=_secret_id): 

419 hook.create_secret( 

420 secret_id=_secret_id, 

421 project_id=PROJECT_ID, 

422 ) 

423 

424 hook.add_secret_version( 

425 project_id=PROJECT_ID, 

426 secret_id=_secret_id, 

427 secret_payload=dict(data=base64.b64encode(json.dumps(payload).encode("ascii"))), 

428 ) 

429 

430 return _secret_id 

431 

432 save_ssl_cert_to_secret_manager_task = save_ssl_cert_to_secret_manager( 

433 ssl_cert=create_ssl_certificate_task, db_type=database_type 

434 ) 

435 

436 task_id = f"example_cloud_sql_query_ssl_{database_type}" 

437 ssl_server_cert_path = ( 

438 f"{{{{ task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslrootcert'] }}}}" 

439 ) 

440 ssl_cert_path = ( 

441 f"{{{{ task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslcert'] }}}}" 

442 ) 

443 ssl_key_path = f"{{{{ task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslkey'] }}}}" 

444 

445 # [START howto_operator_cloudsql_query_operators_ssl] 

446 query_task = CloudSQLExecuteQueryOperator( 

447 gcp_cloudsql_conn_id=conn_id, 

448 task_id=task_id, 

449 sql=SQL, 

450 ssl_client_cert=ssl_cert_path, 

451 ssl_server_cert=ssl_server_cert_path, 

452 ssl_client_key=ssl_key_path, 

453 ) 

454 # [END howto_operator_cloudsql_query_operators_ssl] 

455 

456 task_id = f"example_cloud_sql_query_ssl_secret_{database_type}" 

457 secret_id = f"{{{{ task_instance.xcom_pull('save_ssl_cert_to_secret_manager_{database_type}') }}}}" 

458 

459 # [START howto_operator_cloudsql_query_operators_ssl_secret_id] 

460 query_task_secret = CloudSQLExecuteQueryOperator( 

461 gcp_cloudsql_conn_id=conn_id, 

462 task_id=task_id, 

463 sql=SQL, 

464 ssl_secret_id=secret_id, 

465 ) 

466 # [END howto_operator_cloudsql_query_operators_ssl_secret_id] 

467 

468 delete_instance = CloudSQLDeleteInstanceOperator( 

469 task_id=f"delete_cloud_sql_instance_{database_type}", 

470 project_id=PROJECT_ID, 

471 instance=cloud_sql_instance_name, 

472 trigger_rule=TriggerRule.ALL_DONE, 

473 ) 

474 

475 delete_connection = BashOperator( 

476 task_id=f"delete_connection_{conn_id}", 

477 bash_command=DELETE_CONNECTION_COMMAND.format(conn_id), 

478 trigger_rule=TriggerRule.ALL_DONE, 

479 skip_on_exit_code=1, 

480 ) 

481 

482 @task(task_id=f"delete_secret_{database_type}") 

483 def delete_secret(ssl_secret_id, db_type: str) -> None: 

484 hook = GoogleCloudSecretManagerHook() 

485 if hook.secret_exists(project_id=PROJECT_ID, secret_id=ssl_secret_id): 

486 hook.delete_secret(secret_id=ssl_secret_id, project_id=PROJECT_ID) 

487 

488 delete_secret_task = delete_secret( 

489 ssl_secret_id=save_ssl_cert_to_secret_manager_task, db_type=database_type 

490 ) 

491 

492 ( 

493 # TEST SETUP 

494 get_composer_network 

495 >> create_cloud_sql_instance 

496 >> [create_database, create_user_task, get_ip_address_task] 

497 >> create_connection_task 

498 >> create_ssl_certificate_task 

499 >> [save_ssl_cert_locally_task, save_ssl_cert_to_secret_manager_task] 

500 # TEST BODY 

501 >> query_task 

502 >> query_task_secret 

503 # TEST TEARDOWN 

504 >> [delete_instance, delete_connection, delete_secret_task] 

505 ) 

506 

507 # ### Everything below this line is not part of example ### 

508 # ### Just for system tests purpose ### 

509 from tests.system.utils.watcher import watcher 

510 

511 # This test needs watcher in order to properly mark success/failure 

512 # when "tearDown" task with trigger rule is part of the DAG 

513 list(dag.tasks) >> watcher() 

514 

515from tests.system.utils import get_test_run # noqa: E402 

516 

517# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) 

518test_run = get_test_run(dag)