Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/bases/xcom.py: 49%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

70 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import collections 

21from typing import Any, Protocol 

22 

23import structlog 

24 

25from airflow.sdk.execution_time.comms import ( 

26 DeleteXCom, 

27 GetXCom, 

28 GetXComSequenceSlice, 

29 SetXCom, 

30 XComResult, 

31 XComSequenceSliceResult, 

32) 

33 

34# Lightweight wrapper for XCom values 

35_XComValueWrapper = collections.namedtuple("_XComValueWrapper", "value") 

36 

37log = structlog.get_logger(logger_name="task") 

38 

39 

40class TIKeyProtocol(Protocol): 

41 dag_id: str 

42 task_id: str 

43 run_id: str 

44 map_index: int 

45 

46 

47class BaseXCom: 

48 """BaseXcom is an interface now to interact with XCom backends.""" 

49 

50 XCOM_RETURN_KEY = "return_value" 

51 

52 @classmethod 

53 def set( 

54 cls, 

55 key: str, 

56 value: Any, 

57 *, 

58 dag_id: str, 

59 task_id: str, 

60 run_id: str, 

61 map_index: int = -1, 

62 _mapped_length: int | None = None, 

63 ) -> None: 

64 """ 

65 Store an XCom value. 

66 

67 :param key: Key to store the XCom. 

68 :param value: XCom value to store. 

69 :param dag_id: Dag ID. 

70 :param task_id: Task ID. 

71 :param run_id: Dag run ID for the task. 

72 :param map_index: Optional map index to assign XCom for a mapped task. 

73 The default is ``-1`` (set for a non-mapped task). 

74 """ 

75 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

76 

77 value = cls.serialize_value( 

78 value=value, 

79 key=key, 

80 task_id=task_id, 

81 dag_id=dag_id, 

82 run_id=run_id, 

83 map_index=map_index, 

84 ) 

85 

86 SUPERVISOR_COMMS.send( 

87 SetXCom( 

88 key=key, 

89 value=value, 

90 dag_id=dag_id, 

91 task_id=task_id, 

92 run_id=run_id, 

93 map_index=map_index, 

94 mapped_length=_mapped_length, 

95 ), 

96 ) 

97 

98 @classmethod 

99 def _set_xcom_in_db( 

100 cls, 

101 key: str, 

102 value: Any, 

103 *, 

104 dag_id: str, 

105 task_id: str, 

106 run_id: str, 

107 map_index: int = -1, 

108 ) -> None: 

109 """ 

110 Store an XCom value directly in the metadata database. 

111 

112 :param key: Key to store the XCom. 

113 :param value: XCom value to store. 

114 :param dag_id: Dag ID. 

115 :param task_id: Task ID. 

116 :param run_id: Dag run ID for the task. 

117 :param map_index: Optional map index to assign XCom for a mapped task. 

118 The default is ``-1`` (set for a non-mapped task). 

119 """ 

120 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

121 

122 SUPERVISOR_COMMS.send( 

123 SetXCom( 

124 key=key, 

125 value=value, 

126 dag_id=dag_id, 

127 task_id=task_id, 

128 run_id=run_id, 

129 map_index=map_index, 

130 ), 

131 ) 

132 

133 @classmethod 

134 def get_value( 

135 cls, 

136 *, 

137 ti_key: TIKeyProtocol, 

138 key: str, 

139 ) -> Any: 

140 """ 

141 Retrieve an XCom value for a task instance. 

142 

143 This method returns "full" XCom values (i.e. uses ``deserialize_value`` 

144 from the XCom backend). 

145 

146 If there are no results, *None* is returned. If multiple XCom entries 

147 match the criteria, an arbitrary one is returned. 

148 

149 :param ti_key: The TaskInstanceKey to look up the XCom for. 

150 :param key: A key for the XCom. If provided, only XCom with matching 

151 keys will be returned. Pass *None* (default) to remove the filter. 

152 """ 

153 return cls.get_one( 

154 key=key, 

155 task_id=ti_key.task_id, 

156 dag_id=ti_key.dag_id, 

157 run_id=ti_key.run_id, 

158 map_index=ti_key.map_index, 

159 ) 

160 

161 @classmethod 

162 def _get_xcom_db_ref( 

163 cls, 

164 *, 

165 key: str, 

166 dag_id: str, 

167 task_id: str, 

168 run_id: str, 

169 map_index: int | None = None, 

170 ) -> XComResult: 

171 """ 

172 Retrieve an XCom value, optionally meeting certain criteria. 

173 

174 This method returns "full" XCom values (i.e. uses ``deserialize_value`` 

175 from the XCom backend). 

176 

177 If there are no results, *None* is returned. If multiple XCom entries 

178 match the criteria, an arbitrary one is returned. 

179 

180 .. seealso:: ``get_value()`` is a convenience function if you already 

181 have a structured TaskInstance or TaskInstanceKey object available. 

182 

183 :param run_id: Dag run ID for the task. 

184 :param dag_id: Only pull XCom from this Dag. Pass *None* (default) to 

185 remove the filter. 

186 :param task_id: Only XCom from task with matching ID will be pulled. 

187 Pass *None* (default) to remove the filter. 

188 :param map_index: Only XCom from task with matching ID will be pulled. 

189 Pass *None* (default) to remove the filter. 

190 :param key: A key for the XCom. If provided, only XCom with matching 

191 keys will be returned. Pass *None* (default) to remove the filter. 

192 """ 

193 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

194 

195 msg = SUPERVISOR_COMMS.send( 

196 GetXCom( 

197 key=key, 

198 dag_id=dag_id, 

199 task_id=task_id, 

200 run_id=run_id, 

201 map_index=map_index, 

202 ), 

203 ) 

204 

205 if not isinstance(msg, XComResult): 

