Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/param.py: 38%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
19import contextlib
20import copy
21import datetime
22import json
23import logging
24import warnings
25from typing import TYPE_CHECKING, Any, ClassVar, ItemsView, Iterable, MutableMapping, ValuesView
27from pendulum.parsing import parse_iso8601
29from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning
30from airflow.utils import timezone
31from airflow.utils.mixins import ResolveMixin
32from airflow.utils.types import NOTSET, ArgNotSet
34if TYPE_CHECKING:
35 from airflow.models.dag import DAG
36 from airflow.models.dagrun import DagRun
37 from airflow.models.operator import Operator
38 from airflow.serialization.pydantic.dag_run import DagRunPydantic
39 from airflow.utils.context import Context
41logger = logging.getLogger(__name__)
44class Param:
45 """
46 Class to hold the default value of a Param and rule set to do the validations.
48 Without the rule set it always validates and returns the default value.
50 :param default: The value this Param object holds
51 :param description: Optional help text for the Param
52 :param schema: The validation schema of the Param, if not given then all kwargs except
53 default & description will form the schema
54 """
56 __version__: ClassVar[int] = 1
58 CLASS_IDENTIFIER = "__class"
60 def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs):
61 if default is not NOTSET:
62 self._warn_if_not_json(default)
63 self.value = default
64 self.description = description
65 self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
67 def __copy__(self) -> Param:
68 return Param(self.value, self.description, schema=self.schema)
70 @staticmethod
71 def _warn_if_not_json(value):
72 try:
73 json.dumps(value)
74 except Exception:
75 warnings.warn(
76 "The use of non-json-serializable params is deprecated and will be removed in "
77 "a future release",
78 RemovedInAirflow3Warning,
79 stacklevel=1,
80 )
82 @staticmethod
83 def _warn_if_not_rfc3339_dt(value):
84 """Fallback to iso8601 datetime validation if rfc3339 failed."""
85 try:
86 iso8601_value = parse_iso8601(value)
87 except Exception:
88 return None
89 if not isinstance(iso8601_value, datetime.datetime):
90 return None
91 warnings.warn(
92 f"The use of non-RFC3339 datetime: {value!r} is deprecated "
93 "and will be removed in a future release",
94 RemovedInAirflow3Warning,
95 stacklevel=1,
96 )
97 if timezone.is_naive(iso8601_value):
98 warnings.warn(
99 "The use naive datetime is deprecated and will be removed in a future release",
100 RemovedInAirflow3Warning,
101 stacklevel=1,
102 )
103 return value
105 def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:
106 """
107 Run the validations and returns the Param's final value.
109 May raise ValueError on failed validations, or TypeError
110 if no value is passed and no value already exists.
111 We first check that value is json-serializable; if not, warn.
112 In future release we will require the value to be json-serializable.
114 :param value: The value to be updated for the Param
115 :param suppress_exception: To raise an exception or not when the validations fails.
116 If true and validations fails, the return value would be None.
117 """
118 import jsonschema
119 from jsonschema import FormatChecker
120 from jsonschema.exceptions import ValidationError
122 if value is not NOTSET:
123 self._warn_if_not_json(value)
124 final_val = self.value if value is NOTSET else value
125 if isinstance(final_val, ArgNotSet):
126 if suppress_exception:
127 return None
128 raise ParamValidationError("No value passed and Param has no default value")
129 try:
130 jsonschema.validate(final_val, self.schema, format_checker=FormatChecker())
131 except ValidationError as err:
132 if err.schema.get("format") == "date-time":
133 rfc3339_value = self._warn_if_not_rfc3339_dt(final_val)
134 if rfc3339_value:
135 self.value = rfc3339_value
136 return rfc3339_value
137 if suppress_exception:
138 return None
139 raise ParamValidationError(err) from None
140 self.value = final_val
141 return final_val
143 def dump(self) -> dict:
144 """Dump the Param as a dictionary."""
145 out_dict: dict[str, str | None] = {
146 self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"
147 }
148 out_dict.update(self.__dict__)
149 # Ensure that not set is translated to None
150 if self.value is NOTSET:
151 out_dict["value"] = None
152 return out_dict
154 @property
155 def has_value(self) -> bool:
156 return self.value is not NOTSET and self.value is not None
158 def serialize(self) -> dict:
159 return {"value": self.value, "description": self.description, "schema": self.schema}
161 @staticmethod
162 def deserialize(data: dict[str, Any], version: int) -> Param:
163 if version > Param.__version__:
164 raise TypeError("serialized version > class version")
166 return Param(default=data["value"], description=data["description"], schema=data["schema"])
169class ParamsDict(MutableMapping[str, Any]):
170 """
171 Class to hold all params for dags or tasks.
173 All the keys are strictly string and values are converted into Param's object
174 if they are not already. This class is to replace param's dictionary implicitly
175 and ideally not needed to be used directly.
178 :param dict_obj: A dict or dict like object to init ParamsDict
179 :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict
180 """
182 __version__: ClassVar[int] = 1
183 __slots__ = ["__dict", "suppress_exception"]
185 def __init__(self, dict_obj: MutableMapping | None = None, suppress_exception: bool = False):
186 params_dict: dict[str, Param] = {}
187 dict_obj = dict_obj or {}
188 for k, v in dict_obj.items():
189 if not isinstance(v, Param):
190 params_dict[k] = Param(v)
191 else:
192 params_dict[k] = v
193 self.__dict = params_dict
194 self.suppress_exception = suppress_exception
196 def __bool__(self) -> bool:
197 return bool(self.__dict)
199 def __eq__(self, other: Any) -> bool:
200 if isinstance(other, ParamsDict):
201 return self.dump() == other.dump()
202 if isinstance(other, dict):
203 return self.dump() == other
204 return NotImplemented
206 def __copy__(self) -> ParamsDict:
207 return ParamsDict(self.__dict, self.suppress_exception)
209 def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict:
210 return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception)
212 def __contains__(self, o: object) -> bool:
213 return o in self.__dict
215 def __len__(self) -> int:
216 return len(self.__dict)
218 def __delitem__(self, v: str) -> None:
219 del self.__dict[v]
221 def __iter__(self):
222 return iter(self.__dict)
224 def __repr__(self):
225 return repr(self.dump())
227 def __setitem__(self, key: str, value: Any) -> None:
228 """
229 Override for dictionary's ``setitem`` method to ensure all values are of Param's type only.
231 :param key: A key which needs to be inserted or updated in the dict
232 :param value: A value which needs to be set against the key. It could be of any
233 type but will be converted and stored as a Param object eventually.
234 """
235 if isinstance(value, Param):
236 param = value
237 elif key in self.__dict:
238 param = self.__dict[key]
239 try:
240 param.resolve(value=value, suppress_exception=self.suppress_exception)
241 except ParamValidationError as ve:
242 raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None
243 else:
244 # if the key isn't there already and if the value isn't of Param type create a new Param object
245 param = Param(value)
247 self.__dict[key] = param
249 def __getitem__(self, key: str) -> Any:
250 """
251 Override for dictionary's ``getitem`` method to call the resolve method after fetching the key.
253 :param key: The key to fetch
254 """
255 param = self.__dict[key]
256 return param.resolve(suppress_exception=self.suppress_exception)
258 def get_param(self, key: str) -> Param:
259 """Get the internal :class:`.Param` object for this key."""
260 return self.__dict[key]
262 def items(self):
263 return ItemsView(self.__dict)
265 def values(self):
266 return ValuesView(self.__dict)
268 def update(self, *args, **kwargs) -> None:
269 if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict):
270 return super().update(args[0].__dict)
271 super().update(*args, **kwargs)
273 def dump(self) -> dict[str, Any]:
274 """Dump the ParamsDict object as a dictionary, while suppressing exceptions."""
275 return {k: v.resolve(suppress_exception=True) for k, v in self.items()}
277 def validate(self) -> dict[str, Any]:
278 """Validate & returns all the Params object stored in the dictionary."""
279 resolved_dict = {}
280 try:
281 for k, v in self.items():
282 resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception)
283 except ParamValidationError as ve:
284 raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None
286 return resolved_dict
288 def serialize(self) -> dict[str, Any]:
289 return self.dump()
291 @staticmethod
292 def deserialize(data: dict, version: int) -> ParamsDict:
293 if version > ParamsDict.__version__:
294 raise TypeError("serialized version > class version")
296 return ParamsDict(data)
299class DagParam(ResolveMixin):
300 """DAG run parameter reference.
302 This binds a simple Param object to a name within a DAG instance, so that it
303 can be resolved during the runtime via the ``{{ context }}`` dictionary. The
304 ideal use case of this class is to implicitly convert args passed to a
305 method decorated by ``@dag``.
307 It can be used to parameterize a DAG. You can overwrite its value by setting
308 it on conf when you trigger your DagRun.
310 This can also be used in templates by accessing ``{{ context.params }}``.
312 **Example**:
314 with DAG(...) as dag:
315 EmailOperator(subject=dag.param('subject', 'Hi from Airflow!'))
317 :param current_dag: Dag being used for parameter.
318 :param name: key value which is used to set the parameter
319 :param default: Default value used if no parameter was set.
320 """
322 def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
323 if default is not NOTSET:
324 current_dag.params[name] = default
325 self._name = name
326 self._default = default
328 def iter_references(self) -> Iterable[tuple[Operator, str]]:
329 return ()
331 def resolve(self, context: Context) -> Any:
332 """Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
333 with contextlib.suppress(KeyError):
334 return context["dag_run"].conf[self._name]
335 if self._default is not NOTSET:
336 return self._default
337 with contextlib.suppress(KeyError):
338 return context["params"][self._name]
339 raise AirflowException(f"No value could be resolved for parameter {self._name}")
342def process_params(
343 dag: DAG,
344 task: Operator,
345 dag_run: DagRun | DagRunPydantic | None,
346 *,
347 suppress_exception: bool,
348) -> dict[str, Any]:
349 """Merge, validate params, and convert them into a simple dict."""
350 from airflow.configuration import conf
352 params = ParamsDict(suppress_exception=suppress_exception)
353 with contextlib.suppress(AttributeError):
354 params.update(dag.params)
355 if task.params:
356 params.update(task.params)
357 if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf:
358 logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
359 params.update(dag_run.conf)
360 return params.validate()