Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pymysql/cursors.py: 19%

275 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:28 +0000

1import re 

2import warnings 

3from . import err 

4 

5 

6#: Regular expression for :meth:`Cursor.executemany`. 

7#: executemany only supports simple bulk insert. 

8#: You can use it to load large dataset. 

9RE_INSERT_VALUES = re.compile( 

10 r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" 

11 + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" 

12 + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", 

13 re.IGNORECASE | re.DOTALL, 

14) 

15 

16 

17class Cursor: 

18 """ 

19 This is the object used to interact with the database. 

20 

21 Do not create an instance of a Cursor yourself. Call 

22 connections.Connection.cursor(). 

23 

24 See `Cursor <https://www.python.org/dev/peps/pep-0249/#cursor-objects>`_ in 

25 the specification. 

26 """ 

27 

28 #: Max statement size which :meth:`executemany` generates. 

29 #: 

30 #: Max size of allowed statement is max_allowed_packet - packet_header_size. 

31 #: Default value of max_allowed_packet is 1048576. 

32 max_stmt_length = 1024000 

33 

34 def __init__(self, connection): 

35 self.connection = connection 

36 self.warning_count = 0 

37 self.description = None 

38 self.rownumber = 0 

39 self.rowcount = -1 

40 self.arraysize = 1 

41 self._executed = None 

42 self._result = None 

43 self._rows = None 

44 

45 def close(self): 

46 """ 

47 Closing a cursor just exhausts all remaining data. 

48 """ 

49 conn = self.connection 

50 if conn is None: 

51 return 

52 try: 

53 while self.nextset(): 

54 pass 

55 finally: 

56 self.connection = None 

57 

58 def __enter__(self): 

59 return self 

60 

61 def __exit__(self, *exc_info): 

62 del exc_info 

63 self.close() 

64 

65 def _get_db(self): 

66 if not self.connection: 

67 raise err.ProgrammingError("Cursor closed") 

68 return self.connection 

69 

70 def _check_executed(self): 

71 if not self._executed: 

72 raise err.ProgrammingError("execute() first") 

73 

74 def _conv_row(self, row): 

75 return row 

76 

77 def setinputsizes(self, *args): 

78 """Does nothing, required by DB API.""" 

79 

80 def setoutputsizes(self, *args): 

81 """Does nothing, required by DB API.""" 

82 

83 def _nextset(self, unbuffered=False): 

84 """Get the next query set.""" 

85 conn = self._get_db() 

86 current_result = self._result 

87 if current_result is None or current_result is not conn._result: 

88 return None 

89 if not current_result.has_next: 

90 return None 

91 self._result = None 

92 self._clear_result() 

93 conn.next_result(unbuffered=unbuffered) 

94 self._do_get_result() 

95 return True 

96 

97 def nextset(self): 

98 return self._nextset(False) 

99 

100 def _escape_args(self, args, conn): 

101 if isinstance(args, (tuple, list)): 

102 return tuple(conn.literal(arg) for arg in args) 

103 elif isinstance(args, dict): 

104 return {key: conn.literal(val) for (key, val) in args.items()} 

105 else: 

106 # If it's not a dictionary let's try escaping it anyways. 

107 # Worst case it will throw a Value error 

108 return conn.escape(args) 

109 

110 def mogrify(self, query, args=None): 

111 """ 

112 Returns the exact string that would be sent to the database by calling the 

113 execute() method. 

114 

115 :param query: Query to mogrify. 

116 :type query: str 

117 

118 :param args: Parameters used with query. (optional) 

119 :type args: tuple, list or dict 

120 

121 :return: The query with argument binding applied. 

122 :rtype: str 

123 

124 This method follows the extension to the DB API 2.0 followed by Psycopg. 

125 """ 

126 conn = self._get_db() 

127 

128 if args is not None: 

129 query = query % self._escape_args(args, conn) 

130 

131 return query 

132 

133 def execute(self, query, args=None): 

134 """Execute a query. 

135 

136 :param query: Query to execute. 

137 :type query: str 

138 

139 :param args: Parameters used with query. (optional) 

140 :type args: tuple, list or dict 

141 

142 :return: Number of affected rows. 

143 :rtype: int 

144 

145 If args is a list or tuple, %s can be used as a placeholder in the query. 

146 If args is a dict, %(name)s can be used as a placeholder in the query. 

147 """ 

