Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/comms.py: 75%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

461 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18r""" 

19Communication protocol between the Supervisor and the task process 

20================================================================== 

21 

22* All communication is done over the subprocesses stdin in the form of a binary length-prefixed msgpack frame 

23 (4 byte, big-endian length, followed by the msgpack-encoded _RequestFrame.) Each side uses this same 

24 encoding 

25* Log Messages from the subprocess are sent over the dedicated logs socket (which is line-based JSON) 

26* No messages are sent to task process except in response to a request. (This is because the task process will 

27 be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom 

28 value etc.) 

29* Every request returns a response, even if the frame is otherwise empty. 

30* Requests are written by the subprocess to fd0/stdin. This is making use of the fact that stdin is a 

31 bi-directional socket, and thus we can write to it and don't need a dedicated extra socket for sending 

32 requests. 

33 

34The reason this communication protocol exists, rather than the task process speaking directly to the Task 

35Execution API server is because: 

36 

371. To reduce the number of concurrent HTTP connections on the API server. 

38 

39 The supervisor already has to speak to that to heartbeat the running Task, so having the task speak to its 

40 parent process and having all API traffic go through that means that the number of HTTP connections is 

41 "halved". (Not every task will make API calls, so it's not always halved, but it is reduced.) 

42 

432. This means that the user Task code doesn't ever directly see the task identity JWT token. 

44 

45 This is a short lived token tied to one specific task instance try, so it being leaked/exfiltrated is not a 

46 large risk, but it's easy to not give it to the user code, so lets do that. 

47""" # noqa: D400, D205 

48 

49from __future__ import annotations 

50 

51import itertools 

52from collections.abc import Iterator 

53from datetime import datetime 

54from functools import cached_property 

55from pathlib import Path 

56from socket import socket 

57from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload 

58from uuid import UUID 

59 

60import attrs 

61import msgspec 

62import structlog 

63from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_serializer 

64 

65from airflow.sdk.api.datamodels._generated import ( 

66 AssetEventDagRunReference, 

67 AssetEventResponse, 

68 AssetEventsResponse, 

69 AssetResponse, 

70 BundleInfo, 

71 ConnectionResponse, 

72 DagRun, 

73 DagRunStateResponse, 

74 HITLDetailRequest, 

75 InactiveAssetsResponse, 

76 PrevSuccessfulDagRunResponse, 

77 TaskBreadcrumbsResponse, 

78 TaskInstance, 

79 TaskInstanceState, 

80 TaskStatesResponse, 

81 TIDeferredStatePayload, 

82 TIRescheduleStatePayload, 

83 TIRetryStatePayload, 

84 TIRunContext, 

85 TISkippedDownstreamTasksStatePayload, 

86 TISuccessStatePayload, 

87 TriggerDAGRunPayload, 

88 UpdateHITLDetailPayload, 

89 VariableResponse, 

90 XComResponse, 

91 XComSequenceIndexResponse, 

92 XComSequenceSliceResponse, 

93) 

94from airflow.sdk.exceptions import ErrorType 

95 

96try: 

97 from socket import recv_fds 

98except ImportError: 

99 # Available on Unix and Windows (so "everywhere") but lets be safe 

100 recv_fds = None # type: ignore[assignment] 

101 

102 

103if TYPE_CHECKING: 

104 from structlog.typing import FilteringBoundLogger as Logger 

105 

106SendMsgType = TypeVar("SendMsgType", bound=BaseModel) 

107ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) 

108 

109 

110def _msgpack_enc_hook(obj: Any) -> Any: 

111 import pendulum 

112 

113 if isinstance(obj, pendulum.DateTime): 

114 # convert the pendulm Datetime subclass into a raw datetime so that msgspec can use it's native 

115 # encoding 

116 return datetime( 

117 obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo 

118 ) 

119 if isinstance(obj, Path): 

120 return str(obj) 

121 if isinstance(obj, BaseModel): 

122 return obj.model_dump(exclude_unset=True) 

123 

124 # Raise a NotImplementedError for other types 

125 raise NotImplementedError(f"Objects of type {type(obj)} are not supported") 

126 

127 

128def _new_encoder() -> msgspec.msgpack.Encoder: 

129 return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook) 

130 

131 

132class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): 

133 id: int 

134 """ 

135 The request id, set by the sender. 

136 

137 This is used to allow "pipeling" of requests and to be able to tie response to requests, which is 

138 particularly useful in the Triggerer where multiple async tasks can send a requests concurrently. 

139 """ 

