1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import contextlib
20import copy
21import json
22import logging
23from collections.abc import ItemsView, Iterable, Mapping, MutableMapping, ValuesView
24from typing import TYPE_CHECKING, Any, ClassVar, Literal
25
26from airflow.sdk.definitions._internal.mixins import ResolveMixin
27from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set
28from airflow.sdk.exceptions import ParamValidationError
29
30if TYPE_CHECKING:
31 from airflow.sdk.definitions.context import Context
32 from airflow.sdk.definitions.dag import DAG
33 from airflow.sdk.types import Operator
34
35logger = logging.getLogger(__name__)
36
37
38class Param:
39 """
40 Class to hold the default value of a Param and rule set to do the validations.
41
42 Without the rule set it always validates and returns the default value.
43
44 :param default: The value this Param object holds
45 :param description: Optional help text for the Param
46 :param schema: The validation schema of the Param, if not given then all kwargs except
47 default & description will form the schema
48 """
49
50 __version__: ClassVar[int] = 1
51
52 CLASS_IDENTIFIER = "__class"
53
54 def __init__(
55 self,
56 default: Any = NOTSET,
57 description: str | None = None,
58 source: Literal["dag", "task"] | None = None,
59 **kwargs,
60 ):
61 if default is not NOTSET:
62 self._check_json(default)
63 self.value = default
64 self.description = description
65 self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
66 self.source = source
67
68 def __copy__(self) -> Param:
69 return Param(
70 self.value,
71 self.description,
72 schema=self.schema,
73 source=self.source,
74 )
75
76 @staticmethod
77 def _check_json(value):
78 try:
79 json.dumps(value)
80 except Exception:
81 raise ParamValidationError(
82 f"All provided parameters must be json-serializable. The value '{value}' is not serializable."
83 )
84
85 def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:
86 """
87 Run the validations and returns the Param's final value.
88
89 May raise ValueError on failed validations, or TypeError
90 if no value is passed and no value already exists.
91 We first check that value is json-serializable; if not, warn.
92 In future release we will require the value to be json-serializable.
93
94 :param value: The value to be updated for the Param
95 :param suppress_exception: To raise an exception or not when the validations fails.
96 If true and validations fails, the return value would be None.
97 """
98 import jsonschema
99 from jsonschema import FormatChecker
100 from jsonschema.exceptions import ValidationError
101
102 if value is not NOTSET:
103 self._check_json(value)
104 final_val = self.value if value is NOTSET else value
105 if not is_arg_set(final_val):
106 if suppress_exception:
107 return None
108 raise ParamValidationError("No value passed and Param has no default value")
109 try:
110 jsonschema.validate(final_val, self.schema, format_checker=FormatChecker())
111 except ValidationError as err:
112 if suppress_exception:
113 return None
114 raise ParamValidationError(err) from None
115 self.value = final_val
116 return final_val
117
118 def dump(self) -> dict:
119 """Dump the Param as a dictionary."""
120 out_dict: dict[str, str | None] = {
121 self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"
122 }
123 out_dict.update(self.__dict__)
124 # Ensure that not set is translated to None
125 if self.value is NOTSET:
126 out_dict["value"] = None
127 return out_dict
128
129 @property
130 def has_value(self) -> bool:
131 return self.value is not NOTSET and self.value is not None
132
133 def serialize(self) -> dict:
134 return {
135 "value": self.value,
136 "description": self.description,
137 "schema": self.schema,
138 "source": self.source,
139 }
140
141 @staticmethod
142 def deserialize(data: dict[str, Any], version: int) -> Param:
143 if version > Param.__version__:
144 raise TypeError("serialized version > class version")
145
146 return Param(
147 default=data["value"],
148 description=data["description"],
149 schema=data["schema"],
150 source=data.get("source", None),
151 )
152
153
154class ParamsDict(MutableMapping[str, Any]):
155 """
156 Class to hold all params for dags or tasks.
157
158 All the keys are strictly string and values are converted into Param's object
159 if they are not already. This class is to replace param's dictionary implicitly
160 and ideally not needed to be used directly.
161
162 :param dict_obj: A dict or dict like object to init ParamsDict
163 :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict
164 """
165
166 __version__: ClassVar[int] = 1
167 __slots__ = ["__dict", "suppress_exception"]
168
169 def __init__(self, dict_obj: Mapping[str, Any] | None = None, suppress_exception: bool = False):
170 self.__dict = {k: v if isinstance(v, Param) else Param(v) for k, v in (dict_obj or {}).items()}
171 self.suppress_exception = suppress_exception
172
173 def __bool__(self) -> bool:
174 return bool(self.__dict)
175
176 def __eq__(self, other: Any) -> bool:
177 if isinstance(other, ParamsDict):
178 return self.dump() == other.dump()
179 if isinstance(other, dict):
180 return self.dump() == other
181 return NotImplemented
182
183 def __hash__(self):
184 return hash(self.dump())
185
186 def __copy__(self) -> ParamsDict:
187 return ParamsDict(self.__dict, self.suppress_exception)
188
189 def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict:
190 return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception)
191
192 def __contains__(self, o: object) -> bool:
193 return o in self.__dict
194
195 def __len__(self) -> int:
196 return len(self.__dict)
197
198 def __delitem__(self, v: str) -> None:
199 del self.__dict[v]
200
201 def __iter__(self):
202 return iter(self.__dict)
203
204 def __repr__(self):
205 return repr(self.dump())
206
207 def __setitem__(self, key: str, value: Any) -> None:
208 """
209 Override for dictionary's ``setitem`` method to ensure all values are of Param's type only.
210
211 :param key: A key which needs to be inserted or updated in the dict
212 :param value: A value which needs to be set against the key. It could be of any
213 type but will be converted and stored as a Param object eventually.
214 """
215 if isinstance(value, Param):
216 param = value
217 elif key in self.__dict:
218 param = self.__dict[key]
219 try:
220 param.resolve(value=value, suppress_exception=self.suppress_exception)
221 except ParamValidationError as ve:
222 raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None
223 else:
224 # if the key isn't there already and if the value isn't of Param type create a new Param object
225 param = Param(value)
226
227 self.__dict[key] = param
228
229 def __getitem__(self, key: str) -> Any:
230 """
231 Override for dictionary's ``getitem`` method to call the resolve method after fetching the key.
232
233 :param key: The key to fetch
234 """
235 param = self.__dict[key]
236 return param.resolve(suppress_exception=self.suppress_exception)
237
238 def get_param(self, key: str) -> Param:
239 """Get the internal :class:`.Param` object for this key."""
240 return self.__dict[key]
241
242 def items(self):
243 return ItemsView(self.__dict)
244
245 def values(self):
246 return ValuesView(self.__dict)
247
248 def update(self, *args, **kwargs) -> None:
249 if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict):
250 return super().update(args[0].__dict)
251 super().update(*args, **kwargs)
252
253 def dump(self) -> dict[str, Any]:
254 """Dump the ParamsDict object as a dictionary, while suppressing exceptions."""
255 return {k: v.resolve(suppress_exception=True) for k, v in self.items()}
256
257 def validate(self) -> dict[str, Any]:
258 """Validate & returns all the Params object stored in the dictionary."""
259 resolved_dict = {}
260 try:
261 for k, v in self.items():
262 resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception)
263 except ParamValidationError as ve:
264 raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None
265
266 return resolved_dict
267
268 def serialize(self) -> dict[str, Any]:
269 return self.dump()
270
271 @staticmethod
272 def deserialize(data: dict, version: int) -> ParamsDict:
273 if version > ParamsDict.__version__:
274 raise TypeError("serialized version > class version")
275
276 return ParamsDict(data)
277
278 def _fill_missing_param_source(
279 self,
280 source: Literal["dag", "task"] | None = None,
281 ) -> None:
282 for key in self.__dict:
283 if self.__dict[key].source is None:
284 self.__dict[key].source = source
285
286 @staticmethod
287 def filter_params_by_source(params: ParamsDict, source: Literal["dag", "task"]) -> ParamsDict:
288 return ParamsDict(
289 {key: param for key, param in params.__dict.items() if param.source == source},
290 )
291
292
293class DagParam(ResolveMixin):
294 """
295 Dag run parameter reference.
296
297 This binds a simple Param object to a name within a Dag instance, so that it
298 can be resolved during the runtime via the ``{{ context }}`` dictionary. The
299 ideal use case of this class is to implicitly convert args passed to a
300 method decorated by ``@dag``.
301
302 It can be used to parameterize a Dag. You can overwrite its value by setting
303 it on conf when you trigger your DagRun.
304
305 This can also be used in templates by accessing ``{{ context.params }}``.
306
307 **Example**:
308
309 with DAG(...) as dag:
310 EmailOperator(subject=dag.param('subject', 'Hi from Airflow!'))
311
312 :param current_dag: Dag being used for parameter.
313 :param name: key value which is used to set the parameter
314 :param default: Default value used if no parameter was set.
315 """
316
317 def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
318 if default is not NOTSET:
319 current_dag.params[name] = default
320 self._name = name
321 self._default = default
322 self.current_dag = current_dag
323
324 def iter_references(self) -> Iterable[tuple[Operator, str]]:
325 return ()
326
327 def resolve(self, context: Context) -> Any:
328 """Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
329 with contextlib.suppress(KeyError):
330 if context["dag_run"].conf:
331 return context["dag_run"].conf[self._name]
332 if self._default is not NOTSET:
333 return self._default
334 with contextlib.suppress(KeyError):
335 return context["params"][self._name]
336 raise RuntimeError(f"No value could be resolved for parameter {self._name}")
337
338 def serialize(self) -> dict:
339 """Serialize the DagParam object into a dictionary."""
340 return {
341 "dag_id": self.current_dag.dag_id,
342 "name": self._name,
343 "default": self._default,
344 }
345
346 @classmethod
347 def deserialize(cls, data: dict, dags: dict) -> DagParam:
348 """
349 Deserializes the dictionary back into a DagParam object.
350
351 :param data: The serialized representation of the DagParam.
352 :param dags: A dictionary of available Dags to look up the Dag.
353 """
354 dag_id = data["dag_id"]
355 # Retrieve the current Dag from the provided Dags dictionary
356 current_dag = dags.get(dag_id)
357 if not current_dag:
358 raise ValueError(f"Dag with id {dag_id} not found.")
359
360 return cls(current_dag=current_dag, name=data["name"], default=data["default"])
361
362
363def process_params(
364 dag: DAG,
365 task: Operator,
366 dagrun_conf: dict[str, Any] | None,
367 *,
368 suppress_exception: bool,
369) -> dict[str, Any]:
370 """Merge, validate params, and convert them into a simple dict."""
371 from airflow.sdk.configuration import conf
372
373 dagrun_conf = dagrun_conf or {}
374
375 params = ParamsDict(suppress_exception=suppress_exception)
376 with contextlib.suppress(AttributeError):
377 params.update(dag.params)
378 if task.params:
379 params.update(task.params)
380 if conf.getboolean("core", "dag_run_conf_overrides_params") and dagrun_conf:
381 logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dagrun_conf)
382 params.update(dagrun_conf)
383 return params.validate()