Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/providers/standard/operators/python.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

468 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import inspect 

21import json 

22import logging 

23import os 

24import re 

25import shutil 

26import subprocess 

27import sys 

28import textwrap 

29import types 

30import warnings 

31from abc import ABCMeta, abstractmethod 

32from collections.abc import Callable, Collection, Container, Iterable, Mapping, Sequence 

33from functools import cache 

34from itertools import chain 

35from pathlib import Path 

36from tempfile import TemporaryDirectory 

37from typing import TYPE_CHECKING, Any, NamedTuple, cast 

38 

39import lazy_object_proxy 

40from packaging.requirements import InvalidRequirement, Requirement 

41from packaging.specifiers import InvalidSpecifier 

42from packaging.version import InvalidVersion 

43 

44from airflow.exceptions import ( 

45 AirflowConfigException, 

46 AirflowProviderDeprecationWarning, 

47 DeserializingResultError, 

48) 

49from airflow.models.variable import Variable 

50from airflow.providers.common.compat.sdk import AirflowException, AirflowSkipException, context_merge 

51from airflow.providers.standard.hooks.package_index import PackageIndexHook 

52from airflow.providers.standard.utils.python_virtualenv import ( 

53 _execute_in_subprocess, 

54 prepare_virtualenv, 

55 write_python_script, 

56) 

57from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator 

58from airflow.utils import hashlib_wrapper 

59from airflow.utils.file import get_unique_dag_module_name 

60from airflow.utils.operator_helpers import KeywordParameters 

61 

62if AIRFLOW_V_3_0_PLUS: 

63 from airflow.providers.standard.operators.branch import BaseBranchOperator 

64 from airflow.providers.standard.utils.skipmixin import SkipMixin 

65else: 

66 from airflow.models.skipmixin import SkipMixin 

67 from airflow.operators.branch import BaseBranchOperator # type: ignore[no-redef] 

68 

69 

70log = logging.getLogger(__name__) 

71 

72if TYPE_CHECKING: 

73 from typing import Literal 

74 

75 from pendulum.datetime import DateTime 

76 

77 from airflow.providers.common.compat.sdk import Context 

78 from airflow.sdk.execution_time.callback_runner import ExecutionCallableRunner 

79 from airflow.sdk.execution_time.context import OutletEventAccessorsProtocol 

80 

81 _SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"] 

82 

83 

84@cache 

85def _parse_version_info(text: str) -> tuple[int, int, int, str, int]: 

86 """Parse python version info from a text.""" 

87 parts = text.strip().split(".") 

88 if len(parts) != 5: 

89 msg = f"Invalid Python version info, expected 5 components separated by '.', but got {text!r}." 

90 raise ValueError(msg) 

91 try: 

92 return int(parts[0]), int(parts[1]), int(parts[2]), parts[3], int(parts[4]) 

93 except ValueError: 

94 msg = f"Unable to convert parts {parts} parsed from {text!r} to (int, int, int, str, int)." 

95 raise ValueError(msg) from None 

96 

97 

98class _PythonVersionInfo(NamedTuple): 

99 """Provide the same interface as ``sys.version_info``.""" 

100 

101 major: int 

102 minor: int 

103 micro: int 

104 releaselevel: str 

105 serial: int 

106 

107 @classmethod 

108 def from_executable(cls, executable: str) -> _PythonVersionInfo: 

109 """Parse python version info from an executable.""" 

110 cmd = [executable, "-c", 'import sys; print(".".join(map(str, sys.version_info)))'] 

111 try: 

112 result = subprocess.check_output(cmd, text=True) 

113 except Exception as e: 

114 raise ValueError(f"Error while executing command {cmd}: {e}") 

115 return cls(*_parse_version_info(result.strip())) 

116 

117 

118class PythonOperator(BaseOperator): 

119 """ 

120 Executes a Python callable. 

121 

122 .. seealso:: 

123 For more information on how to use this operator, take a look at the guide: 

124 :ref:`howto/operator:PythonOperator` 

125 

126 When running your callable, Airflow will pass a set of keyword arguments that can be used in your 

127 function. This set of kwargs correspond exactly to what you can use in your jinja templates. 

128 For this to work, you need to define ``**kwargs`` in your function header, or you can add directly the 

129 keyword arguments you would like to get - for example with the below code your callable will get 

130 the values of ``ti`` context variables. 

131 

132 With explicit arguments: 

133 

134 .. code-block:: python 

135 

136 def my_python_callable(ti): 

137 pass 

138 

139 With kwargs: 

140 

141 .. code-block:: python 

142 

143 def my_python_callable(**kwargs): 

144 ti = kwargs["ti"] 

145 

146 

147 :param python_callable: A reference to an object that is callable 

148 :param op_args: a list of positional arguments that will get unpacked when 

149 calling your callable 

150 :param op_kwargs: a dictionary of keyword arguments that will get unpacked 

151 in your function 

152 :param templates_dict: a dictionary where the values are templates that 

153 will get templated by the Airflow engine sometime between 

154 ``__init__`` and ``execute`` takes place and are made available 

155 in your callable's context after the template has been applied. (templated) 

156 :param templates_exts: a list of file extensions to resolve while 

157 processing templated fields, for examples ``['.sql', '.hql']`` 

158 :param show_return_value_in_logs: a bool value whether to show return_value 

159 logs. Defaults to True, which allows return value log output. 

160 It can be set to False to prevent log output of return value when you return huge data 

161 such as transmission a large amount of XCom to TaskAPI. 

162 """ 

163 

164 template_fields: Sequence[str] = ("templates_dict", "op_args", "op_kwargs") 