148 while self.nextset(): 

149 pass 

150 

151 query = self.mogrify(query, args) 

152 

153 result = self._query(query) 

154 self._executed = query 

155 return result 

156 

157 def executemany(self, query, args): 

158 """Run several data against one query. 

159 

160 :param query: Query to execute. 

161 :type query: str 

162 

163 :param args: Sequence of sequences or mappings. It is used as parameter. 

164 :type args: tuple or list 

165 

166 :return: Number of rows affected, if any. 

167 :rtype: int or None 

168 

169 This method improves performance on multiple-row INSERT and 

170 REPLACE. Otherwise it is equivalent to looping over args with 

171 execute(). 

172 """ 

173 if not args: 

174 return 

175 

176 m = RE_INSERT_VALUES.match(query) 

177 if m: 

178 q_prefix = m.group(1) % () 

179 q_values = m.group(2).rstrip() 

180 q_postfix = m.group(3) or "" 

181 assert q_values[0] == "(" and q_values[-1] == ")" 

182 return self._do_execute_many( 

183 q_prefix, 

184 q_values, 

185 q_postfix, 

186 args, 

187 self.max_stmt_length, 

188 self._get_db().encoding, 

189 ) 

190 

191 self.rowcount = sum(self.execute(query, arg) for arg in args) 

192 return self.rowcount 

193 

194 def _do_execute_many( 

195 self, prefix, values, postfix, args, max_stmt_length, encoding 

196 ): 

197 conn = self._get_db() 

198 escape = self._escape_args 

199 if isinstance(prefix, str): 

200 prefix = prefix.encode(encoding) 

201 if isinstance(postfix, str): 

202 postfix = postfix.encode(encoding) 

203 sql = bytearray(prefix) 

204 args = iter(args) 

205 v = values % escape(next(args), conn) 

206 if isinstance(v, str): 

207 v = v.encode(encoding, "surrogateescape") 

208 sql += v 

209 rows = 0 

210 for arg in args: 

211 v = values % escape(arg, conn) 

212 if isinstance(v, str): 

213 v = v.encode(encoding, "surrogateescape") 

214 if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: 

215 rows += self.execute(sql + postfix) 

216 sql = bytearray(prefix) 

217 else: 

218 sql += b"," 

219 sql += v 

220 rows += self.execute(sql + postfix) 

221 self.rowcount = rows 

222 return rows 

223 

224 def callproc(self, procname, args=()): 

225 """Execute stored procedure procname with args. 

226 

227 :param procname: Name of procedure to execute on server. 

228 :type procname: str 

229 

230 :param args: Sequence of parameters to use with procedure. 

231 :type args: tuple or list 

232 

233 Returns the original args. 

234 

235 Compatibility warning: PEP-249 specifies that any modified 

236 parameters must be returned. This is currently impossible 

237 as they are only available by storing them in a server 

238 variable and then retrieved by a query. Since stored 

239 procedures return zero or more result sets, there is no 

240 reliable way to get at OUT or INOUT parameters via callproc. 

241 The server variables are named @_procname_n, where procname 

242 is the parameter above and n is the position of the parameter 

243 (from zero). Once all result sets generated by the procedure 

244 have been fetched, you can issue a SELECT @_procname_0, ... 

245 query using .execute() to get any OUT or INOUT values. 

246 

247 Compatibility warning: The act of calling a stored procedure 

248 itself creates an empty result set. This appears after any 

249 result sets generated by the procedure. This is non-standard 

250 behavior with respect to the DB-API. Be sure to use nextset() 

251 to advance through all result sets; otherwise you may get 

252 disconnected. 

253 """ 

254 conn = self._get_db() 

255 if args: 

256 fmt = f"@_{procname}_%d=%s" 

257 self._query( 

258 "SET %s" 

259 % ",".join( 

260 fmt % (index, conn.escape(arg)) for index, arg in enumerate(args) 

261 ) 

262 ) 

263 self.nextset() 

264 