206 raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") 

207 

208 return msg 

209 

210 @classmethod 

211 def get_one( 

212 cls, 

213 *, 

214 key: str, 

215 dag_id: str, 

216 task_id: str, 

217 run_id: str, 

218 map_index: int | None = None, 

219 include_prior_dates: bool = False, 

220 ) -> Any | None: 

221 """ 

222 Retrieve an XCom value, optionally meeting certain criteria. 

223 

224 This method returns "full" XCom values (i.e. uses ``deserialize_value`` 

225 from the XCom backend). 

226 

227 If there are no results, *None* is returned. If multiple XCom entries 

228 match the criteria, an arbitrary one is returned. 

229 

230 .. seealso:: ``get_value()`` is a convenience function if you already 

231 have a structured TaskInstance or TaskInstanceKey object available. 

232 

233 :param run_id: Dag run ID for the task. 

234 :param dag_id: Only pull XCom from this Dag. Pass *None* (default) to 

235 remove the filter. 

236 :param task_id: Only XCom from task with matching ID will be pulled. 

237 Pass *None* (default) to remove the filter. 

238 :param map_index: Only XCom from task with matching ID will be pulled. 

239 Pass *None* (default) to remove the filter. 

240 :param key: A key for the XCom. If provided, only XCom with matching 

241 keys will be returned. Pass *None* (default) to remove the filter. 

242 :param include_prior_dates: If *False* (default), only XCom from the 

243 specified Dag run is returned. If *True*, the latest matching XCom is 

244 returned regardless of the run it belongs to. 

245 """ 

246 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

247 

248 msg = SUPERVISOR_COMMS.send( 

249 GetXCom( 

250 key=key, 

251 dag_id=dag_id, 

252 task_id=task_id, 

253 run_id=run_id, 

254 map_index=map_index, 

255 include_prior_dates=include_prior_dates, 

256 ), 

257 ) 

258 

259 if not isinstance(msg, XComResult): 

260 raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") 

261 

262 if msg.value is not None: 

263 return cls.deserialize_value(msg) 

264 log.warning( 

265 "No XCom value found; defaulting to None.", 

266 key=key, 

267 dag_id=dag_id, 

268 task_id=task_id, 

269 run_id=run_id, 

270 map_index=map_index, 

271 ) 

272 return None 

273 

274 @classmethod 

275 def get_all( 

276 cls, 

277 *, 

278 key: str, 

279 dag_id: str, 

280 task_id: str, 

281 run_id: str, 

282 include_prior_dates: bool = False, 

283 ) -> Any: 

284 """ 

285 Retrieve all XCom values for a task, typically from all map indexes. 

286 

287 XComSequenceSliceResult can never have *None* in it, it returns an empty list 

288 if no values were found. 

289 

290 This is particularly useful for getting all XCom values from all map 

291 indexes of a mapped task at once. 

292 

293 :param key: A key for the XCom. Only XComs with this key will be returned. 

294 :param run_id: Dag run ID for the task. 

295 :param dag_id: Dag ID to pull XComs from. 

296 :param task_id: Task ID to pull XComs from. 

297 :param include_prior_dates: If *False* (default), only XComs from the 

298 specified Dag run are returned. If *True*, the latest matching XComs are 

299 returned regardless of the run they belong to. 

300 :return: List of all XCom values if found. 

301 """ 

302 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

303 

304 msg = SUPERVISOR_COMMS.send( 

305 msg=GetXComSequenceSlice( 

306 key=key, 

307 dag_id=dag_id, 

308 task_id=task_id, 

309 run_id=run_id, 

310 start=None, 

311 stop=None, 

312 step=None, 

313 include_prior_dates=include_prior_dates, 

314 ), 

315 ) 

316 

317 if not isinstance(msg, XComSequenceSliceResult): 

318 raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") 

319 

320 if not msg.root: 

321 return None 

322 

323 return [cls.deserialize_value(_XComValueWrapper(value)) for value in msg.root] 

324 

325 @staticmethod 

326 def serialize_value( 

327 value: Any, 

328 *, 

329 key: str | None = None, 

330 task_id: str | None = None, 

331 dag_id: str | None = None, 

332 run_id: str | None = None, 

333 map_index: int | None = None, 

334 ) -> str: 

335 """Serialize XCom value to JSON str.""" 

336 from airflow.serialization.serde import serialize 

337 

338 # return back the value for BaseXCom, custom backends will implement this 

339 return serialize(value) # type: ignore[return-value] 

340 

341 @staticmethod 

342 def deserialize_value(result) -> Any: 

343 """Deserialize XCom value from str objects.""" 

344 from airflow.serialization.serde import deserialize 

345 

346 return deserialize(result.value) 

347 

348 @classmethod 

349 def purge(cls, xcom: XComResult, *args) -> None: 

350 """Purge an XCom entry from underlying storage implementations.""" 

351 pass 

352 

353 @classmethod 

354 def delete( 

355 cls, 

356 key: str, 

357 task_id: str, 

358 dag_id: str, 

359 run_id: str, 

360 map_index: int | None = None, 

361 ) -> None: 

362 """Delete an Xcom entry, for custom xcom backends, it gets the path associated with the data on the backend and purges it.""" 

363 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

364 

365 xcom_result = cls._get_xcom_db_ref( 

366 key=key, 

367 dag_id=dag_id, 

368 task_id=task_id, 

369 run_id=run_id, 

370 map_index=map_index, 

371 ) 

372 cls.purge(xcom_result) 

373 SUPERVISOR_COMMS.send( 

374 DeleteXCom( 

375 key=key, 

376 dag_id=dag_id, 

377 task_id=task_id, 

378 run_id=run_id, 

379 map_index=map_index, 

380 ), 

381 )