140 body: dict[str, Any] | None 

141 

142 req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder() 

143 

144 def as_bytes(self) -> bytearray: 

145 # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration 

146 buffer = bytearray(256) 

147 

148 self.req_encoder.encode_into(self, buffer, 4) 

149 

150 n = len(buffer) - 4 

151 if n >= 2**32: 

152 raise OverflowError(f"Cannot send messages larger than 4GiB {n=}") 

153 buffer[:4] = n.to_bytes(4, byteorder="big") 

154 

155 return buffer 

156 

157 

158class _ResponseFrame(_RequestFrame, frozen=True): 

159 id: int 

160 """ 

161 The id of the request this is a response to 

162 """ 

163 body: dict[str, Any] | None = None 

164 error: dict[str, Any] | None = None 

165 

166 

167@attrs.define() 

168class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): 

169 """Handle communication between the task in this process and the supervisor parent process.""" 

170 

171 log: Logger = attrs.field(repr=False, factory=structlog.get_logger) 

172 socket: socket = attrs.field(factory=lambda: socket(fileno=0)) 

173 

174 resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field( 

175 factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False 

176 ) 

177 

178 id_counter: Iterator[int] = attrs.field(factory=itertools.count) 

179 

180 # We could be "clever" here and set the default to this based type parameters and a custom 

181 # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a 

182 # "sort of wrong default" 

183 body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) 

184 

185 err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) 

186 

187 def send(self, msg: SendMsgType) -> ReceiveMsgType | None: 

188 """Send a request to the parent and block until the response is received.""" 

189 frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) 

190 frame_bytes = frame.as_bytes() 

191 

192 self.socket.sendall(frame_bytes) 

193 if isinstance(msg, ResendLoggingFD): 

194 if recv_fds is None: 

195 return None 

196 # We need special handling here! The server can't send us the fd number, as the number on the 

197 # supervisor will be different to in this process, so we have to mutate the message ourselves here. 

198 frame, fds = self._read_frame(maxfds=1) 

199 resp = self._from_frame(frame) 

200 if TYPE_CHECKING: 

201 assert isinstance(resp, SentFDs) 

202 resp.fds = fds 

203 # Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not 

204 # always be in the return type union 

205 return resp # type: ignore[return-value] 

206 

207 return self._get_response() 

208 

209 async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None: 

210 """Send a request to the parent without blocking.""" 

211 raise NotImplementedError 

212 

213 @overload 

214 def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ... 

215 

216 @overload 

217 def _read_frame(self, maxfds: int) -> tuple[_ResponseFrame, list[int]]: ... 

218 

219 def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[int]] | _ResponseFrame: 

220 """ 

221 Get a message from the parent. 

222 

223 This will block until the message has been received. 

224 """ 

225 if self.socket: 

226 self.socket.setblocking(True) 

227 fds = None 

228 if maxfds: 

229 len_bytes, fds, flag, address = recv_fds(self.socket, 4, maxfds) 

230 else: 

231 len_bytes = self.socket.recv(4) 

232 

233 if len_bytes == b"": 

234 raise EOFError("Request socket closed before length") 

235 

236 length = int.from_bytes(len_bytes, byteorder="big") 

237 

238 buffer = bytearray(length) 

239 mv = memoryview(buffer) 

240 

241 pos = 0 

242 while pos < length: 

243 nread = self.socket.recv_into(mv[pos:]) 

244 if nread == 0: 

245 raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") 

246 pos += nread 

247 

248 resp = self.resp_decoder.decode(mv) 

249 if maxfds: 

250 return resp, fds or [] 

251 return resp 

252 

253 def _from_frame(self, frame) -> ReceiveMsgType | None: 

254 from airflow.sdk.exceptions import AirflowRuntimeError 

255 

256 if frame.error is not None: 

257 err = self.err_decoder.validate_python(frame.error) 

258 raise AirflowRuntimeError(error=err) 

259 

260 if frame.body is None: 

261 return None 

262 

263 try: 

264 return self.body_decoder.validate_python(frame.body) 

265 except Exception: 

266 self.log.exception("Unable to decode message") 

267 raise 

268 

269 def _get_response(self) -> ReceiveMsgType | None: 

270 frame = self._read_frame() 

271 return self._from_frame(frame) 

272 

273 

274class StartupDetails(BaseModel): 

