Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/util/_collections.py: 56%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

352 statements  

1# util/_collections.py 

2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7# mypy: allow-untyped-defs, allow-untyped-calls 

8 

9"""Collection classes and helpers.""" 

10from __future__ import annotations 

11 

12import operator 

13import threading 

14import types 

15import typing 

16from typing import Any 

17from typing import Callable 

18from typing import cast 

19from typing import Container 

20from typing import Dict 

21from typing import FrozenSet 

22from typing import Generic 

23from typing import Iterable 

24from typing import Iterator 

25from typing import List 

26from typing import Mapping 

27from typing import NoReturn 

28from typing import Optional 

29from typing import overload 

30from typing import Protocol 

31from typing import Sequence 

32from typing import Set 

33from typing import Tuple 

34from typing import TypeVar 

35from typing import Union 

36from typing import ValuesView 

37import weakref 

38 

39from ._collections_cy import IdentitySet as IdentitySet 

40from ._collections_cy import OrderedSet as OrderedSet 

41from ._collections_cy import unique_list as unique_list # noqa: F401 

42from ._immutabledict_cy import immutabledict as immutabledict 

43from ._immutabledict_cy import ImmutableDictBase as ImmutableDictBase 

44from ._immutabledict_cy import ReadOnlyContainer as ReadOnlyContainer 

45from .typing import is_non_string_iterable 

46from .typing import Literal 

47 

48 

49_T = TypeVar("_T", bound=Any) 

50_KT = TypeVar("_KT", bound=Any) 

51_VT = TypeVar("_VT", bound=Any) 

52_T_co = TypeVar("_T_co", covariant=True) 

53 

54EMPTY_SET: FrozenSet[Any] = frozenset() 

55NONE_SET: FrozenSet[Any] = frozenset([None]) 

56 

57 

58def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]: 

59 """merge two lists, maintaining ordering as much as possible. 

60 

61 this is to reconcile vars(cls) with cls.__annotations__. 

62 

63 Example:: 

64 

65 >>> a = ['__tablename__', 'id', 'x', 'created_at'] 

66 >>> b = ['id', 'name', 'data', 'y', 'created_at'] 

67 >>> merge_lists_w_ordering(a, b) 

68 ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] 

69 

70 This is not necessarily the ordering that things had on the class, 

71 in this case the class is:: 

72 

73 class User(Base): 

74 __tablename__ = "users" 

75 

76 id: Mapped[int] = mapped_column(primary_key=True) 

77 name: Mapped[str] 

78 data: Mapped[Optional[str]] 

79 x = Column(Integer) 

80 y: Mapped[int] 

81 created_at: Mapped[datetime.datetime] = mapped_column() 

82 

83 But things are *mostly* ordered. 

84 

85 The algorithm could also be done by creating a partial ordering for 

86 all items in both lists and then using topological_sort(), but that 

87 is too much overhead. 

88 

89 Background on how I came up with this is at: 

90 https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae 

91 

92 """ 

93 overlap = set(a).intersection(b) 

94 

95 result = [] 

96 

97 current, other = iter(a), iter(b) 

98 

99 while True: 

100 for element in current: 

101 if element in overlap: 

102 overlap.discard(element) 

103 other, current = current, other 

104 break 

105 

106 result.append(element) 

107 else: 

108 result.extend(other) 

109 break 

110 

111 return result 

112 

113 

114def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]: 

115 if not d: 

116 return EMPTY_DICT 

117 elif isinstance(d, immutabledict): 

118 return d 

119 else: 

120 return immutabledict(d) 

121 

122 

123EMPTY_DICT: immutabledict[Any, Any] = immutabledict() 

124 

125 

126class FacadeDict(ImmutableDictBase[_KT, _VT]): 

127 """A dictionary that is not publicly mutable.""" 

128 

129 def __new__(cls, *args: Any) -> FacadeDict[Any, Any]: 

130 new: FacadeDict[Any, Any] = ImmutableDictBase.__new__(cls) 

131 return new 

132 

133 def copy(self) -> NoReturn: 

134 raise NotImplementedError( 

135 "an immutabledict shouldn't need to be copied. use dict(d) " 

136 "if you need a mutable dictionary." 

137 ) 