165 template_fields_renderers = {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"} 

166 BLUE = "#ffefeb" 

167 ui_color = BLUE 

168 

169 # since we won't mutate the arguments, we should just do the shallow copy 

170 # there are some cases we can't deepcopy the objects(e.g protobuf). 

171 shallow_copy_attrs: Sequence[str] = ("python_callable", "op_kwargs") 

172 

173 def __init__( 

174 self, 

175 *, 

176 python_callable: Callable, 

177 op_args: Collection[Any] | None = None, 

178 op_kwargs: Mapping[str, Any] | None = None, 

179 templates_dict: dict[str, Any] | None = None, 

180 templates_exts: Sequence[str] | None = None, 

181 show_return_value_in_logs: bool = True, 

182 **kwargs, 

183 ) -> None: 

184 super().__init__(**kwargs) 

185 if not callable(python_callable): 

186 raise AirflowException("`python_callable` param must be callable") 

187 self.python_callable = python_callable 

188 self.op_args = op_args or () 

189 self.op_kwargs = op_kwargs or {} 

190 self.templates_dict = templates_dict 

191 if templates_exts: 

192 self.template_ext = templates_exts 

193 self.show_return_value_in_logs = show_return_value_in_logs 

194 

195 def execute(self, context: Context) -> Any: 

196 context_merge(context, self.op_kwargs, templates_dict=self.templates_dict) 

197 self.op_kwargs = self.determine_kwargs(context) 

198 

199 # This needs to be lazy because subclasses may implement execute_callable 

200 # by running a separate process that can't use the eager result. 

201 def __prepare_execution() -> tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None: 

202 if AIRFLOW_V_3_0_PLUS: 

203 from airflow.sdk.execution_time.callback_runner import create_executable_runner 

204 from airflow.sdk.execution_time.context import context_get_outlet_events 

205 

206 return create_executable_runner, context_get_outlet_events(context) 

207 from airflow.utils.context import context_get_outlet_events # type: ignore 

208 from airflow.utils.operator_helpers import ExecutionCallableRunner # type: ignore 

209 

210 return ExecutionCallableRunner, context_get_outlet_events(context) 

211 

212 self.__prepare_execution = __prepare_execution 

213 

214 return_value = self.execute_callable() 

215 if self.show_return_value_in_logs: 

216 self.log.info("Done. Returned value was: %s", return_value) 

217 else: 

218 self.log.info("Done. Returned value not shown") 

219 

220 return return_value 

221 

222 def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: 

223 return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking() 

224 

225 __prepare_execution: Callable[[], tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None] 

226 

227 def execute_callable(self) -> Any: 

228 """ 

229 Call the python callable with the given arguments. 

230 

231 :return: the return value of the call. 

232 """ 

233 if (execution_preparation := self.__prepare_execution()) is None: 

234 return self.python_callable(*self.op_args, **self.op_kwargs) 

235 create_execution_runner, asset_events = execution_preparation 

236 runner = create_execution_runner(self.python_callable, asset_events, logger=self.log) 

237 return runner.run(*self.op_args, **self.op_kwargs) 

238 

239 

240class BranchPythonOperator(BaseBranchOperator, PythonOperator): 

241 """ 

242 A workflow can "branch" or follow a path after the execution of this task. 

243 

244 It derives the PythonOperator and expects a Python function that returns 

245 a single task_id, a single task_group_id, or a list of task_ids and/or 

246 task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned 

247 should point to a task or task group directly downstream from {self}. All 

248 other "branches" or directly downstream tasks are marked with a state of 

249 ``skipped`` so that these paths can't move forward. The ``skipped`` states 

250 are propagated downstream to allow for the DAG state to fill up and 

251 the DAG run's state to be inferred. 

252 """ 

253 

254 def choose_branch(self, context: Context) -> str | Iterable[str]: 

255 return PythonOperator.execute(self, context) 

256 

257 

258class ShortCircuitOperator(PythonOperator, SkipMixin): 

259 """ 

260 Allows a pipeline to continue based on the result of a ``python_callable``. 

261 

262 The ShortCircuitOperator is derived from the PythonOperator and evaluates the result of a 

263 ``python_callable``. If the returned result is False or a falsy value, the pipeline will be 

264 short-circuited. Downstream tasks will be marked with a state of "skipped" based on the short-circuiting 

265 mode configured. If the returned result is True or a truthy value, downstream tasks proceed as normal and 

266 an ``XCom`` of the returned result is pushed. 

267 

268 The short-circuiting can be configured to either respect or ignore the ``trigger_rule`` set for 

269 downstream tasks. If ``ignore_downstream_trigger_rules`` is set to True, the default setting, all 

270 downstream tasks are skipped without considering the ``trigger_rule`` defined for tasks. However, if this 

271 parameter is set to False, the direct downstream tasks are skipped but the specified ``trigger_rule`` for 

272 other subsequent downstream tasks are respected. In this mode, the operator assumes the direct downstream 

273 tasks were purposely meant to be skipped but perhaps not other subsequent tasks. 

274 

275 .. seealso:: 

276 For more information on how to use this operator, take a look at the guide: 

277 :ref:`howto/operator:ShortCircuitOperator` 

278 

279 :param ignore_downstream_trigger_rules: If set to True, all downstream tasks from this operator task will 

280 be skipped. This is the default behavior. If set to False, the direct, downstream task(s) will be 

281 skipped but the ``trigger_rule`` defined for all other downstream tasks will be respected. 

282 """ 

283 

284 inherits_from_skipmixin = True 

285 

286 def __init__(self, *, ignore_downstream_trigger_rules: bool = True, **kwargs) -> None: 

287 super().__init__(**kwargs) 

288 self.ignore_downstream_trigger_rules = ignore_downstream_trigger_rules 

289 

290 def execute(self, context: Context) -> Any: 

291 condition = super().execute(context) 

292 self.log.info("Condition result is %s", condition) 

293 

294 if condition: 

295 self.log.info("Proceeding with downstream tasks...") 

296 return condition 

297 

298 if not self.downstream_task_ids: 

299 self.log.info("No downstream tasks; nothing to do.") 

300 return condition 

301 

302 dag_run = context["dag_run"] 

303 

304 def get_tasks_to_skip(): 

305 if self.ignore_downstream_trigger_rules is True: 

306 tasks = context["task"].get_flat_relatives(upstream=False) 

307 else: 

308 tasks = context["task"].get_direct_relatives(upstream=False) 

309 for t in tasks: 

310 if not t.is_teardown: 

311 yield t 

312 

313 to_skip = get_tasks_to_skip() 

314 

315 # this lets us avoid an intermediate list unless debug logging 

316 if self.log.getEffectiveLevel() <= logging.DEBUG: 

317 self.log.debug("Downstream task IDs %s", to_skip := list(get_tasks_to_skip())) 

318 

319 self.log.info("Skipping downstream tasks") 

320 if AIRFLOW_V_3_0_PLUS: 

321 self.skip( 

322 ti=context["ti"], 

323 tasks=to_skip, 

324 ) 

325 else: 

326 if to_skip: 

327 self.skip( 

328 dag_run=context["dag_run"], 

329 tasks=to_skip, 

330 execution_date=cast("DateTime", dag_run.logical_date), # type: ignore[call-arg] 

331 map_index=context["ti"].map_index, 

332 ) 

333 

334 self.log.info("Done.") 

335 # returns the result of the super execute method as it is instead of returning None 

336 return condition 

337 

338 

339def _load_pickle(): 

340 import pickle 

341 

342 return pickle 

343 

344 

345def _load_dill(): 

346 try: 

347 import dill 

348 except ModuleNotFoundError: 

349 log.error("Unable to import `dill` module. Please please make sure that it installed.") 

350 raise 

351 return dill 

352 

353 

354def _load_cloudpickle(): 

355 try: 

356 import cloudpickle 

357 except ModuleNotFoundError: 

358 log.error( 

359 "Unable to import `cloudpickle` module. " 

360 "Please install it with: pip install 'apache-airflow[cloudpickle]'" 

361 ) 

362 raise 

363 return cloudpickle 

364 

365 

366_SERIALIZERS: dict[_SerializerTypeDef, Any] = { 

367 "pickle": lazy_object_proxy.Proxy(_load_pickle), 

368 "dill": lazy_object_proxy.Proxy(_load_dill), 

369 "cloudpickle": lazy_object_proxy.Proxy(_load_cloudpickle), 

370} 

371 

372 

373class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta): 

