Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/comms.py: 73%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

486 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18r""" 

19Communication protocol between the Supervisor and the task process 

20================================================================== 

21 

22* All communication is done over the subprocesses stdin in the form of a binary length-prefixed msgpack frame 

23 (4 byte, big-endian length, followed by the msgpack-encoded _RequestFrame.) Each side uses this same 

24 encoding 

25* Log Messages from the subprocess are sent over the dedicated logs socket (which is line-based JSON) 

26* No messages are sent to task process except in response to a request. (This is because the task process will 

27 be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom 

28 value etc.) 

29* Every request returns a response, even if the frame is otherwise empty. 

30* Requests are written by the subprocess to fd0/stdin. This is making use of the fact that stdin is a 

31 bi-directional socket, and thus we can write to it and don't need a dedicated extra socket for sending 

32 requests. 

33 

34The reason this communication protocol exists, rather than the task process speaking directly to the Task 

35Execution API server is because: 

36 

371. To reduce the number of concurrent HTTP connections on the API server. 

38 

39 The supervisor already has to speak to that to heartbeat the running Task, so having the task speak to its 

40 parent process and having all API traffic go through that means that the number of HTTP connections is 

41 "halved". (Not every task will make API calls, so it's not always halved, but it is reduced.) 

42 

432. This means that the user Task code doesn't ever directly see the task identity JWT token. 

44 

45 This is a short lived token tied to one specific task instance try, so it being leaked/exfiltrated is not a 

46 large risk, but it's easy to not give it to the user code, so lets do that. 

47""" # noqa: D400, D205 

48 

49from __future__ import annotations 

50 

51import asyncio 

52import itertools 

53import threading 

54from collections.abc import Iterator 

55from datetime import datetime 

56from functools import cached_property 

57from pathlib import Path 

58from socket import socket 

59from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload 

60from uuid import UUID 

61 

62import attrs 

63import msgspec 

64import structlog 

65from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter 

66 

67from airflow.sdk.api.datamodels._generated import ( 

68 AssetEventDagRunReference, 

69 AssetEventResponse, 

70 AssetEventsResponse, 

71 AssetResponse, 

72 BundleInfo, 

73 ConnectionResponse, 

74 DagRun, 

75 DagRunStateResponse, 

76 HITLDetailRequest, 

77 InactiveAssetsResponse, 

78 PreviousTIResponse, 

79 PrevSuccessfulDagRunResponse, 

80 TaskBreadcrumbsResponse, 

81 TaskInstance, 

82 TaskInstanceState, 

83 TaskStatesResponse, 

84 TIDeferredStatePayload, 

85 TIRescheduleStatePayload, 

86 TIRetryStatePayload, 

87 TIRunContext, 

88 TISkippedDownstreamTasksStatePayload, 

89 TISuccessStatePayload, 

90 TriggerDAGRunPayload, 

91 UpdateHITLDetailPayload, 

92 VariableResponse, 

93 XComResponse, 

94 XComSequenceIndexResponse, 

95 XComSequenceSliceResponse, 

96) 

97from airflow.sdk.exceptions import ErrorType 

98 

99try: 

100 from socket import recv_fds 

101except ImportError: 

102 # Available on Unix and Windows (so "everywhere") but lets be safe 

103 recv_fds = None # type: ignore[assignment] 

104 

105 

106if TYPE_CHECKING: 

107 from structlog.typing import FilteringBoundLogger as Logger 

108 

109SendMsgType = TypeVar("SendMsgType", bound=BaseModel) 

110ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) 

111 

112 

113def _msgpack_enc_hook(obj: Any) -> Any: 

114 import pendulum 

115 

116 if isinstance(obj, pendulum.DateTime): 

117 # convert the pendulm Datetime subclass into a raw datetime so that msgspec can use it's native 

118 # encoding 

119 return datetime( 

120 obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo 

121 ) 

122 if isinstance(obj, Path): 

123 return str(obj) 

124 if isinstance(obj, BaseModel): 

125 return obj.model_dump(exclude_unset=True) 

126 

127 # Raise a NotImplementedError for other types 

128 raise NotImplementedError(f"Objects of type {type(obj)} are not supported") 

129 

130 

131def _new_encoder() -> msgspec.msgpack.Encoder: 

132 return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook) 

133 

134 

135class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): 

136 id: int 

137 """ 

138 The request id, set by the sender. 

139 

140 This is used to allow "pipeling" of requests and to be able to tie response to requests, which is 

141 particularly useful in the Triggerer where multiple async tasks can send a requests concurrently. 

142 """ 

143 body: dict[str, Any] | None 

144 

145 req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder() 

146 

