Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/param.py: 38%
168 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import contextlib
20import copy
21import json
22import logging
23import warnings
24from typing import TYPE_CHECKING, Any, ClassVar, ItemsView, Iterable, MutableMapping, ValuesView
26from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning
27from airflow.utils.context import Context
28from airflow.utils.mixins import ResolveMixin
29from airflow.utils.types import NOTSET, ArgNotSet
31if TYPE_CHECKING:
32 from airflow.models.dag import DAG
33 from airflow.models.dagrun import DagRun
34 from airflow.models.operator import Operator
36logger = logging.getLogger(__name__)
39class Param:
40 """
41 Class to hold the default value of a Param and rule set to do the validations. Without the rule set
42 it always validates and returns the default value.
44 :param default: The value this Param object holds
45 :param description: Optional help text for the Param
46 :param schema: The validation schema of the Param, if not given then all kwargs except
47 default & description will form the schema
48 """
50 __version__: ClassVar[int] = 1
52 CLASS_IDENTIFIER = "__class"
54 def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs):
55 if default is not NOTSET:
56 self._warn_if_not_json(default)
57 self.value = default
58 self.description = description
59 self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
61 def __copy__(self) -> Param:
62 return Param(self.value, self.description, schema=self.schema)
64 @staticmethod
65 def _warn_if_not_json(value):
66 try:
67 json.dumps(value)
68 except Exception:
69 warnings.warn(
70 "The use of non-json-serializable params is deprecated and will be removed in "
71 "a future release",
72 RemovedInAirflow3Warning,
73 )
75 def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:
76 """
77 Runs the validations and returns the Param's final value.
78 May raise ValueError on failed validations, or TypeError
79 if no value is passed and no value already exists.
80 We first check that value is json-serializable; if not, warn.
81 In future release we will require the value to be json-serializable.
83 :param value: The value to be updated for the Param
84 :param suppress_exception: To raise an exception or not when the validations fails.
85 If true and validations fails, the return value would be None.
86 """
87 import jsonschema
88 from jsonschema import FormatChecker
89 from jsonschema.exceptions import ValidationError
91 if value is not NOTSET:
92 self._warn_if_not_json(value)
93 final_val = value if value is not NOTSET else self.value
94 if isinstance(final_val, ArgNotSet):
95 if suppress_exception:
96 return None
97 raise ParamValidationError("No value passed and Param has no default value")
98 try:
99 jsonschema.validate(final_val, self.schema, format_checker=FormatChecker())
100 except ValidationError as err:
101 if suppress_exception:
102 return None
103 raise ParamValidationError(err) from None
104 self.value = final_val
105 return final_val
107 def dump(self) -> dict:
108 """Dump the Param as a dictionary"""
109 out_dict = {self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"}
110 out_dict.update(self.__dict__)
111 return out_dict
113 @property
114 def has_value(self) -> bool:
115 return self.value is not NOTSET
117 def serialize(self) -> dict:
118 return {"value": self.value, "description": self.description, "schema": self.schema}
120 @staticmethod
121 def deserialize(data: dict[str, Any], version: int) -> Param:
122 if version > Param.__version__:
123 raise TypeError("serialized version > class version")
125 return Param(default=data["value"], description=data["description"], schema=data["schema"])
128class ParamsDict(MutableMapping[str, Any]):
129 """
130 Class to hold all params for dags or tasks. All the keys are strictly string and values
131 are converted into Param's object if they are not already. This class is to replace param's
132 dictionary implicitly and ideally not needed to be used directly.
133 """
135 __version__: ClassVar[int] = 1
136 __slots__ = ["__dict", "suppress_exception"]
138 def __init__(self, dict_obj: dict | None = None, suppress_exception: bool = False):
139 """
140 :param dict_obj: A dict or dict like object to init ParamsDict
141 :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict
142 """
143 params_dict: dict[str, Param] = {}
144 dict_obj = dict_obj or {}
145 for k, v in dict_obj.items():
146 if not isinstance(v, Param):
147 params_dict[k] = Param(v)
148 else:
149 params_dict[k] = v
150 self.__dict = params_dict
151 self.suppress_exception = suppress_exception
153 def __bool__(self) -> bool:
154 return bool(self.__dict)
156 def __eq__(self, other: Any) -> bool:
157 if isinstance(other, ParamsDict):
158 return self.dump() == other.dump()
159 if isinstance(other, dict):
160 return self.dump() == other
161 return NotImplemented
163 def __copy__(self) -> ParamsDict:
164 return ParamsDict(self.__dict, self.suppress_exception)
166 def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict:
167 return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception)
169 def __contains__(self, o: object) -> bool:
170 return o in self.__dict
172 def __len__(self) -> int:
173 return len(self.__dict)
175 def __delitem__(self, v: str) -> None:
176 del self.__dict[v]
178 def __iter__(self):
179 return iter(self.__dict)
181 def __repr__(self):
182 return repr(self.dump())
184 def __setitem__(self, key: str, value: Any) -> None:
185 """
186 Override for dictionary's ``setitem`` method. This method make sure that all values are of
187 Param's type only.
189 :param key: A key which needs to be inserted or updated in the dict
190 :param value: A value which needs to be set against the key. It could be of any
191 type but will be converted and stored as a Param object eventually.
192 """
193 if isinstance(value, Param):
194 param = value
195 elif key in self.__dict:
196 param = self.__dict[key]
197 try:
198 param.resolve(value=value, suppress_exception=self.suppress_exception)
199 except ParamValidationError as ve:
200 raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None
201 else:
202 # if the key isn't there already and if the value isn't of Param type create a new Param object
203 param = Param(value)
205 self.__dict[key] = param
207 def __getitem__(self, key: str) -> Any:
208 """
209 Override for dictionary's ``getitem`` method. After fetching the key, it would call the
210 resolve method as well on the Param object.
212 :param key: The key to fetch
213 """
214 param = self.__dict[key]
215 return param.resolve(suppress_exception=self.suppress_exception)
217 def get_param(self, key: str) -> Param:
218 """Get the internal :class:`.Param` object for this key"""
219 return self.__dict[key]
221 def items(self):
222 return ItemsView(self.__dict)
224 def values(self):
225 return ValuesView(self.__dict)
227 def update(self, *args, **kwargs) -> None:
228 if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict):
229 return super().update(args[0].__dict)
230 super().update(*args, **kwargs)
232 def dump(self) -> dict[str, Any]:
233 """Dumps the ParamsDict object as a dictionary, while suppressing exceptions"""
234 return {k: v.resolve(suppress_exception=True) for k, v in self.items()}
236 def validate(self) -> dict[str, Any]:
237 """Validates & returns all the Params object stored in the dictionary"""
238 resolved_dict = {}
239 try:
240 for k, v in self.items():
241 resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception)
242 except ParamValidationError as ve:
243 raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None
245 return resolved_dict
247 def serialize(self) -> dict[str, Any]:
248 return self.dump()
250 @staticmethod
251 def deserialize(data: dict, version: int) -> ParamsDict:
252 if version > ParamsDict.__version__:
253 raise TypeError("serialized version > class version")
255 return ParamsDict(data)
258class DagParam(ResolveMixin):
259 """DAG run parameter reference.
261 This binds a simple Param object to a name within a DAG instance, so that it
262 can be resolved during the runtime via the ``{{ context }}`` dictionary. The
263 ideal use case of this class is to implicitly convert args passed to a
264 method decorated by ``@dag``.
266 It can be used to parameterize a DAG. You can overwrite its value by setting
267 it on conf when you trigger your DagRun.
269 This can also be used in templates by accessing ``{{ context.params }}``.
271 **Example**:
273 with DAG(...) as dag:
274 EmailOperator(subject=dag.param('subject', 'Hi from Airflow!'))
276 :param current_dag: Dag being used for parameter.
277 :param name: key value which is used to set the parameter
278 :param default: Default value used if no parameter was set.
279 """
281 def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
282 if default is not NOTSET:
283 current_dag.params[name] = default
284 self._name = name
285 self._default = default
287 def iter_references(self) -> Iterable[tuple[Operator, str]]:
288 return ()
290 def resolve(self, context: Context) -> Any:
291 """Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
292 with contextlib.suppress(KeyError):
293 return context["dag_run"].conf[self._name]
294 if self._default is not NOTSET:
295 return self._default
296 with contextlib.suppress(KeyError):
297 return context["params"][self._name]
298 raise AirflowException(f"No value could be resolved for parameter {self._name}")
301def process_params(
302 dag: DAG,
303 task: Operator,
304 dag_run: DagRun | None,
305 *,
306 suppress_exception: bool,
307) -> dict[str, Any]:
308 """Merge, validate params, and convert them into a simple dict."""
309 from airflow.configuration import conf
311 params = ParamsDict(suppress_exception=suppress_exception)
312 with contextlib.suppress(AttributeError):
313 params.update(dag.params)
314 if task.params:
315 params.update(task.params)
316 if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf:
317 logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
318 params.update(dag_run.conf)
319 return params.validate()