275 model_config = ConfigDict(arbitrary_types_allowed=True) 

276 

277 ti: TaskInstance 

278 dag_rel_path: str 

279 bundle_info: BundleInfo 

280 start_date: datetime 

281 ti_context: TIRunContext 

282 sentry_integration: str 

283 type: Literal["StartupDetails"] = "StartupDetails" 

284 

285 

286class AssetResult(AssetResponse): 

287 """Response to ReadXCom request.""" 

288 

289 type: Literal["AssetResult"] = "AssetResult" 

290 

291 @classmethod 

292 def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult: 

293 """ 

294 Get AssetResult from AssetResponse. 

295 

296 AssetResponse is autogenerated from the API schema, so we need to convert it to AssetResult 

297 for communication between the Supervisor and the task process. 

298 """ 

299 # Exclude defaults to avoid sending unnecessary data 

300 # Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True 

301 # to avoid sending unset fields (which are defaults in our case). 

302 return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult") 

303 

304 

305@attrs.define(kw_only=True) 

306class AssetEventSourceTaskInstance: 

307 """Used in AssetEventResult.""" 

308 

309 dag_run: DagRun 

310 task_id: str 

311 map_index: int 

312 

313 @property 

314 def dag_id(self) -> str: 

315 return self.dag_run.dag_id 

316 

317 @property 

318 def run_id(self) -> str: 

319 return self.dag_run.run_id 

320 

321 def xcom_pull(self, *, key: str = "return_value", default: Any = None) -> Any: 

322 from airflow.sdk.execution_time.xcom import XCom 

323 

324 if (value := XCom.get_value(ti_key=self, key=key)) is None: 

325 return default 

326 return value 

327 

328 

329def _fetch_dag_run(*, dag_id: str, run_id: str) -> DagRun: 

330 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

331 

332 response = SUPERVISOR_COMMS.send(GetDagRun(dag_id=dag_id, run_id=run_id)) 

333 if TYPE_CHECKING: 

334 assert isinstance(response, DagRunResult) 

335 return response 

336 

337 

338class AssetEventResult(AssetEventResponse): 

339 """Used in AssetEventsResult.""" 

340 

341 @classmethod 

342 def from_asset_event_response(cls, asset_event_response: AssetEventResponse) -> AssetEventResult: 

343 return cls(**asset_event_response.model_dump(exclude_defaults=True)) 

344 

345 @cached_property 

346 def source_dag_run(self) -> DagRun | None: 

347 if not self.source_dag_id or not self.source_run_id: 

348 return None 

349 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id) 

350 

351 @cached_property 

352 def source_task_instance(self) -> AssetEventSourceTaskInstance | None: 

353 if self.source_task_id is None or self.source_map_index is None: 

354 return None 

355 if (dag_run := self.source_dag_run) is None: 

356 return None 

357 return AssetEventSourceTaskInstance( 

358 dag_run=dag_run, 

359 task_id=self.source_task_id, 

360 map_index=self.source_map_index, 

361 ) 

362 

363 

364class AssetEventsResult(AssetEventsResponse): 

365 """Response to GetAssetEvent request.""" 

366 

367 type: Literal["AssetEventsResult"] = "AssetEventsResult" 

368 

369 @classmethod 

370 def from_asset_events_response(cls, asset_events_response: AssetEventsResponse) -> AssetEventsResult: 

371 """ 

372 Get AssetEventsResult from AssetEventsResponse. 

373 

374 AssetEventsResponse is autogenerated from the API schema, so we need to convert it to AssetEventsResponse 

375 for communication between the Supervisor and the task process. 

376 """ 

377 # Exclude defaults to avoid sending unnecessary data 

378 # Pass the type as AssetEventsResult explicitly so we can then call model_dump_json with exclude_unset=True 

379 # to avoid sending unset fields (which are defaults in our case). 

380 return cls( 

381 **asset_events_response.model_dump(exclude_defaults=True), 

382 type="AssetEventsResult", 

383 ) 

384 

385 def iter_asset_event_results(self) -> Iterator[AssetEventResult]: 

386 return (AssetEventResult.from_asset_event_response(event) for event in self.asset_events) 

387 

388 

389class AssetEventDagRunReferenceResult(AssetEventDagRunReference): 

390 @classmethod 

391 def from_asset_event_dag_run_reference( 

392 cls, 

393 asset_event_dag_run_reference: AssetEventDagRunReference, 

394 ) -> AssetEventDagRunReferenceResult: 

