Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/api/client.py: 33%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

439 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import logging 

21import ssl 

22import sys 

23import uuid 

24from functools import cache 

25from http import HTTPStatus 

26from typing import TYPE_CHECKING, Any, TypeVar 

27 

28import certifi 

29import httpx 

30import msgspec 

31import structlog 

32from pydantic import BaseModel 

33from tenacity import ( 

34 before_log, 

35 retry, 

36 retry_if_exception, 

37 stop_after_attempt, 

38 wait_random_exponential, 

39) 

40from uuid6 import uuid7 

41 

42from airflow.sdk import __version__ 

43from airflow.sdk.api.datamodels._generated import ( 

44 API_VERSION, 

45 AssetEventsResponse, 

46 AssetResponse, 

47 ConnectionResponse, 

48 DagRun, 

49 DagRunStateResponse, 

50 DagRunType, 

51 HITLDetailRequest, 

52 HITLDetailResponse, 

53 HITLUser, 

54 InactiveAssetsResponse, 

55 PrevSuccessfulDagRunResponse, 

56 TaskBreadcrumbsResponse, 

57 TaskInstanceState, 

58 TaskStatesResponse, 

59 TerminalStateNonSuccess, 

60 TIDeferredStatePayload, 

61 TIEnterRunningPayload, 

62 TIHeartbeatInfo, 

63 TIRescheduleStatePayload, 

64 TIRetryStatePayload, 

65 TIRunContext, 

66 TISkippedDownstreamTasksStatePayload, 

67 TISuccessStatePayload, 

68 TITerminalStatePayload, 

69 TriggerDAGRunPayload, 

70 ValidationError as RemoteValidationError, 

71 VariablePostBody, 

72 VariableResponse, 

73 XComResponse, 

74 XComSequenceIndexResponse, 

75 XComSequenceSliceResponse, 

76) 

77from airflow.sdk.configuration import conf 

78from airflow.sdk.exceptions import ErrorType 

79from airflow.sdk.execution_time.comms import ( 

80 CreateHITLDetailPayload, 

81 DRCount, 

82 ErrorResponse, 

83 OKResponse, 

84 PreviousDagRunResult, 

85 SkipDownstreamTasks, 

86 TaskRescheduleStartDate, 

87 TICount, 

88 UpdateHITLDetail, 

89 XComCountResponse, 

90) 

91 

92if TYPE_CHECKING: 

93 from datetime import datetime 

94 from typing import ParamSpec 

95 

96 from airflow.sdk.execution_time.comms import RescheduleTask 

97 

98 P = ParamSpec("P") 

99 T = TypeVar("T") 

100 

101 # # methodtools doesn't have typestubs, so give a stub 

102 def lru_cache(maxsize: int | None = 128): 

103 def wrapper(f): 

104 return f 

105 

106 return wrapper 

107else: 

108 from methodtools import lru_cache 

109 

110 

111@cache 

112def _get_fqdn(name=""): 

113 """ 

114 Get fully qualified domain name from name. 

115 

116 An empty argument is interpreted as meaning the local host. 

117 This is a patched version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254 

118 """ 

119 import socket 

120 

121 name = name.strip() 

122 if not name or name == "0.0.0.0": 

123 name = socket.gethostname() 

124 try: 

125 addrs = socket.getaddrinfo(name, None, 0, socket.SOCK_DGRAM, 0, socket.AI_CANONNAME) 

126 except OSError: 

127 pass 

128 else: 

129 for addr in addrs: 

130 if addr[3]: 

131 name = addr[3] 

132 break 

133 return name 

134 

135 

136def get_hostname(): 

137 """Fetch the hostname using the callable from config or use built-in FQDN as a fallback.""" 

138 return conf.getimport("core", "hostname_callable", fallback=_get_fqdn)() 

139 

140 

141@cache 

142def getuser() -> str: 

143 """ 

144 Get the username of the current user, or error with a nice error message if there's no current user. 

145 

146 We don't want to fall back to os.getuid() because not having a username 

147 probably means the rest of the user environment is wrong (e.g. no $HOME). 

148 Explicit failure is better than silently trying to work badly. 

149 """ 

150 import getpass 

151 

152 try: 

153 return getpass.getuser() 

154 except KeyError: 

155 raise ValueError( 

156 "The user that Airflow is running as has no username; you must run " 

157 "Airflow as a full user, with a username and home directory, " 

158 "in order for it to function properly." 

159 ) 

160 

161 

162log = structlog.get_logger(logger_name=__name__) 

163 

164__all__ = [ 

165 "Client", 

166 "ConnectionOperations", 

167 "ServerResponseError", 

168 "TaskInstanceOperations", 

169 "get_hostname", 

170 "getuser", 

171] 

