1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import functools
21import inspect
22import logging
23import pkgutil
24import sys
25from collections import defaultdict
26from collections.abc import Callable, Iterator
27from importlib import import_module
28from typing import TYPE_CHECKING
29
30from .dag_file import (
31 MODIFIED_DAG_MODULE_NAME as MODIFIED_DAG_MODULE_NAME,
32 UNUSUAL_MODULE_PREFIX as UNUSUAL_MODULE_PREFIX,
33)
34from .file_discovery import (
35 find_path_from_directory as find_path_from_directory,
36)
37
38if sys.version_info >= (3, 12):
39 from importlib import metadata
40else:
41 import importlib_metadata as metadata
42
43log = logging.getLogger(__name__)
44
45EPnD = tuple[metadata.EntryPoint, metadata.Distribution]
46
47if TYPE_CHECKING:
48 from types import ModuleType
49
50
51def accepts_context(callback: Callable) -> bool:
52 """Check if callback accepts a 'context' parameter or **kwargs."""
53 try:
54 sig = inspect.signature(callback)
55 except (ValueError, TypeError):
56 return True
57 params = sig.parameters
58 return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())
59
60
61def accepts_keyword_args(func: Callable) -> bool:
62 """Check if a callable accepts any keyword arguments (named params or **kwargs)."""
63 try:
64 sig = inspect.signature(func)
65 except (ValueError, TypeError):
66 return True
67 return any(
68 p.kind
69 in (
70 inspect.Parameter.POSITIONAL_OR_KEYWORD,
71 inspect.Parameter.KEYWORD_ONLY,
72 inspect.Parameter.VAR_KEYWORD,
73 )
74 for p in sig.parameters.values()
75 )
76
77
78def import_string(dotted_path: str):
79 """
80 Import a dotted module path and return the attribute/class designated by the last name in the path.
81
82 Raise ImportError if the import failed.
83 """
84 # TODO: Add support for nested classes. Currently, it only works for top-level classes.
85 try:
86 module_path, class_name = dotted_path.rsplit(".", 1)
87 except ValueError:
88 raise ImportError(f"{dotted_path} doesn't look like a module path")
89
90 module = import_module(module_path)
91
92 try:
93 return getattr(module, class_name)
94 except AttributeError:
95 raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class')
96
97
98def qualname(o: object | Callable, use_qualname: bool = False, exclude_module: bool = False) -> str:
99 """
100 Convert an attribute/class/callable to a string.
101
102 By default, returns a string importable by ``import_string`` (includes module path).
103 With exclude_module=True, returns only the qualified name without module prefix,
104 useful for stable identification across deployments where module paths may vary.
105 """
106 if callable(o) and hasattr(o, "__module__"):
107 if exclude_module:
108 if hasattr(o, "__qualname__"):
109 return o.__qualname__
110 if hasattr(o, "__name__"):
111 return o.__name__
112 # Handle functools.partial objects specifically (not just any object with 'func' attr)
113 if isinstance(o, functools.partial):
114 return qualname(o.func, exclude_module=True)
115 return type(o).__qualname__
116 if use_qualname and hasattr(o, "__qualname__"):
117 return f"{o.__module__}.{o.__qualname__}"
118 if hasattr(o, "__name__"):
119 return f"{o.__module__}.{o.__name__}"
120
121 cls = o
122
123 if not isinstance(cls, type): # instance or class
124 cls = type(cls)
125
126 name = cls.__qualname__
127 module = cls.__module__
128
129 if exclude_module:
130 return name
131
132 if module and module != "__builtin__":
133 return f"{module}.{name}"
134
135 return name
136
137
138def iter_namespace(ns: ModuleType):
139 return pkgutil.iter_modules(ns.__path__, ns.__name__ + ".")
140
141
142def is_valid_dotpath(path: str) -> bool:
143 """
144 Check if a string follows valid dotpath format (ie: 'package.subpackage.module').
145
146 :param path: String to check
147 """
148 import re
149
150 if not isinstance(path, str):
151 return False
152
153 # Pattern explanation:
154 # ^ - Start of string
155 # [a-zA-Z_] - Must start with letter or underscore
156 # [a-zA-Z0-9_] - Following chars can be letters, numbers, or underscores
157 # (\.[a-zA-Z_][a-zA-Z0-9_]*)* - Can be followed by dots and valid identifiers
158 # $ - End of string
159 pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$"
160
161 return bool(re.match(pattern, path))
162
163
164@functools.cache
165def _get_grouped_entry_points() -> dict[str, list[EPnD]]:
166 mapping: dict[str, list[EPnD]] = defaultdict(list)
167 for dist in metadata.distributions():
168 try:
169 for e in dist.entry_points:
170 mapping[e.group].append((e, dist))
171 except Exception as e:
172 log.warning("Error when retrieving package metadata (skipping it): %s, %s", dist, e)
173 return mapping
174
175
176def entry_points_with_dist(group: str) -> Iterator[EPnD]:
177 """
178 Retrieve entry points of the given group.
179
180 This is like the ``entry_points()`` function from ``importlib.metadata``,
181 except it also returns the distribution the entry point was loaded from.
182
183 Note that this may return multiple distributions to the same package if they
184 are loaded from different ``sys.path`` entries. The caller site should
185 implement appropriate deduplication logic if needed.
186
187 :param group: Filter results to only this entrypoint group
188 :return: Generator of (EntryPoint, Distribution) objects for the specified groups
189 """
190 return iter(_get_grouped_entry_points()[group])