Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/api/client.py: 33%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

447 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import logging 

21import ssl 

22import sys 

23import uuid 

24from functools import cache 

25from http import HTTPStatus 

26from typing import TYPE_CHECKING, Any, TypeVar 

27 

28import certifi 

29import httpx 

30import msgspec 

31import structlog 

32from pydantic import BaseModel 

33from tenacity import ( 

34 before_log, 

35 retry, 

36 retry_if_exception, 

37 stop_after_attempt, 

38 wait_random_exponential, 

39) 

40from uuid6 import uuid7 

41 

42from airflow.sdk import __version__ 

43from airflow.sdk.api.datamodels._generated import ( 

44 API_VERSION, 

45 AssetEventsResponse, 

46 AssetResponse, 

47 ConnectionResponse, 

48 DagRun, 

49 DagRunStateResponse, 

50 DagRunType, 

51 HITLDetailRequest, 

52 HITLDetailResponse, 

53 HITLUser, 

54 InactiveAssetsResponse, 

55 PrevSuccessfulDagRunResponse, 

56 TaskBreadcrumbsResponse, 

57 TaskInstanceState, 

58 TaskStatesResponse, 

59 TerminalStateNonSuccess, 

60 TIDeferredStatePayload, 

61 TIEnterRunningPayload, 

62 TIHeartbeatInfo, 

63 TIRescheduleStatePayload, 

64 TIRetryStatePayload, 

65 TIRunContext, 

66 TISkippedDownstreamTasksStatePayload, 

67 TISuccessStatePayload, 

68 TITerminalStatePayload, 

69 TriggerDAGRunPayload, 

70 ValidationError as RemoteValidationError, 

71 VariablePostBody, 

72 VariableResponse, 

73 XComResponse, 

74 XComSequenceIndexResponse, 

75 XComSequenceSliceResponse, 

76) 

77from airflow.sdk.configuration import conf 

78from airflow.sdk.exceptions import ErrorType 

79from airflow.sdk.execution_time.comms import ( 

80 CreateHITLDetailPayload, 

81 DRCount, 

82 ErrorResponse, 

83 OKResponse, 

84 PreviousDagRunResult, 

85 PreviousTIResult, 

86 SkipDownstreamTasks, 

87 TaskRescheduleStartDate, 

88 TICount, 

89 UpdateHITLDetail, 

90 XComCountResponse, 

91) 

92 

93if TYPE_CHECKING: 

94 from datetime import datetime 

95 from typing import ParamSpec 

96 

97 from airflow.sdk.execution_time.comms import RescheduleTask 

98 

99 P = ParamSpec("P") 

100 T = TypeVar("T") 

101 

102 # # methodtools doesn't have typestubs, so give a stub 

103 def lru_cache(maxsize: int | None = 128): 

104 def wrapper(f): 

105 return f 

106 

107 return wrapper 

108else: 

109 from methodtools import lru_cache 

110 

111 

112@cache 

113def _get_fqdn(name=""): 

114 """ 

115 Get fully qualified domain name from name. 

116 

117 An empty argument is interpreted as meaning the local host. 

118 This is a patched version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254 

119 """ 

120 import socket 

121 

122 name = name.strip() 

123 if not name or name == "0.0.0.0": 

124 name = socket.gethostname() 

125 try: 

126 addrs = socket.getaddrinfo(name, None, 0, socket.SOCK_DGRAM, 0, socket.AI_CANONNAME) 

127 except OSError: 

128 pass 

129 else: 

130 for addr in addrs: 

131 if addr[3]: 

132 name = addr[3] 

133 break 

134 return name 

135 

136 

137def get_hostname(): 

138 """Fetch the hostname using the callable from config or use built-in FQDN as a fallback.""" 

139 return conf.getimport("core", "hostname_callable", fallback=_get_fqdn)() 

140 

141 

142@cache 

143def getuser() -> str: 

144 """ 

145 Get the username of the current user, or error with a nice error message if there's no current user. 

146 

147 We don't want to fall back to os.getuid() because not having a username 

148 probably means the rest of the user environment is wrong (e.g. no $HOME). 

149 Explicit failure is better than silently trying to work badly. 

150 """ 

151 import getpass 

152 

153 try: 

154 return getpass.getuser() 

155 except KeyError: 

156 raise ValueError( 

157 "The user that Airflow is running as has no username; you must run " 

158 "Airflow as a full user, with a username and home directory, " 

159 "in order for it to function properly." 

160 ) 

161 

162 

163log = structlog.get_logger(logger_name=__name__) 

164 

165__all__ = [ 

166 "Client", 

167 "ConnectionOperations", 

168 "ServerResponseError", 

169 "TaskInstanceOperations", 

170 "get_hostname", 

171 "getuser", 

172] 