138 

139 def __reduce__(self) -> Any: 

140 return FacadeDict, (dict(self),) 

141 

142 def _insert_item(self, key: _KT, value: _VT) -> None: 

143 """insert an item into the dictionary directly.""" 

144 dict.__setitem__(self, key, value) 

145 

146 def __repr__(self) -> str: 

147 return "FacadeDict(%s)" % dict.__repr__(self) 

148 

149 

150_DT = TypeVar("_DT", bound=Any) 

151 

152_F = TypeVar("_F", bound=Any) 

153 

154 

155class Properties(Generic[_T]): 

156 """Provide a __getattr__/__setattr__ interface over a dict.""" 

157 

158 __slots__ = ("_data",) 

159 

160 _data: Dict[str, _T] 

161 

162 def __init__(self, data: Dict[str, _T]): 

163 object.__setattr__(self, "_data", data) 

164 

165 def __len__(self) -> int: 

166 return len(self._data) 

167 

168 def __iter__(self) -> Iterator[_T]: 

169 return iter(list(self._data.values())) 

170 

171 def __dir__(self) -> List[str]: 

172 return dir(super()) + [str(k) for k in self._data.keys()] 

173 

174 def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: 

175 return list(self) + list(other) 

176 

177 def __setitem__(self, key: str, obj: _T) -> None: 

178 self._data[key] = obj 

179 

180 def __getitem__(self, key: str) -> _T: 

181 return self._data[key] 

182 

183 def __delitem__(self, key: str) -> None: 

184 del self._data[key] 

185 

186 def __setattr__(self, key: str, obj: _T) -> None: 

187 self._data[key] = obj 

188 

189 def __getstate__(self) -> Dict[str, Any]: 

190 return {"_data": self._data} 

191 

192 def __setstate__(self, state: Dict[str, Any]) -> None: 

193 object.__setattr__(self, "_data", state["_data"]) 

194 

195 def __getattr__(self, key: str) -> _T: 

196 try: 

197 return self._data[key] 

198 except KeyError: 

199 raise AttributeError(key) 

200 

201 def __contains__(self, key: str) -> bool: 

202 return key in self._data 

203 

204 def as_readonly(self) -> ReadOnlyProperties[_T]: 

205 """Return an immutable proxy for this :class:`.Properties`.""" 

206 

207 return ReadOnlyProperties(self._data) 

208 

209 def update(self, value: Dict[str, _T]) -> None: 

210 self._data.update(value) 

211 

212 @overload 

213 def get(self, key: str) -> Optional[_T]: ... 

214 

215 @overload 

216 def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ... 

217 

218 def get( 

219 self, key: str, default: Optional[Union[_DT, _T]] = None 

220 ) -> Optional[Union[_T, _DT]]: 

221 if key in self: 

222 return self[key] 

223 else: 

224 return default 

225 

226 def keys(self) -> List[str]: 

227 return list(self._data) 

228 

229 def values(self) -> List[_T]: 

230 return list(self._data.values()) 

231 

232 def items(self) -> List[Tuple[str, _T]]: 

233 return list(self._data.items()) 

234 

235 def has_key(self, key: str) -> bool: 

236 return key in self._data 

237 

238 def clear(self) -> None: 

239 self._data.clear() 

240 

241 

242class OrderedProperties(Properties[_T]): 

243 """Provide a __getattr__/__setattr__ interface with an OrderedDict 

244 as backing store.""" 

245 

246 __slots__ = () 

247 

248 def __init__(self): 

249 Properties.__init__(self, OrderedDict()) 

250 

251 

252class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]): 

253 """Provide immutable dict/object attribute to an underlying dictionary.""" 

254 

255 __slots__ = () 

256 

257 

258def _ordered_dictionary_sort(d, key=None): 

259 """Sort an OrderedDict in-place.""" 

260 

261 items = [(k, d[k]) for k in sorted(d, key=key)] 

262 

263 d.clear() 

264 

265 d.update(items) 

266 

267 

268OrderedDict = dict 

269sort_dictionary = _ordered_dictionary_sort 

270 

271 

272class WeakSequence(Sequence[_T]): 

