1# util/_preloaded.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Legacy routines to resolve circular module imports at runtime.
9
10These routines are replaced in 1.4.
11
12"""
13
14from functools import update_wrapper
15
16from . import compat
17
18
19class _memoized_property(object):
20 """vendored version of langhelpers.memoized_property.
21
22 not needed in the 1.4 version of preloaded.
23
24 """
25
26 def __init__(self, fget, doc=None):
27 self.fget = fget
28 self.__doc__ = doc or fget.__doc__
29 self.__name__ = fget.__name__
30
31 def __get__(self, obj, cls):
32 if obj is None:
33 return self
34 obj.__dict__[self.__name__] = result = self.fget(obj)
35 return result
36
37
38def _format_argspec_plus(fn, grouped=True):
39 """vendored version of langhelpers._format_argspec_plus.
40
41 not needed in the 1.4 version of preloaded.
42
43 """
44 if compat.callable(fn):
45 spec = compat.inspect_getfullargspec(fn)
46 else:
47 spec = fn
48
49 args = compat.inspect_formatargspec(*spec)
50 if spec[0]:
51 self_arg = spec[0][0]
52 elif spec[1]:
53 self_arg = "%s[0]" % spec[1]
54 else:
55 self_arg = None
56
57 apply_pos = compat.inspect_formatargspec(
58 spec[0], spec[1], spec[2], None, spec[4]
59 )
60 num_defaults = 0
61 if spec[3]:
62 num_defaults += len(spec[3])
63 if spec[4]:
64 num_defaults += len(spec[4])
65 name_args = spec[0] + spec[4]
66
67 if num_defaults:
68 defaulted_vals = name_args[0 - num_defaults :]
69 else:
70 defaulted_vals = ()
71
72 apply_kw = compat.inspect_formatargspec(
73 name_args,
74 spec[1],
75 spec[2],
76 defaulted_vals,
77 formatvalue=lambda x: "=" + x,
78 )
79 if grouped:
80 return dict(
81 args=args,
82 self_arg=self_arg,
83 apply_pos=apply_pos,
84 apply_kw=apply_kw,
85 )
86 else:
87 return dict(
88 args=args[1:-1],
89 self_arg=self_arg,
90 apply_pos=apply_pos[1:-1],
91 apply_kw=apply_kw[1:-1],
92 )
93
94
95class dependencies(object):
96 """Apply imported dependencies as arguments to a function.
97
98 E.g.::
99
100 @util.dependencies(
101 "sqlalchemy.sql.widget",
102 "sqlalchemy.engine.default"
103 );
104 def some_func(self, widget, default, arg1, arg2, **kw):
105 # ...
106
107 Rationale is so that the impact of a dependency cycle can be
108 associated directly with the few functions that cause the cycle,
109 and not pollute the module-level namespace.
110
111 """
112
113 def __init__(self, *deps):
114 self.import_deps = []
115 for dep in deps:
116 tokens = dep.split(".")
117 self.import_deps.append(
118 dependencies._importlater(".".join(tokens[0:-1]), tokens[-1])
119 )
120
121 def __call__(self, fn):
122 import_deps = self.import_deps
123 spec = compat.inspect_getfullargspec(fn)
124
125 spec_zero = list(spec[0])
126 hasself = spec_zero[0] in ("self", "cls")
127
128 for i in range(len(import_deps)):
129 spec[0][i + (1 if hasself else 0)] = "import_deps[%r]" % i
130
131 inner_spec = _format_argspec_plus(spec, grouped=False)
132
133 for impname in import_deps:
134 del spec_zero[1 if hasself else 0]
135 spec[0][:] = spec_zero
136
137 outer_spec = _format_argspec_plus(spec, grouped=False)
138
139 code = "lambda %(args)s: fn(%(apply_kw)s)" % {
140 "args": outer_spec["args"],
141 "apply_kw": inner_spec["apply_kw"],
142 }
143
144 decorated = eval(code, locals())
145 decorated.__defaults__ = getattr(fn, "im_func", fn).__defaults__
146 return update_wrapper(decorated, fn)
147
148 @classmethod
149 def resolve_all(cls, path):
150 for m in list(dependencies._unresolved):
151 if m._full_path.startswith(path):
152 m._resolve()
153
154 _unresolved = set()
155 _by_key = {}
156
157 class _importlater(object):
158 _unresolved = set()
159
160 _by_key = {}
161
162 def __new__(cls, path, addtl):
163 key = path + "." + addtl
164 if key in dependencies._by_key:
165 return dependencies._by_key[key]
166 else:
167 dependencies._by_key[key] = imp = object.__new__(cls)
168 return imp
169
170 def __init__(self, path, addtl):
171 self._il_path = path
172 self._il_addtl = addtl
173 dependencies._unresolved.add(self)
174
175 @property
176 def _full_path(self):
177 return self._il_path + "." + self._il_addtl
178
179 @_memoized_property
180 def module(self):
181 if self in dependencies._unresolved:
182 raise ImportError(
183 "importlater.resolve_all() hasn't "
184 "been called (this is %s %s)"
185 % (self._il_path, self._il_addtl)
186 )
187
188 return getattr(self._initial_import, self._il_addtl)
189
190 def _resolve(self):
191 dependencies._unresolved.discard(self)
192 self._initial_import = compat.import_(
193 self._il_path, globals(), locals(), [self._il_addtl]
194 )
195
196 def __getattr__(self, key):
197 if key == "module":
198 raise ImportError(
199 "Could not resolve module %s" % self._full_path
200 )
201 try:
202 attr = getattr(self.module, key)
203 except AttributeError:
204 raise AttributeError(
205 "Module %s has no attribute '%s'" % (self._full_path, key)
206 )
207 self.__dict__[key] = attr
208 return attr