1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import logging
21import ssl
22import sys
23import uuid
24from functools import cache
25from http import HTTPStatus
26from typing import TYPE_CHECKING, Any, TypeVar
27
28import certifi
29import httpx
30import msgspec
31import structlog
32from pydantic import BaseModel
33from tenacity import (
34 before_log,
35 retry,
36 retry_if_exception,
37 stop_after_attempt,
38 wait_random_exponential,
39)
40from uuid6 import uuid7
41
42from airflow.sdk import __version__
43from airflow.sdk.api.datamodels._generated import (
44 API_VERSION,
45 AssetEventsResponse,
46 AssetResponse,
47 ConnectionResponse,
48 DagRun,
49 DagRunStateResponse,
50 DagRunType,
51 HITLDetailRequest,
52 HITLDetailResponse,
53 HITLUser,
54 InactiveAssetsResponse,
55 PrevSuccessfulDagRunResponse,
56 TaskBreadcrumbsResponse,
57 TaskInstanceState,
58 TaskStatesResponse,
59 TerminalStateNonSuccess,
60 TIDeferredStatePayload,
61 TIEnterRunningPayload,
62 TIHeartbeatInfo,
63 TIRescheduleStatePayload,
64 TIRetryStatePayload,
65 TIRunContext,
66 TISkippedDownstreamTasksStatePayload,
67 TISuccessStatePayload,
68 TITerminalStatePayload,
69 TriggerDAGRunPayload,
70 ValidationError as RemoteValidationError,
71 VariablePostBody,
72 VariableResponse,
73 XComResponse,
74 XComSequenceIndexResponse,
75 XComSequenceSliceResponse,
76)
77from airflow.sdk.configuration import conf
78from airflow.sdk.exceptions import ErrorType
79from airflow.sdk.execution_time.comms import (
80 CreateHITLDetailPayload,
81 DRCount,
82 ErrorResponse,
83 OKResponse,
84 PreviousDagRunResult,
85 SkipDownstreamTasks,
86 TaskRescheduleStartDate,
87 TICount,
88 UpdateHITLDetail,
89 XComCountResponse,
90)
91
92if TYPE_CHECKING:
93 from datetime import datetime
94 from typing import ParamSpec
95
96 from airflow.sdk.execution_time.comms import RescheduleTask
97
98 P = ParamSpec("P")
99 T = TypeVar("T")
100
101 # # methodtools doesn't have typestubs, so give a stub
102 def lru_cache(maxsize: int | None = 128):
103 def wrapper(f):
104 return f
105
106 return wrapper
107else:
108 from methodtools import lru_cache
109
110
111@cache
112def _get_fqdn(name=""):
113 """
114 Get fully qualified domain name from name.
115
116 An empty argument is interpreted as meaning the local host.
117 This is a patched version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254
118 """
119 import socket
120
121 name = name.strip()
122 if not name or name == "0.0.0.0":
123 name = socket.gethostname()
124 try:
125 addrs = socket.getaddrinfo(name, None, 0, socket.SOCK_DGRAM, 0, socket.AI_CANONNAME)
126 except OSError:
127 pass
128 else:
129 for addr in addrs:
130 if addr[3]:
131 name = addr[3]
132 break
133 return name
134
135
136def get_hostname():
137 """Fetch the hostname using the callable from config or use built-in FQDN as a fallback."""
138 return conf.getimport("core", "hostname_callable", fallback=_get_fqdn)()
139
140
141@cache
142def getuser() -> str:
143 """
144 Get the username of the current user, or error with a nice error message if there's no current user.
145
146 We don't want to fall back to os.getuid() because not having a username
147 probably means the rest of the user environment is wrong (e.g. no $HOME).
148 Explicit failure is better than silently trying to work badly.
149 """
150 import getpass
151
152 try:
153 return getpass.getuser()
154 except KeyError:
155 raise ValueError(
156 "The user that Airflow is running as has no username; you must run "
157 "Airflow as a full user, with a username and home directory, "
158 "in order for it to function properly."
159 )
160
161
162log = structlog.get_logger(logger_name=__name__)
163
164__all__ = [
165 "Client",
166 "ConnectionOperations",
167 "ServerResponseError",
168 "TaskInstanceOperations",
169 "get_hostname",
170 "getuser",
171]
172
173
174def get_json_error(response: httpx.Response):
175 """Raise a ServerResponseError if we can extract error info from the error."""
176 err = ServerResponseError.from_response(response)
177 if err:
178 raise err
179
180
181def raise_on_4xx_5xx(response: httpx.Response):
182 return get_json_error(response) or response.raise_for_status()
183
184
185# Py 3.11+ version
186def raise_on_4xx_5xx_with_note(response: httpx.Response):
187 try:
188 return get_json_error(response) or response.raise_for_status()
189 except httpx.HTTPStatusError as e:
190 if TYPE_CHECKING:
191 assert hasattr(e, "add_note")
192 e.add_note(
193 f"Correlation-id={response.headers.get('correlation-id', None) or response.request.headers.get('correlation-id', 'no-correlation-id')}"
194 )
195 raise
196
197
198if hasattr(BaseException, "add_note"):
199 # Py 3.11+
200 raise_on_4xx_5xx = raise_on_4xx_5xx_with_note
201
202
203def add_correlation_id(request: httpx.Request):
204 request.headers["correlation-id"] = str(uuid7())
205
206
207class TaskInstanceOperations:
208 __slots__ = ("client",)
209
210 def __init__(self, client: Client):
211 self.client = client
212
213 def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
214 """Tell the API server that this TI has started running."""
215 body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when)
216
217 resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
218 return TIRunContext.model_validate_json(resp.read())
219
220 def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime, rendered_map_index):
221 """Tell the API server that this TI has reached a terminal state."""
222 if state == TaskInstanceState.SUCCESS:
223 raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead")
224 # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing.
225 body = TITerminalStatePayload(
226 end_date=when, state=TerminalStateNonSuccess(state), rendered_map_index=rendered_map_index
227 )
228 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
229
230 def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index):
231 """Tell the API server that this TI has failed and reached a up_for_retry state."""
232 body = TIRetryStatePayload(end_date=end_date, rendered_map_index=rendered_map_index)
233 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
234
235 def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events, rendered_map_index):
236 """Tell the API server that this TI has succeeded."""
237 body = TISuccessStatePayload(
238 end_date=when,
239 task_outlets=task_outlets,
240 outlet_events=outlet_events,
241 rendered_map_index=rendered_map_index,
242 )
243 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
244
245 def defer(self, id: uuid.UUID, msg):
246 """Tell the API server that this TI has been deferred."""
247 body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True, exclude={"type"}))
248
249 # Create a deferred state payload from msg
250 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
251
252 def reschedule(self, id: uuid.UUID, msg: RescheduleTask):
253 """Tell the API server that this TI has been reschduled."""
254 body = TIRescheduleStatePayload(**msg.model_dump(exclude_unset=True, exclude={"type"}))
255
256 # Create a reschedule state payload from msg
257 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())
258
259 def heartbeat(self, id: uuid.UUID, pid: int):
260 body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
261 self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json())
262
263 def skip_downstream_tasks(self, id: uuid.UUID, msg: SkipDownstreamTasks):
264 """Tell the API server to skip the downstream tasks of this TI."""
265 body = TISkippedDownstreamTasksStatePayload(tasks=msg.tasks)
266 self.client.patch(f"task-instances/{id}/skip-downstream", content=body.model_dump_json())
267
268 def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> OKResponse:
269 """Set Rendered Task Instance Fields via the API server."""
270 self.client.put(f"task-instances/{id}/rtif", json=body)
271 # Any error from the server will anyway be propagated down to the supervisor,
272 # so we choose to send a generic response to the supervisor over the server response to
273 # decouple from the server response string
274 return OKResponse(ok=True)
275
276 def set_rendered_map_index(self, id: uuid.UUID, rendered_map_index: str) -> OKResponse:
277 """Set rendered_map_index for a task instance via the API server."""
278 self.client.patch(f"task-instances/{id}/rendered-map-index", json=rendered_map_index)
279 return OKResponse(ok=True)
280
281 def get_previous_successful_dagrun(self, id: uuid.UUID) -> PrevSuccessfulDagRunResponse:
282 """
283 Get the previous successful dag run for a given task instance.
284
285 The data from it is used to get values for Task Context.
286 """
287 resp = self.client.get(f"task-instances/{id}/previous-successful-dagrun")
288 return PrevSuccessfulDagRunResponse.model_validate_json(resp.read())
289
290 def get_reschedule_start_date(self, id: uuid.UUID, try_number: int = 1) -> TaskRescheduleStartDate:
291 """Get the start date of a task reschedule via the API server."""
292 resp = self.client.get(f"task-reschedules/{id}/start_date", params={"try_number": try_number})
293 return TaskRescheduleStartDate(start_date=resp.json())
294
295 def get_count(
296 self,
297 dag_id: str,
298 map_index: int | None = None,
299 task_ids: list[str] | None = None,
300 task_group_id: str | None = None,
301 logical_dates: list[datetime] | None = None,
302 run_ids: list[str] | None = None,
303 states: list[str] | None = None,
304 ) -> TICount:
305 """Get count of task instances matching the given criteria."""
306 params: dict[str, Any]
307 params = {
308 "dag_id": dag_id,
309 "task_ids": task_ids,
310 "task_group_id": task_group_id,
311 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None,
312 "run_ids": run_ids,
313 "states": states,
314 }
315
316 # Remove None values from params
317 params = {k: v for k, v in params.items() if v is not None}
318
319 if map_index is not None and map_index >= 0:
320 params.update({"map_index": map_index})
321
322 resp = self.client.get("task-instances/count", params=params)
323 return TICount(count=resp.json())
324
325 def get_task_states(
326 self,
327 dag_id: str,
328 map_index: int | None = None,
329 task_ids: list[str] | None = None,
330 task_group_id: str | None = None,
331 logical_dates: list[datetime] | None = None,
332 run_ids: list[str] | None = None,
333 ) -> TaskStatesResponse:
334 """Get task states given criteria."""
335 params: dict[str, Any]
336 params = {
337 "dag_id": dag_id,
338 "task_ids": task_ids,
339 "task_group_id": task_group_id,
340 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None,
341 "run_ids": run_ids,
342 }
343
344 # Remove None values from params
345 params = {k: v for k, v in params.items() if v is not None}
346
347 if map_index is not None and map_index >= 0:
348 params.update({"map_index": map_index})
349
350 resp = self.client.get("task-instances/states", params=params)
351 return TaskStatesResponse.model_validate_json(resp.read())
352
353 def get_task_breakcrumbs(self, dag_id: str, run_id: str) -> TaskBreadcrumbsResponse:
354 params = {"dag_id": dag_id, "run_id": run_id}
355 resp = self.client.get("task-instances/breadcrumbs", params=params)
356 return TaskBreadcrumbsResponse.model_validate_json(resp.read())
357
358 def validate_inlets_and_outlets(self, id: uuid.UUID) -> InactiveAssetsResponse:
359 """Validate whether there're inactive assets in inlets and outlets of a given task instance."""
360 resp = self.client.get(f"task-instances/{id}/validate-inlets-and-outlets")
361 return InactiveAssetsResponse.model_validate_json(resp.read())
362
363
364class ConnectionOperations:
365 __slots__ = ("client",)
366
367 def __init__(self, client: Client):
368 self.client = client
369
370 def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse:
371 """Get a connection from the API server."""
372 try:
373 resp = self.client.get(f"connections/{conn_id}")
374 except ServerResponseError as e:
375 if e.response.status_code == HTTPStatus.NOT_FOUND:
376 log.debug(
377 "Connection not found",
378 conn_id=conn_id,
379 detail=e.detail,
380 status_code=e.response.status_code,
381 )
382 return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": conn_id})
383 raise
384 return ConnectionResponse.model_validate_json(resp.read())
385
386
387class VariableOperations:
388 __slots__ = ("client",)
389
390 def __init__(self, client: Client):
391 self.client = client
392
393 def get(self, key: str) -> VariableResponse | ErrorResponse:
394 """Get a variable from the API server."""
395 try:
396 resp = self.client.get(f"variables/{key}")
397 except ServerResponseError as e:
398 if e.response.status_code == HTTPStatus.NOT_FOUND:
399 log.error(
400 "Variable not found",
401 key=key,
402 detail=e.detail,
403 status_code=e.response.status_code,
404 )
405 return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"key": key})
406 raise
407 return VariableResponse.model_validate_json(resp.read())
408
409 def set(self, key: str, value: str | None, description: str | None = None) -> OKResponse:
410 """Set an Airflow Variable via the API server."""
411 body = VariablePostBody(val=value, description=description)
412 self.client.put(f"variables/{key}", content=body.model_dump_json())
413 # Any error from the server will anyway be propagated down to the supervisor,
414 # so we choose to send a generic response to the supervisor over the server response to
415 # decouple from the server response string
416 return OKResponse(ok=True)
417
418 def delete(
419 self,
420 key: str,
421 ) -> OKResponse:
422 """Delete a variable with given key via the API server."""
423 self.client.delete(f"variables/{key}")
424 # Any error from the server will anyway be propagated down to the supervisor,
425 # so we choose to send a generic response to the supervisor over the server response to
426 # decouple from the server response string
427 return OKResponse(ok=True)
428
429
430class XComOperations:
431 __slots__ = ("client",)
432
433 def __init__(self, client: Client):
434 self.client = client
435
436 def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> XComCountResponse:
437 """Get the number of mapped XCom values."""
438 resp = self.client.head(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}")
439
440 # content_range: str | None
441 if not (content_range := resp.headers["Content-Range"]) or not content_range.startswith(
442 "map_indexes "
443 ):
444 raise RuntimeError(f"Unable to parse Content-Range header from HEAD {resp.request.url}")
445 return XComCountResponse(len=int(content_range[len("map_indexes ") :]))
446
447 def get(
448 self,
449 dag_id: str,
450 run_id: str,
451 task_id: str,
452 key: str,
453 map_index: int | None = None,
454 include_prior_dates: bool = False,
455 ) -> XComResponse:
456 """Get a XCom value from the API server."""
457 # TODO: check if we need to use map_index as params in the uri
458 # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
459 params = {}
460 if map_index is not None and map_index >= 0:
461 params.update({"map_index": map_index})
462 if include_prior_dates:
463 params.update({"include_prior_dates": include_prior_dates})
464 try:
465 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
466 except ServerResponseError as e:
467 if e.response.status_code == HTTPStatus.NOT_FOUND:
468 log.error(
469 "XCom not found",
470 dag_id=dag_id,
471 run_id=run_id,
472 task_id=task_id,
473 key=key,
474 map_index=map_index,
475 detail=e.detail,
476 status_code=e.response.status_code,
477 )
478 # Airflow 2.x just ignores the absence of an XCom and moves on with a return value of None
479 # Hence returning with key as `key` and value as `None`, so that the message is sent back to task runner
480 # and the default value of None in xcom_pull is used.
481 return XComResponse(key=key, value=None)
482 raise
483 return XComResponse.model_validate_json(resp.read())
484
485 def set(
486 self,
487 dag_id: str,
488 run_id: str,
489 task_id: str,
490 key: str,
491 value,
492 map_index: int | None = None,
493 mapped_length: int | None = None,
494 ) -> OKResponse:
495 """Set a XCom value via the API server."""
496 # TODO: check if we need to use map_index as params in the uri
497 # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
498 params = {}
499 if map_index is not None and map_index >= 0:
500 params = {"map_index": map_index}
501 if mapped_length is not None and mapped_length >= 0:
502 params["mapped_length"] = mapped_length
503 self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value)
504 # Any error from the server will anyway be propagated down to the supervisor,
505 # so we choose to send a generic response to the supervisor over the server response to
506 # decouple from the server response string
507 return OKResponse(ok=True)
508
509 def delete(
510 self,
511 dag_id: str,
512 run_id: str,
513 task_id: str,
514 key: str,
515 map_index: int | None = None,
516 ) -> OKResponse:
517 """Delete a XCom with given key via the API server."""
518 params = {}
519 if map_index is not None and map_index >= 0:
520 params = {"map_index": map_index}
521 self.client.delete(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
522 # Any error from the server will anyway be propagated down to the supervisor,
523 # so we choose to send a generic response to the supervisor over the server response to
524 # decouple from the server response string
525 return OKResponse(ok=True)
526
527 def get_sequence_item(
528 self,
529 dag_id: str,
530 run_id: str,
531 task_id: str,
532 key: str,
533 offset: int,
534 ) -> XComSequenceIndexResponse | ErrorResponse:
535 try:
536 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}")
537 except ServerResponseError as e:
538 if e.response.status_code == HTTPStatus.NOT_FOUND:
539 log.error(
540 "XCom not found",
541 dag_id=dag_id,
542 run_id=run_id,
543 task_id=task_id,
544 key=key,
545 offset=offset,
546 detail=e.detail,
547 status_code=e.response.status_code,
548 )
549 return ErrorResponse(
550 error=ErrorType.XCOM_NOT_FOUND,
551 detail={
552 "dag_id": dag_id,
553 "run_id": run_id,
554 "task_id": task_id,
555 "key": key,
556 "offset": offset,
557 },
558 )
559 raise
560 return XComSequenceIndexResponse.model_validate_json(resp.read())
561
562 def get_sequence_slice(
563 self,
564 dag_id: str,
565 run_id: str,
566 task_id: str,
567 key: str,
568 start: int | None,
569 stop: int | None,
570 step: int | None,
571 include_prior_dates: bool = False,
572 ) -> XComSequenceSliceResponse:
573 params = {}
574 if start is not None:
575 params["start"] = start
576 if stop is not None:
577 params["stop"] = stop
578 if step is not None:
579 params["step"] = step
580 if include_prior_dates:
581 params["include_prior_dates"] = include_prior_dates
582 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params)
583 return XComSequenceSliceResponse.model_validate_json(resp.read())
584
585
586class AssetOperations:
587 __slots__ = ("client",)
588
589 def __init__(self, client: Client):
590 self.client = client
591
592 def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse | ErrorResponse:
593 """Get Asset value from the API server."""
594 if name:
595 endpoint = "assets/by-name"
596 params = {"name": name}
597 elif uri:
598 endpoint = "assets/by-uri"
599 params = {"uri": uri}
600 else:
601 raise ValueError("Either `name` or `uri` must be provided")
602
603 try:
604 resp = self.client.get(endpoint, params=params)
605 except ServerResponseError as e:
606 if e.response.status_code == HTTPStatus.NOT_FOUND:
607 log.error(
608 "Asset not found",
609 params=params,
610 detail=e.detail,
611 status_code=e.response.status_code,
612 )
613 return ErrorResponse(error=ErrorType.ASSET_NOT_FOUND, detail=params)
614 raise
615
616 return AssetResponse.model_validate_json(resp.read())
617
618
619class AssetEventOperations:
620 __slots__ = ("client",)
621
622 def __init__(self, client: Client):
623 self.client = client
624
625 def get(
626 self,
627 name: str | None = None,
628 uri: str | None = None,
629 alias_name: str | None = None,
630 after: datetime | None = None,
631 before: datetime | None = None,
632 ascending: bool = True,
633 limit: int | None = None,
634 ) -> AssetEventsResponse:
635 """Get Asset event from the API server."""
636 common_params: dict[str, Any] = {}
637 if after:
638 common_params["after"] = after.isoformat()
639 if before:
640 common_params["before"] = before.isoformat()
641 common_params["ascending"] = ascending
642 if limit:
643 common_params["limit"] = limit
644 if name or uri:
645 resp = self.client.get(
646 "asset-events/by-asset", params={"name": name, "uri": uri, **common_params}
647 )
648 elif alias_name:
649 resp = self.client.get(
650 "asset-events/by-asset-alias", params={"name": alias_name, **common_params}
651 )
652 else:
653 raise ValueError("Either `name`, `uri` or `alias_name` must be provided")
654
655 return AssetEventsResponse.model_validate_json(resp.read())
656
657
658class DagRunOperations:
659 __slots__ = ("client",)
660
661 def __init__(self, client: Client):
662 self.client = client
663
664 def trigger(
665 self,
666 dag_id: str,
667 run_id: str,
668 conf: dict | None = None,
669 logical_date: datetime | None = None,
670 reset_dag_run: bool = False,
671 ) -> OKResponse | ErrorResponse:
672 """Trigger a Dag run via the API server."""
673 body = TriggerDAGRunPayload(logical_date=logical_date, conf=conf or {}, reset_dag_run=reset_dag_run)
674
675 try:
676 self.client.post(
677 f"dag-runs/{dag_id}/{run_id}", content=body.model_dump_json(exclude_defaults=True)
678 )
679 except ServerResponseError as e:
680 if e.response.status_code == HTTPStatus.CONFLICT:
681 if reset_dag_run:
682 log.info("Dag Run already exists; Resetting Dag Run.", dag_id=dag_id, run_id=run_id)
683 return self.clear(run_id=run_id, dag_id=dag_id)
684
685 log.info("Dag Run already exists!", detail=e.detail, dag_id=dag_id, run_id=run_id)
686 return ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS)
687 raise
688
689 return OKResponse(ok=True)
690
691 def clear(self, dag_id: str, run_id: str) -> OKResponse:
692 """Clear a Dag run via the API server."""
693 self.client.post(f"dag-runs/{dag_id}/{run_id}/clear")
694 # TODO: Error handling
695 return OKResponse(ok=True)
696
697 def get_detail(self, dag_id: str, run_id: str) -> DagRun:
698 """Get detail of a dag run."""
699 resp = self.client.get(f"dag-runs/{dag_id}/{run_id}")
700 return DagRun.model_validate_json(resp.read())
701
702 def get_state(self, dag_id: str, run_id: str) -> DagRunStateResponse:
703 """Get the state of a Dag run via the API server."""
704 resp = self.client.get(f"dag-runs/{dag_id}/{run_id}/state")
705 return DagRunStateResponse.model_validate_json(resp.read())
706
707 def get_count(
708 self,
709 dag_id: str,
710 logical_dates: list[datetime] | None = None,
711 run_ids: list[str] | None = None,
712 states: list[str] | None = None,
713 ) -> DRCount:
714 """Get count of Dag runs matching the given criteria."""
715 params = {
716 "dag_id": dag_id,
717 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None,
718 "run_ids": run_ids,
719 "states": states,
720 }
721
722 # Remove None values from params
723 params = {k: v for k, v in params.items() if v is not None}
724
725 resp = self.client.get("dag-runs/count", params=params)
726 return DRCount(count=resp.json())
727
728 def get_previous(
729 self,
730 dag_id: str,
731 logical_date: datetime,
732 state: str | None = None,
733 ) -> PreviousDagRunResult:
734 """Get the previous DAG run before the given logical date, optionally filtered by state."""
735 params = {
736 "dag_id": dag_id,
737 "logical_date": logical_date.isoformat(),
738 }
739 if state:
740 params["state"] = state
741 resp = self.client.get("dag-runs/previous", params=params)
742 return PreviousDagRunResult(dag_run=resp.json())
743
744
745class HITLOperations:
746 """
747 Operations related to Human in the loop. Require Airflow 3.1+.
748
749 :meta: private
750 """
751
752 __slots__ = ("client",)
753
754 def __init__(self, client: Client) -> None:
755 self.client = client
756
757 def add_response(
758 self,
759 *,
760 ti_id: uuid.UUID,
761 options: list[str],
762 subject: str,
763 body: str | None = None,
764 defaults: list[str] | None = None,
765 multiple: bool = False,
766 params: dict[str, dict[str, Any]] | None = None,
767 assigned_users: list[HITLUser] | None = None,
768 ) -> HITLDetailRequest:
769 """Add a Human-in-the-loop response that waits for human response for a specific Task Instance."""
770 payload = CreateHITLDetailPayload(
771 ti_id=ti_id,
772 options=options,
773 subject=subject,
774 body=body,
775 defaults=defaults,
776 multiple=multiple,
777 params=params,
778 assigned_users=assigned_users,
779 )
780 resp = self.client.post(
781 f"/hitlDetails/{ti_id}",
782 content=payload.model_dump_json(),
783 )
784 return HITLDetailRequest.model_validate_json(resp.read())
785
786 def update_response(
787 self,
788 *,
789 ti_id: uuid.UUID,
790 chosen_options: list[str],
791 params_input: dict[str, Any],
792 ) -> HITLDetailResponse:
793 """Update an existing Human-in-the-loop response."""
794 payload = UpdateHITLDetail(
795 ti_id=ti_id,
796 chosen_options=chosen_options,
797 params_input=params_input,
798 )
799 resp = self.client.patch(
800 f"/hitlDetails/{ti_id}",
801 content=payload.model_dump_json(),
802 )
803 return HITLDetailResponse.model_validate_json(resp.read())
804
805 def get_detail_response(self, ti_id: uuid.UUID) -> HITLDetailResponse:
806 """Get content part of a Human-in-the-loop response for a specific Task Instance."""
807 resp = self.client.get(f"/hitlDetails/{ti_id}")
808 return HITLDetailResponse.model_validate_json(resp.read())
809
810
811class BearerAuth(httpx.Auth):
812 def __init__(self, token: str):
813 self.token: str = token
814
815 def auth_flow(self, request: httpx.Request):
816 if self.token:
817 request.headers["Authorization"] = "Bearer " + self.token
818 yield request
819
820
821# This exists as an aid for debugging or local running via the `dry_run` argument to Client. It doesn't make
822# sense for returning connections etc.
823def noop_handler(request: httpx.Request) -> httpx.Response:
824 path = request.url.path
825 log.debug("Dry-run request", method=request.method, path=path)
826
827 if path.startswith("/task-instances/") and path.endswith("/run"):
828 # Return a fake context
829 return httpx.Response(
830 200,
831 json={
832 "dag_run": {
833 "dag_id": "test_dag",
834 "run_id": "test_run",
835 "logical_date": "2021-01-01T00:00:00Z",
836 "start_date": "2021-01-01T00:00:00Z",
837 "run_type": DagRunType.MANUAL,
838 "run_after": "2021-01-01T00:00:00Z",
839 "consumed_asset_events": [],
840 },
841 "max_tries": 0,
842 "should_retry": False,
843 },
844 )
845 return httpx.Response(200, json={"text": "Hello, world!"})
846
847
848# Note: Given defaults make attempts after 1, 3, 7, 15 and fails after 31seconds
849API_RETRIES = conf.getint("workers", "execution_api_retries")
850API_RETRY_WAIT_MIN = conf.getfloat("workers", "execution_api_retry_wait_min")
851API_RETRY_WAIT_MAX = conf.getfloat("workers", "execution_api_retry_wait_max")
852API_SSL_CERT_PATH = conf.get("api", "ssl_cert")
853API_TIMEOUT = conf.getfloat("workers", "execution_api_timeout")
854
855
856def _should_retry_api_request(exception: BaseException) -> bool:
857 """Determine if an API request should be retried based on the exception type."""
858 if isinstance(exception, httpx.HTTPStatusError):
859 return exception.response.status_code >= 500
860
861 return isinstance(exception, httpx.RequestError)
862
863
864class Client(httpx.Client):
865 @lru_cache()
866 @staticmethod
867 def _get_ssl_context_cached(ca_file: str, ca_path: str | None = None) -> ssl.SSLContext:
868 """Cache SSL context to prevent memory growth from repeated context creation."""
869 ctx = ssl.create_default_context(cafile=ca_file)
870 if ca_path:
871 ctx.load_verify_locations(ca_path)
872 return ctx
873
874 def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any):
875 if (not base_url) ^ dry_run:
876 raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}")
877 auth = BearerAuth(token)
878
879 if dry_run:
880 # If dry run is requested, install a no op handler so that simple tasks can "heartbeat" using a
881 # real client, but just don't make any HTTP requests
882 kwargs.setdefault("transport", httpx.MockTransport(noop_handler))
883 kwargs.setdefault("base_url", "dry-run://server")
884 else:
885 kwargs["base_url"] = base_url
886 kwargs["verify"] = self._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH)
887
888 # Set timeout if not explicitly provided
889 kwargs.setdefault("timeout", API_TIMEOUT)
890
891 pyver = f"{'.'.join(map(str, sys.version_info[:3]))}"
892 super().__init__(
893 auth=auth,
894 headers={
895 "user-agent": f"apache-airflow-task-sdk/{__version__} (Python/{pyver})",
896 "airflow-api-version": API_VERSION,
897 },
898 event_hooks={"response": [self._update_auth, raise_on_4xx_5xx], "request": [add_correlation_id]},
899 **kwargs,
900 )
901
902 def _update_auth(self, response: httpx.Response):
903 if new_token := response.headers.get("Refreshed-API-Token"):
904 log.debug("Execution API issued us a refreshed Task token")
905 self.auth = BearerAuth(new_token)
906
907 @retry(
908 retry=retry_if_exception(_should_retry_api_request),
909 stop=stop_after_attempt(API_RETRIES),
910 wait=wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX),
911 before_sleep=before_log(log, logging.WARNING),
912 reraise=True,
913 )
914 def request(self, *args, **kwargs):
915 """Implement a convenience for httpx.Client.request with a retry layer."""
916 # Set content type as convenience if not already set
917 if "content" in kwargs and "headers" not in kwargs:
918 kwargs["headers"] = {"content-type": "application/json"}
919
920 return super().request(*args, **kwargs)
921
922 # We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all
923 # methods on one object prefixed with the object type (`.task_instances.update` rather than
924 # `task_instance_update` etc.)
925
926 @lru_cache() # type: ignore[misc]
927 @property
928 def task_instances(self) -> TaskInstanceOperations:
929 """Operations related to TaskInstances."""
930 return TaskInstanceOperations(self)
931
932 @lru_cache() # type: ignore[misc]
933 @property
934 def dag_runs(self) -> DagRunOperations:
935 """Operations related to DagRuns."""
936 return DagRunOperations(self)
937
938 @lru_cache() # type: ignore[misc]
939 @property
940 def connections(self) -> ConnectionOperations:
941 """Operations related to Connections."""
942 return ConnectionOperations(self)
943
944 @lru_cache() # type: ignore[misc]
945 @property
946 def variables(self) -> VariableOperations:
947 """Operations related to Variables."""
948 return VariableOperations(self)
949
950 @lru_cache() # type: ignore[misc]
951 @property
952 def xcoms(self) -> XComOperations:
953 """Operations related to XComs."""
954 return XComOperations(self)
955
956 @lru_cache() # type: ignore[misc]
957 @property
958 def assets(self) -> AssetOperations:
959 """Operations related to Assets."""
960 return AssetOperations(self)
961
962 @lru_cache() # type: ignore[misc]
963 @property
964 def asset_events(self) -> AssetEventOperations:
965 """Operations related to Asset Events."""
966 return AssetEventOperations(self)
967
968 @lru_cache() # type: ignore[misc]
969 @property
970 def hitl(self):
971 """Operations related to HITL Responses."""
972 return HITLOperations(self)
973
974
975# This is only used for parsing. ServerResponseError is raised instead
976class _ErrorBody(BaseModel):
977 detail: list[RemoteValidationError] | str
978
979 def __repr__(self):
980 return repr(self.detail)
981
982
983class ServerResponseError(httpx.HTTPStatusError):
984 def __init__(self, message: str, *, request: httpx.Request, response: httpx.Response):
985 super().__init__(message, request=request, response=response)
986
987 detail: list[RemoteValidationError] | str | dict[str, Any] | None
988
989 def __reduce__(self) -> tuple[Any, ...]:
990 # Needed because https://github.com/encode/httpx/pull/3108 isn't merged yet.
991 return Exception.__new__, (type(self),) + self.args, self.__dict__
992
993 @classmethod
994 def from_response(cls, response: httpx.Response) -> ServerResponseError | None:
995 if response.is_success:
996 return None
997 # 4xx or 5xx error?
998 if not (400 <= response.status_code < 600):
999 return None
1000
1001 if response.headers.get("content-type") != "application/json":
1002 return None
1003
1004 detail: list[RemoteValidationError] | dict[str, Any] | None = None
1005 try:
1006 body = _ErrorBody.model_validate_json(response.read())
1007
1008 if isinstance(body.detail, list):
1009 detail = body.detail
1010 msg = "Remote server returned validation error"
1011 else:
1012 msg = body.detail or "Un-parseable error"
1013 except Exception:
1014 try:
1015 detail = msgspec.json.decode(response.content)
1016 except Exception:
1017 # Fallback to a normal httpx error
1018 return None
1019 msg = "Server returned error"
1020
1021 self = cls(msg, request=response.request, response=response)
1022 self.detail = detail
1023 return self