147 def as_bytes(self) -> bytearray: 

148 # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration 

149 buffer = bytearray(256) 

150 

151 self.req_encoder.encode_into(self, buffer, 4) 

152 

153 n = len(buffer) - 4 

154 if n >= 2**32: 

155 raise OverflowError(f"Cannot send messages larger than 4GiB {n=}") 

156 buffer[:4] = n.to_bytes(4, byteorder="big") 

157 

158 return buffer 

159 

160 

161class _ResponseFrame(_RequestFrame, frozen=True): 

162 id: int 

163 """ 

164 The id of the request this is a response to 

165 """ 

166 body: dict[str, Any] | None = None 

167 error: dict[str, Any] | None = None 

168 

169 

170@attrs.define() 

171class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): 

172 """Handle communication between the task in this process and the supervisor parent process.""" 

173 

174 log: Logger = attrs.field(repr=False, factory=structlog.get_logger) 

175 socket: socket = attrs.field(factory=lambda: socket(fileno=0)) 

176 

177 resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field( 

178 factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False 

179 ) 

180 

181 id_counter: Iterator[int] = attrs.field(factory=itertools.count) 

182 

183 # We could be "clever" here and set the default to this based type parameters and a custom 

184 # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a 

185 # "sort of wrong default" 

186 body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) 

187 

188 err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) 

189 

190 # Threading lock for sync operations 

191 _thread_lock: threading.Lock = attrs.field(factory=threading.Lock, repr=False) 

192 # Async lock for async operations 

193 _async_lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False) 

194 

195 def send(self, msg: SendMsgType) -> ReceiveMsgType | None: 

196 """Send a request to the parent and block until the response is received.""" 

197 frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) 

198 frame_bytes = frame.as_bytes() 

199 

200 # We must make sure sockets aren't intermixed between sync and async calls, 

201 # thus we need a dual locking mechanism to ensure that. 

202 with self._thread_lock: 

203 self.socket.sendall(frame_bytes) 

204 if isinstance(msg, ResendLoggingFD): 

205 if recv_fds is None: 

206 return None 

207 # We need special handling here! The server can't send us the fd number, as the number on the 

208 # supervisor will be different to in this process, so we have to mutate the message ourselves here. 

209 frame, fds = self._read_frame(maxfds=1) 

210 resp = self._from_frame(frame) 

211 if TYPE_CHECKING: 

212 assert isinstance(resp, SentFDs) 

213 resp.fds = fds 

214 # Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not 

215 # always be in the return type union 

216 return resp # type: ignore[return-value] 

217 

218 return self._get_response() 

219 

220 async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None: 

221 """ 

222 Send a request to the parent without blocking. 

223 

224 Uses async lock for coroutine safety and thread lock for socket safety. 

225 """ 

226 frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) 

227 frame_bytes = frame.as_bytes() 

228 

229 async with self._async_lock: 

230 # Acquire the threading lock without blocking the event loop 

231 loop = asyncio.get_running_loop() 

232 await loop.run_in_executor(None, self._thread_lock.acquire) 

233 try: 

234 # Async write to socket 

235 await loop.sock_sendall(self.socket, frame_bytes) 

236 

237 if isinstance(msg, ResendLoggingFD): 

238 if recv_fds is None: 

239 return None 

240 # Blocking read in a thread 

241 frame, fds = await asyncio.to_thread(self._read_frame, maxfds=1) 

242 resp = self._from_frame(frame) 

243 if TYPE_CHECKING: 

244 assert isinstance(resp, SentFDs) 

245 resp.fds = fds 

246 return resp # type: ignore[return-value] 

247 

248 # Normal blocking read in a thread 

249 frame = await asyncio.to_thread(self._read_frame) 

250 return self._from_frame(frame) 

251 finally: 

252 self._thread_lock.release() 

253 

254 @overload 

255 def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ... 

256 

257 @overload 

258 def _read_frame(self, maxfds: int) -> tuple[_ResponseFrame, list[int]]: ... 

259 

260 def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[int]] | _ResponseFrame: 

261 """ 

262 Get a message from the parent. 

263 

264 This will block until the message has been received. 

265 """ 

266 if self.socket: 

267 self.socket.setblocking(True) 

268 fds = None 

269 if maxfds: 

270 len_bytes, fds, flag, address = recv_fds(self.socket, 4, maxfds) 

271 else: 

272 len_bytes = self.socket.recv(4) 

273 

274 if len_bytes == b"": 

275 raise EOFError("Request socket closed before length") 

276 

277 length = int.from_bytes(len_bytes, byteorder="big") 

278 

279 buffer = bytearray(length) 

