Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/comms.py: 76%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

463 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18r""" 

19Communication protocol between the Supervisor and the task process 

20================================================================== 

21 

22* All communication is done over the subprocesses stdin in the form of a binary length-prefixed msgpack frame 

23 (4 byte, big-endian length, followed by the msgpack-encoded _RequestFrame.) Each side uses this same 

24 encoding 

25* Log Messages from the subprocess are sent over the dedicated logs socket (which is line-based JSON) 

26* No messages are sent to task process except in response to a request. (This is because the task process will 

27 be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom 

28 value etc.) 

29* Every request returns a response, even if the frame is otherwise empty. 

30* Requests are written by the subprocess to fd0/stdin. This is making use of the fact that stdin is a 

31 bi-directional socket, and thus we can write to it and don't need a dedicated extra socket for sending 

32 requests. 

33 

34The reason this communication protocol exists, rather than the task process speaking directly to the Task 

35Execution API server is because: 

36 

371. To reduce the number of concurrent HTTP connections on the API server. 

38 

39 The supervisor already has to speak to that to heartbeat the running Task, so having the task speak to its 

40 parent process and having all API traffic go through that means that the number of HTTP connections is 

41 "halved". (Not every task will make API calls, so it's not always halved, but it is reduced.) 

42 

432. This means that the user Task code doesn't ever directly see the task identity JWT token. 

44 

45 This is a short lived token tied to one specific task instance try, so it being leaked/exfiltrated is not a 

46 large risk, but it's easy to not give it to the user code, so lets do that. 

47""" # noqa: D400, D205 

48 

49from __future__ import annotations 

50 

51import itertools 

52from collections.abc import Iterator 

53from datetime import datetime 

54from functools import cached_property 

55from pathlib import Path 

56from socket import socket 

57from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload 

58from uuid import UUID 

59 

60import attrs 

61import msgspec 

62import structlog 

63from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter 

64 

65from airflow.sdk.api.datamodels._generated import ( 

66 AssetEventDagRunReference, 

67 AssetEventResponse, 

68 AssetEventsResponse, 

69 AssetResponse, 

70 BundleInfo, 

71 ConnectionResponse, 

72 DagRun, 

73 DagRunStateResponse, 

74 HITLDetailRequest, 

75 InactiveAssetsResponse, 

76 PreviousTIResponse, 

77 PrevSuccessfulDagRunResponse, 

78 TaskBreadcrumbsResponse, 

79 TaskInstance, 

80 TaskInstanceState, 

81 TaskStatesResponse, 

82 TIDeferredStatePayload, 

83 TIRescheduleStatePayload, 

84 TIRetryStatePayload, 

85 TIRunContext, 

86 TISkippedDownstreamTasksStatePayload, 

87 TISuccessStatePayload, 

88 TriggerDAGRunPayload, 

89 UpdateHITLDetailPayload, 

90 VariableResponse, 

91 XComResponse, 

92 XComSequenceIndexResponse, 

93 XComSequenceSliceResponse, 

94) 

95from airflow.sdk.exceptions import ErrorType 

96 

97try: 

98 from socket import recv_fds 

99except ImportError: 

100 # Available on Unix and Windows (so "everywhere") but lets be safe 

101 recv_fds = None # type: ignore[assignment] 

102 

103 

104if TYPE_CHECKING: 

105 from structlog.typing import FilteringBoundLogger as Logger 

106 

107SendMsgType = TypeVar("SendMsgType", bound=BaseModel) 

108ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) 

109 

110 

111def _msgpack_enc_hook(obj: Any) -> Any: 

112 import pendulum 

113 

114 if isinstance(obj, pendulum.DateTime): 

115 # convert the pendulm Datetime subclass into a raw datetime so that msgspec can use it's native 

116 # encoding 

117 return datetime( 

118 obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo 

119 ) 

120 if isinstance(obj, Path): 

121 return str(obj) 

122 if isinstance(obj, BaseModel): 

123 return obj.model_dump(exclude_unset=True) 

124 

125 # Raise a NotImplementedError for other types 

126 raise NotImplementedError(f"Objects of type {type(obj)} are not supported") 

127 

128 

129def _new_encoder() -> msgspec.msgpack.Encoder: 

130 return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook) 

131 

132 

133class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): 

134 id: int 

135 """ 

136 The request id, set by the sender. 

137 

138 This is used to allow "pipeling" of requests and to be able to tie response to requests, which is 

139 particularly useful in the Triggerer where multiple async tasks can send a requests concurrently. 

140 """ 