273 def __init__(self, __elements: Sequence[_T] = (), /): 

274 # adapted from weakref.WeakKeyDictionary, prevent reference 

275 # cycles in the collection itself 

276 def _remove(item, selfref=weakref.ref(self)): 

277 self = selfref() 

278 if self is not None: 

279 self._storage.remove(item) 

280 

281 self._remove = _remove 

282 self._storage = [ 

283 weakref.ref(element, _remove) for element in __elements 

284 ] 

285 

286 def append(self, item): 

287 self._storage.append(weakref.ref(item, self._remove)) 

288 

289 def __len__(self): 

290 return len(self._storage) 

291 

292 def __iter__(self): 

293 return ( 

294 obj for obj in (ref() for ref in self._storage) if obj is not None 

295 ) 

296 

297 def __getitem__(self, index): 

298 try: 

299 obj = self._storage[index] 

300 except KeyError: 

301 raise IndexError("Index %s out of range" % index) 

302 else: 

303 return obj() 

304 

305 

306OrderedIdentitySet = IdentitySet 

307 

308 

309class PopulateDict(Dict[_KT, _VT]): 

310 """A dict which populates missing values via a creation function. 

311 

312 Note the creation function takes a key, unlike 

313 collections.defaultdict. 

314 

315 """ 

316 

317 def __init__(self, creator: Callable[[_KT], _VT]): 

318 self.creator = creator 

319 

320 def __missing__(self, key: Any) -> Any: 

321 self[key] = val = self.creator(key) 

322 return val 

323 

324 

325class WeakPopulateDict(Dict[_KT, _VT]): 

326 """Like PopulateDict, but assumes a self + a method and does not create 

327 a reference cycle. 

328 

329 """ 

330 

331 def __init__(self, creator_method: types.MethodType): 

332 self.creator = creator_method.__func__ 

333 weakself = creator_method.__self__ 

334 self.weakself = weakref.ref(weakself) 

335 

336 def __missing__(self, key: Any) -> Any: 

337 self[key] = val = self.creator(self.weakself(), key) 

338 return val 

339 

340 

341# Define collections that are capable of storing 

342# ColumnElement objects as hashable keys/elements. 

343# At this point, these are mostly historical, things 

344# used to be more complicated. 

345column_set = set 

346column_dict = dict 

347ordered_column_set = OrderedSet 

348 

349 

350class UniqueAppender(Generic[_T]): 

351 """Appends items to a collection ensuring uniqueness. 

352 

353 Additional appends() of the same object are ignored. Membership is 

354 determined by identity (``is a``) not equality (``==``). 

355 """ 

356 

357 __slots__ = "data", "_data_appender", "_unique" 

358 

359 data: Union[Iterable[_T], Set[_T], List[_T]] 

360 _data_appender: Callable[[_T], None] 

361 _unique: Dict[int, Literal[True]] 

362 

363 def __init__( 

364 self, 

365 data: Union[Iterable[_T], Set[_T], List[_T]], 

366 via: Optional[str] = None, 

367 ): 

368 self.data = data 

369 self._unique = {} 

370 if via: 

371 self._data_appender = getattr(data, via) 

372 elif hasattr(data, "append"): 

373 self._data_appender = cast("List[_T]", data).append 

374 elif hasattr(data, "add"): 

375 self._data_appender = cast("Set[_T]", data).add 

376 

377 def append(self, item: _T) -> None: 

378 id_ = id(item) 

379 if id_ not in self._unique: 

380 self._data_appender(item) 

381 self._unique[id_] = True 

382 

383 def __iter__(self) -> Iterator[_T]: 

384 return iter(self.data) 

385 

386 

387def coerce_generator_arg(arg: Any) -> List[Any]: 

388 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): 

389 return list(arg[0]) 

390 else: 

391 return cast("List[Any]", arg) 

392 

393 

394def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: 

395 if x is None: 

396 return default # type: ignore 

397 if not is_non_string_iterable(x): 

398 return [x] 

399 elif isinstance(x, list): 

400 return x 

401 else: 

402 return list(x) 

403 

404 

405def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool: 

