1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import inspect
21import json
22import logging
23import os
24import re
25import shutil
26import subprocess
27import sys
28import textwrap
29import types
30import warnings
31from abc import ABCMeta, abstractmethod
32from collections.abc import Callable, Collection, Container, Iterable, Mapping, Sequence
33from functools import cache
34from itertools import chain
35from pathlib import Path
36from tempfile import TemporaryDirectory
37from typing import TYPE_CHECKING, Any, NamedTuple, cast
38
39import lazy_object_proxy
40from packaging.requirements import InvalidRequirement, Requirement
41from packaging.specifiers import InvalidSpecifier
42from packaging.version import InvalidVersion
43
44from airflow.exceptions import (
45 AirflowConfigException,
46 AirflowProviderDeprecationWarning,
47 DeserializingResultError,
48)
49from airflow.models.variable import Variable
50from airflow.providers.common.compat.sdk import AirflowException, AirflowSkipException, context_merge
51from airflow.providers.standard.hooks.package_index import PackageIndexHook
52from airflow.providers.standard.utils.python_virtualenv import (
53 _execute_in_subprocess,
54 prepare_virtualenv,
55 write_python_script,
56)
57from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator
58from airflow.utils import hashlib_wrapper
59from airflow.utils.file import get_unique_dag_module_name
60from airflow.utils.operator_helpers import KeywordParameters
61
62if AIRFLOW_V_3_0_PLUS:
63 from airflow.providers.standard.operators.branch import BaseBranchOperator
64 from airflow.providers.standard.utils.skipmixin import SkipMixin
65else:
66 from airflow.models.skipmixin import SkipMixin
67 from airflow.operators.branch import BaseBranchOperator # type: ignore[no-redef]
68
69
70log = logging.getLogger(__name__)
71
72if TYPE_CHECKING:
73 from typing import Literal
74
75 from pendulum.datetime import DateTime
76
77 from airflow.providers.common.compat.sdk import Context
78 from airflow.sdk.execution_time.callback_runner import ExecutionCallableRunner
79 from airflow.sdk.execution_time.context import OutletEventAccessorsProtocol
80
81 _SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]
82
83
84@cache
85def _parse_version_info(text: str) -> tuple[int, int, int, str, int]:
86 """Parse python version info from a text."""
87 parts = text.strip().split(".")
88 if len(parts) != 5:
89 msg = f"Invalid Python version info, expected 5 components separated by '.', but got {text!r}."
90 raise ValueError(msg)
91 try:
92 return int(parts[0]), int(parts[1]), int(parts[2]), parts[3], int(parts[4])
93 except ValueError:
94 msg = f"Unable to convert parts {parts} parsed from {text!r} to (int, int, int, str, int)."
95 raise ValueError(msg) from None
96
97
98class _PythonVersionInfo(NamedTuple):
99 """Provide the same interface as ``sys.version_info``."""
100
101 major: int
102 minor: int
103 micro: int
104 releaselevel: str
105 serial: int
106
107 @classmethod
108 def from_executable(cls, executable: str) -> _PythonVersionInfo:
109 """Parse python version info from an executable."""
110 cmd = [executable, "-c", 'import sys; print(".".join(map(str, sys.version_info)))']
111 try:
112 result = subprocess.check_output(cmd, text=True)
113 except Exception as e:
114 raise ValueError(f"Error while executing command {cmd}: {e}")
115 return cls(*_parse_version_info(result.strip()))
116
117
118class PythonOperator(BaseOperator):
119 """
120 Executes a Python callable.
121
122 .. seealso::
123 For more information on how to use this operator, take a look at the guide:
124 :ref:`howto/operator:PythonOperator`
125
126 When running your callable, Airflow will pass a set of keyword arguments that can be used in your
127 function. This set of kwargs correspond exactly to what you can use in your jinja templates.
128 For this to work, you need to define ``**kwargs`` in your function header, or you can add directly the
129 keyword arguments you would like to get - for example with the below code your callable will get
130 the values of ``ti`` context variables.
131
132 With explicit arguments:
133
134 .. code-block:: python
135
136 def my_python_callable(ti):
137 pass
138
139 With kwargs:
140
141 .. code-block:: python
142
143 def my_python_callable(**kwargs):
144 ti = kwargs["ti"]
145
146
147 :param python_callable: A reference to an object that is callable
148 :param op_args: a list of positional arguments that will get unpacked when
149 calling your callable
150 :param op_kwargs: a dictionary of keyword arguments that will get unpacked
151 in your function
152 :param templates_dict: a dictionary where the values are templates that
153 will get templated by the Airflow engine sometime between
154 ``__init__`` and ``execute`` takes place and are made available
155 in your callable's context after the template has been applied. (templated)
156 :param templates_exts: a list of file extensions to resolve while
157 processing templated fields, for examples ``['.sql', '.hql']``
158 :param show_return_value_in_logs: a bool value whether to show return_value
159 logs. Defaults to True, which allows return value log output.
160 It can be set to False to prevent log output of return value when you return huge data
161 such as transmission a large amount of XCom to TaskAPI.
162 """
163
164 template_fields: Sequence[str] = ("templates_dict", "op_args", "op_kwargs")
165 template_fields_renderers = {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}
166 BLUE = "#ffefeb"
167 ui_color = BLUE
168
169 # since we won't mutate the arguments, we should just do the shallow copy
170 # there are some cases we can't deepcopy the objects(e.g protobuf).
171 shallow_copy_attrs: Sequence[str] = ("python_callable", "op_kwargs")
172
173 def __init__(
174 self,
175 *,
176 python_callable: Callable,
177 op_args: Collection[Any] | None = None,
178 op_kwargs: Mapping[str, Any] | None = None,
179 templates_dict: dict[str, Any] | None = None,
180 templates_exts: Sequence[str] | None = None,
181 show_return_value_in_logs: bool = True,
182 **kwargs,
183 ) -> None:
184 super().__init__(**kwargs)
185 if not callable(python_callable):
186 raise AirflowException("`python_callable` param must be callable")
187 self.python_callable = python_callable
188 self.op_args = op_args or ()
189 self.op_kwargs = op_kwargs or {}
190 self.templates_dict = templates_dict
191 if templates_exts:
192 self.template_ext = templates_exts
193 self.show_return_value_in_logs = show_return_value_in_logs
194
195 def execute(self, context: Context) -> Any:
196 context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
197 self.op_kwargs = self.determine_kwargs(context)
198
199 # This needs to be lazy because subclasses may implement execute_callable
200 # by running a separate process that can't use the eager result.
201 def __prepare_execution() -> tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None:
202 if AIRFLOW_V_3_0_PLUS:
203 from airflow.sdk.execution_time.callback_runner import create_executable_runner
204 from airflow.sdk.execution_time.context import context_get_outlet_events
205
206 return create_executable_runner, context_get_outlet_events(context)
207 from airflow.utils.context import context_get_outlet_events # type: ignore
208 from airflow.utils.operator_helpers import ExecutionCallableRunner # type: ignore
209
210 return ExecutionCallableRunner, context_get_outlet_events(context)
211
212 self.__prepare_execution = __prepare_execution
213
214 return_value = self.execute_callable()
215 if self.show_return_value_in_logs:
216 self.log.info("Done. Returned value was: %s", return_value)
217 else:
218 self.log.info("Done. Returned value not shown")
219
220 return return_value
221
222 def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
223 return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking()
224
225 __prepare_execution: Callable[[], tuple[ExecutionCallableRunner, OutletEventAccessorsProtocol] | None]
226
227 def execute_callable(self) -> Any:
228 """
229 Call the python callable with the given arguments.
230
231 :return: the return value of the call.
232 """
233 if (execution_preparation := self.__prepare_execution()) is None:
234 return self.python_callable(*self.op_args, **self.op_kwargs)
235 create_execution_runner, asset_events = execution_preparation
236 runner = create_execution_runner(self.python_callable, asset_events, logger=self.log)
237 return runner.run(*self.op_args, **self.op_kwargs)
238
239
240class BranchPythonOperator(BaseBranchOperator, PythonOperator):
241 """
242 A workflow can "branch" or follow a path after the execution of this task.
243
244 It derives the PythonOperator and expects a Python function that returns
245 a single task_id, a single task_group_id, or a list of task_ids and/or
246 task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned
247 should point to a task or task group directly downstream from {self}. All
248 other "branches" or directly downstream tasks are marked with a state of
249 ``skipped`` so that these paths can't move forward. The ``skipped`` states
250 are propagated downstream to allow for the DAG state to fill up and
251 the DAG run's state to be inferred.
252 """
253
254 def choose_branch(self, context: Context) -> str | Iterable[str]:
255 return PythonOperator.execute(self, context)
256
257
258class ShortCircuitOperator(PythonOperator, SkipMixin):
259 """
260 Allows a pipeline to continue based on the result of a ``python_callable``.
261
262 The ShortCircuitOperator is derived from the PythonOperator and evaluates the result of a
263 ``python_callable``. If the returned result is False or a falsy value, the pipeline will be
264 short-circuited. Downstream tasks will be marked with a state of "skipped" based on the short-circuiting
265 mode configured. If the returned result is True or a truthy value, downstream tasks proceed as normal and
266 an ``XCom`` of the returned result is pushed.
267
268 The short-circuiting can be configured to either respect or ignore the ``trigger_rule`` set for
269 downstream tasks. If ``ignore_downstream_trigger_rules`` is set to True, the default setting, all
270 downstream tasks are skipped without considering the ``trigger_rule`` defined for tasks. However, if this
271 parameter is set to False, the direct downstream tasks are skipped but the specified ``trigger_rule`` for
272 other subsequent downstream tasks are respected. In this mode, the operator assumes the direct downstream
273 tasks were purposely meant to be skipped but perhaps not other subsequent tasks.
274
275 .. seealso::
276 For more information on how to use this operator, take a look at the guide:
277 :ref:`howto/operator:ShortCircuitOperator`
278
279 :param ignore_downstream_trigger_rules: If set to True, all downstream tasks from this operator task will
280 be skipped. This is the default behavior. If set to False, the direct, downstream task(s) will be
281 skipped but the ``trigger_rule`` defined for all other downstream tasks will be respected.
282 """
283
284 inherits_from_skipmixin = True
285
286 def __init__(self, *, ignore_downstream_trigger_rules: bool = True, **kwargs) -> None:
287 super().__init__(**kwargs)
288 self.ignore_downstream_trigger_rules = ignore_downstream_trigger_rules
289
290 def execute(self, context: Context) -> Any:
291 condition = super().execute(context)
292 self.log.info("Condition result is %s", condition)
293
294 if condition:
295 self.log.info("Proceeding with downstream tasks...")
296 return condition
297
298 if not self.downstream_task_ids:
299 self.log.info("No downstream tasks; nothing to do.")
300 return condition
301
302 dag_run = context["dag_run"]
303
304 def get_tasks_to_skip():
305 if self.ignore_downstream_trigger_rules is True:
306 tasks = context["task"].get_flat_relatives(upstream=False)
307 else:
308 tasks = context["task"].get_direct_relatives(upstream=False)
309 for t in tasks:
310 if not t.is_teardown:
311 yield t
312
313 to_skip = get_tasks_to_skip()
314
315 # this lets us avoid an intermediate list unless debug logging
316 if self.log.getEffectiveLevel() <= logging.DEBUG:
317 self.log.debug("Downstream task IDs %s", to_skip := list(get_tasks_to_skip()))
318
319 self.log.info("Skipping downstream tasks")
320 if AIRFLOW_V_3_0_PLUS:
321 self.skip(
322 ti=context["ti"],
323 tasks=to_skip,
324 )
325 else:
326 if to_skip:
327 self.skip(
328 dag_run=context["dag_run"],
329 tasks=to_skip,
330 execution_date=cast("DateTime", dag_run.logical_date), # type: ignore[call-arg]
331 map_index=context["ti"].map_index,
332 )
333
334 self.log.info("Done.")
335 # returns the result of the super execute method as it is instead of returning None
336 return condition
337
338
339def _load_pickle():
340 import pickle
341
342 return pickle
343
344
345def _load_dill():
346 try:
347 import dill
348 except ModuleNotFoundError:
349 log.error("Unable to import `dill` module. Please please make sure that it installed.")
350 raise
351 return dill
352
353
354def _load_cloudpickle():
355 try:
356 import cloudpickle
357 except ModuleNotFoundError:
358 log.error(
359 "Unable to import `cloudpickle` module. "
360 "Please install it with: pip install 'apache-airflow[cloudpickle]'"
361 )
362 raise
363 return cloudpickle
364
365
366_SERIALIZERS: dict[_SerializerTypeDef, Any] = {
367 "pickle": lazy_object_proxy.Proxy(_load_pickle),
368 "dill": lazy_object_proxy.Proxy(_load_dill),
369 "cloudpickle": lazy_object_proxy.Proxy(_load_cloudpickle),
370}
371
372
373class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta):
374 BASE_SERIALIZABLE_CONTEXT_KEYS = {
375 "ds",
376 "ds_nodash",
377 "expanded_ti_count",
378 "inlets",
379 "outlets",
380 "run_id",
381 "task_instance_key_str",
382 "test_mode",
383 "ts",
384 "ts_nodash",
385 "ts_nodash_with_tz",
386 # The following should be removed when Airflow 2 support is dropped.
387 "next_ds",
388 "next_ds_nodash",
389 "prev_ds",
390 "prev_ds_nodash",
391 "tomorrow_ds",
392 "tomorrow_ds_nodash",
393 "yesterday_ds",
394 "yesterday_ds_nodash",
395 }
396 if AIRFLOW_V_3_0_PLUS:
397 BASE_SERIALIZABLE_CONTEXT_KEYS.add("task_reschedule_count")
398
399 PENDULUM_SERIALIZABLE_CONTEXT_KEYS = {
400 "data_interval_end",
401 "data_interval_start",
402 "logical_date",
403 "prev_data_interval_end_success",
404 "prev_data_interval_start_success",
405 "prev_start_date_success",
406 "prev_end_date_success",
407 # The following should be removed when Airflow 2 support is dropped.
408 "execution_date",
409 "next_execution_date",
410 "prev_execution_date",
411 "prev_execution_date_success",
412 }
413
414 AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = {
415 "macros",
416 "conf",
417 "dag",
418 "dag_run",
419 "task",
420 "params",
421 "triggering_asset_events",
422 # The following should be removed when Airflow 2 support is dropped.
423 "triggering_dataset_events",
424 }
425
426 def __init__(
427 self,
428 *,
429 python_callable: Callable,
430 serializer: _SerializerTypeDef | None = None,
431 op_args: Collection[Any] | None = None,
432 op_kwargs: Mapping[str, Any] | None = None,
433 string_args: Iterable[str] | None = None,
434 templates_dict: dict | None = None,
435 templates_exts: list[str] | None = None,
436 expect_airflow: bool = True,
437 skip_on_exit_code: int | Container[int] | None = None,
438 env_vars: dict[str, str] | None = None,
439 inherit_env: bool = True,
440 **kwargs,
441 ):
442 if (
443 not isinstance(python_callable, types.FunctionType)
444 or isinstance(python_callable, types.LambdaType)
445 and python_callable.__name__ == "<lambda>"
446 ):
447 raise ValueError(f"{type(self).__name__} only supports functions for python_callable arg")
448 if inspect.isgeneratorfunction(python_callable):
449 raise ValueError(f"{type(self).__name__} does not support using 'yield' in python_callable")
450 super().__init__(
451 python_callable=python_callable,
452 op_args=op_args,
453 op_kwargs=op_kwargs,
454 templates_dict=templates_dict,
455 templates_exts=templates_exts,
456 **kwargs,
457 )
458 self.string_args = string_args or []
459
460 serializer = serializer or "pickle"
461 if serializer not in _SERIALIZERS:
462 msg = (
463 f"Unsupported serializer {serializer!r}. Expected one of {', '.join(map(repr, _SERIALIZERS))}"
464 )
465 raise AirflowException(msg)
466
467 self.pickling_library = _SERIALIZERS[serializer]
468 self.serializer: _SerializerTypeDef = serializer
469
470 self.expect_airflow = expect_airflow
471 self.skip_on_exit_code = (
472 skip_on_exit_code
473 if isinstance(skip_on_exit_code, Container)
474 else [skip_on_exit_code]
475 if skip_on_exit_code is not None
476 else []
477 )
478 self.env_vars = env_vars
479 self.inherit_env = inherit_env
480
481 @abstractmethod
482 def _iter_serializable_context_keys(self):
483 pass
484
485 def execute(self, context: Context) -> Any:
486 serializable_keys = set(self._iter_serializable_context_keys())
487 new = {k: v for k, v in context.items() if k in serializable_keys}
488 serializable_context = cast("Context", new)
489 return super().execute(context=serializable_context)
490
491 def get_python_source(self):
492 """Return the source of self.python_callable."""
493 return textwrap.dedent(inspect.getsource(self.python_callable))
494
495 def _write_args(self, file: Path):
496 def resolve_proxies(obj):
497 """Recursively replaces lazy_object_proxy.Proxy instances with their resolved values."""
498 if isinstance(obj, lazy_object_proxy.Proxy):
499 return obj.__wrapped__ # force evaluation
500 if isinstance(obj, dict):
501 return {k: resolve_proxies(v) for k, v in obj.items()}
502 if isinstance(obj, list):
503 return [resolve_proxies(v) for v in obj]
504 return obj
505
506 if self.op_args or self.op_kwargs:
507 self.log.info("Use %r as serializer.", self.serializer)
508 file.write_bytes(
509 self.pickling_library.dumps({"args": self.op_args, "kwargs": resolve_proxies(self.op_kwargs)})
510 )
511
512 def _write_string_args(self, file: Path):
513 file.write_text("\n".join(map(str, self.string_args)))
514
515 def _read_result(self, path: Path):
516 if path.stat().st_size == 0:
517 return None
518 try:
519 return self.pickling_library.loads(path.read_bytes())
520 except ValueError as value_error:
521 raise DeserializingResultError() from value_error
522
523 def __deepcopy__(self, memo):
524 # module objects can't be copied _at all__
525 memo[id(self.pickling_library)] = self.pickling_library
526 return super().__deepcopy__(memo)
527
528 def _execute_python_callable_in_subprocess(self, python_path: Path):
529 with TemporaryDirectory(prefix="venv-call") as tmp:
530 tmp_dir = Path(tmp)
531 op_kwargs: dict[str, Any] = dict(self.op_kwargs)
532 if self.templates_dict:
533 op_kwargs["templates_dict"] = self.templates_dict
534 input_path = tmp_dir / "script.in"
535 output_path = tmp_dir / "script.out"
536 string_args_path = tmp_dir / "string_args.txt"
537 script_path = tmp_dir / "script.py"
538 termination_log_path = tmp_dir / "termination.log"
539 airflow_context_path = tmp_dir / "airflow_context.json"
540
541 self._write_args(input_path)
542 self._write_string_args(string_args_path)
543
544 jinja_context = {
545 "op_args": self.op_args,
546 "op_kwargs": op_kwargs,
547 "expect_airflow": self.expect_airflow,
548 "pickling_library": self.serializer,
549 "python_callable": self.python_callable.__name__,
550 "python_callable_source": self.get_python_source(),
551 }
552
553 if inspect.getfile(self.python_callable) == self.dag.fileloc:
554 jinja_context["modified_dag_module_name"] = get_unique_dag_module_name(self.dag.fileloc)
555
556 write_python_script(
557 jinja_context=jinja_context,
558 filename=os.fspath(script_path),
559 render_template_as_native_obj=self.dag.render_template_as_native_obj,
560 )
561
562 env_vars = dict(os.environ) if self.inherit_env else {}
563 if fd := os.getenv("__AIRFLOW_SUPERVISOR_FD"):
564 env_vars["__AIRFLOW_SUPERVISOR_FD"] = fd
565 if self.env_vars:
566 env_vars.update(self.env_vars)
567
568 try:
569 cmd: list[str] = [
570 os.fspath(python_path),
571 os.fspath(script_path),
572 os.fspath(input_path),
573 os.fspath(output_path),
574 os.fspath(string_args_path),
575 os.fspath(termination_log_path),
576 os.fspath(airflow_context_path),
577 ]
578 _execute_in_subprocess(
579 cmd=cmd,
580 env=env_vars,
581 )
582 except subprocess.CalledProcessError as e:
583 if e.returncode in self.skip_on_exit_code:
584 raise AirflowSkipException(f"Process exited with code {e.returncode}. Skipping.")
585 if termination_log_path.exists() and termination_log_path.stat().st_size > 0:
586 error_msg = f"Process returned non-zero exit status {e.returncode}.\n"
587 with open(termination_log_path) as file:
588 error_msg += file.read()
589 raise AirflowException(error_msg) from None
590 raise
591
592 if 0 in self.skip_on_exit_code:
593 raise AirflowSkipException("Process exited with code 0. Skipping.")
594
595 return self._read_result(output_path)
596
597 def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
598 keyword_params = KeywordParameters.determine(self.python_callable, self.op_args, context)
599 if AIRFLOW_V_3_0_PLUS:
600 return keyword_params.unpacking()
601 return keyword_params.serializing() # type: ignore[attr-defined]
602
603
604class PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
605 """
606 Run a function in a virtualenv that is created and destroyed automatically.
607
608 The function (has certain caveats) must be defined using def, and not be
609 part of a class. All imports must happen inside the function
610 and no variables outside the scope may be referenced. A global scope
611 variable named virtualenv_string_args will be available (populated by
612 string_args). In addition, one can pass stuff through op_args and op_kwargs, and one
613 can use a return value.
614 Note that if your virtualenv runs in a different Python major version than Airflow,
615 you cannot use return values, op_args, op_kwargs, or use any macros that are being provided to
616 Airflow through plugins. You can use string_args though.
617
618 .. seealso::
619 For more information on how to use this operator, take a look at the guide:
620 :ref:`howto/operator:PythonVirtualenvOperator`
621
622 :param python_callable: A python function with no references to outside variables,
623 defined with def, which will be run in a virtual environment.
624 :param requirements: Either a list of requirement strings, or a (templated)
625 "requirements file" as specified by pip.
626 :param python_version: The Python version to run the virtual environment with. Note that
627 both 2 and 2.7 are acceptable forms.
628 :param serializer: Which serializer use to serialize the args and result. It can be one of the following:
629
630 - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library.
631 - ``"cloudpickle"``: Use cloudpickle for serialize more complex types,
632 this requires to include cloudpickle in your requirements.
633 - ``"dill"``: Use dill for serialize more complex types,
634 this requires to include dill in your requirements.
635 :param system_site_packages: Whether to include
636 system_site_packages in your virtual environment.
637 See virtualenv documentation for more information.
638 :param pip_install_options: a list of pip install options when installing requirements
639 See 'pip install -h' for available options
640 :param op_args: A list of positional arguments to pass to python_callable.
641 :param op_kwargs: A dict of keyword arguments to pass to python_callable.
642 :param string_args: Strings that are present in the global var virtualenv_string_args,
643 available to python_callable at runtime as a list[str]. Note that args are split
644 by newline.
645 :param templates_dict: a dictionary where the values are templates that
646 will get templated by the Airflow engine sometime between
647 ``__init__`` and ``execute`` takes place and are made available
648 in your callable's context after the template has been applied
649 :param templates_exts: a list of file extensions to resolve while
650 processing templated fields, for examples ``['.sql', '.hql']``
651 :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator
652 will raise warning if Airflow is not installed, and it will attempt to load Airflow
653 macros when starting.
654 :param skip_on_exit_code: If python_callable exits with this exit code, leave the task
655 in ``skipped`` state (default: None). If set to ``None``, any non-zero
656 exit code will be treated as a failure.
657 :param index_urls: an optional list of index urls to load Python packages from.
658 If not provided the system pip conf will be used to source packages from.
659 :param index_urls_from_connection_ids: An optional list of ``PackageIndex`` connection IDs.
660 Will be appended to ``index_urls``.
661 :param venv_cache_path: Optional path to the virtual environment parent folder in which the
662 virtual environment will be cached, creates a sub-folder venv-{hash} whereas hash will be replaced
663 with a checksum of requirements. If not provided the virtual environment will be created and deleted
664 in a temp folder for every execution.
665 :param env_vars: A dictionary containing additional environment variables to set for the virtual
666 environment when it is executed.
667 :param inherit_env: Whether to inherit the current environment variables when executing the virtual
668 environment. If set to ``True``, the virtual environment will inherit the environment variables
669 of the parent process (``os.environ``). If set to ``False``, the virtual environment will be
670 executed with a clean environment.
671 """
672
673 template_fields: Sequence[str] = tuple(
674 {"requirements", "index_urls", "index_urls_from_connection_ids", "venv_cache_path"}.union(
675 PythonOperator.template_fields
676 )
677 )
678 template_ext: Sequence[str] = (".txt",)
679
680 def __init__(
681 self,
682 *,
683 python_callable: Callable,
684 requirements: None | Iterable[str] | str = None,
685 python_version: str | None = None,
686 serializer: _SerializerTypeDef | None = None,
687 system_site_packages: bool = True,
688 pip_install_options: list[str] | None = None,
689 op_args: Collection[Any] | None = None,
690 op_kwargs: Mapping[str, Any] | None = None,
691 string_args: Iterable[str] | None = None,
692 templates_dict: dict | None = None,
693 templates_exts: list[str] | None = None,
694 expect_airflow: bool = True,
695 skip_on_exit_code: int | Container[int] | None = None,
696 index_urls: None | Collection[str] | str = None,
697 index_urls_from_connection_ids: None | Collection[str] | str = None,
698 venv_cache_path: None | os.PathLike[str] = None,
699 env_vars: dict[str, str] | None = None,
700 inherit_env: bool = True,
701 **kwargs,
702 ):
703 if (
704 python_version
705 and str(python_version)[0] != str(sys.version_info.major)
706 and (op_args or op_kwargs)
707 ):
708 raise AirflowException(
709 "Passing op_args or op_kwargs is not supported across different Python "
710 "major versions for PythonVirtualenvOperator. Please use string_args."
711 f"Sys version: {sys.version_info}. Virtual environment version: {python_version}"
712 )
713 if python_version is not None and not isinstance(python_version, str):
714 raise AirflowException(
715 "Passing non-string types (e.g. int or float) as python_version not supported"
716 )
717 if not requirements:
718 self.requirements: list[str] = []
719 elif isinstance(requirements, str):
720 self.requirements = [requirements]
721 else:
722 self.requirements = list(requirements)
723 self.python_version = python_version
724 self.system_site_packages = system_site_packages
725 self.pip_install_options = pip_install_options
726 if isinstance(index_urls, str):
727 self.index_urls: list[str] | None = [index_urls]
728 elif isinstance(index_urls, Collection):
729 self.index_urls = list(index_urls)
730 else:
731 self.index_urls = None
732 if isinstance(index_urls_from_connection_ids, str):
733 self.index_urls_from_connection_ids: list[str] | None = [index_urls_from_connection_ids]
734 elif isinstance(index_urls_from_connection_ids, Collection):
735 self.index_urls_from_connection_ids = list(index_urls_from_connection_ids)
736 else:
737 self.index_urls_from_connection_ids = None
738 self.venv_cache_path = venv_cache_path
739 super().__init__(
740 python_callable=python_callable,
741 serializer=serializer,
742 op_args=op_args,
743 op_kwargs=op_kwargs,
744 string_args=string_args,
745 templates_dict=templates_dict,
746 templates_exts=templates_exts,
747 expect_airflow=expect_airflow,
748 skip_on_exit_code=skip_on_exit_code,
749 env_vars=env_vars,
750 inherit_env=inherit_env,
751 **kwargs,
752 )
753
754 def _requirements_list(self, exclude_cloudpickle: bool = False) -> list[str]:
755 """Prepare a list of requirements that need to be installed for the virtual environment."""
756 requirements = [str(dependency) for dependency in self.requirements]
757 if not self.system_site_packages:
758 if (
759 self.serializer == "cloudpickle"
760 and not exclude_cloudpickle
761 and "cloudpickle" not in requirements
762 ):
763 requirements.append("cloudpickle")
764 elif self.serializer == "dill" and "dill" not in requirements:
765 requirements.append("dill")
766 requirements.sort() # Ensure a hash is stable
767 return requirements
768
769 def _prepare_venv(self, venv_path: Path) -> None:
770 """Prepare the requirements and installs the virtual environment."""
771 requirements_file = venv_path / "requirements.txt"
772 requirements_file.write_text("\n".join(self._requirements_list()))
773 prepare_virtualenv(
774 venv_directory=str(venv_path),
775 python_bin=f"python{self.python_version}" if self.python_version else "python",
776 system_site_packages=self.system_site_packages,
777 requirements_file_path=str(requirements_file),
778 pip_install_options=self.pip_install_options,
779 index_urls=self.index_urls,
780 )
781
782 def _calculate_cache_hash(self, exclude_cloudpickle: bool = False) -> tuple[str, str]:
783 """
784 Generate the hash of the cache folder to use.
785
786 The following factors are used as input for the hash:
787 - (sorted) list of requirements
788 - pip install options
789 - flag of system site packages
790 - python version
791 - Variable to override the hash with a cache key
792 - Index URLs
793
794 Returns a hash and the data dict which is the base for the hash as text.
795 """
796 hash_dict = {
797 "requirements_list": self._requirements_list(exclude_cloudpickle=exclude_cloudpickle),
798 "pip_install_options": self.pip_install_options,
799 "index_urls": self.index_urls,
800 "cache_key": str(Variable.get("PythonVirtualenvOperator.cache_key", "")),
801 "python_version": self.python_version,
802 "system_site_packages": self.system_site_packages,
803 }
804 hash_text = json.dumps(hash_dict, sort_keys=True)
805 hash_object = hashlib_wrapper.md5(hash_text.encode())
806 requirements_hash = hash_object.hexdigest()
807 return requirements_hash[:8], hash_text
808
809 def _ensure_venv_cache_exists(self, venv_cache_path: Path) -> Path:
810 """Ensure a valid virtual environment is set up and will create inplace."""
811 cache_hash, hash_data = self._calculate_cache_hash()
812 venv_path = venv_cache_path / f"venv-{cache_hash}"
813 self.log.info("Python virtual environment will be cached in %s", venv_path)
814 venv_path.parent.mkdir(parents=True, exist_ok=True)
815 with open(f"{venv_path}.lock", "w") as f:
816 # Ensure that cache is not build by parallel workers
817 import fcntl
818
819 fcntl.flock(f, fcntl.LOCK_EX)
820
821 hash_marker = venv_path / "install_complete_marker.json"
822 try:
823 if venv_path.exists():
824 if hash_marker.exists():
825 previous_hash_data = hash_marker.read_text(encoding="utf8")
826 if previous_hash_data == hash_data:
827 self.log.info("Reusing cached Python virtual environment in %s", venv_path)
828 return venv_path
829
830 _, hash_data_before_upgrade = self._calculate_cache_hash(exclude_cloudpickle=True)
831 if previous_hash_data == hash_data_before_upgrade:
832 self.log.warning(
833 "Found a previous virtual environment in with outdated dependencies %s, "
834 "deleting and re-creating.",
835 venv_path,
836 )
837 else:
838 self.log.error(
839 "Unicorn alert: Found a previous virtual environment in %s "
840 "with the same hash but different parameters. Previous setup: '%s' / "
841 "Requested venv setup: '%s'. Please report a bug to airflow!",
842 venv_path,
843 previous_hash_data,
844 hash_data,
845 )
846 else:
847 self.log.warning(
848 "Found a previous (probably partial installed) virtual environment in %s, "
849 "deleting and re-creating.",
850 venv_path,
851 )
852
853 shutil.rmtree(venv_path)
854
855 venv_path.mkdir(parents=True)
856 self._prepare_venv(venv_path)
857 hash_marker.write_text(hash_data, encoding="utf8")
858 except Exception as e:
859 shutil.rmtree(venv_path)
860 raise AirflowException(f"Unable to create new virtual environment in {venv_path}") from e
861 self.log.info("New Python virtual environment created in %s", venv_path)
862 return venv_path
863
864 def _cleanup_python_pycache_dir(self, cache_dir_path: Path) -> None:
865 try:
866 shutil.rmtree(cache_dir_path)
867 self.log.debug("The directory %s has been deleted.", cache_dir_path)
868 except FileNotFoundError:
869 self.log.warning("Fail to delete %s. The directory does not exist.", cache_dir_path)
870 except PermissionError:
871 self.log.warning("Permission denied to delete the directory %s.", cache_dir_path)
872
873 def _retrieve_index_urls_from_connection_ids(self):
874 """Retrieve index URLs from Package Index connections."""
875 if self.index_urls is None:
876 self.index_urls = []
877 for conn_id in self.index_urls_from_connection_ids:
878 conn_url = PackageIndexHook(conn_id).get_connection_url()
879 self.index_urls.append(conn_url)
880
881 def execute_callable(self):
882 if self.index_urls_from_connection_ids:
883 self._retrieve_index_urls_from_connection_ids()
884
885 if self.venv_cache_path:
886 venv_path = self._ensure_venv_cache_exists(Path(self.venv_cache_path))
887 python_path = venv_path / "bin" / "python"
888 return self._execute_python_callable_in_subprocess(python_path)
889
890 with TemporaryDirectory(prefix="venv") as tmp_dir:
891 tmp_path = Path(tmp_dir)
892 custom_pycache_prefix = Path(sys.pycache_prefix or "")
893 r_path = tmp_path.relative_to(tmp_path.anchor)
894 venv_python_cache_dir = Path.cwd() / custom_pycache_prefix / r_path
895 self._prepare_venv(tmp_path)
896 python_path = tmp_path / "bin" / "python"
897 result = self._execute_python_callable_in_subprocess(python_path)
898 self._cleanup_python_pycache_dir(venv_python_cache_dir)
899 return result
900
901 def _iter_serializable_context_keys(self):
902 yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS
903
904 found_airflow = found_pendulum = False
905
906 if self.system_site_packages:
907 # If we're using system packages, assume both are present
908 found_airflow = found_pendulum = True
909 else:
910 for raw_str in chain.from_iterable(req.splitlines() for req in self.requirements):
911 line = raw_str.strip()
912 # Skip blank lines and full‐line comments
913 if not line or line.startswith("#"):
914 continue
915
916 # Strip off any inline comment
917 # e.g. turn "foo==1.2.3 # comment" → "foo==1.2.3"
918 req_str = re.sub(r"#.*$", "", line).strip()
919
920 try:
921 req = Requirement(req_str)
922 except (InvalidRequirement, InvalidSpecifier, InvalidVersion) as e:
923 raise ValueError(f"Invalid requirement '{raw_str}': {e}") from e
924
925 if req.name == "apache-airflow":
926 found_airflow = found_pendulum = True
927 break
928 elif req.name == "pendulum":
929 found_pendulum = True
930
931 if found_airflow:
932 yield from self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS
933 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
934 elif found_pendulum:
935 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
936
937
938class BranchPythonVirtualenvOperator(BaseBranchOperator, PythonVirtualenvOperator):
939 """
940 A workflow can "branch" or follow a path after the execution of this task in a virtual environment.
941
942 It derives the PythonVirtualenvOperator and expects a Python function that returns
943 a single task_id, a single task_group_id, or a list of task_ids and/or
944 task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned
945 should point to a task or task group directly downstream from {self}. All
946 other "branches" or directly downstream tasks are marked with a state of
947 ``skipped`` so that these paths can't move forward. The ``skipped`` states
948 are propagated downstream to allow for the DAG state to fill up and
949 the DAG run's state to be inferred.
950
951 .. seealso::
952 For more information on how to use this operator, take a look at the guide:
953 :ref:`howto/operator:BranchPythonVirtualenvOperator`
954 """
955
956 def choose_branch(self, context: Context) -> str | Iterable[str]:
957 return PythonVirtualenvOperator.execute(self, context)
958
959
960class ExternalPythonOperator(_BasePythonVirtualenvOperator):
961 """
962 Run a function in a virtualenv that is not re-created.
963
964 Reused as is without the overhead of creating the virtual environment (with certain caveats).
965
966 The function must be defined using def, and not be
967 part of a class. All imports must happen inside the function
968 and no variables outside the scope may be referenced. A global scope
969 variable named virtualenv_string_args will be available (populated by
970 string_args). In addition, one can pass stuff through op_args and op_kwargs, and one
971 can use a return value.
972 Note that if your virtual environment runs in a different Python major version than Airflow,
973 you cannot use return values, op_args, op_kwargs, or use any macros that are being provided to
974 Airflow through plugins. You can use string_args though.
975
976 If Airflow is installed in the external environment in different version that the version
977 used by the operator, the operator will fail.,
978
979 .. seealso::
980 For more information on how to use this operator, take a look at the guide:
981 :ref:`howto/operator:ExternalPythonOperator`
982
983 :param python: Full path string (file-system specific) that points to a Python binary inside
984 a virtual environment that should be used (in ``VENV/bin`` folder). Should be absolute path
985 (so usually start with "/" or "X:/" depending on the filesystem/os used).
986 :param python_callable: A python function with no references to outside variables,
987 defined with def, which will be run in a virtual environment.
988 :param serializer: Which serializer use to serialize the args and result. It can be one of the following:
989
990 - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library.
991 - ``"cloudpickle"``: Use cloudpickle for serialize more complex types,
992 this requires to include cloudpickle in your requirements.
993 - ``"dill"``: Use dill for serialize more complex types,
994 this requires to include dill in your requirements.
995 :param op_args: A list of positional arguments to pass to python_callable.
996 :param op_kwargs: A dict of keyword arguments to pass to python_callable.
997 :param string_args: Strings that are present in the global var virtualenv_string_args,
998 available to python_callable at runtime as a list[str]. Note that args are split
999 by newline.
1000 :param templates_dict: a dictionary where the values are templates that
1001 will get templated by the Airflow engine sometime between
1002 ``__init__`` and ``execute`` takes place and are made available
1003 in your callable's context after the template has been applied
1004 :param templates_exts: a list of file extensions to resolve while
1005 processing templated fields, for examples ``['.sql', '.hql']``
1006 :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator
1007 will raise warning if Airflow is not installed, and it will attempt to load Airflow
1008 macros when starting.
1009 :param skip_on_exit_code: If python_callable exits with this exit code, leave the task
1010 in ``skipped`` state (default: None). If set to ``None``, any non-zero
1011 exit code will be treated as a failure.
1012 :param env_vars: A dictionary containing additional environment variables to set for the virtual
1013 environment when it is executed.
1014 :param inherit_env: Whether to inherit the current environment variables when executing the virtual
1015 environment. If set to ``True``, the virtual environment will inherit the environment variables
1016 of the parent process (``os.environ``). If set to ``False``, the virtual environment will be
1017 executed with a clean environment.
1018 """
1019
1020 template_fields: Sequence[str] = tuple({"python"}.union(PythonOperator.template_fields))
1021
1022 def __init__(
1023 self,
1024 *,
1025 python: str,
1026 python_callable: Callable,
1027 serializer: _SerializerTypeDef | None = None,
1028 op_args: Collection[Any] | None = None,
1029 op_kwargs: Mapping[str, Any] | None = None,
1030 string_args: Iterable[str] | None = None,
1031 templates_dict: dict | None = None,
1032 templates_exts: list[str] | None = None,
1033 expect_airflow: bool = True,
1034 expect_pendulum: bool = False,
1035 skip_on_exit_code: int | Container[int] | None = None,
1036 env_vars: dict[str, str] | None = None,
1037 inherit_env: bool = True,
1038 **kwargs,
1039 ):
1040 if not python:
1041 raise ValueError("Python Path must be defined in ExternalPythonOperator")
1042 self.python = python
1043 self.expect_pendulum = expect_pendulum
1044 super().__init__(
1045 python_callable=python_callable,
1046 serializer=serializer,
1047 op_args=op_args,
1048 op_kwargs=op_kwargs,
1049 string_args=string_args,
1050 templates_dict=templates_dict,
1051 templates_exts=templates_exts,
1052 expect_airflow=expect_airflow,
1053 skip_on_exit_code=skip_on_exit_code,
1054 env_vars=env_vars,
1055 inherit_env=inherit_env,
1056 **kwargs,
1057 )
1058
1059 def execute_callable(self):
1060 python_path = Path(self.python)
1061 if not python_path.exists():
1062 raise ValueError(f"Python Path '{python_path}' must exists")
1063 if not python_path.is_file():
1064 raise ValueError(f"Python Path '{python_path}' must be a file")
1065 if not python_path.is_absolute():
1066 raise ValueError(f"Python Path '{python_path}' must be an absolute path.")
1067 python_version = _PythonVersionInfo.from_executable(self.python)
1068 if python_version.major != sys.version_info.major and (self.op_args or self.op_kwargs):
1069 raise AirflowException(
1070 "Passing op_args or op_kwargs is not supported across different Python "
1071 "major versions for ExternalPythonOperator. Please use string_args."
1072 f"Sys version: {sys.version_info}. "
1073 f"Virtual environment version: {python_version}"
1074 )
1075 return self._execute_python_callable_in_subprocess(python_path)
1076
1077 def _iter_serializable_context_keys(self):
1078 yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS
1079 if self.expect_airflow and self._get_airflow_version_from_target_env():
1080 yield from self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS
1081 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
1082 elif self._is_pendulum_installed_in_target_env():
1083 yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
1084
1085 def _is_pendulum_installed_in_target_env(self) -> bool:
1086 try:
1087 subprocess.check_call([self.python, "-c", "import pendulum"])
1088 return True
1089 except Exception as e:
1090 if self.expect_pendulum:
1091 self.log.warning("When checking for Pendulum installed in virtual environment got %s", e)
1092 self.log.warning(
1093 "Pendulum is not properly installed in the virtual environment "
1094 "Pendulum context keys will not be available. "
1095 "Please Install Pendulum or Airflow in your virtual environment to access them."
1096 )
1097 return False
1098
1099 @property
1100 def _external_airflow_version_script(self):
1101 """
1102 Return python script which determines the version of the Apache Airflow.
1103
1104 Import airflow as a module might take a while as a result,
1105 obtaining a version would take up to 1 second.
1106 On the other hand, `importlib.metadata.version` will retrieve the package version pretty fast
1107 something below 100ms; this includes new subprocess overhead.
1108
1109 Possible side effect: It might be a situation that `importlib.metadata` is not available (Python < 3.8),
1110 as well as backport `importlib_metadata` which might indicate that venv doesn't contain an `apache-airflow`
1111 or something wrong with the environment.
1112 """
1113 return textwrap.dedent(
1114 """
1115 try:
1116 from importlib.metadata import version
1117 except ImportError:
1118 from importlib_metadata import version
1119 print(version("apache-airflow"))
1120 """
1121 )
1122
1123 def _get_airflow_version_from_target_env(self) -> str | None:
1124 from airflow import __version__ as airflow_version
1125
1126 try:
1127 result = subprocess.check_output(
1128 [self.python, "-c", self._external_airflow_version_script],
1129 text=True,
1130 )
1131 target_airflow_version = result.strip()
1132 if target_airflow_version != airflow_version:
1133 raise AirflowConfigException(
1134 f"The version of Airflow installed for the {self.python} "
1135 f"({target_airflow_version}) is different than the runtime Airflow version: "
1136 f"{airflow_version}. Make sure your environment has the same Airflow version "
1137 f"installed as the Airflow runtime."
1138 )
1139 return target_airflow_version
1140 except Exception as e:
1141 if self.expect_airflow:
1142 self.log.warning("When checking for Airflow installed in virtual environment got %s", e)
1143 self.log.warning(
1144 "This means that Airflow is not properly installed by %s. "
1145 "Airflow context keys will not be available. "
1146 "Please Install Airflow %s in your environment to access them.",
1147 self.python,
1148 airflow_version,
1149 )
1150 return None
1151
1152
1153class BranchExternalPythonOperator(BaseBranchOperator, ExternalPythonOperator):
1154 """
1155 A workflow can "branch" or follow a path after the execution of this task.
1156
1157 Extends ExternalPythonOperator, so expects to get Python:
1158 virtual environment that should be used (in ``VENV/bin`` folder). Should be absolute path,
1159 so it can run on separate virtual environment similarly to ExternalPythonOperator.
1160
1161 .. seealso::
1162 For more information on how to use this operator, take a look at the guide:
1163 :ref:`howto/operator:BranchExternalPythonOperator`
1164 """
1165
1166 def choose_branch(self, context: Context) -> str | Iterable[str]:
1167 return ExternalPythonOperator.execute(self, context)
1168
1169
1170def get_current_context() -> Mapping[str, Any]:
1171 """
1172 Retrieve the execution context dictionary without altering user method's signature.
1173
1174 This is the simplest method of retrieving the execution context dictionary.
1175
1176 **Old style:**
1177
1178 .. code:: python
1179
1180 def my_task(**context):
1181 ti = context["ti"]
1182
1183 **New style:**
1184
1185 .. code:: python
1186
1187 from airflow.providers.standard.operators.python import get_current_context
1188
1189
1190 def my_task():
1191 context = get_current_context()
1192 ti = context["ti"]
1193
1194 Current context will only have value if this method was called after an operator
1195 was starting to execute.
1196 """
1197 if AIRFLOW_V_3_0_PLUS:
1198 warnings.warn(
1199 "Using get_current_context from standard provider is deprecated and will be removed."
1200 "Please import `from airflow.sdk import get_current_context` and use it instead.",
1201 AirflowProviderDeprecationWarning,
1202 stacklevel=2,
1203 )
1204
1205 from airflow.sdk import get_current_context
1206
1207 return get_current_context()
1208 return _get_current_context()
1209
1210
1211def _get_current_context() -> Mapping[str, Any]:
1212 # Airflow 2.x
1213 # TODO: To be removed when Airflow 2 support is dropped
1214 from airflow.models.taskinstance import _CURRENT_CONTEXT # type: ignore[attr-defined]
1215
1216 if not _CURRENT_CONTEXT:
1217 raise RuntimeError(
1218 "Current context was requested but no context was found! Are you running within an Airflow task?"
1219 )
1220 return _CURRENT_CONTEXT[-1]