Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/util/_collections.py: 60%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

368 statements  

1# util/_collections.py 

2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7# mypy: allow-untyped-defs, allow-untyped-calls 

8 

9"""Collection classes and helpers.""" 

10from __future__ import annotations 

11 

12import operator 

13import threading 

14import types 

15import typing 

16from typing import Any 

17from typing import Callable 

18from typing import cast 

19from typing import Container 

20from typing import Dict 

21from typing import FrozenSet 

22from typing import Generic 

23from typing import Iterable 

24from typing import Iterator 

25from typing import List 

26from typing import Mapping 

27from typing import NoReturn 

28from typing import Optional 

29from typing import overload 

30from typing import Sequence 

31from typing import Set 

32from typing import Tuple 

33from typing import TypeVar 

34from typing import Union 

35from typing import ValuesView 

36import weakref 

37 

38from ._has_cy import HAS_CYEXTENSION 

39from .typing import is_non_string_iterable 

40from .typing import Literal 

41from .typing import Protocol 

42 

43if typing.TYPE_CHECKING or not HAS_CYEXTENSION: 

44 from ._py_collections import immutabledict as immutabledict 

45 from ._py_collections import IdentitySet as IdentitySet 

46 from ._py_collections import ReadOnlyContainer as ReadOnlyContainer 

47 from ._py_collections import ImmutableDictBase as ImmutableDictBase 

48 from ._py_collections import OrderedSet as OrderedSet 

49 from ._py_collections import unique_list as unique_list 

50else: 

51 from sqlalchemy.cyextension.immutabledict import ( 

52 ReadOnlyContainer as ReadOnlyContainer, 

53 ) 

54 from sqlalchemy.cyextension.immutabledict import ( 

55 ImmutableDictBase as ImmutableDictBase, 

56 ) 

57 from sqlalchemy.cyextension.immutabledict import ( 

58 immutabledict as immutabledict, 

59 ) 

60 from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet 

61 from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet 

62 from sqlalchemy.cyextension.collections import ( # noqa 

63 unique_list as unique_list, 

64 ) 

65 

66 

67_T = TypeVar("_T", bound=Any) 

68_KT = TypeVar("_KT", bound=Any) 

69_VT = TypeVar("_VT", bound=Any) 

70_T_co = TypeVar("_T_co", covariant=True) 

71 

72EMPTY_SET: FrozenSet[Any] = frozenset() 

73NONE_SET: FrozenSet[Any] = frozenset([None]) 

74 

75 

76def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]: 

77 """merge two lists, maintaining ordering as much as possible. 

78 

79 this is to reconcile vars(cls) with cls.__annotations__. 

80 

81 Example:: 

82 

83 >>> a = ["__tablename__", "id", "x", "created_at"] 

84 >>> b = ["id", "name", "data", "y", "created_at"] 

85 >>> merge_lists_w_ordering(a, b) 

86 ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] 

87 

88 This is not necessarily the ordering that things had on the class, 

89 in this case the class is:: 

90 

91 class User(Base): 

92 __tablename__ = "users" 

93 

94 id: Mapped[int] = mapped_column(primary_key=True) 

95 name: Mapped[str] 

96 data: Mapped[Optional[str]] 

97 x = Column(Integer) 

98 y: Mapped[int] 

99 created_at: Mapped[datetime.datetime] = mapped_column() 

100 

101 But things are *mostly* ordered. 

102 

103 The algorithm could also be done by creating a partial ordering for 

104 all items in both lists and then using topological_sort(), but that 

105 is too much overhead. 

106 

107 Background on how I came up with this is at: 

108 https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae 

109 

110 """ 

111 overlap = set(a).intersection(b) 

112 

113 result = [] 

114 

115 current, other = iter(a), iter(b) 

116 

117 while True: 

118 for element in current: 

119 if element in overlap: 

120 overlap.discard(element) 

121 other, current = current, other 

122 break 

123 

124 result.append(element) 

125 else: 

126 result.extend(other) 

127 break 

128 

129 return result 

130 

131 

132def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]: 

133 if not d: 

134 return EMPTY_DICT 

135 elif isinstance(d, immutabledict): 

136 return d 

137 else: 

138 return immutabledict(d) 

139 

140 

141EMPTY_DICT: immutabledict[Any, Any] = immutabledict() 

142 

143 