173 

174 

175def get_json_error(response: httpx.Response): 

176 """Raise a ServerResponseError if we can extract error info from the error.""" 

177 err = ServerResponseError.from_response(response) 

178 if err: 

179 raise err 

180 

181 

182def raise_on_4xx_5xx(response: httpx.Response): 

183 return get_json_error(response) or response.raise_for_status() 

184 

185 

186# Py 3.11+ version 

187def raise_on_4xx_5xx_with_note(response: httpx.Response): 

188 try: 

189 return get_json_error(response) or response.raise_for_status() 

190 except httpx.HTTPStatusError as e: 

191 if TYPE_CHECKING: 

192 assert hasattr(e, "add_note") 

193 e.add_note( 

194 f"Correlation-id={response.headers.get('correlation-id', None) or response.request.headers.get('correlation-id', 'no-correlation-id')}" 

195 ) 

196 raise 

197 

198 

199if hasattr(BaseException, "add_note"): 

200 # Py 3.11+ 

201 raise_on_4xx_5xx = raise_on_4xx_5xx_with_note 

202 

203 

204def add_correlation_id(request: httpx.Request): 

205 request.headers["correlation-id"] = str(uuid7()) 

206 

207 

208class TaskInstanceOperations: 

209 __slots__ = ("client",) 

210 

211 def __init__(self, client: Client): 

212 self.client = client 

213 

214 def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext: 

215 """Tell the API server that this TI has started running.""" 

216 body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when) 

217 

218 resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) 

219 return TIRunContext.model_validate_json(resp.read()) 

220 

221 def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime, rendered_map_index): 

222 """Tell the API server that this TI has reached a terminal state.""" 

223 if state == TaskInstanceState.SUCCESS: 

224 raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead") 

225 # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. 

226 body = TITerminalStatePayload( 

227 end_date=when, state=TerminalStateNonSuccess(state), rendered_map_index=rendered_map_index 

228 ) 

229 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) 

230 

231 def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index): 

232 """Tell the API server that this TI has failed and reached a up_for_retry state.""" 

233 body = TIRetryStatePayload(end_date=end_date, rendered_map_index=rendered_map_index) 

234 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) 

235 

236 def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events, rendered_map_index): 

237 """Tell the API server that this TI has succeeded.""" 

238 body = TISuccessStatePayload( 

239 end_date=when, 

240 task_outlets=task_outlets, 

241 outlet_events=outlet_events, 

242 rendered_map_index=rendered_map_index, 

243 ) 

244 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) 

245 

246 def defer(self, id: uuid.UUID, msg): 

247 """Tell the API server that this TI has been deferred.""" 

248 body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True, exclude={"type"})) 

249 

250 # Create a deferred state payload from msg 

251 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) 

252 

253 def reschedule(self, id: uuid.UUID, msg: RescheduleTask): 

254 """Tell the API server that this TI has been reschduled.""" 

255 body = TIRescheduleStatePayload(**msg.model_dump(exclude_unset=True, exclude={"type"})) 

256 

257 # Create a reschedule state payload from msg 

258 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) 

259 

260 def heartbeat(self, id: uuid.UUID, pid: int): 

261 body = TIHeartbeatInfo(pid=pid, hostname=get_hostname()) 

262 self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json()) 

263 

264 def skip_downstream_tasks(self, id: uuid.UUID, msg: SkipDownstreamTasks): 

265 """Tell the API server to skip the downstream tasks of this TI.""" 

266 body = TISkippedDownstreamTasksStatePayload(tasks=msg.tasks) 

267 self.client.patch(f"task-instances/{id}/skip-downstream", content=body.model_dump_json()) 

268 

269 def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> OKResponse: 

270 """Set Rendered Task Instance Fields via the API server.""" 

271 self.client.put(f"task-instances/{id}/rtif", json=body) 

272 # Any error from the server will anyway be propagated down to the supervisor, 

273 # so we choose to send a generic response to the supervisor over the server response to 

274 # decouple from the server response string 

275 return OKResponse(ok=True) 

276 

277 def set_rendered_map_index(self, id: uuid.UUID, rendered_map_index: str) -> OKResponse: 

278 """Set rendered_map_index for a task instance via the API server.""" 

279 self.client.patch(f"task-instances/{id}/rendered-map-index", json=rendered_map_index) 

280 return OKResponse(ok=True) 

281 

282 def get_previous_successful_dagrun(self, id: uuid.UUID) -> PrevSuccessfulDagRunResponse: 

283 """ 

284 Get the previous successful dag run for a given task instance. 

285 

286 The data from it is used to get values for Task Context. 

287 """ 

288 resp = self.client.get(f"task-instances/{id}/previous-successful-dagrun") 