141 body: dict[str, Any] | None 

142 

143 req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder() 

144 

145 def as_bytes(self) -> bytearray: 

146 # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration 

147 buffer = bytearray(256) 

148 

149 self.req_encoder.encode_into(self, buffer, 4) 

150 

151 n = len(buffer) - 4 

152 if n >= 2**32: 

153 raise OverflowError(f"Cannot send messages larger than 4GiB {n=}") 

154 buffer[:4] = n.to_bytes(4, byteorder="big") 

155 

156 return buffer 

157 

158 

159class _ResponseFrame(_RequestFrame, frozen=True): 

160 id: int 

161 """ 

162 The id of the request this is a response to 

163 """ 

164 body: dict[str, Any] | None = None 

165 error: dict[str, Any] | None = None 

166 

167 

168@attrs.define() 

169class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): 

170 """Handle communication between the task in this process and the supervisor parent process.""" 

171 

172 log: Logger = attrs.field(repr=False, factory=structlog.get_logger) 

173 socket: socket = attrs.field(factory=lambda: socket(fileno=0)) 

174 

175 resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field( 

176 factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False 

177 ) 

178 

179 id_counter: Iterator[int] = attrs.field(factory=itertools.count) 

180 

181 # We could be "clever" here and set the default to this based type parameters and a custom 

182 # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a 

183 # "sort of wrong default" 

184 body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) 

185 

186 err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) 

187 

188 def send(self, msg: SendMsgType) -> ReceiveMsgType | None: 

189 """Send a request to the parent and block until the response is received.""" 

190 frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) 

191 frame_bytes = frame.as_bytes() 

192 

193 self.socket.sendall(frame_bytes) 

194 if isinstance(msg, ResendLoggingFD): 

195 if recv_fds is None: 

196 return None 

197 # We need special handling here! The server can't send us the fd number, as the number on the 

198 # supervisor will be different to in this process, so we have to mutate the message ourselves here. 

199 frame, fds = self._read_frame(maxfds=1) 

200 resp = self._from_frame(frame) 

201 if TYPE_CHECKING: 

202 assert isinstance(resp, SentFDs) 

203 resp.fds = fds 

204 # Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not 

205 # always be in the return type union 

206 return resp # type: ignore[return-value] 

207 

208 return self._get_response() 

209 

210 async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None: 

211 """Send a request to the parent without blocking.""" 

212 raise NotImplementedError 

213 

214 @overload 

215 def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ... 

216 

217 @overload 

218 def _read_frame(self, maxfds: int) -> tuple[_ResponseFrame, list[int]]: ... 

219 

220 def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[int]] | _ResponseFrame: 

221 """ 

222 Get a message from the parent. 

223 

224 This will block until the message has been received. 

225 """ 

226 if self.socket: 

227 self.socket.setblocking(True) 

228 fds = None 

229 if maxfds: 

230 len_bytes, fds, flag, address = recv_fds(self.socket, 4, maxfds) 

231 else: 

232 len_bytes = self.socket.recv(4) 

233 

234 if len_bytes == b"": 

235 raise EOFError("Request socket closed before length") 

236 

237 length = int.from_bytes(len_bytes, byteorder="big") 

238 

239 buffer = bytearray(length) 

240 mv = memoryview(buffer) 

241 

242 pos = 0 

243 while pos < length: 

244 nread = self.socket.recv_into(mv[pos:]) 

245 if nread == 0: 

246 raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") 

247 pos += nread 

248 

249 resp = self.resp_decoder.decode(mv) 

250 if maxfds: 

251 return resp, fds or [] 

252 return resp 

253 

254 def _from_frame(self, frame) -> ReceiveMsgType | None: 

255 from airflow.sdk.exceptions import AirflowRuntimeError 

256 

257 if frame.error is not None: 

258 err = self.err_decoder.validate_python(frame.error) 

259 raise AirflowRuntimeError(error=err) 

260 

261 if frame.body is None: 

262 return None 

263 

264 try: 

265 return self.body_decoder.validate_python(frame.body) 

266 except Exception: 

267 self.log.exception("Unable to decode message") 

268 raise 

269 

270 def _get_response(self) -> ReceiveMsgType | None: 

271 frame = self._read_frame() 

272 return self._from_frame(frame) 

273 

274 

275class StartupDetails(BaseModel): 

