Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/util/_collections.py: 60%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

364 statements  

1# util/_collections.py 

2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7# mypy: allow-untyped-defs, allow-untyped-calls 

8 

9"""Collection classes and helpers.""" 

10from __future__ import annotations 

11 

12import operator 

13import threading 

14import types 

15import typing 

16from typing import Any 

17from typing import Callable 

18from typing import cast 

19from typing import Container 

20from typing import Dict 

21from typing import FrozenSet 

22from typing import Generic 

23from typing import Iterable 

24from typing import Iterator 

25from typing import List 

26from typing import Mapping 

27from typing import NoReturn 

28from typing import Optional 

29from typing import overload 

30from typing import Sequence 

31from typing import Set 

32from typing import Tuple 

33from typing import TypeVar 

34from typing import Union 

35from typing import ValuesView 

36import weakref 

37 

38from ._has_cy import HAS_CYEXTENSION 

39from .typing import is_non_string_iterable 

40from .typing import Literal 

41from .typing import Protocol 

42 

43if typing.TYPE_CHECKING or not HAS_CYEXTENSION: 

44 from ._py_collections import immutabledict as immutabledict 

45 from ._py_collections import IdentitySet as IdentitySet 

46 from ._py_collections import ReadOnlyContainer as ReadOnlyContainer 

47 from ._py_collections import ImmutableDictBase as ImmutableDictBase 

48 from ._py_collections import OrderedSet as OrderedSet 

49 from ._py_collections import unique_list as unique_list 

50else: 

51 from sqlalchemy.cyextension.immutabledict import ( 

52 ReadOnlyContainer as ReadOnlyContainer, 

53 ) 

54 from sqlalchemy.cyextension.immutabledict import ( 

55 ImmutableDictBase as ImmutableDictBase, 

56 ) 

57 from sqlalchemy.cyextension.immutabledict import ( 

58 immutabledict as immutabledict, 

59 ) 

60 from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet 

61 from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet 

62 from sqlalchemy.cyextension.collections import ( # noqa 

63 unique_list as unique_list, 

64 ) 

65 

66 

67_T = TypeVar("_T", bound=Any) 

68_KT = TypeVar("_KT", bound=Any) 

69_VT = TypeVar("_VT", bound=Any) 

70_T_co = TypeVar("_T_co", covariant=True) 

71 

72EMPTY_SET: FrozenSet[Any] = frozenset() 

73NONE_SET: FrozenSet[Any] = frozenset([None]) 

74 

75 

76def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]: 

77 """merge two lists, maintaining ordering as much as possible. 

78 

79 this is to reconcile vars(cls) with cls.__annotations__. 

80 

81 Example:: 

82 

83 >>> a = ["__tablename__", "id", "x", "created_at"] 

84 >>> b = ["id", "name", "data", "y", "created_at"] 

85 >>> merge_lists_w_ordering(a, b) 

86 ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] 

87 

88 This is not necessarily the ordering that things had on the class, 

89 in this case the class is:: 

90 

91 class User(Base): 

92 __tablename__ = "users" 

93 

94 id: Mapped[int] = mapped_column(primary_key=True) 

95 name: Mapped[str] 

96 data: Mapped[Optional[str]] 

97 x = Column(Integer) 

98 y: Mapped[int] 

99 created_at: Mapped[datetime.datetime] = mapped_column() 

100 

101 But things are *mostly* ordered. 

102 

103 The algorithm could also be done by creating a partial ordering for 

104 all items in both lists and then using topological_sort(), but that 

105 is too much overhead. 

106 

107 Background on how I came up with this is at: 

108 https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae 

109 

110 """ 

111 overlap = set(a).intersection(b) 

112 

113 result = [] 

114 

115 current, other = iter(a), iter(b) 

116 

117 while True: 

118 for element in current: 

119 if element in overlap: 

120 overlap.discard(element) 

121 other, current = current, other 

122 break 

123 

124 result.append(element) 

125 else: 

126 result.extend(other) 

127 break 

128 

129 return result 

130 

131 

132def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]: 

133 if not d: 

134 return EMPTY_DICT 

135 elif isinstance(d, immutabledict): 

136 return d 

137 else: 

138 return immutabledict(d) 

139 

140 

141EMPTY_DICT: immutabledict[Any, Any] = immutabledict() 

142 

143 

