Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/util/_collections.py: 56%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

350 statements  

1# util/_collections.py 

2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7# mypy: allow-untyped-defs, allow-untyped-calls 

8 

9"""Collection classes and helpers.""" 

10from __future__ import annotations 

11 

12import operator 

13import threading 

14import types 

15import typing 

16from typing import Any 

17from typing import Callable 

18from typing import cast 

19from typing import Container 

20from typing import Dict 

21from typing import FrozenSet 

22from typing import Generic 

23from typing import Iterable 

24from typing import Iterator 

25from typing import List 

26from typing import Literal 

27from typing import Mapping 

28from typing import NoReturn 

29from typing import Optional 

30from typing import overload 

31from typing import Protocol 

32from typing import Sequence 

33from typing import Set 

34from typing import Tuple 

35from typing import TypeVar 

36from typing import Union 

37from typing import ValuesView 

38import weakref 

39 

40from ._collections_cy import IdentitySet as IdentitySet 

41from ._collections_cy import OrderedSet as OrderedSet 

42from ._collections_cy import unique_list as unique_list # noqa: F401 

43from ._immutabledict_cy import immutabledict as immutabledict 

44from ._immutabledict_cy import ImmutableDictBase as ImmutableDictBase 

45from ._immutabledict_cy import ReadOnlyContainer as ReadOnlyContainer 

46from .typing import is_non_string_iterable 

47 

48 

49_T = TypeVar("_T", bound=Any) 

50_KT = TypeVar("_KT", bound=Any) 

51_VT = TypeVar("_VT", bound=Any) 

52_T_co = TypeVar("_T_co", covariant=True) 

53 

54EMPTY_SET: FrozenSet[Any] = frozenset() 

55NONE_SET: FrozenSet[Any] = frozenset([None]) 

56 

57 

58def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]: 

59 """merge two lists, maintaining ordering as much as possible. 

60 

61 this is to reconcile vars(cls) with cls.__annotations__. 

62 

63 Example:: 

64 

65 >>> a = ["__tablename__", "id", "x", "created_at"] 

66 >>> b = ["id", "name", "data", "y", "created_at"] 

67 >>> merge_lists_w_ordering(a, b) 

68 ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] 

69 

70 This is not necessarily the ordering that things had on the class, 

71 in this case the class is:: 

72 

73 class User(Base): 

74 __tablename__ = "users" 

75 

76 id: Mapped[int] = mapped_column(primary_key=True) 

77 name: Mapped[str] 

78 data: Mapped[Optional[str]] 

79 x = Column(Integer) 

80 y: Mapped[int] 

81 created_at: Mapped[datetime.datetime] = mapped_column() 

82 

83 But things are *mostly* ordered. 

84 

85 The algorithm could also be done by creating a partial ordering for 

86 all items in both lists and then using topological_sort(), but that 

87 is too much overhead. 

88 

89 Background on how I came up with this is at: 

90 https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae 

91 

92 """ 

93 overlap = set(a).intersection(b) 

94 

95 result = [] 

96 

97 current, other = iter(a), iter(b) 

98 

99 while True: 

100 for element in current: 

101 if element in overlap: 

102 overlap.discard(element) 

103 other, current = current, other 

104 break 

105 

106 result.append(element) 

107 else: 

108 result.extend(other) 

109 break 

110 

111 return result 

112 

113 

114def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]: 

115 if not d: 

116 return EMPTY_DICT 

117 elif isinstance(d, immutabledict): 

118 return d 

119 else: 

120 return immutabledict(d) 

121 

122 

123EMPTY_DICT: immutabledict[Any, Any] = immutabledict() 

124 

125 

126class FacadeDict(ImmutableDictBase[_KT, _VT]): 

127 """A dictionary that is not publicly mutable.""" 

128 

129 def __new__(cls, *args: Any) -> FacadeDict[Any, Any]: 

130 new: FacadeDict[Any, Any] = ImmutableDictBase.__new__(cls) 

131 return new 

132 

133 def copy(self) -> NoReturn: 

134 raise NotImplementedError( 

135 "an immutabledict shouldn't need to be copied. use dict(d) " 

136 "if you need a mutable dictionary." 

137 ) 