172 

173 

174def get_json_error(response: httpx.Response): 

175 """Raise a ServerResponseError if we can extract error info from the error.""" 

176 err = ServerResponseError.from_response(response) 

177 if err: 

178 raise err 

179 

180 

181def raise_on_4xx_5xx(response: httpx.Response): 

182 return get_json_error(response) or response.raise_for_status() 

183 

184 

185# Py 3.11+ version 

186def raise_on_4xx_5xx_with_note(response: httpx.Response): 

187 try: 

188 return get_json_error(response) or response.raise_for_status() 

189 except httpx.HTTPStatusError as e: 

190 if TYPE_CHECKING: 

191 assert hasattr(e, "add_note") 

192 e.add_note( 

193 f"Correlation-id={response.headers.get('correlation-id', None) or response.request.headers.get('correlation-id', 'no-correlation-id')}" 

194 ) 

195 raise 

196 

197 

198if hasattr(BaseException, "add_note"): 

199 # Py 3.11+ 

200 raise_on_4xx_5xx = raise_on_4xx_5xx_with_note 

201 

202 

203def add_correlation_id(request: httpx.Request): 

204 request.headers["correlation-id"] = str(uuid7()) 

205 

206 

207class TaskInstanceOperations: 

208 __slots__ = ("client",) 

209 

210 def __init__(self, client: Client): 

211 self.client = client 

212 

213 def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext: 

214 """Tell the API server that this TI has started running.""" 

215 body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when) 

216 

217 resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json()) 

218 return TIRunContext.model_validate_json(resp.read()) 

219 

220 def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime, rendered_map_index): 

221 """Tell the API server that this TI has reached a terminal state.""" 

222 if state == TaskInstanceState.SUCCESS: 

223 raise ValueError("Logic error. SUCCESS state should call the `succeed` function instead") 

224 # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. 

225 body = TITerminalStatePayload( 

226 end_date=when, state=TerminalStateNonSuccess(state), rendered_map_index=rendered_map_index 

227 ) 

228 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) 

229 

230 def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index): 

231 """Tell the API server that this TI has failed and reached a up_for_retry state.""" 

232 body = TIRetryStatePayload(end_date=end_date, rendered_map_index=rendered_map_index) 

233 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) 

234 

235 def succeed(self, id: uuid.UUID, when: datetime, task_outlets, outlet_events, rendered_map_index): 

236 """Tell the API server that this TI has succeeded.""" 

237 body = TISuccessStatePayload( 

238 end_date=when, 

239 task_outlets=task_outlets, 

240 outlet_events=outlet_events, 

241 rendered_map_index=rendered_map_index, 

242 ) 

243 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) 

244 

245 def defer(self, id: uuid.UUID, msg): 

246 """Tell the API server that this TI has been deferred.""" 

247 body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True, exclude={"type"})) 

248 

249 # Create a deferred state payload from msg 

250 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) 

251 

252 def reschedule(self, id: uuid.UUID, msg: RescheduleTask): 

253 """Tell the API server that this TI has been reschduled.""" 

254 body = TIRescheduleStatePayload(**msg.model_dump(exclude_unset=True, exclude={"type"})) 

255 

256 # Create a reschedule state payload from msg 

257 self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) 

258 

259 def heartbeat(self, id: uuid.UUID, pid: int): 

260 body = TIHeartbeatInfo(pid=pid, hostname=get_hostname()) 

261 self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json()) 

262 

263 def skip_downstream_tasks(self, id: uuid.UUID, msg: SkipDownstreamTasks): 

264 """Tell the API server to skip the downstream tasks of this TI.""" 

265 body = TISkippedDownstreamTasksStatePayload(tasks=msg.tasks) 

266 self.client.patch(f"task-instances/{id}/skip-downstream", content=body.model_dump_json()) 

267 

268 def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> OKResponse: 

269 """Set Rendered Task Instance Fields via the API server.""" 

270 self.client.put(f"task-instances/{id}/rtif", json=body) 

271 # Any error from the server will anyway be propagated down to the supervisor, 

272 # so we choose to send a generic response to the supervisor over the server response to 

273 # decouple from the server response string 

274 return OKResponse(ok=True) 

275 

276 def set_rendered_map_index(self, id: uuid.UUID, rendered_map_index: str) -> OKResponse: 

277 """Set rendered_map_index for a task instance via the API server.""" 

278 self.client.patch(f"task-instances/{id}/rendered-map-index", json=rendered_map_index) 

279 return OKResponse(ok=True) 

280 