276 model_config = ConfigDict(arbitrary_types_allowed=True) 

277 

278 ti: TaskInstance 

279 dag_rel_path: str 

280 bundle_info: BundleInfo 

281 start_date: datetime 

282 ti_context: TIRunContext 

283 sentry_integration: str 

284 type: Literal["StartupDetails"] = "StartupDetails" 

285 

286 

287class AssetResult(AssetResponse): 

288 """Response to ReadXCom request.""" 

289 

290 type: Literal["AssetResult"] = "AssetResult" 

291 

292 @classmethod 

293 def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult: 

294 """ 

295 Get AssetResult from AssetResponse. 

296 

297 AssetResponse is autogenerated from the API schema, so we need to convert it to AssetResult 

298 for communication between the Supervisor and the task process. 

299 """ 

300 # Exclude defaults to avoid sending unnecessary data 

301 # Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True 

302 # to avoid sending unset fields (which are defaults in our case). 

303 return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult") 

304 

305 

306@attrs.define(kw_only=True) 

307class AssetEventSourceTaskInstance: 

308 """Used in AssetEventResult.""" 

309 

310 dag_run: DagRun 

311 task_id: str 

312 map_index: int 

313 

314 @property 

315 def dag_id(self) -> str: 

316 return self.dag_run.dag_id 

317 

318 @property 

319 def run_id(self) -> str: 

320 return self.dag_run.run_id 

321 

322 def xcom_pull(self, *, key: str = "return_value", default: Any = None) -> Any: 

323 from airflow.sdk.execution_time.xcom import XCom 

324 

325 if (value := XCom.get_value(ti_key=self, key=key)) is None: 

326 return default 

327 return value 

328 

329 

330def _fetch_dag_run(*, dag_id: str, run_id: str) -> DagRun: 

331 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

332 

333 response = SUPERVISOR_COMMS.send(GetDagRun(dag_id=dag_id, run_id=run_id)) 

334 if TYPE_CHECKING: 

335 assert isinstance(response, DagRunResult) 

336 return response 

337 

338 

339class AssetEventResult(AssetEventResponse): 

340 """Used in AssetEventsResult.""" 

341 

342 @classmethod 

343 def from_asset_event_response(cls, asset_event_response: AssetEventResponse) -> AssetEventResult: 

344 return cls(**asset_event_response.model_dump(exclude_defaults=True)) 

345 

346 @cached_property 

347 def source_dag_run(self) -> DagRun | None: 

348 if not self.source_dag_id or not self.source_run_id: 

349 return None 

350 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id) 

351 

352 @cached_property 

353 def source_task_instance(self) -> AssetEventSourceTaskInstance | None: 

354 if self.source_task_id is None or self.source_map_index is None: 

355 return None 

356 if (dag_run := self.source_dag_run) is None: 

357 return None 

358 return AssetEventSourceTaskInstance( 

359 dag_run=dag_run, 

360 task_id=self.source_task_id, 

361 map_index=self.source_map_index, 

362 ) 

363 

364 

365class AssetEventsResult(AssetEventsResponse): 

366 """Response to GetAssetEvent request.""" 

367 

368 type: Literal["AssetEventsResult"] = "AssetEventsResult" 

369 

370 @classmethod 

371 def from_asset_events_response(cls, asset_events_response: AssetEventsResponse) -> AssetEventsResult: 

372 """ 

373 Get AssetEventsResult from AssetEventsResponse. 

374 

375 AssetEventsResponse is autogenerated from the API schema, so we need to convert it to AssetEventsResponse 

376 for communication between the Supervisor and the task process. 

377 """ 

378 # Exclude defaults to avoid sending unnecessary data 

379 # Pass the type as AssetEventsResult explicitly so we can then call model_dump_json with exclude_unset=True 

380 # to avoid sending unset fields (which are defaults in our case). 

381 return cls( 

382 **asset_events_response.model_dump(exclude_defaults=True), 

383 type="AssetEventsResult", 

384 ) 

385 

386 def iter_asset_event_results(self) -> Iterator[AssetEventResult]: 

387 return (AssetEventResult.from_asset_event_response(event) for event in self.asset_events) 

388 

389 

390class AssetEventDagRunReferenceResult(AssetEventDagRunReference): 

391 @classmethod 

392 def from_asset_event_dag_run_reference( 

393 cls, 

394 asset_event_dag_run_reference: AssetEventDagRunReference, 

395 ) -> AssetEventDagRunReferenceResult: 