280 mv = memoryview(buffer) 

281 

282 pos = 0 

283 while pos < length: 

284 nread = self.socket.recv_into(mv[pos:]) 

285 if nread == 0: 

286 raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})") 

287 pos += nread 

288 

289 resp = self.resp_decoder.decode(mv) 

290 if maxfds: 

291 return resp, fds or [] 

292 return resp 

293 

294 def _from_frame(self, frame) -> ReceiveMsgType | None: 

295 from airflow.sdk.exceptions import AirflowRuntimeError 

296 

297 if frame.error is not None: 

298 err = self.err_decoder.validate_python(frame.error) 

299 raise AirflowRuntimeError(error=err) 

300 

301 if frame.body is None: 

302 return None 

303 

304 try: 

305 return self.body_decoder.validate_python(frame.body) 

306 except Exception: 

307 self.log.exception("Unable to decode message") 

308 raise 

309 

310 def _get_response(self) -> ReceiveMsgType | None: 

311 frame = self._read_frame() 

312 return self._from_frame(frame) 

313 

314 

315class StartupDetails(BaseModel): 

316 model_config = ConfigDict(arbitrary_types_allowed=True) 

317 

318 ti: TaskInstance 

319 dag_rel_path: str 

320 bundle_info: BundleInfo 

321 start_date: datetime 

322 ti_context: TIRunContext 

323 sentry_integration: str 

324 type: Literal["StartupDetails"] = "StartupDetails" 

325 

326 

327class AssetResult(AssetResponse): 

328 """Response to ReadXCom request.""" 

329 

330 type: Literal["AssetResult"] = "AssetResult" 

331 

332 @classmethod 

333 def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult: 

334 """ 

335 Get AssetResult from AssetResponse. 

336 

337 AssetResponse is autogenerated from the API schema, so we need to convert it to AssetResult 

338 for communication between the Supervisor and the task process. 

339 """ 

340 # Exclude defaults to avoid sending unnecessary data 

341 # Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True 

342 # to avoid sending unset fields (which are defaults in our case). 

343 return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult") 

344 

345 

346@attrs.define(kw_only=True) 

347class AssetEventSourceTaskInstance: 

348 """Used in AssetEventResult.""" 

349 

350 dag_run: DagRun 

351 task_id: str 

352 map_index: int 

353 

354 @property 

355 def dag_id(self) -> str: 

356 return self.dag_run.dag_id 

357 

358 @property 

359 def run_id(self) -> str: 

360 return self.dag_run.run_id 

361 

362 def xcom_pull(self, *, key: str = "return_value", default: Any = None) -> Any: 

363 from airflow.sdk.execution_time.xcom import XCom 

364 

365 if (value := XCom.get_value(ti_key=self, key=key)) is None: 

366 return default 

367 return value 

368 

369 

370def _fetch_dag_run(*, dag_id: str, run_id: str) -> DagRun: 

371 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

372 

373 response = SUPERVISOR_COMMS.send(GetDagRun(dag_id=dag_id, run_id=run_id)) 

374 if TYPE_CHECKING: 

375 assert isinstance(response, DagRunResult) 

376 return response 

377 

378 

379class AssetEventResult(AssetEventResponse): 

380 """Used in AssetEventsResult.""" 

381 

382 @classmethod 

383 def from_asset_event_response(cls, asset_event_response: AssetEventResponse) -> AssetEventResult: 

384 return cls(**asset_event_response.model_dump(exclude_defaults=True)) 

385 

386 @cached_property 

387 def source_dag_run(self) -> DagRun | None: 

388 if not self.source_dag_id or not self.source_run_id: 

389 return None 

390 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id) 

391 

392 @cached_property 

393 def source_task_instance(self) -> AssetEventSourceTaskInstance | None: 

394 if self.source_task_id is None or self.source_map_index is None: 

395 return None 

396 if (dag_run := self.source_dag_run) is None: 

397 return None 

398 return AssetEventSourceTaskInstance( 

399 dag_run=dag_run, 

400 task_id=self.source_task_id, 

401 map_index=self.source_map_index, 

402 ) 

403 

404 

405class AssetEventsResult(AssetEventsResponse): 

406 """Response to GetAssetEvent request.""" 

407 

408 type: Literal["AssetEventsResult"] = "AssetEventsResult" 

409 

410 @classmethod 

411 def from_asset_events_response(cls, asset_events_response: AssetEventsResponse) -> AssetEventsResult: 

412 """ 

413 Get AssetEventsResult from AssetEventsResponse. 

414 

415 AssetEventsResponse is autogenerated from the API schema, so we need to convert it to AssetEventsResponse 

416 for communication between the Supervisor and the task process. 

417 """ 

