Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/param.py: 37%
188 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import contextlib
20import copy
21import datetime
22import json
23import logging
24import warnings
25from typing import TYPE_CHECKING, Any, ClassVar, ItemsView, Iterable, MutableMapping, ValuesView
27from pendulum.parsing import parse_iso8601
29from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning
30from airflow.utils import timezone
31from airflow.utils.context import Context
32from airflow.utils.mixins import ResolveMixin
33from airflow.utils.types import NOTSET, ArgNotSet
35if TYPE_CHECKING:
36 from airflow.models.dag import DAG
37 from airflow.models.dagrun import DagRun
38 from airflow.models.operator import Operator
40logger = logging.getLogger(__name__)
43class Param:
44 """
45 Class to hold the default value of a Param and rule set to do the validations. Without the rule set
46 it always validates and returns the default value.
48 :param default: The value this Param object holds
49 :param description: Optional help text for the Param
50 :param schema: The validation schema of the Param, if not given then all kwargs except
51 default & description will form the schema
52 """
54 __version__: ClassVar[int] = 1
56 CLASS_IDENTIFIER = "__class"
58 def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs):
59 if default is not NOTSET:
60 self._warn_if_not_json(default)
61 self.value = default
62 self.description = description
63 self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
65 def __copy__(self) -> Param:
66 return Param(self.value, self.description, schema=self.schema)
68 @staticmethod
69 def _warn_if_not_json(value):
70 try:
71 json.dumps(value)
72 except Exception:
73 warnings.warn(
74 "The use of non-json-serializable params is deprecated and will be removed in "
75 "a future release",
76 RemovedInAirflow3Warning,
77 )
79 @staticmethod
80 def _warn_if_not_rfc3339_dt(value):
81 """Fallback to iso8601 datetime validation if rfc3339 failed."""
82 try:
83 iso8601_value = parse_iso8601(value)
84 except Exception:
85 return None
86 if not isinstance(iso8601_value, datetime.datetime):
87 return None
88 warnings.warn(
89 f"The use of non-RFC3339 datetime: {value!r} is deprecated "
90 "and will be removed in a future release",
91 RemovedInAirflow3Warning,
92 )
93 if timezone.is_naive(iso8601_value):
94 warnings.warn(
95 "The use naive datetime is deprecated and will be removed in a future release",
96 RemovedInAirflow3Warning,
97 )
98 return value
100 def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:
101 """
102 Runs the validations and returns the Param's final value.
103 May raise ValueError on failed validations, or TypeError
104 if no value is passed and no value already exists.
105 We first check that value is json-serializable; if not, warn.
106 In future release we will require the value to be json-serializable.
108 :param value: The value to be updated for the Param
109 :param suppress_exception: To raise an exception or not when the validations fails.
110 If true and validations fails, the return value would be None.
111 """
112 import jsonschema
113 from jsonschema import FormatChecker
114 from jsonschema.exceptions import ValidationError
116 if value is not NOTSET:
117 self._warn_if_not_json(value)
118 final_val = value if value is not NOTSET else self.value
119 if isinstance(final_val, ArgNotSet):
120 if suppress_exception:
121 return None
122 raise ParamValidationError("No value passed and Param has no default value")
123 try:
124 jsonschema.validate(final_val, self.schema, format_checker=FormatChecker())
125 except ValidationError as err:
126 if err.schema.get("format") == "date-time":
127 rfc3339_value = self._warn_if_not_rfc3339_dt(final_val)
128 if rfc3339_value:
129 self.value = rfc3339_value
130 return rfc3339_value
131 if suppress_exception:
132 return None
133 raise ParamValidationError(err) from None
134 self.value = final_val
135 return final_val
137 def dump(self) -> dict:
138 """Dump the Param as a dictionary."""
139 out_dict = {self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"}
140 out_dict.update(self.__dict__)
141 return out_dict
143 @property
144 def has_value(self) -> bool:
145 return self.value is not NOTSET
147 def serialize(self) -> dict:
148 return {"value": self.value, "description": self.description, "schema": self.schema}
150 @staticmethod
151 def deserialize(data: dict[str, Any], version: int) -> Param:
152 if version > Param.__version__:
153 raise TypeError("serialized version > class version")
155 return Param(default=data["value"], description=data["description"], schema=data["schema"])
158class ParamsDict(MutableMapping[str, Any]):
159 """
160 Class to hold all params for dags or tasks. All the keys are strictly string and values
161 are converted into Param's object if they are not already. This class is to replace param's
162 dictionary implicitly and ideally not needed to be used directly.
163 """
165 __version__: ClassVar[int] = 1
166 __slots__ = ["__dict", "suppress_exception"]
168 def __init__(self, dict_obj: MutableMapping | None = None, suppress_exception: bool = False):
169 """
170 :param dict_obj: A dict or dict like object to init ParamsDict
171 :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict
172 """
173 params_dict: dict[str, Param] = {}
174 dict_obj = dict_obj or {}
175 for k, v in dict_obj.items():
176 if not isinstance(v, Param):
177 params_dict[k] = Param(v)
178 else:
179 params_dict[k] = v
180 self.__dict = params_dict
181 self.suppress_exception = suppress_exception
183 def __bool__(self) -> bool:
184 return bool(self.__dict)
186 def __eq__(self, other: Any) -> bool:
187 if isinstance(other, ParamsDict):
188 return self.dump() == other.dump()
189 if isinstance(other, dict):
190 return self.dump() == other
191 return NotImplemented
193 def __copy__(self) -> ParamsDict:
194 return ParamsDict(self.__dict, self.suppress_exception)
196 def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict:
197 return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception)
199 def __contains__(self, o: object) -> bool:
200 return o in self.__dict
202 def __len__(self) -> int:
203 return len(self.__dict)
205 def __delitem__(self, v: str) -> None:
206 del self.__dict[v]
208 def __iter__(self):
209 return iter(self.__dict)
211 def __repr__(self):
212 return repr(self.dump())
214 def __setitem__(self, key: str, value: Any) -> None:
215 """
216 Override for dictionary's ``setitem`` method. This method make sure that all values are of
217 Param's type only.
219 :param key: A key which needs to be inserted or updated in the dict
220 :param value: A value which needs to be set against the key. It could be of any
221 type but will be converted and stored as a Param object eventually.
222 """
223 if isinstance(value, Param):
224 param = value
225 elif key in self.__dict:
226 param = self.__dict[key]
227 try:
228 param.resolve(value=value, suppress_exception=self.suppress_exception)
229 except ParamValidationError as ve:
230 raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None
231 else:
232 # if the key isn't there already and if the value isn't of Param type create a new Param object
233 param = Param(value)
235 self.__dict[key] = param
237 def __getitem__(self, key: str) -> Any:
238 """
239 Override for dictionary's ``getitem`` method. After fetching the key, it would call the
240 resolve method as well on the Param object.
242 :param key: The key to fetch
243 """
244 param = self.__dict[key]
245 return param.resolve(suppress_exception=self.suppress_exception)
247 def get_param(self, key: str) -> Param:
248 """Get the internal :class:`.Param` object for this key."""
249 return self.__dict[key]
251 def items(self):
252 return ItemsView(self.__dict)
254 def values(self):
255 return ValuesView(self.__dict)
257 def update(self, *args, **kwargs) -> None:
258 if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict):
259 return super().update(args[0].__dict)
260 super().update(*args, **kwargs)
262 def dump(self) -> dict[str, Any]:
263 """Dumps the ParamsDict object as a dictionary, while suppressing exceptions."""
264 return {k: v.resolve(suppress_exception=True) for k, v in self.items()}
266 def validate(self) -> dict[str, Any]:
267 """Validates & returns all the Params object stored in the dictionary."""
268 resolved_dict = {}
269 try:
270 for k, v in self.items():
271 resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception)
272 except ParamValidationError as ve:
273 raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None
275 return resolved_dict
277 def serialize(self) -> dict[str, Any]:
278 return self.dump()
280 @staticmethod
281 def deserialize(data: dict, version: int) -> ParamsDict:
282 if version > ParamsDict.__version__:
283 raise TypeError("serialized version > class version")
285 return ParamsDict(data)
288class DagParam(ResolveMixin):
289 """DAG run parameter reference.
291 This binds a simple Param object to a name within a DAG instance, so that it
292 can be resolved during the runtime via the ``{{ context }}`` dictionary. The
293 ideal use case of this class is to implicitly convert args passed to a
294 method decorated by ``@dag``.
296 It can be used to parameterize a DAG. You can overwrite its value by setting
297 it on conf when you trigger your DagRun.
299 This can also be used in templates by accessing ``{{ context.params }}``.
301 **Example**:
303 with DAG(...) as dag:
304 EmailOperator(subject=dag.param('subject', 'Hi from Airflow!'))
306 :param current_dag: Dag being used for parameter.
307 :param name: key value which is used to set the parameter
308 :param default: Default value used if no parameter was set.
309 """
311 def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
312 if default is not NOTSET:
313 current_dag.params[name] = default
314 self._name = name
315 self._default = default
317 def iter_references(self) -> Iterable[tuple[Operator, str]]:
318 return ()
320 def resolve(self, context: Context) -> Any:
321 """Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
322 with contextlib.suppress(KeyError):
323 return context["dag_run"].conf[self._name]
324 if self._default is not NOTSET:
325 return self._default
326 with contextlib.suppress(KeyError):
327 return context["params"][self._name]
328 raise AirflowException(f"No value could be resolved for parameter {self._name}")
331def process_params(
332 dag: DAG,
333 task: Operator,
334 dag_run: DagRun | None,
335 *,
336 suppress_exception: bool,
337) -> dict[str, Any]:
338 """Merge, validate params, and convert them into a simple dict."""
339 from airflow.configuration import conf
341 params = ParamsDict(suppress_exception=suppress_exception)
342 with contextlib.suppress(AttributeError):
343 params.update(dag.params)
344 if task.params:
345 params.update(task.params)
346 if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf:
347 logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
348 params.update(dag_run.conf)
349 return params.validate()