396 return cls(**asset_event_dag_run_reference.model_dump(exclude_defaults=True)) 

397 

398 @cached_property 

399 def source_dag_run(self) -> DagRun | None: 

400 if not self.source_dag_id or not self.source_run_id: 

401 return None 

402 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id) 

403 

404 @cached_property 

405 def source_task_instance(self) -> AssetEventSourceTaskInstance | None: 

406 if self.source_task_id is None or self.source_map_index is None: 

407 return None 

408 if (dag_run := self.source_dag_run) is None: 

409 return None 

410 return AssetEventSourceTaskInstance( 

411 dag_run=dag_run, 

412 task_id=self.source_task_id, 

413 map_index=self.source_map_index, 

414 ) 

415 

416 

417class InactiveAssetsResult(InactiveAssetsResponse): 

418 """Response of InactiveAssets requests.""" 

419 

420 type: Literal["InactiveAssetsResult"] = "InactiveAssetsResult" 

421 

422 @classmethod 

423 def from_inactive_assets_response( 

424 cls, inactive_assets_response: InactiveAssetsResponse 

425 ) -> InactiveAssetsResult: 

426 """ 

427 Get InactiveAssetsResponse from InactiveAssetsResult. 

428 

429 InactiveAssetsResponse is autogenerated from the API schema, so we need to convert it to InactiveAssetsResult 

430 for communication between the Supervisor and the task process. 

431 """ 

432 return cls(**inactive_assets_response.model_dump(exclude_defaults=True), type="InactiveAssetsResult") 

433 

434 

435class XComResult(XComResponse): 

436 """Response to ReadXCom request.""" 

437 

438 type: Literal["XComResult"] = "XComResult" 

439 

440 @classmethod 

441 def from_xcom_response(cls, xcom_response: XComResponse) -> XComResult: 

442 """ 

443 Get XComResult from XComResponse. 

444 

445 XComResponse is autogenerated from the API schema, so we need to convert it to XComResult 

446 for communication between the Supervisor and the task process. 

447 """ 

448 return cls(**xcom_response.model_dump(exclude_defaults=True), type="XComResult") 

449 

450 

451class XComCountResponse(BaseModel): 

452 len: int 

453 type: Literal["XComLengthResponse"] = "XComLengthResponse" 

454 

455 

456class XComSequenceIndexResult(BaseModel): 

457 root: JsonValue 

458 type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult" 

459 

460 @classmethod 

461 def from_response(cls, response: XComSequenceIndexResponse) -> XComSequenceIndexResult: 

462 return cls(root=response.root, type="XComSequenceIndexResult") 

463 

464 

465class XComSequenceSliceResult(BaseModel): 

466 root: list[JsonValue] 

467 type: Literal["XComSequenceSliceResult"] = "XComSequenceSliceResult" 

468 

469 @classmethod 

470 def from_response(cls, response: XComSequenceSliceResponse) -> XComSequenceSliceResult: 

471 return cls(root=response.root, type="XComSequenceSliceResult") 

472 

473 

474class ConnectionResult(ConnectionResponse): 

475 type: Literal["ConnectionResult"] = "ConnectionResult" 

476 

477 @classmethod 

478 def from_conn_response(cls, connection_response: ConnectionResponse) -> ConnectionResult: 

479 """ 

480 Get ConnectionResult from ConnectionResponse. 

481 

482 ConnectionResponse is autogenerated from the API schema, so we need to convert it to ConnectionResult 

483 for communication between the Supervisor and the task process. 

484 """ 

485 # Exclude defaults to avoid sending unnecessary data 

486 # Pass the type as ConnectionResult explicitly so we can then call model_dump_json with exclude_unset=True 

487 # to avoid sending unset fields (which are defaults in our case). 

488 return cls( 

489 **connection_response.model_dump(exclude_defaults=True, by_alias=True), type="ConnectionResult" 

490 ) 

491 

492 

493class VariableResult(VariableResponse): 

494 type: Literal["VariableResult"] = "VariableResult" 

495 

496 @classmethod 

497 def from_variable_response(cls, variable_response: VariableResponse) -> VariableResult: 

498 """ 

499 Get VariableResult from VariableResponse. 

500 

501 VariableResponse is autogenerated from the API schema, so we need to convert it to VariableResult 

502 for communication between the Supervisor and the task process. 

503 """ 

504 return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult") 

505 

506 

507class DagRunResult(DagRun): 

