1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import logging
21import ssl
22import sys
23import uuid
24from functools import cache
25from http import HTTPStatus
26from typing import TYPE_CHECKING, Any, TypeVar
27
28import certifi
29import httpx
30import msgspec
31import structlog
32from pydantic import BaseModel
33from tenacity import (
34 before_log,
35 retry,
36 retry_if_exception,
37 stop_after_attempt,
38 wait_random_exponential,
39)
40from uuid6 import uuid7
41
42from airflow.sdk import __version__
43from airflow.sdk.api.datamodels._generated import (
44 API_VERSION,
45 AssetEventsResponse,
46 AssetResponse,
47 ConnectionResponse,
48 DagRun,
49 DagRunStateResponse,
50 DagRunType,
51 HITLDetailRequest,
52 HITLDetailResponse,
53 HITLUser,
54 InactiveAssetsResponse,
55 PrevSuccessfulDagRunResponse,
56 TaskBreadcrumbsResponse,
57 TaskInstanceState,
58 TaskStatesResponse,
59 TerminalStateNonSuccess,
60 TIDeferredStatePayload,
61 TIEnterRunningPayload,
62 TIHeartbeatInfo,
63 TIRescheduleStatePayload,
64 TIRetryStatePayload,
65 TIRunContext,
66 TISkippedDownstreamTasksStatePayload,
67 TISuccessStatePayload,
68 TITerminalStatePayload,
69 TriggerDAGRunPayload,
70 ValidationError as RemoteValidationError,
71 VariablePostBody,
72 VariableResponse,
73 XComResponse,
74 XComSequenceIndexResponse,
75 XComSequenceSliceResponse,
76)
77from airflow.sdk.configuration import conf
78from airflow.sdk.exceptions import ErrorType
79from airflow.sdk.execution_time.comms import (
80 CreateHITLDetailPayload,
81 DRCount,
82 ErrorResponse,
83 OKResponse,
84 PreviousDagRunResult,
85 PreviousTIResult,
86 SkipDownstreamTasks,
87 TaskRescheduleStartDate,
88 TICount,
89 UpdateHITLDetail,
90 XComCountResponse,
91)
92
93if TYPE_CHECKING:
94 from datetime import datetime
95 from typing import ParamSpec
96
97 from airflow.sdk.execution_time.comms import RescheduleTask
98
99 P = ParamSpec("P")
100 T = TypeVar("T")
101
102 # # methodtools doesn't have typestubs, so give a stub
103 def lru_cache(maxsize: int | None = 128):
104 def wrapper(f):
105 return f
106
107 return wrapper
108else:
109 from methodtools import lru_cache
110
111
112@cache
113def _get_fqdn(name=""):
114 """
115 Get fully qualified domain name from name.
116
117 An empty argument is interpreted as meaning the local host.
118 This is a patched version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254
119 """
120 import socket
121
122 name = name.strip()
123 if not name or name == "0.0.0.0":
124 name = socket.gethostname()
125 try:
126 addrs = socket.getaddrinfo(name, None, 0, socket.SOCK_DGRAM, 0, socket.AI_CANONNAME)
127 except OSError:
128 pass
129 else:
130 for addr in addrs:
131 if addr[3]:
132 name = addr[3]
133 break
134 return name
135
136
137def get_hostname():
138 """Fetch the hostname using the callable from config or use built-in FQDN as a fallback."""
139 return conf.getimport("core", "hostname_callable", fallback=_get_fqdn)()
140
141
142@cache
143def getuser() -> str:
144 """
145 Get the username of the current user, or error with a nice error message if there's no current user.
146
147 We don't want to fall back to os.getuid() because not having a username
148 probably means the rest of the user environment is wrong (e.g. no $HOME).
149 Explicit failure is better than silently trying to work badly.
150 """
151 import getpass
152
153 try:
154 return getpass.getuser()
155 except KeyError:
156 raise ValueError(
157 "The user that Airflow is running as has no username; you must run "
158 "Airflow as a full user, with a username and home directory, "
159 "in order for it to function properly."
160 )
161
162
163log = structlog.get_logger(logger_name=__name__)
164
165__all__ = [
166 "Client",
167 "ConnectionOperations",
168 "ServerResponseError",
169 "TaskInstanceOperations",
170 "get_hostname",
171 "getuser",
172]
173
174
175def get_json_error(response: httpx.Response):
176 """Raise a ServerResponseError if we can extract error info from the error."""
177 err = ServerResponseError.from_response(response)
178 if err:
179 raise err
180
181
182def raise_on_4xx_5xx(response: httpx.Response):
183 return get_json_error(response) or response.raise_for_status()
184
185
186# Py 3.11+ version
187def raise_on_4xx_5xx_with_note(response: httpx.Response):
188 try:
189 return get_json_error(response) or response.raise_for_status()
190 except httpx.HTTPStatusError as e:
191 if TYPE_CHECKING:
192 assert hasattr(e, "add_note")
193 e.add_note(
194 f"Correlation-id={response.headers.get('correlation-id', None) or response.request.headers.get('correlation-id', 'no-correlation-id')}"
195 )
196 raise
197
198
199if hasattr(BaseException, "add_note"):
200 # Py 3.11+
201 raise_on_4xx_5xx = raise_on_4xx_5xx_with_note
202
203
204def add_correlation_id(request: httpx.Request):
205 request.headers["correlation-id"] = str(uuid7())
206
207
208class TaskInstanceOperations:
209 __slots__ = ("client",)
210
211 def __init__(self, client: Client):
212 self.client = client
213
214 def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
215 """Tell the API server that this TI has started running."""
216 body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when)
217
218 resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
219 return TIRunContext.model_validate_json(resp.read())
220
221 def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime, rendered_map_index):
222 """Tell the API server that this TI has reached a terminal state."""
223 if state == TaskInstanceState.SUCCESS:
224 raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead")
225 # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing.
226 body = TITerminalStatePayload(
227 end_date=when, state=TerminalStateNonSuccess(state), rendered_map_index=rendered_map_index
228 )
229 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
230
231 def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index):
232 """Tell the API server that this TI has failed and reached a up_for_retry state."""
233 body = TIRetryStatePayload(end_date=end_date, rendered_map_index=rendered_map_index)
234 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
235
236 def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events, rendered_map_index):
237 """Tell the API server that this TI has succeeded."""
238 body = TISuccessStatePayload(
239 end_date=when,
240 task_outlets=task_outlets,
241 outlet_events=outlet_events,
242 rendered_map_index=rendered_map_index,
243 )
244 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
245
246 def defer(self, id: uuid.UUID, msg):
247 """Tell the API server that this TI has been deferred."""
248 body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True, exclude={"type"}))
249
250 # Create a deferred state payload from msg
251 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
252
253 def reschedule(self, id: uuid.UUID, msg: RescheduleTask):
254 """Tell the API server that this TI has been reschduled."""
255 body = TIRescheduleStatePayload(**msg.model_dump(exclude_unset=True, exclude={"type"}))
256
257 # Create a reschedule state payload from msg
258 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
259
260 def heartbeat(self, id: uuid.UUID, pid: int):
261 body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
262 self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json())
263
264 def skip_downstream_tasks(self, id: uuid.UUID, msg: SkipDownstreamTasks):
265 """Tell the API server to skip the downstream tasks of this TI."""
266 body = TISkippedDownstreamTasksStatePayload(tasks=msg.tasks)
267 self.client.patch(f"task-instances/{id}/skip-downstream", content=body.model_dump_json())
268
269 def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> OKResponse:
270 """Set Rendered Task Instance Fields via the API server."""
271 self.client.put(f"task-instances/{id}/rtif", json=body)
272 # Any error from the server will anyway be propagated down to the supervisor,
273 # so we choose to send a generic response to the supervisor over the server response to
274 # decouple from the server response string
275 return OKResponse(ok=True)
276
277 def set_rendered_map_index(self, id: uuid.UUID, rendered_map_index: str) -> OKResponse:
278 """Set rendered_map_index for a task instance via the API server."""
279 self.client.patch(f"task-instances/{id}/rendered-map-index", json=rendered_map_index)
280 return OKResponse(ok=True)
281
282 def get_previous_successful_dagrun(self, id: uuid.UUID) -> PrevSuccessfulDagRunResponse:
283 """
284 Get the previous successful dag run for a given task instance.
285
286 The data from it is used to get values for Task Context.
287 """
288 resp = self.client.get(f"task-instances/{id}/previous-successful-dagrun")
289 return PrevSuccessfulDagRunResponse.model_validate_json(resp.read())
290
291 def get_reschedule_start_date(self, id: uuid.UUID, try_number: int = 1) -> TaskRescheduleStartDate:
292 """Get the start date of a task reschedule via the API server."""
293 resp = self.client.get(f"task-reschedules/{id}/start_date", params={"try_number": try_number})
294 return TaskRescheduleStartDate(start_date=resp.json())
295
296 def get_count(
297 self,
298 dag_id: str,
299 map_index: int | None = None,
300 task_ids: list[str] | None = None,
301 task_group_id: str | None = None,
302 logical_dates: list[datetime] | None = None,
303 run_ids: list[str] | None = None,
304 states: list[str] | None = None,
305 ) -> TICount:
306 """Get count of task instances matching the given criteria."""
307 params: dict[str, Any]
308 params = {
309 "dag_id": dag_id,
310 "task_ids": task_ids,
311 "task_group_id": task_group_id,
312 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None,
313 "run_ids": run_ids,
314 "states": states,
315 }
316
317 # Remove None values from params
318 params = {k: v for k, v in params.items() if v is not None}
319
320 if map_index is not None and map_index >= 0:
321 params.update({"map_index": map_index})
322
323 resp = self.client.get("task-instances/count", params=params)
324 return TICount(count=resp.json())
325
326 def get_previous(
327 self,
328 dag_id: str,
329 task_id: str,
330 logical_date: datetime | None = None,
331 map_index: int = -1,
332 state: TaskInstanceState | str | None = None,
333 ) -> PreviousTIResult:
334 """
335 Get the previous task instance matching the given criteria.
336
337 :param dag_id: DAG ID
338 :param task_id: Task ID
339 :param logical_date: If provided, finds TI with logical_date < this value (before filter)
340 :param map_index: Map index to filter by (defaults to -1 for non-mapped tasks)
341 :param state: If provided, filters by TaskInstance state
342 """
343 params: dict[str, Any] = {"map_index": map_index}
344 if logical_date:
345 params["logical_date"] = logical_date.isoformat()
346 if state:
347 params["state"] = state.value if isinstance(state, TaskInstanceState) else state
348
349 resp = self.client.get(f"task-instances/previous/{dag_id}/{task_id}", params=params)
350 return PreviousTIResult(task_instance=resp.json())
351
352 def get_task_states(
353 self,
354 dag_id: str,
355 map_index: int | None = None,
356 task_ids: list[str] | None = None,
357 task_group_id: str | None = None,
358 logical_dates: list[datetime] | None = None,
359 run_ids: list[str] | None = None,
360 ) -> TaskStatesResponse:
361 """Get task states given criteria."""
362 params: dict[str, Any]
363 params = {
364 "dag_id": dag_id,
365 "task_ids": task_ids,
366 "task_group_id": task_group_id,
367 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None,
368 "run_ids": run_ids,
369 }
370
371 # Remove None values from params
372 params = {k: v for k, v in params.items() if v is not None}
373
374 if map_index is not None and map_index >= 0:
375 params.update({"map_index": map_index})
376
377 resp = self.client.get("task-instances/states", params=params)
378 return TaskStatesResponse.model_validate_json(resp.read())
379
380 def get_task_breakcrumbs(self, dag_id: str, run_id: str) -> TaskBreadcrumbsResponse:
381 params = {"dag_id": dag_id, "run_id": run_id}
382 resp = self.client.get("task-instances/breadcrumbs", params=params)
383 return TaskBreadcrumbsResponse.model_validate_json(resp.read())
384
385 def validate_inlets_and_outlets(self, id: uuid.UUID) -> InactiveAssetsResponse:
386 """Validate whether there're inactive assets in inlets and outlets of a given task instance."""
387 resp = self.client.get(f"task-instances/{id}/validate-inlets-and-outlets")
388 return InactiveAssetsResponse.model_validate_json(resp.read())
389
390
391class ConnectionOperations:
392 __slots__ = ("client",)
393
394 def __init__(self, client: Client):
395 self.client = client
396
397 def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse:
398 """Get a connection from the API server."""
399 try:
400 resp = self.client.get(f"connections/{conn_id}")
401 except ServerResponseError as e:
402 if e.response.status_code == HTTPStatus.NOT_FOUND:
403 log.debug(
404 "Connection not found",
405 conn_id=conn_id,
406 detail=e.detail,
407 status_code=e.response.status_code,
408 )
409 return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": conn_id})
410 raise
411 return ConnectionResponse.model_validate_json(resp.read())
412
413
414class VariableOperations:
415 __slots__ = ("client",)
416
417 def __init__(self, client: Client):
418 self.client = client
419
420 def get(self, key: str) -> VariableResponse | ErrorResponse:
421 """Get a variable from the API server."""
422 try:
423 resp = self.client.get(f"variables/{key}")
424 except ServerResponseError as e:
425 if e.response.status_code == HTTPStatus.NOT_FOUND:
426 log.error(
427 "Variable not found",
428 key=key,
429 detail=e.detail,
430 status_code=e.response.status_code,
431 )
432 return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"key": key})
433 raise
434 return VariableResponse.model_validate_json(resp.read())
435
436 def set(self, key: str, value: str | None, description: str | None = None) -> OKResponse:
437 """Set an Airflow Variable via the API server."""
438 body = VariablePostBody(val=value, description=description)
439 self.client.put(f"variables/{key}", content=body.model_dump_json())
440 # Any error from the server will anyway be propagated down to the supervisor,
441 # so we choose to send a generic response to the supervisor over the server response to
442 # decouple from the server response string
443 return OKResponse(ok=True)
444
445 def delete(
446 self,
447 key: str,
448 ) -> OKResponse:
449 """Delete a variable with given key via the API server."""
450 self.client.delete(f"variables/{key}")
451 # Any error from the server will anyway be propagated down to the supervisor,
452 # so we choose to send a generic response to the supervisor over the server response to
453 # decouple from the server response string
454 return OKResponse(ok=True)
455
456
457class XComOperations:
458 __slots__ = ("client",)
459
460 def __init__(self, client: Client):
461 self.client = client
462
463 def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> XComCountResponse:
464 """Get the number of mapped XCom values."""
465 resp = self.client.head(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}")
466
467 # content_range: str | None
468 if not (content_range := resp.headers["Content-Range"]) or not content_range.startswith(
469 "map_indexes "
470 ):
471 raise RuntimeError(f"Unable to parse Content-Range header from HEAD {resp.request.url}")
472 return XComCountResponse(len=int(content_range[len("map_indexes ") :]))
473
474 def get(
475 self,
476 dag_id: str,
477 run_id: str,
478 task_id: str,
479 key: str,
480 map_index: int | None = None,
481 include_prior_dates: bool = False,
482 ) -> XComResponse:
483 """Get a XCom value from the API server."""
484 # TODO: check if we need to use map_index as params in the uri
485 # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
486 params = {}
487 if map_index is not None and map_index >= 0:
488 params.update({"map_index": map_index})
489 if include_prior_dates:
490 params.update({"include_prior_dates": include_prior_dates})
491 try:
492 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
493 except ServerResponseError as e:
494 if e.response.status_code == HTTPStatus.NOT_FOUND:
495 log.error(
496 "XCom not found",
497 dag_id=dag_id,
498 run_id=run_id,
499 task_id=task_id,
500 key=key,
501 map_index=map_index,
502 detail=e.detail,
503 status_code=e.response.status_code,
504 )
505 # Airflow 2.x just ignores the absence of an XCom and moves on with a return value of None
506 # Hence returning with key as `key` and value as `None`, so that the message is sent back to task runner
507 # and the default value of None in xcom_pull is used.
508 return XComResponse(key=key, value=None)
509 raise
510 return XComResponse.model_validate_json(resp.read())
511
512 def set(
513 self,
514 dag_id: str,
515 run_id: str,
516 task_id: str,
517 key: str,
518 value,
519 map_index: int | None = None,
520 mapped_length: int | None = None,
521 ) -> OKResponse:
522 """Set a XCom value via the API server."""
523 # TODO: check if we need to use map_index as params in the uri
524 # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
525 params = {}
526 if map_index is not None and map_index >= 0:
527 params = {"map_index": map_index}
528 if mapped_length is not None and mapped_length >= 0:
529 params["mapped_length"] = mapped_length
530 self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value)
531 # Any error from the server will anyway be propagated down to the supervisor,
532 # so we choose to send a generic response to the supervisor over the server response to
533 # decouple from the server response string
534 return OKResponse(ok=True)
535
536 def delete(
537 self,
538 dag_id: str,
539 run_id: str,
540 task_id: str,
541 key: str,
542 map_index: int | None = None,
543 ) -> OKResponse:
544 """Delete a XCom with given key via the API server."""
545 params = {}
546 if map_index is not None and map_index >= 0:
547 params = {"map_index": map_index}
548 self.client.delete(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
549 # Any error from the server will anyway be propagated down to the supervisor,
550 # so we choose to send a generic response to the supervisor over the server response to
551 # decouple from the server response string
552 return OKResponse(ok=True)
553
554 def get_sequence_item(
555 self,
556 dag_id: str,
557 run_id: str,
558 task_id: str,
559 key: str,
560 offset: int,
561 ) -> XComSequenceIndexResponse | ErrorResponse:
562 try:
563 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}")
564 except ServerResponseError as e:
565 if e.response.status_code == HTTPStatus.NOT_FOUND:
566 log.error(
567 "XCom not found",
568 dag_id=dag_id,
569 run_id=run_id,
570 task_id=task_id,
571 key=key,
572 offset=offset,
573 detail=e.detail,
574 status_code=e.response.status_code,
575 )
576 return ErrorResponse(
577 error=ErrorType.XCOM_NOT_FOUND,
578 detail={
579 "dag_id": dag_id,
580 "run_id": run_id,
581 "task_id": task_id,
582 "key": key,
583 "offset": offset,
584 },
585 )
586 raise
587 return XComSequenceIndexResponse.model_validate_json(resp.read())
588
589 def get_sequence_slice(
590 self,
591 dag_id: str,
592 run_id: str,
593 task_id: str,
594 key: str,
595 start: int | None,
596 stop: int | None,
597 step: int | None,
598 include_prior_dates: bool = False,
599 ) -> XComSequenceSliceResponse:
600 params = {}
601 if start is not None:
602 params["start"] = start
603 if stop is not None:
604 params["stop"] = stop
605 if step is not None:
606 params["step"] = step
607 if include_prior_dates:
608 params["include_prior_dates"] = include_prior_dates
609 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params)
610 return XComSequenceSliceResponse.model_validate_json(resp.read())
611
612
613class AssetOperations:
614 __slots__ = ("client",)
615
616 def __init__(self, client: Client):
617 self.client = client
618
619 def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse | ErrorResponse:
620 """Get Asset value from the API server."""
621 if name:
622 endpoint = "assets/by-name"
623 params = {"name": name}
624 elif uri:
625 endpoint = "assets/by-uri"
626 params = {"uri": uri}
627 else:
628 raise ValueError("Either `name` or `uri` must be provided")
629
630 try:
631 resp = self.client.get(endpoint, params=params)
632 except ServerResponseError as e:
633 if e.response.status_code == HTTPStatus.NOT_FOUND:
634 log.error(
635 "Asset not found",
636 params=params,
637 detail=e.detail,
638 status_code=e.response.status_code,
639 )
640 return ErrorResponse(error=ErrorType.ASSET_NOT_FOUND, detail=params)
641 raise
642
643 return AssetResponse.model_validate_json(resp.read())
644
645
646class AssetEventOperations:
647 __slots__ = ("client",)
648
649 def __init__(self, client: Client):
650 self.client = client
651
652 def get(
653 self,
654 name: str | None = None,
655 uri: str | None = None,
656 alias_name: str | None = None,
657 after: datetime | None = None,
658 before: datetime | None = None,
659 ascending: bool = True,
660 limit: int | None = None,
661 ) -> AssetEventsResponse:
662 """Get Asset event from the API server."""
663 common_params: dict[str, Any] = {}
664 if after:
665 common_params["after"] = after.isoformat()
666 if before:
667 common_params["before"] = before.isoformat()
668 common_params["ascending"] = ascending
669 if limit:
670 common_params["limit"] = limit
671 if name or uri:
672 resp = self.client.get(
673 "asset-events/by-asset", params={"name": name, "uri": uri, **common_params}
674 )
675 elif alias_name:
676 resp = self.client.get(
677 "asset-events/by-asset-alias", params={"name": alias_name, **common_params}
678 )
679 else:
680 raise ValueError("Either `name`, `uri` or `alias_name` must be provided")
681
682 return AssetEventsResponse.model_validate_json(resp.read())
683
684
685class DagRunOperations:
686 __slots__ = ("client",)
687
688 def __init__(self, client: Client):
689 self.client = client
690
691 def trigger(
692 self,
693 dag_id: str,
694 run_id: str,
695 conf: dict | None = None,
696 logical_date: datetime | None = None,
697 reset_dag_run: bool = False,
698 ) -> OKResponse | ErrorResponse:
699 """Trigger a Dag run via the API server."""
700 body = TriggerDAGRunPayload(logical_date=logical_date, conf=conf or {}, reset_dag_run=reset_dag_run)
701
702 try:
703 self.client.post(
704 f"dag-runs/{dag_id}/{run_id}", content=body.model_dump_json(exclude_defaults=True)
705 )
706 except ServerResponseError as e:
707 if e.response.status_code == HTTPStatus.CONFLICT:
708 if reset_dag_run:
709 log.info("Dag Run already exists; Resetting Dag Run.", dag_id=dag_id, run_id=run_id)
710 return self.clear(run_id=run_id, dag_id=dag_id)
711
712 log.info("Dag Run already exists!", detail=e.detail, dag_id=dag_id, run_id=run_id)
713 return ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS)
714 raise
715
716 return OKResponse(ok=True)
717
718 def clear(self, dag_id: str, run_id: str) -> OKResponse:
719 """Clear a Dag run via the API server."""
720 self.client.post(f"dag-runs/{dag_id}/{run_id}/clear")
721 # TODO: Error handling
722 return OKResponse(ok=True)
723
724 def get_detail(self, dag_id: str, run_id: str) -> DagRun:
725 """Get detail of a dag run."""
726 resp = self.client.get(f"dag-runs/{dag_id}/{run_id}")
727 return DagRun.model_validate_json(resp.read())
728
729 def get_state(self, dag_id: str, run_id: str) -> DagRunStateResponse:
730 """Get the state of a Dag run via the API server."""
731 resp = self.client.get(f"dag-runs/{dag_id}/{run_id}/state")
732 return DagRunStateResponse.model_validate_json(resp.read())
733
734 def get_count(
735 self,
736 dag_id: str,
737 logical_dates: list[datetime] | None = None,
738 run_ids: list[str] | None = None,
739 states: list[str] | None = None,
740 ) -> DRCount:
741 """Get count of Dag runs matching the given criteria."""
742 params = {
743 "dag_id": dag_id,
744 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None,
745 "run_ids": run_ids,
746 "states": states,
747 }
748
749 # Remove None values from params
750 params = {k: v for k, v in params.items() if v is not None}
751
752 resp = self.client.get("dag-runs/count", params=params)
753 return DRCount(count=resp.json())
754
755 def get_previous(
756 self,
757 dag_id: str,
758 logical_date: datetime,
759 state: str | None = None,
760 ) -> PreviousDagRunResult:
761 """Get the previous DAG run before the given logical date, optionally filtered by state."""
762 params = {
763 "dag_id": dag_id,
764 "logical_date": logical_date.isoformat(),
765 }
766 if state:
767 params["state"] = state
768 resp = self.client.get("dag-runs/previous", params=params)
769 return PreviousDagRunResult(dag_run=resp.json())
770
771
772class HITLOperations:
773 """
774 Operations related to Human in the loop. Require Airflow 3.1+.
775
776 :meta: private
777 """
778
779 __slots__ = ("client",)
780
781 def __init__(self, client: Client) -> None:
782 self.client = client
783
784 def add_response(
785 self,
786 *,
787 ti_id: uuid.UUID,
788 options: list[str],
789 subject: str,
790 body: str | None = None,
791 defaults: list[str] | None = None,
792 multiple: bool = False,
793 params: dict[str, dict[str, Any]] | None = None,
794 assigned_users: list[HITLUser] | None = None,
795 ) -> HITLDetailRequest:
796 """Add a Human-in-the-loop response that waits for human response for a specific Task Instance."""
797 payload = CreateHITLDetailPayload(
798 ti_id=ti_id,
799 options=options,
800 subject=subject,
801 body=body,
802 defaults=defaults,
803 multiple=multiple,
804 params=params,
805 assigned_users=assigned_users,
806 )
807 resp = self.client.post(
808 f"/hitlDetails/{ti_id}",
809 content=payload.model_dump_json(),
810 )
811 return HITLDetailRequest.model_validate_json(resp.read())
812
813 def update_response(
814 self,
815 *,
816 ti_id: uuid.UUID,
817 chosen_options: list[str],
818 params_input: dict[str, Any],
819 ) -> HITLDetailResponse:
820 """Update an existing Human-in-the-loop response."""
821 payload = UpdateHITLDetail(
822 ti_id=ti_id,
823 chosen_options=chosen_options,
824 params_input=params_input,
825 )
826 resp = self.client.patch(
827 f"/hitlDetails/{ti_id}",
828 content=payload.model_dump_json(),
829 )
830 return HITLDetailResponse.model_validate_json(resp.read())
831
832 def get_detail_response(self, ti_id: uuid.UUID) -> HITLDetailResponse:
833 """Get content part of a Human-in-the-loop response for a specific Task Instance."""
834 resp = self.client.get(f"/hitlDetails/{ti_id}")
835 return HITLDetailResponse.model_validate_json(resp.read())
836
837
838class BearerAuth(httpx.Auth):
839 def __init__(self, token: str):
840 self.token: str = token
841
842 def auth_flow(self, request: httpx.Request):
843 if self.token:
844 request.headers["Authorization"] = "Bearer " + self.token
845 yield request
846
847
848# This exists as an aid for debugging or local running via the `dry_run` argument to Client. It doesn't make
849# sense for returning connections etc.
850def noop_handler(request: httpx.Request) -> httpx.Response:
851 path = request.url.path
852 log.debug("Dry-run request", method=request.method, path=path)
853
854 if path.startswith("/task-instances/") and path.endswith("/run"):
855 # Return a fake context
856 return httpx.Response(
857 200,
858 json={
859 "dag_run": {
860 "dag_id": "test_dag",
861 "run_id": "test_run",
862 "logical_date": "2021-01-01T00:00:00Z",
863 "start_date": "2021-01-01T00:00:00Z",
864 "run_type": DagRunType.MANUAL,
865 "run_after": "2021-01-01T00:00:00Z",
866 "consumed_asset_events": [],
867 },
868 "max_tries": 0,
869 "should_retry": False,
870 },
871 )
872 return httpx.Response(200, json={"text": "Hello, world!"})
873
874
875# Note: Given defaults make attempts after 1, 3, 7, 15 and fails after 31seconds
876API_RETRIES = conf.getint("workers", "execution_api_retries")
877API_RETRY_WAIT_MIN = conf.getfloat("workers", "execution_api_retry_wait_min")
878API_RETRY_WAIT_MAX = conf.getfloat("workers", "execution_api_retry_wait_max")
879API_SSL_CERT_PATH = conf.get("api", "ssl_cert")
880API_TIMEOUT = conf.getfloat("workers", "execution_api_timeout")
881
882
883def _should_retry_api_request(exception: BaseException) -> bool:
884 """Determine if an API request should be retried based on the exception type."""
885 if isinstance(exception, httpx.HTTPStatusError):
886 return exception.response.status_code >= 500
887
888 return isinstance(exception, httpx.RequestError)
889
890
891class Client(httpx.Client):
892 @lru_cache()
893 @staticmethod
894 def _get_ssl_context_cached(ca_file: str, ca_path: str | None = None) -> ssl.SSLContext:
895 """Cache SSL context to prevent memory growth from repeated context creation."""
896 ctx = ssl.create_default_context(cafile=ca_file)
897 if ca_path:
898 ctx.load_verify_locations(ca_path)
899 return ctx
900
901 def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any):
902 if (not base_url) ^ dry_run:
903 raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}")
904 auth = BearerAuth(token)
905
906 if dry_run:
907 # If dry run is requested, install a no op handler so that simple tasks can "heartbeat" using a
908 # real client, but just don't make any HTTP requests
909 kwargs.setdefault("transport", httpx.MockTransport(noop_handler))
910 kwargs.setdefault("base_url", "dry-run://server")
911 else:
912 kwargs["base_url"] = base_url
913 # Call via the class to avoid binding lru_cache wires to this instance.
914 kwargs["verify"] = type(self)._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH)
915
916 # Set timeout if not explicitly provided
917 kwargs.setdefault("timeout", API_TIMEOUT)
918
919 pyver = f"{'.'.join(map(str, sys.version_info[:3]))}"
920 super().__init__(
921 auth=auth,
922 headers={
923 "user-agent": f"apache-airflow-task-sdk/{__version__} (Python/{pyver})",
924 "airflow-api-version": API_VERSION,
925 },
926 event_hooks={"response": [self._update_auth, raise_on_4xx_5xx], "request": [add_correlation_id]},
927 **kwargs,
928 )
929
930 def _update_auth(self, response: httpx.Response):
931 if new_token := response.headers.get("Refreshed-API-Token"):
932 log.debug("Execution API issued us a refreshed Task token")
933 self.auth = BearerAuth(new_token)
934
935 @retry(
936 retry=retry_if_exception(_should_retry_api_request),
937 stop=stop_after_attempt(API_RETRIES),
938 wait=wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX),
939 before_sleep=before_log(log, logging.WARNING),
940 reraise=True,
941 )
942 def request(self, *args, **kwargs):
943 """Implement a convenience for httpx.Client.request with a retry layer."""
944 # Set content type as convenience if not already set
945 if kwargs.get("content", None) is not None and "content-type" not in (
946 kwargs.get("headers", {}) or {}
947 ):
948 kwargs["headers"] = {"content-type": "application/json"}
949
950 return super().request(*args, **kwargs)
951
952 # We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all
953 # methods on one object prefixed with the object type (`.task_instances.update` rather than
954 # `task_instance_update` etc.)
955
956 @lru_cache() # type: ignore[misc]
957 @property
958 def task_instances(self) -> TaskInstanceOperations:
959 """Operations related to TaskInstances."""
960 return TaskInstanceOperations(self)
961
962 @lru_cache() # type: ignore[misc]
963 @property
964 def dag_runs(self) -> DagRunOperations:
965 """Operations related to DagRuns."""
966 return DagRunOperations(self)
967
968 @lru_cache() # type: ignore[misc]
969 @property
970 def connections(self) -> ConnectionOperations:
971 """Operations related to Connections."""
972 return ConnectionOperations(self)
973
974 @lru_cache() # type: ignore[misc]
975 @property
976 def variables(self) -> VariableOperations:
977 """Operations related to Variables."""
978 return VariableOperations(self)
979
980 @lru_cache() # type: ignore[misc]
981 @property
982 def xcoms(self) -> XComOperations:
983 """Operations related to XComs."""
984 return XComOperations(self)
985
986 @lru_cache() # type: ignore[misc]
987 @property
988 def assets(self) -> AssetOperations:
989 """Operations related to Assets."""
990 return AssetOperations(self)
991
992 @lru_cache() # type: ignore[misc]
993 @property
994 def asset_events(self) -> AssetEventOperations:
995 """Operations related to Asset Events."""
996 return AssetEventOperations(self)
997
998 @lru_cache() # type: ignore[misc]
999 @property
1000 def hitl(self):
1001 """Operations related to HITL Responses."""
1002 return HITLOperations(self)
1003
1004
1005# This is only used for parsing. ServerResponseError is raised instead
1006class _ErrorBody(BaseModel):
1007 detail: list[RemoteValidationError] | str
1008
1009 def __repr__(self):
1010 return repr(self.detail)
1011
1012
1013class ServerResponseError(httpx.HTTPStatusError):
1014 def __init__(self, message: str, *, request: httpx.Request, response: httpx.Response):
1015 super().__init__(message, request=request, response=response)
1016
1017 detail: list[RemoteValidationError] | str | dict[str, Any] | None
1018
1019 def __reduce__(self) -> tuple[Any, ...]:
1020 # Needed because https://github.com/encode/httpx/pull/3108 isn't merged yet.
1021 return Exception.__new__, (type(self),) + self.args, self.__dict__
1022
1023 @classmethod
1024 def from_response(cls, response: httpx.Response) -> ServerResponseError | None:
1025 if response.is_success:
1026 return None
1027 # 4xx or 5xx error?
1028 if not (400 <= response.status_code < 600):
1029 return None
1030
1031 if response.headers.get("content-type") != "application/json":
1032 return None
1033
1034 detail: list[RemoteValidationError] | dict[str, Any] | None = None
1035 try:
1036 body = _ErrorBody.model_validate_json(response.read())
1037
1038 if isinstance(body.detail, list):
1039 detail = body.detail
1040 msg = "Remote server returned validation error"
1041 else:
1042 msg = body.detail or "Un-parseable error"
1043 except Exception:
1044 try:
1045 detail = msgspec.json.decode(response.content)
1046 except Exception:
1047 # Fallback to a normal httpx error
1048 return None
1049 msg = "Server returned error"
1050
1051 self = cls(msg, request=response.request, response=response)
1052 self.detail = detail
1053 return self