Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/models/param.py: 38%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

191 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import contextlib 

20import copy 

21import datetime 

22import json 

23import logging 

24import warnings 

25from typing import TYPE_CHECKING, Any, ClassVar, ItemsView, Iterable, MutableMapping, ValuesView 

26 

27from pendulum.parsing import parse_iso8601 

28 

29from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning 

30from airflow.utils import timezone 

31from airflow.utils.mixins import ResolveMixin 

32from airflow.utils.types import NOTSET, ArgNotSet 

33 

34if TYPE_CHECKING: 

35 from airflow.models.dag import DAG 

36 from airflow.models.dagrun import DagRun 

37 from airflow.models.operator import Operator 

38 from airflow.serialization.pydantic.dag_run import DagRunPydantic 

39 from airflow.utils.context import Context 

40 

41logger = logging.getLogger(__name__) 

42 

43 

44class Param: 

45 """ 

46 Class to hold the default value of a Param and rule set to do the validations. 

47 

48 Without the rule set it always validates and returns the default value. 

49 

50 :param default: The value this Param object holds 

51 :param description: Optional help text for the Param 

52 :param schema: The validation schema of the Param, if not given then all kwargs except 

53 default & description will form the schema 

54 """ 

55 

56 __version__: ClassVar[int] = 1 

57 

58 CLASS_IDENTIFIER = "__class" 

59 

60 def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs): 

61 if default is not NOTSET: 

62 self._warn_if_not_json(default) 

63 self.value = default 

64 self.description = description 

65 self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs 

66 

67 def __copy__(self) -> Param: 

68 return Param(self.value, self.description, schema=self.schema) 

69 

70 @staticmethod 

71 def _warn_if_not_json(value): 

72 try: 

73 json.dumps(value) 

74 except Exception: 

75 warnings.warn( 

76 "The use of non-json-serializable params is deprecated and will be removed in " 

77 "a future release", 

78 RemovedInAirflow3Warning, 

79 stacklevel=1, 

80 ) 

81 

82 @staticmethod 

83 def _warn_if_not_rfc3339_dt(value): 

84 """Fallback to iso8601 datetime validation if rfc3339 failed.""" 

85 try: 

86 iso8601_value = parse_iso8601(value) 

87 except Exception: 

88 return None 

89 if not isinstance(iso8601_value, datetime.datetime): 

90 return None 

91 warnings.warn( 

92 f"The use of non-RFC3339 datetime: {value!r} is deprecated " 

93 "and will be removed in a future release", 

94 RemovedInAirflow3Warning, 

95 stacklevel=1, 

96 ) 

97 if timezone.is_naive(iso8601_value): 

98 warnings.warn( 

99 "The use naive datetime is deprecated and will be removed in a future release", 

100 RemovedInAirflow3Warning, 

101 stacklevel=1, 

102 ) 

103 return value 

104 

105 def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: 

106 """ 

107 Run the validations and returns the Param's final value. 

108 

109 May raise ValueError on failed validations, or TypeError 

110 if no value is passed and no value already exists. 

111 We first check that value is json-serializable; if not, warn. 

112 In future release we will require the value to be json-serializable. 

113 

114 :param value: The value to be updated for the Param 

115 :param suppress_exception: To raise an exception or not when the validations fails. 

116 If true and validations fails, the return value would be None. 

117 """ 

118 import jsonschema 

119 from jsonschema import FormatChecker 

120 from jsonschema.exceptions import ValidationError 

121 

122 if value is not NOTSET: 

123 self._warn_if_not_json(value) 

124 final_val = self.value if value is NOTSET else value 

125 if isinstance(final_val, ArgNotSet): 

126 if suppress_exception: 

127 return None 

128 raise ParamValidationError("No value passed and Param has no default value") 

129 try: 

130 jsonschema.validate(final_val, self.schema, format_checker=FormatChecker()) 

131 except ValidationError as err: 

132 if err.schema.get("format") == "date-time": 