281 def get_previous_successful_dagrun(self, id: uuid.UUID) -> PrevSuccessfulDagRunResponse: 

282 """ 

283 Get the previous successful dag run for a given task instance. 

284 

285 The data from it is used to get values for Task Context. 

286 """ 

287 resp = self.client.get(f"task-instances/{id}/previous-successful-dagrun") 

288 return PrevSuccessfulDagRunResponse.model_validate_json(resp.read()) 

289 

290 def get_reschedule_start_date(self, id: uuid.UUID, try_number: int = 1) -> TaskRescheduleStartDate: 

291 """Get the start date of a task reschedule via the API server.""" 

292 resp = self.client.get(f"task-reschedules/{id}/start_date", params={"try_number": try_number}) 

293 return TaskRescheduleStartDate(start_date=resp.json()) 

294 

295 def get_count( 

296 self, 

297 dag_id: str, 

298 map_index: int | None = None, 

299 task_ids: list[str] | None = None, 

300 task_group_id: str | None = None, 

301 logical_dates: list[datetime] | None = None, 

302 run_ids: list[str] | None = None, 

303 states: list[str] | None = None, 

304 ) -> TICount: 

305 """Get count of task instances matching the given criteria.""" 

306 params: dict[str, Any] 

307 params = { 

308 "dag_id": dag_id, 

309 "task_ids": task_ids, 

310 "task_group_id": task_group_id, 

311 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None, 

312 "run_ids": run_ids, 

313 "states": states, 

314 } 

315 

316 # Remove None values from params 

317 params = {k: v for k, v in params.items() if v is not None} 

318 

319 if map_index is not None and map_index >= 0: 

320 params.update({"map_index": map_index}) 

321 

322 resp = self.client.get("task-instances/count", params=params) 

323 return TICount(count=resp.json()) 

324 

325 def get_task_states( 

326 self, 

327 dag_id: str, 

328 map_index: int | None = None, 

329 task_ids: list[str] | None = None, 

330 task_group_id: str | None = None, 

331 logical_dates: list[datetime] | None = None, 

332 run_ids: list[str] | None = None, 

333 ) -> TaskStatesResponse: 

334 """Get task states given criteria.""" 

335 params: dict[str, Any] 

336 params = { 

337 "dag_id": dag_id, 

338 "task_ids": task_ids, 

339 "task_group_id": task_group_id, 

340 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None, 

341 "run_ids": run_ids, 

342 } 

343 

344 # Remove None values from params 

345 params = {k: v for k, v in params.items() if v is not None} 

346 

347 if map_index is not None and map_index >= 0: 

348 params.update({"map_index": map_index}) 

349 

350 resp = self.client.get("task-instances/states", params=params) 

351 return TaskStatesResponse.model_validate_json(resp.read()) 

352 

353 def get_task_breakcrumbs(self, dag_id: str, run_id: str) -> TaskBreadcrumbsResponse: 

354 params = {"dag_id": dag_id, "run_id": run_id} 

355 resp = self.client.get("task-instances/breadcrumbs", params=params) 

356 return TaskBreadcrumbsResponse.model_validate_json(resp.read()) 

357 

358 def validate_inlets_and_outlets(self, id: uuid.UUID) -> InactiveAssetsResponse: 

359 """Validate whether there're inactive assets in inlets and outlets of a given task instance.""" 

360 resp = self.client.get(f"task-instances/{id}/validate-inlets-and-outlets") 

361 return InactiveAssetsResponse.model_validate_json(resp.read()) 

362 

363 

364class ConnectionOperations: 

365 __slots__ = ("client",) 

366 

367 def __init__(self, client: Client): 

368 self.client = client 

369 

370 def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse: 

371 """Get a connection from the API server.""" 

372 try: 

373 resp = self.client.get(f"connections/{conn_id}") 

374 except ServerResponseError as e: 

375 if e.response.status_code == HTTPStatus.NOT_FOUND: 

376 log.debug( 

377 "Connection not found", 

378 conn_id=conn_id, 

379 detail=e.detail, 

380 status_code=e.response.status_code, 

381 ) 

382 return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": conn_id}) 

383 raise 

384 return ConnectionResponse.model_validate_json(resp.read()) 

385 

386 

387class VariableOperations: 

388 __slots__ = ("client",) 

389 

390 def __init__(self, client: Client): 

391 self.client = client 

392 

393 def get(self, key: str) -> VariableResponse | ErrorResponse: 

394 """Get a variable from the API server.""" 

395 try: 

396 resp = self.client.get(f"variables/{key}") 

397 except ServerResponseError as e: 

398 if e.response.status_code == HTTPStatus.NOT_FOUND: 

