Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/param.py: 37%

188 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:35 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import contextlib 

20import copy 

21import datetime 

22import json 

23import logging 

24import warnings 

25from typing import TYPE_CHECKING, Any, ClassVar, ItemsView, Iterable, MutableMapping, ValuesView 

26 

27from pendulum.parsing import parse_iso8601 

28 

29from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning 

30from airflow.utils import timezone 

31from airflow.utils.context import Context 

32from airflow.utils.mixins import ResolveMixin 

33from airflow.utils.types import NOTSET, ArgNotSet 

34 

35if TYPE_CHECKING: 

36 from airflow.models.dag import DAG 

37 from airflow.models.dagrun import DagRun 

38 from airflow.models.operator import Operator 

39 

40logger = logging.getLogger(__name__) 

41 

42 

43class Param: 

44 """ 

45 Class to hold the default value of a Param and rule set to do the validations. Without the rule set 

46 it always validates and returns the default value. 

47 

48 :param default: The value this Param object holds 

49 :param description: Optional help text for the Param 

50 :param schema: The validation schema of the Param, if not given then all kwargs except 

51 default & description will form the schema 

52 """ 

53 

54 __version__: ClassVar[int] = 1 

55 

56 CLASS_IDENTIFIER = "__class" 

57 

58 def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs): 

59 if default is not NOTSET: 

60 self._warn_if_not_json(default) 

61 self.value = default 

62 self.description = description 

63 self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs 

64 

65 def __copy__(self) -> Param: 

66 return Param(self.value, self.description, schema=self.schema) 

67 

68 @staticmethod 

69 def _warn_if_not_json(value): 

70 try: 

71 json.dumps(value) 

72 except Exception: 

73 warnings.warn( 

74 "The use of non-json-serializable params is deprecated and will be removed in " 

75 "a future release", 

76 RemovedInAirflow3Warning, 

77 ) 

78 

79 @staticmethod 

80 def _warn_if_not_rfc3339_dt(value): 

81 """Fallback to iso8601 datetime validation if rfc3339 failed.""" 

82 try: 

83 iso8601_value = parse_iso8601(value) 

84 except Exception: 

85 return None 

86 if not isinstance(iso8601_value, datetime.datetime): 

87 return None 

88 warnings.warn( 

89 f"The use of non-RFC3339 datetime: {value!r} is deprecated " 

90 "and will be removed in a future release", 

91 RemovedInAirflow3Warning, 

92 ) 

93 if timezone.is_naive(iso8601_value): 

94 warnings.warn( 

95 "The use naive datetime is deprecated and will be removed in a future release", 

96 RemovedInAirflow3Warning, 

97 ) 

98 return value 

99 

100 def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: 

101 """ 

102 Runs the validations and returns the Param's final value. 

103 May raise ValueError on failed validations, or TypeError 

104 if no value is passed and no value already exists. 

105 We first check that value is json-serializable; if not, warn. 

106 In future release we will require the value to be json-serializable. 

107 

108 :param value: The value to be updated for the Param 

109 :param suppress_exception: To raise an exception or not when the validations fails. 

110 If true and validations fails, the return value would be None. 

111 """ 

112 import jsonschema 

113 from jsonschema import FormatChecker 

114 from jsonschema.exceptions import ValidationError 

115 

116 if value is not NOTSET: 

117 self._warn_if_not_json(value) 

118 final_val = value if value is not NOTSET else self.value 

119 if isinstance(final_val, ArgNotSet): 

120 if suppress_exception: 

121 return None 

122 raise ParamValidationError("No value passed and Param has no default value") 

123 try: 

124 jsonschema.validate(final_val, self.schema, format_checker=FormatChecker()) 

125 except ValidationError as err: 

126 if err.schema.get("format") == "date-time": 

127 rfc3339_value = self._warn_if_not_rfc3339_dt(final_val) 

128 if rfc3339_value: 

129 self.value = rfc3339_value 

130 return rfc3339_value 

131 if suppress_exception: 

132 return None 

133 raise ParamValidationError(err) from None 

134 self.value = final_val 

135 return final_val 

136 

137 def dump(self) -> dict: 

138 """Dump the Param as a dictionary.""" 

139 out_dict = {self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"} 

140 out_dict.update(self.__dict__) 

141 return out_dict 

142 

143 @property 

144 def has_value(self) -> bool: 

145 return self.value is not NOTSET 

146 

147 def serialize(self) -> dict: 

148 return {"value": self.value, "description": self.description, "schema": self.schema} 

149 

150 @staticmethod 

151 def deserialize(data: dict[str, Any], version: int) -> Param: 

152 if version > Param.__version__: 

153 raise TypeError("serialized version > class version") 

154 

155 return Param(default=data["value"], description=data["description"], schema=data["schema"]) 

156 

157 

158class ParamsDict(MutableMapping[str, Any]): 

159 """ 

160 Class to hold all params for dags or tasks. All the keys are strictly string and values 

161 are converted into Param's object if they are not already. This class is to replace param's 

162 dictionary implicitly and ideally not needed to be used directly. 

163 """ 

164 

165 __version__: ClassVar[int] = 1 

166 __slots__ = ["__dict", "suppress_exception"] 

167 

168 def __init__(self, dict_obj: MutableMapping | None = None, suppress_exception: bool = False): 

169 """ 

170 :param dict_obj: A dict or dict like object to init ParamsDict 

171 :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict 

172 """ 

173 params_dict: dict[str, Param] = {} 

174 dict_obj = dict_obj or {} 

175 for k, v in dict_obj.items(): 

176 if not isinstance(v, Param): 

177 params_dict[k] = Param(v) 

178 else: 

179 params_dict[k] = v 

180 self.__dict = params_dict 

181 self.suppress_exception = suppress_exception 

182 

183 def __bool__(self) -> bool: 

184 return bool(self.__dict) 

185 

186 def __eq__(self, other: Any) -> bool: 

187 if isinstance(other, ParamsDict): 

188 return self.dump() == other.dump() 

189 if isinstance(other, dict): 

190 return self.dump() == other 

191 return NotImplemented 

192 

193 def __copy__(self) -> ParamsDict: 

194 return ParamsDict(self.__dict, self.suppress_exception) 

195 

196 def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict: 

197 return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception) 

