1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18r"""
19Communication protocol between the Supervisor and the task process
20==================================================================
21
22* All communication is done over the subprocesses stdin in the form of a binary length-prefixed msgpack frame
23 (4 byte, big-endian length, followed by the msgpack-encoded _RequestFrame.) Each side uses this same
24 encoding
25* Log Messages from the subprocess are sent over the dedicated logs socket (which is line-based JSON)
26* No messages are sent to task process except in response to a request. (This is because the task process will
27 be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom
28 value etc.)
29* Every request returns a response, even if the frame is otherwise empty.
30* Requests are written by the subprocess to fd0/stdin. This is making use of the fact that stdin is a
31 bi-directional socket, and thus we can write to it and don't need a dedicated extra socket for sending
32 requests.
33
34The reason this communication protocol exists, rather than the task process speaking directly to the Task
35Execution API server is because:
36
371. To reduce the number of concurrent HTTP connections on the API server.
38
39 The supervisor already has to speak to that to heartbeat the running Task, so having the task speak to its
40 parent process and having all API traffic go through that means that the number of HTTP connections is
41 "halved". (Not every task will make API calls, so it's not always halved, but it is reduced.)
42
432. This means that the user Task code doesn't ever directly see the task identity JWT token.
44
45 This is a short lived token tied to one specific task instance try, so it being leaked/exfiltrated is not a
46 large risk, but it's easy to not give it to the user code, so lets do that.
47""" # noqa: D400, D205
48
49from __future__ import annotations
50
51import itertools
52from collections.abc import Iterator
53from datetime import datetime
54from functools import cached_property
55from pathlib import Path
56from socket import socket
57from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload
58from uuid import UUID
59
60import attrs
61import msgspec
62import structlog
63from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
64
65from airflow.sdk.api.datamodels._generated import (
66 AssetEventDagRunReference,
67 AssetEventResponse,
68 AssetEventsResponse,
69 AssetResponse,
70 BundleInfo,
71 ConnectionResponse,
72 DagRun,
73 DagRunStateResponse,
74 HITLDetailRequest,
75 InactiveAssetsResponse,
76 PreviousTIResponse,
77 PrevSuccessfulDagRunResponse,
78 TaskBreadcrumbsResponse,
79 TaskInstance,
80 TaskInstanceState,
81 TaskStatesResponse,
82 TIDeferredStatePayload,
83 TIRescheduleStatePayload,
84 TIRetryStatePayload,
85 TIRunContext,
86 TISkippedDownstreamTasksStatePayload,
87 TISuccessStatePayload,
88 TriggerDAGRunPayload,
89 UpdateHITLDetailPayload,
90 VariableResponse,
91 XComResponse,
92 XComSequenceIndexResponse,
93 XComSequenceSliceResponse,
94)
95from airflow.sdk.exceptions import ErrorType
96
97try:
98 from socket import recv_fds
99except ImportError:
100 # Available on Unix and Windows (so "everywhere") but lets be safe
101 recv_fds = None # type: ignore[assignment]
102
103
104if TYPE_CHECKING:
105 from structlog.typing import FilteringBoundLogger as Logger
106
107SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
108ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
109
110
111def _msgpack_enc_hook(obj: Any) -> Any:
112 import pendulum
113
114 if isinstance(obj, pendulum.DateTime):
115 # convert the pendulm Datetime subclass into a raw datetime so that msgspec can use it's native
116 # encoding
117 return datetime(
118 obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo
119 )
120 if isinstance(obj, Path):
121 return str(obj)
122 if isinstance(obj, BaseModel):
123 return obj.model_dump(exclude_unset=True)
124
125 # Raise a NotImplementedError for other types
126 raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
127
128
129def _new_encoder() -> msgspec.msgpack.Encoder:
130 return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook)
131
132
133class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True):
134 id: int
135 """
136 The request id, set by the sender.
137
138 This is used to allow "pipeling" of requests and to be able to tie response to requests, which is
139 particularly useful in the Triggerer where multiple async tasks can send a requests concurrently.
140 """
141 body: dict[str, Any] | None
142
143 req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder()
144
145 def as_bytes(self) -> bytearray:
146 # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration
147 buffer = bytearray(256)
148
149 self.req_encoder.encode_into(self, buffer, 4)
150
151 n = len(buffer) - 4
152 if n >= 2**32:
153 raise OverflowError(f"Cannot send messages larger than 4GiB {n=}")
154 buffer[:4] = n.to_bytes(4, byteorder="big")
155
156 return buffer
157
158
159class _ResponseFrame(_RequestFrame, frozen=True):
160 id: int
161 """
162 The id of the request this is a response to
163 """
164 body: dict[str, Any] | None = None
165 error: dict[str, Any] | None = None
166
167
168@attrs.define()
169class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
170 """Handle communication between the task in this process and the supervisor parent process."""
171
172 log: Logger = attrs.field(repr=False, factory=structlog.get_logger)
173 socket: socket = attrs.field(factory=lambda: socket(fileno=0))
174
175 resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field(
176 factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False
177 )
178
179 id_counter: Iterator[int] = attrs.field(factory=itertools.count)
180
181 # We could be "clever" here and set the default to this based type parameters and a custom
182 # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a
183 # "sort of wrong default"
184 body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False)
185
186 err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False)
187
188 def send(self, msg: SendMsgType) -> ReceiveMsgType | None:
189 """Send a request to the parent and block until the response is received."""
190 frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
191 frame_bytes = frame.as_bytes()
192
193 self.socket.sendall(frame_bytes)
194 if isinstance(msg, ResendLoggingFD):
195 if recv_fds is None:
196 return None
197 # We need special handling here! The server can't send us the fd number, as the number on the
198 # supervisor will be different to in this process, so we have to mutate the message ourselves here.
199 frame, fds = self._read_frame(maxfds=1)
200 resp = self._from_frame(frame)
201 if TYPE_CHECKING:
202 assert isinstance(resp, SentFDs)
203 resp.fds = fds
204 # Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not
205 # always be in the return type union
206 return resp # type: ignore[return-value]
207
208 return self._get_response()
209
210 async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None:
211 """Send a request to the parent without blocking."""
212 raise NotImplementedError
213
214 @overload
215 def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ...
216
217 @overload
218 def _read_frame(self, maxfds: int) -> tuple[_ResponseFrame, list[int]]: ...
219
220 def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[int]] | _ResponseFrame:
221 """
222 Get a message from the parent.
223
224 This will block until the message has been received.
225 """
226 if self.socket:
227 self.socket.setblocking(True)
228 fds = None
229 if maxfds:
230 len_bytes, fds, flag, address = recv_fds(self.socket, 4, maxfds)
231 else:
232 len_bytes = self.socket.recv(4)
233
234 if len_bytes == b"":
235 raise EOFError("Request socket closed before length")
236
237 length = int.from_bytes(len_bytes, byteorder="big")
238
239 buffer = bytearray(length)
240 mv = memoryview(buffer)
241
242 pos = 0
243 while pos < length:
244 nread = self.socket.recv_into(mv[pos:])
245 if nread == 0:
246 raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})")
247 pos += nread
248
249 resp = self.resp_decoder.decode(mv)
250 if maxfds:
251 return resp, fds or []
252 return resp
253
254 def _from_frame(self, frame) -> ReceiveMsgType | None:
255 from airflow.sdk.exceptions import AirflowRuntimeError
256
257 if frame.error is not None:
258 err = self.err_decoder.validate_python(frame.error)
259 raise AirflowRuntimeError(error=err)
260
261 if frame.body is None:
262 return None
263
264 try:
265 return self.body_decoder.validate_python(frame.body)
266 except Exception:
267 self.log.exception("Unable to decode message")
268 raise
269
270 def _get_response(self) -> ReceiveMsgType | None:
271 frame = self._read_frame()
272 return self._from_frame(frame)
273
274
275class StartupDetails(BaseModel):
276 model_config = ConfigDict(arbitrary_types_allowed=True)
277
278 ti: TaskInstance
279 dag_rel_path: str
280 bundle_info: BundleInfo
281 start_date: datetime
282 ti_context: TIRunContext
283 sentry_integration: str
284 type: Literal["StartupDetails"] = "StartupDetails"
285
286
287class AssetResult(AssetResponse):
288 """Response to ReadXCom request."""
289
290 type: Literal["AssetResult"] = "AssetResult"
291
292 @classmethod
293 def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult:
294 """
295 Get AssetResult from AssetResponse.
296
297 AssetResponse is autogenerated from the API schema, so we need to convert it to AssetResult
298 for communication between the Supervisor and the task process.
299 """
300 # Exclude defaults to avoid sending unnecessary data
301 # Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True
302 # to avoid sending unset fields (which are defaults in our case).
303 return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult")
304
305
306@attrs.define(kw_only=True)
307class AssetEventSourceTaskInstance:
308 """Used in AssetEventResult."""
309
310 dag_run: DagRun
311 task_id: str
312 map_index: int
313
314 @property
315 def dag_id(self) -> str:
316 return self.dag_run.dag_id
317
318 @property
319 def run_id(self) -> str:
320 return self.dag_run.run_id
321
322 def xcom_pull(self, *, key: str = "return_value", default: Any = None) -> Any:
323 from airflow.sdk.execution_time.xcom import XCom
324
325 if (value := XCom.get_value(ti_key=self, key=key)) is None:
326 return default
327 return value
328
329
330def _fetch_dag_run(*, dag_id: str, run_id: str) -> DagRun:
331 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
332
333 response = SUPERVISOR_COMMS.send(GetDagRun(dag_id=dag_id, run_id=run_id))
334 if TYPE_CHECKING:
335 assert isinstance(response, DagRunResult)
336 return response
337
338
339class AssetEventResult(AssetEventResponse):
340 """Used in AssetEventsResult."""
341
342 @classmethod
343 def from_asset_event_response(cls, asset_event_response: AssetEventResponse) -> AssetEventResult:
344 return cls(**asset_event_response.model_dump(exclude_defaults=True))
345
346 @cached_property
347 def source_dag_run(self) -> DagRun | None:
348 if not self.source_dag_id or not self.source_run_id:
349 return None
350 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id)
351
352 @cached_property
353 def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
354 if self.source_task_id is None or self.source_map_index is None:
355 return None
356 if (dag_run := self.source_dag_run) is None:
357 return None
358 return AssetEventSourceTaskInstance(
359 dag_run=dag_run,
360 task_id=self.source_task_id,
361 map_index=self.source_map_index,
362 )
363
364
365class AssetEventsResult(AssetEventsResponse):
366 """Response to GetAssetEvent request."""
367
368 type: Literal["AssetEventsResult"] = "AssetEventsResult"
369
370 @classmethod
371 def from_asset_events_response(cls, asset_events_response: AssetEventsResponse) -> AssetEventsResult:
372 """
373 Get AssetEventsResult from AssetEventsResponse.
374
375 AssetEventsResponse is autogenerated from the API schema, so we need to convert it to AssetEventsResponse
376 for communication between the Supervisor and the task process.
377 """
378 # Exclude defaults to avoid sending unnecessary data
379 # Pass the type as AssetEventsResult explicitly so we can then call model_dump_json with exclude_unset=True
380 # to avoid sending unset fields (which are defaults in our case).
381 return cls(
382 **asset_events_response.model_dump(exclude_defaults=True),
383 type="AssetEventsResult",
384 )
385
386 def iter_asset_event_results(self) -> Iterator[AssetEventResult]:
387 return (AssetEventResult.from_asset_event_response(event) for event in self.asset_events)
388
389
390class AssetEventDagRunReferenceResult(AssetEventDagRunReference):
391 @classmethod
392 def from_asset_event_dag_run_reference(
393 cls,
394 asset_event_dag_run_reference: AssetEventDagRunReference,
395 ) -> AssetEventDagRunReferenceResult:
396 return cls(**asset_event_dag_run_reference.model_dump(exclude_defaults=True))
397
398 @cached_property
399 def source_dag_run(self) -> DagRun | None:
400 if not self.source_dag_id or not self.source_run_id:
401 return None
402 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id)
403
404 @cached_property
405 def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
406 if self.source_task_id is None or self.source_map_index is None:
407 return None
408 if (dag_run := self.source_dag_run) is None:
409 return None
410 return AssetEventSourceTaskInstance(
411 dag_run=dag_run,
412 task_id=self.source_task_id,
413 map_index=self.source_map_index,
414 )
415
416
417class InactiveAssetsResult(InactiveAssetsResponse):
418 """Response of InactiveAssets requests."""
419
420 type: Literal["InactiveAssetsResult"] = "InactiveAssetsResult"
421
422 @classmethod
423 def from_inactive_assets_response(
424 cls, inactive_assets_response: InactiveAssetsResponse
425 ) -> InactiveAssetsResult:
426 """
427 Get InactiveAssetsResponse from InactiveAssetsResult.
428
429 InactiveAssetsResponse is autogenerated from the API schema, so we need to convert it to InactiveAssetsResult
430 for communication between the Supervisor and the task process.
431 """
432 return cls(**inactive_assets_response.model_dump(exclude_defaults=True), type="InactiveAssetsResult")
433
434
435class XComResult(XComResponse):
436 """Response to ReadXCom request."""
437
438 type: Literal["XComResult"] = "XComResult"
439
440 @classmethod
441 def from_xcom_response(cls, xcom_response: XComResponse) -> XComResult:
442 """
443 Get XComResult from XComResponse.
444
445 XComResponse is autogenerated from the API schema, so we need to convert it to XComResult
446 for communication between the Supervisor and the task process.
447 """
448 return cls(**xcom_response.model_dump(exclude_defaults=True), type="XComResult")
449
450
451class XComCountResponse(BaseModel):
452 len: int
453 type: Literal["XComLengthResponse"] = "XComLengthResponse"
454
455
456class XComSequenceIndexResult(BaseModel):
457 root: JsonValue
458 type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult"
459
460 @classmethod
461 def from_response(cls, response: XComSequenceIndexResponse) -> XComSequenceIndexResult:
462 return cls(root=response.root, type="XComSequenceIndexResult")
463
464
465class XComSequenceSliceResult(BaseModel):
466 root: list[JsonValue]
467 type: Literal["XComSequenceSliceResult"] = "XComSequenceSliceResult"
468
469 @classmethod
470 def from_response(cls, response: XComSequenceSliceResponse) -> XComSequenceSliceResult:
471 return cls(root=response.root, type="XComSequenceSliceResult")
472
473
474class ConnectionResult(ConnectionResponse):
475 type: Literal["ConnectionResult"] = "ConnectionResult"
476
477 @classmethod
478 def from_conn_response(cls, connection_response: ConnectionResponse) -> ConnectionResult:
479 """
480 Get ConnectionResult from ConnectionResponse.
481
482 ConnectionResponse is autogenerated from the API schema, so we need to convert it to ConnectionResult
483 for communication between the Supervisor and the task process.
484 """
485 # Exclude defaults to avoid sending unnecessary data
486 # Pass the type as ConnectionResult explicitly so we can then call model_dump_json with exclude_unset=True
487 # to avoid sending unset fields (which are defaults in our case).
488 return cls(
489 **connection_response.model_dump(exclude_defaults=True, by_alias=True), type="ConnectionResult"
490 )
491
492
493class VariableResult(VariableResponse):
494 type: Literal["VariableResult"] = "VariableResult"
495
496 @classmethod
497 def from_variable_response(cls, variable_response: VariableResponse) -> VariableResult:
498 """
499 Get VariableResult from VariableResponse.
500
501 VariableResponse is autogenerated from the API schema, so we need to convert it to VariableResult
502 for communication between the Supervisor and the task process.
503 """
504 return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult")
505
506
507class DagRunResult(DagRun):
508 type: Literal["DagRunResult"] = "DagRunResult"
509
510 @classmethod
511 def from_api_response(cls, dr_response: DagRun) -> DagRunResult:
512 """
513 Create result class from API Response.
514
515 API Response is autogenerated from the API schema, so we need to convert it to Result
516 for communication between the Supervisor and the task process since it needs a
517 discriminator field.
518 """
519 return cls(**dr_response.model_dump(exclude_defaults=True), type="DagRunResult")
520
521
522class DagRunStateResult(DagRunStateResponse):
523 type: Literal["DagRunStateResult"] = "DagRunStateResult"
524
525 # TODO: Create a convert api_response to result classes so we don't need to do this
526 # for all the classes above
527 @classmethod
528 def from_api_response(cls, dr_state_response: DagRunStateResponse) -> DagRunStateResult:
529 """
530 Create result class from API Response.
531
532 API Response is autogenerated from the API schema, so we need to convert it to Result
533 for communication between the Supervisor and the task process since it needs a
534 discriminator field.
535 """
536 return cls(**dr_state_response.model_dump(exclude_defaults=True), type="DagRunStateResult")
537
538
539class PreviousDagRunResult(BaseModel):
540 """Response containing previous Dag run information."""
541
542 dag_run: DagRun | None = None
543 type: Literal["PreviousDagRunResult"] = "PreviousDagRunResult"
544
545
546class PreviousTIResult(BaseModel):
547 """Response containing previous task instance data."""
548
549 task_instance: PreviousTIResponse | None = None
550 type: Literal["PreviousTIResult"] = "PreviousTIResult"
551
552
553class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse):
554 type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult"
555
556 @classmethod
557 def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse) -> PrevSuccessfulDagRunResult:
558 """
559 Get a result object from response object.
560
561 PrevSuccessfulDagRunResponse is autogenerated from the API schema, so we need to convert it to
562 PrevSuccessfulDagRunResult for communication between the Supervisor and the task process.
563 """
564 return cls(**prev_dag_run.model_dump(exclude_defaults=True), type="PrevSuccessfulDagRunResult")
565
566
567class TaskRescheduleStartDate(BaseModel):
568 """Response containing the first reschedule date for a task instance."""
569
570 start_date: AwareDatetime | None
571 type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate"
572
573
574class TICount(BaseModel):
575 """Response containing count of Task Instances matching certain filters."""
576
577 count: int
578 type: Literal["TICount"] = "TICount"
579
580
581class TaskStatesResult(TaskStatesResponse):
582 type: Literal["TaskStatesResult"] = "TaskStatesResult"
583
584 @classmethod
585 def from_api_response(cls, task_states_response: TaskStatesResponse) -> TaskStatesResult:
586 """
587 Create result class from API Response.
588
589 API Response is autogenerated from the API schema, so we need to convert it to Result
590 for communication between the Supervisor and the task process since it needs a
591 discriminator field.
592 """
593 return cls(**task_states_response.model_dump(exclude_defaults=True), type="TaskStatesResult")
594
595
596class TaskBreadcrumbsResult(TaskBreadcrumbsResponse):
597 type: Literal["TaskBreadcrumbsResult"] = "TaskBreadcrumbsResult"
598
599 @classmethod
600 def from_api_response(cls, response: TaskBreadcrumbsResponse) -> TaskBreadcrumbsResult:
601 """
602 Create result class from API Response.
603
604 API Response is autogenerated from the API schema, so we need to convert
605 it to Result for communication between the Supervisor and the task
606 process since it needs a discriminator field.
607 """
608 return cls(**response.model_dump(exclude_defaults=True), type="TaskBreadcrumbsResult")
609
610
611class DRCount(BaseModel):
612 """Response containing count of Dag Runs matching certain filters."""
613
614 count: int
615 type: Literal["DRCount"] = "DRCount"
616
617
618class ErrorResponse(BaseModel):
619 error: ErrorType = ErrorType.GENERIC_ERROR
620 detail: dict | None = None
621 type: Literal["ErrorResponse"] = "ErrorResponse"
622
623
624class OKResponse(BaseModel):
625 ok: bool
626 type: Literal["OKResponse"] = "OKResponse"
627
628
629class SentFDs(BaseModel):
630 type: Literal["SentFDs"] = "SentFDs"
631 fds: list[int]
632
633
634class CreateHITLDetailPayload(HITLDetailRequest):
635 """Add the input request part of a Human-in-the-loop response."""
636
637 type: Literal["CreateHITLDetailPayload"] = "CreateHITLDetailPayload"
638
639
640class HITLDetailRequestResult(HITLDetailRequest):
641 """Response to CreateHITLDetailPayload request."""
642
643 type: Literal["HITLDetailRequestResult"] = "HITLDetailRequestResult"
644
645 @classmethod
646 def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequestResult:
647 """
648 Get HITLDetailRequestResult from HITLDetailRequest (API response).
649
650 HITLDetailRequest is the API response model. We convert it to HITLDetailRequestResult
651 for communication between the Supervisor and task process, adding the discriminator field
652 required for the tagged union deserialization.
653 """
654 return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult")
655
656
657ToTask = Annotated[
658 AssetResult
659 | AssetEventsResult
660 | ConnectionResult
661 | DagRunResult
662 | DagRunStateResult
663 | DRCount
664 | ErrorResponse
665 | PrevSuccessfulDagRunResult
666 | PreviousTIResult
667 | SentFDs
668 | StartupDetails
669 | TaskRescheduleStartDate
670 | TICount
671 | TaskBreadcrumbsResult
672 | TaskStatesResult
673 | VariableResult
674 | XComCountResponse
675 | XComResult
676 | XComSequenceIndexResult
677 | XComSequenceSliceResult
678 | InactiveAssetsResult
679 | CreateHITLDetailPayload
680 | HITLDetailRequestResult
681 | OKResponse
682 | PreviousDagRunResult,
683 Field(discriminator="type"),
684]
685
686
687class TaskState(BaseModel):
688 """
689 Update a task's state.
690
691 If a process exits without sending one of these the state will be derived from the exit code:
692 - 0 = SUCCESS
693 - anything else = FAILED
694 """
695
696 state: Literal[
697 TaskInstanceState.FAILED,
698 TaskInstanceState.SKIPPED,
699 TaskInstanceState.REMOVED,
700 ]
701 end_date: datetime | None = None
702 type: Literal["TaskState"] = "TaskState"
703 rendered_map_index: str | None = None
704
705
706class SucceedTask(TISuccessStatePayload):
707 """Update a task's state to success. Includes task_outlets and outlet_events for registering asset events."""
708
709 type: Literal["SucceedTask"] = "SucceedTask"
710
711
712class DeferTask(TIDeferredStatePayload):
713 """Update a task instance state to deferred."""
714
715 type: Literal["DeferTask"] = "DeferTask"
716
717
718class RetryTask(TIRetryStatePayload):
719 """Update a task instance state to up_for_retry."""
720
721 type: Literal["RetryTask"] = "RetryTask"
722
723
724class RescheduleTask(TIRescheduleStatePayload):
725 """Update a task instance state to reschedule/up_for_reschedule."""
726
727 type: Literal["RescheduleTask"] = "RescheduleTask"
728
729
730class SkipDownstreamTasks(TISkippedDownstreamTasksStatePayload):
731 """Update state of downstream tasks within a task instance to 'skipped', while updating current task to success state."""
732
733 type: Literal["SkipDownstreamTasks"] = "SkipDownstreamTasks"
734
735
736class GetXCom(BaseModel):
737 key: str
738 dag_id: str
739 run_id: str
740 task_id: str
741 map_index: int | None = None
742 include_prior_dates: bool = False
743 type: Literal["GetXCom"] = "GetXCom"
744
745
746class GetXComCount(BaseModel):
747 """Get the number of (mapped) XCom values available."""
748
749 key: str
750 dag_id: str
751 run_id: str
752 task_id: str
753 type: Literal["GetNumberXComs"] = "GetNumberXComs"
754
755
756class GetXComSequenceItem(BaseModel):
757 key: str
758 dag_id: str
759 run_id: str
760 task_id: str
761 offset: int
762 type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem"
763
764
765class GetXComSequenceSlice(BaseModel):
766 key: str
767 dag_id: str
768 run_id: str
769 task_id: str
770 start: int | None
771 stop: int | None
772 step: int | None
773 include_prior_dates: bool = False
774 type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice"
775
776
777class SetXCom(BaseModel):
778 key: str
779 value: JsonValue
780 dag_id: str
781 run_id: str
782 task_id: str
783 map_index: int | None = None
784 mapped_length: int | None = None
785 type: Literal["SetXCom"] = "SetXCom"
786
787
788class DeleteXCom(BaseModel):
789 key: str
790 dag_id: str
791 run_id: str
792 task_id: str
793 map_index: int | None = None
794 type: Literal["DeleteXCom"] = "DeleteXCom"
795
796
797class GetConnection(BaseModel):
798 conn_id: str
799 type: Literal["GetConnection"] = "GetConnection"
800
801
802class GetVariable(BaseModel):
803 key: str
804 type: Literal["GetVariable"] = "GetVariable"
805
806
807class PutVariable(BaseModel):
808 key: str
809 value: str | None
810 description: str | None
811 type: Literal["PutVariable"] = "PutVariable"
812
813
814class DeleteVariable(BaseModel):
815 key: str
816 type: Literal["DeleteVariable"] = "DeleteVariable"
817
818
819class ResendLoggingFD(BaseModel):
820 type: Literal["ResendLoggingFD"] = "ResendLoggingFD"
821
822
823class SetRenderedFields(BaseModel):
824 """Payload for setting RTIF for a task instance."""
825
826 # We are using a BaseModel here compared to server using RootModel because we
827 # have a discriminator running with "type", and RootModel doesn't support type
828
829 rendered_fields: dict[str, JsonValue]
830 type: Literal["SetRenderedFields"] = "SetRenderedFields"
831
832
833class SetRenderedMapIndex(BaseModel):
834 """Payload for setting rendered_map_index for a task instance."""
835
836 rendered_map_index: str
837 type: Literal["SetRenderedMapIndex"] = "SetRenderedMapIndex"
838
839
840class TriggerDagRun(TriggerDAGRunPayload):
841 dag_id: str
842 run_id: Annotated[str, Field(title="Dag Run Id")]
843 type: Literal["TriggerDagRun"] = "TriggerDagRun"
844
845
846class GetDagRun(BaseModel):
847 dag_id: str
848 run_id: str
849 type: Literal["GetDagRun"] = "GetDagRun"
850
851
852class GetDagRunState(BaseModel):
853 dag_id: str
854 run_id: str
855 type: Literal["GetDagRunState"] = "GetDagRunState"
856
857
858class GetPreviousDagRun(BaseModel):
859 dag_id: str
860 logical_date: AwareDatetime
861 state: str | None = None
862 type: Literal["GetPreviousDagRun"] = "GetPreviousDagRun"
863
864
865class GetPreviousTI(BaseModel):
866 """Request to get previous task instance."""
867
868 dag_id: str
869 task_id: str
870 logical_date: AwareDatetime | None = None
871 map_index: int = -1
872 state: TaskInstanceState | None = None
873 type: Literal["GetPreviousTI"] = "GetPreviousTI"
874
875
876class GetAssetByName(BaseModel):
877 name: str
878 type: Literal["GetAssetByName"] = "GetAssetByName"
879
880
881class GetAssetByUri(BaseModel):
882 uri: str
883 type: Literal["GetAssetByUri"] = "GetAssetByUri"
884
885
886class GetAssetEventByAsset(BaseModel):
887 name: str | None
888 uri: str | None
889 after: AwareDatetime | None = None
890 before: AwareDatetime | None = None
891 limit: int | None = None
892 ascending: bool = True
893 type: Literal["GetAssetEventByAsset"] = "GetAssetEventByAsset"
894
895
896class GetAssetEventByAssetAlias(BaseModel):
897 alias_name: str
898 after: AwareDatetime | None = None
899 before: AwareDatetime | None = None
900 limit: int | None = None
901 ascending: bool = True
902 type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias"
903
904
905class ValidateInletsAndOutlets(BaseModel):
906 ti_id: UUID
907 type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets"
908
909
910class GetPrevSuccessfulDagRun(BaseModel):
911 ti_id: UUID
912 type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
913
914
915class GetTaskRescheduleStartDate(BaseModel):
916 ti_id: UUID
917 try_number: int = 1
918 type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate"
919
920
921class GetTICount(BaseModel):
922 dag_id: str
923 map_index: int | None = None
924 task_ids: list[str] | None = None
925 task_group_id: str | None = None
926 logical_dates: list[AwareDatetime] | None = None
927 run_ids: list[str] | None = None
928 states: list[str] | None = None
929 type: Literal["GetTICount"] = "GetTICount"
930
931
932class GetTaskStates(BaseModel):
933 dag_id: str
934 map_index: int | None = None
935 task_ids: list[str] | None = None
936 task_group_id: str | None = None
937 logical_dates: list[AwareDatetime] | None = None
938 run_ids: list[str] | None = None
939 type: Literal["GetTaskStates"] = "GetTaskStates"
940
941
942class GetTaskBreadcrumbs(BaseModel):
943 dag_id: str
944 run_id: str
945 type: Literal["GetTaskBreadcrumbs"] = "GetTaskBreadcrumbs"
946
947
948class GetDRCount(BaseModel):
949 dag_id: str
950 logical_dates: list[AwareDatetime] | None = None
951 run_ids: list[str] | None = None
952 states: list[str] | None = None
953 type: Literal["GetDRCount"] = "GetDRCount"
954
955
956class GetHITLDetailResponse(BaseModel):
957 """Get the response content part of a Human-in-the-loop response."""
958
959 ti_id: UUID
960 type: Literal["GetHITLDetailResponse"] = "GetHITLDetailResponse"
961
962
963class UpdateHITLDetail(UpdateHITLDetailPayload):
964 """Update the response content part of an existing Human-in-the-loop response."""
965
966 type: Literal["UpdateHITLDetail"] = "UpdateHITLDetail"
967
968
969class MaskSecret(BaseModel):
970 """Add a new value to be redacted in task logs."""
971
972 # This is needed since calls to `mask_secret` in the Task process will otherwise only add the mask value
973 # to the child process, but the redaction happens in the parent.
974 # We cannot use `string | Iterable | dict here` (would be more intuitive) because bug in Pydantic
975 # https://github.com/pydantic/pydantic/issues/9541 turns iterable into a ValidatorIterator
976 value: JsonValue
977 name: str | None = None
978 type: Literal["MaskSecret"] = "MaskSecret"
979
980
981ToSupervisor = Annotated[
982 DeferTask
983 | DeleteXCom
984 | GetAssetByName
985 | GetAssetByUri
986 | GetAssetEventByAsset
987 | GetAssetEventByAssetAlias
988 | GetConnection
989 | GetDagRun
990 | GetDagRunState
991 | GetDRCount
992 | GetPrevSuccessfulDagRun
993 | GetPreviousDagRun
994 | GetPreviousTI
995 | GetTaskRescheduleStartDate
996 | GetTICount
997 | GetTaskBreadcrumbs
998 | GetTaskStates
999 | GetVariable
1000 | GetXCom
1001 | GetXComCount
1002 | GetXComSequenceItem
1003 | GetXComSequenceSlice
1004 | PutVariable
1005 | RescheduleTask
1006 | RetryTask
1007 | SetRenderedFields
1008 | SetRenderedMapIndex
1009 | SetXCom
1010 | SkipDownstreamTasks
1011 | SucceedTask
1012 | ValidateInletsAndOutlets
1013 | TaskState
1014 | TriggerDagRun
1015 | DeleteVariable
1016 | ResendLoggingFD
1017 | CreateHITLDetailPayload
1018 | UpdateHITLDetail
1019 | GetHITLDetailResponse
1020 | MaskSecret,
1021 Field(discriminator="type"),
1022]