1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""
19Example Airflow DAG that performs query in a Cloud SQL instance with SSL support.
20"""
21
22from __future__ import annotations
23
24import base64
25import json
26import logging
27import os
28import random
29import string
30from copy import deepcopy
31from datetime import datetime
32from pathlib import Path
33from typing import Any, Iterable
34
35from googleapiclient import discovery
36
37from airflow import settings
38from airflow.decorators import task
39from airflow.models.connection import Connection
40from airflow.models.dag import DAG
41from airflow.operators.bash import BashOperator
42from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLHook
43from airflow.providers.google.cloud.hooks.secret_manager import GoogleCloudSecretManagerHook
44from airflow.providers.google.cloud.operators.cloud_sql import (
45 CloudSQLCreateInstanceDatabaseOperator,
46 CloudSQLCreateInstanceOperator,
47 CloudSQLDeleteInstanceOperator,
48 CloudSQLExecuteQueryOperator,
49)
50from airflow.utils.trigger_rule import TriggerRule
51
52ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
53PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "Not found")
54DAG_ID = "cloudsql-query-ssl"
55REGION = "us-central1"
56HOME_DIR = Path.home()
57
58COMPOSER_ENVIRONMENT = os.environ.get("COMPOSER_ENVIRONMENT", "")
59if COMPOSER_ENVIRONMENT:
60 # We assume that the test is launched in Cloud Composer environment because the reserved environment
61 # variable is assigned (https://cloud.google.com/composer/docs/composer-2/set-environment-variables)
62 GET_COMPOSER_NETWORK_COMMAND = """
63 gcloud composer environments describe $COMPOSER_ENVIRONMENT \
64 --location=$COMPOSER_LOCATION \
65 --project=$GCP_PROJECT \
66 --format="value(config.nodeConfig.network)"
67 """
68else:
69 # The test is launched locally
70 GET_COMPOSER_NETWORK_COMMAND = "echo"
71
72
73def run_in_composer():
74 return bool(COMPOSER_ENVIRONMENT)
75
76
77CLOUD_SQL_INSTANCE_NAME_TEMPLATE = f"{ENV_ID}-{DAG_ID}".replace("_", "-")
78CLOUD_SQL_INSTANCE_CREATE_BODY_TEMPLATE: dict[str, Any] = {
79 "name": CLOUD_SQL_INSTANCE_NAME_TEMPLATE,
80 "settings": {
81 "tier": "db-custom-1-3840",
82 "dataDiskSizeGb": 30,
83 "pricingPlan": "PER_USE",
84 "ipConfiguration": {},
85 },
86 # For using a different database version please check the link below.
87 # https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/SqlDatabaseVersion
88 "databaseVersion": "1.2.3",
89 "region": REGION,
90 "ipConfiguration": {
91 "ipv4Enabled": True,
92 "requireSsl": True,
93 "authorizedNetworks": [
94 {"value": "0.0.0.0/0"},
95 ],
96 },
97}
98
99DB_PROVIDERS: Iterable[dict[str, str]] = (
100 {
101 "database_type": "postgres",
102 "port": "5432",
103 "database_version": "POSTGRES_15",
104 "cloud_sql_instance_name": f"{CLOUD_SQL_INSTANCE_NAME_TEMPLATE}-postgres",
105 },
106 {
107 "database_type": "mysql",
108 "port": "3306",
109 "database_version": "MYSQL_8_0",
110 "cloud_sql_instance_name": f"{CLOUD_SQL_INSTANCE_NAME_TEMPLATE}-mysql",
111 },
112)
113
114
115def ip_configuration() -> dict[str, Any]:
116 """Generates an ip configuration for a CloudSQL instance creation body"""
117 if run_in_composer():
118 # Use connection to Cloud SQL instance via Private IP within the Cloud Composer's network.
119 return {
120 "ipv4Enabled": True,
121 "requireSsl": False,
122 "sslMode": "ENCRYPTED_ONLY",
123 "enablePrivatePathForGoogleCloudServices": True,
124 "privateNetwork": """{{ task_instance.xcom_pull('get_composer_network')}}""",
125 }
126 else:
127 # Use connection to Cloud SQL instance via Public IP from anywhere (mask 0.0.0.0/0).
128 # Consider specifying your network mask
129 # for allowing requests only from the trusted sources, not from anywhere.
130 return {
131 "ipv4Enabled": True,
132 "requireSsl": False,
133 "sslMode": "ENCRYPTED_ONLY",
134 "authorizedNetworks": [
135 {"value": "0.0.0.0/0"},
136 ],
137 }
138
139
140def cloud_sql_instance_create_body(database_provider: dict[str, Any]) -> dict[str, Any]:
141 """Generates a CloudSQL instance creation body"""
142 create_body: dict[str, Any] = deepcopy(CLOUD_SQL_INSTANCE_CREATE_BODY_TEMPLATE)
143 create_body["name"] = database_provider["cloud_sql_instance_name"]
144 create_body["databaseVersion"] = database_provider["database_version"]
145 create_body["settings"]["ipConfiguration"] = ip_configuration()
146 return create_body
147
148
149CLOUD_SQL_DATABASE_NAME = "test_db"
150CLOUD_SQL_USER = "test_user"
151CLOUD_SQL_PASSWORD = "JoxHlwrPzwch0gz9"
152CLOUD_SQL_IP_ADDRESS = "127.0.0.1"
153CLOUD_SQL_PUBLIC_PORT = 5432
154
155
156def cloud_sql_database_create_body(instance: str) -> dict[str, Any]:
157 """Generates a CloudSQL database creation body"""
158 return {
159 "instance": instance,
160 "name": CLOUD_SQL_DATABASE_NAME,
161 "project": PROJECT_ID,
162 }
163
164
165CLOUD_SQL_INSTANCE_NAME = ""
166DATABASE_TYPE = "" # "postgres|mysql|mssql"
167
168# [START howto_operator_cloudsql_query_connections]
169# Connect directly via TCP (SSL)
170CONNECTION_PUBLIC_TCP_SSL_KWARGS = {
171 "conn_type": "gcpcloudsql",
172 "login": CLOUD_SQL_USER,
173 "password": CLOUD_SQL_PASSWORD,
174 "host": CLOUD_SQL_IP_ADDRESS,
175 "port": CLOUD_SQL_PUBLIC_PORT,
176 "schema": CLOUD_SQL_DATABASE_NAME,
177 "extra": {
178 "database_type": DATABASE_TYPE,
179 "project_id": PROJECT_ID,
180 "location": REGION,
181 "instance": CLOUD_SQL_INSTANCE_NAME,
182 "use_proxy": "False",
183 "use_ssl": "True",
184 },
185}
186# [END howto_operator_cloudsql_query_connections]
187
188CONNECTION_PUBLIC_TCP_SSL_ID = f"{DAG_ID}_{ENV_ID}_tcp_ssl"
189
190SQL = [
191 "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
192 "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
193 "INSERT INTO TABLE_TEST VALUES (0)",
194 "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)",
195 "DROP TABLE TABLE_TEST",
196 "DROP TABLE TABLE_TEST2",
197]
198
199DELETE_CONNECTION_COMMAND = "airflow connections delete {}"
200
201SSL_PATH = f"/{DAG_ID}/{ENV_ID}"
202SSL_LOCAL_PATH_PREFIX = "/tmp"
203SSL_COMPOSER_PATH_PREFIX = "/home/airflow/gcs/data"
204# [START howto_operator_cloudsql_query_connections_env]
205
206# The connections below are created using one of the standard approaches - via environment
207# variables named AIRFLOW_CONN_* . The connections can also be created in the database
208# of AIRFLOW (using command line or UI).
209
210postgres_kwargs = {
211 "user": "user",
212 "password": "password",
213 "public_ip": "public_ip",
214 "public_port": "public_port",
215 "database": "database",
216 "project_id": "project_id",
217 "location": "location",
218 "instance": "instance",
219 "client_cert_file": "client_cert_file",
220 "client_key_file": "client_key_file",
221 "server_ca_file": "server_ca_file",
222}
223
224# Postgres: connect directly via TCP (SSL)
225os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL"] = (
226 "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
227 "database_type=postgres&"
228 "project_id={project_id}&"
229 "location={location}&"
230 "instance={instance}&"
231 "use_proxy=False&"
232 "use_ssl=True&"
233 "sslcert={client_cert_file}&"
234 "sslkey={client_key_file}&"
235 "sslrootcert={server_ca_file}".format(**postgres_kwargs)
236)
237
238mysql_kwargs = {
239 "user": "user",
240 "password": "password",
241 "public_ip": "public_ip",
242 "public_port": "public_port",
243 "database": "database",
244 "project_id": "project_id",
245 "location": "location",
246 "instance": "instance",
247 "client_cert_file": "client_cert_file",
248 "client_key_file": "client_key_file",
249 "server_ca_file": "server_ca_file",
250}
251
252# MySQL: connect directly via TCP (SSL)
253os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL"] = (
254 "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
255 "database_type=mysql&"
256 "project_id={project_id}&"
257 "location={location}&"
258 "instance={instance}&"
259 "use_proxy=False&"
260 "use_ssl=True&"
261 "sslcert={client_cert_file}&"
262 "sslkey={client_key_file}&"
263 "sslrootcert={server_ca_file}".format(**mysql_kwargs)
264)
265# [END howto_operator_cloudsql_query_connections_env]
266
267
268log = logging.getLogger(__name__)
269
270with DAG(
271 dag_id=DAG_ID,
272 start_date=datetime(2021, 1, 1),
273 catchup=False,
274 tags=["example", "cloudsql", "postgres"],
275) as dag:
276 get_composer_network = BashOperator(
277 task_id="get_composer_network",
278 bash_command=GET_COMPOSER_NETWORK_COMMAND,
279 do_xcom_push=True,
280 )
281
282 for db_provider in DB_PROVIDERS:
283 database_type: str = db_provider["database_type"]
284 cloud_sql_instance_name: str = db_provider["cloud_sql_instance_name"]
285
286 create_cloud_sql_instance = CloudSQLCreateInstanceOperator(
287 task_id=f"create_cloud_sql_instance_{database_type}",
288 project_id=PROJECT_ID,
289 instance=cloud_sql_instance_name,
290 body=cloud_sql_instance_create_body(database_provider=db_provider),
291 )
292
293 create_database = CloudSQLCreateInstanceDatabaseOperator(
294 task_id=f"create_database_{database_type}",
295 body=cloud_sql_database_create_body(instance=cloud_sql_instance_name),
296 instance=cloud_sql_instance_name,
297 )
298
299 @task(task_id=f"create_user_{database_type}")
300 def create_user(instance: str) -> None:
301 with discovery.build("sqladmin", "v1beta4") as service:
302 request = service.users().insert(
303 project=PROJECT_ID,
304 instance=instance,
305 body={
306 "name": CLOUD_SQL_USER,
307 "password": CLOUD_SQL_PASSWORD,
308 },
309 )
310 request.execute()
311 return None
312
313 create_user_task = create_user(instance=cloud_sql_instance_name)
314
315 @task(task_id=f"get_ip_address_{database_type}")
316 def get_ip_address(instance: str) -> str | None:
317 """Returns a Cloud SQL instance IP address.
318
319 If the test is running in Cloud Composer, the Private IP address is used, otherwise Public IP."""
320 with discovery.build("sqladmin", "v1beta4") as service:
321 request = service.connect().get(
322 project=PROJECT_ID,
323 instance=instance,
324 fields="ipAddresses",
325 )
326 response = request.execute()
327 for ip_item in response.get("ipAddresses", []):
328 if run_in_composer():
329 if ip_item["type"] == "PRIVATE":
330 return ip_item["ipAddress"]
331 else:
332 if ip_item["type"] == "PRIMARY":
333 return ip_item["ipAddress"]
334 return None
335
336 get_ip_address_task = get_ip_address(instance=cloud_sql_instance_name)
337
338 conn_id = f"{CONNECTION_PUBLIC_TCP_SSL_ID}_{database_type}"
339
340 @task(task_id=f"create_connection_{database_type}")
341 def create_connection(
342 connection_id: str, instance: str, db_type: str, ip_address: str, port: str
343 ) -> str | None:
344 session = settings.Session()
345 if session.query(Connection).filter(Connection.conn_id == connection_id).first():
346 log.warning("Connection '%s' already exists", connection_id)
347 return connection_id
348
349 connection: dict[str, Any] = deepcopy(CONNECTION_PUBLIC_TCP_SSL_KWARGS)
350 connection["extra"]["instance"] = instance
351 connection["host"] = ip_address
352 connection["extra"]["database_type"] = db_type
353 connection["port"] = port
354 conn = Connection(conn_id=connection_id, **connection)
355 session.add(conn)
356 session.commit()
357 log.info("Connection created: '%s'", connection_id)
358 return connection_id
359
360 create_connection_task = create_connection(
361 connection_id=conn_id,
362 instance=cloud_sql_instance_name,
363 db_type=database_type,
364 ip_address=get_ip_address_task,
365 port=db_provider["port"],
366 )
367
368 @task(task_id=f"create_ssl_certificates_{database_type}")
369 def create_ssl_certificate(instance: str, connection_id: str) -> dict[str, Any]:
370 hook = CloudSQLHook(api_version="v1", gcp_conn_id=connection_id)
371 certificate_name = f"test_cert_{''.join(random.choice(string.ascii_letters) for _ in range(8))}"
372 response = hook.create_ssl_certificate(
373 instance=instance,
374 body={"common_name": certificate_name},
375 project_id=PROJECT_ID,
376 )
377 return response
378
379 create_ssl_certificate_task = create_ssl_certificate(
380 instance=cloud_sql_instance_name, connection_id=create_connection_task
381 )
382
383 @task(task_id=f"save_ssl_cert_locally_{database_type}")
384 def save_ssl_cert_locally(ssl_cert: dict[str, Any], db_type: str) -> dict[str, str]:
385 folder = SSL_COMPOSER_PATH_PREFIX if run_in_composer() else SSL_LOCAL_PATH_PREFIX
386 folder += f"/certs/{db_type}/{ssl_cert['operation']['name']}"
387 if not os.path.exists(folder):
388 os.makedirs(folder)
389 _ssl_root_cert_path = f"{folder}/sslrootcert.pem"
390 _ssl_cert_path = f"{folder}/sslcert.pem"
391 _ssl_key_path = f"{folder}/sslkey.pem"
392 with open(_ssl_root_cert_path, "w") as ssl_root_cert_file:
393 ssl_root_cert_file.write(ssl_cert["serverCaCert"]["cert"])
394 with open(_ssl_cert_path, "w") as ssl_cert_file:
395 ssl_cert_file.write(ssl_cert["clientCert"]["certInfo"]["cert"])
396 with open(_ssl_key_path, "w") as ssl_key_file:
397 ssl_key_file.write(ssl_cert["clientCert"]["certPrivateKey"])
398 return {
399 "sslrootcert": _ssl_root_cert_path,
400 "sslcert": _ssl_cert_path,
401 "sslkey": _ssl_key_path,
402 }
403
404 save_ssl_cert_locally_task = save_ssl_cert_locally(
405 ssl_cert=create_ssl_certificate_task, db_type=database_type
406 )
407
408 @task(task_id=f"save_ssl_cert_to_secret_manager_{database_type}")
409 def save_ssl_cert_to_secret_manager(ssl_cert: dict[str, Any], db_type: str) -> str:
410 hook = GoogleCloudSecretManagerHook()
411 payload = {
412 "sslrootcert": ssl_cert["serverCaCert"]["cert"],
413 "sslcert": ssl_cert["clientCert"]["certInfo"]["cert"],
414 "sslkey": ssl_cert["clientCert"]["certPrivateKey"],
415 }
416 _secret_id = f"secret_{DAG_ID}_{ENV_ID}_{db_type}"
417
418 if not hook.secret_exists(project_id=PROJECT_ID, secret_id=_secret_id):
419 hook.create_secret(
420 secret_id=_secret_id,
421 project_id=PROJECT_ID,
422 )
423
424 hook.add_secret_version(
425 project_id=PROJECT_ID,
426 secret_id=_secret_id,
427 secret_payload=dict(data=base64.b64encode(json.dumps(payload).encode("ascii"))),
428 )
429
430 return _secret_id
431
432 save_ssl_cert_to_secret_manager_task = save_ssl_cert_to_secret_manager(
433 ssl_cert=create_ssl_certificate_task, db_type=database_type
434 )
435
436 task_id = f"example_cloud_sql_query_ssl_{database_type}"
437 ssl_server_cert_path = (
438 f"{{{{ task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslrootcert'] }}}}"
439 )
440 ssl_cert_path = (
441 f"{{{{ task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslcert'] }}}}"
442 )
443 ssl_key_path = f"{{{{ task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslkey'] }}}}"
444
445 # [START howto_operator_cloudsql_query_operators_ssl]
446 query_task = CloudSQLExecuteQueryOperator(
447 gcp_cloudsql_conn_id=conn_id,
448 task_id=task_id,
449 sql=SQL,
450 ssl_client_cert=ssl_cert_path,
451 ssl_server_cert=ssl_server_cert_path,
452 ssl_client_key=ssl_key_path,
453 )
454 # [END howto_operator_cloudsql_query_operators_ssl]
455
456 task_id = f"example_cloud_sql_query_ssl_secret_{database_type}"
457 secret_id = f"{{{{ task_instance.xcom_pull('save_ssl_cert_to_secret_manager_{database_type}') }}}}"
458
459 # [START howto_operator_cloudsql_query_operators_ssl_secret_id]
460 query_task_secret = CloudSQLExecuteQueryOperator(
461 gcp_cloudsql_conn_id=conn_id,
462 task_id=task_id,
463 sql=SQL,
464 ssl_secret_id=secret_id,
465 )
466 # [END howto_operator_cloudsql_query_operators_ssl_secret_id]
467
468 delete_instance = CloudSQLDeleteInstanceOperator(
469 task_id=f"delete_cloud_sql_instance_{database_type}",
470 project_id=PROJECT_ID,
471 instance=cloud_sql_instance_name,
472 trigger_rule=TriggerRule.ALL_DONE,
473 )
474
475 delete_connection = BashOperator(
476 task_id=f"delete_connection_{conn_id}",
477 bash_command=DELETE_CONNECTION_COMMAND.format(conn_id),
478 trigger_rule=TriggerRule.ALL_DONE,
479 skip_on_exit_code=1,
480 )
481
482 @task(task_id=f"delete_secret_{database_type}")
483 def delete_secret(ssl_secret_id, db_type: str) -> None:
484 hook = GoogleCloudSecretManagerHook()
485 if hook.secret_exists(project_id=PROJECT_ID, secret_id=ssl_secret_id):
486 hook.delete_secret(secret_id=ssl_secret_id, project_id=PROJECT_ID)
487
488 delete_secret_task = delete_secret(
489 ssl_secret_id=save_ssl_cert_to_secret_manager_task, db_type=database_type
490 )
491
492 (
493 # TEST SETUP
494 get_composer_network
495 >> create_cloud_sql_instance
496 >> [create_database, create_user_task, get_ip_address_task]
497 >> create_connection_task
498 >> create_ssl_certificate_task
499 >> [save_ssl_cert_locally_task, save_ssl_cert_to_secret_manager_task]
500 # TEST BODY
501 >> query_task
502 >> query_task_secret
503 # TEST TEARDOWN
504 >> [delete_instance, delete_connection, delete_secret_task]
505 )
506
507 # ### Everything below this line is not part of example ###
508 # ### Just for system tests purpose ###
509 from tests.system.utils.watcher import watcher
510
511 # This test needs watcher in order to properly mark success/failure
512 # when "tearDown" task with trigger rule is part of the DAG
513 list(dag.tasks) >> watcher()
514
515from tests.system.utils import get_test_run # noqa: E402
516
517# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
518test_run = get_test_run(dag)