418 # Exclude defaults to avoid sending unnecessary data 

419 # Pass the type as AssetEventsResult explicitly so we can then call model_dump_json with exclude_unset=True 

420 # to avoid sending unset fields (which are defaults in our case). 

421 return cls( 

422 **asset_events_response.model_dump(exclude_defaults=True), 

423 type="AssetEventsResult", 

424 ) 

425 

426 def iter_asset_event_results(self) -> Iterator[AssetEventResult]: 

427 return (AssetEventResult.from_asset_event_response(event) for event in self.asset_events) 

428 

429 

430class AssetEventDagRunReferenceResult(AssetEventDagRunReference): 

431 @classmethod 

432 def from_asset_event_dag_run_reference( 

433 cls, 

434 asset_event_dag_run_reference: AssetEventDagRunReference, 

435 ) -> AssetEventDagRunReferenceResult: 

436 return cls(**asset_event_dag_run_reference.model_dump(exclude_defaults=True)) 

437 

438 @cached_property 

439 def source_dag_run(self) -> DagRun | None: 

440 if not self.source_dag_id or not self.source_run_id: 

441 return None 

442 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id) 

443 

444 @cached_property 

445 def source_task_instance(self) -> AssetEventSourceTaskInstance | None: 

446 if self.source_task_id is None or self.source_map_index is None: 

447 return None 

448 if (dag_run := self.source_dag_run) is None: 

449 return None 

450 return AssetEventSourceTaskInstance( 

451 dag_run=dag_run, 

452 task_id=self.source_task_id, 

453 map_index=self.source_map_index, 

454 ) 

455 

456 

457class InactiveAssetsResult(InactiveAssetsResponse): 

458 """Response of InactiveAssets requests.""" 

459 

460 type: Literal["InactiveAssetsResult"] = "InactiveAssetsResult" 

461 

462 @classmethod 

463 def from_inactive_assets_response( 

464 cls, inactive_assets_response: InactiveAssetsResponse 

465 ) -> InactiveAssetsResult: 

466 """ 

467 Get InactiveAssetsResponse from InactiveAssetsResult. 

468 

469 InactiveAssetsResponse is autogenerated from the API schema, so we need to convert it to InactiveAssetsResult 

470 for communication between the Supervisor and the task process. 

471 """ 

472 return cls(**inactive_assets_response.model_dump(exclude_defaults=True), type="InactiveAssetsResult") 

473 

474 

475class XComResult(XComResponse): 

476 """Response to ReadXCom request.""" 

477 

478 type: Literal["XComResult"] = "XComResult" 

479 

480 @classmethod 

481 def from_xcom_response(cls, xcom_response: XComResponse) -> XComResult: 

482 """ 

483 Get XComResult from XComResponse. 

484 

485 XComResponse is autogenerated from the API schema, so we need to convert it to XComResult 

486 for communication between the Supervisor and the task process. 

487 """ 

488 return cls(**xcom_response.model_dump(exclude_defaults=True), type="XComResult") 

489 

490 

491class XComCountResponse(BaseModel): 

492 len: int 

493 type: Literal["XComLengthResponse"] = "XComLengthResponse" 

494 

495 

496class XComSequenceIndexResult(BaseModel): 

497 root: JsonValue 

498 type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult" 

499 

500 @classmethod 

501 def from_response(cls, response: XComSequenceIndexResponse) -> XComSequenceIndexResult: 

502 return cls(root=response.root, type="XComSequenceIndexResult") 

503 

504 

505class XComSequenceSliceResult(BaseModel): 

506 root: list[JsonValue] 

507 type: Literal["XComSequenceSliceResult"] = "XComSequenceSliceResult" 

508 

509 @classmethod 

510 def from_response(cls, response: XComSequenceSliceResponse) -> XComSequenceSliceResult: 

511 return cls(root=response.root, type="XComSequenceSliceResult") 

512 

513 

514class ConnectionResult(ConnectionResponse): 

515 type: Literal["ConnectionResult"] = "ConnectionResult" 

516 

517 @classmethod 

518 def from_conn_response(cls, connection_response: ConnectionResponse) -> ConnectionResult: 

519 """ 

520 Get ConnectionResult from ConnectionResponse. 

521 

522 ConnectionResponse is autogenerated from the API schema, so we need to convert it to ConnectionResult 

523 for communication between the Supervisor and the task process. 

524 """ 

525 # Exclude defaults to avoid sending unnecessary data 

526 # Pass the type as ConnectionResult explicitly so we can then call model_dump_json with exclude_unset=True 

527 # to avoid sending unset fields (which are defaults in our case). 

