Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/utils/operator_helpers.py: 23%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import logging
21from datetime import datetime
22from typing import TYPE_CHECKING, Any, Callable, Collection, Mapping, TypeVar
24from airflow import settings
25from airflow.utils.context import Context, lazy_mapping_from_context
27if TYPE_CHECKING:
28 from airflow.utils.context import OutletEventAccessors
30R = TypeVar("R")
32DEFAULT_FORMAT_PREFIX = "airflow.ctx."
33ENV_VAR_FORMAT_PREFIX = "AIRFLOW_CTX_"
35AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
36 "AIRFLOW_CONTEXT_DAG_ID": {
37 "default": f"{DEFAULT_FORMAT_PREFIX}dag_id",
38 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_ID",
39 },
40 "AIRFLOW_CONTEXT_TASK_ID": {
41 "default": f"{DEFAULT_FORMAT_PREFIX}task_id",
42 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TASK_ID",
43 },
44 "AIRFLOW_CONTEXT_EXECUTION_DATE": {
45 "default": f"{DEFAULT_FORMAT_PREFIX}execution_date",
46 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}EXECUTION_DATE",
47 },
48 "AIRFLOW_CONTEXT_TRY_NUMBER": {
49 "default": f"{DEFAULT_FORMAT_PREFIX}try_number",
50 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TRY_NUMBER",
51 },
52 "AIRFLOW_CONTEXT_DAG_RUN_ID": {
53 "default": f"{DEFAULT_FORMAT_PREFIX}dag_run_id",
54 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_RUN_ID",
55 },
56 "AIRFLOW_CONTEXT_DAG_OWNER": {
57 "default": f"{DEFAULT_FORMAT_PREFIX}dag_owner",
58 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_OWNER",
59 },
60 "AIRFLOW_CONTEXT_DAG_EMAIL": {
61 "default": f"{DEFAULT_FORMAT_PREFIX}dag_email",
62 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_EMAIL",
63 },
64}
67def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool = False) -> dict[str, str]:
68 """
69 Return values used to externally reconstruct relations between dags, dag_runs, tasks and task_instances.
71 Given a context, this function provides a dictionary of values that can be used to
72 externally reconstruct relations between dags, dag_runs, tasks and task_instances.
73 Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if
74 in_env_var_format is set to True.
76 :param context: The context for the task_instance of interest.
77 :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format.
78 :return: task_instance context as dict.
79 """
80 params = {}
81 if in_env_var_format:
82 name_format = "env_var_format"
83 else:
84 name_format = "default"
86 task = context.get("task")
87 task_instance = context.get("task_instance")
88 dag_run = context.get("dag_run")
90 ops = [
91 (task, "email", "AIRFLOW_CONTEXT_DAG_EMAIL"),
92 (task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"),
93 (task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"),
94 (task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"),
95 (task_instance, "execution_date", "AIRFLOW_CONTEXT_EXECUTION_DATE"),
96 (task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"),
97 (dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"),
98 ]
100 context_params = settings.get_airflow_context_vars(context)
101 for key, value in context_params.items():
102 if not isinstance(key, str):
103 raise TypeError(f"key <{key}> must be string")
104 if not isinstance(value, str):
105 raise TypeError(f"value of key <{key}> must be string, not {type(value)}")
107 if in_env_var_format:
108 if not key.startswith(ENV_VAR_FORMAT_PREFIX):
109 key = ENV_VAR_FORMAT_PREFIX + key.upper()
110 else:
111 if not key.startswith(DEFAULT_FORMAT_PREFIX):
112 key = DEFAULT_FORMAT_PREFIX + key
113 params[key] = value
115 for subject, attr, mapping_key in ops:
116 _attr = getattr(subject, attr, None)
117 if subject and _attr:
118 mapping_value = AIRFLOW_VAR_NAME_FORMAT_MAPPING[mapping_key][name_format]
119 if isinstance(_attr, str):
120 params[mapping_value] = _attr
121 elif isinstance(_attr, datetime):
122 params[mapping_value] = _attr.isoformat()
123 elif isinstance(_attr, list):
124 # os env variable value needs to be string
125 params[mapping_value] = ",".join(_attr)
126 else:
127 params[mapping_value] = str(_attr)
129 return params
132class KeywordParameters:
133 """Wrapper representing ``**kwargs`` to a callable.
135 The actual ``kwargs`` can be obtained by calling either ``unpacking()`` or
136 ``serializing()``. They behave almost the same and are only different if
137 the containing ``kwargs`` is an Airflow Context object, and the calling
138 function uses ``**kwargs`` in the argument list.
140 In this particular case, ``unpacking()`` uses ``lazy-object-proxy`` to
141 prevent the Context from emitting deprecation warnings too eagerly when it's
142 unpacked by ``**``. ``serializing()`` does not do this, and will allow the
143 warnings to be emitted eagerly, which is useful when you want to dump the
144 content and use it somewhere else without needing ``lazy-object-proxy``.
145 """
147 def __init__(self, kwargs: Mapping[str, Any], *, wildcard: bool) -> None:
148 self._kwargs = kwargs
149 self._wildcard = wildcard
151 @classmethod
152 def determine(
153 cls,
154 func: Callable[..., Any],
155 args: Collection[Any],
156 kwargs: Mapping[str, Any],
157 ) -> KeywordParameters:
158 import inspect
159 import itertools
161 signature = inspect.signature(func)
162 has_wildcard_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature.parameters.values())
164 for name in itertools.islice(signature.parameters.keys(), len(args)):
165 # Check if args conflict with names in kwargs.
166 if name in kwargs:
167 raise ValueError(f"The key {name!r} in args is a part of kwargs and therefore reserved.")
169 if has_wildcard_kwargs:
170 # If the callable has a **kwargs argument, it's ready to accept all the kwargs.
171 return cls(kwargs, wildcard=True)
173 # If the callable has no **kwargs argument, it only wants the arguments it requested.
174 kwargs = {key: kwargs[key] for key in signature.parameters if key in kwargs}
175 return cls(kwargs, wildcard=False)
177 def unpacking(self) -> Mapping[str, Any]:
178 """Dump the kwargs mapping to unpack with ``**`` in a function call."""
179 if self._wildcard and isinstance(self._kwargs, Context): # type: ignore[misc]
180 return lazy_mapping_from_context(self._kwargs)
181 return self._kwargs
183 def serializing(self) -> Mapping[str, Any]:
184 """Dump the kwargs mapping for serialization purposes."""
185 return self._kwargs
188def determine_kwargs(
189 func: Callable[..., Any],
190 args: Collection[Any],
191 kwargs: Mapping[str, Any],
192) -> Mapping[str, Any]:
193 """
194 Inspect the signature of a callable to determine which kwargs need to be passed to the callable.
196 :param func: The callable that you want to invoke
197 :param args: The positional arguments that need to be passed to the callable, so we know how many to skip.
198 :param kwargs: The keyword arguments that need to be filtered before passing to the callable.
199 :return: A dictionary which contains the keyword arguments that are compatible with the callable.
200 """
201 return KeywordParameters.determine(func, args, kwargs).unpacking()
204def make_kwargs_callable(func: Callable[..., R]) -> Callable[..., R]:
205 """
206 Create a new callable that only forwards necessary arguments from any provided input.
208 Make a new callable that can accept any number of positional or keyword arguments
209 but only forwards those required by the given callable func.
210 """
211 import functools
213 @functools.wraps(func)
214 def kwargs_func(*args, **kwargs):
215 kwargs = determine_kwargs(func, args, kwargs)
216 return func(*args, **kwargs)
218 return kwargs_func
221class ExecutionCallableRunner:
222 """Run an execution callable against a task context and given arguments.
224 If the callable is a simple function, this simply calls it with the supplied
225 arguments (including the context). If the callable is a generator function,
226 the generator is exhausted here, with the yielded values getting fed back
227 into the task context automatically for execution.
229 :meta private:
230 """
232 def __init__(
233 self,
234 func: Callable,
235 outlet_events: OutletEventAccessors,
236 *,
237 logger: logging.Logger | None,
238 ) -> None:
239 self.func = func
240 self.outlet_events = outlet_events
241 self.logger = logger or logging.getLogger(__name__)
243 def run(self, *args, **kwargs) -> Any:
244 import inspect
246 from airflow.datasets.metadata import Metadata
247 from airflow.utils.types import NOTSET
249 if not inspect.isgeneratorfunction(self.func):
250 return self.func(*args, **kwargs)
252 result: Any = NOTSET
254 def _run():
255 nonlocal result
256 result = yield from self.func(*args, **kwargs)
258 for metadata in _run():
259 if isinstance(metadata, Metadata):
260 self.outlet_events[metadata.uri].extra.update(metadata.extra)
261 continue
262 self.logger.warning("Ignoring unknown data of %r received from task", type(metadata))
263 if self.logger.isEnabledFor(logging.DEBUG):
264 self.logger.debug("Full yielded value: %r", metadata)
266 return result