289 return PrevSuccessfulDagRunResponse.model_validate_json(resp.read()) 

290 

291 def get_reschedule_start_date(self, id: uuid.UUID, try_number: int = 1) -> TaskRescheduleStartDate: 

292 """Get the start date of a task reschedule via the API server.""" 

293 resp = self.client.get(f"task-reschedules/{id}/start_date", params={"try_number": try_number}) 

294 return TaskRescheduleStartDate(start_date=resp.json()) 

295 

296 def get_count( 

297 self, 

298 dag_id: str, 

299 map_index: int | None = None, 

300 task_ids: list[str] | None = None, 

301 task_group_id: str | None = None, 

302 logical_dates: list[datetime] | None = None, 

303 run_ids: list[str] | None = None, 

304 states: list[str] | None = None, 

305 ) -> TICount: 

306 """Get count of task instances matching the given criteria.""" 

307 params: dict[str, Any] 

308 params = { 

309 "dag_id": dag_id, 

310 "task_ids": task_ids, 

311 "task_group_id": task_group_id, 

312 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None, 

313 "run_ids": run_ids, 

314 "states": states, 

315 } 

316 

317 # Remove None values from params 

318 params = {k: v for k, v in params.items() if v is not None} 

319 

320 if map_index is not None and map_index >= 0: 

321 params.update({"map_index": map_index}) 

322 

323 resp = self.client.get("task-instances/count", params=params) 

324 return TICount(count=resp.json()) 

325 

326 def get_previous( 

327 self, 

328 dag_id: str, 

329 task_id: str, 

330 logical_date: datetime | None = None, 

331 map_index: int = -1, 

332 state: TaskInstanceState | str | None = None, 

333 ) -> PreviousTIResult: 

334 """ 

335 Get the previous task instance matching the given criteria. 

336 

337 :param dag_id: DAG ID 

338 :param task_id: Task ID 

339 :param logical_date: If provided, finds TI with logical_date < this value (before filter) 

340 :param map_index: Map index to filter by (defaults to -1 for non-mapped tasks) 

341 :param state: If provided, filters by TaskInstance state 

342 """ 

343 params: dict[str, Any] = {"map_index": map_index} 

344 if logical_date: 

345 params["logical_date"] = logical_date.isoformat() 

346 if state: 

347 params["state"] = state.value if isinstance(state, TaskInstanceState) else state 

348 

349 resp = self.client.get(f"task-instances/previous/{dag_id}/{task_id}", params=params) 

350 return PreviousTIResult(task_instance=resp.json()) 

351 

352 def get_task_states( 

353 self, 

354 dag_id: str, 

355 map_index: int | None = None, 

356 task_ids: list[str] | None = None, 

357 task_group_id: str | None = None, 

358 logical_dates: list[datetime] | None = None, 

359 run_ids: list[str] | None = None, 

360 ) -> TaskStatesResponse: 

361 """Get task states given criteria.""" 

362 params: dict[str, Any] 

363 params = { 

364 "dag_id": dag_id, 

365 "task_ids": task_ids, 

366 "task_group_id": task_group_id, 

367 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None, 

368 "run_ids": run_ids, 

369 } 

370 

371 # Remove None values from params 

372 params = {k: v for k, v in params.items() if v is not None} 

373 

374 if map_index is not None and map_index >= 0: 

375 params.update({"map_index": map_index}) 

376 

377 resp = self.client.get("task-instances/states", params=params) 

378 return TaskStatesResponse.model_validate_json(resp.read()) 

379 

380 def get_task_breakcrumbs(self, dag_id: str, run_id: str) -> TaskBreadcrumbsResponse: 

381 params = {"dag_id": dag_id, "run_id": run_id} 

382 resp = self.client.get("task-instances/breadcrumbs", params=params) 

383 return TaskBreadcrumbsResponse.model_validate_json(resp.read()) 

384 

385 def validate_inlets_and_outlets(self, id: uuid.UUID) -> InactiveAssetsResponse: 

386 """Validate whether there're inactive assets in inlets and outlets of a given task instance.""" 

387 resp = self.client.get(f"task-instances/{id}/validate-inlets-and-outlets") 

388 return InactiveAssetsResponse.model_validate_json(resp.read()) 

389 

390 

391class ConnectionOperations: 

392 __slots__ = ("client",) 

393 

394 def __init__(self, client: Client): 

395 self.client = client 

396 

397 def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse: 

398 """Get a connection from the API server.""" 

399 try: 

400 resp = self.client.get(f"connections/{conn_id}") 

401 except ServerResponseError as e: 

402 if e.response.status_code == HTTPStatus.NOT_FOUND: 

403 log.debug( 

404 "Connection not found", 

405 conn_id=conn_id, 

406 detail=e.detail, 

407 status_code=e.response.status_code, 

408 ) 