528 return cls( 

529 **connection_response.model_dump(exclude_defaults=True, by_alias=True), type="ConnectionResult" 

530 ) 

531 

532 

533class VariableResult(VariableResponse): 

534 type: Literal["VariableResult"] = "VariableResult" 

535 

536 @classmethod 

537 def from_variable_response(cls, variable_response: VariableResponse) -> VariableResult: 

538 """ 

539 Get VariableResult from VariableResponse. 

540 

541 VariableResponse is autogenerated from the API schema, so we need to convert it to VariableResult 

542 for communication between the Supervisor and the task process. 

543 """ 

544 return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult") 

545 

546 

547class DagRunResult(DagRun): 

548 type: Literal["DagRunResult"] = "DagRunResult" 

549 

550 @classmethod 

551 def from_api_response(cls, dr_response: DagRun) -> DagRunResult: 

552 """ 

553 Create result class from API Response. 

554 

555 API Response is autogenerated from the API schema, so we need to convert it to Result 

556 for communication between the Supervisor and the task process since it needs a 

557 discriminator field. 

558 """ 

559 return cls(**dr_response.model_dump(exclude_defaults=True), type="DagRunResult") 

560 

561 

562class DagRunStateResult(DagRunStateResponse): 

563 type: Literal["DagRunStateResult"] = "DagRunStateResult" 

564 

565 # TODO: Create a convert api_response to result classes so we don't need to do this 

566 # for all the classes above 

567 @classmethod 

568 def from_api_response(cls, dr_state_response: DagRunStateResponse) -> DagRunStateResult: 

569 """ 

570 Create result class from API Response. 

571 

572 API Response is autogenerated from the API schema, so we need to convert it to Result 

573 for communication between the Supervisor and the task process since it needs a 

574 discriminator field. 

575 """ 

576 return cls(**dr_state_response.model_dump(exclude_defaults=True), type="DagRunStateResult") 

577 

578 

579class PreviousDagRunResult(BaseModel): 

580 """Response containing previous Dag run information.""" 

581 

582 dag_run: DagRun | None = None 

583 type: Literal["PreviousDagRunResult"] = "PreviousDagRunResult" 

584 

585 

586class PreviousTIResult(BaseModel): 

587 """Response containing previous task instance data.""" 

588 

589 task_instance: PreviousTIResponse | None = None 

590 type: Literal["PreviousTIResult"] = "PreviousTIResult" 

591 

592 

593class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse): 

594 type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult" 

595 

596 @classmethod 

597 def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse) -> PrevSuccessfulDagRunResult: 

598 """ 

599 Get a result object from response object. 

600 

601 PrevSuccessfulDagRunResponse is autogenerated from the API schema, so we need to convert it to 

602 PrevSuccessfulDagRunResult for communication between the Supervisor and the task process. 

603 """ 

604 return cls(**prev_dag_run.model_dump(exclude_defaults=True), type="PrevSuccessfulDagRunResult") 

605 

606 

607class TaskRescheduleStartDate(BaseModel): 

608 """Response containing the first reschedule date for a task instance.""" 

609 

610 start_date: AwareDatetime | None 

611 type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate" 

612 

613 

614class TICount(BaseModel): 

615 """Response containing count of Task Instances matching certain filters.""" 

616 

617 count: int 

618 type: Literal["TICount"] = "TICount" 

619 

620 

621class TaskStatesResult(TaskStatesResponse): 

622 type: Literal["TaskStatesResult"] = "TaskStatesResult" 

623 

624 @classmethod 

625 def from_api_response(cls, task_states_response: TaskStatesResponse) -> TaskStatesResult: 

626 """ 

627 Create result class from API Response. 

628 

629 API Response is autogenerated from the API schema, so we need to convert it to Result 

630 for communication between the Supervisor and the task process since it needs a 

631 discriminator field. 

632 """ 

633 return cls(**task_states_response.model_dump(exclude_defaults=True), type="TaskStatesResult") 

634 

635 

636class TaskBreadcrumbsResult(TaskBreadcrumbsResponse): 

637 type: Literal["TaskBreadcrumbsResult"] = "TaskBreadcrumbsResult" 

638 

639 @classmethod 

640 def from_api_response(cls, response: TaskBreadcrumbsResponse) -> TaskBreadcrumbsResult: 

641 """ 

642 Create result class from API Response. 

643 

644 API Response is autogenerated from the API schema, so we need to convert 

645 it to Result for communication between the Supervisor and the task 

646 process since it needs a discriminator field. 

647 """ 

648 return cls(**response.model_dump(exclude_defaults=True), type="TaskBreadcrumbsResult") 