399 log.error( 

400 "Variable not found", 

401 key=key, 

402 detail=e.detail, 

403 status_code=e.response.status_code, 

404 ) 

405 return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"key": key}) 

406 raise 

407 return VariableResponse.model_validate_json(resp.read()) 

408 

409 def set(self, key: str, value: str | None, description: str | None = None) -> OKResponse: 

410 """Set an Airflow Variable via the API server.""" 

411 body = VariablePostBody(val=value, description=description) 

412 self.client.put(f"variables/{key}", content=body.model_dump_json()) 

413 # Any error from the server will anyway be propagated down to the supervisor, 

414 # so we choose to send a generic response to the supervisor over the server response to 

415 # decouple from the server response string 

416 return OKResponse(ok=True) 

417 

418 def delete( 

419 self, 

420 key: str, 

421 ) -> OKResponse: 

422 """Delete a variable with given key via the API server.""" 

423 self.client.delete(f"variables/{key}") 

424 # Any error from the server will anyway be propagated down to the supervisor, 

425 # so we choose to send a generic response to the supervisor over the server response to 

426 # decouple from the server response string 

427 return OKResponse(ok=True) 

428 

429 

430class XComOperations: 

431 __slots__ = ("client",) 

432 

433 def __init__(self, client: Client): 

434 self.client = client 

435 

436 def head(self, dag_id: str, run_id: str, task_id: str, key: str) -> XComCountResponse: 

437 """Get the number of mapped XCom values.""" 

438 resp = self.client.head(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}") 

439 

440 # content_range: str | None 

441 if not (content_range := resp.headers["Content-Range"]) or not content_range.startswith( 

442 "map_indexes " 

443 ): 

444 raise RuntimeError(f"Unable to parse Content-Range header from HEAD {resp.request.url}") 

445 return XComCountResponse(len=int(content_range[len("map_indexes ") :])) 

446 

447 def get( 

448 self, 

449 dag_id: str, 

450 run_id: str, 

451 task_id: str, 

452 key: str, 

453 map_index: int | None = None, 

454 include_prior_dates: bool = False, 

455 ) -> XComResponse: 

456 """Get a XCom value from the API server.""" 

457 # TODO: check if we need to use map_index as params in the uri 

458 # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 

459 params = {} 

460 if map_index is not None and map_index >= 0: 

461 params.update({"map_index": map_index}) 

462 if include_prior_dates: 

463 params.update({"include_prior_dates": include_prior_dates}) 

464 try: 

465 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) 

466 except ServerResponseError as e: 

467 if e.response.status_code == HTTPStatus.NOT_FOUND: 

468 log.error( 

469 "XCom not found", 

470 dag_id=dag_id, 

471 run_id=run_id, 

472 task_id=task_id, 

473 key=key, 

474 map_index=map_index, 

475 detail=e.detail, 

476 status_code=e.response.status_code, 

477 ) 

478 # Airflow 2.x just ignores the absence of an XCom and moves on with a return value of None 

479 # Hence returning with key as `key` and value as `None`, so that the message is sent back to task runner 

480 # and the default value of None in xcom_pull is used. 

481 return XComResponse(key=key, value=None) 

482 raise 

483 return XComResponse.model_validate_json(resp.read()) 

484 

485 def set( 

486 self, 

487 dag_id: str, 

488 run_id: str, 

489 task_id: str, 

490 key: str, 

491 value, 

492 map_index: int | None = None, 

493 mapped_length: int | None = None, 

494 ) -> OKResponse: 

495 """Set a XCom value via the API server.""" 

496 # TODO: check if we need to use map_index as params in the uri 

497 # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 

498 params = {} 

499 if map_index is not None and map_index >= 0: 

500 params = {"map_index": map_index} 

501 if mapped_length is not None and mapped_length >= 0: 

502 params["mapped_length"] = mapped_length 

503 self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value) 

504 # Any error from the server will anyway be propagated down to the supervisor, 

505 # so we choose to send a generic response to the supervisor over the server response to 

506 # decouple from the server response string 

507 return OKResponse(ok=True) 

508 

509 def delete( 

510 self, 

511 dag_id: str, 

512 run_id: str, 

513 task_id: str, 

514 key: str, 

515 map_index: int | None = None, 

516 ) -> OKResponse: 

517 """Delete a XCom with given key via the API server.""" 

518 params = {} 

519 if map_index is not None and map_index >= 0: 

520 params = {"map_index": map_index} 

521 self.client.delete(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) 

522 # Any error from the server will anyway be propagated down to the supervisor, 

523 # so we choose to send a generic response to the supervisor over the server response to 