144class FacadeDict(ImmutableDictBase[_KT, _VT]): 

145 """A dictionary that is not publicly mutable.""" 

146 

147 def __new__(cls, *args: Any) -> FacadeDict[Any, Any]: 

148 new = ImmutableDictBase.__new__(cls) 

149 return new 

150 

151 def copy(self) -> NoReturn: 

152 raise NotImplementedError( 

153 "an immutabledict shouldn't need to be copied. use dict(d) " 

154 "if you need a mutable dictionary." 

155 ) 

156 

157 def __reduce__(self) -> Any: 

158 return FacadeDict, (dict(self),) 

159 

160 def _insert_item(self, key: _KT, value: _VT) -> None: 

161 """insert an item into the dictionary directly.""" 

162 dict.__setitem__(self, key, value) 

163 

164 def __repr__(self) -> str: 

165 return "FacadeDict(%s)" % dict.__repr__(self) 

166 

167 

168_DT = TypeVar("_DT", bound=Any) 

169 

170_F = TypeVar("_F", bound=Any) 

171 

172 

173class Properties(Generic[_T]): 

174 """Provide a __getattr__/__setattr__ interface over a dict.""" 

175 

176 __slots__ = ("_data",) 

177 

178 _data: Dict[str, _T] 

179 

180 def __init__(self, data: Dict[str, _T]): 

181 object.__setattr__(self, "_data", data) 

182 

183 def __len__(self) -> int: 

184 return len(self._data) 

185 

186 def __iter__(self) -> Iterator[_T]: 

187 return iter(list(self._data.values())) 

188 

189 def __dir__(self) -> List[str]: 

190 return dir(super()) + [str(k) for k in self._data.keys()] 

191 

192 def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: 

193 return list(self) + list(other) 

194 

195 def __setitem__(self, key: str, obj: _T) -> None: 

196 self._data[key] = obj 

197 

198 def __getitem__(self, key: str) -> _T: 

199 return self._data[key] 

200 

201 def __delitem__(self, key: str) -> None: 

202 del self._data[key] 

203 

204 def __setattr__(self, key: str, obj: _T) -> None: 

205 self._data[key] = obj 

206 

207 def __getstate__(self) -> Dict[str, Any]: 

208 return {"_data": self._data} 

209 

210 def __setstate__(self, state: Dict[str, Any]) -> None: 

211 object.__setattr__(self, "_data", state["_data"]) 

212 

213 def __getattr__(self, key: str) -> _T: 

214 try: 

215 return self._data[key] 

216 except KeyError: 

217 raise AttributeError(key) 

218 

219 def __contains__(self, key: str) -> bool: 

220 return key in self._data 

221 

222 def as_readonly(self) -> ReadOnlyProperties[_T]: 

223 """Return an immutable proxy for this :class:`.Properties`.""" 

224 

225 return ReadOnlyProperties(self._data) 

226 

227 def update(self, value: Dict[str, _T]) -> None: 

228 self._data.update(value) 

229 

230 @overload 

231 def get(self, key: str) -> Optional[_T]: ... 

232 

233 @overload 

234 def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ... 

235 

236 def get( 

237 self, key: str, default: Optional[Union[_DT, _T]] = None 

238 ) -> Optional[Union[_T, _DT]]: 

239 if key in self: 

240 return self[key] 

241 else: 

242 return default 

243 

244 def keys(self) -> List[str]: 

245 return list(self._data) 

246 

247 def values(self) -> List[_T]: 

248 return list(self._data.values()) 

249 

250 def items(self) -> List[Tuple[str, _T]]: 

251 return list(self._data.items()) 

252 

253 def has_key(self, key: str) -> bool: 

254 return key in self._data 

255 

256 def clear(self) -> None: 

257 self._data.clear() 

258 

259 

260class OrderedProperties(Properties[_T]): 

261 """Provide a __getattr__/__setattr__ interface with an OrderedDict 

262 as backing store.""" 

263 

264 __slots__ = () 

265 

266 def __init__(self): 

267 Properties.__init__(self, OrderedDict()) 

268 

269 

270class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]): 

271 """Provide immutable dict/object attribute to an underlying dictionary.""" 

272 

273 __slots__ = () 

274 

275 

276def _ordered_dictionary_sort(d, key=None): 

277 """Sort an OrderedDict in-place.""" 