649 

650 

651class DRCount(BaseModel): 

652 """Response containing count of Dag Runs matching certain filters.""" 

653 

654 count: int 

655 type: Literal["DRCount"] = "DRCount" 

656 

657 

658class ErrorResponse(BaseModel): 

659 error: ErrorType = ErrorType.GENERIC_ERROR 

660 detail: dict | None = None 

661 type: Literal["ErrorResponse"] = "ErrorResponse" 

662 

663 

664class OKResponse(BaseModel): 

665 ok: bool 

666 type: Literal["OKResponse"] = "OKResponse" 

667 

668 

669class SentFDs(BaseModel): 

670 type: Literal["SentFDs"] = "SentFDs" 

671 fds: list[int] 

672 

673 

674class CreateHITLDetailPayload(HITLDetailRequest): 

675 """Add the input request part of a Human-in-the-loop response.""" 

676 

677 type: Literal["CreateHITLDetailPayload"] = "CreateHITLDetailPayload" 

678 

679 

680class HITLDetailRequestResult(HITLDetailRequest): 

681 """Response to CreateHITLDetailPayload request.""" 

682 

683 type: Literal["HITLDetailRequestResult"] = "HITLDetailRequestResult" 

684 

685 @classmethod 

686 def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequestResult: 

687 """ 

688 Get HITLDetailRequestResult from HITLDetailRequest (API response). 

689 

690 HITLDetailRequest is the API response model. We convert it to HITLDetailRequestResult 

691 for communication between the Supervisor and task process, adding the discriminator field 

692 required for the tagged union deserialization. 

693 """ 

694 return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult") 

695 

696 

697ToTask = Annotated[ 

698 AssetResult 

699 | AssetEventsResult 

700 | ConnectionResult 

701 | DagRunResult 

702 | DagRunStateResult 

703 | DRCount 

704 | ErrorResponse 

705 | PrevSuccessfulDagRunResult 

706 | PreviousTIResult 

707 | SentFDs 

708 | StartupDetails 

709 | TaskRescheduleStartDate 

710 | TICount 

711 | TaskBreadcrumbsResult 

712 | TaskStatesResult 

713 | VariableResult 

714 | XComCountResponse 

715 | XComResult 

716 | XComSequenceIndexResult 

717 | XComSequenceSliceResult 

718 | InactiveAssetsResult 

719 | CreateHITLDetailPayload 

720 | HITLDetailRequestResult 

721 | OKResponse 

722 | PreviousDagRunResult, 

723 Field(discriminator="type"), 

724] 

725 

726 

727class TaskState(BaseModel): 

728 """ 

729 Update a task's state. 

730 

731 If a process exits without sending one of these the state will be derived from the exit code: 

732 - 0 = SUCCESS 

733 - anything else = FAILED 

734 """ 

735 

736 state: Literal[ 

737 TaskInstanceState.FAILED, 

738 TaskInstanceState.SKIPPED, 

739 TaskInstanceState.REMOVED, 

740 ] 

741 end_date: datetime | None = None 

742 type: Literal["TaskState"] = "TaskState" 

743 rendered_map_index: str | None = None 

744 

745 

746class SucceedTask(TISuccessStatePayload): 

747 """Update a task's state to success. Includes task_outlets and outlet_events for registering asset events.""" 

748 

749 type: Literal["SucceedTask"] = "SucceedTask" 

750 

751 

752class DeferTask(TIDeferredStatePayload): 

753 """Update a task instance state to deferred.""" 

754 

755 type: Literal["DeferTask"] = "DeferTask" 

756 

757 

758class RetryTask(TIRetryStatePayload): 

759 """Update a task instance state to up_for_retry.""" 

760 

761 type: Literal["RetryTask"] = "RetryTask" 

762 

763 

764class RescheduleTask(TIRescheduleStatePayload): 

765 """Update a task instance state to reschedule/up_for_reschedule.""" 

766 

767 type: Literal["RescheduleTask"] = "RescheduleTask" 

768 

769 

770class SkipDownstreamTasks(TISkippedDownstreamTasksStatePayload): 

771 """Update state of downstream tasks within a task instance to 'skipped', while updating current task to success state.""" 

772 

773 type: Literal["SkipDownstreamTasks"] = "SkipDownstreamTasks" 

774 

775 

776class GetXCom(BaseModel): 

777 key: str 

778 dag_id: str 

779 run_id: str 

780 task_id: str 

781 map_index: int | None = None 

782 include_prior_dates: bool = False 

783 type: Literal["GetXCom"] = "GetXCom" 

784 

785 

786class GetXComCount(BaseModel): 