395 return cls(**asset_event_dag_run_reference.model_dump(exclude_defaults=True)) 

396 

397 @cached_property 

398 def source_dag_run(self) -> DagRun | None: 

399 if not self.source_dag_id or not self.source_run_id: 

400 return None 

401 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id) 

402 

403 @cached_property 

404 def source_task_instance(self) -> AssetEventSourceTaskInstance | None: 

405 if self.source_task_id is None or self.source_map_index is None: 

406 return None 

407 if (dag_run := self.source_dag_run) is None: 

408 return None 

409 return AssetEventSourceTaskInstance( 

410 dag_run=dag_run, 

411 task_id=self.source_task_id, 

412 map_index=self.source_map_index, 

413 ) 

414 

415 

416class InactiveAssetsResult(InactiveAssetsResponse): 

417 """Response of InactiveAssets requests.""" 

418 

419 type: Literal["InactiveAssetsResult"] = "InactiveAssetsResult" 

420 

421 @classmethod 

422 def from_inactive_assets_response( 

423 cls, inactive_assets_response: InactiveAssetsResponse 

424 ) -> InactiveAssetsResult: 

425 """ 

426 Get InactiveAssetsResponse from InactiveAssetsResult. 

427 

428 InactiveAssetsResponse is autogenerated from the API schema, so we need to convert it to InactiveAssetsResult 

429 for communication between the Supervisor and the task process. 

430 """ 

431 return cls(**inactive_assets_response.model_dump(exclude_defaults=True), type="InactiveAssetsResult") 

432 

433 

434class XComResult(XComResponse): 

435 """Response to ReadXCom request.""" 

436 

437 type: Literal["XComResult"] = "XComResult" 

438 

439 @classmethod 

440 def from_xcom_response(cls, xcom_response: XComResponse) -> XComResult: 

441 """ 

442 Get XComResult from XComResponse. 

443 

444 XComResponse is autogenerated from the API schema, so we need to convert it to XComResult 

445 for communication between the Supervisor and the task process. 

446 """ 

447 return cls(**xcom_response.model_dump(exclude_defaults=True), type="XComResult") 

448 

449 

450class XComCountResponse(BaseModel): 

451 len: int 

452 type: Literal["XComLengthResponse"] = "XComLengthResponse" 

453 

454 

455class XComSequenceIndexResult(BaseModel): 

456 root: JsonValue 

457 type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult" 

458 

459 @classmethod 

460 def from_response(cls, response: XComSequenceIndexResponse) -> XComSequenceIndexResult: 

461 return cls(root=response.root, type="XComSequenceIndexResult") 

462 

463 

464class XComSequenceSliceResult(BaseModel): 

465 root: list[JsonValue] 

466 type: Literal["XComSequenceSliceResult"] = "XComSequenceSliceResult" 

467 

468 @classmethod 

469 def from_response(cls, response: XComSequenceSliceResponse) -> XComSequenceSliceResult: 

470 return cls(root=response.root, type="XComSequenceSliceResult") 

471 

472 

473class ConnectionResult(ConnectionResponse): 

474 type: Literal["ConnectionResult"] = "ConnectionResult" 

475 

476 @classmethod 

477 def from_conn_response(cls, connection_response: ConnectionResponse) -> ConnectionResult: 

478 """ 

479 Get ConnectionResult from ConnectionResponse. 

480 

481 ConnectionResponse is autogenerated from the API schema, so we need to convert it to ConnectionResult 

482 for communication between the Supervisor and the task process. 

483 """ 

484 # Exclude defaults to avoid sending unnecessary data 

485 # Pass the type as ConnectionResult explicitly so we can then call model_dump_json with exclude_unset=True 

486 # to avoid sending unset fields (which are defaults in our case). 

487 return cls( 

488 **connection_response.model_dump(exclude_defaults=True, by_alias=True), type="ConnectionResult" 

489 ) 

490 

491 

492class VariableResult(VariableResponse): 

493 type: Literal["VariableResult"] = "VariableResult" 

494 

495 @classmethod 

496 def from_variable_response(cls, variable_response: VariableResponse) -> VariableResult: 

497 """ 

498 Get VariableResult from VariableResponse. 

499 

500 VariableResponse is autogenerated from the API schema, so we need to convert it to VariableResult 

501 for communication between the Supervisor and the task process. 

502 """ 

