Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/util/_collections.py: 56%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

354 statements  

1# util/_collections.py 

2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7# mypy: allow-untyped-defs, allow-untyped-calls 

8 

9"""Collection classes and helpers.""" 

10from __future__ import annotations 

11 

12import operator 

13import threading 

14import types 

15import typing 

16from typing import Any 

17from typing import Callable 

18from typing import cast 

19from typing import Container 

20from typing import Dict 

21from typing import FrozenSet 

22from typing import Generic 

23from typing import Iterable 

24from typing import Iterator 

25from typing import List 

26from typing import Mapping 

27from typing import NoReturn 

28from typing import Optional 

29from typing import overload 

30from typing import Protocol 

31from typing import Sequence 

32from typing import Set 

33from typing import Tuple 

34from typing import TypeVar 

35from typing import Union 

36from typing import ValuesView 

37import weakref 

38 

39from ._collections_cy import IdentitySet as IdentitySet 

40from ._collections_cy import OrderedSet as OrderedSet 

41from ._collections_cy import unique_list as unique_list # noqa: F401 

42from ._immutabledict_cy import immutabledict as immutabledict 

43from ._immutabledict_cy import ImmutableDictBase as ImmutableDictBase 

44from ._immutabledict_cy import ReadOnlyContainer as ReadOnlyContainer 

45from .typing import is_non_string_iterable 

46from .typing import Literal 

47 

48 

49_T = TypeVar("_T", bound=Any) 

50_KT = TypeVar("_KT", bound=Any) 

51_VT = TypeVar("_VT", bound=Any) 

52_T_co = TypeVar("_T_co", covariant=True) 

53 

54EMPTY_SET: FrozenSet[Any] = frozenset() 

55NONE_SET: FrozenSet[Any] = frozenset([None]) 

56 

57 

58def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]: 

59 """merge two lists, maintaining ordering as much as possible. 

60 

61 this is to reconcile vars(cls) with cls.__annotations__. 

62 

63 Example:: 

64 

65 >>> a = ["__tablename__", "id", "x", "created_at"] 

66 >>> b = ["id", "name", "data", "y", "created_at"] 

67 >>> merge_lists_w_ordering(a, b) 

68 ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] 

69 

70 This is not necessarily the ordering that things had on the class, 

71 in this case the class is:: 

72 

73 class User(Base): 

74 __tablename__ = "users" 

75 

76 id: Mapped[int] = mapped_column(primary_key=True) 

77 name: Mapped[str] 

78 data: Mapped[Optional[str]] 

79 x = Column(Integer) 

80 y: Mapped[int] 

81 created_at: Mapped[datetime.datetime] = mapped_column() 

82 

83 But things are *mostly* ordered. 

84 

85 The algorithm could also be done by creating a partial ordering for 

86 all items in both lists and then using topological_sort(), but that 

87 is too much overhead. 

88 

89 Background on how I came up with this is at: 

90 https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae 

91 

92 """ 

93 overlap = set(a).intersection(b) 

94 

95 result = [] 

96 

97 current, other = iter(a), iter(b) 

98 

99 while True: 

100 for element in current: 

101 if element in overlap: 

102 overlap.discard(element) 

103 other, current = current, other 

104 break 

105 

106 result.append(element) 

107 else: 

108 result.extend(other) 

109 break 

110 

111 return result 

112 

113 

114def coerce_to_immutabledict(d: Mapping[_KT, _VT]) -> immutabledict[_KT, _VT]: 

115 if not d: 

116 return EMPTY_DICT 

117 elif isinstance(d, immutabledict): 

118 return d 

119 else: 

120 return immutabledict(d) 

121 

122 

123EMPTY_DICT: immutabledict[Any, Any] = immutabledict() 

124 

125 

126class FacadeDict(ImmutableDictBase[_KT, _VT]): 

127 """A dictionary that is not publicly mutable.""" 

128 

129 def __new__(cls, *args: Any) -> FacadeDict[Any, Any]: 

130 new: FacadeDict[Any, Any] = ImmutableDictBase.__new__(cls) 

131 return new 

132 

133 def copy(self) -> NoReturn: 

134 raise NotImplementedError( 

135 "an immutabledict shouldn't need to be copied. use dict(d) " 

136 "if you need a mutable dictionary." 

137 ) 