374 BASE_SERIALIZABLE_CONTEXT_KEYS = { 

375 "ds", 

376 "ds_nodash", 

377 "expanded_ti_count", 

378 "inlets", 

379 "outlets", 

380 "run_id", 

381 "task_instance_key_str", 

382 "test_mode", 

383 "ts", 

384 "ts_nodash", 

385 "ts_nodash_with_tz", 

386 # The following should be removed when Airflow 2 support is dropped. 

387 "next_ds", 

388 "next_ds_nodash", 

389 "prev_ds", 

390 "prev_ds_nodash", 

391 "tomorrow_ds", 

392 "tomorrow_ds_nodash", 

393 "yesterday_ds", 

394 "yesterday_ds_nodash", 

395 } 

396 if AIRFLOW_V_3_0_PLUS: 

397 BASE_SERIALIZABLE_CONTEXT_KEYS.add("task_reschedule_count") 

398 

399 PENDULUM_SERIALIZABLE_CONTEXT_KEYS = { 

400 "data_interval_end", 

401 "data_interval_start", 

402 "logical_date", 

403 "prev_data_interval_end_success", 

404 "prev_data_interval_start_success", 

405 "prev_start_date_success", 

406 "prev_end_date_success", 

407 # The following should be removed when Airflow 2 support is dropped. 

408 "execution_date", 

409 "next_execution_date", 

410 "prev_execution_date", 

411 "prev_execution_date_success", 

412 } 

413 

414 AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = { 

415 "macros", 

416 "conf", 

417 "dag", 

418 "dag_run", 

419 "task", 

420 "params", 

421 "triggering_asset_events", 

422 # The following should be removed when Airflow 2 support is dropped. 

423 "triggering_dataset_events", 

424 } 

425 

426 def __init__( 

427 self, 

428 *, 

429 python_callable: Callable, 

430 serializer: _SerializerTypeDef | None = None, 

431 op_args: Collection[Any] | None = None, 

432 op_kwargs: Mapping[str, Any] | None = None, 

433 string_args: Iterable[str] | None = None, 

434 templates_dict: dict | None = None, 

435 templates_exts: list[str] | None = None, 

436 expect_airflow: bool = True, 

437 skip_on_exit_code: int | Container[int] | None = None, 

438 env_vars: dict[str, str] | None = None, 

439 inherit_env: bool = True, 

440 **kwargs, 

441 ): 

442 if ( 

443 not isinstance(python_callable, types.FunctionType) 

444 or isinstance(python_callable, types.LambdaType) 

445 and python_callable.__name__ == "<lambda>" 

446 ): 

447 raise ValueError(f"{type(self).__name__} only supports functions for python_callable arg") 

448 if inspect.isgeneratorfunction(python_callable): 

449 raise ValueError(f"{type(self).__name__} does not support using 'yield' in python_callable") 

450 super().__init__( 

451 python_callable=python_callable, 

452 op_args=op_args, 

453 op_kwargs=op_kwargs, 

454 templates_dict=templates_dict, 

455 templates_exts=templates_exts, 

456 **kwargs, 

457 ) 

458 self.string_args = string_args or [] 

459 

460 serializer = serializer or "pickle" 

461 if serializer not in _SERIALIZERS: 

462 msg = ( 

463 f"Unsupported serializer {serializer!r}. Expected one of {', '.join(map(repr, _SERIALIZERS))}" 

464 ) 

465 raise AirflowException(msg) 

466 

467 self.pickling_library = _SERIALIZERS[serializer] 

468 self.serializer: _SerializerTypeDef = serializer 

469 

470 self.expect_airflow = expect_airflow 

471 self.skip_on_exit_code = ( 

472 skip_on_exit_code 

473 if isinstance(skip_on_exit_code, Container) 

474 else [skip_on_exit_code] 

475 if skip_on_exit_code is not None 

476 else [] 

477 ) 

478 self.env_vars = env_vars 

479 self.inherit_env = inherit_env 

480 

481 @abstractmethod 

482 def _iter_serializable_context_keys(self): 

483 pass 

484 

485 def execute(self, context: Context) -> Any: 

486 serializable_keys = set(self._iter_serializable_context_keys()) 

487 new = {k: v for k, v in context.items() if k in serializable_keys} 

488 serializable_context = cast("Context", new) 

489 return super().execute(context=serializable_context) 

490 

491 def get_python_source(self): 

492 """Return the source of self.python_callable.""" 

493 return textwrap.dedent(inspect.getsource(self.python_callable)) 

494 

495 def _write_args(self, file: Path): 

496 def resolve_proxies(obj): 

497 """Recursively replaces lazy_object_proxy.Proxy instances with their resolved values.""" 

498 if isinstance(obj, lazy_object_proxy.Proxy): 

499 return obj.__wrapped__ # force evaluation 

500 if isinstance(obj, dict): 

501 return {k: resolve_proxies(v) for k, v in obj.items()} 

502 if isinstance(obj, list): 

