Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/definitions/param.py: 33%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

186 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import contextlib 

20import copy 

21import json 

22import logging 

23from collections.abc import ItemsView, Iterable, Mapping, MutableMapping, ValuesView 

24from typing import TYPE_CHECKING, Any, ClassVar, Literal 

25 

26from airflow.sdk.definitions._internal.mixins import ResolveMixin 

27from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set 

28from airflow.sdk.exceptions import ParamValidationError 

29 

30if TYPE_CHECKING: 

31 from airflow.sdk.definitions.context import Context 

32 from airflow.sdk.definitions.dag import DAG 

33 from airflow.sdk.types import Operator 

34 

35logger = logging.getLogger(__name__) 

36 

37 

38class Param: 

39 """ 

40 Class to hold the default value of a Param and rule set to do the validations. 

41 

42 Without the rule set it always validates and returns the default value. 

43 

44 :param default: The value this Param object holds 

45 :param description: Optional help text for the Param 

46 :param schema: The validation schema of the Param, if not given then all kwargs except 

47 default & description will form the schema 

48 """ 

49 

50 __version__: ClassVar[int] = 1 

51 

52 CLASS_IDENTIFIER = "__class" 

53 

54 def __init__( 

55 self, 

56 default: Any = NOTSET, 

57 description: str | None = None, 

58 source: Literal["dag", "task"] | None = None, 

59 **kwargs, 

60 ): 

61 if default is not NOTSET: 

62 self._check_json(default) 

63 self.value = default 

64 self.description = description 

65 self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs 

66 self.source = source 

67 

68 def __copy__(self) -> Param: 

69 return Param( 

70 self.value, 

71 self.description, 

72 schema=self.schema, 

73 source=self.source, 

74 ) 

75 

76 @staticmethod 

77 def _check_json(value): 

78 try: 

79 json.dumps(value) 

80 except Exception: 

81 raise ParamValidationError( 

82 f"All provided parameters must be json-serializable. The value '{value}' is not serializable." 

83 ) 

84 

85 def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: 

86 """ 

87 Run the validations and returns the Param's final value. 

88 

89 May raise ValueError on failed validations, or TypeError 

90 if no value is passed and no value already exists. 

91 We first check that value is json-serializable; if not, warn. 

92 In future release we will require the value to be json-serializable. 

93 

94 :param value: The value to be updated for the Param 

95 :param suppress_exception: To raise an exception or not when the validations fails. 

96 If true and validations fails, the return value would be None. 

97 """ 

98 import jsonschema 

99 from jsonschema import FormatChecker 

100 from jsonschema.exceptions import ValidationError 

101 

102 if value is not NOTSET: 

103 self._check_json(value) 

104 final_val = self.value if value is NOTSET else value 

105 if not is_arg_set(final_val): 

106 if suppress_exception: 

107 return None 

108 raise ParamValidationError("No value passed and Param has no default value") 

109 try: 

110 jsonschema.validate(final_val, self.schema, format_checker=FormatChecker()) 

111 except ValidationError as err: 

112 if suppress_exception: 

113 return None 

114 raise ParamValidationError(err) from None 

115 self.value = final_val 

116 return final_val 

117 

118 def dump(self) -> dict: 

119 """Dump the Param as a dictionary.""" 

120 out_dict: dict[str, str | None] = { 

121 self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}" 

122 } 

123 out_dict.update(self.__dict__) 

124 # Ensure that not set is translated to None 

125 if self.value is NOTSET: 

126 out_dict["value"] = None 

127 return out_dict 

128 

129 @property 

130 def has_value(self) -> bool: 

131 return self.value is not NOTSET and self.value is not None 

132 

133 def serialize(self) -> dict: 

134 return { 

135 "value": self.value, 

136 "description": self.description, 

137 "schema": self.schema, 

138 "source": self.source, 

139 } 

140 

141 @staticmethod 

142 def deserialize(data: dict[str, Any], version: int) -> Param: 

143 if version > Param.__version__: 

144 raise TypeError("serialized version > class version") 

145 

146 return Param( 

147 default=data["value"], 

148 description=data["description"], 

149 schema=data["schema"], 

150 source=data.get("source", None), 

151 ) 