508 type: Literal["DagRunResult"] = "DagRunResult" 

509 

510 @classmethod 

511 def from_api_response(cls, dr_response: DagRun) -> DagRunResult: 

512 """ 

513 Create result class from API Response. 

514 

515 API Response is autogenerated from the API schema, so we need to convert it to Result 

516 for communication between the Supervisor and the task process since it needs a 

517 discriminator field. 

518 """ 

519 return cls(**dr_response.model_dump(exclude_defaults=True), type="DagRunResult") 

520 

521 

522class DagRunStateResult(DagRunStateResponse): 

523 type: Literal["DagRunStateResult"] = "DagRunStateResult" 

524 

525 # TODO: Create a convert api_response to result classes so we don't need to do this 

526 # for all the classes above 

527 @classmethod 

528 def from_api_response(cls, dr_state_response: DagRunStateResponse) -> DagRunStateResult: 

529 """ 

530 Create result class from API Response. 

531 

532 API Response is autogenerated from the API schema, so we need to convert it to Result 

533 for communication between the Supervisor and the task process since it needs a 

534 discriminator field. 

535 """ 

536 return cls(**dr_state_response.model_dump(exclude_defaults=True), type="DagRunStateResult") 

537 

538 

539class PreviousDagRunResult(BaseModel): 

540 """Response containing previous Dag run information.""" 

541 

542 dag_run: DagRun | None = None 

543 type: Literal["PreviousDagRunResult"] = "PreviousDagRunResult" 

544 

545 

546class PreviousTIResult(BaseModel): 

547 """Response containing previous task instance data.""" 

548 

549 task_instance: PreviousTIResponse | None = None 

550 type: Literal["PreviousTIResult"] = "PreviousTIResult" 

551 

552 

553class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse): 

554 type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult" 

555 

556 @classmethod 

557 def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse) -> PrevSuccessfulDagRunResult: 

558 """ 

559 Get a result object from response object. 

560 

561 PrevSuccessfulDagRunResponse is autogenerated from the API schema, so we need to convert it to 

562 PrevSuccessfulDagRunResult for communication between the Supervisor and the task process. 

563 """ 

564 return cls(**prev_dag_run.model_dump(exclude_defaults=True), type="PrevSuccessfulDagRunResult") 

565 

566 

567class TaskRescheduleStartDate(BaseModel): 

568 """Response containing the first reschedule date for a task instance.""" 

569 

570 start_date: AwareDatetime | None 

571 type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate" 

572 

573 

574class TICount(BaseModel): 

575 """Response containing count of Task Instances matching certain filters.""" 

576 

577 count: int 

578 type: Literal["TICount"] = "TICount" 

579 

580 

581class TaskStatesResult(TaskStatesResponse): 

582 type: Literal["TaskStatesResult"] = "TaskStatesResult" 

583 

584 @classmethod 

585 def from_api_response(cls, task_states_response: TaskStatesResponse) -> TaskStatesResult: 

586 """ 

587 Create result class from API Response. 

588 

589 API Response is autogenerated from the API schema, so we need to convert it to Result 

590 for communication between the Supervisor and the task process since it needs a 

591 discriminator field. 

592 """ 

593 return cls(**task_states_response.model_dump(exclude_defaults=True), type="TaskStatesResult") 

594 

595 

596class TaskBreadcrumbsResult(TaskBreadcrumbsResponse): 

597 type: Literal["TaskBreadcrumbsResult"] = "TaskBreadcrumbsResult" 

598 

599 @classmethod 

600 def from_api_response(cls, response: TaskBreadcrumbsResponse) -> TaskBreadcrumbsResult: 

601 """ 

602 Create result class from API Response. 

603 

604 API Response is autogenerated from the API schema, so we need to convert 

605 it to Result for communication between the Supervisor and the task 

606 process since it needs a discriminator field. 

607 """ 

608 return cls(**response.model_dump(exclude_defaults=True), type="TaskBreadcrumbsResult") 

609 

610 

611class DRCount(BaseModel): 

612 """Response containing count of Dag Runs matching certain filters.""" 

613 

614 count: int 

615 type: Literal["DRCount"] = "DRCount" 

616 

617 

618class ErrorResponse(BaseModel): 

619 error: ErrorType = ErrorType.GENERIC_ERROR 

620 detail: dict | None = None 

621 type: Literal["ErrorResponse"] = "ErrorResponse" 

622 

623 

624class OKResponse(BaseModel): 