409 return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": conn_id}) 

410 raise 

411 return ConnectionResponse.model_validate_json(resp.read()) 

412 

413 

414class VariableOperations: 

415 __slots__ = ("client",) 

416 

417 def __init__(self, client: Client): 

418 self.client = client 

419 

420 def get(self, key: str) -> VariableResponse | ErrorResponse: 

421 """Get a variable from the API server.""" 

422 try: 

423 resp = self.client.get(f"variables/{key}") 

424 except ServerResponseError as e: 

425 if e.response.status_code == HTTPStatus.NOT_FOUND: 

426 log.error( 

427 "Variable not found", 

428 key=key, 

429 detail=e.detail, 

430 status_code=e.response.status_code, 

431 ) 

432 return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"key": key}) 

433 raise 

434 return VariableResponse.model_validate_json(resp.read()) 

435 

436 def set(self, key: str, value: str | None, description: str | None = None) -> OKResponse: 

437 """Set an Airflow Variable via the API server.""" 

438 body = VariablePostBody(val=value, description=description) 

439 self.client.put(f"variables/{key}", content=body.model_dump_json()) 

440 # Any error from the server will anyway be propagated down to the supervisor, 

441 # so we choose to send a generic response to the supervisor over the server response to 

442 # decouple from the server response string 

443 return OKResponse(ok=True) 

444 

445 def delete( 

446 self, 

447 key: str, 

448 ) -> OKResponse: 

449 """Delete a variable with given key via the API server.""" 

450 self.client.delete(f"variables/{key}") 

451 # Any error from the server will anyway be propagated down to the supervisor, 

452 # so we choose to send a generic response to the supervisor over the server response to 

453 # decouple from the server response string 

454 return OKResponse(ok=True) 

455 

456 

457class XComOperations: 

458 __slots__ = ("client",) 

459 

460 def __init__(self, client: Client): 

461 self.client = client 

462 

463 def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> XComCountResponse: 

464 """Get the number of mapped XCom values.""" 

465 resp = self.client.head(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}") 

466 

467 # content_range: str | None 

468 if not (content_range := resp.headers["Content-Range"]) or not content_range.startswith( 

469 "map_indexes " 

470 ): 

471 raise RuntimeError(f"Unable to parse Content-Range header from HEAD {resp.request.url}") 

472 return XComCountResponse(len=int(content_range[len("map_indexes ") :])) 

473 

474 def get( 

475 self, 

476 dag_id: str, 

477 run_id: str, 

478 task_id: str, 

479 key: str, 

480 map_index: int | None = None, 

481 include_prior_dates: bool = False, 

482 ) -> XComResponse: 

483 """Get a XCom value from the API server.""" 

484 # TODO: check if we need to use map_index as params in the uri 

485 # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 

486 params = {} 

487 if map_index is not None and map_index >= 0: 

488 params.update({"map_index": map_index}) 

489 if include_prior_dates: 

490 params.update({"include_prior_dates": include_prior_dates}) 

491 try: 

492 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) 

493 except ServerResponseError as e: 

494 if e.response.status_code == HTTPStatus.NOT_FOUND: 

495 log.error( 

496 "XCom not found", 

497 dag_id=dag_id, 

498 run_id=run_id, 

499 task_id=task_id, 

500 key=key, 

501 map_index=map_index, 

502 detail=e.detail, 

503 status_code=e.response.status_code, 

504 ) 

505 # Airflow 2.x just ignores the absence of an XCom and moves on with a return value of None 

506 # Hence returning with key as `key` and value as `None`, so that the message is sent back to task runner 

507 # and the default value of None in xcom_pull is used. 

508 return XComResponse(key=key, value=None) 

509 raise 

510 return XComResponse.model_validate_json(resp.read()) 

511 

512 def set( 

513 self, 

514 dag_id: str, 

515 run_id: str, 

516 task_id: str, 

517 key: str, 

518 value, 

519 map_index: int | None = None, 

520 mapped_length: int | None = None, 

521 ) -> OKResponse: 

522 """Set a XCom value via the API server.""" 

523 # TODO: check if we need to use map_index as params in the uri 

524 # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 

525 params = {} 

526 if map_index is not None and map_index >= 0: 

527 params = {"map_index": map_index} 

528 if mapped_length is not None and mapped_length >= 0: 

529 params["mapped_length"] = mapped_length 

530 self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value) 

531 # Any error from the server will anyway be propagated down to the supervisor, 

532 # so we choose to send a generic response to the supervisor over the server response to 

533 # decouple from the server response string 

534 return OKResponse(ok=True) 

535 

536 def delete( 

537 self, 

538 dag_id: str, 

539 run_id: str, 

540 task_id: str, 

541 key: str, 

542 map_index: int | None = None, 

543 ) -> OKResponse: 

