1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import threading
12from collections import OrderedDict
13from typing import Any, Generic, TypeVar
14
15import attr
16
17from hypothesis.errors import InvalidArgument
18
19K = TypeVar("K")
20V = TypeVar("V")
21
22
23@attr.s(slots=True)
24class Entry(Generic[K, V]):
25 key: K = attr.ib()
26 value: V = attr.ib()
27 score: int = attr.ib()
28 pins: int = attr.ib(default=0)
29
30 @property
31 def sort_key(self) -> tuple[int, ...]:
32 if self.pins == 0:
33 # Unpinned entries are sorted by score.
34 return (0, self.score)
35 else:
36 # Pinned entries sort after unpinned ones. Beyond that, we don't
37 # worry about their relative order.
38 return (1,)
39
40
41class GenericCache(Generic[K, V]):
42 """Generic supertype for cache implementations.
43
44 Defines a dict-like mapping with a maximum size, where as well as mapping
45 to a value, each key also maps to a score. When a write would cause the
46 dict to exceed its maximum size, it first evicts the existing key with
47 the smallest score, then adds the new key to the map. If due to pinning
48 no key can be evicted, ValueError is raised.
49
50 A key has the following lifecycle:
51
52 1. key is written for the first time, the key is given the score
53 self.new_entry(key, value)
54 2. whenever an existing key is read or written, self.on_access(key, value,
55 score) is called. This returns a new score for the key.
56 3. After a key is evicted, self.on_evict(key, value, score) is called.
57
58 The cache will be in a valid state in all of these cases.
59
60 Implementations are expected to implement new_entry and optionally
61 on_access and on_evict to implement a specific scoring strategy.
62 """
63
64 __slots__ = ("_threadlocal", "max_size")
65
66 def __init__(self, max_size: int):
67 if max_size <= 0:
68 raise InvalidArgument("Cache size must be at least one.")
69
70 self.max_size = max_size
71
72 # Implementation: We store a binary heap of Entry objects in self.data,
73 # with the heap property requiring that a parent's score is <= that of
74 # its children. keys_to_index then maps keys to their index in the
75 # heap. We keep these two in sync automatically - the heap is never
76 # reordered without updating the index.
77 self._threadlocal = threading.local()
78
79 @property
80 def keys_to_indices(self) -> dict[K, int]:
81 try:
82 return self._threadlocal.keys_to_indices
83 except AttributeError:
84 self._threadlocal.keys_to_indices = {}
85 return self._threadlocal.keys_to_indices
86
87 @property
88 def data(self) -> list[Entry[K, V]]:
89 try:
90 return self._threadlocal.data
91 except AttributeError:
92 self._threadlocal.data = []
93 return self._threadlocal.data
94
95 def __len__(self) -> int:
96 assert len(self.keys_to_indices) == len(self.data)
97 return len(self.data)
98
99 def __contains__(self, key: K) -> bool:
100 return key in self.keys_to_indices
101
102 def __getitem__(self, key: K) -> V:
103 i = self.keys_to_indices[key]
104 result = self.data[i]
105 self.__entry_was_accessed(i)
106 return result.value
107
108 def __setitem__(self, key: K, value: V) -> None:
109 evicted = None
110 try:
111 i = self.keys_to_indices[key]
112 except KeyError:
113 entry = Entry(key, value, self.new_entry(key, value))
114 if len(self.data) >= self.max_size:
115 evicted = self.data[0]
116 if evicted.pins > 0:
117 raise ValueError(
118 "Cannot increase size of cache where all keys have been pinned."
119 ) from None
120 try:
121 del self.keys_to_indices[evicted.key]
122 except KeyError: # pragma: no cover
123 # This can't happen, but happens nevertheless with
124 # id(key1) == id(key2)
125 # but
126 # hash(key1) != hash(key2)
127 # (see https://github.com/HypothesisWorks/hypothesis/issues/4442)
128 # Rebuild keys_to_indices to match data.
129 self.keys_to_indices.clear()
130 self.keys_to_indices.update(
131 {
132 entry.key: i
133 for i, entry in enumerate(self.data)
134 if entry is not evicted
135 }
136 )
137 assert len(self.keys_to_indices) == len(self.data) - 1
138 i = 0
139 self.data[0] = entry
140 else:
141 i = len(self.data)
142 self.data.append(entry)
143 self.keys_to_indices[key] = i
144 self.__balance(i)
145 else:
146 entry = self.data[i]
147 assert entry.key == key
148 entry.value = value
149 self.__entry_was_accessed(i)
150
151 if evicted is not None:
152 if self.data[0] is not entry:
153 assert evicted.sort_key <= self.data[0].sort_key
154 self.on_evict(evicted.key, evicted.value, evicted.score)
155
156 def __iter__(self):
157 return iter(self.keys_to_indices)
158
159 def pin(self, key: K, value: V) -> None:
160 """Mark ``key`` as pinned (with the given value). That is, it may not
161 be evicted until ``unpin(key)`` has been called. The same key may be
162 pinned multiple times, possibly changing its value, and will not be
163 unpinned until the same number of calls to unpin have been made.
164 """
165 self[key] = value
166
167 i = self.keys_to_indices[key]
168 entry = self.data[i]
169 entry.pins += 1
170 if entry.pins == 1:
171 self.__balance(i)
172
173 def unpin(self, key: K) -> None:
174 """Undo one previous call to ``pin(key)``. The value stays the same.
175 Once all calls are undone this key may be evicted as normal."""
176 i = self.keys_to_indices[key]
177 entry = self.data[i]
178 if entry.pins == 0:
179 raise ValueError(f"Key {key!r} has not been pinned")
180 entry.pins -= 1
181 if entry.pins == 0:
182 self.__balance(i)
183
184 def is_pinned(self, key: K) -> bool:
185 """Returns True if the key is currently pinned."""
186 i = self.keys_to_indices[key]
187 return self.data[i].pins > 0
188
189 def clear(self) -> None:
190 """Remove all keys, regardless of their pinned status."""
191 del self.data[:]
192 self.keys_to_indices.clear()
193
194 def __repr__(self) -> str:
195 return "{" + ", ".join(f"{e.key!r}: {e.value!r}" for e in self.data) + "}"
196
197 def new_entry(self, key: K, value: V) -> int:
198 """Called when a key is written that does not currently appear in the
199 map.
200
201 Returns the score to associate with the key.
202 """
203 raise NotImplementedError
204
205 def on_access(self, key: K, value: V, score: Any) -> Any:
206 """Called every time a key that is already in the map is read or
207 written.
208
209 Returns the new score for the key.
210 """
211 return score
212
213 def on_evict(self, key: K, value: V, score: Any) -> Any:
214 """Called after a key has been evicted, with the score it had had at
215 the point of eviction."""
216
217 def check_valid(self) -> None:
218 """Debugging method for use in tests.
219
220 Asserts that all of the cache's invariants hold. When everything
221 is working correctly this should be an expensive no-op.
222 """
223 assert len(self.keys_to_indices) == len(self.data)
224 for i, e in enumerate(self.data):
225 assert self.keys_to_indices[e.key] == i
226 for j in [i * 2 + 1, i * 2 + 2]:
227 if j < len(self.data):
228 assert e.sort_key <= self.data[j].sort_key, self.data
229
230 def __entry_was_accessed(self, i: int) -> None:
231 entry = self.data[i]
232 new_score = self.on_access(entry.key, entry.value, entry.score)
233 if new_score != entry.score:
234 entry.score = new_score
235 # changing the score of a pinned entry cannot unbalance the heap, as
236 # we place all pinned entries after unpinned ones, regardless of score.
237 if entry.pins == 0:
238 self.__balance(i)
239
240 def __swap(self, i: int, j: int) -> None:
241 assert i < j
242 assert self.data[j].sort_key < self.data[i].sort_key
243 self.data[i], self.data[j] = self.data[j], self.data[i]
244 self.keys_to_indices[self.data[i].key] = i
245 self.keys_to_indices[self.data[j].key] = j
246
247 def __balance(self, i: int) -> None:
248 """When we have made a modification to the heap such that
249 the heap property has been violated locally around i but previously
250 held for all other indexes (and no other values have been modified),
251 this fixes the heap so that the heap property holds everywhere."""
252 # bubble up (if score is too low for current position)
253 while (parent := (i - 1) // 2) >= 0:
254 if self.__out_of_order(parent, i):
255 self.__swap(parent, i)
256 i = parent
257 else:
258 break
259 # or bubble down (if score is too high for current position)
260 while children := [j for j in (2 * i + 1, 2 * i + 2) if j < len(self.data)]:
261 smallest_child = min(children, key=lambda j: self.data[j].sort_key)
262 if self.__out_of_order(i, smallest_child):
263 self.__swap(i, smallest_child)
264 i = smallest_child
265 else:
266 break
267
268 def __out_of_order(self, i: int, j: int) -> bool:
269 """Returns True if the indices i, j are in the wrong order.
270
271 i must be the parent of j.
272 """
273 assert i == (j - 1) // 2
274 return self.data[j].sort_key < self.data[i].sort_key
275
276
277class LRUReusedCache(GenericCache[K, V]):
278 """The only concrete implementation of GenericCache we use outside of tests
279 currently.
280
281 Adopts a modified least-recently used eviction policy: It evicts the key
282 that has been used least recently, but it will always preferentially evict
283 keys that have never been accessed after insertion. Among keys that have been
284 accessed, it ignores the number of accesses.
285
286 This retains most of the benefits of an LRU cache, but adds an element of
287 scan-resistance to the process: If we end up scanning through a large
288 number of keys without reusing them, this does not evict the existing
289 entries in preference for the new ones.
290 """
291
292 __slots__ = ("__tick",)
293
294 def __init__(self, max_size: int):
295 super().__init__(max_size)
296 self.__tick: int = 0
297
298 def tick(self) -> int:
299 self.__tick += 1
300 return self.__tick
301
302 def new_entry(self, key: K, value: V) -> Any:
303 return (1, self.tick())
304
305 def on_access(self, key: K, value: V, score: Any) -> Any:
306 return (2, self.tick())
307
308
309class LRUCache(Generic[K, V]):
310 """
311 This is a drop-in replacement for a GenericCache (despite the lack of inheritance)
312 in performance critical environments. It turns out that GenericCache's heap
313 balancing for arbitrary scores can be quite expensive compared to the doubly
314 linked list approach of lru_cache or OrderedDict.
315
316 This class is a pure LRU and does not provide any sort of affininty towards
317 the number of accesses beyond recency. If soft-pinning entries which have been
318 accessed at least once is important, use LRUReusedCache.
319 """
320
321 # Here are some nice performance references for lru_cache vs OrderedDict:
322 # https://github.com/python/cpython/issues/72426#issuecomment-1093727671
323 # https://discuss.python.org/t/simplify-lru-cache/18192/6
324 #
325 # We use OrderedDict here because it is unclear to me we can provide the same
326 # api as GenericCache using @lru_cache without messing with lru_cache internals.
327 #
328 # Anecdotally, OrderedDict seems quite competitive with lru_cache, but perhaps
329 # that is localized to our access patterns.
330
331 def __init__(self, max_size: int) -> None:
332 assert max_size > 0
333 self.max_size = max_size
334 self._threadlocal = threading.local()
335
336 @property
337 def cache(self) -> OrderedDict[K, V]:
338 try:
339 return self._threadlocal.cache
340 except AttributeError:
341 self._threadlocal.cache = OrderedDict()
342 return self._threadlocal.cache
343
344 def __setitem__(self, key: K, value: V) -> None:
345 self.cache[key] = value
346 self.cache.move_to_end(key)
347
348 while len(self.cache) > self.max_size:
349 self.cache.popitem(last=False)
350
351 def __getitem__(self, key: K) -> V:
352 val = self.cache[key]
353 self.cache.move_to_end(key)
354 return val
355
356 def __iter__(self):
357 return iter(self.cache)
358
359 def __len__(self) -> int:
360 return len(self.cache)
361
362 def __contains__(self, key: K) -> bool:
363 return key in self.cache
364
365 # implement GenericCache interface, for tests
366 def check_valid(self) -> None:
367 pass