524 # decouple from the server response string 

525 return OKResponse(ok=True) 

526 

527 def get_sequence_item( 

528 self, 

529 dag_id: str, 

530 run_id: str, 

531 task_id: str, 

532 key: str, 

533 offset: int, 

534 ) -> XComSequenceIndexResponse | ErrorResponse: 

535 try: 

536 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}") 

537 except ServerResponseError as e: 

538 if e.response.status_code == HTTPStatus.NOT_FOUND: 

539 log.error( 

540 "XCom not found", 

541 dag_id=dag_id, 

542 run_id=run_id, 

543 task_id=task_id, 

544 key=key, 

545 offset=offset, 

546 detail=e.detail, 

547 status_code=e.response.status_code, 

548 ) 

549 return ErrorResponse( 

550 error=ErrorType.XCOM_NOT_FOUND, 

551 detail={ 

552 "dag_id": dag_id, 

553 "run_id": run_id, 

554 "task_id": task_id, 

555 "key": key, 

556 "offset": offset, 

557 }, 

558 ) 

559 raise 

560 return XComSequenceIndexResponse.model_validate_json(resp.read()) 

561 

562 def get_sequence_slice( 

563 self, 

564 dag_id: str, 

565 run_id: str, 

566 task_id: str, 

567 key: str, 

568 start: int | None, 

569 stop: int | None, 

570 step: int | None, 

571 include_prior_dates: bool = False, 

572 ) -> XComSequenceSliceResponse: 

573 params = {} 

574 if start is not None: 

575 params["start"] = start 

576 if stop is not None: 

577 params["stop"] = stop 

578 if step is not None: 

579 params["step"] = step 

580 if include_prior_dates: 

581 params["include_prior_dates"] = include_prior_dates 

582 resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params) 

583 return XComSequenceSliceResponse.model_validate_json(resp.read()) 

584 

585 

586class AssetOperations: 

587 __slots__ = ("client",) 

588 

589 def __init__(self, client: Client): 

590 self.client = client 

591 

592 def get(self, name: str | None = None, uri: str | None = None) -> AssetResponse | ErrorResponse: 

593 """Get Asset value from the API server.""" 

594 if name: 

595 endpoint = "assets/by-name" 

596 params = {"name": name} 

597 elif uri: 

598 endpoint = "assets/by-uri" 

599 params = {"uri": uri} 

600 else: 

601 raise ValueError("Either `name` or `uri` must be provided") 

602 

603 try: 

604 resp = self.client.get(endpoint, params=params) 

605 except ServerResponseError as e: 

606 if e.response.status_code == HTTPStatus.NOT_FOUND: 

607 log.error( 

608 "Asset not found", 

609 params=params, 

610 detail=e.detail, 

611 status_code=e.response.status_code, 

612 ) 

613 return ErrorResponse(error=ErrorType.ASSET_NOT_FOUND, detail=params) 

614 raise 

615 

616 return AssetResponse.model_validate_json(resp.read()) 

617 

618 

619class AssetEventOperations: 

620 __slots__ = ("client",) 

621 

622 def __init__(self, client: Client): 

623 self.client = client 

624 

625 def get( 

626 self, 

627 name: str | None = None, 

628 uri: str | None = None, 

629 alias_name: str | None = None, 

630 after: datetime | None = None, 

631 before: datetime | None = None, 

632 ascending: bool = True, 

633 limit: int | None = None, 

634 ) -> AssetEventsResponse: 

635 """Get Asset event from the API server.""" 

636 common_params: dict[str, Any] = {} 

637 if after: 

638 common_params["after"] = after.isoformat() 

639 if before: 

640 common_params["before"] = before.isoformat() 

641 common_params["ascending"] = ascending 

642 if limit: 

643 common_params["limit"] = limit 

644 if name or uri: 

645 resp = self.client.get( 

646 "asset-events/by-asset", params={"name": name, "uri": uri, **common_params} 

647 ) 

648 elif alias_name: 

649 resp = self.client.get( 

650 "asset-events/by-asset-alias", params={"name": alias_name, **common_params} 

651 ) 

652 else: 

653 raise ValueError("Either `name`, `uri` or `alias_name` must be provided") 

654 

655 return AssetEventsResponse.model_validate_json(resp.read()) 

656 

657 

658class DagRunOperations: 

659 __slots__ = ("client",) 

660 

661 def __init__(self, client: Client): 

662 self.client = client 

663 

664 def trigger( 

665 self, 

666 dag_id: str, 

667 run_id: str, 

668 conf: dict | None = None, 

669 logical_date: datetime | None = None, 

670 reset_dag_run: bool = False, 

671 ) -> OKResponse | ErrorResponse: 