503 return [resolve_proxies(v) for v in obj] 

504 return obj 

505 

506 if self.op_args or self.op_kwargs: 

507 self.log.info("Use %r as serializer.", self.serializer) 

508 file.write_bytes( 

509 self.pickling_library.dumps({"args": self.op_args, "kwargs": resolve_proxies(self.op_kwargs)}) 

510 ) 

511 

512 def _write_string_args(self, file: Path): 

513 file.write_text("\n".join(map(str, self.string_args))) 

514 

515 def _read_result(self, path: Path): 

516 if path.stat().st_size == 0: 

517 return None 

518 try: 

519 return self.pickling_library.loads(path.read_bytes()) 

520 except ValueError as value_error: 

521 raise DeserializingResultError() from value_error 

522 

523 def __deepcopy__(self, memo): 

524 # module objects can't be copied _at all__ 

525 memo[id(self.pickling_library)] = self.pickling_library 

526 return super().__deepcopy__(memo) 

527 

528 def _execute_python_callable_in_subprocess(self, python_path: Path): 

529 with TemporaryDirectory(prefix="venv-call") as tmp: 

530 tmp_dir = Path(tmp) 

531 op_kwargs: dict[str, Any] = dict(self.op_kwargs) 

532 if self.templates_dict: 

533 op_kwargs["templates_dict"] = self.templates_dict 

534 input_path = tmp_dir / "script.in" 

535 output_path = tmp_dir / "script.out" 

536 string_args_path = tmp_dir / "string_args.txt" 

537 script_path = tmp_dir / "script.py" 

538 termination_log_path = tmp_dir / "termination.log" 

539 airflow_context_path = tmp_dir / "airflow_context.json" 

540 

541 self._write_args(input_path) 

542 self._write_string_args(string_args_path) 

543 

544 jinja_context = { 

545 "op_args": self.op_args, 

546 "op_kwargs": op_kwargs, 

547 "expect_airflow": self.expect_airflow, 

548 "pickling_library": self.serializer, 

549 "python_callable": self.python_callable.__name__, 

550 "python_callable_source": self.get_python_source(), 

551 } 

552 

553 if inspect.getfile(self.python_callable) == self.dag.fileloc: 

554 jinja_context["modified_dag_module_name"] = get_unique_dag_module_name(self.dag.fileloc) 

555 

556 write_python_script( 

557 jinja_context=jinja_context, 

558 filename=os.fspath(script_path), 

559 render_template_as_native_obj=self.dag.render_template_as_native_obj, 

560 ) 

561 

562 env_vars = dict(os.environ) if self.inherit_env else {} 

563 if fd := os.getenv("__AIRFLOW_SUPERVISOR_FD"): 

564 env_vars["__AIRFLOW_SUPERVISOR_FD"] = fd 

565 if self.env_vars: 

566 env_vars.update(self.env_vars) 

567 

568 try: 

569 cmd: list[str] = [ 

570 os.fspath(python_path), 

571 os.fspath(script_path), 

572 os.fspath(input_path), 

573 os.fspath(output_path), 

574 os.fspath(string_args_path), 

575 os.fspath(termination_log_path), 

576 os.fspath(airflow_context_path), 

577 ] 

578 _execute_in_subprocess( 

579 cmd=cmd, 

580 env=env_vars, 

581 ) 

582 except subprocess.CalledProcessError as e: 

583 if e.returncode in self.skip_on_exit_code: 

584 raise AirflowSkipException(f"Process exited with code {e.returncode}. Skipping.") 

585 if termination_log_path.exists() and termination_log_path.stat().st_size > 0: 

586 error_msg = f"Process returned non-zero exit status {e.returncode}.\n" 

587 with open(termination_log_path) as file: 

588 error_msg += file.read() 

589 raise AirflowException(error_msg) from None 

590 raise 

591 

592 if 0 in self.skip_on_exit_code: 

593 raise AirflowSkipException("Process exited with code 0. Skipping.") 

594 

595 return self._read_result(output_path) 

596 

597 def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: 

598 keyword_params = KeywordParameters.determine(self.python_callable, self.op_args, context) 

599 if AIRFLOW_V_3_0_PLUS: 

600 return keyword_params.unpacking() 

601 return keyword_params.serializing() # type: ignore[attr-defined] 

602 

603 

604class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): 

605 """ 

606 Run a function in a virtualenv that is created and destroyed automatically. 

607 

608 The function (has certain caveats) must be defined using def, and not be 

609 part of a class. All imports must happen inside the function 

610 and no variables outside the scope may be referenced. A global scope 

611 variable named virtualenv_string_args will be available (populated by 

612 string_args). In addition, one can pass stuff through op_args and op_kwargs, and one 

613 can use a return value. 

614 Note that if your virtualenv runs in a different Python major version than Airflow, 

615 you cannot use return values, op_args, op_kwargs, or use any macros that are being provided to 

616 Airflow through plugins. You can use string_args though. 

617 

618 .. seealso:: 

619 For more information on how to use this operator, take a look at the guide: 

620 :ref:`howto/operator:PythonVirtualenvOperator` 

621 

622 :param python_callable: A python function with no references to outside variables, 

623 defined with def, which will be run in a virtual environment. 

624 :param requirements: Either a list of requirement strings, or a (templated) 

625 "requirements file" as specified by pip. 

626 :param python_version: The Python version to run the virtual environment with. Note that 

627 both 2 and 2.7 are acceptable forms. 

628 :param serializer: Which serializer use to serialize the args and result. It can be one of the following: 

629 

630 - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library. 

631 - ``"cloudpickle"``: Use cloudpickle for serialize more complex types, 

632 this requires to include cloudpickle in your requirements. 

633 - ``"dill"``: Use dill for serialize more complex types, 

634 this requires to include dill in your requirements. 

635 :param system_site_packages: Whether to include 

636 system_site_packages in your virtual environment. 

637 See virtualenv documentation for more information. 

638 :param pip_install_options: a list of pip install options when installing requirements 

639 See 'pip install -h' for available options 

640 :param op_args: A list of positional arguments to pass to python_callable. 

641 :param op_kwargs: A dict of keyword arguments to pass to python_callable. 

642 :param string_args: Strings that are present in the global var virtualenv_string_args, 

643 available to python_callable at runtime as a list[str]. Note that args are split 

644 by newline. 

645 :param templates_dict: a dictionary where the values are templates that 

646 will get templated by the Airflow engine sometime between 

647 ``__init__`` and ``execute`` takes place and are made available 

648 in your callable's context after the template has been applied 

649 :param templates_exts: a list of file extensions to resolve while 

650 processing templated fields, for examples ``['.sql', '.hql']`` 

651 :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator 

652 will raise warning if Airflow is not installed, and it will attempt to load Airflow 

653 macros when starting. 

654 :param skip_on_exit_code: If python_callable exits with this exit code, leave the task 

655 in ``skipped`` state (default: None). If set to ``None``, any non-zero 

656 exit code will be treated as a failure. 

657 :param index_urls: an optional list of index urls to load Python packages from. 

658 If not provided the system pip conf will be used to source packages from. 

659 :param index_urls_from_connection_ids: An optional list of ``PackageIndex`` connection IDs. 

660 Will be appended to ``index_urls``. 

661 :param venv_cache_path: Optional path to the virtual environment parent folder in which the 

662 virtual environment will be cached, creates a sub-folder venv-{hash} whereas hash will be replaced 

663 with a checksum of requirements. If not provided the virtual environment will be created and deleted 

664 in a temp folder for every execution. 

665 :param env_vars: A dictionary containing additional environment variables to set for the virtual 

666 environment when it is executed. 

667 :param inherit_env: Whether to inherit the current environment variables when executing the virtual 

668 environment. If set to ``True``, the virtual environment will inherit the environment variables 

669 of the parent process (``os.environ``). If set to ``False``, the virtual environment will be 

670 executed with a clean environment. 

671 """ 