406 r"""return True if any items of set\_ are present in iterable. 

407 

408 Goes through special effort to ensure __hash__ is not called 

409 on items in iterable that don't support it. 

410 

411 """ 

412 return any(i in set_ for i in iterable if i.__hash__) 

413 

414 

415def to_set(x): 

416 if x is None: 

417 return set() 

418 if not isinstance(x, set): 

419 return set(to_list(x)) 

420 else: 

421 return x 

422 

423 

424def to_column_set(x: Any) -> Set[Any]: 

425 if x is None: 

426 return column_set() 

427 if not isinstance(x, column_set): 

428 return column_set(to_list(x)) 

429 else: 

430 return x 

431 

432 

433def update_copy(d, _new=None, **kw): 

434 """Copy the given dict and update with the given values.""" 

435 

436 d = d.copy() 

437 if _new: 

438 d.update(_new) 

439 d.update(**kw) 

440 return d 

441 

442 

443def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: 

444 """Given an iterator of which further sub-elements may also be 

445 iterators, flatten the sub-elements into a single iterator. 

446 

447 """ 

448 elem: _T 

449 for elem in x: 

450 if not isinstance(elem, str) and hasattr(elem, "__iter__"): 

451 yield from flatten_iterator(elem) 

452 else: 

453 yield elem 

454 

455 

456class LRUCache(typing.MutableMapping[_KT, _VT]): 

457 """Dictionary with 'squishy' removal of least 

458 recently used items. 

459 

460 Note that either get() or [] should be used here, but 

461 generally its not safe to do an "in" check first as the dictionary 

462 can change subsequent to that call. 

463 

464 """ 

465 

466 __slots__ = ( 

467 "capacity", 

468 "threshold", 

469 "size_alert", 

470 "_data", 

471 "_counter", 

472 "_mutex", 

473 ) 

474 

475 capacity: int 

476 threshold: float 

477 size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]] 

478 

479 def __init__( 

480 self, 

481 capacity: int = 100, 

482 threshold: float = 0.5, 

483 size_alert: Optional[Callable[..., None]] = None, 

484 ): 

485 self.capacity = capacity 

486 self.threshold = threshold 

487 self.size_alert = size_alert 

488 self._counter = 0 

489 self._mutex = threading.Lock() 

490 self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {} 

491 

492 def _inc_counter(self): 

493 self._counter += 1 

494 return self._counter 

495 

496 @overload 

497 def get(self, key: _KT) -> Optional[_VT]: ... 

498 

499 @overload 

500 def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... 

501 

502 def get( 

503 self, key: _KT, default: Optional[Union[_VT, _T]] = None 

504 ) -> Optional[Union[_VT, _T]]: 

505 item = self._data.get(key) 

506 if item is not None: 

507 item[2][0] = self._inc_counter() 

508 return item[1] 

509 else: 

510 return default 

511 

512 def __getitem__(self, key: _KT) -> _VT: 

513 item = self._data[key] 

514 item[2][0] = self._inc_counter() 

515 return item[1] 

516 

517 def __iter__(self) -> Iterator[_KT]: 

518 return iter(self._data) 

519 

520 def __len__(self) -> int: 

521 return len(self._data) 

522 

523 def values(self) -> ValuesView[_VT]: 

524 return typing.ValuesView({k: i[1] for k, i in self._data.items()}) 

525 

526 def __setitem__(self, key: _KT, value: _VT) -> None: 

527 self._data[key] = (key, value, [self._inc_counter()]) 

528 self._manage_size() 

529 

530 def __delitem__(self, __v: _KT) -> None: 

531 del self._data[__v] 

532 

533 @property 

534 def size_threshold(self) -> float: 

535 return self.capacity + self.capacity * self.threshold 

536 

537 def _manage_size(self) -> None: 

538 if not self._mutex.acquire(False): 

539 return 

540 try: 

541 size_alert = bool(self.size_alert) 

542 while len(self) > self.capacity + self.capacity * self.threshold: 

543 if size_alert: 

544 size_alert = False 

545 self.size_alert(self) # type: ignore 

546 by_counter = sorted( 

547 self._data.values(), 

548 key=operator.itemgetter(2), 

549 reverse=True, 

550 ) 

551 for item in by_counter[self.capacity :]: 