144class FacadeDict(ImmutableDictBase[_KT, _VT]): 

145 """A dictionary that is not publicly mutable.""" 

146 

147 def __new__(cls, *args: Any) -> FacadeDict[Any, Any]: 

148 new = ImmutableDictBase.__new__(cls) 

149 return new 

150 

151 def copy(self) -> NoReturn: 

152 raise NotImplementedError( 

153 "an immutabledict shouldn't need to be copied. use dict(d) " 

154 "if you need a mutable dictionary." 

155 ) 

156 

157 def __reduce__(self) -> Any: 

158 return FacadeDict, (dict(self),) 

159 

160 def _insert_item(self, key: _KT, value: _VT) -> None: 

161 """insert an item into the dictionary directly.""" 

162 dict.__setitem__(self, key, value) 

163 

164 def __repr__(self) -> str: 

165 return "FacadeDict(%s)" % dict.__repr__(self) 

166 

167 

168_DT = TypeVar("_DT", bound=Any) 

169 

170_F = TypeVar("_F", bound=Any) 

171 

172 

173class Properties(Generic[_T]): 

174 """Provide a __getattr__/__setattr__ interface over a dict.""" 

175 

176 __slots__ = ("_data",) 

177 

178 _data: Dict[str, _T] 

179 

180 def __init__(self, data: Dict[str, _T]): 

181 object.__setattr__(self, "_data", data) 

182 

183 def __len__(self) -> int: 

184 return len(self._data) 

185 

186 def __iter__(self) -> Iterator[_T]: 

187 return iter(list(self._data.values())) 

188 

189 def __dir__(self) -> List[str]: 

190 return dir(super()) + [str(k) for k in self._data.keys()] 

191 

192 def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: 

193 return list(self) + list(other) 

194 

195 def __setitem__(self, key: str, obj: _T) -> None: 

196 self._data[key] = obj 

197 

198 def __getitem__(self, key: str) -> _T: 

199 return self._data[key] 

200 

201 def __delitem__(self, key: str) -> None: 

202 del self._data[key] 

203 

204 def __setattr__(self, key: str, obj: _T) -> None: 

205 self._data[key] = obj 

206 

207 def __getstate__(self) -> Dict[str, Any]: 

208 return {"_data": self._data} 

209 

210 def __setstate__(self, state: Dict[str, Any]) -> None: 

211 object.__setattr__(self, "_data", state["_data"]) 

212 

213 def __getattr__(self, key: str) -> _T: 

214 try: 

215 return self._data[key] 

216 except KeyError: 

217 raise AttributeError(key) 

218 

219 def __contains__(self, key: str) -> bool: 

220 return key in self._data 

221 

222 def as_readonly(self) -> ReadOnlyProperties[_T]: 

223 """Return an immutable proxy for this :class:`.Properties`.""" 

224 

225 return ReadOnlyProperties(self._data) 

226 

227 def update(self, value: Dict[str, _T]) -> None: 

228 self._data.update(value) 

229 

230 @overload 

231 def get(self, key: str) -> Optional[_T]: ... 

232 

233 @overload 

234 def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ... 

235 

236 def get( 

237 self, key: str, default: Optional[Union[_DT, _T]] = None 

238 ) -> Optional[Union[_T, _DT]]: 

239 if key in self: 

240 return self[key] 

241 else: 

242 return default 

243 

244 def keys(self) -> List[str]: 

245 return list(self._data) 

246 

247 def values(self) -> List[_T]: 

248 return list(self._data.values()) 

249 

250 def items(self) -> List[Tuple[str, _T]]: 

251 return list(self._data.items()) 

252 

253 def has_key(self, key: str) -> bool: 

254 return key in self._data 

255 

256 def clear(self) -> None: 

257 self._data.clear() 

258 

259 

260class OrderedProperties(Properties[_T]): 

261 """Provide a __getattr__/__setattr__ interface with an OrderedDict 

262 as backing store.""" 

263 

264 __slots__ = () 

265 

266 def __init__(self): 

267 Properties.__init__(self, OrderedDict()) 

268 

269 

270class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]): 

271 """Provide immutable dict/object attribute to an underlying dictionary.""" 

272 

273 __slots__ = () 

274 

275 

276def _ordered_dictionary_sort(d, key=None): 

277 """Sort an OrderedDict in-place.""" 

278 

279 items = [(k, d[k]) for k in sorted(d, key=key)] 