672 """Trigger a Dag run via the API server.""" 

673 body = TriggerDAGRunPayload(logical_date=logical_date, conf=conf or {}, reset_dag_run=reset_dag_run) 

674 

675 try: 

676 self.client.post( 

677 f"dag-runs/{dag_id}/{run_id}", content=body.model_dump_json(exclude_defaults=True) 

678 ) 

679 except ServerResponseError as e: 

680 if e.response.status_code == HTTPStatus.CONFLICT: 

681 if reset_dag_run: 

682 log.info("Dag Run already exists; Resetting Dag Run.", dag_id=dag_id, run_id=run_id) 

683 return self.clear(run_id=run_id, dag_id=dag_id) 

684 

685 log.info("Dag Run already exists!", detail=e.detail, dag_id=dag_id, run_id=run_id) 

686 return ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS) 

687 raise 

688 

689 return OKResponse(ok=True) 

690 

691 def clear(self, dag_id: str, run_id: str) -> OKResponse: 

692 """Clear a Dag run via the API server.""" 

693 self.client.post(f"dag-runs/{dag_id}/{run_id}/clear") 

694 # TODO: Error handling 

695 return OKResponse(ok=True) 

696 

697 def get_detail(self, dag_id: str, run_id: str) -> DagRun: 

698 """Get detail of a dag run.""" 

699 resp = self.client.get(f"dag-runs/{dag_id}/{run_id}") 

700 return DagRun.model_validate_json(resp.read()) 

701 

702 def get_state(self, dag_id: str, run_id: str) -> DagRunStateResponse: 

703 """Get the state of a Dag run via the API server.""" 

704 resp = self.client.get(f"dag-runs/{dag_id}/{run_id}/state") 

705 return DagRunStateResponse.model_validate_json(resp.read()) 

706 

707 def get_count( 

708 self, 

709 dag_id: str, 

710 logical_dates: list[datetime] | None = None, 

711 run_ids: list[str] | None = None, 

712 states: list[str] | None = None, 

713 ) -> DRCount: 

714 """Get count of Dag runs matching the given criteria.""" 

715 params = { 

716 "dag_id": dag_id, 

717 "logical_dates": [d.isoformat() for d in logical_dates] if logical_dates is not None else None, 

718 "run_ids": run_ids, 

719 "states": states, 

720 } 

721 

722 # Remove None values from params 

723 params = {k: v for k, v in params.items() if v is not None} 

724 

725 resp = self.client.get("dag-runs/count", params=params) 

726 return DRCount(count=resp.json()) 

727 

728 def get_previous( 

729 self, 

730 dag_id: str, 

731 logical_date: datetime, 

732 state: str | None = None, 

733 ) -> PreviousDagRunResult: 

734 """Get the previous DAG run before the given logical date, optionally filtered by state.""" 

735 params = { 

736 "dag_id": dag_id, 

737 "logical_date": logical_date.isoformat(), 

738 } 

739 if state: 

740 params["state"] = state 

741 resp = self.client.get("dag-runs/previous", params=params) 

742 return PreviousDagRunResult(dag_run=resp.json()) 

743 

744 

745class HITLOperations: 

746 """ 

747 Operations related to Human in the loop. Require Airflow 3.1+. 

748 

749 :meta: private 

750 """ 

751 

752 __slots__ = ("client",) 

753 

754 def __init__(self, client: Client) -> None: 

755 self.client = client 

756 

757 def add_response( 

758 self, 

759 *, 

760 ti_id: uuid.UUID, 

761 options: list[str], 

762 subject: str, 

763 body: str | None = None, 

764 defaults: list[str] | None = None, 

765 multiple: bool = False, 

766 params: dict[str, dict[str, Any]] | None = None, 

767 assigned_users: list[HITLUser] | None = None, 

768 ) -> HITLDetailRequest: 

769 """Add a Human-in-the-loop response that waits for human response for a specific Task Instance.""" 

770 payload = CreateHITLDetailPayload( 

771 ti_id=ti_id, 

772 options=options, 

773 subject=subject, 

774 body=body, 

775 defaults=defaults, 

776 multiple=multiple, 

777 params=params, 

778 assigned_users=assigned_users, 

779 ) 

780 resp = self.client.post( 

781 f"/hitlDetails/{ti_id}", 

782 content=payload.model_dump_json(), 

783 ) 

784 return HITLDetailRequest.model_validate_json(resp.read()) 

785 

786 def update_response( 

787 self, 

788 *, 

789 ti_id: uuid.UUID, 

790 chosen_options: list[str], 

791 params_input: dict[str, Any], 

792 ) -> HITLDetailResponse: 