625 ok: bool 

626 type: Literal["OKResponse"] = "OKResponse" 

627 

628 

629class SentFDs(BaseModel): 

630 type: Literal["SentFDs"] = "SentFDs" 

631 fds: list[int] 

632 

633 

634class CreateHITLDetailPayload(HITLDetailRequest): 

635 """Add the input request part of a Human-in-the-loop response.""" 

636 

637 type: Literal["CreateHITLDetailPayload"] = "CreateHITLDetailPayload" 

638 

639 

640class HITLDetailRequestResult(HITLDetailRequest): 

641 """Response to CreateHITLDetailPayload request.""" 

642 

643 type: Literal["HITLDetailRequestResult"] = "HITLDetailRequestResult" 

644 

645 @classmethod 

646 def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequestResult: 

647 """ 

648 Get HITLDetailRequestResult from HITLDetailRequest (API response). 

649 

650 HITLDetailRequest is the API response model. We convert it to HITLDetailRequestResult 

651 for communication between the Supervisor and task process, adding the discriminator field 

652 required for the tagged union deserialization. 

653 """ 

654 return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult") 

655 

656 

657ToTask = Annotated[ 

658 AssetResult 

659 | AssetEventsResult 

660 | ConnectionResult 

661 | DagRunResult 

662 | DagRunStateResult 

663 | DRCount 

664 | ErrorResponse 

665 | PrevSuccessfulDagRunResult 

666 | PreviousTIResult 

667 | SentFDs 

668 | StartupDetails 

669 | TaskRescheduleStartDate 

670 | TICount 

671 | TaskBreadcrumbsResult 

672 | TaskStatesResult 

673 | VariableResult 

674 | XComCountResponse 

675 | XComResult 

676 | XComSequenceIndexResult 

677 | XComSequenceSliceResult 

678 | InactiveAssetsResult 

679 | CreateHITLDetailPayload 

680 | HITLDetailRequestResult 

681 | OKResponse 

682 | PreviousDagRunResult, 

683 Field(discriminator="type"), 

684] 

685 

686 

687class TaskState(BaseModel): 

688 """ 

689 Update a task's state. 

690 

691 If a process exits without sending one of these the state will be derived from the exit code: 

692 - 0 = SUCCESS 

693 - anything else = FAILED 

694 """ 

695 

696 state: Literal[ 

697 TaskInstanceState.FAILED, 

698 TaskInstanceState.SKIPPED, 

699 TaskInstanceState.REMOVED, 

700 ] 

701 end_date: datetime | None = None 

702 type: Literal["TaskState"] = "TaskState" 

703 rendered_map_index: str | None = None 

704 

705 

706class SucceedTask(TISuccessStatePayload): 

707 """Update a task's state to success. Includes task_outlets and outlet_events for registering asset events.""" 

708 

709 type: Literal["SucceedTask"] = "SucceedTask" 

710 

711 

712class DeferTask(TIDeferredStatePayload): 

713 """Update a task instance state to deferred.""" 

714 

715 type: Literal["DeferTask"] = "DeferTask" 

716 

717 

718class RetryTask(TIRetryStatePayload): 

719 """Update a task instance state to up_for_retry.""" 

720 

721 type: Literal["RetryTask"] = "RetryTask" 

722 

723 

724class RescheduleTask(TIRescheduleStatePayload): 

725 """Update a task instance state to reschedule/up_for_reschedule.""" 

726 

727 type: Literal["RescheduleTask"] = "RescheduleTask" 

728 

729 

730class SkipDownstreamTasks(TISkippedDownstreamTasksStatePayload): 

731 """Update state of downstream tasks within a task instance to 'skipped', while updating current task to success state.""" 

732 

733 type: Literal["SkipDownstreamTasks"] = "SkipDownstreamTasks" 

734 

735 

736class GetXCom(BaseModel): 

737 key: str 

738 dag_id: str 

739 run_id: str 

740 task_id: str 

741 map_index: int | None = None 

742 include_prior_dates: bool = False 

743 type: Literal["GetXCom"] = "GetXCom" 

744 

745 

746class GetXComCount(BaseModel): 

747 """Get the number of (mapped) XCom values available.""" 

748 

749 key: str 

750 dag_id: str 

751 run_id: str 

752 task_id: str 

753 type: Literal["GetNumberXComs"] = "GetNumberXComs" 

754 

755 

756class GetXComSequenceItem(BaseModel): 