138 

139 def __reduce__(self) -> Any: 

140 return FacadeDict, (dict(self),) 

141 

142 def _insert_item(self, key: _KT, value: _VT) -> None: 

143 """insert an item into the dictionary directly.""" 

144 dict.__setitem__(self, key, value) 

145 

146 def __repr__(self) -> str: 

147 return "FacadeDict(%s)" % dict.__repr__(self) 

148 

149 

150_DT = TypeVar("_DT", bound=Any) 

151 

152_F = TypeVar("_F", bound=Any) 

153 

154 

155class Properties(Generic[_T]): 

156 """Provide a __getattr__/__setattr__ interface over a dict.""" 

157 

158 __slots__ = ("_data",) 

159 

160 _data: Dict[str, _T] 

161 

162 def __init__(self, data: Dict[str, _T]): 

163 object.__setattr__(self, "_data", data) 

164 

165 def __len__(self) -> int: 

166 return len(self._data) 

167 

168 def __iter__(self) -> Iterator[_T]: 

169 return iter(list(self._data.values())) 

170 

171 def __dir__(self) -> List[str]: 

172 return dir(super()) + [str(k) for k in self._data.keys()] 

173 

174 def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: 

175 return list(self) + list(other) 

176 

177 def __setitem__(self, key: str, obj: _T) -> None: 

178 self._data[key] = obj 

179 

180 def __getitem__(self, key: str) -> _T: 

181 return self._data[key] 

182 

183 def __delitem__(self, key: str) -> None: 

184 del self._data[key] 

185 

186 def __setattr__(self, key: str, obj: _T) -> None: 

187 self._data[key] = obj 

188 

189 def __getstate__(self) -> Dict[str, Any]: 

190 return {"_data": self._data} 

191 

192 def __setstate__(self, state: Dict[str, Any]) -> None: 

193 object.__setattr__(self, "_data", state["_data"]) 

194 

195 def __getattr__(self, key: str) -> _T: 

196 try: 

197 return self._data[key] 

198 except KeyError: 

199 raise AttributeError(key) 

200 

201 def __contains__(self, key: str) -> bool: 

202 return key in self._data 

203 

204 def as_readonly(self) -> ReadOnlyProperties[_T]: 

205 """Return an immutable proxy for this :class:`.Properties`.""" 

206 

207 return ReadOnlyProperties(self._data) 

208 

209 def update(self, value: Dict[str, _T]) -> None: 

210 self._data.update(value) 

211 

212 @overload 

213 def get(self, key: str) -> Optional[_T]: ... 

214 

215 @overload 

216 def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ... 

217 

218 def get( 

219 self, key: str, default: Optional[Union[_DT, _T]] = None 

220 ) -> Optional[Union[_T, _DT]]: 

221 if key in self: 

222 return self[key] 

223 else: 

224 return default 

225 

226 def keys(self) -> List[str]: 

227 return list(self._data) 

228 

229 def values(self) -> List[_T]: 

230 return list(self._data.values()) 

231 

232 def items(self) -> List[Tuple[str, _T]]: 

233 return list(self._data.items()) 

234 

235 def has_key(self, key: str) -> bool: 

236 return key in self._data 

237 

238 def clear(self) -> None: 

239 self._data.clear() 

240 

241 

242class OrderedProperties(Properties[_T]): 

243 """Provide a __getattr__/__setattr__ interface with an OrderedDict 

244 as backing store.""" 

245 

246 __slots__ = () 

247 

248 def __init__(self): 

249 Properties.__init__(self, OrderedDict()) 

250 

251 

252class ReadOnlyProperties(ReadOnlyContainer, Properties[_T]): 

253 """Provide immutable dict/object attribute to an underlying dictionary.""" 

254 

255 __slots__ = () 

256 

257 

258def _ordered_dictionary_sort(d, key=None): 

259 """Sort an OrderedDict in-place.""" 

260 

261 items = [(k, d[k]) for k in sorted(d, key=key)] 

262 

263 d.clear() 

264 

265 d.update(items) 

266 

267 

268OrderedDict = dict 

269sort_dictionary = _ordered_dictionary_sort 

270 

271 

