1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import inspect
21import json
22import logging
23import os
24import re
25import shutil
26import subprocess
27import sys
28import textwrap
29import types
30import warnings
31from abc import ABCMeta, abstractmethod
32from collections.abc import Callable, Collection, Container, Iterable, Mapping, Sequence
33from functools import cache
34from itertools import chain
35from pathlib import Path
36from tempfile import TemporaryDirectory
37from typing import TYPE_CHECKING, Any, NamedTuple, cast
38
39import lazy_object_proxy
40from packaging.requirements import InvalidRequirement, Requirement
41from packaging.specifiers import InvalidSpecifier
42from packaging.version import InvalidVersion
43
44from airflow.exceptions import (
45 AirflowConfigException,
46 AirflowProviderDeprecationWarning,
47 DeserializingResultError,
48)
49from airflow.models.variable import Variable
50from airflow.providers.common.compat.sdk import AirflowException, AirflowSkipException, context_merge
51from airflow.providers.common.compat.standard.operators import (
52 BaseAsyncOperator,
53 is_async_callable,
54)
55from airflow.providers.standard.hooks.package_index import PackageIndexHook
56from airflow.providers.standard.utils.python_virtualenv import (
57 _execute_in_subprocess,
58 prepare_virtualenv,
59 write_python_script,
60)
61from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS
62from airflow.utils import hashlib_wrapper
63from airflow.utils.file import get_unique_dag_module_name
64from airflow.utils.operator_helpers import KeywordParameters
65
66if AIRFLOW_V_3_0_PLUS:
67 from airflow.providers.standard.operators.branch import BaseBranchOperator
68 from airflow.providers.standard.utils.skipmixin import SkipMixin
69else:
70 from airflow.models.skipmixin import SkipMixin
71 from airflow.operators.branch import BaseBranchOperator # type: ignore[no-redef]
72
73
74log = logging.getLogger(__name__)
75
76if TYPE_CHECKING:
77 from typing import Literal
78
79 from pendulum.datetime import DateTime
80
81 from airflow.providers.common.compat.sdk import Context
82 from airflow.sdk.execution_time.callback_runner import (
83 AsyncExecutionCallableRunner,
84 ExecutionCallableRunner,
85 )
86 from airflow.sdk.execution_time.context import OutletEventAccessorsProtocol
87
88 _SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]
89
90
91@cache
92def _parse_version_info(text: str) -> tuple[int, int, int, str, int]:
93 """Parse python version info from a text."""
94 parts = text.strip().split(".")
95 if len(parts) != 5:
96 msg = f"Invalid Python version info, expected 5 components separated by '.', but got {text!r}."
97 raise ValueError(msg)
98 try:
99 return int(parts[0]), int(parts[1]), int(parts[2]), parts[3], int(parts[4])
100 except ValueError:
101 msg = f"Unable to convert parts {parts} parsed from {text!r} to (int, int, int, str, int)."
102 raise ValueError(msg) from None
103
104
105class _PythonVersionInfo(NamedTuple):
106 """Provide the same interface as ``sys.version_info``."""
107
108 major: int
109 minor: int
110 micro: int
111 releaselevel: str
112 serial: int
113
114 @classmethod
115 def from_executable(cls, executable: str) -> _PythonVersionInfo:
116 """Parse python version info from an executable."""
117 cmd = [executable, "-c", 'import sys; print(".".join(map(str, sys.version_info)))']
118 try:
119 result = subprocess.check_output(cmd, text=True)
120 except Exception as e:
121 raise ValueError(f"Error while executing command {cmd}: {e}")
122 return cls(*_parse_version_info(result.strip()))
123
124
125class PythonOperator(BaseAsyncOperator):
126 """
127 Base class for all Python operators.
128
129 .. seealso::
130 For more information on how to use this operator, take a look at the guide:
131 :ref:`howto/operator:PythonOperator`
132
133 When running your callable, Airflow will pass a set of keyword arguments that can be used in your
134 function. This set of kwargs correspond exactly to what you can use in your jinja templates.
135 For this to work, you need to define ``**kwargs`` in your function header, or you can add directly the
136 keyword arguments you would like to get - for example with the below code your callable will get
137 the values of ``ti`` context variables.
138
139 With explicit arguments:
140
141 .. code-block:: python
142
143 def my_python_callable(ti):
144 pass
145
146 With kwargs:
147
148 .. code-block:: python
149
150 def my_python_callable(**kwargs):
151 ti = kwargs["ti"]
152
153
154 :param python_callable: A reference to an object that is callable
155 :param op_args: a list of positional arguments that will get unpacked when
156 calling your callable
157 :param op_kwargs: a dictionary of keyword arguments that will get unpacked
158 in your function
159 :param templates_dict: a dictionary where the values are templates that
160 will get templated by the Airflow engine sometime between
161 ``__init__`` and ``execute`` takes place and are made available
162 in your callable's context after the template has been applied. (templated)
163 :param templates_exts: a list of file extensions to resolve while
164 processing templated fields, for examples ``['.sql', '.hql']``
165 :param show_return_value_in_logs: a bool value whether to show return_value
166 logs. Defaults to True, which allows return value log output.
167 It can be set to False to prevent log output of return value when you return huge data
168 such as transmission a large amount of XCom to TaskAPI.
169 """
170
171 template_fields: Sequence[str] = ("templates_dict", "op_args", "op_kwargs")
172 template_fields_renderers = {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}
173 BLUE = "#ffefeb"
174 ui_color = BLUE
175
176 # since we won't mutate the arguments, we should just do the shallow copy
177 # there are some cases we can't deepcopy the objects(e.g protobuf).
178 shallow_copy_attrs: Sequence[str] = ("python_callable", "op_kwargs")
179
180 def __init__(
181 self,
182 *,
183 python_callable: Callable,
184 op_args: Collection[Any] | None = None,
185 op_kwargs: Mapping[str, Any] | None = None,
186 templates_dict: dict[str, Any] | None = None,
187 templates_exts: Sequence[str] | None = None,
188 show_return_value_in_logs: bool = True,
189 **kwargs,
190 ) -> None:
191 super().__init__(**kwargs)
192 if not callable(python_callable):
193 raise AirflowException("`python_callable` param must be callable")
194 self.python_callable = python_callable
195 self.op_args = op_args or ()
196 self.op_kwargs = op_kwargs or {}
197 self.templates_dict = templates_dict
198 if templates_exts:
199 self.template_ext = templates_exts
200 self.show_return_value_in_logs = show_return_value_in_logs
201
202 @property
203 def is_async(self) -> bool:
204 return is_async_callable(self.python_callable)
205
206 def execute(self, context) -> Any:
207 if self.is_async:
208 return BaseAsyncOperator.execute(self, context)
209
210 context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
211 self.op_kwargs = self.determine_kwargs(context)
212
213 # This needs to be lazy because subclasses may implement execute_callable
214 # by running a separate process that can't use the eager result.
215 def __prepare_execution() -> tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None:
216 if AIRFLOW_V_3_0_PLUS:
217 from airflow.sdk.execution_time.callback_runner import create_executable_runner
218 from airflow.sdk.execution_time.context import context_get_outlet_events
219
220 return create_executable_runner, context_get_outlet_events(context)
221 from airflow.utils.context import context_get_outlet_events # type: ignore
222 from airflow.utils.operator_helpers import ExecutionCallableRunner # type: ignore
223
224 return ExecutionCallableRunner, context_get_outlet_events(context)
225
226 self.__prepare_execution = __prepare_execution
227
228 return_value = self.execute_callable()
229 if self.show_return_value_in_logs:
230 self.log.info("Done. Returned value was: %s", return_value)
231 else:
232 self.log.info("Done. Returned value not shown")
233
234 return return_value
235
236 def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
237 return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking()
238
239 __prepare_execution: Callable[[], tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None]
240
241 def execute_callable(self) -> Any:
242 """
243 Call the python callable with the given arguments.
244
245 :return: the return value of the call.
246 """
247 if (execution_preparation := self.__prepare_execution()) is None:
248 return self.python_callable(*self.op_args, **self.op_kwargs)
249 create_execution_runner, asset_events = execution_preparation
250 runner = create_execution_runner(self.python_callable, asset_events, logger=self.log)
251 return runner.run(*self.op_args, **self.op_kwargs)
252
253 if AIRFLOW_V_3_2_PLUS:
254
255 async def aexecute(self, context):
256 context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
257 self.op_kwargs = self.determine_kwargs(context)
258
259 # This needs to be lazy because subclasses may implement execute_callable
260 # by running a separate process that can't use the eager result.
261 def __prepare_execution() -> (
262 tuple[AsyncExecutionCallableRunner, OutletEventAccessorsProtocol] | None
263 ):
264 from airflow.sdk.execution_time.callback_runner import create_async_executable_runner
265 from airflow.sdk.execution_time.context import context_get_outlet_events
266
267 return (
268 cast("AsyncExecutionCallableRunner", create_async_executable_runner),
269 context_get_outlet_events(context),
270 )
271
272 self.__prepare_execution = __prepare_execution
273
274 return_value = await self.aexecute_callable()
275 if self.show_return_value_in_logs:
276 self.log.info("Done. Returned value was: %s", return_value)
277 else:
278 self.log.info("Done. Returned value not shown")
279
280 return return_value
281
282 async def aexecute_callable(self) -> Any:
283 """
284 Call the python callable with the given arguments.
285
286 :return: the return value of the call.
287 """
288 if (execution_preparation := self.__prepare_execution()) is None:
289 return await self.python_callable(*self.op_args, **self.op_kwargs)
290 create_execution_runner, asset_events = execution_preparation
291 runner = create_execution_runner(self.python_callable, asset_events, logger=self.log)
292 return await runner.run(*self.op_args, **self.op_kwargs)
293
294
295class BranchPythonOperator(BaseBranchOperator, PythonOperator):
296 """
297 A workflow can "branch" or follow a path after the execution of this task.
298
299 It derives the PythonOperator and expects a Python function that returns
300 a single task_id, a single task_group_id, or a list of task_ids and/or
301 task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned
302 should point to a task or task group directly downstream from {self}. All
303 other "branches" or directly downstream tasks are marked with a state of
304 ``skipped`` so that these paths can't move forward. The ``skipped`` states
305 are propagated downstream to allow for the DAG state to fill up and
306 the DAG run's state to be inferred.
307 """
308
309 def choose_branch(self, context: Context) -> str | Iterable[str]:
310 return PythonOperator.execute(self, context)
311
312
313class ShortCircuitOperator(PythonOperator, SkipMixin):
314 """
315 Allows a pipeline to continue based on the result of a ``python_callable``.
316
317 The ShortCircuitOperator is derived from the PythonOperator and evaluates the result of a
318 ``python_callable``. If the returned result is False or a falsy value, the pipeline will be
319 short-circuited. Downstream tasks will be marked with a state of "skipped" based on the short-circuiting
320 mode configured. If the returned result is True or a truthy value, downstream tasks proceed as normal and
321 an ``XCom`` of the returned result is pushed.
322
323 The short-circuiting can be configured to either respect or ignore the ``trigger_rule`` set for
324 downstream tasks. If ``ignore_downstream_trigger_rules`` is set to True, the default setting, all
325 downstream tasks are skipped without considering the ``trigger_rule`` defined for tasks. However, if this
326 parameter is set to False, the direct downstream tasks are skipped but the specified ``trigger_rule`` for
327 other subsequent downstream tasks are respected. In this mode, the operator assumes the direct downstream
328 tasks were purposely meant to be skipped but perhaps not other subsequent tasks.
329
330 .. seealso::
331 For more information on how to use this operator, take a look at the guide:
332 :ref:`howto/operator:ShortCircuitOperator`
333
334 :param ignore_downstream_trigger_rules: If set to True, all downstream tasks from this operator task will
335 be skipped. This is the default behavior. If set to False, the direct, downstream task(s) will be
336 skipped but the ``trigger_rule`` defined for all other downstream tasks will be respected.
337 """
338
339 inherits_from_skipmixin = True
340
341 def __init__(self, *, ignore_downstream_trigger_rules: bool = True, **kwargs) -> None:
342 super().__init__(**kwargs)
343 self.ignore_downstream_trigger_rules = ignore_downstream_trigger_rules
344
345 def execute(self, context: Context) -> Any:
346 condition = super().execute(context)
347 self.log.info("Condition result is %s", condition)
348
349 if condition:
350 self.log.info("Proceeding with downstream tasks...")
351 return condition
352
353 if not self.downstream_task_ids:
354 self.log.info("No downstream tasks; nothing to do.")
355 return condition
356
357 dag_run = context["dag_run"]
358
359 def get_tasks_to_skip():
360 if self.ignore_downstream_trigger_rules is True:
361 tasks = context["task"].get_flat_relatives(upstream=False)
362 else:
363 tasks = context["task"].get_direct_relatives(upstream=False)
364 for t in tasks:
365 if not t.is_teardown:
366 yield t
367
368 to_skip = get_tasks_to_skip()
369
370 # this lets us avoid an intermediate list unless debug logging
371 if self.log.getEffectiveLevel() <= logging.DEBUG:
372 self.log.debug("Downstream task IDs %s", to_skip := list(get_tasks_to_skip()))
373
374 self.log.info("Skipping downstream tasks")
375 if AIRFLOW_V_3_0_PLUS:
376 self.skip(
377 ti=context["ti"],
378 tasks=to_skip,
379 )
380 else:
381 if to_skip:
382 self.skip(
383 dag_run=context["dag_run"],
384 tasks=to_skip,
385 execution_date=cast("DateTime", dag_run.logical_date), # type: ignore[call-arg]
386 map_index=context["ti"].map_index,
387 )
388
389 self.log.info("Done.")
390 # returns the result of the super execute method as it is instead of returning None
391 return condition
392
393
394def _load_pickle():
395 import pickle
396
397 return pickle
398
399
400def _load_dill():
401 try:
402 import dill
403 except ModuleNotFoundError:
404 log.error("Unable to import `dill` module. Please please make sure that it installed.")
405 raise
406 return dill
407
408
409def _load_cloudpickle():
410 try:
411 import cloudpickle
412 except ModuleNotFoundError:
413 log.error(
414 "Unable to import `cloudpickle` module. "
415 "Please install it with: pip install 'apache-airflow[cloudpickle]'"
416 )
417 raise
418 return cloudpickle
419
420
421_SERIALIZERS: dict[_SerializerTypeDef, Any] = {
422 "pickle": lazy_object_proxy.Proxy(_load_pickle),
423 "dill": lazy_object_proxy.Proxy(_load_dill),
424 "cloudpickle": lazy_object_proxy.Proxy(_load_cloudpickle),
425}
426
427
428class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta):
429 BASE_SERIALIZABLE_CONTEXT_KEYS = {
430 "ds",
431 "ds_nodash",
432 "expanded_ti_count",
433 "inlets",
434 "outlets",
435 "run_id",
436 "task_instance_key_str",
437 "test_mode",
438 "ts",
439 "ts_nodash",
440 "ts_nodash_with_tz",
441 # The following should be removed when Airflow 2 support is dropped.
442 "next_ds",
443 "next_ds_nodash",
444 "prev_ds",
445 "prev_ds_nodash",
446 "tomorrow_ds",
447 "tomorrow_ds_nodash",
448 "yesterday_ds",
449 "yesterday_ds_nodash",
450 }
451 if AIRFLOW_V_3_0_PLUS:
452 BASE_SERIALIZABLE_CONTEXT_KEYS.add("task_reschedule_count")
453
454 PENDULUM_SERIALIZABLE_CONTEXT_KEYS = {
455 "data_interval_end",
456 "data_interval_start",
457 "logical_date",
458 "prev_data_interval_end_success",
459 "prev_data_interval_start_success",
460 "prev_start_date_success",
461 "prev_end_date_success",
462 # The following should be removed when Airflow 2 support is dropped.
463 "execution_date",
464 "next_execution_date",
465 "prev_execution_date",
466 "prev_execution_date_success",
467 }
468
469 AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = {
470 "macros",
471 "conf",
472 "dag",
473 "dag_run",
474 "task",
475 "params",
476 "triggering_asset_events",
477 # The following should be removed when Airflow 2 support is dropped.
478 "triggering_dataset_events",
479 }
480
481 def __init__(
482 self,
483 *,
484 python_callable: Callable,
485 serializer: _SerializerTypeDef | None = None,
486 op_args: Collection[Any] | None = None,
487 op_kwargs: Mapping[str, Any] | None = None,
488 string_args: Iterable[str] | None = None,
489 templates_dict: dict | None = None,
490 templates_exts: list[str] | None = None,
491 expect_airflow: bool = True,
492 skip_on_exit_code: int | Container[int] | None = None,
493 env_vars: dict[str, str] | None = None,
494 inherit_env: bool = True,
495 **kwargs,
496 ):
497 if (
498 not isinstance(python_callable, types.FunctionType)
499 or isinstance(python_callable, types.LambdaType)
500 and python_callable.__name__ == "<lambda>"
501 ):
502 raise ValueError(f"{type(self).__name__} only supports functions for python_callable arg")
503 if inspect.isgeneratorfunction(python_callable):
504 raise ValueError(f"{type(self).__name__} does not support using 'yield' in python_callable")
505 super().__init__(
506 python_callable=python_callable,
507 op_args=op_args,
508 op_kwargs=op_kwargs,
509 templates_dict=templates_dict,
510 templates_exts=templates_exts,
511 **kwargs,
512 )
513 self.string_args = string_args or []
514
515 serializer = serializer or "pickle"
516 if serializer not in _SERIALIZERS:
517 msg = (
518 f"Unsupported serializer {serializer!r}. Expected one of {', '.join(map(repr, _SERIALIZERS))}"
519 )
520 raise AirflowException(msg)
521
522 self.pickling_library = _SERIALIZERS[serializer]
523 self.serializer: _SerializerTypeDef = serializer
524
525 self.expect_airflow = expect_airflow
526 self.skip_on_exit_code = (
527 skip_on_exit_code
528 if isinstance(skip_on_exit_code, Container)
529 else [skip_on_exit_code]
530 if skip_on_exit_code is not None
531 else []
532 )
533 self.env_vars = env_vars
534 self.inherit_env = inherit_env
535
536 @abstractmethod
537 def _iter_serializable_context_keys(self):
538 pass
539
540 def execute(self, context: Context) -> Any:
541 serializable_keys = set(self._iter_serializable_context_keys())
542 new = {k: v for k, v in context.items() if k in serializable_keys}
543 serializable_context = cast("Context", new)
544 # Store bundle_path for subprocess execution
545 self._bundle_path = self._get_bundle_path_from_context(context)
546 return super().execute(context=serializable_context)
547
548 def _get_bundle_path_from_context(self, context: Context) -> str | None:
549 """
550 Extract bundle_path from the task instance's bundle_instance.
551
552 :param context: The task execution context
553 :return: Path to the bundle root directory, or None if not in a bundle
554 """
555 if not AIRFLOW_V_3_0_PLUS:
556 return None
557
558 # In Airflow 3.x, the RuntimeTaskInstance has a bundle_instance attribute
559 # that contains the bundle information including its path
560 ti = context["ti"]
561 if bundle_instance := getattr(ti, "bundle_instance", None):
562 return bundle_instance.path
563
564 return None
565
566 def get_python_source(self):
567 """Return the source of self.python_callable."""
568 return textwrap.dedent(inspect.getsource(self.python_callable))
569
570 def _write_args(self, file: Path):
571 def resolve_proxies(obj):
572 """Recursively replaces lazy_object_proxy.Proxy instances with their resolved values."""
573 if isinstance(obj, lazy_object_proxy.Proxy):
574 return obj.__wrapped__ # force evaluation
575 if isinstance(obj, dict):
576 return {k: resolve_proxies(v) for k, v in obj.items()}
577 if isinstance(obj, list):
578 return [resolve_proxies(v) for v in obj]
579 return obj
580
581 if self.op_args or self.op_kwargs:
582 self.log.info("Use %r as serializer.", self.serializer)
583 file.write_bytes(
584 self.pickling_library.dumps({"args": self.op_args, "kwargs": resolve_proxies(self.op_kwargs)})
585 )
586
587 def _write_string_args(self, file: Path):
588 file.write_text("\n".join(map(str, self.string_args)))
589
590 def _read_result(self, path: Path):
591 if path.stat().st_size == 0:
592 return None
593 try:
594 return self.pickling_library.loads(path.read_bytes())
595 except ValueError as value_error:
596 raise DeserializingResultError() from value_error
597
598 def __deepcopy__(self, memo):
599 # module objects can't be copied _at all__
600 memo[id(self.pickling_library)] = self.pickling_library
601 return super().__deepcopy__(memo)
602
603 def _execute_python_callable_in_subprocess(self, python_path: Path):
604 with TemporaryDirectory(prefix="venv-call") as tmp:
605 tmp_dir = Path(tmp)
606 op_kwargs: dict[str, Any] = dict(self.op_kwargs)
607 if self.templates_dict:
608 op_kwargs["templates_dict"] = self.templates_dict
609 input_path = tmp_dir / "script.in"
610 output_path = tmp_dir / "script.out"
611 string_args_path = tmp_dir / "string_args.txt"
612 script_path = tmp_dir / "script.py"
613 termination_log_path = tmp_dir / "termination.log"
614 airflow_context_path = tmp_dir / "airflow_context.json"
615
616 self._write_args(input_path)
617 self._write_string_args(string_args_path)
618
619 jinja_context = {
620 "op_args": self.op_args,
621 "op_kwargs": op_kwargs,
622 "expect_airflow": self.expect_airflow,
623 "pickling_library": self.serializer,
624 "python_callable": self.python_callable.__name__,
625 "python_callable_source": self.get_python_source(),
626 }
627
628 if inspect.getfile(self.python_callable) == self.dag.fileloc:
629 jinja_context["modified_dag_module_name"] = get_unique_dag_module_name(self.dag.fileloc)
630
631 write_python_script(
632 jinja_context=jinja_context,
633 filename=os.fspath(script_path),
634 render_template_as_native_obj=self.dag.render_template_as_native_obj,
635 )
636
637 env_vars = dict(os.environ) if self.inherit_env else {}
638 if fd := os.getenv("__AIRFLOW_SUPERVISOR_FD"):
639 env_vars["__AIRFLOW_SUPERVISOR_FD"] = fd
640 if self.env_vars:
641 env_vars.update(self.env_vars)
642
643 # Add bundle_path to PYTHONPATH for subprocess to import Dag bundle modules
644 if self._bundle_path:
645 bundle_path = self._bundle_path
646 existing_pythonpath = env_vars.get("PYTHONPATH", "")
647 if existing_pythonpath:
648 # Append bundle_path after existing PYTHONPATH
649 env_vars["PYTHONPATH"] = f"{existing_pythonpath}{os.pathsep}{bundle_path}"
650 else:
651 env_vars["PYTHONPATH"] = bundle_path
652
653 try:
654 cmd: list[str] = [
655 os.fspath(python_path),
656 os.fspath(script_path),
657 os.fspath(input_path),
658 os.fspath(output_path),
659 os.fspath(string_args_path),
660 os.fspath(termination_log_path),
661 os.fspath(airflow_context_path),
662 ]
663 _execute_in_subprocess(
664 cmd=cmd,
665 env=env_vars,
666 )
667 except subprocess.CalledProcessError as e:
668 if e.returncode in self.skip_on_exit_code:
669 raise AirflowSkipException(f"Process exited with code {e.returncode}. Skipping.")
670 if termination_log_path.exists() and termination_log_path.stat().st_size > 0:
671 error_msg = f"Process returned non-zero exit status {e.returncode}.\n"
672 with open(termination_log_path) as file:
673 error_msg += file.read()
674 raise AirflowException(error_msg) from None
675 raise
676
677 if 0 in self.skip_on_exit_code:
678 raise AirflowSkipException("Process exited with code 0. Skipping.")
679
680 return self._read_result(output_path)
681
682 def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
683 keyword_params = KeywordParameters.determine(self.python_callable, self.op_args, context)
684 if AIRFLOW_V_3_0_PLUS:
685 return keyword_params.unpacking()
686 return keyword_params.serializing() # type: ignore[attr-defined]
687
688
689class PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
690 """
691 Run a function in a virtualenv that is created and destroyed automatically.
692
693 The function (has certain caveats) must be defined using def, and not be
694 part of a class. All imports must happen inside the function
695 and no variables outside the scope may be referenced. A global scope
696 variable named virtualenv_string_args will be available (populated by
697 string_args). In addition, one can pass stuff through op_args and op_kwargs, and one
698 can use a return value.
699 Note that if your virtualenv runs in a different Python major version than Airflow,
700 you cannot use return values, op_args, op_kwargs, or use any macros that are being provided to
701 Airflow through plugins. You can use string_args though.
702
703 .. seealso::
704 For more information on how to use this operator, take a look at the guide:
705 :ref:`howto/operator:PythonVirtualenvOperator`
706
707 :param python_callable: A python function with no references to outside variables,
708 defined with def, which will be run in a virtual environment.
709 :param requirements: Either a list of requirement strings, or a (templated)
710 "requirements file" as specified by pip.
711 :param python_version: The Python version to run the virtual environment with. Note that
712 both 2 and 2.7 are acceptable forms.
713 :param serializer: Which serializer use to serialize the args and result. It can be one of the following:
714
715 - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library.
716 - ``"cloudpickle"``: Use cloudpickle for serialize more complex types,
717 this requires to include cloudpickle in your requirements.
718 - ``"dill"``: Use dill for serialize more complex types,
719 this requires to include dill in your requirements.
720 :param system_site_packages: Whether to include
721 system_site_packages in your virtual environment.
722 See virtualenv documentation for more information.
723 :param pip_install_options: a list of pip install options when installing requirements
724 See 'pip install -h' for available options
725 :param op_args: A list of positional arguments to pass to python_callable.
726 :param op_kwargs: A dict of keyword arguments to pass to python_callable.
727 :param string_args: Strings that are present in the global var virtualenv_string_args,
728 available to python_callable at runtime as a list[str]. Note that args are split
729 by newline.
730 :param templates_dict: a dictionary where the values are templates that
731 will get templated by the Airflow engine sometime between
732 ``__init__`` and ``execute`` takes place and are made available
733 in your callable's context after the template has been applied
734 :param templates_exts: a list of file extensions to resolve while
735 processing templated fields, for examples ``['.sql', '.hql']``
736 :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator
737 will raise warning if Airflow is not installed, and it will attempt to load Airflow
738 macros when starting.
739 :param skip_on_exit_code: If python_callable exits with this exit code, leave the task
740 in ``skipped`` state (default: None). If set to ``None``, any non-zero
741 exit code will be treated as a failure.
742 :param index_urls: an optional list of index urls to load Python packages from.
743 If not provided the system pip conf will be used to source packages from.
744 :param index_urls_from_connection_ids: An optional list of ``PackageIndex`` connection IDs.
745 Will be appended to ``index_urls``.
746 :param venv_cache_path: Optional path to the virtual environment parent folder in which the
747 virtual environment will be cached, creates a sub-folder venv-{hash} whereas hash will be replaced
748 with a checksum of requirements. If not provided the virtual environment will be created and deleted
749 in a temp folder for every execution.
750 :param env_vars: A dictionary containing additional environment variables to set for the virtual
751 environment when it is executed.
752 :param inherit_env: Whether to inherit the current environment variables when executing the virtual
753 environment. If set to ``True``, the virtual environment will inherit the environment variables
754 of the parent process (``os.environ``). If set to ``False``, the virtual environment will be
755 executed with a clean environment.
756 """
757
758 template_fields: Sequence[str] = tuple(
759 {"requirements", "index_urls", "index_urls_from_connection_ids", "venv_cache_path"}.union(
760 PythonOperator.template_fields
761 )
762 )
763 template_ext: Sequence[str] = (".txt",)
764
765 def __init__(
766 self,
767 *,
768 python_callable: Callable,
769 requirements: None | Iterable[str] | str = None,
770 python_version: str | None = None,
771 serializer: _SerializerTypeDef | None = None,
772 system_site_packages: bool = True,
773 pip_install_options: list[str] | None = None,
774 op_args: Collection[Any] | None = None,
775 op_kwargs: Mapping[str, Any] | None = None,
776 string_args: Iterable[str] | None = None,
777 templates_dict: dict | None = None,
778 templates_exts: list[str] | None = None,
779 expect_airflow: bool = True,
780 skip_on_exit_code: int | Container[int] | None = None,
781 index_urls: None | Collection[str] | str = None,
782 index_urls_from_connection_ids: None | Collection[str] | str = None,
783 venv_cache_path: None | os.PathLike[str] = None,
784 env_vars: dict[str, str] | None = None,
785 inherit_env: bool = True,
786 **kwargs,
787 ):
788 if (
789 python_version
790 and str(python_version)[0] != str(sys.version_info.major)
791 and (op_args or op_kwargs)
792 ):
793 raise AirflowException(
794 "Passing op_args or op_kwargs is not supported across different Python "
795 "major versions for PythonVirtualenvOperator. Please use string_args."
796 f"Sys version: {sys.version_info}. Virtual environment version: {python_version}"
797 )
798 if python_version is not None and not isinstance(python_version, str):
799 raise AirflowException(
800 "Passing non-string types (e.g. int or float) as python_version not supported"
801 )
802 if not requirements:
803 self.requirements: list[str] = []
804 elif isinstance(requirements, str):
805 self.requirements = [requirements]
806 else:
807 self.requirements = list(requirements)
808 self.python_version = python_version
809 self.system_site_packages = system_site_packages
810 self.pip_install_options = pip_install_options
811 if isinstance(index_urls, str):
812 self.index_urls: list[str] | None = [index_urls]
813 elif isinstance(index_urls, Collection):
814 self.index_urls = list(index_urls)
815 else:
816 self.index_urls = None
817 if isinstance(index_urls_from_connection_ids, str):
818 self.index_urls_from_connection_ids: list[str] | None = [index_urls_from_connection_ids]
819 elif isinstance(index_urls_from_connection_ids, Collection):
820 self.index_urls_from_connection_ids = list(index_urls_from_connection_ids)
821 else:
822 self.index_urls_from_connection_ids = None
823 self.venv_cache_path = venv_cache_path
824 super().__init__(
825 python_callable=python_callable,
826 serializer=serializer,
827 op_args=op_args,
828 op_kwargs=op_kwargs,
829 string_args=string_args,
830 templates_dict=templates_dict,
831 templates_exts=templates_exts,
832 expect_airflow=expect_airflow,
833 skip_on_exit_code=skip_on_exit_code,
834 env_vars=env_vars,
835 inherit_env=inherit_env,
836 **kwargs,
837 )
838
839 def _requirements_list(self, exclude_cloudpickle: bool = False) -> list[str]:
840 """Prepare a list of requirements that need to be installed for the virtual environment."""
841 requirements = [str(dependency) for dependency in self.requirements]
842 if not self.system_site_packages:
843 if (
844 self.serializer == "cloudpickle"
845 and not exclude_cloudpickle
846 and "cloudpickle" not in requirements
847 ):
848 requirements.append("cloudpickle")
849 elif self.serializer == "dill" and "dill" not in requirements:
850 requirements.append("dill")
851 requirements.sort() # Ensure a hash is stable
852 return requirements
853
854 def _prepare_venv(self, venv_path: Path) -> None:
855 """Prepare the requirements and installs the virtual environment."""
856 requirements_file = venv_path / "requirements.txt"
857 requirements_file.write_text("\n".join(self._requirements_list()))
858 prepare_virtualenv(
859 venv_directory=str(venv_path),
860 python_bin=f"python{self.python_version}" if self.python_version else "python",
861 system_site_packages=self.system_site_packages,
862 requirements_file_path=str(requirements_file),
863 pip_install_options=self.pip_install_options,
864 index_urls=self.index_urls,
865 )
866
867 def _calculate_cache_hash(self, exclude_cloudpickle: bool = False) -> tuple[str, str]:
868 """
869 Generate the hash of the cache folder to use.
870
871 The following factors are used as input for the hash:
872 - (sorted) list of requirements
873 - pip install options
874 - flag of system site packages
875 - python version
876 - Variable to override the hash with a cache key
877 - Index URLs
878
879 Returns a hash and the data dict which is the base for the hash as text.
880 """
881 hash_dict = {
882 "requirements_list": self._requirements_list(exclude_cloudpickle=exclude_cloudpickle),
883 "pip_install_options": self.pip_install_options,
884 "index_urls": self.index_urls,
885 "cache_key": str(Variable.get("PythonVirtualenvOperator.cache_key", "")),
886 "python_version": self.python_version,
887 "system_site_packages": self.system_site_packages,
888 }
889 hash_text = json.dumps(hash_dict, sort_keys=True)
890 hash_object = hashlib_wrapper.md5(hash_text.encode())
891 requirements_hash = hash_object.hexdigest()
892 return requirements_hash[:8], hash_text
893
894 def _ensure_venv_cache_exists(self, venv_cache_path: Path) -> Path:
895 """Ensure a valid virtual environment is set up and will create inplace."""
896 cache_hash, hash_data = self._calculate_cache_hash()
897 venv_path = venv_cache_path / f"venv-{cache_hash}"
898 self.log.info("Python virtual environment will be cached in %s", venv_path)
899 venv_path.parent.mkdir(parents=True, exist_ok=True)
900 with open(f"{venv_path}.lock", "w") as f:
901 # Ensure that cache is not build by parallel workers
902 import fcntl
903
904 fcntl.flock(f, fcntl.LOCK_EX)
905
906 hash_marker = venv_path / "install_complete_marker.json"
907 try:
908 if venv_path.exists():
909 if hash_marker.exists():
910 previous_hash_data = hash_marker.read_text(encoding="utf8")
911 if previous_hash_data == hash_data:
912 self.log.info("Reusing cached Python virtual environment in %s", venv_path)
913 return venv_path
914
915 _, hash_data_before_upgrade = self._calculate_cache_hash(exclude_cloudpickle=True)
916 if previous_hash_data == hash_data_before_upgrade:
917 self.log.warning(
918 "Found a previous virtual environment in with outdated dependencies %s, "
919 "deleting and re-creating.",
920 venv_path,
921 )
922 else:
923 self.log.error(
924 "Unicorn alert: Found a previous virtual environment in %s "
925 "with the same hash but different parameters. Previous setup: '%s' / "
926 "Requested venv setup: '%s'. Please report a bug to airflow!",
927 venv_path,
928 previous_hash_data,
929 hash_data,
930 )
931 else:
932 self.log.warning(
933 "Found a previous (probably partial installed) virtual environment in %s, "
934 "deleting and re-creating.",
935 venv_path,
936 )
937
938 shutil.rmtree(venv_path)
939
940 venv_path.mkdir(parents=True)
941 self._prepare_venv(venv_path)
942 hash_marker.write_text(hash_data, encoding="utf8")
943 except Exception as e:
944 shutil.rmtree(venv_path)
945 raise AirflowException(f"Unable to create new virtual environment in {venv_path}") from e
946 self.log.info("New Python virtual environment created in %s", venv_path)
947 return venv_path
948
949 def _cleanup_python_pycache_dir(self, cache_dir_path: Path) -> None:
950 try:
951 shutil.rmtree(cache_dir_path)
952 self.log.debug("The directory %s has been deleted.", cache_dir_path)
953 except FileNotFoundError:
954 self.log.warning("Fail to delete %s. The directory does not exist.", cache_dir_path)
955 except PermissionError:
956 self.log.warning("Permission denied to delete the directory %s.", cache_dir_path)
957
958 def _retrieve_index_urls_from_connection_ids(self):
959 """Retrieve index URLs from Package Index connections."""
960 if self.index_urls is None:
961 self.index_urls = []
962 for conn_id in self.index_urls_from_connection_ids:
963 conn_url = PackageIndexHook(conn_id).get_connection_url()
964 self.index_urls.append(conn_url)
965
966 def execute_callable(self):
967 if self.index_urls_from_connection_ids:
968 self._retrieve_index_urls_from_connection_ids()
969
970 if self.venv_cache_path:
971 venv_path = self._ensure_venv_cache_exists(Path(self.venv_cache_path))
972 python_path = venv_path / "bin" / "python"
973 return self._execute_python_callable_in_subprocess(python_path)
974
975 with TemporaryDirectory(prefix="venv") as tmp_dir:
976 tmp_path = Path(tmp_dir)
977 custom_pycache_prefix = Path(sys.pycache_prefix or "")
978 r_path = tmp_path.relative_to(tmp_path.anchor)
979 venv_python_cache_dir = Path.cwd() / custom_pycache_prefix / r_path
980 self._prepare_venv(tmp_path)
981 python_path = tmp_path / "bin" / "python"
982 result = self._execute_python_callable_in_subprocess(python_path)
983 self._cleanup_python_pycache_dir(venv_python_cache_dir)
984 return result
985
986 def _iter_serializable_context_keys(self):
987 yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS
988
989 found_airflow = found_pendulum = False
990
991 if self.system_site_packages:
992 # If we're using system packages, assume both are present
993 found_airflow = found_pendulum = True
994 else:
995 for raw_str in chain.from_iterable(req.splitlines() for req in self.requirements):
996 line = raw_str.strip()
997 # Skip blank lines and full‐line comments
998 if not line or line.startswith("#"):
999 continue
1000
1001 # Strip off any inline comment
1002 # e.g. turn "foo==1.2.3 # comment" → "foo==1.2.3"
1003 req_str = re.sub(r"#.*$", "", line).strip()
1004
1005 try:
1006 req = Requirement(req_str)
1007 except (InvalidRequirement, InvalidSpecifier, InvalidVersion) as e:
1008 raise ValueError(f"Invalid requirement '{raw_str}': {e}") from e
1009
1010 if req.name == "apache-airflow":
1011 found_airflow = found_pendulum = True
1012 break
1013 elif req.name == "pendulum":
1014 found_pendulum = True
1015
1016 if found_airflow:
1017 yield from self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS
1018 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
1019 elif found_pendulum:
1020 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
1021
1022
1023class BranchPythonVirtualenvOperator(BaseBranchOperator, PythonVirtualenvOperator):
1024 """
1025 A workflow can "branch" or follow a path after the execution of this task in a virtual environment.
1026
1027 It derives the PythonVirtualenvOperator and expects a Python function that returns
1028 a single task_id, a single task_group_id, or a list of task_ids and/or
1029 task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned
1030 should point to a task or task group directly downstream from {self}. All
1031 other "branches" or directly downstream tasks are marked with a state of
1032 ``skipped`` so that these paths can't move forward. The ``skipped`` states
1033 are propagated downstream to allow for the DAG state to fill up and
1034 the DAG run's state to be inferred.
1035
1036 .. seealso::
1037 For more information on how to use this operator, take a look at the guide:
1038 :ref:`howto/operator:BranchPythonVirtualenvOperator`
1039 """
1040
1041 def choose_branch(self, context: Context) -> str | Iterable[str]:
1042 return PythonVirtualenvOperator.execute(self, context)
1043
1044
1045class ExternalPythonOperator(_BasePythonVirtualenvOperator):
1046 """
1047 Run a function in a virtualenv that is not re-created.
1048
1049 Reused as is without the overhead of creating the virtual environment (with certain caveats).
1050
1051 The function must be defined using def, and not be
1052 part of a class. All imports must happen inside the function
1053 and no variables outside the scope may be referenced. A global scope
1054 variable named virtualenv_string_args will be available (populated by
1055 string_args). In addition, one can pass stuff through op_args and op_kwargs, and one
1056 can use a return value.
1057 Note that if your virtual environment runs in a different Python major version than Airflow,
1058 you cannot use return values, op_args, op_kwargs, or use any macros that are being provided to
1059 Airflow through plugins. You can use string_args though.
1060
1061 If Airflow is installed in the external environment in different version that the version
1062 used by the operator, the operator will fail.,
1063
1064 .. seealso::
1065 For more information on how to use this operator, take a look at the guide:
1066 :ref:`howto/operator:ExternalPythonOperator`
1067
1068 :param python: Full path string (file-system specific) that points to a Python binary inside
1069 a virtual environment that should be used (in ``VENV/bin`` folder). Should be absolute path
1070 (so usually start with "/" or "X:/" depending on the filesystem/os used).
1071 :param python_callable: A python function with no references to outside variables,
1072 defined with def, which will be run in a virtual environment.
1073 :param serializer: Which serializer use to serialize the args and result. It can be one of the following:
1074
1075 - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library.
1076 - ``"cloudpickle"``: Use cloudpickle for serialize more complex types,
1077 this requires to include cloudpickle in your requirements.
1078 - ``"dill"``: Use dill for serialize more complex types,
1079 this requires to include dill in your requirements.
1080 :param op_args: A list of positional arguments to pass to python_callable.
1081 :param op_kwargs: A dict of keyword arguments to pass to python_callable.
1082 :param string_args: Strings that are present in the global var virtualenv_string_args,
1083 available to python_callable at runtime as a list[str]. Note that args are split
1084 by newline.
1085 :param templates_dict: a dictionary where the values are templates that
1086 will get templated by the Airflow engine sometime between
1087 ``__init__`` and ``execute`` takes place and are made available
1088 in your callable's context after the template has been applied
1089 :param templates_exts: a list of file extensions to resolve while
1090 processing templated fields, for examples ``['.sql', '.hql']``
1091 :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator
1092 will raise warning if Airflow is not installed, and it will attempt to load Airflow
1093 macros when starting.
1094 :param skip_on_exit_code: If python_callable exits with this exit code, leave the task
1095 in ``skipped`` state (default: None). If set to ``None``, any non-zero
1096 exit code will be treated as a failure.
1097 :param env_vars: A dictionary containing additional environment variables to set for the virtual
1098 environment when it is executed.
1099 :param inherit_env: Whether to inherit the current environment variables when executing the virtual
1100 environment. If set to ``True``, the virtual environment will inherit the environment variables
1101 of the parent process (``os.environ``). If set to ``False``, the virtual environment will be
1102 executed with a clean environment.
1103 """
1104
1105 template_fields: Sequence[str] = tuple({"python"}.union(PythonOperator.template_fields))
1106
1107 def __init__(
1108 self,
1109 *,
1110 python: str,
1111 python_callable: Callable,
1112 serializer: _SerializerTypeDef | None = None,
1113 op_args: Collection[Any] | None = None,
1114 op_kwargs: Mapping[str, Any] | None = None,
1115 string_args: Iterable[str] | None = None,
1116 templates_dict: dict | None = None,
1117 templates_exts: list[str] | None = None,
1118 expect_airflow: bool = True,
1119 expect_pendulum: bool = False,
1120 skip_on_exit_code: int | Container[int] | None = None,
1121 env_vars: dict[str, str] | None = None,
1122 inherit_env: bool = True,
1123 **kwargs,
1124 ):
1125 if not python:
1126 raise ValueError("Python Path must be defined in ExternalPythonOperator")
1127 self.python = python
1128 self.expect_pendulum = expect_pendulum
1129 super().__init__(
1130 python_callable=python_callable,
1131 serializer=serializer,
1132 op_args=op_args,
1133 op_kwargs=op_kwargs,
1134 string_args=string_args,
1135 templates_dict=templates_dict,
1136 templates_exts=templates_exts,
1137 expect_airflow=expect_airflow,
1138 skip_on_exit_code=skip_on_exit_code,
1139 env_vars=env_vars,
1140 inherit_env=inherit_env,
1141 **kwargs,
1142 )
1143
1144 def execute_callable(self):
1145 python_path = Path(self.python)
1146 if not python_path.exists():
1147 raise ValueError(f"Python Path '{python_path}' must exists")
1148 if not python_path.is_file():
1149 raise ValueError(f"Python Path '{python_path}' must be a file")
1150 if not python_path.is_absolute():
1151 raise ValueError(f"Python Path '{python_path}' must be an absolute path.")
1152 python_version = _PythonVersionInfo.from_executable(self.python)
1153 if python_version.major != sys.version_info.major and (self.op_args or self.op_kwargs):
1154 raise AirflowException(
1155 "Passing op_args or op_kwargs is not supported across different Python "
1156 "major versions for ExternalPythonOperator. Please use string_args."
1157 f"Sys version: {sys.version_info}. "
1158 f"Virtual environment version: {python_version}"
1159 )
1160 return self._execute_python_callable_in_subprocess(python_path)
1161
1162 def _iter_serializable_context_keys(self):
1163 yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS
1164 if self.expect_airflow and self._get_airflow_version_from_target_env():
1165 yield from self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS
1166 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
1167 elif self._is_pendulum_installed_in_target_env():
1168 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
1169
1170 def _is_pendulum_installed_in_target_env(self) -> bool:
1171 try:
1172 subprocess.check_call([self.python, "-c", "import pendulum"])
1173 return True
1174 except Exception as e:
1175 if self.expect_pendulum:
1176 self.log.warning("When checking for Pendulum installed in virtual environment got %s", e)
1177 self.log.warning(
1178 "Pendulum is not properly installed in the virtual environment "
1179 "Pendulum context keys will not be available. "
1180 "Please Install Pendulum or Airflow in your virtual environment to access them."
1181 )
1182 return False
1183
1184 @property
1185 def _external_airflow_version_script(self):
1186 """
1187 Return python script which determines the version of the Apache Airflow.
1188
1189 Import airflow as a module might take a while as a result,
1190 obtaining a version would take up to 1 second.
1191 On the other hand, `importlib.metadata.version` will retrieve the package version pretty fast
1192 something below 100ms; this includes new subprocess overhead.
1193
1194 Possible side effect: It might be a situation that `importlib.metadata` is not available (Python < 3.8),
1195 as well as backport `importlib_metadata` which might indicate that venv doesn't contain an `apache-airflow`
1196 or something wrong with the environment.
1197 """
1198 return textwrap.dedent(
1199 """
1200 try:
1201 from importlib.metadata import version
1202 except ImportError:
1203 from importlib_metadata import version
1204 print(version("apache-airflow"))
1205 """
1206 )
1207
1208 def _get_airflow_version_from_target_env(self) -> str | None:
1209 from airflow import __version__ as airflow_version
1210
1211 try:
1212 result = subprocess.check_output(
1213 [self.python, "-c", self._external_airflow_version_script],
1214 text=True,
1215 )
1216 target_airflow_version = result.strip()
1217 if target_airflow_version != airflow_version:
1218 raise AirflowConfigException(
1219 f"The version of Airflow installed for the {self.python} "
1220 f"({target_airflow_version}) is different than the runtime Airflow version: "
1221 f"{airflow_version}. Make sure your environment has the same Airflow version "
1222 f"installed as the Airflow runtime."
1223 )
1224 return target_airflow_version
1225 except Exception as e:
1226 if self.expect_airflow:
1227 self.log.warning("When checking for Airflow installed in virtual environment got %s", e)
1228 self.log.warning(
1229 "This means that Airflow is not properly installed by %s. "
1230 "Airflow context keys will not be available. "
1231 "Please Install Airflow %s in your environment to access them.",
1232 self.python,
1233 airflow_version,
1234 )
1235 return None
1236
1237
1238class BranchExternalPythonOperator(BaseBranchOperator, ExternalPythonOperator):
1239 """
1240 A workflow can "branch" or follow a path after the execution of this task.
1241
1242 Extends ExternalPythonOperator, so expects to get Python:
1243 virtual environment that should be used (in ``VENV/bin`` folder). Should be absolute path,
1244 so it can run on separate virtual environment similarly to ExternalPythonOperator.
1245
1246 .. seealso::
1247 For more information on how to use this operator, take a look at the guide:
1248 :ref:`howto/operator:BranchExternalPythonOperator`
1249 """
1250
1251 def choose_branch(self, context: Context) -> str | Iterable[str]:
1252 return ExternalPythonOperator.execute(self, context)
1253
1254
1255def get_current_context() -> Mapping[str, Any]:
1256 """
1257 Retrieve the execution context dictionary without altering user method's signature.
1258
1259 This is the simplest method of retrieving the execution context dictionary.
1260
1261 **Old style:**
1262
1263 .. code:: python
1264
1265 def my_task(**context):
1266 ti = context["ti"]
1267
1268 **New style:**
1269
1270 .. code:: python
1271
1272 from airflow.providers.standard.operators.python import get_current_context
1273
1274
1275 def my_task():
1276 context = get_current_context()
1277 ti = context["ti"]
1278
1279 Current context will only have value if this method was called after an operator
1280 was starting to execute.
1281 """
1282 if AIRFLOW_V_3_0_PLUS:
1283 warnings.warn(
1284 "Using get_current_context from standard provider is deprecated and will be removed."
1285 "Please import `from airflow.sdk import get_current_context` and use it instead.",
1286 AirflowProviderDeprecationWarning,
1287 stacklevel=2,
1288 )
1289
1290 from airflow.sdk import get_current_context
1291
1292 return get_current_context()
1293 return _get_current_context()
1294
1295
1296def _get_current_context() -> Mapping[str, Any]:
1297 # Airflow 2.x
1298 # TODO: To be removed when Airflow 2 support is dropped
1299 from airflow.models.taskinstance import _CURRENT_CONTEXT # type: ignore[attr-defined]
1300
1301 if not _CURRENT_CONTEXT:
1302 raise RuntimeError(
1303 "Current context was requested but no context was found! Are you running within an Airflow task?"
1304 )
1305 return _CURRENT_CONTEXT[-1]