138 

139 def __reduce__(self) -> Any: 

140 return FacadeDict, (dict(self),) 

141 

142 def _insert_item(self, key: _KT, value: _VT) -> None: 

143 """insert an item into the dictionary directly.""" 

144 dict.__setitem__(self, key, value) 

145 

146 def __repr__(self) -> str: 

147 return "FacadeDict(%s)" % dict.__repr__(self) 

148 

149 

150_DT = TypeVar("_DT", bound=Any) 

151 

152_F = TypeVar("_F", bound=Any) 

153 

154 

155class Properties(Generic[_T]): 

156 """Provide a __getattr__/__setattr__ interface over a dict.""" 

157 

158 __slots__ = ("_data",) 

159 

160 _data: Dict[str, _T] 

161 

162 def __init__(self, data: Dict[str, _T]): 

163 object.__setattr__(self, "_data", data) 

164 

165 def __len__(self) -> int: 

166 return len(self._data) 

167 

168 def __iter__(self) -> Iterator[_T]: 

169 return iter(list(self._data.values())) 

170 

171 def __dir__(self) -> List[str]: 

172 return dir(super()) + [str(k) for k in self._data.keys()] 

173 

174 def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: 

175 return list(self) + list(other) 

176 

177 def __setitem__(self, key: str, obj: _T) -> None: 

178 self._data[key] = obj 

179 

180 def __getitem__(self, key: str) -> _T: 

181 return self._data[key] 

182 

183 def __delitem__(self, key: str) -> None: 

184 del self._data[key] 

185 

186 def __setattr__(self, key: str, obj: _T) -> None: 

187 self._data[key] = obj 

188 

189 def __getstate__(self) -> Dict[str, Any]: 

190 return {"_data": self._data} 

191 

192 def __setstate__(self, state: Dict[str, Any]) -> None: 

193 object.__setattr__(self, "_data", state["_data"]) 

194 

195 def __getattr__(self, key: str) -> _T: 

196 try: 

197 return self._data[key] 

198 except KeyError: 

199 raise AttributeError(key) 

200 

201 def __contains__(self, key: str) -> bool: 

202 return key in self._data 

203 

204 def as_readonly(self) -> ReadOnlyProperties[_T]: 

205 """Return an immutable proxy for this :class:`.Properties`.""" 

206 

207 return ReadOnlyProperties(self._data) 

208 

209 def update(self, value: Dict[str, _T]) -> None: 

210 self._data.update(value) 

211 

212 @overload 

213 def get(self, key: str) -> Optional[_T]: ... 

214 

215 @overload 

216 def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ... 

217 

218 def get( 

219 self, key: str, default: Optional[Union[_DT, _T]] = None 

220 ) -> Optional[Union[_T, _DT]]: 

221 if key in self: 

222 return self[key] 

223 else: 

224 return default 

225 

226 def keys(self) -> List[str]: 

227 return list(self._data) 

228 

229 def values(self) -> List[_T]: 

230 return list(self._data.values()) 

231 

232 def items(self) -> List[Tuple[str, _T]]: 

233 return list(self._data.items()) 

234 

235 def has_key(self, key: str) -> bool: 

236 return key in self._data 

237 

238 def clear(self) -> None: 

239 self._data.clear() 

240 

241 

242class OrderedProperties(Properties[_T]): 

243 """Provide a __getattr__/__setattr__ interface with an OrderedDict 

244 as backing store.""" 

245 

246 __slots__ = () 

247 

248 def __init__(self): 

249 Properties.__init__(self, OrderedDict()) 

250 

251 

252class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]): 

253 """Provide immutable dict/object attribute to an underlying dictionary.""" 

254 

255 __slots__ = () 

256 

257 

258def _ordered_dictionary_sort(d, key=None): 

259 """Sort an OrderedDict in-place.""" 

260 

261 items = [(k, d[k]) for k in sorted(d, key=key)] 

262 

263 d.clear() 

264 

265 d.update(items) 

266 

267 

268OrderedDict = dict 

269sort_dictionary = _ordered_dictionary_sort 

270 

271 

272class WeakSequence(Sequence[_T]): 