503 return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult") 

504 

505 

506class DagRunResult(DagRun): 

507 type: Literal["DagRunResult"] = "DagRunResult" 

508 

509 @classmethod 

510 def from_api_response(cls, dr_response: DagRun) -> DagRunResult: 

511 """ 

512 Create result class from API Response. 

513 

514 API Response is autogenerated from the API schema, so we need to convert it to Result 

515 for communication between the Supervisor and the task process since it needs a 

516 discriminator field. 

517 """ 

518 return cls(**dr_response.model_dump(exclude_defaults=True), type="DagRunResult") 

519 

520 

521class DagRunStateResult(DagRunStateResponse): 

522 type: Literal["DagRunStateResult"] = "DagRunStateResult" 

523 

524 # TODO: Create a convert api_response to result classes so we don't need to do this 

525 # for all the classes above 

526 @classmethod 

527 def from_api_response(cls, dr_state_response: DagRunStateResponse) -> DagRunStateResult: 

528 """ 

529 Create result class from API Response. 

530 

531 API Response is autogenerated from the API schema, so we need to convert it to Result 

532 for communication between the Supervisor and the task process since it needs a 

533 discriminator field. 

534 """ 

535 return cls(**dr_state_response.model_dump(exclude_defaults=True), type="DagRunStateResult") 

536 

537 

538class PreviousDagRunResult(BaseModel): 

539 """Response containing previous Dag run information.""" 

540 

541 dag_run: DagRun | None = None 

542 type: Literal["PreviousDagRunResult"] = "PreviousDagRunResult" 

543 

544 

545class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse): 

546 type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult" 

547 

548 @classmethod 

549 def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse) -> PrevSuccessfulDagRunResult: 

550 """ 

551 Get a result object from response object. 

552 

553 PrevSuccessfulDagRunResponse is autogenerated from the API schema, so we need to convert it to 

554 PrevSuccessfulDagRunResult for communication between the Supervisor and the task process. 

555 """ 

556 return cls(**prev_dag_run.model_dump(exclude_defaults=True), type="PrevSuccessfulDagRunResult") 

557 

558 

559class TaskRescheduleStartDate(BaseModel): 

560 """Response containing the first reschedule date for a task instance.""" 

561 

562 start_date: AwareDatetime | None 

563 type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate" 

564 

565 

566class TICount(BaseModel): 

567 """Response containing count of Task Instances matching certain filters.""" 

568 

569 count: int 

570 type: Literal["TICount"] = "TICount" 

571 

572 

573class TaskStatesResult(TaskStatesResponse): 

574 type: Literal["TaskStatesResult"] = "TaskStatesResult" 

575 

576 @classmethod 

577 def from_api_response(cls, task_states_response: TaskStatesResponse) -> TaskStatesResult: 

578 """ 

579 Create result class from API Response. 

580 

581 API Response is autogenerated from the API schema, so we need to convert it to Result 

582 for communication between the Supervisor and the task process since it needs a 

583 discriminator field. 

584 """ 

585 return cls(**task_states_response.model_dump(exclude_defaults=True), type="TaskStatesResult") 

586 

587 

588class TaskBreadcrumbsResult(TaskBreadcrumbsResponse): 

589 type: Literal["TaskBreadcrumbsResult"] = "TaskBreadcrumbsResult" 

590 

591 @classmethod 

592 def from_api_response(cls, response: TaskBreadcrumbsResponse) -> TaskBreadcrumbsResult: 

593 """ 

594 Create result class from API Response. 

595 

596 API Response is autogenerated from the API schema, so we need to convert 

597 it to Result for communication between the Supervisor and the task 

598 process since it needs a discriminator field. 

599 """ 

600 return cls(**response.model_dump(exclude_defaults=True), type="TaskBreadcrumbsResult") 

601 

602 

603class DRCount(BaseModel): 

604 """Response containing count of Dag Runs matching certain filters.""" 

605 

606 count: int 

607 type: Literal["DRCount"] = "DRCount" 

608 

609 

610class ErrorResponse(BaseModel): 

611 error: ErrorType = ErrorType.GENERIC_ERROR 

612 detail: dict | None = None 

613 type: Literal["ErrorResponse"] = "ErrorResponse" 

614 

615 

616class OKResponse(BaseModel): 

617 ok: bool 

618 type: Literal["OKResponse"] = "OKResponse" 

619 

620 

621class SentFDs(BaseModel): 

