Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/utils/operator_helpers.py: 24%
74 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20from datetime import datetime
21from typing import Any, Callable, Collection, Mapping, TypeVar
23from airflow import settings
24from airflow.utils.context import Context, lazy_mapping_from_context
26R = TypeVar("R")
28DEFAULT_FORMAT_PREFIX = "airflow.ctx."
29ENV_VAR_FORMAT_PREFIX = "AIRFLOW_CTX_"
31AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
32 "AIRFLOW_CONTEXT_DAG_ID": {
33 "default": f"{DEFAULT_FORMAT_PREFIX}dag_id",
34 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_ID",
35 },
36 "AIRFLOW_CONTEXT_TASK_ID": {
37 "default": f"{DEFAULT_FORMAT_PREFIX}task_id",
38 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TASK_ID",
39 },
40 "AIRFLOW_CONTEXT_EXECUTION_DATE": {
41 "default": f"{DEFAULT_FORMAT_PREFIX}execution_date",
42 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}EXECUTION_DATE",
43 },
44 "AIRFLOW_CONTEXT_TRY_NUMBER": {
45 "default": f"{DEFAULT_FORMAT_PREFIX}try_number",
46 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}TRY_NUMBER",
47 },
48 "AIRFLOW_CONTEXT_DAG_RUN_ID": {
49 "default": f"{DEFAULT_FORMAT_PREFIX}dag_run_id",
50 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_RUN_ID",
51 },
52 "AIRFLOW_CONTEXT_DAG_OWNER": {
53 "default": f"{DEFAULT_FORMAT_PREFIX}dag_owner",
54 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_OWNER",
55 },
56 "AIRFLOW_CONTEXT_DAG_EMAIL": {
57 "default": f"{DEFAULT_FORMAT_PREFIX}dag_email",
58 "env_var_format": f"{ENV_VAR_FORMAT_PREFIX}DAG_EMAIL",
59 },
60}
63def context_to_airflow_vars(context: Mapping[str, Any], in_env_var_format: bool = False) -> dict[str, str]:
64 """
65 Given a context, this function provides a dictionary of values that can be used to
66 externally reconstruct relations between dags, dag_runs, tasks and task_instances.
67 Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if
68 in_env_var_format is set to True.
70 :param context: The context for the task_instance of interest.
71 :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format.
72 :return: task_instance context as dict.
73 """
74 params = {}
75 if in_env_var_format:
76 name_format = "env_var_format"
77 else:
78 name_format = "default"
80 task = context.get("task")
81 task_instance = context.get("task_instance")
82 dag_run = context.get("dag_run")
84 ops = [
85 (task, "email", "AIRFLOW_CONTEXT_DAG_EMAIL"),
86 (task, "owner", "AIRFLOW_CONTEXT_DAG_OWNER"),
87 (task_instance, "dag_id", "AIRFLOW_CONTEXT_DAG_ID"),
88 (task_instance, "task_id", "AIRFLOW_CONTEXT_TASK_ID"),
89 (task_instance, "execution_date", "AIRFLOW_CONTEXT_EXECUTION_DATE"),
90 (task_instance, "try_number", "AIRFLOW_CONTEXT_TRY_NUMBER"),
91 (dag_run, "run_id", "AIRFLOW_CONTEXT_DAG_RUN_ID"),
92 ]
94 context_params = settings.get_airflow_context_vars(context)
95 for key, value in context_params.items():
96 if not isinstance(key, str):
97 raise TypeError(f"key <{key}> must be string")
98 if not isinstance(value, str):
99 raise TypeError(f"value of key <{key}> must be string, not {type(value)}")
101 if in_env_var_format:
102 if not key.startswith(ENV_VAR_FORMAT_PREFIX):
103 key = ENV_VAR_FORMAT_PREFIX + key.upper()
104 else:
105 if not key.startswith(DEFAULT_FORMAT_PREFIX):
106 key = DEFAULT_FORMAT_PREFIX + key
107 params[key] = value
109 for subject, attr, mapping_key in ops:
110 _attr = getattr(subject, attr, None)
111 if subject and _attr:
112 mapping_value = AIRFLOW_VAR_NAME_FORMAT_MAPPING[mapping_key][name_format]
113 if isinstance(_attr, str):
114 params[mapping_value] = _attr
115 elif isinstance(_attr, datetime):
116 params[mapping_value] = _attr.isoformat()
117 elif isinstance(_attr, list):
118 # os env variable value needs to be string
119 params[mapping_value] = ",".join(_attr)
120 else:
121 params[mapping_value] = str(_attr)
123 return params
126class KeywordParameters:
127 """Wrapper representing ``**kwargs`` to a callable.
129 The actual ``kwargs`` can be obtained by calling either ``unpacking()`` or
130 ``serializing()``. They behave almost the same and are only different if
131 the containing ``kwargs`` is an Airflow Context object, and the calling
132 function uses ``**kwargs`` in the argument list.
134 In this particular case, ``unpacking()`` uses ``lazy-object-proxy`` to
135 prevent the Context from emitting deprecation warnings too eagerly when it's
136 unpacked by ``**``. ``serializing()`` does not do this, and will allow the
137 warnings to be emitted eagerly, which is useful when you want to dump the
138 content and use it somewhere else without needing ``lazy-object-proxy``.
139 """
141 def __init__(self, kwargs: Mapping[str, Any], *, wildcard: bool) -> None:
142 self._kwargs = kwargs
143 self._wildcard = wildcard
145 @classmethod
146 def determine(
147 cls,
148 func: Callable[..., Any],
149 args: Collection[Any],
150 kwargs: Mapping[str, Any],
151 ) -> KeywordParameters:
152 import inspect
153 import itertools
155 signature = inspect.signature(func)
156 has_wildcard_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature.parameters.values())
158 for name in itertools.islice(signature.parameters.keys(), len(args)):
159 # Check if args conflict with names in kwargs.
160 if name in kwargs:
161 raise ValueError(f"The key {name!r} in args is a part of kwargs and therefore reserved.")
163 if has_wildcard_kwargs:
164 # If the callable has a **kwargs argument, it's ready to accept all the kwargs.
165 return cls(kwargs, wildcard=True)
167 # If the callable has no **kwargs argument, it only wants the arguments it requested.
168 kwargs = {key: kwargs[key] for key in signature.parameters if key in kwargs}
169 return cls(kwargs, wildcard=False)
171 def unpacking(self) -> Mapping[str, Any]:
172 """Dump the kwargs mapping to unpack with ``**`` in a function call."""
173 if self._wildcard and isinstance(self._kwargs, Context):
174 return lazy_mapping_from_context(self._kwargs)
175 return self._kwargs
177 def serializing(self) -> Mapping[str, Any]:
178 """Dump the kwargs mapping for serialization purposes."""
179 return self._kwargs
182def determine_kwargs(
183 func: Callable[..., Any],
184 args: Collection[Any],
185 kwargs: Mapping[str, Any],
186) -> Mapping[str, Any]:
187 """
188 Inspect the signature of a given callable to determine which arguments in kwargs need
189 to be passed to the callable.
191 :param func: The callable that you want to invoke
192 :param args: The positional arguments that needs to be passed to the callable, so we
193 know how many to skip.
194 :param kwargs: The keyword arguments that need to be filtered before passing to the callable.
195 :return: A dictionary which contains the keyword arguments that are compatible with the callable.
196 """
197 return KeywordParameters.determine(func, args, kwargs).unpacking()
200def make_kwargs_callable(func: Callable[..., R]) -> Callable[..., R]:
201 """
202 Make a new callable that can accept any number of positional or keyword arguments
203 but only forwards those required by the given callable func.
204 """
205 import functools
207 @functools.wraps(func)
208 def kwargs_func(*args, **kwargs):
209 kwargs = determine_kwargs(func, args, kwargs)
210 return func(*args, **kwargs)
212 return kwargs_func