1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18r"""
19Communication protocol between the Supervisor and the task process
20==================================================================
21
22* All communication is done over the subprocesses stdin in the form of a binary length-prefixed msgpack frame
23 (4 byte, big-endian length, followed by the msgpack-encoded _RequestFrame.) Each side uses this same
24 encoding
25* Log Messages from the subprocess are sent over the dedicated logs socket (which is line-based JSON)
26* No messages are sent to task process except in response to a request. (This is because the task process will
27 be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom
28 value etc.)
29* Every request returns a response, even if the frame is otherwise empty.
30* Requests are written by the subprocess to fd0/stdin. This is making use of the fact that stdin is a
31 bi-directional socket, and thus we can write to it and don't need a dedicated extra socket for sending
32 requests.
33
34The reason this communication protocol exists, rather than the task process speaking directly to the Task
35Execution API server is because:
36
371. To reduce the number of concurrent HTTP connections on the API server.
38
39 The supervisor already has to speak to that to heartbeat the running Task, so having the task speak to its
40 parent process and having all API traffic go through that means that the number of HTTP connections is
41 "halved". (Not every task will make API calls, so it's not always halved, but it is reduced.)
42
432. This means that the user Task code doesn't ever directly see the task identity JWT token.
44
45 This is a short lived token tied to one specific task instance try, so it being leaked/exfiltrated is not a
46 large risk, but it's easy to not give it to the user code, so lets do that.
47""" # noqa: D400, D205
48
49from __future__ import annotations
50
51import itertools
52from collections.abc import Iterator
53from datetime import datetime
54from functools import cached_property
55from pathlib import Path
56from socket import socket
57from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Generic, Literal, TypeVar, overload
58from uuid import UUID
59
60import attrs
61import msgspec
62import structlog
63from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_serializer
64
65from airflow.sdk.api.datamodels._generated import (
66 AssetEventDagRunReference,
67 AssetEventResponse,
68 AssetEventsResponse,
69 AssetResponse,
70 BundleInfo,
71 ConnectionResponse,
72 DagRun,
73 DagRunStateResponse,
74 HITLDetailRequest,
75 InactiveAssetsResponse,
76 PrevSuccessfulDagRunResponse,
77 TaskBreadcrumbsResponse,
78 TaskInstance,
79 TaskInstanceState,
80 TaskStatesResponse,
81 TIDeferredStatePayload,
82 TIRescheduleStatePayload,
83 TIRetryStatePayload,
84 TIRunContext,
85 TISkippedDownstreamTasksStatePayload,
86 TISuccessStatePayload,
87 TriggerDAGRunPayload,
88 UpdateHITLDetailPayload,
89 VariableResponse,
90 XComResponse,
91 XComSequenceIndexResponse,
92 XComSequenceSliceResponse,
93)
94from airflow.sdk.exceptions import ErrorType
95
96try:
97 from socket import recv_fds
98except ImportError:
99 # Available on Unix and Windows (so "everywhere") but lets be safe
100 recv_fds = None # type: ignore[assignment]
101
102
103if TYPE_CHECKING:
104 from structlog.typing import FilteringBoundLogger as Logger
105
106SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
107ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
108
109
110def _msgpack_enc_hook(obj: Any) -> Any:
111 import pendulum
112
113 if isinstance(obj, pendulum.DateTime):
114 # convert the pendulm Datetime subclass into a raw datetime so that msgspec can use it's native
115 # encoding
116 return datetime(
117 obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo
118 )
119 if isinstance(obj, Path):
120 return str(obj)
121 if isinstance(obj, BaseModel):
122 return obj.model_dump(exclude_unset=True)
123
124 # Raise a NotImplementedError for other types
125 raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
126
127
128def _new_encoder() -> msgspec.msgpack.Encoder:
129 return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook)
130
131
132class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True):
133 id: int
134 """
135 The request id, set by the sender.
136
137 This is used to allow "pipeling" of requests and to be able to tie response to requests, which is
138 particularly useful in the Triggerer where multiple async tasks can send a requests concurrently.
139 """
140 body: dict[str, Any] | None
141
142 req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder()
143
144 def as_bytes(self) -> bytearray:
145 # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration
146 buffer = bytearray(256)
147
148 self.req_encoder.encode_into(self, buffer, 4)
149
150 n = len(buffer) - 4
151 if n >= 2**32:
152 raise OverflowError(f"Cannot send messages larger than 4GiB {n=}")
153 buffer[:4] = n.to_bytes(4, byteorder="big")
154
155 return buffer
156
157
158class _ResponseFrame(_RequestFrame, frozen=True):
159 id: int
160 """
161 The id of the request this is a response to
162 """
163 body: dict[str, Any] | None = None
164 error: dict[str, Any] | None = None
165
166
167@attrs.define()
168class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
169 """Handle communication between the task in this process and the supervisor parent process."""
170
171 log: Logger = attrs.field(repr=False, factory=structlog.get_logger)
172 socket: socket = attrs.field(factory=lambda: socket(fileno=0))
173
174 resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field(
175 factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False
176 )
177
178 id_counter: Iterator[int] = attrs.field(factory=itertools.count)
179
180 # We could be "clever" here and set the default to this based type parameters and a custom
181 # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a
182 # "sort of wrong default"
183 body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False)
184
185 err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False)
186
187 def send(self, msg: SendMsgType) -> ReceiveMsgType | None:
188 """Send a request to the parent and block until the response is received."""
189 frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
190 frame_bytes = frame.as_bytes()
191
192 self.socket.sendall(frame_bytes)
193 if isinstance(msg, ResendLoggingFD):
194 if recv_fds is None:
195 return None
196 # We need special handling here! The server can't send us the fd number, as the number on the
197 # supervisor will be different to in this process, so we have to mutate the message ourselves here.
198 frame, fds = self._read_frame(maxfds=1)
199 resp = self._from_frame(frame)
200 if TYPE_CHECKING:
201 assert isinstance(resp, SentFDs)
202 resp.fds = fds
203 # Since we know this is an expliclt SendFDs, and since this class is generic SendFDs might not
204 # always be in the return type union
205 return resp # type: ignore[return-value]
206
207 return self._get_response()
208
209 async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None:
210 """Send a request to the parent without blocking."""
211 raise NotImplementedError
212
213 @overload
214 def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ...
215
216 @overload
217 def _read_frame(self, maxfds: int) -> tuple[_ResponseFrame, list[int]]: ...
218
219 def _read_frame(self, maxfds: int | None = None) -> tuple[_ResponseFrame, list[int]] | _ResponseFrame:
220 """
221 Get a message from the parent.
222
223 This will block until the message has been received.
224 """
225 if self.socket:
226 self.socket.setblocking(True)
227 fds = None
228 if maxfds:
229 len_bytes, fds, flag, address = recv_fds(self.socket, 4, maxfds)
230 else:
231 len_bytes = self.socket.recv(4)
232
233 if len_bytes == b"":
234 raise EOFError("Request socket closed before length")
235
236 length = int.from_bytes(len_bytes, byteorder="big")
237
238 buffer = bytearray(length)
239 mv = memoryview(buffer)
240
241 pos = 0
242 while pos < length:
243 nread = self.socket.recv_into(mv[pos:])
244 if nread == 0:
245 raise EOFError(f"Request socket closed before response was complete ({self.id_counter=})")
246 pos += nread
247
248 resp = self.resp_decoder.decode(mv)
249 if maxfds:
250 return resp, fds or []
251 return resp
252
253 def _from_frame(self, frame) -> ReceiveMsgType | None:
254 from airflow.sdk.exceptions import AirflowRuntimeError
255
256 if frame.error is not None:
257 err = self.err_decoder.validate_python(frame.error)
258 raise AirflowRuntimeError(error=err)
259
260 if frame.body is None:
261 return None
262
263 try:
264 return self.body_decoder.validate_python(frame.body)
265 except Exception:
266 self.log.exception("Unable to decode message")
267 raise
268
269 def _get_response(self) -> ReceiveMsgType | None:
270 frame = self._read_frame()
271 return self._from_frame(frame)
272
273
274class StartupDetails(BaseModel):
275 model_config = ConfigDict(arbitrary_types_allowed=True)
276
277 ti: TaskInstance
278 dag_rel_path: str
279 bundle_info: BundleInfo
280 start_date: datetime
281 ti_context: TIRunContext
282 sentry_integration: str
283 type: Literal["StartupDetails"] = "StartupDetails"
284
285
286class AssetResult(AssetResponse):
287 """Response to ReadXCom request."""
288
289 type: Literal["AssetResult"] = "AssetResult"
290
291 @classmethod
292 def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult:
293 """
294 Get AssetResult from AssetResponse.
295
296 AssetResponse is autogenerated from the API schema, so we need to convert it to AssetResult
297 for communication between the Supervisor and the task process.
298 """
299 # Exclude defaults to avoid sending unnecessary data
300 # Pass the type as AssetResult explicitly so we can then call model_dump_json with exclude_unset=True
301 # to avoid sending unset fields (which are defaults in our case).
302 return cls(**asset_response.model_dump(exclude_defaults=True), type="AssetResult")
303
304
305@attrs.define(kw_only=True)
306class AssetEventSourceTaskInstance:
307 """Used in AssetEventResult."""
308
309 dag_run: DagRun
310 task_id: str
311 map_index: int
312
313 @property
314 def dag_id(self) -> str:
315 return self.dag_run.dag_id
316
317 @property
318 def run_id(self) -> str:
319 return self.dag_run.run_id
320
321 def xcom_pull(self, *, key: str = "return_value", default: Any = None) -> Any:
322 from airflow.sdk.execution_time.xcom import XCom
323
324 if (value := XCom.get_value(ti_key=self, key=key)) is None:
325 return default
326 return value
327
328
329def _fetch_dag_run(*, dag_id: str, run_id: str) -> DagRun:
330 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
331
332 response = SUPERVISOR_COMMS.send(GetDagRun(dag_id=dag_id, run_id=run_id))
333 if TYPE_CHECKING:
334 assert isinstance(response, DagRunResult)
335 return response
336
337
338class AssetEventResult(AssetEventResponse):
339 """Used in AssetEventsResult."""
340
341 @classmethod
342 def from_asset_event_response(cls, asset_event_response: AssetEventResponse) -> AssetEventResult:
343 return cls(**asset_event_response.model_dump(exclude_defaults=True))
344
345 @cached_property
346 def source_dag_run(self) -> DagRun | None:
347 if not self.source_dag_id or not self.source_run_id:
348 return None
349 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id)
350
351 @cached_property
352 def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
353 if self.source_task_id is None or self.source_map_index is None:
354 return None
355 if (dag_run := self.source_dag_run) is None:
356 return None
357 return AssetEventSourceTaskInstance(
358 dag_run=dag_run,
359 task_id=self.source_task_id,
360 map_index=self.source_map_index,
361 )
362
363
364class AssetEventsResult(AssetEventsResponse):
365 """Response to GetAssetEvent request."""
366
367 type: Literal["AssetEventsResult"] = "AssetEventsResult"
368
369 @classmethod
370 def from_asset_events_response(cls, asset_events_response: AssetEventsResponse) -> AssetEventsResult:
371 """
372 Get AssetEventsResult from AssetEventsResponse.
373
374 AssetEventsResponse is autogenerated from the API schema, so we need to convert it to AssetEventsResponse
375 for communication between the Supervisor and the task process.
376 """
377 # Exclude defaults to avoid sending unnecessary data
378 # Pass the type as AssetEventsResult explicitly so we can then call model_dump_json with exclude_unset=True
379 # to avoid sending unset fields (which are defaults in our case).
380 return cls(
381 **asset_events_response.model_dump(exclude_defaults=True),
382 type="AssetEventsResult",
383 )
384
385 def iter_asset_event_results(self) -> Iterator[AssetEventResult]:
386 return (AssetEventResult.from_asset_event_response(event) for event in self.asset_events)
387
388
389class AssetEventDagRunReferenceResult(AssetEventDagRunReference):
390 @classmethod
391 def from_asset_event_dag_run_reference(
392 cls,
393 asset_event_dag_run_reference: AssetEventDagRunReference,
394 ) -> AssetEventDagRunReferenceResult:
395 return cls(**asset_event_dag_run_reference.model_dump(exclude_defaults=True))
396
397 @cached_property
398 def source_dag_run(self) -> DagRun | None:
399 if not self.source_dag_id or not self.source_run_id:
400 return None
401 return _fetch_dag_run(dag_id=self.source_dag_id, run_id=self.source_run_id)
402
403 @cached_property
404 def source_task_instance(self) -> AssetEventSourceTaskInstance | None:
405 if self.source_task_id is None or self.source_map_index is None:
406 return None
407 if (dag_run := self.source_dag_run) is None:
408 return None
409 return AssetEventSourceTaskInstance(
410 dag_run=dag_run,
411 task_id=self.source_task_id,
412 map_index=self.source_map_index,
413 )
414
415
416class InactiveAssetsResult(InactiveAssetsResponse):
417 """Response of InactiveAssets requests."""
418
419 type: Literal["InactiveAssetsResult"] = "InactiveAssetsResult"
420
421 @classmethod
422 def from_inactive_assets_response(
423 cls, inactive_assets_response: InactiveAssetsResponse
424 ) -> InactiveAssetsResult:
425 """
426 Get InactiveAssetsResponse from InactiveAssetsResult.
427
428 InactiveAssetsResponse is autogenerated from the API schema, so we need to convert it to InactiveAssetsResult
429 for communication between the Supervisor and the task process.
430 """
431 return cls(**inactive_assets_response.model_dump(exclude_defaults=True), type="InactiveAssetsResult")
432
433
434class XComResult(XComResponse):
435 """Response to ReadXCom request."""
436
437 type: Literal["XComResult"] = "XComResult"
438
439 @classmethod
440 def from_xcom_response(cls, xcom_response: XComResponse) -> XComResult:
441 """
442 Get XComResult from XComResponse.
443
444 XComResponse is autogenerated from the API schema, so we need to convert it to XComResult
445 for communication between the Supervisor and the task process.
446 """
447 return cls(**xcom_response.model_dump(exclude_defaults=True), type="XComResult")
448
449
450class XComCountResponse(BaseModel):
451 len: int
452 type: Literal["XComLengthResponse"] = "XComLengthResponse"
453
454
455class XComSequenceIndexResult(BaseModel):
456 root: JsonValue
457 type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult"
458
459 @classmethod
460 def from_response(cls, response: XComSequenceIndexResponse) -> XComSequenceIndexResult:
461 return cls(root=response.root, type="XComSequenceIndexResult")
462
463
464class XComSequenceSliceResult(BaseModel):
465 root: list[JsonValue]
466 type: Literal["XComSequenceSliceResult"] = "XComSequenceSliceResult"
467
468 @classmethod
469 def from_response(cls, response: XComSequenceSliceResponse) -> XComSequenceSliceResult:
470 return cls(root=response.root, type="XComSequenceSliceResult")
471
472
473class ConnectionResult(ConnectionResponse):
474 type: Literal["ConnectionResult"] = "ConnectionResult"
475
476 @classmethod
477 def from_conn_response(cls, connection_response: ConnectionResponse) -> ConnectionResult:
478 """
479 Get ConnectionResult from ConnectionResponse.
480
481 ConnectionResponse is autogenerated from the API schema, so we need to convert it to ConnectionResult
482 for communication between the Supervisor and the task process.
483 """
484 # Exclude defaults to avoid sending unnecessary data
485 # Pass the type as ConnectionResult explicitly so we can then call model_dump_json with exclude_unset=True
486 # to avoid sending unset fields (which are defaults in our case).
487 return cls(
488 **connection_response.model_dump(exclude_defaults=True, by_alias=True), type="ConnectionResult"
489 )
490
491
492class VariableResult(VariableResponse):
493 type: Literal["VariableResult"] = "VariableResult"
494
495 @classmethod
496 def from_variable_response(cls, variable_response: VariableResponse) -> VariableResult:
497 """
498 Get VariableResult from VariableResponse.
499
500 VariableResponse is autogenerated from the API schema, so we need to convert it to VariableResult
501 for communication between the Supervisor and the task process.
502 """
503 return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult")
504
505
506class DagRunResult(DagRun):
507 type: Literal["DagRunResult"] = "DagRunResult"
508
509 @classmethod
510 def from_api_response(cls, dr_response: DagRun) -> DagRunResult:
511 """
512 Create result class from API Response.
513
514 API Response is autogenerated from the API schema, so we need to convert it to Result
515 for communication between the Supervisor and the task process since it needs a
516 discriminator field.
517 """
518 return cls(**dr_response.model_dump(exclude_defaults=True), type="DagRunResult")
519
520
521class DagRunStateResult(DagRunStateResponse):
522 type: Literal["DagRunStateResult"] = "DagRunStateResult"
523
524 # TODO: Create a convert api_response to result classes so we don't need to do this
525 # for all the classes above
526 @classmethod
527 def from_api_response(cls, dr_state_response: DagRunStateResponse) -> DagRunStateResult:
528 """
529 Create result class from API Response.
530
531 API Response is autogenerated from the API schema, so we need to convert it to Result
532 for communication between the Supervisor and the task process since it needs a
533 discriminator field.
534 """
535 return cls(**dr_state_response.model_dump(exclude_defaults=True), type="DagRunStateResult")
536
537
538class PreviousDagRunResult(BaseModel):
539 """Response containing previous Dag run information."""
540
541 dag_run: DagRun | None = None
542 type: Literal["PreviousDagRunResult"] = "PreviousDagRunResult"
543
544
545class PrevSuccessfulDagRunResult(PrevSuccessfulDagRunResponse):
546 type: Literal["PrevSuccessfulDagRunResult"] = "PrevSuccessfulDagRunResult"
547
548 @classmethod
549 def from_dagrun_response(cls, prev_dag_run: PrevSuccessfulDagRunResponse) -> PrevSuccessfulDagRunResult:
550 """
551 Get a result object from response object.
552
553 PrevSuccessfulDagRunResponse is autogenerated from the API schema, so we need to convert it to
554 PrevSuccessfulDagRunResult for communication between the Supervisor and the task process.
555 """
556 return cls(**prev_dag_run.model_dump(exclude_defaults=True), type="PrevSuccessfulDagRunResult")
557
558
559class TaskRescheduleStartDate(BaseModel):
560 """Response containing the first reschedule date for a task instance."""
561
562 start_date: AwareDatetime | None
563 type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate"
564
565
566class TICount(BaseModel):
567 """Response containing count of Task Instances matching certain filters."""
568
569 count: int
570 type: Literal["TICount"] = "TICount"
571
572
573class TaskStatesResult(TaskStatesResponse):
574 type: Literal["TaskStatesResult"] = "TaskStatesResult"
575
576 @classmethod
577 def from_api_response(cls, task_states_response: TaskStatesResponse) -> TaskStatesResult:
578 """
579 Create result class from API Response.
580
581 API Response is autogenerated from the API schema, so we need to convert it to Result
582 for communication between the Supervisor and the task process since it needs a
583 discriminator field.
584 """
585 return cls(**task_states_response.model_dump(exclude_defaults=True), type="TaskStatesResult")
586
587
588class TaskBreadcrumbsResult(TaskBreadcrumbsResponse):
589 type: Literal["TaskBreadcrumbsResult"] = "TaskBreadcrumbsResult"
590
591 @classmethod
592 def from_api_response(cls, response: TaskBreadcrumbsResponse) -> TaskBreadcrumbsResult:
593 """
594 Create result class from API Response.
595
596 API Response is autogenerated from the API schema, so we need to convert
597 it to Result for communication between the Supervisor and the task
598 process since it needs a discriminator field.
599 """
600 return cls(**response.model_dump(exclude_defaults=True), type="TaskBreadcrumbsResult")
601
602
603class DRCount(BaseModel):
604 """Response containing count of Dag Runs matching certain filters."""
605
606 count: int
607 type: Literal["DRCount"] = "DRCount"
608
609
610class ErrorResponse(BaseModel):
611 error: ErrorType = ErrorType.GENERIC_ERROR
612 detail: dict | None = None
613 type: Literal["ErrorResponse"] = "ErrorResponse"
614
615
616class OKResponse(BaseModel):
617 ok: bool
618 type: Literal["OKResponse"] = "OKResponse"
619
620
621class SentFDs(BaseModel):
622 type: Literal["SentFDs"] = "SentFDs"
623 fds: list[int]
624
625
626class CreateHITLDetailPayload(HITLDetailRequest):
627 """Add the input request part of a Human-in-the-loop response."""
628
629 type: Literal["CreateHITLDetailPayload"] = "CreateHITLDetailPayload"
630
631
632class HITLDetailRequestResult(HITLDetailRequest):
633 """Response to CreateHITLDetailPayload request."""
634
635 type: Literal["HITLDetailRequestResult"] = "HITLDetailRequestResult"
636
637 @classmethod
638 def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequestResult:
639 """
640 Get HITLDetailRequestResult from HITLDetailRequest (API response).
641
642 HITLDetailRequest is the API response model. We convert it to HITLDetailRequestResult
643 for communication between the Supervisor and task process, adding the discriminator field
644 required for the tagged union deserialization.
645 """
646 return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult")
647
648
649ToTask = Annotated[
650 AssetResult
651 | AssetEventsResult
652 | ConnectionResult
653 | DagRunResult
654 | DagRunStateResult
655 | DRCount
656 | ErrorResponse
657 | PrevSuccessfulDagRunResult
658 | SentFDs
659 | StartupDetails
660 | TaskRescheduleStartDate
661 | TICount
662 | TaskBreadcrumbsResult
663 | TaskStatesResult
664 | VariableResult
665 | XComCountResponse
666 | XComResult
667 | XComSequenceIndexResult
668 | XComSequenceSliceResult
669 | InactiveAssetsResult
670 | CreateHITLDetailPayload
671 | HITLDetailRequestResult
672 | OKResponse
673 | PreviousDagRunResult,
674 Field(discriminator="type"),
675]
676
677
678class TaskState(BaseModel):
679 """
680 Update a task's state.
681
682 If a process exits without sending one of these the state will be derived from the exit code:
683 - 0 = SUCCESS
684 - anything else = FAILED
685 """
686
687 state: Literal[
688 TaskInstanceState.FAILED,
689 TaskInstanceState.SKIPPED,
690 TaskInstanceState.REMOVED,
691 ]
692 end_date: datetime | None = None
693 type: Literal["TaskState"] = "TaskState"
694 rendered_map_index: str | None = None
695
696
697class SucceedTask(TISuccessStatePayload):
698 """Update a task's state to success. Includes task_outlets and outlet_events for registering asset events."""
699
700 type: Literal["SucceedTask"] = "SucceedTask"
701
702
703class DeferTask(TIDeferredStatePayload):
704 """Update a task instance state to deferred."""
705
706 type: Literal["DeferTask"] = "DeferTask"
707
708 @field_serializer("trigger_kwargs", "next_kwargs", check_fields=True)
709 def _serde_kwarg_fields(self, val: str | dict[str, Any] | None, _info):
710 from airflow.serialization.serialized_objects import BaseSerialization
711
712 if not isinstance(val, dict):
713 # None, or an encrypted string
714 return val
715
716 if val.keys() == {"__type", "__var"}:
717 # Already encoded.
718 return val
719 return BaseSerialization.serialize(val or {})
720
721
722class RetryTask(TIRetryStatePayload):
723 """Update a task instance state to up_for_retry."""
724
725 type: Literal["RetryTask"] = "RetryTask"
726
727
728class RescheduleTask(TIRescheduleStatePayload):
729 """Update a task instance state to reschedule/up_for_reschedule."""
730
731 type: Literal["RescheduleTask"] = "RescheduleTask"
732
733
734class SkipDownstreamTasks(TISkippedDownstreamTasksStatePayload):
735 """Update state of downstream tasks within a task instance to 'skipped', while updating current task to success state."""
736
737 type: Literal["SkipDownstreamTasks"] = "SkipDownstreamTasks"
738
739
740class GetXCom(BaseModel):
741 key: str
742 dag_id: str
743 run_id: str
744 task_id: str
745 map_index: int | None = None
746 include_prior_dates: bool = False
747 type: Literal["GetXCom"] = "GetXCom"
748
749
750class GetXComCount(BaseModel):
751 """Get the number of (mapped) XCom values available."""
752
753 key: str
754 dag_id: str
755 run_id: str
756 task_id: str
757 type: Literal["GetNumberXComs"] = "GetNumberXComs"
758
759
760class GetXComSequenceItem(BaseModel):
761 key: str
762 dag_id: str
763 run_id: str
764 task_id: str
765 offset: int
766 type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem"
767
768
769class GetXComSequenceSlice(BaseModel):
770 key: str
771 dag_id: str
772 run_id: str
773 task_id: str
774 start: int | None
775 stop: int | None
776 step: int | None
777 include_prior_dates: bool = False
778 type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice"
779
780
781class SetXCom(BaseModel):
782 key: str
783 value: JsonValue
784 dag_id: str
785 run_id: str
786 task_id: str
787 map_index: int | None = None
788 mapped_length: int | None = None
789 type: Literal["SetXCom"] = "SetXCom"
790
791
792class DeleteXCom(BaseModel):
793 key: str
794 dag_id: str
795 run_id: str
796 task_id: str
797 map_index: int | None = None
798 type: Literal["DeleteXCom"] = "DeleteXCom"
799
800
801class GetConnection(BaseModel):
802 conn_id: str
803 type: Literal["GetConnection"] = "GetConnection"
804
805
806class GetVariable(BaseModel):
807 key: str
808 type: Literal["GetVariable"] = "GetVariable"
809
810
811class PutVariable(BaseModel):
812 key: str
813 value: str | None
814 description: str | None
815 type: Literal["PutVariable"] = "PutVariable"
816
817
818class DeleteVariable(BaseModel):
819 key: str
820 type: Literal["DeleteVariable"] = "DeleteVariable"
821
822
823class ResendLoggingFD(BaseModel):
824 type: Literal["ResendLoggingFD"] = "ResendLoggingFD"
825
826
827class SetRenderedFields(BaseModel):
828 """Payload for setting RTIF for a task instance."""
829
830 # We are using a BaseModel here compared to server using RootModel because we
831 # have a discriminator running with "type", and RootModel doesn't support type
832
833 rendered_fields: dict[str, JsonValue]
834 type: Literal["SetRenderedFields"] = "SetRenderedFields"
835
836
837class SetRenderedMapIndex(BaseModel):
838 """Payload for setting rendered_map_index for a task instance."""
839
840 rendered_map_index: str
841 type: Literal["SetRenderedMapIndex"] = "SetRenderedMapIndex"
842
843
844class TriggerDagRun(TriggerDAGRunPayload):
845 dag_id: str
846 run_id: Annotated[str, Field(title="Dag Run Id")]
847 type: Literal["TriggerDagRun"] = "TriggerDagRun"
848
849
850class GetDagRun(BaseModel):
851 dag_id: str
852 run_id: str
853 type: Literal["GetDagRun"] = "GetDagRun"
854
855
856class GetDagRunState(BaseModel):
857 dag_id: str
858 run_id: str
859 type: Literal["GetDagRunState"] = "GetDagRunState"
860
861
862class GetPreviousDagRun(BaseModel):
863 dag_id: str
864 logical_date: AwareDatetime
865 state: str | None = None
866 type: Literal["GetPreviousDagRun"] = "GetPreviousDagRun"
867
868
869class GetAssetByName(BaseModel):
870 name: str
871 type: Literal["GetAssetByName"] = "GetAssetByName"
872
873
874class GetAssetByUri(BaseModel):
875 uri: str
876 type: Literal["GetAssetByUri"] = "GetAssetByUri"
877
878
879class GetAssetEventByAsset(BaseModel):
880 name: str | None
881 uri: str | None
882 after: AwareDatetime | None = None
883 before: AwareDatetime | None = None
884 limit: int | None = None
885 ascending: bool = True
886 type: Literal["GetAssetEventByAsset"] = "GetAssetEventByAsset"
887
888
889class GetAssetEventByAssetAlias(BaseModel):
890 alias_name: str
891 after: AwareDatetime | None = None
892 before: AwareDatetime | None = None
893 limit: int | None = None
894 ascending: bool = True
895 type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias"
896
897
898class ValidateInletsAndOutlets(BaseModel):
899 ti_id: UUID
900 type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets"
901
902
903class GetPrevSuccessfulDagRun(BaseModel):
904 ti_id: UUID
905 type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun"
906
907
908class GetTaskRescheduleStartDate(BaseModel):
909 ti_id: UUID
910 try_number: int = 1
911 type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate"
912
913
914class GetTICount(BaseModel):
915 dag_id: str
916 map_index: int | None = None
917 task_ids: list[str] | None = None
918 task_group_id: str | None = None
919 logical_dates: list[AwareDatetime] | None = None
920 run_ids: list[str] | None = None
921 states: list[str] | None = None
922 type: Literal["GetTICount"] = "GetTICount"
923
924
925class GetTaskStates(BaseModel):
926 dag_id: str
927 map_index: int | None = None
928 task_ids: list[str] | None = None
929 task_group_id: str | None = None
930 logical_dates: list[AwareDatetime] | None = None
931 run_ids: list[str] | None = None
932 type: Literal["GetTaskStates"] = "GetTaskStates"
933
934
935class GetTaskBreadcrumbs(BaseModel):
936 dag_id: str
937 run_id: str
938 type: Literal["GetTaskBreadcrumbs"] = "GetTaskBreadcrumbs"
939
940
941class GetDRCount(BaseModel):
942 dag_id: str
943 logical_dates: list[AwareDatetime] | None = None
944 run_ids: list[str] | None = None
945 states: list[str] | None = None
946 type: Literal["GetDRCount"] = "GetDRCount"
947
948
949class GetHITLDetailResponse(BaseModel):
950 """Get the response content part of a Human-in-the-loop response."""
951
952 ti_id: UUID
953 type: Literal["GetHITLDetailResponse"] = "GetHITLDetailResponse"
954
955
956class UpdateHITLDetail(UpdateHITLDetailPayload):
957 """Update the response content part of an existing Human-in-the-loop response."""
958
959 type: Literal["UpdateHITLDetail"] = "UpdateHITLDetail"
960
961
962class MaskSecret(BaseModel):
963 """Add a new value to be redacted in task logs."""
964
965 # This is needed since calls to `mask_secret` in the Task process will otherwise only add the mask value
966 # to the child process, but the redaction happens in the parent.
967 # We cannot use `string | Iterable | dict here` (would be more intuitive) because bug in Pydantic
968 # https://github.com/pydantic/pydantic/issues/9541 turns iterable into a ValidatorIterator
969 value: JsonValue
970 name: str | None = None
971 type: Literal["MaskSecret"] = "MaskSecret"
972
973
974ToSupervisor = Annotated[
975 DeferTask
976 | DeleteXCom
977 | GetAssetByName
978 | GetAssetByUri
979 | GetAssetEventByAsset
980 | GetAssetEventByAssetAlias
981 | GetConnection
982 | GetDagRun
983 | GetDagRunState
984 | GetDRCount
985 | GetPrevSuccessfulDagRun
986 | GetPreviousDagRun
987 | GetTaskRescheduleStartDate
988 | GetTICount
989 | GetTaskBreadcrumbs
990 | GetTaskStates
991 | GetVariable
992 | GetXCom
993 | GetXComCount
994 | GetXComSequenceItem
995 | GetXComSequenceSlice
996 | PutVariable
997 | RescheduleTask
998 | RetryTask
999 | SetRenderedFields
1000 | SetRenderedMapIndex
1001 | SetXCom
1002 | SkipDownstreamTasks
1003 | SucceedTask
1004 | ValidateInletsAndOutlets
1005 | TaskState
1006 | TriggerDagRun
1007 | DeleteVariable
1008 | ResendLoggingFD
1009 | CreateHITLDetailPayload
1010 | UpdateHITLDetail
1011 | GetHITLDetailResponse
1012 | MaskSecret,
1013 Field(discriminator="type"),
1014]