787 """Get the number of (mapped) XCom values available.""" 

788 

789 key: str 

790 dag_id: str 

791 run_id: str 

792 task_id: str 

793 type: Literal["GetNumberXComs"] = "GetNumberXComs" 

794 

795 

796class GetXComSequenceItem(BaseModel): 

797 key: str 

798 dag_id: str 

799 run_id: str 

800 task_id: str 

801 offset: int 

802 type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem" 

803 

804 

805class GetXComSequenceSlice(BaseModel): 

806 key: str 

807 dag_id: str 

808 run_id: str 

809 task_id: str 

810 start: int | None 

811 stop: int | None 

812 step: int | None 

813 include_prior_dates: bool = False 

814 type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice" 

815 

816 

817class SetXCom(BaseModel): 

818 key: str 

819 value: JsonValue 

820 dag_id: str 

821 run_id: str 

822 task_id: str 

823 map_index: int | None = None 

824 mapped_length: int | None = None 

825 type: Literal["SetXCom"] = "SetXCom" 

826 

827 

828class DeleteXCom(BaseModel): 

829 key: str 

830 dag_id: str 

831 run_id: str 

832 task_id: str 

833 map_index: int | None = None 

834 type: Literal["DeleteXCom"] = "DeleteXCom" 

835 

836 

837class GetConnection(BaseModel): 

838 conn_id: str 

839 type: Literal["GetConnection"] = "GetConnection" 

840 

841 

842class GetVariable(BaseModel): 

843 key: str 

844 type: Literal["GetVariable"] = "GetVariable" 

845 

846 

847class PutVariable(BaseModel): 

848 key: str 

849 value: str | None 

850 description: str | None 

851 type: Literal["PutVariable"] = "PutVariable" 

852 

853 

854class DeleteVariable(BaseModel): 

855 key: str 

856 type: Literal["DeleteVariable"] = "DeleteVariable" 

857 

858 

859class ResendLoggingFD(BaseModel): 

860 type: Literal["ResendLoggingFD"] = "ResendLoggingFD" 

861 

862 

863class SetRenderedFields(BaseModel): 

864 """Payload for setting RTIF for a task instance.""" 

865 

866 # We are using a BaseModel here compared to server using RootModel because we 

867 # have a discriminator running with "type", and RootModel doesn't support type 

868 

869 rendered_fields: dict[str, JsonValue] 

870 type: Literal["SetRenderedFields"] = "SetRenderedFields" 

871 

872 

873class SetRenderedMapIndex(BaseModel): 

874 """Payload for setting rendered_map_index for a task instance.""" 

875 

876 rendered_map_index: str 

877 type: Literal["SetRenderedMapIndex"] = "SetRenderedMapIndex" 

878 

879 

880class TriggerDagRun(TriggerDAGRunPayload): 

881 dag_id: str 

882 run_id: Annotated[str, Field(title="Dag Run Id")] 

883 type: Literal["TriggerDagRun"] = "TriggerDagRun" 

884 

885 

886class GetDagRun(BaseModel): 

887 dag_id: str 

888 run_id: str 

889 type: Literal["GetDagRun"] = "GetDagRun" 

890 

891 

892class GetDagRunState(BaseModel): 

893 dag_id: str 

894 run_id: str 

895 type: Literal["GetDagRunState"] = "GetDagRunState" 

896 

897 

898class GetPreviousDagRun(BaseModel): 

899 dag_id: str 

900 logical_date: AwareDatetime 

901 state: str | None = None 

902 type: Literal["GetPreviousDagRun"] = "GetPreviousDagRun" 

903 

904 

905class GetPreviousTI(BaseModel): 

906 """Request to get previous task instance.""" 

907 

908 dag_id: str 

909 task_id: str 

910 logical_date: AwareDatetime | None = None 

911 map_index: int = -1 

912 state: TaskInstanceState | None = None 

913 type: Literal["GetPreviousTI"] = "GetPreviousTI" 

914 

915 

916class GetAssetByName(BaseModel): 

917 name: str 

918 type: Literal["GetAssetByName"] = "GetAssetByName" 

919 

920 

921class GetAssetByUri(BaseModel): 

922 uri: str 

923 type: Literal["GetAssetByUri"] = "GetAssetByUri" 

924 

925 

926class GetAssetEventByAsset(BaseModel): 

927 name: str | None 

928 uri: str | None 

929 after: AwareDatetime | None = None 

930 before: AwareDatetime | None = None 

931 limit: int | None = None 

932 ascending: bool = True 

933 type: Literal["GetAssetEventByAsset"] = "GetAssetEventByAsset" 

934 

935 