272class WeakSequence(Sequence[_T]): 

273 def __init__(self, __elements: Sequence[_T] = (), /): 

274 # adapted from weakref.WeakKeyDictionary, prevent reference 

275 # cycles in the collection itself 

276 def _remove(item, selfref=weakref.ref(self)): 

277 self = selfref() 

278 if self is not None: 

279 self._storage.remove(item) 

280 

281 self._remove = _remove 

282 self._storage = [ 

283 weakref.ref(element, _remove) for element in __elements 

284 ] 

285 

286 def append(self, item): 

287 self._storage.append(weakref.ref(item, self._remove)) 

288 

289 def __len__(self): 

290 return len(self._storage) 

291 

292 def __iter__(self): 

293 return ( 

294 obj for obj in (ref() for ref in self._storage) if obj is not None 

295 ) 

296 

297 def __getitem__(self, index): 

298 try: 

299 obj = self._storage[index] 

300 except KeyError: 

301 raise IndexError("Index %s out of range" % index) 

302 else: 

303 return obj() 

304 

305 

306OrderedIdentitySet = IdentitySet 

307 

308 

309class PopulateDict(Dict[_KT, _VT]): 

310 """A dict which populates missing values via a creation function. 

311 

312 Note the creation function takes a key, unlike 

313 collections.defaultdict. 

314 

315 """ 

316 

317 def __init__(self, creator: Callable[[_KT], _VT]): 

318 self.creator = creator 

319 

320 def __missing__(self, key: Any) -> Any: 

321 self[key] = val = self.creator(key) 

322 return val 

323 

324 

325class WeakPopulateDict(Dict[_KT, _VT]): 

326 """Like PopulateDict, but assumes a self + a method and does not create 

327 a reference cycle. 

328 

329 """ 

330 

331 def __init__(self, creator_method: types.MethodType): 

332 self.creator = creator_method.__func__ 

333 weakself = creator_method.__self__ 

334 self.weakself = weakref.ref(weakself) 

335 

336 def __missing__(self, key: Any) -> Any: 

337 self[key] = val = self.creator(self.weakself(), key) 

338 return val 

339 

340 

341# Define collections that are capable of storing 

342# ColumnElement objects as hashable keys/elements. 

343# At this point, these are mostly historical, things 

344# used to be more complicated. 

345column_set = set 

346column_dict = dict 

347ordered_column_set = OrderedSet 

348 

349 

350class UniqueAppender(Generic[_T]): 

351 """Appends items to a collection ensuring uniqueness. 

352 

353 Additional appends() of the same object are ignored. Membership is 

354 determined by identity (``is a``) not equality (``==``). 

355 """ 

356 

357 __slots__ = "data", "_data_appender", "_unique" 

358 

359 data: Union[Iterable[_T], Set[_T], List[_T]] 

360 _data_appender: Callable[[_T], None] 

361 _unique: Dict[int, Literal[True]] 

362 

363 def __init__( 

364 self, 

365 data: Union[Iterable[_T], Set[_T], List[_T]], 

366 via: Optional[str] = None, 

367 ): 

368 self.data = data 

369 self._unique = {} 

370 if via: 

371 self._data_appender = getattr(data, via) 

372 elif hasattr(data, "append"): 

373 self._data_appender = cast("List[_T]", data).append 

374 elif hasattr(data, "add"): 

375 self._data_appender = cast("Set[_T]", data).add 

376 

377 def append(self, item: _T) -> None: 

378 id_ = id(item) 

379 if id_ not in self._unique: 

380 self._data_appender(item) 

381 self._unique[id_] = True 

382 

383 def __iter__(self) -> Iterator[_T]: 

384 return iter(self.data) 

385 

386 

387def coerce_generator_arg(arg: Any) -> List[Any]: 

388 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): 

389 return list(arg[0]) 

390 else: 

391 return cast("List[Any]", arg) 

392 

393 

394def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: 

395 if x is None: 

396 return default # type: ignore 

397 if not is_non_string_iterable(x): 

398 return [x] 

399 elif isinstance(x, list): 

400 return x 

401 else: 

402 return list(x) 

403 

404 

405def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool: 