622 type: Literal["SentFDs"] = "SentFDs" 

623 fds: list[int] 

624 

625 

626class CreateHITLDetailPayload(HITLDetailRequest): 

627 """Add the input request part of a Human-in-the-loop response.""" 

628 

629 type: Literal["CreateHITLDetailPayload"] = "CreateHITLDetailPayload" 

630 

631 

632class HITLDetailRequestResult(HITLDetailRequest): 

633 """Response to CreateHITLDetailPayload request.""" 

634 

635 type: Literal["HITLDetailRequestResult"] = "HITLDetailRequestResult" 

636 

637 @classmethod 

638 def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequestResult: 

639 """ 

640 Get HITLDetailRequestResult from HITLDetailRequest (API response). 

641 

642 HITLDetailRequest is the API response model. We convert it to HITLDetailRequestResult 

643 for communication between the Supervisor and task process, adding the discriminator field 

644 required for the tagged union deserialization. 

645 """ 

646 return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult") 

647 

648 

649ToTask = Annotated[ 

650 AssetResult 

651 | AssetEventsResult 

652 | ConnectionResult 

653 | DagRunResult 

654 | DagRunStateResult 

655 | DRCount 

656 | ErrorResponse 

657 | PrevSuccessfulDagRunResult 

658 | SentFDs 

659 | StartupDetails 

660 | TaskRescheduleStartDate 

661 | TICount 

662 | TaskBreadcrumbsResult 

663 | TaskStatesResult 

664 | VariableResult 

665 | XComCountResponse 

666 | XComResult 

667 | XComSequenceIndexResult 

668 | XComSequenceSliceResult 

669 | InactiveAssetsResult 

670 | CreateHITLDetailPayload 

671 | HITLDetailRequestResult 

672 | OKResponse 

673 | PreviousDagRunResult, 

674 Field(discriminator="type"), 

675] 

676 

677 

678class TaskState(BaseModel): 

679 """ 

680 Update a task's state. 

681 

682 If a process exits without sending one of these the state will be derived from the exit code: 

683 - 0 = SUCCESS 

684 - anything else = FAILED 

685 """ 

686 

687 state: Literal[ 

688 TaskInstanceState.FAILED, 

689 TaskInstanceState.SKIPPED, 

690 TaskInstanceState.REMOVED, 

691 ] 

692 end_date: datetime | None = None 

693 type: Literal["TaskState"] = "TaskState" 

694 rendered_map_index: str | None = None 

695 

696 

697class SucceedTask(TISuccessStatePayload): 

698 """Update a task's state to success. Includes task_outlets and outlet_events for registering asset events.""" 

699 

700 type: Literal["SucceedTask"] = "SucceedTask" 

701 

702 

703class DeferTask(TIDeferredStatePayload): 

704 """Update a task instance state to deferred.""" 

705 

706 type: Literal["DeferTask"] = "DeferTask" 

707 

708 @field_serializer("trigger_kwargs", "next_kwargs", check_fields=True) 

709 def _serde_kwarg_fields(self, val: str | dict[str, Any] | None, _info): 

710 from airflow.serialization.serialized_objects import BaseSerialization 

711 

712 if not isinstance(val, dict): 

713 # None, or an encrypted string 

714 return val 

715 

716 if val.keys() == {"__type", "__var"}: 

717 # Already encoded. 

718 return val 

719 return BaseSerialization.serialize(val or {}) 

720 

721 

722class RetryTask(TIRetryStatePayload): 

723 """Update a task instance state to up_for_retry.""" 

724 

725 type: Literal["RetryTask"] = "RetryTask" 

726 

727 

728class RescheduleTask(TIRescheduleStatePayload): 

729 """Update a task instance state to reschedule/up_for_reschedule.""" 

730 

731 type: Literal["RescheduleTask"] = "RescheduleTask" 

732 

733 

734class SkipDownstreamTasks(TISkippedDownstreamTasksStatePayload): 

735 """Update state of downstream tasks within a task instance to 'skipped', while updating current task to success state.""" 

736 

737 type: Literal["SkipDownstreamTasks"] = "SkipDownstreamTasks" 

738 

739 

740class GetXCom(BaseModel): 

741 key: str 

742 dag_id: str 

743 run_id: str 

744 task_id: str 

745 map_index: int | None = None 

746 include_prior_dates: bool = False 

747 type: Literal["GetXCom"] = "GetXCom" 