672 

673 template_fields: Sequence[str] = tuple( 

674 {"requirements", "index_urls", "index_urls_from_connection_ids", "venv_cache_path"}.union( 

675 PythonOperator.template_fields 

676 ) 

677 ) 

678 template_ext: Sequence[str] = (".txt",) 

679 

680 def __init__( 

681 self, 

682 *, 

683 python_callable: Callable, 

684 requirements: None | Iterable[str] | str = None, 

685 python_version: str | None = None, 

686 serializer: _SerializerTypeDef | None = None, 

687 system_site_packages: bool = True, 

688 pip_install_options: list[str] | None = None, 

689 op_args: Collection[Any] | None = None, 

690 op_kwargs: Mapping[str, Any] | None = None, 

691 string_args: Iterable[str] | None = None, 

692 templates_dict: dict | None = None, 

693 templates_exts: list[str] | None = None, 

694 expect_airflow: bool = True, 

695 skip_on_exit_code: int | Container[int] | None = None, 

696 index_urls: None | Collection[str] | str = None, 

697 index_urls_from_connection_ids: None | Collection[str] | str = None, 

698 venv_cache_path: None | os.PathLike[str] = None, 

699 env_vars: dict[str, str] | None = None, 

700 inherit_env: bool = True, 

701 **kwargs, 

702 ): 

703 if ( 

704 python_version 

705 and str(python_version)[0] != str(sys.version_info.major) 

706 and (op_args or op_kwargs) 

707 ): 

708 raise AirflowException( 

709 "Passing op_args or op_kwargs is not supported across different Python " 

710 "major versions for PythonVirtualenvOperator. Please use string_args." 

711 f"Sys version: {sys.version_info}. Virtual environment version: {python_version}" 

712 ) 

713 if python_version is not None and not isinstance(python_version, str): 

714 raise AirflowException( 

715 "Passing non-string types (e.g. int or float) as python_version not supported" 

716 ) 

717 if not requirements: 

718 self.requirements: list[str] = [] 

719 elif isinstance(requirements, str): 

720 self.requirements = [requirements] 

721 else: 

722 self.requirements = list(requirements) 

723 self.python_version = python_version 

724 self.system_site_packages = system_site_packages 

725 self.pip_install_options = pip_install_options 

726 if isinstance(index_urls, str): 

727 self.index_urls: list[str] | None = [index_urls] 

728 elif isinstance(index_urls, Collection): 

729 self.index_urls = list(index_urls) 

730 else: 

731 self.index_urls = None 

732 if isinstance(index_urls_from_connection_ids, str): 

733 self.index_urls_from_connection_ids: list[str] | None = [index_urls_from_connection_ids] 

734 elif isinstance(index_urls_from_connection_ids, Collection): 

735 self.index_urls_from_connection_ids = list(index_urls_from_connection_ids) 

736 else: 

737 self.index_urls_from_connection_ids = None 

738 self.venv_cache_path = venv_cache_path 

739 super().__init__( 

740 python_callable=python_callable, 

741 serializer=serializer, 

742 op_args=op_args, 

743 op_kwargs=op_kwargs, 

744 string_args=string_args, 

745 templates_dict=templates_dict, 

746 templates_exts=templates_exts, 

747 expect_airflow=expect_airflow, 

748 skip_on_exit_code=skip_on_exit_code, 

749 env_vars=env_vars, 

750 inherit_env=inherit_env, 

751 **kwargs, 

752 ) 

753 

754 def _requirements_list(self, exclude_cloudpickle: bool = False) -> list[str]: 

755 """Prepare a list of requirements that need to be installed for the virtual environment.""" 

756 requirements = [str(dependency) for dependency in self.requirements] 

757 if not self.system_site_packages: 

758 if ( 

759 self.serializer == "cloudpickle" 

760 and not exclude_cloudpickle 

761 and "cloudpickle" not in requirements 

762 ): 

763 requirements.append("cloudpickle") 

764 elif self.serializer == "dill" and "dill" not in requirements: 

765 requirements.append("dill") 

766 requirements.sort() # Ensure a hash is stable 

767 return requirements 

768 

769 def _prepare_venv(self, venv_path: Path) -> None: 

770 """Prepare the requirements and installs the virtual environment.""" 

771 requirements_file = venv_path / "requirements.txt" 

772 requirements_file.write_text("\n".join(self._requirements_list())) 

773 prepare_virtualenv( 

774 venv_directory=str(venv_path), 

775 python_bin=f"python{self.python_version}" if self.python_version else "python", 

776 system_site_packages=self.system_site_packages, 

777 requirements_file_path=str(requirements_file), 

778 pip_install_options=self.pip_install_options, 

779 index_urls=self.index_urls, 

780 ) 

781 