406 r"""return True if any items of set\_ are present in iterable. 

407 

408 Goes through special effort to ensure __hash__ is not called 

409 on items in iterable that don't support it. 

410 

411 """ 

412 return any(i in set_ for i in iterable if i.__hash__) 

413 

414 

415def to_set(x): 

416 if x is None: 

417 return set() 

418 if not isinstance(x, set): 

419 return set(to_list(x)) 

420 else: 

421 return x 

422 

423 

424def to_column_set(x: Any) -> Set[Any]: 

425 if x is None: 

426 return column_set() 

427 if not isinstance(x, column_set): 

428 return column_set(to_list(x)) 

429 else: 

430 return x 

431 

432 

433def update_copy( 

434 d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any 

435) -> Dict[Any, Any]: 

436 """Copy the given dict and update with the given values.""" 

437 

438 d = d.copy() 

439 if _new: 

440 d.update(_new) 

441 d.update(**kw) 

442 return d 

443 

444 

445def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: 

446 """Given an iterator of which further sub-elements may also be 

447 iterators, flatten the sub-elements into a single iterator. 

448 

449 """ 

450 elem: _T 

451 for elem in x: 

452 if not isinstance(elem, str) and hasattr(elem, "__iter__"): 

453 yield from flatten_iterator(elem) 

454 else: 

455 yield elem 

456 

457 

458class LRUCache(typing.MutableMapping[_KT, _VT]): 

459 """Dictionary with 'squishy' removal of least 

460 recently used items. 

461 

462 Note that either get() or [] should be used here, but 

463 generally its not safe to do an "in" check first as the dictionary 

464 can change subsequent to that call. 

465 

466 """ 

467 

468 __slots__ = ( 

469 "capacity", 

470 "threshold", 

471 "size_alert", 

472 "_data", 

473 "_counter", 

474 "_mutex", 

475 ) 

476 

477 capacity: int 

478 threshold: float 

479 size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]] 

480 

481 def __init__( 

482 self, 

483 capacity: int = 100, 

484 threshold: float = 0.5, 

485 size_alert: Optional[Callable[..., None]] = None, 

486 ): 

487 self.capacity = capacity 

488 self.threshold = threshold 

489 self.size_alert = size_alert 

490 self._counter = 0 

491 self._mutex = threading.Lock() 

492 self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {} 

493 

494 def _inc_counter(self): 

495 self._counter += 1 

496 return self._counter 

497 

498 @overload 

499 def get(self, key: _KT) -> Optional[_VT]: ... 

500 

501 @overload 

502 def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... 

503 

504 def get( 

505 self, key: _KT, default: Optional[Union[_VT, _T]] = None 

506 ) -> Optional[Union[_VT, _T]]: 

507 item = self._data.get(key) 

508 if item is not None: 

509 item[2][0] = self._inc_counter() 

510 return item[1] 

511 else: 

512 return default 

513 

514 def __getitem__(self, key: _KT) -> _VT: 

515 item = self._data[key] 

516 item[2][0] = self._inc_counter() 

517 return item[1] 

518 

519 def __iter__(self) -> Iterator[_KT]: 

520 return iter(self._data) 

521 

522 def __len__(self) -> int: 

523 return len(self._data) 

524 

525 def values(self) -> ValuesView[_VT]: 

526 return typing.ValuesView({k: i[1] for k, i in self._data.items()}) 

527 

528 def __setitem__(self, key: _KT, value: _VT) -> None: 

529 self._data[key] = (key, value, [self._inc_counter()]) 

530 self._manage_size() 

531 

532 def __delitem__(self, __v: _KT) -> None: 

533 del self._data[__v] 

534 

535 @property 

536 def size_threshold(self) -> float: 

537 return self.capacity + self.capacity * self.threshold 

538 

539 def _manage_size(self) -> None: 

540 if not self._mutex.acquire(False): 

541 return 

542 try: 

543 size_alert = bool(self.size_alert) 

544 while len(self) > self.capacity + self.capacity * self.threshold: 

545 if size_alert: 

546 size_alert = False 

547 self.size_alert(self) # type: ignore 

548 by_counter = sorted( 

549 self._data.values(), 

550 key=operator.itemgetter(2), 

551 reverse=True, 

552 ) 

553 for item in by_counter[self.capacity :]: 