748 

749 

750class GetXComCount(BaseModel): 

751 """Get the number of (mapped) XCom values available.""" 

752 

753 key: str 

754 dag_id: str 

755 run_id: str 

756 task_id: str 

757 type: Literal["GetNumberXComs"] = "GetNumberXComs" 

758 

759 

760class GetXComSequenceItem(BaseModel): 

761 key: str 

762 dag_id: str 

763 run_id: str 

764 task_id: str 

765 offset: int 

766 type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem" 

767 

768 

769class GetXComSequenceSlice(BaseModel): 

770 key: str 

771 dag_id: str 

772 run_id: str 

773 task_id: str 

774 start: int | None 

775 stop: int | None 

776 step: int | None 

777 include_prior_dates: bool = False 

778 type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice" 

779 

780 

781class SetXCom(BaseModel): 

782 key: str 

783 value: JsonValue 

784 dag_id: str 

785 run_id: str 

786 task_id: str 

787 map_index: int | None = None 

788 mapped_length: int | None = None 

789 type: Literal["SetXCom"] = "SetXCom" 

790 

791 

792class DeleteXCom(BaseModel): 

793 key: str 

794 dag_id: str 

795 run_id: str 

796 task_id: str 

797 map_index: int | None = None 

798 type: Literal["DeleteXCom"] = "DeleteXCom" 

799 

800 

801class GetConnection(BaseModel): 

802 conn_id: str 

803 type: Literal["GetConnection"] = "GetConnection" 

804 

805 

806class GetVariable(BaseModel): 

807 key: str 

808 type: Literal["GetVariable"] = "GetVariable" 

809 

810 

811class PutVariable(BaseModel): 

812 key: str 

813 value: str | None 

814 description: str | None 

815 type: Literal["PutVariable"] = "PutVariable" 

816 

817 

818class DeleteVariable(BaseModel): 

819 key: str 

820 type: Literal["DeleteVariable"] = "DeleteVariable" 

821 

822 

823class ResendLoggingFD(BaseModel): 

824 type: Literal["ResendLoggingFD"] = "ResendLoggingFD" 

825 

826 

827class SetRenderedFields(BaseModel): 

828 """Payload for setting RTIF for a task instance.""" 

829 

830 # We are using a BaseModel here compared to server using RootModel because we 

831 # have a discriminator running with "type", and RootModel doesn't support type 

832 

833 rendered_fields: dict[str, JsonValue] 

834 type: Literal["SetRenderedFields"] = "SetRenderedFields" 

835 

836 

837class SetRenderedMapIndex(BaseModel): 

838 """Payload for setting rendered_map_index for a task instance.""" 

839 

840 rendered_map_index: str 

841 type: Literal["SetRenderedMapIndex"] = "SetRenderedMapIndex" 

842 

843 

844class TriggerDagRun(TriggerDAGRunPayload): 

845 dag_id: str 

846 run_id: Annotated[str, Field(title="Dag Run Id")] 

847 type: Literal["TriggerDagRun"] = "TriggerDagRun" 

848 

849 

850class GetDagRun(BaseModel): 

851 dag_id: str 

852 run_id: str 

853 type: Literal["GetDagRun"] = "GetDagRun" 

854 

855 

856class GetDagRunState(BaseModel): 

857 dag_id: str 

858 run_id: str 

859 type: Literal["GetDagRunState"] = "GetDagRunState" 

860 

861 

862class GetPreviousDagRun(BaseModel): 

863 dag_id: str 

864 logical_date: AwareDatetime 

865 state: str | None = None 

866 type: Literal["GetPreviousDagRun"] = "GetPreviousDagRun" 

867 

868 

869class GetAssetByName(BaseModel): 

870 name: str 

871 type: Literal["GetAssetByName"] = "GetAssetByName" 

872 

873 

874class GetAssetByUri(BaseModel): 

875 uri: str 

876 type: Literal["GetAssetByUri"] = "GetAssetByUri" 

877 

878 

879class GetAssetEventByAsset(BaseModel): 

880 name: str | None 

881 uri: str | None 

882 after: AwareDatetime | None = None 

883 before: AwareDatetime | None = None 

884 limit: int | None = None 

885 ascending: bool = True 

886 type: Literal["GetAssetEventByAsset"] = "GetAssetEventByAsset" 

887 

888 

889class GetAssetEventByAssetAlias(BaseModel): 

890 alias_name: str 