152 

153 

154class ParamsDict(MutableMapping[str, Any]): 

155 """ 

156 Class to hold all params for dags or tasks. 

157 

158 All the keys are strictly string and values are converted into Param's object 

159 if they are not already. This class is to replace param's dictionary implicitly 

160 and ideally not needed to be used directly. 

161 

162 :param dict_obj: A dict or dict like object to init ParamsDict 

163 :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict 

164 """ 

165 

166 __version__: ClassVar[int] = 1 

167 __slots__ = ["__dict", "suppress_exception"] 

168 

169 def __init__(self, dict_obj: Mapping[str, Any] | None = None, suppress_exception: bool = False): 

170 self.__dict = {k: v if isinstance(v, Param) else Param(v) for k, v in (dict_obj or {}).items()} 

171 self.suppress_exception = suppress_exception 

172 

173 def __bool__(self) -> bool: 

174 return bool(self.__dict) 

175 

176 def __eq__(self, other: Any) -> bool: 

177 if isinstance(other, ParamsDict): 

178 return self.dump() == other.dump() 

179 if isinstance(other, dict): 

180 return self.dump() == other 

181 return NotImplemented 

182 

183 def __hash__(self): 

184 return hash(self.dump()) 

185 

186 def __copy__(self) -> ParamsDict: 

187 return ParamsDict(self.__dict, self.suppress_exception) 

188 

189 def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict: 

190 return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception) 

191 

192 def __contains__(self, o: object) -> bool: 

193 return o in self.__dict 

194 

195 def __len__(self) -> int: 

196 return len(self.__dict) 

197 

198 def __delitem__(self, v: str) -> None: 

199 del self.__dict[v] 

200 

201 def __iter__(self): 

202 return iter(self.__dict) 

203 

204 def __repr__(self): 

205 return repr(self.dump()) 

206 

207 def __setitem__(self, key: str, value: Any) -> None: 

208 """ 

209 Override for dictionary's ``setitem`` method to ensure all values are of Param's type only. 

210 

211 :param key: A key which needs to be inserted or updated in the dict 

212 :param value: A value which needs to be set against the key. It could be of any 

213 type but will be converted and stored as a Param object eventually. 

214 """ 

215 if isinstance(value, Param): 

216 param = value 

217 elif key in self.__dict: 

218 param = self.__dict[key] 

219 try: 

220 param.resolve(value=value, suppress_exception=self.suppress_exception) 

221 except ParamValidationError as ve: 

222 raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None 

223 else: 

224 # if the key isn't there already and if the value isn't of Param type create a new Param object 

225 param = Param(value) 

226 

227 self.__dict[key] = param 

228 

229 def __getitem__(self, key: str) -> Any: 

230 """ 

231 Override for dictionary's ``getitem`` method to call the resolve method after fetching the key. 

232 

233 :param key: The key to fetch 

234 """ 

235 param = self.__dict[key] 

236 return param.resolve(suppress_exception=self.suppress_exception) 

237 

238 def get_param(self, key: str) -> Param: 

239 """Get the internal :class:`.Param` object for this key.""" 

240 return self.__dict[key] 

241 

242 def items(self): 

243 return ItemsView(self.__dict) 

244 

245 def values(self): 

246 return ValuesView(self.__dict) 

247 

248 def update(self, *args, **kwargs) -> None: 

249 if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict): 

250 return super().update(args[0].__dict) 

251 super().update(*args, **kwargs) 

252 

253 def dump(self) -> dict[str, Any]: 

254 """Dump the ParamsDict object as a dictionary, while suppressing exceptions.""" 

255 return {k: v.resolve(suppress_exception=True) for k, v in self.items()} 

256 

257 def validate(self) -> dict[str, Any]: 

258 """Validate & returns all the Params object stored in the dictionary.""" 

259 resolved_dict = {} 

260 try: 

261 for k, v in self.items(): 

262 resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception) 

263 except ParamValidationError as ve: 

264 raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None 

265 

266 return resolved_dict 

267 

268 def serialize(self) -> dict[str, Any]: 

269 return self.dump() 