757 key: str 

758 dag_id: str 

759 run_id: str 

760 task_id: str 

761 offset: int 

762 type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem" 

763 

764 

765class GetXComSequenceSlice(BaseModel): 

766 key: str 

767 dag_id: str 

768 run_id: str 

769 task_id: str 

770 start: int | None 

771 stop: int | None 

772 step: int | None 

773 include_prior_dates: bool = False 

774 type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice" 

775 

776 

777class SetXCom(BaseModel): 

778 key: str 

779 value: JsonValue 

780 dag_id: str 

781 run_id: str 

782 task_id: str 

783 map_index: int | None = None 

784 mapped_length: int | None = None 

785 type: Literal["SetXCom"] = "SetXCom" 

786 

787 

788class DeleteXCom(BaseModel): 

789 key: str 

790 dag_id: str 

791 run_id: str 

792 task_id: str 

793 map_index: int | None = None 

794 type: Literal["DeleteXCom"] = "DeleteXCom" 

795 

796 

797class GetConnection(BaseModel): 

798 conn_id: str 

799 type: Literal["GetConnection"] = "GetConnection" 

800 

801 

802class GetVariable(BaseModel): 

803 key: str 

804 type: Literal["GetVariable"] = "GetVariable" 

805 

806 

807class PutVariable(BaseModel): 

808 key: str 

809 value: str | None 

810 description: str | None 

811 type: Literal["PutVariable"] = "PutVariable" 

812 

813 

814class DeleteVariable(BaseModel): 

815 key: str 

816 type: Literal["DeleteVariable"] = "DeleteVariable" 

817 

818 

819class ResendLoggingFD(BaseModel): 

820 type: Literal["ResendLoggingFD"] = "ResendLoggingFD" 

821 

822 

823class SetRenderedFields(BaseModel): 

824 """Payload for setting RTIF for a task instance.""" 

825 

826 # We are using a BaseModel here compared to server using RootModel because we 

827 # have a discriminator running with "type", and RootModel doesn't support type 

828 

829 rendered_fields: dict[str, JsonValue] 

830 type: Literal["SetRenderedFields"] = "SetRenderedFields" 

831 

832 

833class SetRenderedMapIndex(BaseModel): 

834 """Payload for setting rendered_map_index for a task instance.""" 

835 

836 rendered_map_index: str 

837 type: Literal["SetRenderedMapIndex"] = "SetRenderedMapIndex" 

838 

839 

840class TriggerDagRun(TriggerDAGRunPayload): 

841 dag_id: str 

842 run_id: Annotated[str, Field(title="Dag Run Id")] 

843 type: Literal["TriggerDagRun"] = "TriggerDagRun" 

844 

845 

846class GetDagRun(BaseModel): 

847 dag_id: str 

848 run_id: str 

849 type: Literal["GetDagRun"] = "GetDagRun" 

850 

851 

852class GetDagRunState(BaseModel): 

853 dag_id: str 

854 run_id: str 

855 type: Literal["GetDagRunState"] = "GetDagRunState" 

856 

857 

858class GetPreviousDagRun(BaseModel): 

859 dag_id: str 

860 logical_date: AwareDatetime 

861 state: str | None = None 

862 type: Literal["GetPreviousDagRun"] = "GetPreviousDagRun" 

863 

864 

865class GetPreviousTI(BaseModel): 

866 """Request to get previous task instance.""" 

867 

868 dag_id: str 

869 task_id: str 

870 logical_date: AwareDatetime | None = None 

871 map_index: int = -1 

872 state: TaskInstanceState | None = None 

873 type: Literal["GetPreviousTI"] = "GetPreviousTI" 

874 

875 

876class GetAssetByName(BaseModel): 

877 name: str 

878 type: Literal["GetAssetByName"] = "GetAssetByName" 

879 

880 

881class GetAssetByUri(BaseModel): 

882 uri: str 

883 type: Literal["GetAssetByUri"] = "GetAssetByUri" 

884 

885 

886class GetAssetEventByAsset(BaseModel): 

887 name: str | None 

888 uri: str | None 

889 after: AwareDatetime | None = None 

890 before: AwareDatetime | None = None 

891 limit: int | None = None 

892 ascending: bool = True 

893 type: Literal["GetAssetEventByAsset"] = "GetAssetEventByAsset" 

894 

895 

896class GetAssetEventByAssetAlias(BaseModel): 

897 alias_name: str 

898 after: AwareDatetime | None = None 

