Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/param.py: 38%

168 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17from __future__ import annotations 

18 

19import contextlib 

20import copy 

21import json 

22import logging 

23import warnings 

24from typing import TYPE_CHECKING, Any, ClassVar, ItemsView, Iterable, MutableMapping, ValuesView 

25 

26from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning 

27from airflow.utils.context import Context 

28from airflow.utils.mixins import ResolveMixin 

29from airflow.utils.types import NOTSET, ArgNotSet 

30 

31if TYPE_CHECKING: 

32 from airflow.models.dag import DAG 

33 from airflow.models.dagrun import DagRun 

34 from airflow.models.operator import Operator 

35 

36logger = logging.getLogger(__name__) 

37 

38 

39class Param: 

40 """ 

41 Class to hold the default value of a Param and rule set to do the validations. Without the rule set 

42 it always validates and returns the default value. 

43 

44 :param default: The value this Param object holds 

45 :param description: Optional help text for the Param 

46 :param schema: The validation schema of the Param, if not given then all kwargs except 

47 default & description will form the schema 

48 """ 

49 

50 __version__: ClassVar[int] = 1 

51 

52 CLASS_IDENTIFIER = "__class" 

53 

54 def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs): 

55 if default is not NOTSET: 

56 self._warn_if_not_json(default) 

57 self.value = default 

58 self.description = description 

59 self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs 

60 

61 def __copy__(self) -> Param: 

62 return Param(self.value, self.description, schema=self.schema) 

63 

64 @staticmethod 

65 def _warn_if_not_json(value): 

66 try: 

67 json.dumps(value) 

68 except Exception: 

69 warnings.warn( 

70 "The use of non-json-serializable params is deprecated and will be removed in " 

71 "a future release", 

72 RemovedInAirflow3Warning, 

73 ) 

74 

75 def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any: 

76 """ 

77 Runs the validations and returns the Param's final value. 

78 May raise ValueError on failed validations, or TypeError 

79 if no value is passed and no value already exists. 

80 We first check that value is json-serializable; if not, warn. 

81 In future release we will require the value to be json-serializable. 

82 

83 :param value: The value to be updated for the Param 

84 :param suppress_exception: To raise an exception or not when the validations fails. 

85 If true and validations fails, the return value would be None. 

86 """ 

87 import jsonschema 

88 from jsonschema import FormatChecker 

89 from jsonschema.exceptions import ValidationError 

90 

91 if value is not NOTSET: 

92 self._warn_if_not_json(value) 

93 final_val = value if value is not NOTSET else self.value 

94 if isinstance(final_val, ArgNotSet): 

95 if suppress_exception: 

96 return None 

97 raise ParamValidationError("No value passed and Param has no default value") 

98 try: 

99 jsonschema.validate(final_val, self.schema, format_checker=FormatChecker()) 

100 except ValidationError as err: 

101 if suppress_exception: 

102 return None 

103 raise ParamValidationError(err) from None 

104 self.value = final_val 

105 return final_val 

106 

107 def dump(self) -> dict: 

108 """Dump the Param as a dictionary""" 

109 out_dict = {self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"} 

110 out_dict.update(self.__dict__) 

111 return out_dict 

112 

113 @property 

114 def has_value(self) -> bool: 

115 return self.value is not NOTSET 

116 

117 def serialize(self) -> dict: 

118 return {"value": self.value, "description": self.description, "schema": self.schema} 

119 

120 @staticmethod 

121 def deserialize(data: dict[str, Any], version: int) -> Param: 

122 if version > Param.__version__: 

123 raise TypeError("serialized version > class version") 

124 

125 return Param(default=data["value"], description=data["description"], schema=data["schema"]) 

126 

127 

128class ParamsDict(MutableMapping[str, Any]): 

129 """ 

130 Class to hold all params for dags or tasks. All the keys are strictly string and values 

131 are converted into Param's object if they are not already. This class is to replace param's 

132 dictionary implicitly and ideally not needed to be used directly. 

133 """ 

134 

135 __version__: ClassVar[int] = 1 

136 __slots__ = ["__dict", "suppress_exception"] 

137 

138 def __init__(self, dict_obj: dict | None = None, suppress_exception: bool = False): 

139 """ 

140 :param dict_obj: A dict or dict like object to init ParamsDict 

141 :param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict 

142 """ 

143 params_dict: dict[str, Param] = {} 

144 dict_obj = dict_obj or {} 

145 for k, v in dict_obj.items(): 

146 if not isinstance(v, Param): 

147 params_dict[k] = Param(v) 

148 else: 

149 params_dict[k] = v 

150 self.__dict = params_dict 

151 self.suppress_exception = suppress_exception 

152 

153 def __bool__(self) -> bool: 

154 return bool(self.__dict) 

155 

156 def __eq__(self, other: Any) -> bool: 

157 if isinstance(other, ParamsDict): 

158 return self.dump() == other.dump() 

159 if isinstance(other, dict): 

160 return self.dump() == other 

161 return NotImplemented 

162 

163 def __copy__(self) -> ParamsDict: 

164 return ParamsDict(self.__dict, self.suppress_exception) 

165 

166 def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict: 

167 return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception) 

168 

169 def __contains__(self, o: object) -> bool: 