782 def _calculate_cache_hash(self, exclude_cloudpickle: bool = False) -> tuple[str, str]: 

783 """ 

784 Generate the hash of the cache folder to use. 

785 

786 The following factors are used as input for the hash: 

787 - (sorted) list of requirements 

788 - pip install options 

789 - flag of system site packages 

790 - python version 

791 - Variable to override the hash with a cache key 

792 - Index URLs 

793 

794 Returns a hash and the data dict which is the base for the hash as text. 

795 """ 

796 hash_dict = { 

797 "requirements_list": self._requirements_list(exclude_cloudpickle=exclude_cloudpickle), 

798 "pip_install_options": self.pip_install_options, 

799 "index_urls": self.index_urls, 

800 "cache_key": str(Variable.get("PythonVirtualenvOperator.cache_key", "")), 

801 "python_version": self.python_version, 

802 "system_site_packages": self.system_site_packages, 

803 } 

804 hash_text = json.dumps(hash_dict, sort_keys=True) 

805 hash_object = hashlib_wrapper.md5(hash_text.encode()) 

806 requirements_hash = hash_object.hexdigest() 

807 return requirements_hash[:8], hash_text 

808 

809 def _ensure_venv_cache_exists(self, venv_cache_path: Path) -> Path: 

810 """Ensure a valid virtual environment is set up and will create inplace.""" 

811 cache_hash, hash_data = self._calculate_cache_hash() 

812 venv_path = venv_cache_path / f"venv-{cache_hash}" 

813 self.log.info("Python virtual environment will be cached in %s", venv_path) 

814 venv_path.parent.mkdir(parents=True, exist_ok=True) 

815 with open(f"{venv_path}.lock", "w") as f: 

816 # Ensure that cache is not build by parallel workers 

817 import fcntl 

818 

819 fcntl.flock(f, fcntl.LOCK_EX) 

820 

821 hash_marker = venv_path / "install_complete_marker.json" 

822 try: 

823 if venv_path.exists(): 

824 if hash_marker.exists(): 

825 previous_hash_data = hash_marker.read_text(encoding="utf8") 

826 if previous_hash_data == hash_data: 

827 self.log.info("Reusing cached Python virtual environment in %s", venv_path) 

828 return venv_path 

829 

830 _, hash_data_before_upgrade = self._calculate_cache_hash(exclude_cloudpickle=True) 

831 if previous_hash_data == hash_data_before_upgrade: 

832 self.log.warning( 

833 "Found a previous virtual environment in with outdated dependencies %s, " 

834 "deleting and re-creating.", 

835 venv_path, 

836 ) 

837 else: 

838 self.log.error( 

839 "Unicorn alert: Found a previous virtual environment in %s " 

840 "with the same hash but different parameters. Previous setup: '%s' / " 

841 "Requested venv setup: '%s'. Please report a bug to airflow!", 

842 venv_path, 

843 previous_hash_data, 

844 hash_data, 

845 ) 

846 else: 

847 self.log.warning( 

848 "Found a previous (probably partial installed) virtual environment in %s, " 

849 "deleting and re-creating.", 

850 venv_path, 

851 ) 

852 

853 shutil.rmtree(venv_path) 

854 

855 venv_path.mkdir(parents=True) 

856 self._prepare_venv(venv_path) 

857 hash_marker.write_text(hash_data, encoding="utf8") 

858 except Exception as e: 

859 shutil.rmtree(venv_path) 

860 raise AirflowException(f"Unable to create new virtual environment in {venv_path}") from e 

861 self.log.info("New Python virtual environment created in %s", venv_path) 

862 return venv_path 

863 

864 def _cleanup_python_pycache_dir(self, cache_dir_path: Path) -> None: 

865 try: 

866 shutil.rmtree(cache_dir_path) 

867 self.log.debug("The directory %s has been deleted.", cache_dir_path) 

868 except FileNotFoundError: 

869 self.log.warning("Fail to delete %s. The directory does not exist.", cache_dir_path) 

870 except PermissionError: 

871 self.log.warning("Permission denied to delete the directory %s.", cache_dir_path) 

872 

873 def _retrieve_index_urls_from_connection_ids(self): 

874 """Retrieve index URLs from Package Index connections.""" 

875 if self.index_urls is None: 

876 self.index_urls = [] 

877 for conn_id in self.index_urls_from_connection_ids: 

878 conn_url = PackageIndexHook(conn_id).get_connection_url() 

879 self.index_urls.append(conn_url) 

880 

881 def execute_callable(self): 

882 if self.index_urls_from_connection_ids: 

883 self._retrieve_index_urls_from_connection_ids() 

884 

885 if self.venv_cache_path: 

886 venv_path = self._ensure_venv_cache_exists(Path(self.venv_cache_path)) 

887 python_path = venv_path / "bin" / "python" 

888 return self._execute_python_callable_in_subprocess(python_path) 

889 

890 with TemporaryDirectory(prefix="venv") as tmp_dir: 

891 tmp_path = Path(tmp_dir) 

892 custom_pycache_prefix = Path(sys.pycache_prefix or "") 

893 r_path = tmp_path.relative_to(tmp_path.anchor) 

894 venv_python_cache_dir = Path.cwd() / custom_pycache_prefix / r_path 

895 self._prepare_venv(tmp_path) 

896 python_path = tmp_path / "bin" / "python" 

897 result = self._execute_python_callable_in_subprocess(python_path) 

898 self._cleanup_python_pycache_dir(venv_python_cache_dir) 

899 return result 

900 

901 def _iter_serializable_context_keys(self): 

902 yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS 

903 

904 found_airflow = found_pendulum = False 

905 

906 if self.system_site_packages: 

907 # If we're using system packages, assume both are present 

908 found_airflow = found_pendulum = True 

909 else: 

910 for raw_str in chain.from_iterable(req.splitlines() for req in self.requirements): 

911 line = raw_str.strip() 

912 # Skip blank lines and full‐line comments 

913 if not line or line.startswith("#"): 

914 continue 

915 

916 # Strip off any inline comment 

917 # e.g. turn "foo==1.2.3 # comment" → "foo==1.2.3" 

918 req_str = re.sub(r"#.*$", "", line).strip() 

919 

920 try: 

921 req = Requirement(req_str) 

922 except (InvalidRequirement, InvalidSpecifier, InvalidVersion) as e: 