899 before: AwareDatetime | None = None 

900 limit: int | None = None 

901 ascending: bool = True 

902 type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias" 

903 

904 

905class ValidateInletsAndOutlets(BaseModel): 

906 ti_id: UUID 

907 type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets" 

908 

909 

910class GetPrevSuccessfulDagRun(BaseModel): 

911 ti_id: UUID 

912 type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun" 

913 

914 

915class GetTaskRescheduleStartDate(BaseModel): 

916 ti_id: UUID 

917 try_number: int = 1 

918 type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate" 

919 

920 

921class GetTICount(BaseModel): 

922 dag_id: str 

923 map_index: int | None = None 

924 task_ids: list[str] | None = None 

925 task_group_id: str | None = None 

926 logical_dates: list[AwareDatetime] | None = None 

927 run_ids: list[str] | None = None 

928 states: list[str] | None = None 

929 type: Literal["GetTICount"] = "GetTICount" 

930 

931 

932class GetTaskStates(BaseModel): 

933 dag_id: str 

934 map_index: int | None = None 

935 task_ids: list[str] | None = None 

936 task_group_id: str | None = None 

937 logical_dates: list[AwareDatetime] | None = None 

938 run_ids: list[str] | None = None 

939 type: Literal["GetTaskStates"] = "GetTaskStates" 

940 

941 

942class GetTaskBreadcrumbs(BaseModel): 

943 dag_id: str 

944 run_id: str 

945 type: Literal["GetTaskBreadcrumbs"] = "GetTaskBreadcrumbs" 

946 

947 

948class GetDRCount(BaseModel): 

949 dag_id: str 

950 logical_dates: list[AwareDatetime] | None = None 

951 run_ids: list[str] | None = None 

952 states: list[str] | None = None 

953 type: Literal["GetDRCount"] = "GetDRCount" 

954 

955 

956class GetHITLDetailResponse(BaseModel): 

957 """Get the response content part of a Human-in-the-loop response.""" 

958 

959 ti_id: UUID 

960 type: Literal["GetHITLDetailResponse"] = "GetHITLDetailResponse" 

961 

962 

963class UpdateHITLDetail(UpdateHITLDetailPayload): 

964 """Update the response content part of an existing Human-in-the-loop response.""" 

965 

966 type: Literal["UpdateHITLDetail"] = "UpdateHITLDetail" 

967 

968 

969class MaskSecret(BaseModel): 

970 """Add a new value to be redacted in task logs.""" 

971 

972 # This is needed since calls to `mask_secret` in the Task process will otherwise only add the mask value 

973 # to the child process, but the redaction happens in the parent. 

974 # We cannot use `string | Iterable | dict here` (would be more intuitive) because bug in Pydantic 

975 # https://github.com/pydantic/pydantic/issues/9541 turns iterable into a ValidatorIterator 

976 value: JsonValue 

977 name: str | None = None 

978 type: Literal["MaskSecret"] = "MaskSecret" 

979 

980 

981ToSupervisor = Annotated[ 

982 DeferTask 

983 | DeleteXCom 

984 | GetAssetByName 

985 | GetAssetByUri 

986 | GetAssetEventByAsset 

987 | GetAssetEventByAssetAlias 

988 | GetConnection 

989 | GetDagRun 

990 | GetDagRunState 

991 | GetDRCount 

992 | GetPrevSuccessfulDagRun 

993 | GetPreviousDagRun 

994 | GetPreviousTI 

995 | GetTaskRescheduleStartDate 

996 | GetTICount 

997 | GetTaskBreadcrumbs 

998 | GetTaskStates 

999 | GetVariable 

1000 | GetXCom 

1001 | GetXComCount 

1002 | GetXComSequenceItem 

1003 | GetXComSequenceSlice 

1004 | PutVariable 

1005 | RescheduleTask 

1006 | RetryTask 

1007 | SetRenderedFields 

1008 | SetRenderedMapIndex 

1009 | SetXCom 

1010 | SkipDownstreamTasks 

1011 | SucceedTask 

1012 | ValidateInletsAndOutlets 

1013 | TaskState 

1014 | TriggerDagRun 

1015 | DeleteVariable 

1016 | ResendLoggingFD 

1017 | CreateHITLDetailPayload 

1018 | UpdateHITLDetail 

1019 | GetHITLDetailResponse 

1020 | MaskSecret, 

1021 Field(discriminator="type"), 

1022]