793 """Update an existing Human-in-the-loop response.""" 

794 payload = UpdateHITLDetail( 

795 ti_id=ti_id, 

796 chosen_options=chosen_options, 

797 params_input=params_input, 

798 ) 

799 resp = self.client.patch( 

800 f"/hitlDetails/{ti_id}", 

801 content=payload.model_dump_json(), 

802 ) 

803 return HITLDetailResponse.model_validate_json(resp.read()) 

804 

805 def get_detail_response(self, ti_id: uuid.UUID) -> HITLDetailResponse: 

806 """Get content part of a Human-in-the-loop response for a specific Task Instance.""" 

807 resp = self.client.get(f"/hitlDetails/{ti_id}") 

808 return HITLDetailResponse.model_validate_json(resp.read()) 

809 

810 

811class BearerAuth(httpx.Auth): 

812 def __init__(self, token: str): 

813 self.token: str = token 

814 

815 def auth_flow(self, request: httpx.Request): 

816 if self.token: 

817 request.headers["Authorization"] = "Bearer " + self.token 

818 yield request 

819 

820 

821# This exists as an aid for debugging or local running via the `dry_run` argument to Client. It doesn't make 

822# sense for returning connections etc. 

823def noop_handler(request: httpx.Request) -> httpx.Response: 

824 path = request.url.path 

825 log.debug("Dry-run request", method=request.method, path=path) 

826 

827 if path.startswith("/task-instances/") and path.endswith("/run"): 

828 # Return a fake context 

829 return httpx.Response( 

830 200, 

831 json={ 

832 "dag_run": { 

833 "dag_id": "test_dag", 

834 "run_id": "test_run", 

835 "logical_date": "2021-01-01T00:00:00Z", 

836 "start_date": "2021-01-01T00:00:00Z", 

837 "run_type": DagRunType.MANUAL, 

838 "run_after": "2021-01-01T00:00:00Z", 

839 "consumed_asset_events": [], 

840 }, 

841 "max_tries": 0, 

842 "should_retry": False, 

843 }, 

844 ) 

845 return httpx.Response(200, json={"text": "Hello, world!"}) 

846 

847 

848# Note: Given defaults make attempts after 1, 3, 7, 15 and fails after 31seconds 

849API_RETRIES = conf.getint("workers", "execution_api_retries") 

850API_RETRY_WAIT_MIN = conf.getfloat("workers", "execution_api_retry_wait_min") 

851API_RETRY_WAIT_MAX = conf.getfloat("workers", "execution_api_retry_wait_max") 

852API_SSL_CERT_PATH = conf.get("api", "ssl_cert") 

853API_TIMEOUT = conf.getfloat("workers", "execution_api_timeout") 

854 

855 

856def _should_retry_api_request(exception: BaseException) -> bool: 

857 """Determine if an API request should be retried based on the exception type.""" 

858 if isinstance(exception, httpx.HTTPStatusError): 

859 return exception.response.status_code >= 500 

860 

861 return isinstance(exception, httpx.RequestError) 

862 

863 

864class Client(httpx.Client): 

865 @lru_cache() 

866 @staticmethod 

867 def _get_ssl_context_cached(ca_file: str, ca_path: str | None = None) -> ssl.SSLContext: 

868 """Cache SSL context to prevent memory growth from repeated context creation.""" 

869 ctx = ssl.create_default_context(cafile=ca_file) 

870 if ca_path: 

871 ctx.load_verify_locations(ca_path) 

872 return ctx 

873 

874 def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): 

875 if (not base_url) ^ dry_run: 

876 raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}") 

877 auth = BearerAuth(token) 

878 

879 if dry_run: 

880 # If dry run is requested, install a no op handler so that simple tasks can "heartbeat" using a 

881 # real client, but just don't make any HTTP requests 

882 kwargs.setdefault("transport", httpx.MockTransport(noop_handler)) 

883 kwargs.setdefault("base_url", "dry-run://server") 

884 else: 

885 kwargs["base_url"] = base_url 

886 kwargs["verify"] = self._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH) 

887 

888 # Set timeout if not explicitly provided 

889 kwargs.setdefault("timeout", API_TIMEOUT) 

890 

891 pyver = f"{'.'.join(map(str, sys.version_info[:3]))}" 

892 super().__init__( 

893 auth=auth, 

894 headers={ 

895 "user-agent": f"apache-airflow-task-sdk/{__version__} (Python/{pyver})", 

896 "airflow-api-version": API_VERSION, 

897 }, 

898 event_hooks={"response": [self._update_auth, raise_on_4xx_5xx], "request": [add_correlation_id]}, 

899 **kwargs, 

900 ) 