278 

279 items = [(k, d[k]) for k in sorted(d, key=key)] 

280 

281 d.clear() 

282 

283 d.update(items) 

284 

285 

286OrderedDict = dict 

287sort_dictionary = _ordered_dictionary_sort 

288 

289 

290class WeakSequence(Sequence[_T]): 

291 def __init__(self, __elements: Sequence[_T] = ()): 

292 # adapted from weakref.WeakKeyDictionary, prevent reference 

293 # cycles in the collection itself 

294 def _remove(item, selfref=weakref.ref(self)): 

295 self = selfref() 

296 if self is not None: 

297 self._storage.remove(item) 

298 

299 self._remove = _remove 

300 self._storage = [ 

301 weakref.ref(element, _remove) for element in __elements 

302 ] 

303 

304 def append(self, item): 

305 self._storage.append(weakref.ref(item, self._remove)) 

306 

307 def __len__(self): 

308 return len(self._storage) 

309 

310 def __iter__(self): 

311 return ( 

312 obj for obj in (ref() for ref in self._storage) if obj is not None 

313 ) 

314 

315 def __getitem__(self, index): 

316 return self._storage[index]() 

317 

318 

319class OrderedIdentitySet(IdentitySet): 

320 def __init__(self, iterable: Optional[Iterable[Any]] = None): 

321 IdentitySet.__init__(self) 

322 self._members = OrderedDict() 

323 if iterable: 

324 for o in iterable: 

325 self.add(o) 

326 

327 

328class PopulateDict(Dict[_KT, _VT]): 

329 """A dict which populates missing values via a creation function. 

330 

331 Note the creation function takes a key, unlike 

332 collections.defaultdict. 

333 

334 """ 

335 

336 def __init__(self, creator: Callable[[_KT], _VT]): 

337 self.creator = creator 

338 

339 def __missing__(self, key: Any) -> Any: 

340 self[key] = val = self.creator(key) 

341 return val 

342 

343 

344class WeakPopulateDict(Dict[_KT, _VT]): 

345 """Like PopulateDict, but assumes a self + a method and does not create 

346 a reference cycle. 

347 

348 """ 

349 

350 def __init__(self, creator_method: types.MethodType): 

351 self.creator = creator_method.__func__ 

352 weakself = creator_method.__self__ 

353 self.weakself = weakref.ref(weakself) 

354 

355 def __missing__(self, key: Any) -> Any: 

356 self[key] = val = self.creator(self.weakself(), key) 

357 return val 

358 

359 

360# Define collections that are capable of storing 

361# ColumnElement objects as hashable keys/elements. 

362# At this point, these are mostly historical, things 

363# used to be more complicated. 

364column_set = set 

365column_dict = dict 

366ordered_column_set = OrderedSet 

367 

368 

369class UniqueAppender(Generic[_T]): 

370 """Appends items to a collection ensuring uniqueness. 

371 

372 Additional appends() of the same object are ignored. Membership is 

373 determined by identity (``is a``) not equality (``==``). 

374 """ 

375 

376 __slots__ = "data", "_data_appender", "_unique" 

377 

378 data: Union[Iterable[_T], Set[_T], List[_T]] 

379 _data_appender: Callable[[_T], None] 

380 _unique: Dict[int, Literal[True]] 

381 

382 def __init__( 

383 self, 

384 data: Union[Iterable[_T], Set[_T], List[_T]], 

385 via: Optional[str] = None, 

386 ): 

387 self.data = data 

388 self._unique = {} 

389 if via: 

390 self._data_appender = getattr(data, via) 

391 elif hasattr(data, "append"): 

392 self._data_appender = cast("List[_T]", data).append 

393 elif hasattr(data, "add"): 

394 self._data_appender = cast("Set[_T]", data).add 

395 

396 def append(self, item: _T) -> None: 

397 id_ = id(item) 

398 if id_ not in self._unique: 

399 self._data_appender(item) 

400 self._unique[id_] = True 

401 

402 def __iter__(self) -> Iterator[_T]: 

403 return iter(self.data) 

404 

405 

406def coerce_generator_arg(arg: Any) -> List[Any]: 

407 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): 

408 return list(arg[0]) 

409 else: 

410 return cast("List[Any]", arg) 

411 

412 

413def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: 

414 if x is None: 

415 return default # type: ignore 