923 raise ValueError(f"Invalid requirement '{raw_str}': {e}") from e 

924 

925 if req.name == "apache-airflow": 

926 found_airflow = found_pendulum = True 

927 break 

928 elif req.name == "pendulum": 

929 found_pendulum = True 

930 

931 if found_airflow: 

932 yield from self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS 

933 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS 

934 elif found_pendulum: 

935 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS 

936 

937 

938class BranchPythonVirtualenvOperator(BaseBranchOperator, PythonVirtualenvOperator): 

939 """ 

940 A workflow can "branch" or follow a path after the execution of this task in a virtual environment. 

941 

942 It derives the PythonVirtualenvOperator and expects a Python function that returns 

943 a single task_id, a single task_group_id, or a list of task_ids and/or 

944 task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned 

945 should point to a task or task group directly downstream from {self}. All 

946 other "branches" or directly downstream tasks are marked with a state of 

947 ``skipped`` so that these paths can't move forward. The ``skipped`` states 

948 are propagated downstream to allow for the DAG state to fill up and 

949 the DAG run's state to be inferred. 

950 

951 .. seealso:: 

952 For more information on how to use this operator, take a look at the guide: 

953 :ref:`howto/operator:BranchPythonVirtualenvOperator` 

954 """ 

955 

956 def choose_branch(self, context: Context) -> str | Iterable[str]: 

957 return PythonVirtualenvOperator.execute(self, context) 

958 

959 

960class ExternalPythonOperator(_BasePythonVirtualenvOperator): 

961 """ 

962 Run a function in a virtualenv that is not re-created. 

963 

964 Reused as is without the overhead of creating the virtual environment (with certain caveats). 

965 

966 The function must be defined using def, and not be 

967 part of a class. All imports must happen inside the function 

968 and no variables outside the scope may be referenced. A global scope 

969 variable named virtualenv_string_args will be available (populated by 

970 string_args). In addition, one can pass stuff through op_args and op_kwargs, and one 

971 can use a return value. 

972 Note that if your virtual environment runs in a different Python major version than Airflow, 

973 you cannot use return values, op_args, op_kwargs, or use any macros that are being provided to 

974 Airflow through plugins. You can use string_args though. 

975 

976 If Airflow is installed in the external environment in different version that the version 

977 used by the operator, the operator will fail., 

978 

979 .. seealso:: 

980 For more information on how to use this operator, take a look at the guide: 

981 :ref:`howto/operator:ExternalPythonOperator` 

982 

983 :param python: Full path string (file-system specific) that points to a Python binary inside 

984 a virtual environment that should be used (in ``VENV/bin`` folder). Should be absolute path 

985 (so usually start with "/" or "X:/" depending on the filesystem/os used). 

986 :param python_callable: A python function with no references to outside variables, 

987 defined with def, which will be run in a virtual environment. 

988 :param serializer: Which serializer use to serialize the args and result. It can be one of the following: 

989 

990 - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library. 

991 - ``"cloudpickle"``: Use cloudpickle for serialize more complex types, 

992 this requires to include cloudpickle in your requirements. 

993 - ``"dill"``: Use dill for serialize more complex types, 

994 this requires to include dill in your requirements. 

995 :param op_args: A list of positional arguments to pass to python_callable. 

996 :param op_kwargs: A dict of keyword arguments to pass to python_callable. 

997 :param string_args: Strings that are present in the global var virtualenv_string_args, 

998 available to python_callable at runtime as a list[str]. Note that args are split 

999 by newline. 

1000 :param templates_dict: a dictionary where the values are templates that 

1001 will get templated by the Airflow engine sometime between 

1002 ``__init__`` and ``execute`` takes place and are made available 

1003 in your callable's context after the template has been applied 

1004 :param templates_exts: a list of file extensions to resolve while 

1005 processing templated fields, for examples ``['.sql', '.hql']`` 

1006 :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator 

1007 will raise warning if Airflow is not installed, and it will attempt to load Airflow 

1008 macros when starting. 

1009 :param skip_on_exit_code: If python_callable exits with this exit code, leave the task 

1010 in ``skipped`` state (default: None). If set to ``None``, any non-zero 

1011 exit code will be treated as a failure. 

1012 :param env_vars: A dictionary containing additional environment variables to set for the virtual 

1013 environment when it is executed. 

1014 :param inherit_env: Whether to inherit the current environment variables when executing the virtual 

1015 environment. If set to ``True``, the virtual environment will inherit the environment variables 

1016 of the parent process (``os.environ``). If set to ``False``, the virtual environment will be 

1017 executed with a clean environment. 

1018 """ 

1019 

1020 template_fields: Sequence[str] = tuple({"python"}.union(PythonOperator.template_fields)) 

1021 

1022 def __init__( 

1023 self, 

1024 *, 

1025 python: str, 

1026 python_callable: Callable, 

1027 serializer: _SerializerTypeDef | None = None, 

1028 op_args: Collection[Any] | None = None, 

1029 op_kwargs: Mapping[str, Any] | None = None, 

1030 string_args: Iterable[str] | None = None, 

1031 templates_dict: dict | None = None, 

1032 templates_exts: list[str] | None = None, 

1033 expect_airflow: bool = True, 

1034 expect_pendulum: bool = False, 

1035 skip_on_exit_code: int | Container[int] | None = None, 

1036 env_vars: dict[str, str] | None = None, 

1037 inherit_env: bool = True, 

1038 **kwargs, 

1039 ): 

1040 if not python: 

1041 raise ValueError("Python Path must be defined in ExternalPythonOperator") 

1042 self.python = python 

1043 self.expect_pendulum = expect_pendulum 

1044 super().__init__( 

1045 python_callable=python_callable, 

1046 serializer=serializer, 

1047 op_args=op_args, 

1048 op_kwargs=op_kwargs, 

1049 string_args=string_args, 

1050 templates_dict=templates_dict, 

1051 templates_exts=templates_exts, 

1052 expect_airflow=expect_airflow, 

1053 skip_on_exit_code=skip_on_exit_code, 

1054 env_vars=env_vars, 

1055 inherit_env=inherit_env, 

1056 **kwargs, 

1057 ) 

1058 

1059 def execute_callable(self): 

1060 python_path = Path(self.python) 

1061 if not python_path.exists(): 

1062 raise ValueError(f"Python Path '{python_path}' must exists") 

1063 if not python_path.is_file(): 