265 q = "CALL {}({})".format( 

266 procname, 

267 ",".join(["@_%s_%d" % (procname, i) for i in range(len(args))]), 

268 ) 

269 self._query(q) 

270 self._executed = q 

271 return args 

272 

273 def fetchone(self): 

274 """Fetch the next row.""" 

275 self._check_executed() 

276 if self._rows is None or self.rownumber >= len(self._rows): 

277 return None 

278 result = self._rows[self.rownumber] 

279 self.rownumber += 1 

280 return result 

281 

282 def fetchmany(self, size=None): 

283 """Fetch several rows.""" 

284 self._check_executed() 

285 if self._rows is None: 

286 # Django expects () for EOF. 

287 # https://github.com/django/django/blob/0c1518ee429b01c145cf5b34eab01b0b92f8c246/django/db/backends/mysql/features.py#L8 

288 return () 

289 end = self.rownumber + (size or self.arraysize) 

290 result = self._rows[self.rownumber : end] 

291 self.rownumber = min(end, len(self._rows)) 

292 return result 

293 

294 def fetchall(self): 

295 """Fetch all the rows.""" 

296 self._check_executed() 

297 if self._rows is None: 

298 return [] 

299 if self.rownumber: 

300 result = self._rows[self.rownumber :] 

301 else: 

302 result = self._rows 

303 self.rownumber = len(self._rows) 

304 return result 

305 

306 def scroll(self, value, mode="relative"): 

307 self._check_executed() 

308 if mode == "relative": 

309 r = self.rownumber + value 

310 elif mode == "absolute": 

311 r = value 

312 else: 

313 raise err.ProgrammingError("unknown scroll mode %s" % mode) 

314 

315 if not (0 <= r < len(self._rows)): 

316 raise IndexError("out of range") 

317 self.rownumber = r 

318 

319 def _query(self, q): 

320 conn = self._get_db() 

321 self._clear_result() 

322 conn.query(q) 

323 self._do_get_result() 

324 return self.rowcount 

325 

326 def _clear_result(self): 

327 self.rownumber = 0 

328 self._result = None 

329 

330 self.rowcount = 0 

331 self.warning_count = 0 

332 self.description = None 

333 self.lastrowid = None 

334 self._rows = None 

335 

336 def _do_get_result(self): 

337 conn = self._get_db() 

338 

339 self._result = result = conn._result 

340 

341 self.rowcount = result.affected_rows 

342 self.warning_count = result.warning_count 

343 self.description = result.description 

344 self.lastrowid = result.insert_id 

345 self._rows = result.rows 

346 

347 def __iter__(self): 

348 return self 

349 

350 def __next__(self): 

351 row = self.fetchone() 

352 if row is None: 

353 raise StopIteration 

354 return row 

355 

356 def __getattr__(self, name): 

357 # DB-API 2.0 optional extension says these errors can be accessed 

358 # via Connection object. But MySQLdb had defined them on Cursor object. 

359 if name in ( 

360 "Warning", 

361 "Error", 

362 "InterfaceError", 

363 "DatabaseError", 

364 "DataError", 

365 "OperationalError", 

366 "IntegrityError", 

367 "InternalError", 

368 "ProgrammingError", 

369 "NotSupportedError", 

370 ): 

371 # Deprecated since v1.1 

372 warnings.warn( 

373 "PyMySQL errors hould be accessed from `pymysql` package", 

374 DeprecationWarning, 

375 stacklevel=2, 

376 ) 

377 return getattr(err, name) 

378 raise AttributeError(name) 

379 

380 

381class DictCursorMixin: 

382 # You can override this to use OrderedDict or other dict-like types. 

383 dict_type = dict 

384 

385 def _do_get_result(self): 

386 super()._do_get_result() 

387 fields = [] 

388 if self.description: 

389 for f in self._result.fields: 

390 name = f.name 

391 if name in fields: 

392 name = f.table_name + "." + name 

393 fields.append(name) 

394 self._fields = fields 

395 

396 if fields and self._rows: 

397 self._rows = [self._conv_row(r) for r in self._rows] 

398 

399 def _conv_row(self, row): 

400 if row is None: 