270 

271 @staticmethod 

272 def deserialize(data: dict, version: int) -> ParamsDict: 

273 if version > ParamsDict.__version__: 

274 raise TypeError("serialized version > class version") 

275 

276 return ParamsDict(data) 

277 

278 def _fill_missing_param_source( 

279 self, 

280 source: Literal["dag", "task"] | None = None, 

281 ) -> None: 

282 for key in self.__dict: 

283 if self.__dict[key].source is None: 

284 self.__dict[key].source = source 

285 

286 @staticmethod 

287 def filter_params_by_source(params: ParamsDict, source: Literal["dag", "task"]) -> ParamsDict: 

288 return ParamsDict( 

289 {key: param for key, param in params.__dict.items() if param.source == source}, 

290 ) 

291 

292 

293class DagParam(ResolveMixin): 

294 """ 

295 Dag run parameter reference. 

296 

297 This binds a simple Param object to a name within a Dag instance, so that it 

298 can be resolved during the runtime via the ``{{ context }}`` dictionary. The 

299 ideal use case of this class is to implicitly convert args passed to a 

300 method decorated by ``@dag``. 

301 

302 It can be used to parameterize a Dag. You can overwrite its value by setting 

303 it on conf when you trigger your DagRun. 

304 

305 This can also be used in templates by accessing ``{{ context.params }}``. 

306 

307 **Example**: 

308 

309 with DAG(...) as dag: 

310 EmailOperator(subject=dag.param('subject', 'Hi from Airflow!')) 

311 

312 :param current_dag: Dag being used for parameter. 

313 :param name: key value which is used to set the parameter 

314 :param default: Default value used if no parameter was set. 

315 """ 

316 

317 def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET): 

318 if default is not NOTSET: 

319 current_dag.params[name] = default 

320 self._name = name 

321 self._default = default 

322 self.current_dag = current_dag 

323 

324 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

325 return () 

326 

327 def resolve(self, context: Context) -> Any: 

328 """Pull DagParam value from DagRun context. This method is run during ``op.execute()``.""" 

329 with contextlib.suppress(KeyError): 

330 if context["dag_run"].conf: 

331 return context["dag_run"].conf[self._name] 

332 if self._default is not NOTSET: 

333 return self._default 

334 with contextlib.suppress(KeyError): 

335 return context["params"][self._name] 

336 raise RuntimeError(f"No value could be resolved for parameter {self._name}") 

337 

338 def serialize(self) -> dict: 

339 """Serialize the DagParam object into a dictionary.""" 

340 return { 

341 "dag_id": self.current_dag.dag_id, 

342 "name": self._name, 

343 "default": self._default, 

344 } 

345 

346 @classmethod 

347 def deserialize(cls, data: dict, dags: dict) -> DagParam: 

348 """ 

349 Deserializes the dictionary back into a DagParam object. 

350 

351 :param data: The serialized representation of the DagParam. 

352 :param dags: A dictionary of available Dags to look up the Dag. 

353 """ 

354 dag_id = data["dag_id"] 

355 # Retrieve the current Dag from the provided Dags dictionary 

356 current_dag = dags.get(dag_id) 

357 if not current_dag: 

358 raise ValueError(f"Dag with id {dag_id} not found.") 

359 

360 return cls(current_dag=current_dag, name=data["name"], default=data["default"]) 

361 

362 

363def process_params( 

364 dag: DAG, 

365 task: Operator, 

366 dagrun_conf: dict[str, Any] | None, 

367 *, 

368 suppress_exception: bool, 

369) -> dict[str, Any]: 

370 """Merge, validate params, and convert them into a simple dict.""" 

371 from airflow.sdk.configuration import conf 

372 

373 dagrun_conf = dagrun_conf or {} 

374 

375 params = ParamsDict(suppress_exception=suppress_exception) 

376 with contextlib.suppress(AttributeError): 

377 params.update(dag.params) 

378 if task.params: 

379 params.update(task.params) 

380 if conf.getboolean("core", "dag_run_conf_overrides_params") and dagrun_conf: 

381 logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dagrun_conf) 

382 params.update(dagrun_conf) 

383 return params.validate()