544 """Delete a XCom with given key via the API server.""" 

545 params = {} 

546 if map_index is not None and map_index >= 0: 

547 params = {"map_index": map_index} 

548 self.client.delete(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) 

549 # Any error from the server will anyway be propagated down to the supervisor, 

550 # so we choose to send a generic response to the supervisor over the server response to 

551 # decouple from the server response string 

552 return OKResponse(ok=True) 

553 

554 def get_sequence_item( 

555 self, 

556 dag_id: str, 

557 run_id: str, 

558 task_id: str, 

559 key: str, 

560 offset: int, 

561 ) -> XComSequenceIndexResponse | ErrorResponse: 

562 try: 

563 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}") 

564 except ServerResponseError as e: 

565 if e.response.status_code == HTTPStatus.NOT_FOUND: 

566 log.error( 

567 "XCom not found", 

568 dag_id=dag_id, 

569 run_id=run_id, 

570 task_id=task_id, 

571 key=key, 

572 offset=offset, 

573 detail=e.detail, 

574 status_code=e.response.status_code, 

575 ) 

576 return ErrorResponse( 

577 error=ErrorType.XCOM_NOT_FOUND, 

578 detail={ 

579 "dag_id": dag_id, 

580 "run_id": run_id, 

581 "task_id": task_id, 

582 "key": key, 

583 "offset": offset, 

584 }, 

585 ) 

586 raise 

587 return XComSequenceIndexResponse.model_validate_json(resp.read()) 

588 

589 def get_sequence_slice( 

590 self, 

591 dag_id: str, 

592 run_id: str, 

593 task_id: str, 

594 key: str, 

595 start: int | None, 

596 stop: int | None, 

597 step: int | None, 

598 include_prior_dates: bool = False, 

599 ) -> XComSequenceSliceResponse: 

600 params = {} 

601 if start is not None: 

602 params["start"] = start 

603 if stop is not None: 

604 params["stop"] = stop 

605 if step is not None: 

606 params["step"] = step 

607 if include_prior_dates: 

608 params["include_prior_dates"] = include_prior_dates 

609 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params) 

610 return XComSequenceSliceResponse.model_validate_json(resp.read()) 

611 

612 

613class AssetOperations: 

614 __slots__ = ("client",) 

615 

616 def __init__(self, client: Client): 

617 self.client = client 

618 

619 def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse | ErrorResponse: 

620 """Get Asset value from the API server.""" 

621 if name: 

622 endpoint = "assets/by-name" 

623 params = {"name": name} 

624 elif uri: 

625 endpoint = "assets/by-uri" 

626 params = {"uri": uri} 

627 else: 

628 raise ValueError("Either `name` or `uri` must be provided") 

629 

630 try: 

631 resp = self.client.get(endpoint, params=params) 

632 except ServerResponseError as e: 

633 if e.response.status_code == HTTPStatus.NOT_FOUND: 

634 log.error( 

635 "Asset not found", 

636 params=params, 

637 detail=e.detail, 

638 status_code=e.response.status_code, 

639 ) 

640 return ErrorResponse(error=ErrorType.ASSET_NOT_FOUND, detail=params) 

641 raise 

642 

643 return AssetResponse.model_validate_json(resp.read()) 

644 

645 

646class AssetEventOperations: 

647 __slots__ = ("client",) 

648 

649 def __init__(self, client: Client): 

650 self.client = client 

651 

652 def get( 

653 self, 

654 name: str | None = None, 

655 uri: str | None = None, 

656 alias_name: str | None = None, 

657 after: datetime | None = None, 

658 before: datetime | None = None, 

659 ascending: bool = True, 

660 limit: int | None = None, 

661 ) -> AssetEventsResponse: 

662 """Get Asset event from the API server.""" 

663 common_params: dict[str, Any] = {} 

664 if after: 

665 common_params["after"] = after.isoformat() 

666 if before: 

667 common_params["before"] = before.isoformat() 

668 common_params["ascending"] = ascending 

669 if limit: 

670 common_params["limit"] = limit 

671 if name or uri: 

672 resp = self.client.get( 

673 "asset-events/by-asset", params={"name": name, "uri": uri, **common_params} 

674 ) 

675 elif alias_name: 

676 resp = self.client.get( 

677 "asset-events/by-asset-alias", params={"name": alias_name, **common_params} 

678 ) 

679 else: 

680 raise ValueError("Either `name`, `uri` or `alias_name` must be provided") 

681 

682 return AssetEventsResponse.model_validate_json(resp.read()) 

683 

684 

685class DagRunOperations: 

686 __slots__ = ("client",) 

687 

688 def __init__(self, client: Client): 

689 self.client = client 