273 def __init__(self, __elements: Sequence[_T] = (), /): 

274 # adapted from weakref.WeakKeyDictionary, prevent reference 

275 # cycles in the collection itself 

276 def _remove(item, selfref=weakref.ref(self)): 

277 self = selfref() 

278 if self is not None: 

279 self._storage.remove(item) 

280 

281 self._remove = _remove 

282 self._storage = [ 

283 weakref.ref(element, _remove) for element in __elements 

284 ] 

285 

286 def append(self, item): 

287 self._storage.append(weakref.ref(item, self._remove)) 

288 

289 def __len__(self): 

290 return len(self._storage) 

291 

292 def __iter__(self): 

293 return ( 

294 obj for obj in (ref() for ref in self._storage) if obj is not None 

295 ) 

296 

297 def __getitem__(self, index): 

298 return self._storage[index]() 

299 

300 

301OrderedIdentitySet = IdentitySet 

302 

303 

304class PopulateDict(Dict[_KT, _VT]): 

305 """A dict which populates missing values via a creation function. 

306 

307 Note the creation function takes a key, unlike 

308 collections.defaultdict. 

309 

310 """ 

311 

312 def __init__(self, creator: Callable[[_KT], _VT]): 

313 self.creator = creator 

314 

315 def __missing__(self, key: Any) -> Any: 

316 self[key] = val = self.creator(key) 

317 return val 

318 

319 

320class WeakPopulateDict(Dict[_KT, _VT]): 

321 """Like PopulateDict, but assumes a self + a method and does not create 

322 a reference cycle. 

323 

324 """ 

325 

326 def __init__(self, creator_method: types.MethodType): 

327 self.creator = creator_method.__func__ 

328 weakself = creator_method.__self__ 

329 self.weakself = weakref.ref(weakself) 

330 

331 def __missing__(self, key: Any) -> Any: 

332 self[key] = val = self.creator(self.weakself(), key) 

333 return val 

334 

335 

336# Define collections that are capable of storing 

337# ColumnElement objects as hashable keys/elements. 

338# At this point, these are mostly historical, things 

339# used to be more complicated. 

340column_set = set 

341column_dict = dict 

342ordered_column_set = OrderedSet 

343 

344 

345class UniqueAppender(Generic[_T]): 

346 """Appends items to a collection ensuring uniqueness. 

347 

348 Additional appends() of the same object are ignored. Membership is 

349 determined by identity (``is a``) not equality (``==``). 

350 """ 

351 

352 __slots__ = "data", "_data_appender", "_unique" 

353 

354 data: Union[Iterable[_T], Set[_T], List[_T]] 

355 _data_appender: Callable[[_T], None] 

356 _unique: Dict[int, Literal[True]] 

357 

358 def __init__( 

359 self, 

360 data: Union[Iterable[_T], Set[_T], List[_T]], 

361 via: Optional[str] = None, 

362 ): 

363 self.data = data 

364 self._unique = {} 

365 if via: 

366 self._data_appender = getattr(data, via) 

367 elif hasattr(data, "append"): 

368 self._data_appender = cast("List[_T]", data).append 

369 elif hasattr(data, "add"): 

370 self._data_appender = cast("Set[_T]", data).add 

371 

372 def append(self, item: _T) -> None: 

373 id_ = id(item) 

374 if id_ not in self._unique: 

375 self._data_appender(item) 

376 self._unique[id_] = True 

377 

378 def __iter__(self) -> Iterator[_T]: 

379 return iter(self.data) 

380 

381 

382def coerce_generator_arg(arg: Any) -> List[Any]: 

383 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): 

384 return list(arg[0]) 

385 else: 

386 return cast("List[Any]", arg) 

387 

388 

389def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: 

390 if x is None: 

391 return default # type: ignore 

392 if not is_non_string_iterable(x): 

393 return [x] 

394 elif isinstance(x, list): 

395 return x 

396 else: 

397 return list(x) 

398 

399 

400def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool: 

401 r"""return True if any items of set\_ are present in iterable. 

402 

403 Goes through special effort to ensure __hash__ is not called 

404 on items in iterable that don't support it. 

405 

406 """ 

