1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import inspect
21from collections.abc import Callable, Collection, Mapping
22from typing import Any, TypeVar
23
24R = TypeVar("R")
25
26
27class KeywordParameters:
28 """
29 Wrapper representing ``**kwargs`` to a callable.
30
31 The actual ``kwargs`` can be obtained by calling either ``unpacking()`` or
32 ``serializing()``. They behave almost the same and are only different if
33 the containing ``kwargs`` is an Airflow Context object, and the calling
34 function uses ``**kwargs`` in the argument list.
35
36 In this particular case, ``unpacking()`` uses ``lazy-object-proxy`` to
37 prevent the Context from emitting deprecation warnings too eagerly when it's
38 unpacked by ``**``. ``serializing()`` does not do this, and will allow the
39 warnings to be emitted eagerly, which is useful when you want to dump the
40 content and use it somewhere else without needing ``lazy-object-proxy``.
41 """
42
43 def __init__(self, kwargs: Mapping[str, Any]) -> None:
44 self._kwargs = kwargs
45
46 @classmethod
47 def determine(
48 cls,
49 func: Callable[..., Any],
50 args: Collection[Any],
51 kwargs: Mapping[str, Any],
52 ) -> KeywordParameters:
53 import itertools
54
55 signature = inspect.signature(func)
56 has_wildcard_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature.parameters.values())
57
58 for name, param in itertools.islice(signature.parameters.items(), len(args)):
59 # Keyword-only arguments can't be passed positionally and are not checked.
60 if param.kind == inspect.Parameter.KEYWORD_ONLY:
61 continue
62 if param.kind == inspect.Parameter.VAR_KEYWORD:
63 continue
64
65 # Check if args conflict with names in kwargs.
66 if name in kwargs:
67 raise ValueError(f"The key {name!r} in args is a part of kwargs and therefore reserved.")
68
69 if has_wildcard_kwargs:
70 # If the callable has a **kwargs argument, it's ready to accept all the kwargs.
71 return cls(kwargs)
72
73 # If the callable has no **kwargs argument, it only wants the arguments it requested.
74 filtered_kwargs = {key: kwargs[key] for key in signature.parameters if key in kwargs}
75 return cls(filtered_kwargs)
76
77 def unpacking(self) -> Mapping[str, Any]:
78 """Dump the kwargs mapping to unpack with ``**`` in a function call."""
79 return self._kwargs
80
81
82def determine_kwargs(
83 func: Callable[..., Any],
84 args: Collection[Any],
85 kwargs: Mapping[str, Any],
86) -> Mapping[str, Any]:
87 """
88 Inspect the signature of a callable to determine which kwargs need to be passed to the callable.
89
90 :param func: The callable that you want to invoke
91 :param args: The positional arguments that need to be passed to the callable, so we know how many to skip.
92 :param kwargs: The keyword arguments that need to be filtered before passing to the callable.
93 :return: A dictionary which contains the keyword arguments that are compatible with the callable.
94 """
95 return KeywordParameters.determine(func, args, kwargs).unpacking()
96
97
98def make_kwargs_callable(func: Callable[..., R]) -> Callable[..., R]:
99 """
100 Create a new callable that only forwards necessary arguments from any provided input.
101
102 Make a new callable that can accept any number of positional or keyword arguments
103 but only forwards those required by the given callable func.
104 """
105 import functools
106
107 @functools.wraps(func)
108 def kwargs_func(*args, **kwargs):
109 kwargs = determine_kwargs(func, args, kwargs)
110 return func(*args, **kwargs)
111
112 return kwargs_func