1064 raise ValueError(f"Python Path '{python_path}' must be a file") 

1065 if not python_path.is_absolute(): 

1066 raise ValueError(f"Python Path '{python_path}' must be an absolute path.") 

1067 python_version = _PythonVersionInfo.from_executable(self.python) 

1068 if python_version.major != sys.version_info.major and (self.op_args or self.op_kwargs): 

1069 raise AirflowException( 

1070 "Passing op_args or op_kwargs is not supported across different Python " 

1071 "major versions for ExternalPythonOperator. Please use string_args." 

1072 f"Sys version: {sys.version_info}. " 

1073 f"Virtual environment version: {python_version}" 

1074 ) 

1075 return self._execute_python_callable_in_subprocess(python_path) 

1076 

1077 def _iter_serializable_context_keys(self): 

1078 yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS 

1079 if self.expect_airflow and self._get_airflow_version_from_target_env(): 

1080 yield from self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS 

1081 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS 

1082 elif self._is_pendulum_installed_in_target_env(): 

1083 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS 

1084 

1085 def _is_pendulum_installed_in_target_env(self) -> bool: 

1086 try: 

1087 subprocess.check_call([self.python, "-c", "import pendulum"]) 

1088 return True 

1089 except Exception as e: 

1090 if self.expect_pendulum: 

1091 self.log.warning("When checking for Pendulum installed in virtual environment got %s", e) 

1092 self.log.warning( 

1093 "Pendulum is not properly installed in the virtual environment " 

1094 "Pendulum context keys will not be available. " 

1095 "Please Install Pendulum or Airflow in your virtual environment to access them." 

1096 ) 

1097 return False 

1098 

1099 @property 

1100 def _external_airflow_version_script(self): 

1101 """ 

1102 Return python script which determines the version of the Apache Airflow. 

1103 

1104 Import airflow as a module might take a while as a result, 

1105 obtaining a version would take up to 1 second. 

1106 On the other hand, `importlib.metadata.version` will retrieve the package version pretty fast 

1107 something below 100ms; this includes new subprocess overhead. 

1108 

1109 Possible side effect: It might be a situation that `importlib.metadata` is not available (Python < 3.8), 

1110 as well as backport `importlib_metadata` which might indicate that venv doesn't contain an `apache-airflow` 

1111 or something wrong with the environment. 

1112 """ 

1113 return textwrap.dedent( 

1114 """ 

1115 try: 

1116 from importlib.metadata import version 

1117 except ImportError: 

1118 from importlib_metadata import version 

1119 print(version("apache-airflow")) 

1120 """ 

1121 ) 

1122 

1123 def _get_airflow_version_from_target_env(self) -> str | None: 

1124 from airflow import __version__ as airflow_version 

1125 

1126 try: 

1127 result = subprocess.check_output( 

1128 [self.python, "-c", self._external_airflow_version_script], 

1129 text=True, 

1130 ) 

1131 target_airflow_version = result.strip() 

1132 if target_airflow_version != airflow_version: 

1133 raise AirflowConfigException( 

1134 f"The version of Airflow installed for the {self.python} " 

1135 f"({target_airflow_version}) is different than the runtime Airflow version: " 

1136 f"{airflow_version}. Make sure your environment has the same Airflow version " 

1137 f"installed as the Airflow runtime." 

1138 ) 

1139 return target_airflow_version 

1140 except Exception as e: 

1141 if self.expect_airflow: 

1142 self.log.warning("When checking for Airflow installed in virtual environment got %s", e) 

1143 self.log.warning( 

1144 "This means that Airflow is not properly installed by %s. " 

1145 "Airflow context keys will not be available. " 

1146 "Please Install Airflow %s in your environment to access them.", 

1147 self.python, 

1148 airflow_version, 

1149 ) 

1150 return None 

1151 

1152 

1153class BranchExternalPythonOperator(BaseBranchOperator, ExternalPythonOperator): 

1154 """ 

1155 A workflow can "branch" or follow a path after the execution of this task. 

1156 

1157 Extends ExternalPythonOperator, so expects to get Python: 

1158 virtual environment that should be used (in ``VENV/bin`` folder). Should be absolute path, 

1159 so it can run on separate virtual environment similarly to ExternalPythonOperator. 

1160 

1161 .. seealso:: 

1162 For more information on how to use this operator, take a look at the guide: 

1163 :ref:`howto/operator:BranchExternalPythonOperator` 

1164 """ 

1165 

1166 def choose_branch(self, context: Context) -> str | Iterable[str]: 

1167 return ExternalPythonOperator.execute(self, context) 

1168 

1169 

1170def get_current_context() -> Mapping[str, Any]: 

1171 """ 

1172 Retrieve the execution context dictionary without altering user method's signature. 

1173 

1174 This is the simplest method of retrieving the execution context dictionary. 

1175 

1176 **Old style:** 

1177 

1178 .. code:: python 

1179 

1180 def my_task(**context): 

1181 ti = context["ti"] 

1182 

1183 **New style:** 

1184 

1185 .. code:: python 

1186 

1187 from airflow.providers.standard.operators.python import get_current_context 

1188 

1189 

1190 def my_task(): 

1191 context = get_current_context() 

1192 ti = context["ti"] 

1193 

1194 Current context will only have value if this method was called after an operator 

1195 was starting to execute. 

1196 """ 

1197 if AIRFLOW_V_3_0_PLUS: 

1198 warnings.warn( 

1199 "Using get_current_context from standard provider is deprecated and will be removed." 

1200 "Please import `from airflow.sdk import get_current_context` and use it instead.", 

1201 AirflowProviderDeprecationWarning, 

1202 stacklevel=2, 

1203 ) 

1204 

1205 from airflow.sdk import get_current_context 

1206 

1207 return get_current_context() 

1208 return _get_current_context() 

1209 

1210 

1211def _get_current_context() -> Mapping[str, Any]: 

1212 # Airflow 2.x 

1213 # TODO: To be removed when Airflow 2 support is dropped 

1214 from airflow.models.taskinstance import _CURRENT_CONTEXT # type: ignore[attr-defined] 

1215 

1216 if not _CURRENT_CONTEXT: 

1217 raise RuntimeError( 

1218 "Current context was requested but no context was found! Are you running within an Airflow task?" 

1219 ) 

1220 return _CURRENT_CONTEXT[-1]