1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18r"""
19Communication protocol between the Supervisor and the task process
20==================================================================
21
22* All communication is done over the subprocesses stdin in the form of a binary length-prefixed msgpack frame
23 (4 byte, big-endian length, followed by the msgpack-encoded _RequestFrame.) Each side uses this same
24 encoding
25* Log Messages from the subprocess are sent over the dedicated logs socket (which is line-based JSON)
26* No messages are sent to task process except in response to a request. (This is because the task process will
27 be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom
28 value etc.)
29* Every request returns a response, even if the frame is otherwise empty.
30* Requests are written by the subprocess to fd0/stdin. This is making use of the fact that stdin is a
31 bi-directional socket, and thus we can write to it and don't need a dedicated extra socket for sending
32 requests.
33
34The reason this communication protocol exists, rather than the task process speaking directly to the Task
35Execution API server is because:
36
371. To reduce the number of concurrent HTTP connections on the API server.
38
39 The supervisor already has to speak to that to heartbeat the running Task, so having the task speak to its
40 parent process and having all API traffic go through that means that the number of HTTP connections is
41 "halved". (Not every task will make API calls, so it's not always halved, but it is reduced.)
42
432. This means that the user Task code doesn't ever directly see the task identity JWT token.
44
45 This is a short lived token tied to one specific task instance try, so it being leaked/exfiltrated is not a
46 large risk, but it's easy to not give it to the user code, so lets do that.
47""" # noqa: D400, D205
48
49from __future__ import annotations
50
51import asyncio
52import itertools
53import threading
54from collections.abc import Iterator
55from datetime import datetime
56from functools import cached_property
57from pathlib import Path
58from socket import socket
59from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload
60from uuid import UUID
61
62import attrs
63import msgspec
64import structlog
65from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
66
67from airflow.sdk.api.datamodels._generated import (
68 AssetEventDagRunReference,
69 AssetEventResponse,
70 AssetEventsResponse,
71 AssetResponse,
72 BundleInfo,
73 ConnectionResponse,
74 DagRun,
75 DagRunStateResponse,
76 HITLDetailRequest,
77 InactiveAssetsResponse,
78 PreviousTIResponse,
79 PrevSuccessfulDagRunResponse,
80 TaskBreadcrumbsResponse,
81 TaskInstance,
82 TaskInstanceState,
83 TaskStatesResponse,
84 TIDeferredStatePayload,
85 TIRescheduleStatePayload,
86 TIRetryStatePayload,
87 TIRunContext,
88 TISkippedDownstreamTasksStatePayload,
89 TISuccessStatePayload,
90 TriggerDAGRunPayload,
91 UpdateHITLDetailPayload,
92 VariableResponse,
93 XComResponse,
94 XComSequenceIndexResponse,
95 XComSequenceSliceResponse,
96)
97from airflow.sdk.exceptions import ErrorType
98
99try:
100 from socket import recv_fds
101except ImportError:
102 # Available on Unix and Windows (so "everywhere") but lets be safe
103 recv_fds = None # type: ignore[assignment]
104
105
106if TYPE_CHECKING:
107 from structlog.typing import FilteringBoundLogger as Logger
108
109SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
110ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
111
112
113def _msgpack_enc_hook(obj: Any) -> Any:
114 import pendulum
115
116 if isinstance(obj, pendulum.DateTime):
117 # convert the pendulm Datetime subclass into a raw datetime so that msgspec can use it's native
118 # encoding
119 return datetime(
120 obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo
121 )
122 if isinstance(obj, Path):
123 return str(obj)
124 if isinstance(obj, BaseModel):
125 return obj.model_dump(exclude_unset=True)
126
127 # Raise a NotImplementedError for other types
128 raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
129
130
131def _new_encoder() -> msgspec.msgpack.Encoder:
132 return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook)
133
134
135class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True):
136 id: int
137 """
138 The request id, set by the sender.
139
140 This is used to allow "pipeling" of requests and to be able to tie response to requests, which is
141 particularly useful in the Triggerer where multiple async tasks can send a requests concurrently.
142 """
143 body: dict[str, Any] | None
144
145 req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder()
146
147 def as_bytes(self) -> bytearray:
148 # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration
149 buffer = bytearray(256)
150
151 self.req_encoder.encode_into(self, buffer, 4)
152
153 n = len(buffer) - 4
154 if n >= 2**32:
155 raise OverflowError(f"Cannot send messages larger than 4GiB {n=}")
156 buffer[:4] = n.to_bytes(4, byteorder="big")
157
158 return buffer
159
160
161class _ResponseFrame(_RequestFrame, frozen=True):
162 id: int
163 """
164 The id of the request this is a response to
165 """
166 body: dict[str, Any] | None = None
167 error: dict[str, Any] | None = None
168
169
170@attrs.define()
171class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
172 """Handle communication between the task in this process and the supervisor parent process."""
173
174 log: Logger = attrs.field(repr=False, factory=structlog.get_logger)
175 socket: socket = attrs.field(factory=lambda: socket(fileno=0))
176
177 resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field(
178 factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False
179 )
180
181 id_counter: Iterator[int] = attrs.field(factory=itertools.count)
182
183 # We could be "clever" here and set the default to this based type parameters and a custom
184 # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a
185 # "sort of wrong default"
186 body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False)
187
188 err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False)
189
190 # Threading lock for sync operations
191 _thread_lock: threading.Lock = attrs.field(factory=threading.Lock, repr=False)
192 # Async lock for async operations
193 _async_lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)
194
195 def send(self, msg: SendMsgType) -> ReceiveMsgType | None:
196 """Send a request to the parent and block until the response is received."""
197 frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
198 frame_bytes = frame.as_bytes()
199
200 # We must make sure sockets aren't intermixed between sync and async calls,
201 # thus we need a dual locking mechanism to ensure that.
202 with self._thread_lock:
203 self.socket.sendall(frame_bytes)
204 if isinstance(msg, ResendLoggingFD):
205 if recv_fds is None:
206 return None
207 # We need special handling here! The server can't send us the fd number, as the number on the
208 # supervisor will be different to in this process, so we have to mutate the message ourselves here.
209 frame, fds = self._read_frame(maxfds=1)
210 resp = self._from_frame(frame)
211 if TYPE_CHECKING:
212 assert isinstance(resp, SentFDs)
213 resp.fds = fds
214 # Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not
215 # always be in the return type union
216 return resp # type: ignore[return-value]
217
218 return self._get_response()
219
220 async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None:
221 """
222 Send a request to the parent without blocking.
223
224 Uses async lock for coroutine safety and thread lock for socket safety.
225 """
226 frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
227 frame_bytes = frame.as_bytes()
228
229 async with self._async_lock:
230 # Acquire the threading lock without blocking the event loop
231 loop = asyncio.get_running_loop()
232 await loop.run_in_executor(None, self._thread_lock.acquire)
233 try:
234 # Async write to socket
235 await loop.sock_sendall(self.socket, frame_bytes)
236
237 if isinstance(msg, ResendLoggingFD):
238 if recv_fds is None:
239 return None
240 # Blocking read in a thread
241 frame, fds = await asyncio.to_thread(self._read_frame, maxfds=1)
242 resp = self._from_frame(frame)
243 if TYPE_CHECKING:
244 assert isinstance(resp, SentFDs)
245 resp.fds = fds
246 return resp # type: ignore[return-value]
247
248 # Normal blocking read in a thread
249 frame = await asyncio.to_thread(self._read_frame)
250 return self._from_frame(frame)
251 finally:
252 self._thread_lock.release()
253
254 @overload
255 def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ...
256
257 @overload
258 def _read_frame(self, maxfds: int) -> tuple[_ResponseFrame, list[int]]: ...
259
260 def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[int]] | _ResponseFrame:
261 """
262 Get a message from the parent.
263
264 This will block until the message has been received.
265 """
266 if self.socket:
267 self.socket.setblocking(True)
268 fds = None
269 if maxfds:
270 len_bytes, fds, flag, address = recv_fds(self.socket, 4, maxfds)
271 else:
272 len_bytes = self.socket.recv(4)
273
274 if len_bytes == b"":
275 raise EOFError("Request socket closed before length")
276
277 length = int.from_bytes(len_bytes, byteorder="big")
278
279 buffer = bytearray(length)
280 mv = memoryview(buffer)
281
282 pos = 0
283 while pos < length:
284 nread = self.socket.recv_into(mv[pos:])
285 if nread == 0:
286 raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})")
287 pos += nread
288
289 resp = self.resp_decoder.decode(mv)
290 if maxfds:
291 return resp, fds or []
292 return resp
293
294 def _from_frame(self, frame) -> ReceiveMsgType | None:
295 from airflow.sdk.exceptions import AirflowRuntimeError
296
297 if frame.error is not None:
298 err = self.err_decoder.validate_python(frame.error)
299 raise AirflowRuntimeError(error=err)
300
301 if frame.body is None:
302 return None
303
304 try:
305 return self.body_decoder.validate_python(frame.body)
306 except Exception:
307 self.log.exception("Unable to decode message")
308 raise
309
310 def _get_response(self) -> ReceiveMsgType | None:
311 frame = self._read_frame()
312 return self._from_frame(frame)
313
314
315class StartupDetails(BaseModel):
316 model_config = ConfigDict(arbitrary_types_allowed=True)
317
318 ti: TaskInstance
319 dag_rel_path: str
320 bundle_info: BundleInfo
321 start_date: datetime
322 ti_context: TIRunContext
323 sentry_integration: str
324 type: Literal["StartupDetails"] = "StartupDetails"
325
326
327class AssetResult(AssetResponse):
328 """Response to ReadXCom request."""
329
330 type: Literal["AssetResult"] = "AssetResult"
331
332 @classmethod
333 def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult:
334 """
335 Get AssetResult from AssetResponse.
336
337 AssetResponse is autogenerated from the API schema, so we need to convert it to AssetResult
338 for communication between the Supervisor and the task process.
339 """
340 # Exclude defaults to avoid sending unnecessary data
341 # Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True
342 # to avoid sending unset fields (which are defaults in our case).
343 return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult")
344
345
346@attrs.define(kw_only=True)
347class AssetEventSourceTaskInstance:
348 """Used in AssetEventResult."""
349
350 dag_run: DagRun
351 task_id: str
352 map_index: int
353
354 @property
355 def dag_id(self) -> str:
356 return self.dag_run.dag_id
357
358 @property
359 def run_id(self) -> str:
360 return self.dag_run.run_id
361
362 def xcom_pull(self, *, key: str = "return_value", default: Any = None) -> Any:
363 from airflow.sdk.execution_time.xcom import XCom
364
365 if (value := XCom.get_value(ti_key=self, key=key)) is None:
366 return default
367 return value
368
369
370def _fetch_dag_run(*, dag_id: str, run_id: str) -> DagRun:
371 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
372
373 response = SUPERVISOR_COMMS.send(GetDagRun(dag_id=dag_id, run_id=run_id))
374 if TYPE_CHECKING:
375 assert isinstance(response, DagRunResult)
376 return response
377
378
379class AssetEventResult(AssetEventResponse):
380 """Used in AssetEventsResult."""
381
382 @classmethod
383 def from_asset_event_response(cls, asset_event_response: AssetEventResponse) -> AssetEventResult:
384 return cls(**asset_event_response.model_dump(exclude_defaults=True))
385
386 @cached_property
387 def source_dag_run(self) -> DagRun | None:
388 if not self.source_dag_id or not self.source_run_id:
389 return None
390 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id)
391
392 @cached_property
393 def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
394 if self.source_task_id is None or self.source_map_index is None:
395 return None
396 if (dag_run := self.source_dag_run) is None:
397 return None
398 return AssetEventSourceTaskInstance(
399 dag_run=dag_run,
400 task_id=self.source_task_id,
401 map_index=self.source_map_index,
402 )
403
404
405class AssetEventsResult(AssetEventsResponse):
406 """Response to GetAssetEvent request."""
407
408 type: Literal["AssetEventsResult"] = "AssetEventsResult"
409
410 @classmethod
411 def from_asset_events_response(cls, asset_events_response: AssetEventsResponse) -> AssetEventsResult:
412 """
413 Get AssetEventsResult from AssetEventsResponse.
414
415 AssetEventsResponse is autogenerated from the API schema, so we need to convert it to AssetEventsResponse
416 for communication between the Supervisor and the task process.
417 """
418 # Exclude defaults to avoid sending unnecessary data
419 # Pass the type as AssetEventsResult explicitly so we can then call model_dump_json with exclude_unset=True
420 # to avoid sending unset fields (which are defaults in our case).
421 return cls(
422 **asset_events_response.model_dump(exclude_defaults=True),
423 type="AssetEventsResult",
424 )
425
426 def iter_asset_event_results(self) -> Iterator[AssetEventResult]:
427 return (AssetEventResult.from_asset_event_response(event) for event in self.asset_events)
428
429
430class AssetEventDagRunReferenceResult(AssetEventDagRunReference):
431 @classmethod
432 def from_asset_event_dag_run_reference(
433 cls,
434 asset_event_dag_run_reference: AssetEventDagRunReference,
435 ) -> AssetEventDagRunReferenceResult:
436 return cls(**asset_event_dag_run_reference.model_dump(exclude_defaults=True))
437
438 @cached_property
439 def source_dag_run(self) -> DagRun | None:
440 if not self.source_dag_id or not self.source_run_id:
441 return None
442 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id)
443
444 @cached_property
445 def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
446 if self.source_task_id is None or self.source_map_index is None:
447 return None
448 if (dag_run := self.source_dag_run) is None:
449 return None
450 return AssetEventSourceTaskInstance(
451 dag_run=dag_run,
452 task_id=self.source_task_id,
453 map_index=self.source_map_index,
454 )
455
456
457class InactiveAssetsResult(InactiveAssetsResponse):
458 """Response of InactiveAssets requests."""
459
460 type: Literal["InactiveAssetsResult"] = "InactiveAssetsResult"
461
462 @classmethod
463 def from_inactive_assets_response(
464 cls, inactive_assets_response: InactiveAssetsResponse
465 ) -> InactiveAssetsResult:
466 """
467 Get InactiveAssetsResponse from InactiveAssetsResult.
468
469 InactiveAssetsResponse is autogenerated from the API schema, so we need to convert it to InactiveAssetsResult
470 for communication between the Supervisor and the task process.
471 """
472 return cls(**inactive_assets_response.model_dump(exclude_defaults=True), type="InactiveAssetsResult")
473
474
475class XComResult(XComResponse):
476 """Response to ReadXCom request."""
477
478 type: Literal["XComResult"] = "XComResult"
479
480 @classmethod
481 def from_xcom_response(cls, xcom_response: XComResponse) -> XComResult:
482 """
483 Get XComResult from XComResponse.
484
485 XComResponse is autogenerated from the API schema, so we need to convert it to XComResult
486 for communication between the Supervisor and the task process.
487 """
488 return cls(**xcom_response.model_dump(exclude_defaults=True), type="XComResult")
489
490
491class XComCountResponse(BaseModel):
492 len: int
493 type: Literal["XComLengthResponse"] = "XComLengthResponse"
494
495
496class XComSequenceIndexResult(BaseModel):
497 root: JsonValue
498 type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult"
499
500 @classmethod
501 def from_response(cls, response: XComSequenceIndexResponse) -> XComSequenceIndexResult:
502 return cls(root=response.root, type="XComSequenceIndexResult")
503
504
505class XComSequenceSliceResult(BaseModel):
506 root: list[JsonValue]
507 type: Literal["XComSequenceSliceResult"] = "XComSequenceSliceResult"
508
509 @classmethod
510 def from_response(cls, response: XComSequenceSliceResponse) -> XComSequenceSliceResult:
511 return cls(root=response.root, type="XComSequenceSliceResult")
512
513
514class ConnectionResult(ConnectionResponse):
515 type: Literal["ConnectionResult"] = "ConnectionResult"
516
517 @classmethod
518 def from_conn_response(cls, connection_response: ConnectionResponse) -> ConnectionResult:
519 """
520 Get ConnectionResult from ConnectionResponse.
521
522 ConnectionResponse is autogenerated from the API schema, so we need to convert it to ConnectionResult
523 for communication between the Supervisor and the task process.
524 """
525 # Exclude defaults to avoid sending unnecessary data
526 # Pass the type as ConnectionResult explicitly so we can then call model_dump_json with exclude_unset=True
527 # to avoid sending unset fields (which are defaults in our case).
528 return cls(
529 **connection_response.model_dump(exclude_defaults=True, by_alias=True), type="ConnectionResult"
530 )
531
532
533class VariableResult(VariableResponse):
534 type: Literal["VariableResult"] = "VariableResult"
535
536 @classmethod
537 def from_variable_response(cls, variable_response: VariableResponse) -> VariableResult:
538 """
539 Get VariableResult from VariableResponse.
540
541 VariableResponse is autogenerated from the API schema, so we need to convert it to VariableResult
542 for communication between the Supervisor and the task process.
543 """
544 return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult")
545
546
547class DagRunResult(DagRun):
548 type: Literal["DagRunResult"] = "DagRunResult"
549
550 @classmethod
551 def from_api_response(cls, dr_response: DagRun) -> DagRunResult:
552 """
553 Create result class from API Response.
554
555 API Response is autogenerated from the API schema, so we need to convert it to Result
556 for communication between the Supervisor and the task process since it needs a
557 discriminator field.
558 """
559 return cls(**dr_response.model_dump(exclude_defaults=True), type="DagRunResult")
560
561
562class DagRunStateResult(DagRunStateResponse):
563 type: Literal["DagRunStateResult"] = "DagRunStateResult"
564
565 # TODO: Create a convert api_response to result classes so we don't need to do this
566 # for all the classes above
567 @classmethod
568 def from_api_response(cls, dr_state_response: DagRunStateResponse) -> DagRunStateResult:
569 """
570 Create result class from API Response.
571
572 API Response is autogenerated from the API schema, so we need to convert it to Result
573 for communication between the Supervisor and the task process since it needs a
574 discriminator field.
575 """
576 return cls(**dr_state_response.model_dump(exclude_defaults=True), type="DagRunStateResult")
577
578
579class PreviousDagRunResult(BaseModel):
580 """Response containing previous Dag run information."""
581
582 dag_run: DagRun | None = None
583 type: Literal["PreviousDagRunResult"] = "PreviousDagRunResult"
584
585
586class PreviousTIResult(BaseModel):
587 """Response containing previous task instance data."""
588
589 task_instance: PreviousTIResponse | None = None
590 type: Literal["PreviousTIResult"] = "PreviousTIResult"
591
592
593class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse):
594 type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult"
595
596 @classmethod
597 def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse) -> PrevSuccessfulDagRunResult:
598 """
599 Get a result object from response object.
600
601 PrevSuccessfulDagRunResponse is autogenerated from the API schema, so we need to convert it to
602 PrevSuccessfulDagRunResult for communication between the Supervisor and the task process.
603 """
604 return cls(**prev_dag_run.model_dump(exclude_defaults=True), type="PrevSuccessfulDagRunResult")
605
606
607class TaskRescheduleStartDate(BaseModel):
608 """Response containing the first reschedule date for a task instance."""
609
610 start_date: AwareDatetime | None
611 type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate"
612
613
614class TICount(BaseModel):
615 """Response containing count of Task Instances matching certain filters."""
616
617 count: int
618 type: Literal["TICount"] = "TICount"
619
620
621class TaskStatesResult(TaskStatesResponse):
622 type: Literal["TaskStatesResult"] = "TaskStatesResult"
623
624 @classmethod
625 def from_api_response(cls, task_states_response: TaskStatesResponse) -> TaskStatesResult:
626 """
627 Create result class from API Response.
628
629 API Response is autogenerated from the API schema, so we need to convert it to Result
630 for communication between the Supervisor and the task process since it needs a
631 discriminator field.
632 """
633 return cls(**task_states_response.model_dump(exclude_defaults=True), type="TaskStatesResult")
634
635
636class TaskBreadcrumbsResult(TaskBreadcrumbsResponse):
637 type: Literal["TaskBreadcrumbsResult"] = "TaskBreadcrumbsResult"
638
639 @classmethod
640 def from_api_response(cls, response: TaskBreadcrumbsResponse) -> TaskBreadcrumbsResult:
641 """
642 Create result class from API Response.
643
644 API Response is autogenerated from the API schema, so we need to convert
645 it to Result for communication between the Supervisor and the task
646 process since it needs a discriminator field.
647 """
648 return cls(**response.model_dump(exclude_defaults=True), type="TaskBreadcrumbsResult")
649
650
651class DRCount(BaseModel):
652 """Response containing count of Dag Runs matching certain filters."""
653
654 count: int
655 type: Literal["DRCount"] = "DRCount"
656
657
658class ErrorResponse(BaseModel):
659 error: ErrorType = ErrorType.GENERIC_ERROR
660 detail: dict | None = None
661 type: Literal["ErrorResponse"] = "ErrorResponse"
662
663
664class OKResponse(BaseModel):
665 ok: bool
666 type: Literal["OKResponse"] = "OKResponse"
667
668
669class SentFDs(BaseModel):
670 type: Literal["SentFDs"] = "SentFDs"
671 fds: list[int]
672
673
674class CreateHITLDetailPayload(HITLDetailRequest):
675 """Add the input request part of a Human-in-the-loop response."""
676
677 type: Literal["CreateHITLDetailPayload"] = "CreateHITLDetailPayload"
678
679
680class HITLDetailRequestResult(HITLDetailRequest):
681 """Response to CreateHITLDetailPayload request."""
682
683 type: Literal["HITLDetailRequestResult"] = "HITLDetailRequestResult"
684
685 @classmethod
686 def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequestResult:
687 """
688 Get HITLDetailRequestResult from HITLDetailRequest (API response).
689
690 HITLDetailRequest is the API response model. We convert it to HITLDetailRequestResult
691 for communication between the Supervisor and task process, adding the discriminator field
692 required for the tagged union deserialization.
693 """
694 return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult")
695
696
697ToTask = Annotated[
698 AssetResult
699 | AssetEventsResult
700 | ConnectionResult
701 | DagRunResult
702 | DagRunStateResult
703 | DRCount
704 | ErrorResponse
705 | PrevSuccessfulDagRunResult
706 | PreviousTIResult
707 | SentFDs
708 | StartupDetails
709 | TaskRescheduleStartDate
710 | TICount
711 | TaskBreadcrumbsResult
712 | TaskStatesResult
713 | VariableResult
714 | XComCountResponse
715 | XComResult
716 | XComSequenceIndexResult
717 | XComSequenceSliceResult
718 | InactiveAssetsResult
719 | CreateHITLDetailPayload
720 | HITLDetailRequestResult
721 | OKResponse
722 | PreviousDagRunResult,
723 Field(discriminator="type"),
724]
725
726
727class TaskState(BaseModel):
728 """
729 Update a task's state.
730
731 If a process exits without sending one of these the state will be derived from the exit code:
732 - 0 = SUCCESS
733 - anything else = FAILED
734 """
735
736 state: Literal[
737 TaskInstanceState.FAILED,
738 TaskInstanceState.SKIPPED,
739 TaskInstanceState.REMOVED,
740 ]
741 end_date: datetime | None = None
742 type: Literal["TaskState"] = "TaskState"
743 rendered_map_index: str | None = None
744
745
746class SucceedTask(TISuccessStatePayload):
747 """Update a task's state to success. Includes task_outlets and outlet_events for registering asset events."""
748
749 type: Literal["SucceedTask"] = "SucceedTask"
750
751
752class DeferTask(TIDeferredStatePayload):
753 """Update a task instance state to deferred."""
754
755 type: Literal["DeferTask"] = "DeferTask"
756
757
758class RetryTask(TIRetryStatePayload):
759 """Update a task instance state to up_for_retry."""
760
761 type: Literal["RetryTask"] = "RetryTask"
762
763
764class RescheduleTask(TIRescheduleStatePayload):
765 """Update a task instance state to reschedule/up_for_reschedule."""
766
767 type: Literal["RescheduleTask"] = "RescheduleTask"
768
769
770class SkipDownstreamTasks(TISkippedDownstreamTasksStatePayload):
771 """Update state of downstream tasks within a task instance to 'skipped', while updating current task to success state."""
772
773 type: Literal["SkipDownstreamTasks"] = "SkipDownstreamTasks"
774
775
776class GetXCom(BaseModel):
777 key: str
778 dag_id: str
779 run_id: str
780 task_id: str
781 map_index: int | None = None
782 include_prior_dates: bool = False
783 type: Literal["GetXCom"] = "GetXCom"
784
785
786class GetXComCount(BaseModel):
787 """Get the number of (mapped) XCom values available."""
788
789 key: str
790 dag_id: str
791 run_id: str
792 task_id: str
793 type: Literal["GetNumberXComs"] = "GetNumberXComs"
794
795
796class GetXComSequenceItem(BaseModel):
797 key: str
798 dag_id: str
799 run_id: str
800 task_id: str
801 offset: int
802 type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem"
803
804
805class GetXComSequenceSlice(BaseModel):
806 key: str
807 dag_id: str
808 run_id: str
809 task_id: str
810 start: int | None
811 stop: int | None
812 step: int | None
813 include_prior_dates: bool = False
814 type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice"
815
816
817class SetXCom(BaseModel):
818 key: str
819 value: JsonValue
820 dag_id: str
821 run_id: str
822 task_id: str
823 map_index: int | None = None
824 mapped_length: int | None = None
825 type: Literal["SetXCom"] = "SetXCom"
826
827
828class DeleteXCom(BaseModel):
829 key: str
830 dag_id: str
831 run_id: str
832 task_id: str
833 map_index: int | None = None
834 type: Literal["DeleteXCom"] = "DeleteXCom"
835
836
837class GetConnection(BaseModel):
838 conn_id: str
839 type: Literal["GetConnection"] = "GetConnection"
840
841
842class GetVariable(BaseModel):
843 key: str
844 type: Literal["GetVariable"] = "GetVariable"
845
846
847class PutVariable(BaseModel):
848 key: str
849 value: str | None
850 description: str | None
851 type: Literal["PutVariable"] = "PutVariable"
852
853
854class DeleteVariable(BaseModel):
855 key: str
856 type: Literal["DeleteVariable"] = "DeleteVariable"
857
858
859class ResendLoggingFD(BaseModel):
860 type: Literal["ResendLoggingFD"] = "ResendLoggingFD"
861
862
863class SetRenderedFields(BaseModel):
864 """Payload for setting RTIF for a task instance."""
865
866 # We are using a BaseModel here compared to server using RootModel because we
867 # have a discriminator running with "type", and RootModel doesn't support type
868
869 rendered_fields: dict[str, JsonValue]
870 type: Literal["SetRenderedFields"] = "SetRenderedFields"
871
872
873class SetRenderedMapIndex(BaseModel):
874 """Payload for setting rendered_map_index for a task instance."""
875
876 rendered_map_index: str
877 type: Literal["SetRenderedMapIndex"] = "SetRenderedMapIndex"
878
879
880class TriggerDagRun(TriggerDAGRunPayload):
881 dag_id: str
882 run_id: Annotated[str, Field(title="Dag Run Id")]
883 type: Literal["TriggerDagRun"] = "TriggerDagRun"
884
885
886class GetDagRun(BaseModel):
887 dag_id: str
888 run_id: str
889 type: Literal["GetDagRun"] = "GetDagRun"
890
891
892class GetDagRunState(BaseModel):
893 dag_id: str
894 run_id: str
895 type: Literal["GetDagRunState"] = "GetDagRunState"
896
897
898class GetPreviousDagRun(BaseModel):
899 dag_id: str
900 logical_date: AwareDatetime
901 state: str | None = None
902 type: Literal["GetPreviousDagRun"] = "GetPreviousDagRun"
903
904
905class GetPreviousTI(BaseModel):
906 """Request to get previous task instance."""
907
908 dag_id: str
909 task_id: str
910 logical_date: AwareDatetime | None = None
911 map_index: int = -1
912 state: TaskInstanceState | None = None
913 type: Literal["GetPreviousTI"] = "GetPreviousTI"
914
915
916class GetAssetByName(BaseModel):
917 name: str
918 type: Literal["GetAssetByName"] = "GetAssetByName"
919
920
921class GetAssetByUri(BaseModel):
922 uri: str
923 type: Literal["GetAssetByUri"] = "GetAssetByUri"
924
925
926class GetAssetEventByAsset(BaseModel):
927 name: str | None
928 uri: str | None
929 after: AwareDatetime | None = None
930 before: AwareDatetime | None = None
931 limit: int | None = None
932 ascending: bool = True
933 type: Literal["GetAssetEventByAsset"] = "GetAssetEventByAsset"
934
935
936class GetAssetEventByAssetAlias(BaseModel):
937 alias_name: str
938 after: AwareDatetime | None = None
939 before: AwareDatetime | None = None
940 limit: int | None = None
941 ascending: bool = True
942 type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias"
943
944
945class ValidateInletsAndOutlets(BaseModel):
946 ti_id: UUID
947 type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets"
948
949
950class GetPrevSuccessfulDagRun(BaseModel):
951 ti_id: UUID
952 type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
953
954
955class GetTaskRescheduleStartDate(BaseModel):
956 ti_id: UUID
957 try_number: int = 1
958 type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate"
959
960
961class GetTICount(BaseModel):
962 dag_id: str
963 map_index: int | None = None
964 task_ids: list[str] | None = None
965 task_group_id: str | None = None
966 logical_dates: list[AwareDatetime] | None = None
967 run_ids: list[str] | None = None
968 states: list[str] | None = None
969 type: Literal["GetTICount"] = "GetTICount"
970
971
972class GetTaskStates(BaseModel):
973 dag_id: str
974 map_index: int | None = None
975 task_ids: list[str] | None = None
976 task_group_id: str | None = None
977 logical_dates: list[AwareDatetime] | None = None
978 run_ids: list[str] | None = None
979 type: Literal["GetTaskStates"] = "GetTaskStates"
980
981
982class GetTaskBreadcrumbs(BaseModel):
983 dag_id: str
984 run_id: str
985 type: Literal["GetTaskBreadcrumbs"] = "GetTaskBreadcrumbs"
986
987
988class GetDRCount(BaseModel):
989 dag_id: str
990 logical_dates: list[AwareDatetime] | None = None
991 run_ids: list[str] | None = None
992 states: list[str] | None = None
993 type: Literal["GetDRCount"] = "GetDRCount"
994
995
996class GetHITLDetailResponse(BaseModel):
997 """Get the response content part of a Human-in-the-loop response."""
998
999 ti_id: UUID
1000 type: Literal["GetHITLDetailResponse"] = "GetHITLDetailResponse"
1001
1002
1003class UpdateHITLDetail(UpdateHITLDetailPayload):
1004 """Update the response content part of an existing Human-in-the-loop response."""
1005
1006 type: Literal["UpdateHITLDetail"] = "UpdateHITLDetail"
1007
1008
1009class MaskSecret(BaseModel):
1010 """Add a new value to be redacted in task logs."""
1011
1012 # This is needed since calls to `mask_secret` in the Task process will otherwise only add the mask value
1013 # to the child process, but the redaction happens in the parent.
1014 # We cannot use `string | Iterable | dict here` (would be more intuitive) because bug in Pydantic
1015 # https://github.com/pydantic/pydantic/issues/9541 turns iterable into a ValidatorIterator
1016 value: JsonValue
1017 name: str | None = None
1018 type: Literal["MaskSecret"] = "MaskSecret"
1019
1020
1021ToSupervisor = Annotated[
1022 DeferTask
1023 | DeleteXCom
1024 | GetAssetByName
1025 | GetAssetByUri
1026 | GetAssetEventByAsset
1027 | GetAssetEventByAssetAlias
1028 | GetConnection
1029 | GetDagRun
1030 | GetDagRunState
1031 | GetDRCount
1032 | GetPrevSuccessfulDagRun
1033 | GetPreviousDagRun
1034 | GetPreviousTI
1035 | GetTaskRescheduleStartDate
1036 | GetTICount
1037 | GetTaskBreadcrumbs
1038 | GetTaskStates
1039 | GetVariable
1040 | GetXCom
1041 | GetXComCount
1042 | GetXComSequenceItem
1043 | GetXComSequenceSlice
1044 | PutVariable
1045 | RescheduleTask
1046 | RetryTask
1047 | SetRenderedFields
1048 | SetRenderedMapIndex
1049 | SetXCom
1050 | SkipDownstreamTasks
1051 | SucceedTask
1052 | ValidateInletsAndOutlets
1053 | TaskState
1054 | TriggerDagRun
1055 | DeleteVariable
1056 | ResendLoggingFD
1057 | CreateHITLDetailPayload
1058 | UpdateHITLDetail
1059 | GetHITLDetailResponse
1060 | MaskSecret,
1061 Field(discriminator="type"),
1062]