416 if not is_non_string_iterable(x): 

417 return [x] 

418 elif isinstance(x, list): 

419 return x 

420 else: 

421 return list(x) 

422 

423 

424def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool: 

425 r"""return True if any items of set\_ are present in iterable. 

426 

427 Goes through special effort to ensure __hash__ is not called 

428 on items in iterable that don't support it. 

429 

430 """ 

431 return any(i in set_ for i in iterable if i.__hash__) 

432 

433 

434def to_set(x): 

435 if x is None: 

436 return set() 

437 if not isinstance(x, set): 

438 return set(to_list(x)) 

439 else: 

440 return x 

441 

442 

443def to_column_set(x: Any) -> Set[Any]: 

444 if x is None: 

445 return column_set() 

446 if not isinstance(x, column_set): 

447 return column_set(to_list(x)) 

448 else: 

449 return x 

450 

451 

452def update_copy( 

453 d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any 

454) -> Dict[Any, Any]: 

455 """Copy the given dict and update with the given values.""" 

456 

457 d = d.copy() 

458 if _new: 

459 d.update(_new) 

460 d.update(**kw) 

461 return d 

462 

463 

464def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: 

465 """Given an iterator of which further sub-elements may also be 

466 iterators, flatten the sub-elements into a single iterator. 

467 

468 """ 

469 elem: _T 

470 for elem in x: 

471 if not isinstance(elem, str) and hasattr(elem, "__iter__"): 

472 yield from flatten_iterator(elem) 

473 else: 

474 yield elem 

475 

476 

477class LRUCache(typing.MutableMapping[_KT, _VT]): 

478 """Dictionary with 'squishy' removal of least 

479 recently used items. 

480 

481 Note that either get() or [] should be used here, but 

482 generally its not safe to do an "in" check first as the dictionary 

483 can change subsequent to that call. 

484 

485 """ 

486 

487 __slots__ = ( 

488 "capacity", 

489 "threshold", 

490 "size_alert", 

491 "_data", 

492 "_counter", 

493 "_mutex", 

494 ) 

495 

496 capacity: int 

497 threshold: float 

498 size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]] 

499 

500 def __init__( 

501 self, 

502 capacity: int = 100, 

503 threshold: float = 0.5, 

504 size_alert: Optional[Callable[..., None]] = None, 

505 ): 

506 self.capacity = capacity 

507 self.threshold = threshold 

508 self.size_alert = size_alert 

509 self._counter = 0 

510 self._mutex = threading.Lock() 

511 self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {} 

512 

513 def _inc_counter(self): 

514 self._counter += 1 

515 return self._counter 

516 

517 @overload 

518 def get(self, key: _KT) -> Optional[_VT]: ... 

519 

520 @overload 

521 def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... 

522 

523 def get( 

524 self, key: _KT, default: Optional[Union[_VT, _T]] = None 

525 ) -> Optional[Union[_VT, _T]]: 

526 item = self._data.get(key) 

527 if item is not None: 

528 item[2][0] = self._inc_counter() 

529 return item[1] 

530 else: 

531 return default 

532 

533 def __getitem__(self, key: _KT) -> _VT: 

534 item = self._data[key] 

535 item[2][0] = self._inc_counter() 

536 return item[1] 

537 

538 def __iter__(self) -> Iterator[_KT]: 

539 return iter(self._data) 

540 

541 def __len__(self) -> int: 

542 return len(self._data) 

543 

544 def values(self) -> ValuesView[_VT]: 

545 return typing.ValuesView({k: i[1] for k, i in self._data.items()}) 

546 

547 def __setitem__(self, key: _KT, value: _VT) -> None: 

548 self._data[key] = (key, value, [self._inc_counter()]) 

549 self._manage_size() 

550 

551 def __delitem__(self, __v: _KT) -> None: 

552 del self._data[__v] 

553 

554 @property 

555 def size_threshold(self) -> float: 

556 return self.capacity + self.capacity * self.threshold 

557 

558 def _manage_size(self) -> None: 

559 if not self._mutex.acquire(False): 

560 return 

561 try: 

562 size_alert = bool(self.size_alert) 

563 while len(self) > self.capacity + self.capacity * self.threshold: 

564 if size_alert: 

565 size_alert = False 

566 self.size_alert(self) # type: ignore 