554 try: 

555 del self._data[item[0]] 

556 except KeyError: 

557 # deleted elsewhere; skip 

558 continue 

559 finally: 

560 self._mutex.release() 

561 

562 

563class _CreateFuncType(Protocol[_T_co]): 

564 def __call__(self) -> _T_co: ... 

565 

566 

567class _ScopeFuncType(Protocol): 

568 def __call__(self) -> Any: ... 

569 

570 

571class ScopedRegistry(Generic[_T]): 

572 """A Registry that can store one or multiple instances of a single 

573 class on the basis of a "scope" function. 

574 

575 The object implements ``__call__`` as the "getter", so by 

576 calling ``myregistry()`` the contained object is returned 

577 for the current scope. 

578 

579 :param createfunc: 

580 a callable that returns a new object to be placed in the registry 

581 

582 :param scopefunc: 

583 a callable that will return a key to store/retrieve an object. 

584 """ 

585 

586 __slots__ = "createfunc", "scopefunc", "registry" 

587 

588 createfunc: _CreateFuncType[_T] 

589 scopefunc: _ScopeFuncType 

590 registry: Any 

591 

592 def __init__( 

593 self, createfunc: Callable[[], _T], scopefunc: Callable[[], Any] 

594 ): 

595 """Construct a new :class:`.ScopedRegistry`. 

596 

597 :param createfunc: A creation function that will generate 

598 a new value for the current scope, if none is present. 

599 

600 :param scopefunc: A function that returns a hashable 

601 token representing the current scope (such as, current 

602 thread identifier). 

603 

604 """ 

605 self.createfunc = createfunc 

606 self.scopefunc = scopefunc 

607 self.registry = {} 

608 

609 def __call__(self) -> _T: 

610 key = self.scopefunc() 

611 try: 

612 return self.registry[key] # type: ignore[no-any-return] 

613 except KeyError: 

614 return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501 

615 

616 def has(self) -> bool: 

617 """Return True if an object is present in the current scope.""" 

618 

619 return self.scopefunc() in self.registry 

620 

621 def set(self, obj: _T) -> None: 

622 """Set the value for the current scope.""" 

623 

624 self.registry[self.scopefunc()] = obj 

625 

626 def clear(self) -> None: 

627 """Clear the current scope, if any.""" 

628 

629 try: 

630 del self.registry[self.scopefunc()] 

631 except KeyError: 

632 pass 

633 

634 

635class ThreadLocalRegistry(ScopedRegistry[_T]): 

636 """A :class:`.ScopedRegistry` that uses a ``threading.local()`` 

637 variable for storage. 

638 

639 """ 

640 

641 def __init__(self, createfunc: Callable[[], _T]): 

642 self.createfunc = createfunc 

643 self.registry = threading.local() 

644 

645 def __call__(self) -> _T: 

646 try: 

647 return self.registry.value # type: ignore[no-any-return] 

648 except AttributeError: 

649 val = self.registry.value = self.createfunc() 

650 return val 

651 

652 def has(self) -> bool: 

653 return hasattr(self.registry, "value") 

654 

655 def set(self, obj: _T) -> None: 

656 self.registry.value = obj 

657 

658 def clear(self) -> None: 

659 try: 

660 del self.registry.value 

661 except AttributeError: 

662 pass 

663 

664 

665def has_dupes(sequence, target): 

666 """Given a sequence and search object, return True if there's more 

667 than one, False if zero or one of them. 

668 

669 

670 """ 

671 # compare to .index version below, this version introduces less function 

672 # overhead and is usually the same speed. At 15000 items (way bigger than 

673 # a relationship-bound collection in memory usually is) it begins to 

674 # fall behind the other version only by microseconds. 

675 c = 0 

676 for item in sequence: 

677 if item is target: 

678 c += 1 

679 if c > 1: 

680 return True 

681 return False 

682 

683 

684# .index version. the two __contains__ calls as well 

685# as .index() and isinstance() slow this down. 

686# def has_dupes(sequence, target): 

687# if target not in sequence: 

688# return False 

689# elif not isinstance(sequence, collections_abc.Sequence): 

690# return False 

691# 

692# idx = sequence.index(target) 

693# return target in sequence[idx + 1:]