133 rfc3339_value = self._warn_if_not_rfc3339_dt(final_val) 

134 if rfc3339_value: 

135 self.value = rfc3339_value 

136 return rfc3339_value 

137 if suppress_exception: 

138 return None 

139 raise ParamValidationError(err) from None 

140 self.value = final_val 

141 return final_val 

142 

143 def dump(self) -> dict: 

144 """Dump the Param as a dictionary.""" 

145 out_dict: dict[str, str | None] = { 

146 self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}" 

147 } 

148 out_dict.update(self.__dict__) 

149 # Ensure that not set is translated to None 

150 if self.value is NOTSET: 

151 out_dict["value"] = None 

152 return out_dict 

153 

154 @property 

155 def has_value(self) -> bool: 

156 return self.value is not NOTSET and self.value is not None 

157 

158 def serialize(self) -> dict: 

159 return {"value": self.value, "description": self.description, "schema": self.schema} 

160 

161 @staticmethod 

162 def deserialize(data: dict[str, Any], version: int) -> Param: 

163 if version > Param.__version__: 

164 raise TypeError("serialized version > class version") 

165 

166 return Param(default=data["value"], description=data["description"], schema=data["schema"]) 

167 

168 

169class ParamsDict(MutableMapping[str, Any]): 

170 """ 

171 Class to hold all params for dags or tasks. 

172 

173 All the keys are strictly string and values are converted into Param's object 

174 if they are not already. This class is to replace param's dictionary implicitly 

175 and ideally not needed to be used directly. 

176 

177 

178 :param dict_obj: A dict or dict like object to init ParamsDict 

179 :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict 

180 """ 

181 

182 __version__: ClassVar[int] = 1 

183 __slots__ = ["__dict", "suppress_exception"] 

184 

185 def __init__(self, dict_obj: MutableMapping | None = None, suppress_exception: bool = False): 

186 params_dict: dict[str, Param] = {} 

187 dict_obj = dict_obj or {} 

188 for k, v in dict_obj.items(): 

189 if not isinstance(v, Param): 

190 params_dict[k] = Param(v) 

191 else: 

192 params_dict[k] = v 

193 self.__dict = params_dict 

194 self.suppress_exception = suppress_exception 

195 

196 def __bool__(self) -> bool: 

197 return bool(self.__dict) 

198 

199 def __eq__(self, other: Any) -> bool: 

200 if isinstance(other, ParamsDict): 

201 return self.dump() == other.dump() 

202 if isinstance(other, dict): 

203 return self.dump() == other 

204 return NotImplemented 

205 

206 def __copy__(self) -> ParamsDict: 

207 return ParamsDict(self.__dict, self.suppress_exception) 

208 

209 def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict: 

210 return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception) 

211 

212 def __contains__(self, o: object) -> bool: 

213 return o in self.__dict 

214 

215 def __len__(self) -> int: 

216 return len(self.__dict) 

217 

218 def __delitem__(self, v: str) -> None: 

219 del self.__dict[v] 

220 

221 def __iter__(self): 

222 return iter(self.__dict) 

223 

224 def __repr__(self): 

225 return repr(self.dump()) 

226 

227 def __setitem__(self, key: str, value: Any) -> None: 

228 """ 

229 Override for dictionary's ``setitem`` method to ensure all values are of Param's type only. 

230 

231 :param key: A key which needs to be inserted or updated in the dict 

232 :param value: A value which needs to be set against the key. It could be of any 

233 type but will be converted and stored as a Param object eventually. 

234 """ 

235 if isinstance(value, Param): 

236 param = value 

237 elif key in self.__dict: 

238 param = self.__dict[key] 

239 try: 

240 param.resolve(value=value, suppress_exception=self.suppress_exception) 

241 except ParamValidationError as ve: 

242 raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None 

243 else: 

244 # if the key isn't there already and if the value isn't of Param type create a new Param object 

245 param = Param(value) 

