1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Secrets backend that routes requests to the Execution API."""
18
19from __future__ import annotations
20
21from typing import TYPE_CHECKING
22
23from airflow.sdk.bases.secrets_backend import BaseSecretsBackend
24
25if TYPE_CHECKING:
26 from airflow.sdk import Connection
27
28
29class ExecutionAPISecretsBackend(BaseSecretsBackend):
30 """
31 Secrets backend for client contexts (workers, DAG processors, triggerers).
32
33 Routes connection and variable requests through SUPERVISOR_COMMS to the
34 Execution API server. This backend should only be registered in client
35 processes, not in API server/scheduler processes.
36 """
37
38 def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None:
39 """
40 Get connection URI via SUPERVISOR_COMMS.
41
42 Not used since we override get_connection directly.
43 """
44 raise NotImplementedError("Use get_connection instead")
45
46 def get_connection(self, conn_id: str, team_name: str | None = None) -> Connection | None: # type: ignore[override]
47 """
48 Return connection object by routing through SUPERVISOR_COMMS.
49
50 :param conn_id: connection id
51 :param team_name: Name of the team associated to the task trying to access the connection.
52 Unused here because the team name is inferred from the task ID provided in the execution API JWT token.
53 :return: Connection object or None if not found
54 """
55 from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
56 from airflow.sdk.execution_time.context import _process_connection_result_conn
57 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
58
59 try:
60 msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id))
61
62 if isinstance(msg, ErrorResponse):
63 # Connection not found or error occurred
64 return None
65
66 # Convert ExecutionAPI response to SDK Connection
67 return _process_connection_result_conn(msg)
68 except RuntimeError as e:
69 # TriggerCommsDecoder.send() uses async_to_sync internally, which raises RuntimeError
70 # when called within an async event loop. In greenback portal contexts (triggerer),
71 # we catch this and use greenback to call the async version instead.
72 if str(e).startswith("You cannot use AsyncToSync in the same thread as an async event loop"):
73 import asyncio
74
75 import greenback
76
77 task = asyncio.current_task()
78 if greenback.has_portal(task):
79 import warnings
80
81 warnings.warn(
82 "You should not use sync calls here -- use `await aget_connection` instead",
83 stacklevel=2,
84 )
85 return greenback.await_(self.aget_connection(conn_id))
86 # Fall through to the general exception handler for other RuntimeErrors
87 return None
88 except Exception:
89 # If SUPERVISOR_COMMS fails for any reason, return None
90 # to allow fallback to other backends
91 return None
92
93 def get_variable(self, key: str, team_name: str | None = None) -> str | None:
94 """
95 Return variable value by routing through SUPERVISOR_COMMS.
96
97 :param key: Variable key
98 :param team_name: Name of the team associated to the task trying to access the variable.
99 Unused here because the team name is inferred from the task ID provided in the execution API JWT token.
100 :return: Variable value or None if not found
101 """
102 from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult
103 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
104
105 try:
106 msg = SUPERVISOR_COMMS.send(GetVariable(key=key))
107
108 if isinstance(msg, ErrorResponse):
109 # Variable not found or error occurred
110 return None
111
112 # Extract value from VariableResult
113 if isinstance(msg, VariableResult):
114 return msg.value # Already a string | None
115 return None
116 except Exception:
117 # If SUPERVISOR_COMMS fails for any reason, return None
118 # to allow fallback to other backends
119 return None
120
121 async def aget_connection(self, conn_id: str) -> Connection | None: # type: ignore[override]
122 """
123 Return connection object asynchronously via SUPERVISOR_COMMS.
124
125 :param conn_id: connection id
126 :return: Connection object or None if not found
127 """
128 from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
129 from airflow.sdk.execution_time.context import _process_connection_result_conn
130 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
131
132 try:
133 msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id))
134
135 if isinstance(msg, ErrorResponse):
136 # Connection not found or error occurred
137 return None
138
139 # Convert ExecutionAPI response to SDK Connection
140 return _process_connection_result_conn(msg)
141 except Exception:
142 # If SUPERVISOR_COMMS fails for any reason, return None
143 # to allow fallback to other backends
144 return None
145
146 async def aget_variable(self, key: str) -> str | None:
147 """
148 Return variable value asynchronously via SUPERVISOR_COMMS.
149
150 :param key: Variable key
151 :return: Variable value or None if not found
152 """
153 from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult
154 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
155
156 try:
157 msg = await SUPERVISOR_COMMS.asend(GetVariable(key=key))
158
159 if isinstance(msg, ErrorResponse):
160 # Variable not found or error occurred
161 return None
162
163 # Extract value from VariableResult
164 if isinstance(msg, VariableResult):
165 return msg.value # Already a string | None
166 return None
167 except Exception:
168 # If SUPERVISOR_COMMS fails for any reason, return None
169 # to allow fallback to other backends
170 return None