891 after: AwareDatetime | None = None 

892 before: AwareDatetime | None = None 

893 limit: int | None = None 

894 ascending: bool = True 

895 type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias" 

896 

897 

898class ValidateInletsAndOutlets(BaseModel): 

899 ti_id: UUID 

900 type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets" 

901 

902 

903class GetPrevSuccessfulDagRun(BaseModel): 

904 ti_id: UUID 

905 type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun" 

906 

907 

908class GetTaskRescheduleStartDate(BaseModel): 

909 ti_id: UUID 

910 try_number: int = 1 

911 type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate" 

912 

913 

914class GetTICount(BaseModel): 

915 dag_id: str 

916 map_index: int | None = None 

917 task_ids: list[str] | None = None 

918 task_group_id: str | None = None 

919 logical_dates: list[AwareDatetime] | None = None 

920 run_ids: list[str] | None = None 

921 states: list[str] | None = None 

922 type: Literal["GetTICount"] = "GetTICount" 

923 

924 

925class GetTaskStates(BaseModel): 

926 dag_id: str 

927 map_index: int | None = None 

928 task_ids: list[str] | None = None 

929 task_group_id: str | None = None 

930 logical_dates: list[AwareDatetime] | None = None 

931 run_ids: list[str] | None = None 

932 type: Literal["GetTaskStates"] = "GetTaskStates" 

933 

934 

935class GetTaskBreadcrumbs(BaseModel): 

936 dag_id: str 

937 run_id: str 

938 type: Literal["GetTaskBreadcrumbs"] = "GetTaskBreadcrumbs" 

939 

940 

941class GetDRCount(BaseModel): 

942 dag_id: str 

943 logical_dates: list[AwareDatetime] | None = None 

944 run_ids: list[str] | None = None 

945 states: list[str] | None = None 

946 type: Literal["GetDRCount"] = "GetDRCount" 

947 

948 

949class GetHITLDetailResponse(BaseModel): 

950 """Get the response content part of a Human-in-the-loop response.""" 

951 

952 ti_id: UUID 

953 type: Literal["GetHITLDetailResponse"] = "GetHITLDetailResponse" 

954 

955 

956class UpdateHITLDetail(UpdateHITLDetailPayload): 

957 """Update the response content part of an existing Human-in-the-loop response.""" 

958 

959 type: Literal["UpdateHITLDetail"] = "UpdateHITLDetail" 

960 

961 

962class MaskSecret(BaseModel): 

963 """Add a new value to be redacted in task logs.""" 

964 

965 # This is needed since calls to `mask_secret` in the Task process will otherwise only add the mask value 

966 # to the child process, but the redaction happens in the parent. 

967 # We cannot use `string | Iterable | dict here` (would be more intuitive) because bug in Pydantic 

968 # https://github.com/pydantic/pydantic/issues/9541 turns iterable into a ValidatorIterator 

969 value: JsonValue 

970 name: str | None = None 

971 type: Literal["MaskSecret"] = "MaskSecret" 

972 

973 

974ToSupervisor = Annotated[ 

975 DeferTask 

976 | DeleteXCom 

977 | GetAssetByName 

978 | GetAssetByUri 

979 | GetAssetEventByAsset 

980 | GetAssetEventByAssetAlias 

981 | GetConnection 

982 | GetDagRun 

983 | GetDagRunState 

984 | GetDRCount 

985 | GetPrevSuccessfulDagRun 

986 | GetPreviousDagRun 

987 | GetTaskRescheduleStartDate 

988 | GetTICount 

989 | GetTaskBreadcrumbs 

990 | GetTaskStates 

991 | GetVariable 

992 | GetXCom 

993 | GetXComCount 

994 | GetXComSequenceItem 

995 | GetXComSequenceSlice 

996 | PutVariable 

997 | RescheduleTask 

998 | RetryTask 

999 | SetRenderedFields 

1000 | SetRenderedMapIndex 

1001 | SetXCom 

1002 | SkipDownstreamTasks 

1003 | SucceedTask 

1004 | ValidateInletsAndOutlets 

1005 | TaskState 

1006 | TriggerDagRun 

1007 | DeleteVariable 

1008 | ResendLoggingFD 

1009 | CreateHITLDetailPayload 

1010 | UpdateHITLDetail 

1011 | GetHITLDetailResponse 

1012 | MaskSecret, 

1013 Field(discriminator="type"), 

1014]