690 

691 def trigger( 

692 self, 

693 dag_id: str, 

694 run_id: str, 

695 conf: dict | None = None, 

696 logical_date: datetime | None = None, 

697 reset_dag_run: bool = False, 

698 ) -> OKResponse | ErrorResponse: 

699 """Trigger a Dag run via the API server.""" 

700 body = TriggerDAGRunPayload(logical_date=logical_date, conf=conf or {}, reset_dag_run=reset_dag_run) 

701 

702 try: 

703 self.client.post( 

704 f"dag-runs/{dag_id}/{run_id}", content=body.model_dump_json(exclude_defaults=True) 

705 ) 

706 except ServerResponseError as e: 

707 if e.response.status_code == HTTPStatus.CONFLICT: 

708 if reset_dag_run: 

709 log.info("Dag Run already exists; Resetting Dag Run.", dag_id=dag_id, run_id=run_id) 

710 return self.clear(run_id=run_id, dag_id=dag_id) 

711 

712 log.info("Dag Run already exists!", detail=e.detail, dag_id=dag_id, run_id=run_id) 

713 return ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS) 

714 raise 

715 

716 return OKResponse(ok=True) 

717 

718 def clear(self, dag_id: str, run_id: str) -> OKResponse: 

719 """Clear a Dag run via the API server.""" 

720 self.client.post(f"dag-runs/{dag_id}/{run_id}/clear") 

721 # TODO: Error handling 

722 return OKResponse(ok=True) 

723 

724 def get_detail(self, dag_id: str, run_id: str) -> DagRun: 

725 """Get detail of a dag run.""" 

726 resp = self.client.get(f"dag-runs/{dag_id}/{run_id}") 

727 return DagRun.model_validate_json(resp.read()) 

728 

729 def get_state(self, dag_id: str, run_id: str) -> DagRunStateResponse: 

730 """Get the state of a Dag run via the API server.""" 

731 resp = self.client.get(f"dag-runs/{dag_id}/{run_id}/state") 

732 return DagRunStateResponse.model_validate_json(resp.read()) 

733 

734 def get_count( 

735 self, 

736 dag_id: str, 

737 logical_dates: list[datetime] | None = None, 

738 run_ids: list[str] | None = None, 

739 states: list[str] | None = None, 

740 ) -> DRCount: 

741 """Get count of Dag runs matching the given criteria.""" 

742 params = { 

743 "dag_id": dag_id, 

744 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None, 

745 "run_ids": run_ids, 

746 "states": states, 

747 } 

748 

749 # Remove None values from params 

750 params = {k: v for k, v in params.items() if v is not None} 

751 

752 resp = self.client.get("dag-runs/count", params=params) 

753 return DRCount(count=resp.json()) 

754 

755 def get_previous( 

756 self, 

757 dag_id: str, 

758 logical_date: datetime, 

759 state: str | None = None, 

760 ) -> PreviousDagRunResult: 

761 """Get the previous DAG run before the given logical date, optionally filtered by state.""" 

762 params = { 

763 "dag_id": dag_id, 

764 "logical_date": logical_date.isoformat(), 

765 } 

766 if state: 

767 params["state"] = state 

768 resp = self.client.get("dag-runs/previous", params=params) 

769 return PreviousDagRunResult(dag_run=resp.json()) 

770 

771 

772class HITLOperations: 

773 """ 

774 Operations related to Human in the loop. Require Airflow 3.1+. 

775 

776 :meta: private 

777 """ 

778 

779 __slots__ = ("client",) 

780 

781 def __init__(self, client: Client) -> None: 

782 self.client = client 

783 

784 def add_response( 

785 self, 

786 *, 

787 ti_id: uuid.UUID, 

788 options: list[str], 

789 subject: str, 

790 body: str | None = None, 

791 defaults: list[str] | None = None, 

792 multiple: bool = False, 

793 params: dict[str, dict[str, Any]] | None = None, 

794 assigned_users: list[HITLUser] | None = None, 

795 ) -> HITLDetailRequest: 

796 """Add a Human-in-the-loop response that waits for human response for a specific Task Instance.""" 

797 payload = CreateHITLDetailPayload( 

798 ti_id=ti_id, 

799 options=options, 

800 subject=subject, 

801 body=body, 

802 defaults=defaults, 

803 multiple=multiple, 

804 params=params, 

805 assigned_users=assigned_users, 

806 ) 

807 resp = self.client.post( 

808 f"/hitlDetails/{ti_id}", 

809 content=payload.model_dump_json(), 

810 ) 

811 return HITLDetailRequest.model_validate_json(resp.read()) 

812 

813 def update_response( 

814 self, 

815 *, 

816 ti_id: uuid.UUID, 

817 chosen_options: list[str], 

818 params_input: dict[str, Any], 

819 ) -> HITLDetailResponse: 