407 return any(i in set_ for i in iterable if i.__hash__) 

408 

409 

410def to_set(x): 

411 if x is None: 

412 return set() 

413 if not isinstance(x, set): 

414 return set(to_list(x)) 

415 else: 

416 return x 

417 

418 

419def to_column_set(x: Any) -> Set[Any]: 

420 if x is None: 

421 return column_set() 

422 if not isinstance(x, column_set): 

423 return column_set(to_list(x)) 

424 else: 

425 return x 

426 

427 

428def update_copy( 

429 d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any 

430) -> Dict[Any, Any]: 

431 """Copy the given dict and update with the given values.""" 

432 

433 d = d.copy() 

434 if _new: 

435 d.update(_new) 

436 d.update(**kw) 

437 return d 

438 

439 

440def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: 

441 """Given an iterator of which further sub-elements may also be 

442 iterators, flatten the sub-elements into a single iterator. 

443 

444 """ 

445 elem: _T 

446 for elem in x: 

447 if not isinstance(elem, str) and hasattr(elem, "__iter__"): 

448 yield from flatten_iterator(elem) 

449 else: 

450 yield elem 

451 

452 

453class LRUCache(typing.MutableMapping[_KT, _VT]): 

454 """Dictionary with 'squishy' removal of least 

455 recently used items. 

456 

457 Note that either get() or [] should be used here, but 

458 generally its not safe to do an "in" check first as the dictionary 

459 can change subsequent to that call. 

460 

461 """ 

462 

463 __slots__ = ( 

464 "capacity", 

465 "threshold", 

466 "size_alert", 

467 "_data", 

468 "_counter", 

469 "_mutex", 

470 ) 

471 

472 capacity: int 

473 threshold: float 

474 size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]] 

475 

476 def __init__( 

477 self, 

478 capacity: int = 100, 

479 threshold: float = 0.5, 

480 size_alert: Optional[Callable[..., None]] = None, 

481 ): 

482 self.capacity = capacity 

483 self.threshold = threshold 

484 self.size_alert = size_alert 

485 self._counter = 0 

486 self._mutex = threading.Lock() 

487 self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {} 

488 

489 def _inc_counter(self): 

490 self._counter += 1 

491 return self._counter 

492 

493 @overload 

494 def get(self, key: _KT) -> Optional[_VT]: ... 

495 

496 @overload 

497 def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... 

498 

499 def get( 

500 self, key: _KT, default: Optional[Union[_VT, _T]] = None 

501 ) -> Optional[Union[_VT, _T]]: 

502 item = self._data.get(key) 

503 if item is not None: 

504 item[2][0] = self._inc_counter() 

505 return item[1] 

506 else: 

507 return default 

508 

509 def __getitem__(self, key: _KT) -> _VT: 

510 item = self._data[key] 

511 item[2][0] = self._inc_counter() 

512 return item[1] 

513 

514 def __iter__(self) -> Iterator[_KT]: 

515 return iter(self._data) 

516 

517 def __len__(self) -> int: 

518 return len(self._data) 

519 

520 def values(self) -> ValuesView[_VT]: 

521 return typing.ValuesView({k: i[1] for k, i in self._data.items()}) 

522 

523 def __setitem__(self, key: _KT, value: _VT) -> None: 

524 self._data[key] = (key, value, [self._inc_counter()]) 

525 self._manage_size() 

526 

527 def __delitem__(self, __v: _KT) -> None: 

528 del self._data[__v] 

529 

530 @property 

531 def size_threshold(self) -> float: 

532 return self.capacity + self.capacity * self.threshold 

533 

534 def _manage_size(self) -> None: 

535 if not self._mutex.acquire(False): 

536 return 

537 try: 

538 size_alert = bool(self.size_alert) 

539 while len(self) > self.capacity + self.capacity * self.threshold: 

540 if size_alert: 

541 size_alert = False 

542 self.size_alert(self) # type: ignore 

543 by_counter = sorted( 

544 self._data.values(), 

545 key=operator.itemgetter(2), 

546 reverse=True, 

547 ) 

548 for item in by_counter[self.capacity :]: 

