1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import functools
21import inspect
22import logging
23import pkgutil
24import sys
25from collections import defaultdict
26from collections.abc import Callable, Iterator
27from importlib import import_module
28from typing import TYPE_CHECKING
29
30from .file_discovery import (
31 find_path_from_directory as find_path_from_directory,
32)
33
34if sys.version_info >= (3, 12):
35 from importlib import metadata
36else:
37 import importlib_metadata as metadata
38
39log = logging.getLogger(__name__)
40
41EPnD = tuple[metadata.EntryPoint, metadata.Distribution]
42
43if TYPE_CHECKING:
44 from types import ModuleType
45
46
47def accepts_context(callback: Callable) -> bool:
48 """Check if callback accepts a 'context' parameter or **kwargs."""
49 try:
50 sig = inspect.signature(callback)
51 except (ValueError, TypeError):
52 return True
53 params = sig.parameters
54 return "context" in params or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())
55
56
57def accepts_keyword_args(func: Callable) -> bool:
58 """Check if a callable accepts any keyword arguments (named params or **kwargs)."""
59 try:
60 sig = inspect.signature(func)
61 except (ValueError, TypeError):
62 return True
63 return any(
64 p.kind
65 in (
66 inspect.Parameter.POSITIONAL_OR_KEYWORD,
67 inspect.Parameter.KEYWORD_ONLY,
68 inspect.Parameter.VAR_KEYWORD,
69 )
70 for p in sig.parameters.values()
71 )
72
73
74def import_string(dotted_path: str):
75 """
76 Import a dotted module path and return the attribute/class designated by the last name in the path.
77
78 Raise ImportError if the import failed.
79 """
80 # TODO: Add support for nested classes. Currently, it only works for top-level classes.
81 try:
82 module_path, class_name = dotted_path.rsplit(".", 1)
83 except ValueError:
84 raise ImportError(f"{dotted_path} doesn't look like a module path")
85
86 module = import_module(module_path)
87
88 try:
89 return getattr(module, class_name)
90 except AttributeError:
91 raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class')
92
93
94def qualname(o: object | Callable, use_qualname: bool = False, exclude_module: bool = False) -> str:
95 """
96 Convert an attribute/class/callable to a string.
97
98 By default, returns a string importable by ``import_string`` (includes module path).
99 With exclude_module=True, returns only the qualified name without module prefix,
100 useful for stable identification across deployments where module paths may vary.
101 """
102 if callable(o) and hasattr(o, "__module__"):
103 if exclude_module:
104 if hasattr(o, "__qualname__"):
105 return o.__qualname__
106 if hasattr(o, "__name__"):
107 return o.__name__
108 # Handle functools.partial objects specifically (not just any object with 'func' attr)
109 if isinstance(o, functools.partial):
110 return qualname(o.func, exclude_module=True)
111 return type(o).__qualname__
112 if use_qualname and hasattr(o, "__qualname__"):
113 return f"{o.__module__}.{o.__qualname__}"
114 if hasattr(o, "__name__"):
115 return f"{o.__module__}.{o.__name__}"
116
117 cls = o
118
119 if not isinstance(cls, type): # instance or class
120 cls = type(cls)
121
122 name = cls.__qualname__
123 module = cls.__module__
124
125 if exclude_module:
126 return name
127
128 if module and module != "__builtin__":
129 return f"{module}.{name}"
130
131 return name
132
133
134def iter_namespace(ns: ModuleType):
135 return pkgutil.iter_modules(ns.__path__, ns.__name__ + ".")
136
137
138def is_valid_dotpath(path: str) -> bool:
139 """
140 Check if a string follows valid dotpath format (ie: 'package.subpackage.module').
141
142 :param path: String to check
143 """
144 import re
145
146 if not isinstance(path, str):
147 return False
148
149 # Pattern explanation:
150 # ^ - Start of string
151 # [a-zA-Z_] - Must start with letter or underscore
152 # [a-zA-Z0-9_] - Following chars can be letters, numbers, or underscores
153 # (\.[a-zA-Z_][a-zA-Z0-9_]*)* - Can be followed by dots and valid identifiers
154 # $ - End of string
155 pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$"
156
157 return bool(re.match(pattern, path))
158
159
160@functools.cache
161def _get_grouped_entry_points() -> dict[str, list[EPnD]]:
162 mapping: dict[str, list[EPnD]] = defaultdict(list)
163 for dist in metadata.distributions():
164 try:
165 for e in dist.entry_points:
166 mapping[e.group].append((e, dist))
167 except Exception as e:
168 log.warning("Error when retrieving package metadata (skipping it): %s, %s", dist, e)
169 return mapping
170
171
172def entry_points_with_dist(group: str) -> Iterator[EPnD]:
173 """
174 Retrieve entry points of the given group.
175
176 This is like the ``entry_points()`` function from ``importlib.metadata``,
177 except it also returns the distribution the entry point was loaded from.
178
179 Note that this may return multiple distributions to the same package if they
180 are loaded from different ``sys.path`` entries. The caller site should
181 implement appropriate deduplication logic if needed.
182
183 :param group: Filter results to only this entrypoint group
184 :return: Generator of (EntryPoint, Distribution) objects for the specified groups
185 """
186 return iter(_get_grouped_entry_points()[group])