820 """Update an existing Human-in-the-loop response.""" 

821 payload = UpdateHITLDetail( 

822 ti_id=ti_id, 

823 chosen_options=chosen_options, 

824 params_input=params_input, 

825 ) 

826 resp = self.client.patch( 

827 f"/hitlDetails/{ti_id}", 

828 content=payload.model_dump_json(), 

829 ) 

830 return HITLDetailResponse.model_validate_json(resp.read()) 

831 

832 def get_detail_response(self, ti_id: uuid.UUID) -> HITLDetailResponse: 

833 """Get content part of a Human-in-the-loop response for a specific Task Instance.""" 

834 resp = self.client.get(f"/hitlDetails/{ti_id}") 

835 return HITLDetailResponse.model_validate_json(resp.read()) 

836 

837 

838class BearerAuth(httpx.Auth): 

839 def __init__(self, token: str): 

840 self.token: str = token 

841 

842 def auth_flow(self, request: httpx.Request): 

843 if self.token: 

844 request.headers["Authorization"] = "Bearer " + self.token 

845 yield request 

846 

847 

848# This exists as an aid for debugging or local running via the `dry_run` argument to Client. It doesn't make 

849# sense for returning connections etc. 

850def noop_handler(request: httpx.Request) -> httpx.Response: 

851 path = request.url.path 

852 log.debug("Dry-run request", method=request.method, path=path) 

853 

854 if path.startswith("/task-instances/") and path.endswith("/run"): 

855 # Return a fake context 

856 return httpx.Response( 

857 200, 

858 json={ 

859 "dag_run": { 

860 "dag_id": "test_dag", 

861 "run_id": "test_run", 

862 "logical_date": "2021-01-01T00:00:00Z", 

863 "start_date": "2021-01-01T00:00:00Z", 

864 "run_type": DagRunType.MANUAL, 

865 "run_after": "2021-01-01T00:00:00Z", 

866 "consumed_asset_events": [], 

867 }, 

868 "max_tries": 0, 

869 "should_retry": False, 

870 }, 

871 ) 

872 return httpx.Response(200, json={"text": "Hello, world!"}) 

873 

874 

875# Note: Given defaults make attempts after 1, 3, 7, 15 and fails after 31seconds 

876API_RETRIES = conf.getint("workers", "execution_api_retries") 

877API_RETRY_WAIT_MIN = conf.getfloat("workers", "execution_api_retry_wait_min") 

878API_RETRY_WAIT_MAX = conf.getfloat("workers", "execution_api_retry_wait_max") 

879API_SSL_CERT_PATH = conf.get("api", "ssl_cert") 

880API_TIMEOUT = conf.getfloat("workers", "execution_api_timeout") 

881 

882 

883def _should_retry_api_request(exception: BaseException) -> bool: 

884 """Determine if an API request should be retried based on the exception type.""" 

885 if isinstance(exception, httpx.HTTPStatusError): 

886 return exception.response.status_code >= 500 

887 

888 return isinstance(exception, httpx.RequestError) 

889 

890 

891class Client(httpx.Client): 

892 @lru_cache() 

893 @staticmethod 

894 def _get_ssl_context_cached(ca_file: str, ca_path: str | None = None) -> ssl.SSLContext: 

895 """Cache SSL context to prevent memory growth from repeated context creation.""" 

896 ctx = ssl.create_default_context(cafile=ca_file) 

897 if ca_path: 

898 ctx.load_verify_locations(ca_path) 

899 return ctx 

900 

901 def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): 

902 if (not base_url) ^ dry_run: 

903 raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}") 

904 auth = BearerAuth(token) 

905 

906 if dry_run: 

907 # If dry run is requested, install a no op handler so that simple tasks can "heartbeat" using a 

908 # real client, but just don't make any HTTP requests 

909 kwargs.setdefault("transport", httpx.MockTransport(noop_handler)) 

910 kwargs.setdefault("base_url", "dry-run://server") 

911 else: 

912 kwargs["base_url"] = base_url 

913 # Call via the class to avoid binding lru_cache wires to this instance. 

914 kwargs["verify"] = type(self)._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH) 

915 

916 # Set timeout if not explicitly provided 

917 kwargs.setdefault("timeout", API_TIMEOUT) 

918 

919 pyver = f"{'.'.join(map(str, sys.version_info[:3]))}" 

920 super().__init__( 

921 auth=auth, 

922 headers={ 

923 "user-agent": f"apache-airflow-task-sdk/{__version__} (Python/{pyver})", 

924 "airflow-api-version": API_VERSION, 

925 }, 

926 event_hooks={"response": [self._update_auth, raise_on_4xx_5xx], "request": [add_correlation_id]}, 

927 **kwargs, 

928 ) 