280 

281 d.clear() 

282 

283 d.update(items) 

284 

285 

286OrderedDict = dict 

287sort_dictionary = _ordered_dictionary_sort 

288 

289 

290class WeakSequence(Sequence[_T]): 

291 def __init__(self, __elements: Sequence[_T] = ()): 

292 # adapted from weakref.WeakKeyDictionary, prevent reference 

293 # cycles in the collection itself 

294 def _remove(item, selfref=weakref.ref(self)): 

295 self = selfref() 

296 if self is not None: 

297 self._storage.remove(item) 

298 

299 self._remove = _remove 

300 self._storage = [ 

301 weakref.ref(element, _remove) for element in __elements 

302 ] 

303 

304 def append(self, item): 

305 self._storage.append(weakref.ref(item, self._remove)) 

306 

307 def __len__(self): 

308 return len(self._storage) 

309 

310 def __iter__(self): 

311 return ( 

312 obj for obj in (ref() for ref in self._storage) if obj is not None 

313 ) 

314 

315 def __getitem__(self, index): 

316 try: 

317 obj = self._storage[index] 

318 except KeyError: 

319 raise IndexError("Index %s out of range" % index) 

320 else: 

321 return obj() 

322 

323 

324class OrderedIdentitySet(IdentitySet): 

325 def __init__(self, iterable: Optional[Iterable[Any]] = None): 

326 IdentitySet.__init__(self) 

327 self._members = OrderedDict() 

328 if iterable: 

329 for o in iterable: 

330 self.add(o) 

331 

332 

333class PopulateDict(Dict[_KT, _VT]): 

334 """A dict which populates missing values via a creation function. 

335 

336 Note the creation function takes a key, unlike 

337 collections.defaultdict. 

338 

339 """ 

340 

341 def __init__(self, creator: Callable[[_KT], _VT]): 

342 self.creator = creator 

343 

344 def __missing__(self, key: Any) -> Any: 

345 self[key] = val = self.creator(key) 

346 return val 

347 

348 

349class WeakPopulateDict(Dict[_KT, _VT]): 

350 """Like PopulateDict, but assumes a self + a method and does not create 

351 a reference cycle. 

352 

353 """ 

354 

355 def __init__(self, creator_method: types.MethodType): 

356 self.creator = creator_method.__func__ 

357 weakself = creator_method.__self__ 

358 self.weakself = weakref.ref(weakself) 

359 

360 def __missing__(self, key: Any) -> Any: 

361 self[key] = val = self.creator(self.weakself(), key) 

362 return val 

363 

364 

365# Define collections that are capable of storing 

366# ColumnElement objects as hashable keys/elements. 

367# At this point, these are mostly historical, things 

368# used to be more complicated. 

369column_set = set 

370column_dict = dict 

371ordered_column_set = OrderedSet 

372 

373 

374class UniqueAppender(Generic[_T]): 

375 """Appends items to a collection ensuring uniqueness. 

376 

377 Additional appends() of the same object are ignored. Membership is 

378 determined by identity (``is a``) not equality (``==``). 

379 """ 

380 

381 __slots__ = "data", "_data_appender", "_unique" 

382 

383 data: Union[Iterable[_T], Set[_T], List[_T]] 

384 _data_appender: Callable[[_T], None] 

385 _unique: Dict[int, Literal[True]] 

386 

387 def __init__( 

388 self, 

389 data: Union[Iterable[_T], Set[_T], List[_T]], 

390 via: Optional[str] = None, 

391 ): 

392 self.data = data 

393 self._unique = {} 

394 if via: 

395 self._data_appender = getattr(data, via) 

396 elif hasattr(data, "append"): 

397 self._data_appender = cast("List[_T]", data).append 

398 elif hasattr(data, "add"): 

399 self._data_appender = cast("Set[_T]", data).add 

400 

401 def append(self, item: _T) -> None: 

402 id_ = id(item) 

403 if id_ not in self._unique: 

404 self._data_appender(item) 

405 self._unique[id_] = True 

406 

407 def __iter__(self) -> Iterator[_T]: 

408 return iter(self.data) 

409 

410 

411def coerce_generator_arg(arg: Any) -> List[Any]: 

412 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): 

413 return list(arg[0]) 

414 else: 

415 return cast("List[Any]", arg) 

416 

417 

418def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: 