170 return o in self.__dict 

171 

172 def __len__(self) -> int: 

173 return len(self.__dict) 

174 

175 def __delitem__(self, v: str) -> None: 

176 del self.__dict[v] 

177 

178 def __iter__(self): 

179 return iter(self.__dict) 

180 

181 def __repr__(self): 

182 return repr(self.dump()) 

183 

184 def __setitem__(self, key: str, value: Any) -> None: 

185 """ 

186 Override for dictionary's ``setitem`` method. This method make sure that all values are of 

187 Param's type only. 

188 

189 :param key: A key which needs to be inserted or updated in the dict 

190 :param value: A value which needs to be set against the key. It could be of any 

191 type but will be converted and stored as a Param object eventually. 

192 """ 

193 if isinstance(value, Param): 

194 param = value 

195 elif key in self.__dict: 

196 param = self.__dict[key] 

197 try: 

198 param.resolve(value=value, suppress_exception=self.suppress_exception) 

199 except ParamValidationError as ve: 

200 raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None 

201 else: 

202 # if the key isn't there already and if the value isn't of Param type create a new Param object 

203 param = Param(value) 

204 

205 self.__dict[key] = param 

206 

207 def __getitem__(self, key: str) -> Any: 

208 """ 

209 Override for dictionary's ``getitem`` method. After fetching the key, it would call the 

210 resolve method as well on the Param object. 

211 

212 :param key: The key to fetch 

213 """ 

214 param = self.__dict[key] 

215 return param.resolve(suppress_exception=self.suppress_exception) 

216 

217 def get_param(self, key: str) -> Param: 

218 """Get the internal :class:`.Param` object for this key""" 

219 return self.__dict[key] 

220 

221 def items(self): 

222 return ItemsView(self.__dict) 

223 

224 def values(self): 

225 return ValuesView(self.__dict) 

226 

227 def update(self, *args, **kwargs) -> None: 

228 if len(args) == 1 and not kwargs and isinstance(args[0], ParamsDict): 

229 return super().update(args[0].__dict) 

230 super().update(*args, **kwargs) 

231 

232 def dump(self) -> dict[str, Any]: 

233 """Dumps the ParamsDict object as a dictionary, while suppressing exceptions""" 

234 return {k: v.resolve(suppress_exception=True) for k, v in self.items()} 

235 

236 def validate(self) -> dict[str, Any]: 

237 """Validates & returns all the Params object stored in the dictionary""" 

238 resolved_dict = {} 

239 try: 

240 for k, v in self.items(): 

241 resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception) 

242 except ParamValidationError as ve: 

243 raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None 

244 

245 return resolved_dict 

246 

247 def serialize(self) -> dict[str, Any]: 

248 return self.dump() 

249 

250 @staticmethod 

251 def deserialize(data: dict, version: int) -> ParamsDict: 

252 if version > ParamsDict.__version__: 

253 raise TypeError("serialized version > class version") 

254 

255 return ParamsDict(data) 

256 

257 

258class DagParam(ResolveMixin): 

259 """DAG run parameter reference. 

260 

261 This binds a simple Param object to a name within a DAG instance, so that it 

262 can be resolved during the runtime via the ``{{ context }}`` dictionary. The 

263 ideal use case of this class is to implicitly convert args passed to a 

264 method decorated by ``@dag``. 

265 

266 It can be used to parameterize a DAG. You can overwrite its value by setting 

267 it on conf when you trigger your DagRun. 

268 

269 This can also be used in templates by accessing ``{{ context.params }}``. 

270 

271 **Example**: 

272 

273 with DAG(...) as dag: 

274 EmailOperator(subject=dag.param('subject', 'Hi from Airflow!')) 

275 

276 :param current_dag: Dag being used for parameter. 

277 :param name: key value which is used to set the parameter 

278 :param default: Default value used if no parameter was set. 

279 """ 

280 

281 def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET): 

282 if default is not NOTSET: 

283 current_dag.params[name] = default 

284 self._name = name 

285 self._default = default 

286 

287 def iter_references(self) -> Iterable[tuple[Operator, str]]: 

288 return () 

289 

290 def resolve(self, context: Context) -> Any: 

291 """Pull DagParam value from DagRun context. This method is run during ``op.execute()``.""" 

292 with contextlib.suppress(KeyError): 

293 return context["dag_run"].conf[self._name] 

294 if self._default is not NOTSET: 

295 return self._default 

296 with contextlib.suppress(KeyError): 

297 return context["params"][self._name] 

298 raise AirflowException(f"No value could be resolved for parameter {self._name}") 

299 

300 

301def process_params( 

302 dag: DAG, 

303 task: Operator, 

304 dag_run: DagRun | None, 

305 *, 

306 suppress_exception: bool, 

307) -> dict[str, Any]: 

308 """Merge, validate params, and convert them into a simple dict.""" 

309 from airflow.configuration import conf 

310 

311 params = ParamsDict(suppress_exception=suppress_exception) 

312 with contextlib.suppress(AttributeError): 

313 params.update(dag.params) 

314 if task.params: 

315 params.update(task.params) 

316 if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf: 

317 logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) 

318 params.update(dag_run.conf) 

319 return params.validate()