401 return None 

402 return self.dict_type(zip(self._fields, row)) 

403 

404 

405class DictCursor(DictCursorMixin, Cursor): 

406 """A cursor which returns results as a dictionary""" 

407 

408 

409class SSCursor(Cursor): 

410 """ 

411 Unbuffered Cursor, mainly useful for queries that return a lot of data, 

412 or for connections to remote servers over a slow network. 

413 

414 Instead of copying every row of data into a buffer, this will fetch 

415 rows as needed. The upside of this is the client uses much less memory, 

416 and rows are returned much faster when traveling over a slow network 

417 or if the result set is very big. 

418 

419 There are limitations, though. The MySQL protocol doesn't support 

420 returning the total number of rows, so the only way to tell how many rows 

421 there are is to iterate over every row returned. Also, it currently isn't 

422 possible to scroll backwards, as only the current row is held in memory. 

423 """ 

424 

425 def _conv_row(self, row): 

426 return row 

427 

428 def close(self): 

429 conn = self.connection 

430 if conn is None: 

431 return 

432 

433 if self._result is not None and self._result is conn._result: 

434 self._result._finish_unbuffered_query() 

435 

436 try: 

437 while self.nextset(): 

438 pass 

439 finally: 

440 self.connection = None 

441 

442 __del__ = close 

443 

444 def _query(self, q): 

445 conn = self._get_db() 

446 self._clear_result() 

447 conn.query(q, unbuffered=True) 

448 self._do_get_result() 

449 return self.rowcount 

450 

451 def nextset(self): 

452 return self._nextset(unbuffered=True) 

453 

454 def read_next(self): 

455 """Read next row.""" 

456 return self._conv_row(self._result._read_rowdata_packet_unbuffered()) 

457 

458 def fetchone(self): 

459 """Fetch next row.""" 

460 self._check_executed() 

461 row = self.read_next() 

462 if row is None: 

463 self.warning_count = self._result.warning_count 

464 return None 

465 self.rownumber += 1 

466 return row 

467 

468 def fetchall(self): 

469 """ 

470 Fetch all, as per MySQLdb. Pretty useless for large queries, as 

471 it is buffered. See fetchall_unbuffered(), if you want an unbuffered 

472 generator version of this method. 

473 """ 

474 return list(self.fetchall_unbuffered()) 

475 

476 def fetchall_unbuffered(self): 

477 """ 

478 Fetch all, implemented as a generator, which isn't to standard, 

479 however, it doesn't make sense to return everything in a list, as that 

480 would use ridiculous memory for large result sets. 

481 """ 

482 return iter(self.fetchone, None) 

483 

484 def fetchmany(self, size=None): 

485 """Fetch many.""" 

486 self._check_executed() 

487 if size is None: 

488 size = self.arraysize 

489 

490 rows = [] 

491 for i in range(size): 

492 row = self.read_next() 

493 if row is None: 

494 self.warning_count = self._result.warning_count 

495 break 

496 rows.append(row) 

497 self.rownumber += 1 

498 if not rows: 

499 # Django expects () for EOF. 

500 # https://github.com/django/django/blob/0c1518ee429b01c145cf5b34eab01b0b92f8c246/django/db/backends/mysql/features.py#L8 

501 return () 

502 return rows 

503 

504 def scroll(self, value, mode="relative"): 

505 self._check_executed() 

506 

507 if mode == "relative": 

508 if value < 0: 

509 raise err.NotSupportedError( 

510 "Backwards scrolling not supported by this cursor" 

511 ) 

512 

513 for _ in range(value): 

514 self.read_next() 

515 self.rownumber += value 

516 elif mode == "absolute": 

517 if value < self.rownumber: 

518 raise err.NotSupportedError( 

519 "Backwards scrolling not supported by this cursor" 

520 ) 

521 

522 end = value - self.rownumber 

523 for _ in range(end): 

524 self.read_next() 

525 self.rownumber = value 

526 else: 

527 raise err.ProgrammingError("unknown scroll mode %s" % mode) 

528 

529 

530class SSDictCursor(DictCursorMixin, SSCursor): 

531 """An unbuffered cursor, which returns results as a dictionary"""