1"""
2Module is used to infer Django model fields.
3"""
4from inspect import Parameter
5from typing import Any
6
7from jedi import debug
8from jedi.inference.cache import inference_state_function_cache
9from jedi.inference.base_value import ValueSet, iterator_to_value_set, ValueWrapper
10from jedi.inference.filters import DictFilter, AttributeOverwrite
11from jedi.inference.names import NameWrapper, BaseTreeParamName
12from jedi.inference.compiled.value import EmptyCompiledName
13from jedi.inference.value.instance import TreeInstance
14from jedi.inference.value.klass import ClassMixin
15from jedi.inference.gradual.base import GenericClass
16from jedi.inference.gradual.generics import TupleGenericManager
17from jedi.inference.signature import AbstractSignature
18
19
20mapping = {
21 'IntegerField': (None, 'int'),
22 'BigIntegerField': (None, 'int'),
23 'PositiveIntegerField': (None, 'int'),
24 'SmallIntegerField': (None, 'int'),
25 'CharField': (None, 'str'),
26 'TextField': (None, 'str'),
27 'EmailField': (None, 'str'),
28 'GenericIPAddressField': (None, 'str'),
29 'URLField': (None, 'str'),
30 'FloatField': (None, 'float'),
31 'BinaryField': (None, 'bytes'),
32 'BooleanField': (None, 'bool'),
33 'DecimalField': ('decimal', 'Decimal'),
34 'TimeField': ('datetime', 'time'),
35 'DurationField': ('datetime', 'timedelta'),
36 'DateField': ('datetime', 'date'),
37 'DateTimeField': ('datetime', 'datetime'),
38 'UUIDField': ('uuid', 'UUID'),
39}
40
41_FILTER_LIKE_METHODS = ('create', 'filter', 'exclude', 'update', 'get',
42 'get_or_create', 'update_or_create')
43
44
45@inference_state_function_cache()
46def _get_deferred_attributes(inference_state):
47 return inference_state.import_module(
48 ('django', 'db', 'models', 'query_utils')
49 ).py__getattribute__('DeferredAttribute').execute_annotation(None)
50
51
52def _infer_scalar_field(inference_state, field_name, field_tree_instance, is_instance):
53 try:
54 module_name, attribute_name = mapping[field_tree_instance.py__name__()]
55 except KeyError:
56 return None
57
58 if not is_instance:
59 return _get_deferred_attributes(inference_state)
60
61 if module_name is None:
62 module = inference_state.builtins_module
63 else:
64 module = inference_state.import_module((module_name,))
65
66 for attribute in module.py__getattribute__(attribute_name):
67 return attribute.execute_with_values()
68
69
70@iterator_to_value_set
71def _get_foreign_key_values(cls, field_tree_instance):
72 if isinstance(field_tree_instance, TreeInstance):
73 # TODO private access..
74 argument_iterator = field_tree_instance._arguments.unpack()
75 key, lazy_values = next(argument_iterator, (None, None))
76 if key is None and lazy_values is not None:
77 for value in lazy_values.infer():
78 if value.py__name__() == 'str':
79 foreign_key_class_name = value.get_safe_value()
80 module = cls.get_root_context()
81 for v in module.py__getattribute__(foreign_key_class_name):
82 if v.is_class():
83 yield v
84 elif value.is_class():
85 yield value
86
87
88def _infer_field(cls, field_name, is_instance):
89 inference_state = cls.inference_state
90 result = field_name.infer()
91 for field_tree_instance in result:
92 scalar_field = _infer_scalar_field(
93 inference_state, field_name, field_tree_instance, is_instance)
94 if scalar_field is not None:
95 return scalar_field
96
97 name = field_tree_instance.py__name__()
98 is_many_to_many = name == 'ManyToManyField'
99 if name in ('ForeignKey', 'OneToOneField') or is_many_to_many:
100 if not is_instance:
101 return _get_deferred_attributes(inference_state)
102
103 values = _get_foreign_key_values(cls, field_tree_instance)
104 if is_many_to_many:
105 return ValueSet(filter(None, [
106 _create_manager_for(v, 'RelatedManager') for v in values
107 ]))
108 else:
109 return values.execute_with_values()
110
111 debug.dbg('django plugin: fail to infer `%s` from class `%s`',
112 field_name.string_name, cls.py__name__())
113 return result
114
115
116class DjangoModelName(NameWrapper):
117 def __init__(self, cls, name, is_instance):
118 super().__init__(name)
119 self._cls = cls
120 self._is_instance = is_instance
121
122 def infer(self):
123 return _infer_field(self._cls, self._wrapped_name, self._is_instance)
124
125
126def _create_manager_for(cls, manager_cls='BaseManager'):
127 managers = cls.inference_state.import_module(
128 ('django', 'db', 'models', 'manager')
129 ).py__getattribute__(manager_cls)
130 for m in managers:
131 if m.is_class_mixin():
132 generics_manager = TupleGenericManager((ValueSet([cls]),))
133 for c in GenericClass(m, generics_manager).execute_annotation(None):
134 return c
135 return None
136
137
138def _new_dict_filter(cls, is_instance):
139 filters = list(cls.get_filters(
140 is_instance=is_instance,
141 include_metaclasses=False,
142 include_type_when_class=False)
143 )
144 dct: dict[str, Any] = {
145 name.string_name: DjangoModelName(cls, name, is_instance)
146 for filter_ in reversed(filters)
147 for name in filter_.values()
148 }
149 if is_instance:
150 # Replace the objects with a name that amounts to nothing when accessed
151 # in an instance. This is not perfect and still completes "objects" in
152 # that case, but it at least not inferes stuff like `.objects.filter`.
153 # It would be nicer to do that in a better way, so that it also doesn't
154 # show up in completions, but it's probably just not worth doing that
155 # for the extra amount of work.
156 dct['objects'] = EmptyCompiledName(cls.inference_state, 'objects')
157
158 return DictFilter(dct)
159
160
161def is_django_model_base(value):
162 return value.py__name__() == 'ModelBase' \
163 and value.get_root_context().py__name__() == 'django.db.models.base'
164
165
166def get_metaclass_filters(func):
167 def wrapper(cls, metaclasses, is_instance):
168 for metaclass in metaclasses:
169 if is_django_model_base(metaclass):
170 return [_new_dict_filter(cls, is_instance)]
171
172 return func(cls, metaclasses, is_instance)
173 return wrapper
174
175
176def tree_name_to_values(func):
177 def wrapper(inference_state, context, tree_name):
178 result = func(inference_state, context, tree_name)
179 if tree_name.value in _FILTER_LIKE_METHODS:
180 # Here we try to overwrite stuff like User.objects.filter. We need
181 # this to make sure that keyword param completion works on these
182 # kind of methods.
183 for v in result:
184 if v.get_qualified_names() == ('_BaseQuerySet', tree_name.value) \
185 and v.parent_context.is_module() \
186 and v.parent_context.py__name__() == 'django.db.models.query':
187 qs = context.get_value()
188 generics = qs.get_generics()
189 if len(generics) >= 1:
190 return ValueSet(QuerySetMethodWrapper(v, model)
191 for model in generics[0])
192
193 elif tree_name.value == 'BaseManager' and context.is_module() \
194 and context.py__name__() == 'django.db.models.manager':
195 return ValueSet(ManagerWrapper(r) for r in result)
196
197 elif tree_name.value == 'Field' and context.is_module() \
198 and context.py__name__() == 'django.db.models.fields':
199 return ValueSet(FieldWrapper(r) for r in result)
200 return result
201 return wrapper
202
203
204def _find_fields(cls):
205 for name in _new_dict_filter(cls, is_instance=False).values():
206 for value in name.infer():
207 if value.name.get_qualified_names(include_module_names=True) \
208 == ('django', 'db', 'models', 'query_utils', 'DeferredAttribute'):
209 yield name
210
211
212def _get_signatures(cls):
213 return [DjangoModelSignature(cls, field_names=list(_find_fields(cls)))]
214
215
216def get_metaclass_signatures(func):
217 def wrapper(cls, metaclasses):
218 for metaclass in metaclasses:
219 if is_django_model_base(metaclass):
220 return _get_signatures(cls)
221 return func(cls, metaclass)
222 return wrapper
223
224
225class ManagerWrapper(ValueWrapper):
226 def py__getitem__(self, index_value_set, contextualized_node):
227 return ValueSet(
228 GenericManagerWrapper(generic)
229 for generic in self._wrapped_value.py__getitem__(
230 index_value_set, contextualized_node)
231 )
232
233
234class GenericManagerWrapper(AttributeOverwrite, ClassMixin):
235 def py__get__on_class(self, calling_instance, instance, class_value):
236 return calling_instance.class_value.with_generics(
237 (ValueSet({class_value}),)
238 ).py__call__(calling_instance._arguments)
239
240 def with_generics(self, generics_tuple):
241 return self._wrapped_value.with_generics(generics_tuple)
242
243
244class FieldWrapper(ValueWrapper):
245 def py__getitem__(self, index_value_set, contextualized_node):
246 return ValueSet(
247 GenericFieldWrapper(generic)
248 for generic in self._wrapped_value.py__getitem__(
249 index_value_set, contextualized_node)
250 )
251
252
253class GenericFieldWrapper(AttributeOverwrite, ClassMixin):
254 def py__get__on_class(self, calling_instance, instance, class_value):
255 # This is mostly an optimization to avoid Jedi aborting inference,
256 # because of too many function executions of Field.__get__.
257 return ValueSet({calling_instance})
258
259
260class DjangoModelSignature(AbstractSignature):
261 def __init__(self, value, field_names):
262 super().__init__(value)
263 self._field_names = field_names
264
265 def get_param_names(self, resolve_stars=False):
266 return [DjangoParamName(name) for name in self._field_names]
267
268
269class DjangoParamName(BaseTreeParamName):
270 def __init__(self, field_name):
271 super().__init__(field_name.parent_context, field_name.tree_name)
272 self._field_name = field_name
273
274 def get_kind(self):
275 return Parameter.KEYWORD_ONLY
276
277 def infer(self):
278 return self._field_name.infer()
279
280
281class QuerySetMethodWrapper(ValueWrapper):
282 def __init__(self, method, model_cls):
283 super().__init__(method)
284 self._model_cls = model_cls
285
286 def py__get__(self, instance, class_value):
287 return ValueSet({QuerySetBoundMethodWrapper(v, self._model_cls)
288 for v in self._wrapped_value.py__get__(instance, class_value)})
289
290
291class QuerySetBoundMethodWrapper(ValueWrapper):
292 def __init__(self, method, model_cls):
293 super().__init__(method)
294 self._model_cls = model_cls
295
296 def get_signatures(self):
297 return _get_signatures(self._model_cls)