1"""
2Query subclasses which provide extra functionality beyond simple data retrieval.
3"""
4
5from django.core.exceptions import FieldError
6from django.db.models.sql.constants import (
7 GET_ITERATOR_CHUNK_SIZE,
8 NO_RESULTS,
9 ROW_COUNT,
10)
11from django.db.models.sql.query import Query
12
13__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
14
15
16class DeleteQuery(Query):
17 """A DELETE SQL query."""
18
19 compiler = "SQLDeleteCompiler"
20
21 def do_query(self, table, where, using):
22 self.alias_map = {table: self.alias_map[table]}
23 self.where = where
24 return self.get_compiler(using).execute_sql(ROW_COUNT)
25
26 def delete_batch(self, pk_list, using):
27 """
28 Set up and execute delete queries for all the objects in pk_list.
29
30 More than one physical query may be executed if there are a
31 lot of values in pk_list.
32 """
33 # number of objects deleted
34 num_deleted = 0
35 field = self.get_meta().pk
36 for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
37 self.clear_where()
38 self.add_filter(
39 f"{field.attname}__in",
40 pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
41 )
42 num_deleted += self.do_query(
43 self.get_meta().db_table, self.where, using=using
44 )
45 return num_deleted
46
47
48class UpdateQuery(Query):
49 """An UPDATE SQL query."""
50
51 compiler = "SQLUpdateCompiler"
52
53 def __init__(self, *args, **kwargs):
54 super().__init__(*args, **kwargs)
55 self._setup_query()
56
57 def _setup_query(self):
58 """
59 Run on initialization and at the end of chaining. Any attributes that
60 would normally be set in __init__() should go here instead.
61 """
62 self.values = []
63 self.related_ids = None
64 self.related_updates = {}
65
66 def clone(self):
67 obj = super().clone()
68 obj.related_updates = self.related_updates.copy()
69 return obj
70
71 def update_batch(self, pk_list, values, using):
72 self.add_update_values(values)
73 for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
74 self.clear_where()
75 self.add_filter(
76 "pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
77 )
78 self.get_compiler(using).execute_sql(NO_RESULTS)
79
80 def add_update_values(self, values):
81 """
82 Convert a dictionary of field name to value mappings into an update
83 query. This is the entry point for the public update() method on
84 querysets.
85 """
86 values_seq = []
87 for name, val in values.items():
88 field = self.get_meta().get_field(name)
89 direct = (
90 not (field.auto_created and not field.concrete) or not field.concrete
91 )
92 model = field.model._meta.concrete_model
93 if not direct or (field.is_relation and field.many_to_many):
94 raise FieldError(
95 "Cannot update model field %r (only non-relations and "
96 "foreign keys permitted)." % field
97 )
98 if model is not self.get_meta().concrete_model:
99 self.add_related_update(model, field, val)
100 continue
101 values_seq.append((field, model, val))
102 return self.add_update_fields(values_seq)
103
104 def add_update_fields(self, values_seq):
105 """
106 Append a sequence of (field, model, value) triples to the internal list
107 that will be used to generate the UPDATE query. Might be more usefully
108 called add_update_targets() to hint at the extra information here.
109 """
110 for field, model, val in values_seq:
111 # Omit generated fields.
112 if field.generated:
113 continue
114 if hasattr(val, "resolve_expression"):
115 # Resolve expressions here so that annotations are no longer needed
116 val = val.resolve_expression(self, allow_joins=False, for_save=True)
117 self.values.append((field, model, val))
118
119 def add_related_update(self, model, field, value):
120 """
121 Add (name, value) to an update query for an ancestor model.
122
123 Update are coalesced so that only one update query per ancestor is run.
124 """
125 self.related_updates.setdefault(model, []).append((field, None, value))
126
127 def get_related_updates(self):
128 """
129 Return a list of query objects: one for each update required to an
130 ancestor model. Each query will have the same filtering conditions as
131 the current query but will only update a single table.
132 """
133 if not self.related_updates:
134 return []
135 result = []
136 for model, values in self.related_updates.items():
137 query = UpdateQuery(model)
138 query.values = values
139 if self.related_ids is not None:
140 query.add_filter("pk__in", self.related_ids[model])
141 result.append(query)
142 return result
143
144
145class InsertQuery(Query):
146 compiler = "SQLInsertCompiler"
147
148 def __init__(
149 self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
150 ):
151 super().__init__(*args, **kwargs)
152 self.fields = []
153 self.objs = []
154 self.on_conflict = on_conflict
155 self.update_fields = update_fields or []
156 self.unique_fields = unique_fields or []
157
158 def insert_values(self, fields, objs, raw=False):
159 self.fields = fields
160 self.objs = objs
161 self.raw = raw
162
163
164class AggregateQuery(Query):
165 """
166 Take another query as a parameter to the FROM clause and only select the
167 elements in the provided list.
168 """
169
170 compiler = "SQLAggregateCompiler"
171
172 def __init__(self, model, inner_query):
173 self.inner_query = inner_query
174 super().__init__(model)