419 if x is None: 

420 return default # type: ignore 

421 if not is_non_string_iterable(x): 

422 return [x] 

423 elif isinstance(x, list): 

424 return x 

425 else: 

426 return list(x) 

427 

428 

429def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool: 

430 r"""return True if any items of set\_ are present in iterable. 

431 

432 Goes through special effort to ensure __hash__ is not called 

433 on items in iterable that don't support it. 

434 

435 """ 

436 return any(i in set_ for i in iterable if i.__hash__) 

437 

438 

439def to_set(x): 

440 if x is None: 

441 return set() 

442 if not isinstance(x, set): 

443 return set(to_list(x)) 

444 else: 

445 return x 

446 

447 

448def to_column_set(x: Any) -> Set[Any]: 

449 if x is None: 

450 return column_set() 

451 if not isinstance(x, column_set): 

452 return column_set(to_list(x)) 

453 else: 

454 return x 

455 

456 

457def update_copy( 

458 d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any 

459) -> Dict[Any, Any]: 

460 """Copy the given dict and update with the given values.""" 

461 

462 d = d.copy() 

463 if _new: 

464 d.update(_new) 

465 d.update(**kw) 

466 return d 

467 

468 

469def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: 

470 """Given an iterator of which further sub-elements may also be 

471 iterators, flatten the sub-elements into a single iterator. 

472 

473 """ 

474 elem: _T 

475 for elem in x: 

476 if not isinstance(elem, str) and hasattr(elem, "__iter__"): 

477 yield from flatten_iterator(elem) 

478 else: 

479 yield elem 

480 

481 

482class LRUCache(typing.MutableMapping[_KT, _VT]): 

483 """Dictionary with 'squishy' removal of least 

484 recently used items. 

485 

486 Note that either get() or [] should be used here, but 

487 generally its not safe to do an "in" check first as the dictionary 

488 can change subsequent to that call. 

489 

490 """ 

491 

492 __slots__ = ( 

493 "capacity", 

494 "threshold", 

495 "size_alert", 

496 "_data", 

497 "_counter", 

498 "_mutex", 

499 ) 

500 

501 capacity: int 

502 threshold: float 

503 size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]] 

504 

505 def __init__( 

506 self, 

507 capacity: int = 100, 

508 threshold: float = 0.5, 

509 size_alert: Optional[Callable[..., None]] = None, 

510 ): 

511 self.capacity = capacity 

512 self.threshold = threshold 

513 self.size_alert = size_alert 

514 self._counter = 0 

515 self._mutex = threading.Lock() 

516 self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {} 

517 

518 def _inc_counter(self): 

519 self._counter += 1 

520 return self._counter 

521 

522 @overload 

523 def get(self, key: _KT) -> Optional[_VT]: ... 

524 

525 @overload 

526 def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... 

527 

528 def get( 

529 self, key: _KT, default: Optional[Union[_VT, _T]] = None 

530 ) -> Optional[Union[_VT, _T]]: 

531 item = self._data.get(key) 

532 if item is not None: 

533 item[2][0] = self._inc_counter() 

534 return item[1] 

535 else: 

536 return default 

537 

538 def __getitem__(self, key: _KT) -> _VT: 

539 item = self._data[key] 

540 item[2][0] = self._inc_counter() 

541 return item[1] 

542 

543 def __iter__(self) -> Iterator[_KT]: 

544 return iter(self._data) 

545 

546 def __len__(self) -> int: 

547 return len(self._data) 

548 

549 def values(self) -> ValuesView[_VT]: 

550 return typing.ValuesView({k: i[1] for k, i in self._data.items()}) 

551 

552 def __setitem__(self, key: _KT, value: _VT) -> None: 

553 self._data[key] = (key, value, [self._inc_counter()]) 

554 self._manage_size() 

555 

556 def __delitem__(self, __v: _KT) -> None: 

557 del self._data[__v] 

558 

559 @property 

560 def size_threshold(self) -> float: 

561 return self.capacity + self.capacity * self.threshold 

562 

563 def _manage_size(self) -> None: 

564 if not self._mutex.acquire(False): 

565 return 

566 try: 

567 size_alert = bool(self.size_alert) 

568 while len(self) > self.capacity + self.capacity * self.threshold: 

569 if size_alert: 

570 size_alert = False 

571 self.size_alert(self) # type: ignore 