198 

199 def __contains__(self, o: object) -> bool: 

200 return o in self.__dict 

201 

202 def __len__(self) -> int: 

203 return len(self.__dict) 

204 

205 def __delitem__(self, v: str) -> None: 

206 del self.__dict[v] 

207 

208 def __iter__(self): 

209 return iter(self.__dict) 

210 

211 def __repr__(self): 

212 return repr(self.dump()) 

213 

214 def __setitem__(self, key: str, value: Any) -> None: 

215 """ 

216 Override for dictionary's ``setitem`` method. This method make sure that all values are of 

217 Param's type only. 

218 

219 :param key: A key which needs to be inserted or updated in the dict 

220 :param value: A value which needs to be set against the key. It could be of any 

221 type but will be converted and stored as a Param object eventually. 

222 """ 

223 if isinstance(value, Param): 

224 param = value 

225 elif key in self.__dict: 

226 param = self.__dict[key] 

227 try: 

228 param.resolve(value=value, suppress_exception=self.suppress_exception) 

229 except ParamValidationError as ve: 

230 raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None 

231 else: 

232 # if the key isn't there already and if the value isn't of Param type create a new Param object 

233 param = Param(value) 

234 

235 self.__dict[key] = param 

236 

237 def __getitem__(self, key: str) -> Any: 

238 """ 

239 Override for dictionary's ``getitem`` method. After fetching the key, it would call the 

240 resolve method as well on the Param object. 

241 

242 :param key: The key to fetch 

243 """ 

244 param = self.__dict[key] 

245 return param.resolve(suppress_exception=self.suppress_exception) 

246 

247 def get_param(self, key: str) -> Param: 

248 """Get the internal :class:`.Param` object for this key.""" 

249 return self.__dict[key] 

250 

251 def items(self): 

252 return ItemsView(self.__dict) 

253 

254 def values(self): 

255 return ValuesView(self.__dict) 

256 

257 def update(self, *args, **kwargs) -> None: 

258 if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict): 

259 return super().update(args[0].__dict) 

260 super().update(*args, **kwargs) 

261 

262 def dump(self) -> dict[str, Any]: 

263 """Dumps the ParamsDict object as a dictionary, while suppressing exceptions.""" 

264 return {k: v.resolve(suppress_exception=True) for k, v in self.items()} 

265 

266 def validate(self) -> dict[str, Any]: 

267 """Validates & returns all the Params object stored in the dictionary.""" 

268 resolved_dict = {} 

269 try: 

270 for k, v in self.items(): 

271 resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception) 

272 except ParamValidationError as ve: 

273 raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None 

274 

275 return resolved_dict 

276 

277 def serialize(self) -> dict[str, Any]: 

278 return self.dump() 

279 

280 @staticmethod 

281 def deserialize(data: dict, version: int) -> ParamsDict: 

282 if version > ParamsDict.__version__: 

283 raise TypeError("serialized version > class version") 

284 

285 return ParamsDict(data) 

286 

287 

288class DagParam(ResolveMixin): 

289 """DAG run parameter reference. 

290 

291 This binds a simple Param object to a name within a DAG instance, so that it 

292 can be resolved during the runtime via the ``{{ context }}`` dictionary. The 

293 ideal use case of this class is to implicitly convert args passed to a 

294 method decorated by ``@dag``. 

295 

296 It can be used to parameterize a DAG. You can overwrite its value by setting 

297 it on conf when you trigger your DagRun. 

298 

299 This can also be used in templates by accessing ``{{ context.params }}``. 

300 

301 **Example**: 

302 

303 with DAG(...) as dag: 

304 EmailOperator(subject=dag.param('subject', 'Hi from Airflow!')) 

305 

306 :param current_dag: Dag being used for parameter. 

307 :param name: key value which is used to set the parameter 

308 :param default: Default value used if no parameter was set. 

309 """ 

310 

311 def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET): 

312 if default is not NOTSET: 

313 current_dag.params[name] = default 

314 self._name = name 

315 self._default = default 

316 

317 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

318 return () 

319 

320 def resolve(self, context: Context) -> Any: 

321 """Pull DagParam value from DagRun context. This method is run during ``op.execute()``.""" 

322 with contextlib.suppress(KeyError): 

323 return context["dag_run"].conf[self._name] 

324 if self._default is not NOTSET: 

325 return self._default 

326 with contextlib.suppress(KeyError): 

327 return context["params"][self._name] 

328 raise AirflowException(f"No value could be resolved for parameter {self._name}") 

329 

330 

331def process_params( 

332 dag: DAG, 

333 task: Operator, 

334 dag_run: DagRun | None, 

335 *, 

336 suppress_exception: bool, 

337) -> dict[str, Any]: 

338 """Merge, validate params, and convert them into a simple dict.""" 

339 from airflow.configuration import conf 

340 

341 params = ParamsDict(suppress_exception=suppress_exception) 

342 with contextlib.suppress(AttributeError): 

343 params.update(dag.params) 

344 if task.params: 

345 params.update(task.params) 

346 if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf: 

347 logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) 

348 params.update(dag_run.conf) 

349 return params.validate()