246 

247 self.__dict[key] = param 

248 

249 def __getitem__(self, key: str) -> Any: 

250 """ 

251 Override for dictionary's ``getitem`` method to call the resolve method after fetching the key. 

252 

253 :param key: The key to fetch 

254 """ 

255 param = self.__dict[key] 

256 return param.resolve(suppress_exception=self.suppress_exception) 

257 

258 def get_param(self, key: str) -> Param: 

259 """Get the internal :class:`.Param` object for this key.""" 

260 return self.__dict[key] 

261 

262 def items(self): 

263 return ItemsView(self.__dict) 

264 

265 def values(self): 

266 return ValuesView(self.__dict) 

267 

268 def update(self, *args, **kwargs) -> None: 

269 if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict): 

270 return super().update(args[0].__dict) 

271 super().update(*args, **kwargs) 

272 

273 def dump(self) -> dict[str, Any]: 

274 """Dump the ParamsDict object as a dictionary, while suppressing exceptions.""" 

275 return {k: v.resolve(suppress_exception=True) for k, v in self.items()} 

276 

277 def validate(self) -> dict[str, Any]: 

278 """Validate & returns all the Params object stored in the dictionary.""" 

279 resolved_dict = {} 

280 try: 

281 for k, v in self.items(): 

282 resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception) 

283 except ParamValidationError as ve: 

284 raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None 

285 

286 return resolved_dict 

287 

288 def serialize(self) -> dict[str, Any]: 

289 return self.dump() 

290 

291 @staticmethod 

292 def deserialize(data: dict, version: int) -> ParamsDict: 

293 if version > ParamsDict.__version__: 

294 raise TypeError("serialized version > class version") 

295 

296 return ParamsDict(data) 

297 

298 

299class DagParam(ResolveMixin): 

300 """DAG run parameter reference. 

301 

302 This binds a simple Param object to a name within a DAG instance, so that it 

303 can be resolved during the runtime via the ``{{ context }}`` dictionary. The 

304 ideal use case of this class is to implicitly convert args passed to a 

305 method decorated by ``@dag``. 

306 

307 It can be used to parameterize a DAG. You can overwrite its value by setting 

308 it on conf when you trigger your DagRun. 

309 

310 This can also be used in templates by accessing ``{{ context.params }}``. 

311 

312 **Example**: 

313 

314 with DAG(...) as dag: 

315 EmailOperator(subject=dag.param('subject', 'Hi from Airflow!')) 

316 

317 :param current_dag: Dag being used for parameter. 

318 :param name: key value which is used to set the parameter 

319 :param default: Default value used if no parameter was set. 

320 """ 

321 

322 def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET): 

323 if default is not NOTSET: 

324 current_dag.params[name] = default 

325 self._name = name 

326 self._default = default 

327 

328 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

329 return () 

330 

331 def resolve(self, context: Context) -> Any: 

332 """Pull DagParam value from DagRun context. This method is run during ``op.execute()``.""" 

333 with contextlib.suppress(KeyError): 

334 return context["dag_run"].conf[self._name] 

335 if self._default is not NOTSET: 

336 return self._default 

337 with contextlib.suppress(KeyError): 

338 return context["params"][self._name] 

339 raise AirflowException(f"No value could be resolved for parameter {self._name}") 

340 

341 

342def process_params( 

343 dag: DAG, 

344 task: Operator, 

345 dag_run: DagRun | DagRunPydantic | None, 

346 *, 

347 suppress_exception: bool, 

348) -> dict[str, Any]: 

349 """Merge, validate params, and convert them into a simple dict.""" 

350 from airflow.configuration import conf 

351 

352 params = ParamsDict(suppress_exception=suppress_exception) 

353 with contextlib.suppress(AttributeError): 

354 params.update(dag.params) 

355 if task.params: 

356 params.update(task.params) 

357 if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf: 

358 logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) 

359 params.update(dag_run.conf) 

360 return params.validate()