1import datetime
2import decimal
3import functools
4import logging
5import time
6import warnings
7from contextlib import contextmanager
8from hashlib import md5
9
10from django.apps import apps
11from django.db import NotSupportedError
12from django.utils.dateparse import parse_time
13
14logger = logging.getLogger("django.db.backends")
15
16
17class CursorWrapper:
18 def __init__(self, cursor, db):
19 self.cursor = cursor
20 self.db = db
21
22 WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
23
24 APPS_NOT_READY_WARNING_MSG = (
25 "Accessing the database during app initialization is discouraged. To fix this "
26 "warning, avoid executing queries in AppConfig.ready() or when your app "
27 "modules are imported."
28 )
29
30 def __getattr__(self, attr):
31 cursor_attr = getattr(self.cursor, attr)
32 if attr in CursorWrapper.WRAP_ERROR_ATTRS:
33 return self.db.wrap_database_errors(cursor_attr)
34 else:
35 return cursor_attr
36
37 def __iter__(self):
38 with self.db.wrap_database_errors:
39 yield from self.cursor
40
41 def __enter__(self):
42 return self
43
44 def __exit__(self, type, value, traceback):
45 # Close instead of passing through to avoid backend-specific behavior
46 # (#17671). Catch errors liberally because errors in cleanup code
47 # aren't useful.
48 try:
49 self.close()
50 except self.db.Database.Error:
51 pass
52
53 # The following methods cannot be implemented in __getattr__, because the
54 # code must run when the method is invoked, not just when it is accessed.
55
56 def callproc(self, procname, params=None, kparams=None):
57 # Keyword parameters for callproc aren't supported in PEP 249, but the
58 # database driver may support them (e.g. oracledb).
59 if kparams is not None and not self.db.features.supports_callproc_kwargs:
60 raise NotSupportedError(
61 "Keyword parameters for callproc are not supported on this "
62 "database backend."
63 )
64 # Raise a warning during app initialization (stored_app_configs is only
65 # ever set during testing).
66 if not apps.ready and not apps.stored_app_configs:
67 warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
68 self.db.validate_no_broken_transaction()
69 with self.db.wrap_database_errors:
70 if params is None and kparams is None:
71 return self.cursor.callproc(procname)
72 elif kparams is None:
73 return self.cursor.callproc(procname, params)
74 else:
75 params = params or ()
76 return self.cursor.callproc(procname, params, kparams)
77
78 def execute(self, sql, params=None):
79 return self._execute_with_wrappers(
80 sql, params, many=False, executor=self._execute
81 )
82
83 def executemany(self, sql, param_list):
84 return self._execute_with_wrappers(
85 sql, param_list, many=True, executor=self._executemany
86 )
87
88 def _execute_with_wrappers(self, sql, params, many, executor):
89 context = {"connection": self.db, "cursor": self}
90 for wrapper in reversed(self.db.execute_wrappers):
91 executor = functools.partial(wrapper, executor)
92 return executor(sql, params, many, context)
93
94 def _execute(self, sql, params, *ignored_wrapper_args):
95 # Raise a warning during app initialization (stored_app_configs is only
96 # ever set during testing).
97 if not apps.ready and not apps.stored_app_configs:
98 warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
99 self.db.validate_no_broken_transaction()
100 with self.db.wrap_database_errors:
101 if params is None:
102 # params default might be backend specific.
103 return self.cursor.execute(sql)
104 else:
105 return self.cursor.execute(sql, params)
106
107 def _executemany(self, sql, param_list, *ignored_wrapper_args):
108 # Raise a warning during app initialization (stored_app_configs is only
109 # ever set during testing).
110 if not apps.ready and not apps.stored_app_configs:
111 warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
112 self.db.validate_no_broken_transaction()
113 with self.db.wrap_database_errors:
114 return self.cursor.executemany(sql, param_list)
115
116
117class CursorDebugWrapper(CursorWrapper):
118 # XXX callproc isn't instrumented at this time.
119
120 def execute(self, sql, params=None):
121 with self.debug_sql(sql, params, use_last_executed_query=True):
122 return super().execute(sql, params)
123
124 def executemany(self, sql, param_list):
125 with self.debug_sql(sql, param_list, many=True):
126 return super().executemany(sql, param_list)
127
128 @contextmanager
129 def debug_sql(
130 self, sql=None, params=None, use_last_executed_query=False, many=False
131 ):
132 start = time.monotonic()
133 try:
134 yield
135 finally:
136 stop = time.monotonic()
137 duration = stop - start
138 if use_last_executed_query:
139 sql = self.db.ops.last_executed_query(self.cursor, sql, params)
140 try:
141 times = len(params) if many else ""
142 except TypeError:
143 # params could be an iterator.
144 times = "?"
145 self.db.queries_log.append(
146 {
147 "sql": "%s times: %s" % (times, sql) if many else sql,
148 "time": "%.3f" % duration,
149 }
150 )
151 logger.debug(
152 "(%.3f) %s; args=%s; alias=%s",
153 duration,
154 self.db.ops.format_debug_sql(sql),
155 params,
156 self.db.alias,
157 extra={
158 "duration": duration,
159 "sql": sql,
160 "params": params,
161 "alias": self.db.alias,
162 },
163 )
164
165
166@contextmanager
167def debug_transaction(connection, sql):
168 start = time.monotonic()
169 try:
170 yield
171 finally:
172 if connection.queries_logged:
173 stop = time.monotonic()
174 duration = stop - start
175 connection.queries_log.append(
176 {
177 "sql": "%s" % sql,
178 "time": "%.3f" % duration,
179 }
180 )
181 logger.debug(
182 "(%.3f) %s; args=%s; alias=%s",
183 duration,
184 sql,
185 None,
186 connection.alias,
187 extra={
188 "duration": duration,
189 "sql": sql,
190 "alias": connection.alias,
191 },
192 )
193
194
195def split_tzname_delta(tzname):
196 """
197 Split a time zone name into a 3-tuple of (name, sign, offset).
198 """
199 for sign in ["+", "-"]:
200 if sign in tzname:
201 name, offset = tzname.rsplit(sign, 1)
202 if offset and parse_time(offset):
203 if ":" not in offset:
204 offset = f"{offset}:00"
205 return name, sign, offset
206 return tzname, None, None
207
208
209###############################################
210# Converters from database (string) to Python #
211###############################################
212
213
214def typecast_date(s):
215 return (
216 datetime.date(*map(int, s.split("-"))) if s else None
217 ) # return None if s is null
218
219
220def typecast_time(s): # does NOT store time zone information
221 if not s:
222 return None
223 hour, minutes, seconds = s.split(":")
224 if "." in seconds: # check whether seconds have a fractional part
225 seconds, microseconds = seconds.split(".")
226 else:
227 microseconds = "0"
228 return datetime.time(
229 int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
230 )
231
232
233def typecast_timestamp(s): # does NOT store time zone information
234 # "2005-07-29 15:48:00.590358-05"
235 # "2005-07-29 09:56:00-05"
236 if not s:
237 return None
238 if " " not in s:
239 return typecast_date(s)
240 d, t = s.split()
241 # Remove timezone information.
242 if "-" in t:
243 t, _ = t.split("-", 1)
244 elif "+" in t:
245 t, _ = t.split("+", 1)
246 dates = d.split("-")
247 times = t.split(":")
248 seconds = times[2]
249 if "." in seconds: # check whether seconds have a fractional part
250 seconds, microseconds = seconds.split(".")
251 else:
252 microseconds = "0"
253 return datetime.datetime(
254 int(dates[0]),
255 int(dates[1]),
256 int(dates[2]),
257 int(times[0]),
258 int(times[1]),
259 int(seconds),
260 int((microseconds + "000000")[:6]),
261 )
262
263
264###############################################
265# Converters from Python to database (string) #
266###############################################
267
268
269def split_identifier(identifier):
270 """
271 Split an SQL identifier into a two element tuple of (namespace, name).
272
273 The identifier could be a table, column, or sequence name might be prefixed
274 by a namespace.
275 """
276 try:
277 namespace, name = identifier.split('"."')
278 except ValueError:
279 namespace, name = "", identifier
280 return namespace.strip('"'), name.strip('"')
281
282
283def truncate_name(identifier, length=None, hash_len=4):
284 """
285 Shorten an SQL identifier to a repeatable mangled version with the given
286 length.
287
288 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
289 truncate the table portion only.
290 """
291 namespace, name = split_identifier(identifier)
292
293 if length is None or len(name) <= length:
294 return identifier
295
296 digest = names_digest(name, length=hash_len)
297 return "%s%s%s" % (
298 '%s"."' % namespace if namespace else "",
299 name[: length - hash_len],
300 digest,
301 )
302
303
304def names_digest(*args, length):
305 """
306 Generate a 32-bit digest of a set of arguments that can be used to shorten
307 identifying names.
308 """
309 h = md5(usedforsecurity=False)
310 for arg in args:
311 h.update(arg.encode())
312 return h.hexdigest()[:length]
313
314
315def format_number(value, max_digits, decimal_places):
316 """
317 Format a number into a string with the requisite number of digits and
318 decimal places.
319 """
320 if value is None:
321 return None
322 context = decimal.getcontext().copy()
323 if max_digits is not None:
324 context.prec = max_digits
325 if decimal_places is not None:
326 value = value.quantize(
327 decimal.Decimal(1).scaleb(-decimal_places), context=context
328 )
329 else:
330 context.traps[decimal.Rounded] = 1
331 value = context.create_decimal(value)
332 return "{:f}".format(value)
333
334
335def strip_quotes(table_name):
336 """
337 Strip quotes off of quoted table names to make them safe for use in index
338 names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
339 scheme) becomes 'USER"."TABLE'.
340 """
341 has_quotes = table_name.startswith('"') and table_name.endswith('"')
342 return table_name[1:-1] if has_quotes else table_name