549 try: 

550 del self._data[item[0]] 

551 except KeyError: 

552 # deleted elsewhere; skip 

553 continue 

554 finally: 

555 self._mutex.release() 

556 

557 

558class _CreateFuncType(Protocol[_T_co]): 

559 def __call__(self) -> _T_co: ... 

560 

561 

562class _ScopeFuncType(Protocol): 

563 def __call__(self) -> Any: ... 

564 

565 

566class ScopedRegistry(Generic[_T]): 

567 """A Registry that can store one or multiple instances of a single 

568 class on the basis of a "scope" function. 

569 

570 The object implements ``__call__`` as the "getter", so by 

571 calling ``myregistry()`` the contained object is returned 

572 for the current scope. 

573 

574 :param createfunc: 

575 a callable that returns a new object to be placed in the registry 

576 

577 :param scopefunc: 

578 a callable that will return a key to store/retrieve an object. 

579 """ 

580 

581 __slots__ = "createfunc", "scopefunc", "registry" 

582 

583 createfunc: _CreateFuncType[_T] 

584 scopefunc: _ScopeFuncType 

585 registry: Any 

586 

587 def __init__( 

588 self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any] 

589 ): 

590 """Construct a new :class:`.ScopedRegistry`. 

591 

592 :param createfunc: A creation function that will generate 

593 a new value for the current scope, if none is present. 

594 

595 :param scopefunc: A function that returns a hashable 

596 token representing the current scope (such as, current 

597 thread identifier). 

598 

599 """ 

600 self.createfunc = createfunc 

601 self.scopefunc = scopefunc 

602 self.registry = {} 

603 

604 def __call__(self) -> _T: 

605 key = self.scopefunc() 

606 try: 

607 return self.registry[key] # type: ignore[no-any-return] 

608 except KeyError: 

609 return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501 

610 

611 def has(self) -> bool: 

612 """Return True if an object is present in the current scope.""" 

613 

614 return self.scopefunc() in self.registry 

615 

616 def set(self, obj: _T) -> None: 

617 """Set the value for the current scope.""" 

618 

619 self.registry[self.scopefunc()] = obj 

620 

621 def clear(self) -> None: 

622 """Clear the current scope, if any.""" 

623 

624 try: 

625 del self.registry[self.scopefunc()] 

626 except KeyError: 

627 pass 

628 

629 

630class ThreadLocalRegistry(ScopedRegistry[_T]): 

631 """A :class:`.ScopedRegistry` that uses a ``threading.local()`` 

632 variable for storage. 

633 

634 """ 

635 

636 def __init__(self, createfunc: Callable[[], _T]): 

637 self.createfunc = createfunc 

638 self.registry = threading.local() 

639 

640 def __call__(self) -> _T: 

641 try: 

642 return self.registry.value # type: ignore[no-any-return] 

643 except AttributeError: 

644 val = self.registry.value = self.createfunc() 

645 return val 

646 

647 def has(self) -> bool: 

648 return hasattr(self.registry, "value") 

649 

650 def set(self, obj: _T) -> None: 

651 self.registry.value = obj 

652 

653 def clear(self) -> None: 

654 try: 

655 del self.registry.value 

656 except AttributeError: 

657 pass 

658 

659 

660def has_dupes(sequence, target): 

661 """Given a sequence and search object, return True if there's more 

662 than one, False if zero or one of them. 

663 

664 

665 """ 

666 # compare to .index version below, this version introduces less function 

667 # overhead and is usually the same speed. At 15000 items (way bigger than 

668 # a relationship-bound collection in memory usually is) it begins to 

669 # fall behind the other version only by microseconds. 

670 c = 0 

671 for item in sequence: 

672 if item is target: 

673 c += 1 

674 if c > 1: 

675 return True 

676 return False 

677 

678 

679# .index version. the two __contains__ calls as well 

680# as .index() and isinstance() slow this down. 

681# def has_dupes(sequence, target): 

682# if target not in sequence: 

683# return False 

684# elif not isinstance(sequence, collections_abc.Sequence): 

685# return False 

686# 

687# idx = sequence.index(target) 

688# return target in sequence[idx + 1:]