929 

930 def _update_auth(self, response: httpx.Response): 

931 if new_token := response.headers.get("Refreshed-API-Token"): 

932 log.debug("Execution API issued us a refreshed Task token") 

933 self.auth = BearerAuth(new_token) 

934 

935 @retry( 

936 retry=retry_if_exception(_should_retry_api_request), 

937 stop=stop_after_attempt(API_RETRIES), 

938 wait=wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX), 

939 before_sleep=before_log(log, logging.WARNING), 

940 reraise=True, 

941 ) 

942 def request(self, *args, **kwargs): 

943 """Implement a convenience for httpx.Client.request with a retry layer.""" 

944 # Set content type as convenience if not already set 

945 if kwargs.get("content", None) is not None and "content-type" not in ( 

946 kwargs.get("headers", {}) or {} 

947 ): 

948 kwargs["headers"] = {"content-type": "application/json"} 

949 

950 return super().request(*args, **kwargs) 

951 

952 # We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all 

953 # methods on one object prefixed with the object type (`.task_instances.update` rather than 

954 # `task_instance_update` etc.) 

955 

956 @lru_cache() # type: ignore[misc] 

957 @property 

958 def task_instances(self) -> TaskInstanceOperations: 

959 """Operations related to TaskInstances.""" 

960 return TaskInstanceOperations(self) 

961 

962 @lru_cache() # type: ignore[misc] 

963 @property 

964 def dag_runs(self) -> DagRunOperations: 

965 """Operations related to DagRuns.""" 

966 return DagRunOperations(self) 

967 

968 @lru_cache() # type: ignore[misc] 

969 @property 

970 def connections(self) -> ConnectionOperations: 

971 """Operations related to Connections.""" 

972 return ConnectionOperations(self) 

973 

974 @lru_cache() # type: ignore[misc] 

975 @property 

976 def variables(self) -> VariableOperations: 

977 """Operations related to Variables.""" 

978 return VariableOperations(self) 

979 

980 @lru_cache() # type: ignore[misc] 

981 @property 

982 def xcoms(self) -> XComOperations: 

983 """Operations related to XComs.""" 

984 return XComOperations(self) 

985 

986 @lru_cache() # type: ignore[misc] 

987 @property 

988 def assets(self) -> AssetOperations: 

989 """Operations related to Assets.""" 

990 return AssetOperations(self) 

991 

992 @lru_cache() # type: ignore[misc] 

993 @property 

994 def asset_events(self) -> AssetEventOperations: 

995 """Operations related to Asset Events.""" 

996 return AssetEventOperations(self) 

997 

998 @lru_cache() # type: ignore[misc] 

999 @property 

1000 def hitl(self): 

1001 """Operations related to HITL Responses.""" 

1002 return HITLOperations(self) 

1003 

1004 

1005# This is only used for parsing. ServerResponseError is raised instead 

1006class _ErrorBody(BaseModel): 

1007 detail: list[RemoteValidationError] | str 

1008 

1009 def __repr__(self): 

1010 return repr(self.detail) 

1011 

1012 

1013class ServerResponseError(httpx.HTTPStatusError): 

1014 def __init__(self, message: str, *, request: httpx.Request, response: httpx.Response): 

1015 super().__init__(message, request=request, response=response) 

1016 

1017 detail: list[RemoteValidationError] | str | dict[str, Any] | None 

1018 

1019 def __reduce__(self) -> tuple[Any, ...]: 

1020 # Needed because https://github.com/encode/httpx/pull/3108 isn't merged yet. 

1021 return Exception.__new__, (type(self),) + self.args, self.__dict__ 

1022 

1023 @classmethod 

1024 def from_response(cls, response: httpx.Response) -> ServerResponseError | None: 

1025 if response.is_success: 

1026 return None 

1027 # 4xx or 5xx error? 

1028 if not (400 <= response.status_code < 600): 

1029 return None 

1030 

1031 if response.headers.get("content-type") != "application/json": 

1032 return None 

1033 

1034 detail: list[RemoteValidationError] | dict[str, Any] | None = None 

1035 try: 

1036 body = _ErrorBody.model_validate_json(response.read()) 

1037 

1038 if isinstance(body.detail, list): 

1039 detail = body.detail 

1040 msg = "Remote server returned validation error" 

1041 else: 

1042 msg = body.detail or "Un-parseable error" 

1043 except Exception: 

1044 try: 

1045 detail = msgspec.json.decode(response.content) 

1046 except Exception: 

1047 # Fallback to a normal httpx error 

1048 return None 

1049 msg = "Server returned error" 

1050 

1051 self = cls(msg, request=response.request, response=response) 

1052 self.detail = detail 

1053 return self