901 

902 def _update_auth(self, response: httpx.Response): 

903 if new_token := response.headers.get("Refreshed-API-Token"): 

904 log.debug("Execution API issued us a refreshed Task token") 

905 self.auth = BearerAuth(new_token) 

906 

907 @retry( 

908 retry=retry_if_exception(_should_retry_api_request), 

909 stop=stop_after_attempt(API_RETRIES), 

910 wait=wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX), 

911 before_sleep=before_log(log, logging.WARNING), 

912 reraise=True, 

913 ) 

914 def request(self, *args, **kwargs): 

915 """Implement a convenience for httpx.Client.request with a retry layer.""" 

916 # Set content type as convenience if not already set 

917 if "content" in kwargs and "headers" not in kwargs: 

918 kwargs["headers"] = {"content-type": "application/json"} 

919 

920 return super().request(*args, **kwargs) 

921 

922 # We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all 

923 # methods on one object prefixed with the object type (`.task_instances.update` rather than 

924 # `task_instance_update` etc.) 

925 

926 @lru_cache() # type: ignore[misc] 

927 @property 

928 def task_instances(self) -> TaskInstanceOperations: 

929 """Operations related to TaskInstances.""" 

930 return TaskInstanceOperations(self) 

931 

932 @lru_cache() # type: ignore[misc] 

933 @property 

934 def dag_runs(self) -> DagRunOperations: 

935 """Operations related to DagRuns.""" 

936 return DagRunOperations(self) 

937 

938 @lru_cache() # type: ignore[misc] 

939 @property 

940 def connections(self) -> ConnectionOperations: 

941 """Operations related to Connections.""" 

942 return ConnectionOperations(self) 

943 

944 @lru_cache() # type: ignore[misc] 

945 @property 

946 def variables(self) -> VariableOperations: 

947 """Operations related to Variables.""" 

948 return VariableOperations(self) 

949 

950 @lru_cache() # type: ignore[misc] 

951 @property 

952 def xcoms(self) -> XComOperations: 

953 """Operations related to XComs.""" 

954 return XComOperations(self) 

955 

956 @lru_cache() # type: ignore[misc] 

957 @property 

958 def assets(self) -> AssetOperations: 

959 """Operations related to Assets.""" 

960 return AssetOperations(self) 

961 

962 @lru_cache() # type: ignore[misc] 

963 @property 

964 def asset_events(self) -> AssetEventOperations: 

965 """Operations related to Asset Events.""" 

966 return AssetEventOperations(self) 

967 

968 @lru_cache() # type: ignore[misc] 

969 @property 

970 def hitl(self): 

971 """Operations related to HITL Responses.""" 

972 return HITLOperations(self) 

973 

974 

975# This is only used for parsing. ServerResponseError is raised instead 

976class _ErrorBody(BaseModel): 

977 detail: list[RemoteValidationError] | str 

978 

979 def __repr__(self): 

980 return repr(self.detail) 

981 

982 

983class ServerResponseError(httpx.HTTPStatusError): 

984 def __init__(self, message: str, *, request: httpx.Request, response: httpx.Response): 

985 super().__init__(message, request=request, response=response) 

986 

987 detail: list[RemoteValidationError] | str | dict[str, Any] | None 

988 

989 def __reduce__(self) -> tuple[Any, ...]: 

990 # Needed because https://github.com/encode/httpx/pull/3108 isn't merged yet. 

991 return Exception.__new__, (type(self),) + self.args, self.__dict__ 

992 

993 @classmethod 

994 def from_response(cls, response: httpx.Response) -> ServerResponseError | None: 

995 if response.is_success: 

996 return None 

997 # 4xx or 5xx error? 

998 if not (400 <= response.status_code < 600): 

999 return None 

1000 

1001 if response.headers.get("content-type") != "application/json": 

1002 return None 

1003 

1004 detail: list[RemoteValidationError] | dict[str, Any] | None = None 

1005 try: 

1006 body = _ErrorBody.model_validate_json(response.read()) 

1007 

1008 if isinstance(body.detail, list): 

1009 detail = body.detail 

1010 msg = "Remote server returned validation error" 

1011 else: 

1012 msg = body.detail or "Un-parseable error" 

1013 except Exception: 

1014 try: 

1015 detail = msgspec.json.decode(response.content) 

1016 except Exception: 

1017 # Fallback to a normal httpx error 

1018 return None 

1019 msg = "Server returned error" 

1020 

1021 self = cls(msg, request=response.request, response=response) 

1022 self.detail = detail 

1023 return self