552 try: 

553 del self._data[item[0]] 

554 except KeyError: 

555 # deleted elsewhere; skip 

556 continue 

557 finally: 

558 self._mutex.release() 

559 

560 

561class _CreateFuncType(Protocol[_T_co]): 

562 def __call__(self) -> _T_co: ... 

563 

564 

565class _ScopeFuncType(Protocol): 

566 def __call__(self) -> Any: ... 

567 

568 

569class ScopedRegistry(Generic[_T]): 

570 """A Registry that can store one or multiple instances of a single 

571 class on the basis of a "scope" function. 

572 

573 The object implements ``__call__`` as the "getter", so by 

574 calling ``myregistry()`` the contained object is returned 

575 for the current scope. 

576 

577 :param createfunc: 

578 a callable that returns a new object to be placed in the registry 

579 

580 :param scopefunc: 

581 a callable that will return a key to store/retrieve an object. 

582 """ 

583 

584 __slots__ = "createfunc", "scopefunc", "registry" 

585 

586 createfunc: _CreateFuncType[_T] 

587 scopefunc: _ScopeFuncType 

588 registry: Any 

589 

590 def __init__( 

591 self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any] 

592 ): 

593 """Construct a new :class:`.ScopedRegistry`. 

594 

595 :param createfunc: A creation function that will generate 

596 a new value for the current scope, if none is present. 

597 

598 :param scopefunc: A function that returns a hashable 

599 token representing the current scope (such as, current 

600 thread identifier). 

601 

602 """ 

603 self.createfunc = createfunc 

604 self.scopefunc = scopefunc 

605 self.registry = {} 

606 

607 def __call__(self) -> _T: 

608 key = self.scopefunc() 

609 try: 

610 return self.registry[key] # type: ignore[no-any-return] 

611 except KeyError: 

612 return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501 

613 

614 def has(self) -> bool: 

615 """Return True if an object is present in the current scope.""" 

616 

617 return self.scopefunc() in self.registry 

618 

619 def set(self, obj: _T) -> None: 

620 """Set the value for the current scope.""" 

621 

622 self.registry[self.scopefunc()] = obj 

623 

624 def clear(self) -> None: 

625 """Clear the current scope, if any.""" 

626 

627 try: 

628 del self.registry[self.scopefunc()] 

629 except KeyError: 

630 pass 

631 

632 

633class ThreadLocalRegistry(ScopedRegistry[_T]): 

634 """A :class:`.ScopedRegistry` that uses a ``threading.local()`` 

635 variable for storage. 

636 

637 """ 

638 

639 def __init__(self, createfunc: Callable[[], _T]): 

640 self.createfunc = createfunc 

641 self.registry = threading.local() 

642 

643 def __call__(self) -> _T: 

644 try: 

645 return self.registry.value # type: ignore[no-any-return] 

646 except AttributeError: 

647 val = self.registry.value = self.createfunc() 

648 return val 

649 

650 def has(self) -> bool: 

651 return hasattr(self.registry, "value") 

652 

653 def set(self, obj: _T) -> None: 

654 self.registry.value = obj 

655 

656 def clear(self) -> None: 

657 try: 

658 del self.registry.value 

659 except AttributeError: 

660 pass 

661 

662 

663def has_dupes(sequence, target): 

664 """Given a sequence and search object, return True if there's more 

665 than one, False if zero or one of them. 

666 

667 

668 """ 

669 # compare to .index version below, this version introduces less function 

670 # overhead and is usually the same speed. At 15000 items (way bigger than 

671 # a relationship-bound collection in memory usually is) it begins to 

672 # fall behind the other version only by microseconds. 

673 c = 0 

674 for item in sequence: 

675 if item is target: 

676 c += 1 

677 if c > 1: 

678 return True 

679 return False 

680 

681 

682# .index version. the two __contains__ calls as well 

683# as .index() and isinstance() slow this down. 

684# def has_dupes(sequence, target): 

685# if target not in sequence: 

686# return False 

687# elif not isinstance(sequence, collections_abc.Sequence): 

688# return False 

689# 

690# idx = sequence.index(target) 

691# return target in sequence[idx + 1:]