572 by_counter = sorted( 

573 self._data.values(), 

574 key=operator.itemgetter(2), 

575 reverse=True, 

576 ) 

577 for item in by_counter[self.capacity :]: 

578 try: 

579 del self._data[item[0]] 

580 except KeyError: 

581 # deleted elsewhere; skip 

582 continue 

583 finally: 

584 self._mutex.release() 

585 

586 

587class _CreateFuncType(Protocol[_T_co]): 

588 def __call__(self) -> _T_co: ... 

589 

590 

591class _ScopeFuncType(Protocol): 

592 def __call__(self) -> Any: ... 

593 

594 

595class ScopedRegistry(Generic[_T]): 

596 """A Registry that can store one or multiple instances of a single 

597 class on the basis of a "scope" function. 

598 

599 The object implements ``__call__`` as the "getter", so by 

600 calling ``myregistry()`` the contained object is returned 

601 for the current scope. 

602 

603 :param createfunc: 

604 a callable that returns a new object to be placed in the registry 

605 

606 :param scopefunc: 

607 a callable that will return a key to store/retrieve an object. 

608 """ 

609 

610 __slots__ = "createfunc", "scopefunc", "registry" 

611 

612 createfunc: _CreateFuncType[_T] 

613 scopefunc: _ScopeFuncType 

614 registry: Any 

615 

616 def __init__( 

617 self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any] 

618 ): 

619 """Construct a new :class:`.ScopedRegistry`. 

620 

621 :param createfunc: A creation function that will generate 

622 a new value for the current scope, if none is present. 

623 

624 :param scopefunc: A function that returns a hashable 

625 token representing the current scope (such as, current 

626 thread identifier). 

627 

628 """ 

629 self.createfunc = createfunc 

630 self.scopefunc = scopefunc 

631 self.registry = {} 

632 

633 def __call__(self) -> _T: 

634 key = self.scopefunc() 

635 try: 

636 return self.registry[key] # type: ignore[no-any-return] 

637 except KeyError: 

638 return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501 

639 

640 def has(self) -> bool: 

641 """Return True if an object is present in the current scope.""" 

642 

643 return self.scopefunc() in self.registry 

644 

645 def set(self, obj: _T) -> None: 

646 """Set the value for the current scope.""" 

647 

648 self.registry[self.scopefunc()] = obj 

649 

650 def clear(self) -> None: 

651 """Clear the current scope, if any.""" 

652 

653 try: 

654 del self.registry[self.scopefunc()] 

655 except KeyError: 

656 pass 

657 

658 

659class ThreadLocalRegistry(ScopedRegistry[_T]): 

660 """A :class:`.ScopedRegistry` that uses a ``threading.local()`` 

661 variable for storage. 

662 

663 """ 

664 

665 def __init__(self, createfunc: Callable[[], _T]): 

666 self.createfunc = createfunc 

667 self.registry = threading.local() 

668 

669 def __call__(self) -> _T: 

670 try: 

671 return self.registry.value # type: ignore[no-any-return] 

672 except AttributeError: 

673 val = self.registry.value = self.createfunc() 

674 return val 

675 

676 def has(self) -> bool: 

677 return hasattr(self.registry, "value") 

678 

679 def set(self, obj: _T) -> None: 

680 self.registry.value = obj 

681 

682 def clear(self) -> None: 

683 try: 

684 del self.registry.value 

685 except AttributeError: 

686 pass 

687 

688 

689def has_dupes(sequence, target): 

690 """Given a sequence and search object, return True if there's more 

691 than one, False if zero or one of them. 

692 

693 

694 """ 

695 # compare to .index version below, this version introduces less function 

696 # overhead and is usually the same speed. At 15000 items (way bigger than 

697 # a relationship-bound collection in memory usually is) it begins to 

698 # fall behind the other version only by microseconds. 

699 c = 0 

700 for item in sequence: 

701 if item is target: 

702 c += 1 

703 if c > 1: 

704 return True 

705 return False 

706 

707 

708# .index version. the two __contains__ calls as well 

709# as .index() and isinstance() slow this down. 

710# def has_dupes(sequence, target): 

711# if target not in sequence: 

712# return False 

713# elif not isinstance(sequence, collections_abc.Sequence): 

714# return False 

715# 

716# idx = sequence.index(target) 

717# return target in sequence[idx + 1:]