936class GetAssetEventByAssetAlias(BaseModel): 

937 alias_name: str 

938 after: AwareDatetime | None = None 

939 before: AwareDatetime | None = None 

940 limit: int | None = None 

941 ascending: bool = True 

942 type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias" 

943 

944 

945class ValidateInletsAndOutlets(BaseModel): 

946 ti_id: UUID 

947 type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets" 

948 

949 

950class GetPrevSuccessfulDagRun(BaseModel): 

951 ti_id: UUID 

952 type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun" 

953 

954 

955class GetTaskRescheduleStartDate(BaseModel): 

956 ti_id: UUID 

957 try_number: int = 1 

958 type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate" 

959 

960 

961class GetTICount(BaseModel): 

962 dag_id: str 

963 map_index: int | None = None 

964 task_ids: list[str] | None = None 

965 task_group_id: str | None = None 

966 logical_dates: list[AwareDatetime] | None = None 

967 run_ids: list[str] | None = None 

968 states: list[str] | None = None 

969 type: Literal["GetTICount"] = "GetTICount" 

970 

971 

972class GetTaskStates(BaseModel): 

973 dag_id: str 

974 map_index: int | None = None 

975 task_ids: list[str] | None = None 

976 task_group_id: str | None = None 

977 logical_dates: list[AwareDatetime] | None = None 

978 run_ids: list[str] | None = None 

979 type: Literal["GetTaskStates"] = "GetTaskStates" 

980 

981 

982class GetTaskBreadcrumbs(BaseModel): 

983 dag_id: str 

984 run_id: str 

985 type: Literal["GetTaskBreadcrumbs"] = "GetTaskBreadcrumbs" 

986 

987 

988class GetDRCount(BaseModel): 

989 dag_id: str 

990 logical_dates: list[AwareDatetime] | None = None 

991 run_ids: list[str] | None = None 

992 states: list[str] | None = None 

993 type: Literal["GetDRCount"] = "GetDRCount" 

994 

995 

996class GetHITLDetailResponse(BaseModel): 

997 """Get the response content part of a Human-in-the-loop response.""" 

998 

999 ti_id: UUID 

1000 type: Literal["GetHITLDetailResponse"] = "GetHITLDetailResponse" 

1001 

1002 

1003class UpdateHITLDetail(UpdateHITLDetailPayload): 

1004 """Update the response content part of an existing Human-in-the-loop response.""" 

1005 

1006 type: Literal["UpdateHITLDetail"] = "UpdateHITLDetail" 

1007 

1008 

1009class MaskSecret(BaseModel): 

1010 """Add a new value to be redacted in task logs.""" 

1011 

1012 # This is needed since calls to `mask_secret` in the Task process will otherwise only add the mask value 

1013 # to the child process, but the redaction happens in the parent. 

1014 # We cannot use `string | Iterable | dict here` (would be more intuitive) because bug in Pydantic 

1015 # https://github.com/pydantic/pydantic/issues/9541 turns iterable into a ValidatorIterator 

1016 value: JsonValue 

1017 name: str | None = None 

1018 type: Literal["MaskSecret"] = "MaskSecret" 

1019 

1020 

1021ToSupervisor = Annotated[ 

1022 DeferTask 

1023 | DeleteXCom 

1024 | GetAssetByName 

1025 | GetAssetByUri 

1026 | GetAssetEventByAsset 

1027 | GetAssetEventByAssetAlias 

1028 | GetConnection 

1029 | GetDagRun 

1030 | GetDagRunState 

1031 | GetDRCount 

1032 | GetPrevSuccessfulDagRun 

1033 | GetPreviousDagRun 

1034 | GetPreviousTI 

1035 | GetTaskRescheduleStartDate 

1036 | GetTICount 

1037 | GetTaskBreadcrumbs 

1038 | GetTaskStates 

1039 | GetVariable 

1040 | GetXCom 

1041 | GetXComCount 

1042 | GetXComSequenceItem 

1043 | GetXComSequenceSlice 

1044 | PutVariable 

1045 | RescheduleTask 

1046 | RetryTask 

1047 | SetRenderedFields 

1048 | SetRenderedMapIndex 

1049 | SetXCom 

1050 | SkipDownstreamTasks 

1051 | SucceedTask 

1052 | ValidateInletsAndOutlets 

1053 | TaskState 

1054 | TriggerDagRun 

1055 | DeleteVariable 

1056 | ResendLoggingFD 

1057 | CreateHITLDetailPayload 

1058 | UpdateHITLDetail 

1059 | GetHITLDetailResponse 

1060 | MaskSecret, 

1061 Field(discriminator="type"), 

1062]