567 by_counter = sorted( 

568 self._data.values(), 

569 key=operator.itemgetter(2), 

570 reverse=True, 

571 ) 

572 for item in by_counter[self.capacity :]: 

573 try: 

574 del self._data[item[0]] 

575 except KeyError: 

576 # deleted elsewhere; skip 

577 continue 

578 finally: 

579 self._mutex.release() 

580 

581 

582class _CreateFuncType(Protocol[_T_co]): 

583 def __call__(self) -> _T_co: ... 

584 

585 

586class _ScopeFuncType(Protocol): 

587 def __call__(self) -> Any: ... 

588 

589 

590class ScopedRegistry(Generic[_T]): 

591 """A Registry that can store one or multiple instances of a single 

592 class on the basis of a "scope" function. 

593 

594 The object implements ``__call__`` as the "getter", so by 

595 calling ``myregistry()`` the contained object is returned 

596 for the current scope. 

597 

598 :param createfunc: 

599 a callable that returns a new object to be placed in the registry 

600 

601 :param scopefunc: 

602 a callable that will return a key to store/retrieve an object. 

603 """ 

604 

605 __slots__ = "createfunc", "scopefunc", "registry" 

606 

607 createfunc: _CreateFuncType[_T] 

608 scopefunc: _ScopeFuncType 

609 registry: Any 

610 

611 def __init__( 

612 self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any] 

613 ): 

614 """Construct a new :class:`.ScopedRegistry`. 

615 

616 :param createfunc: A creation function that will generate 

617 a new value for the current scope, if none is present. 

618 

619 :param scopefunc: A function that returns a hashable 

620 token representing the current scope (such as, current 

621 thread identifier). 

622 

623 """ 

624 self.createfunc = createfunc 

625 self.scopefunc = scopefunc 

626 self.registry = {} 

627 

628 def __call__(self) -> _T: 

629 key = self.scopefunc() 

630 try: 

631 return self.registry[key] # type: ignore[no-any-return] 

632 except KeyError: 

633 return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501 

634 

635 def has(self) -> bool: 

636 """Return True if an object is present in the current scope.""" 

637 

638 return self.scopefunc() in self.registry 

639 

640 def set(self, obj: _T) -> None: 

641 """Set the value for the current scope.""" 

642 

643 self.registry[self.scopefunc()] = obj 

644 

645 def clear(self) -> None: 

646 """Clear the current scope, if any.""" 

647 

648 try: 

649 del self.registry[self.scopefunc()] 

650 except KeyError: 

651 pass 

652 

653 

654class ThreadLocalRegistry(ScopedRegistry[_T]): 

655 """A :class:`.ScopedRegistry` that uses a ``threading.local()`` 

656 variable for storage. 

657 

658 """ 

659 

660 def __init__(self, createfunc: Callable[[], _T]): 

661 self.createfunc = createfunc 

662 self.registry = threading.local() 

663 

664 def __call__(self) -> _T: 

665 try: 

666 return self.registry.value # type: ignore[no-any-return] 

667 except AttributeError: 

668 val = self.registry.value = self.createfunc() 

669 return val 

670 

671 def has(self) -> bool: 

672 return hasattr(self.registry, "value") 

673 

674 def set(self, obj: _T) -> None: 

675 self.registry.value = obj 

676 

677 def clear(self) -> None: 

678 try: 

679 del self.registry.value 

680 except AttributeError: 

681 pass 

682 

683 

684def has_dupes(sequence, target): 

685 """Given a sequence and search object, return True if there's more 

686 than one, False if zero or one of them. 

687 

688 

689 """ 

690 # compare to .index version below, this version introduces less function 

691 # overhead and is usually the same speed. At 15000 items (way bigger than 

692 # a relationship-bound collection in memory usually is) it begins to 

693 # fall behind the other version only by microseconds. 

694 c = 0 

695 for item in sequence: 

696 if item is target: 

697 c += 1 

698 if c > 1: 

699 return True 

700 return False 

701 

702 

703# .index version. the two __contains__ calls as well 

704# as .index() and isinstance() slow this down. 

705# def has_dupes(sequence, target): 

706# if target not in sequence: 

707# return False 

708# elif not isinstance(sequence, collections_abc.Sequence): 

709# return False 

710# 

711# idx = sequence.index(target) 

712# return target in sequence[idx + 1:]