1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import threading
12from collections import OrderedDict
13from dataclasses import dataclass
14from typing import Any, Generic, TypeVar
15
16from hypothesis.errors import InvalidArgument
17
18K = TypeVar("K")
19V = TypeVar("V")
20
21
22@dataclass(slots=True, frozen=False)
23class Entry(Generic[K, V]):
24 key: K
25 value: V
26 score: int
27 pins: int = 0
28
29 @property
30 def sort_key(self) -> tuple[int, ...]:
31 if self.pins == 0:
32 # Unpinned entries are sorted by score.
33 return (0, self.score)
34 else:
35 # Pinned entries sort after unpinned ones. Beyond that, we don't
36 # worry about their relative order.
37 return (1,)
38
39
40class GenericCache(Generic[K, V]):
41 """Generic supertype for cache implementations.
42
43 Defines a dict-like mapping with a maximum size, where as well as mapping
44 to a value, each key also maps to a score. When a write would cause the
45 dict to exceed its maximum size, it first evicts the existing key with
46 the smallest score, then adds the new key to the map. If due to pinning
47 no key can be evicted, ValueError is raised.
48
49 A key has the following lifecycle:
50
51 1. key is written for the first time, the key is given the score
52 self.new_entry(key, value)
53 2. whenever an existing key is read or written, self.on_access(key, value,
54 score) is called. This returns a new score for the key.
55 3. After a key is evicted, self.on_evict(key, value, score) is called.
56
57 The cache will be in a valid state in all of these cases.
58
59 Implementations are expected to implement new_entry and optionally
60 on_access and on_evict to implement a specific scoring strategy.
61 """
62
63 __slots__ = ("_threadlocal", "max_size")
64
65 def __init__(self, max_size: int):
66 if max_size <= 0:
67 raise InvalidArgument("Cache size must be at least one.")
68
69 self.max_size = max_size
70
71 # Implementation: We store a binary heap of Entry objects in self.data,
72 # with the heap property requiring that a parent's score is <= that of
73 # its children. keys_to_index then maps keys to their index in the
74 # heap. We keep these two in sync automatically - the heap is never
75 # reordered without updating the index.
76 self._threadlocal = threading.local()
77
78 @property
79 def keys_to_indices(self) -> dict[K, int]:
80 try:
81 return self._threadlocal.keys_to_indices
82 except AttributeError:
83 self._threadlocal.keys_to_indices = {}
84 return self._threadlocal.keys_to_indices
85
86 @property
87 def data(self) -> list[Entry[K, V]]:
88 try:
89 return self._threadlocal.data
90 except AttributeError:
91 self._threadlocal.data = []
92 return self._threadlocal.data
93
94 def __len__(self) -> int:
95 assert len(self.keys_to_indices) == len(self.data)
96 return len(self.data)
97
98 def __contains__(self, key: K) -> bool:
99 return key in self.keys_to_indices
100
101 def __getitem__(self, key: K) -> V:
102 i = self.keys_to_indices[key]
103 result = self.data[i]
104 self.__entry_was_accessed(i)
105 return result.value
106
107 def __setitem__(self, key: K, value: V) -> None:
108 evicted = None
109 try:
110 i = self.keys_to_indices[key]
111 except KeyError:
112 entry = Entry(key, value, self.new_entry(key, value))
113 if len(self.data) >= self.max_size:
114 evicted = self.data[0]
115 if evicted.pins > 0:
116 raise ValueError(
117 "Cannot increase size of cache where all keys have been pinned."
118 ) from None
119 del self.keys_to_indices[evicted.key]
120 i = 0
121 self.data[0] = entry
122 else:
123 i = len(self.data)
124 self.data.append(entry)
125 self.keys_to_indices[key] = i
126 self.__balance(i)
127 else:
128 entry = self.data[i]
129 assert entry.key == key
130 entry.value = value
131 self.__entry_was_accessed(i)
132
133 if evicted is not None:
134 if self.data[0] is not entry:
135 assert evicted.sort_key <= self.data[0].sort_key
136 self.on_evict(evicted.key, evicted.value, evicted.score)
137
138 def __iter__(self):
139 return iter(self.keys_to_indices)
140
141 def pin(self, key: K, value: V) -> None:
142 """Mark ``key`` as pinned (with the given value). That is, it may not
143 be evicted until ``unpin(key)`` has been called. The same key may be
144 pinned multiple times, possibly changing its value, and will not be
145 unpinned until the same number of calls to unpin have been made.
146 """
147 self[key] = value
148
149 i = self.keys_to_indices[key]
150 entry = self.data[i]
151 entry.pins += 1
152 if entry.pins == 1:
153 self.__balance(i)
154
155 def unpin(self, key: K) -> None:
156 """Undo one previous call to ``pin(key)``. The value stays the same.
157 Once all calls are undone this key may be evicted as normal."""
158 i = self.keys_to_indices[key]
159 entry = self.data[i]
160 if entry.pins == 0:
161 raise ValueError(f"Key {key!r} has not been pinned")
162 entry.pins -= 1
163 if entry.pins == 0:
164 self.__balance(i)
165
166 def is_pinned(self, key: K) -> bool:
167 """Returns True if the key is currently pinned."""
168 i = self.keys_to_indices[key]
169 return self.data[i].pins > 0
170
171 def clear(self) -> None:
172 """Remove all keys, regardless of their pinned status."""
173 del self.data[:]
174 self.keys_to_indices.clear()
175
176 def __repr__(self) -> str:
177 return "{" + ", ".join(f"{e.key!r}: {e.value!r}" for e in self.data) + "}"
178
179 def new_entry(self, key: K, value: V) -> int:
180 """Called when a key is written that does not currently appear in the
181 map.
182
183 Returns the score to associate with the key.
184 """
185 raise NotImplementedError
186
187 def on_access(self, key: K, value: V, score: Any) -> Any:
188 """Called every time a key that is already in the map is read or
189 written.
190
191 Returns the new score for the key.
192 """
193 return score
194
195 def on_evict(self, key: K, value: V, score: Any) -> Any:
196 """Called after a key has been evicted, with the score it had had at
197 the point of eviction."""
198
199 def check_valid(self) -> None:
200 """Debugging method for use in tests.
201
202 Asserts that all of the cache's invariants hold. When everything
203 is working correctly this should be an expensive no-op.
204 """
205 assert len(self.keys_to_indices) == len(self.data)
206 for i, e in enumerate(self.data):
207 assert self.keys_to_indices[e.key] == i
208 for j in [i * 2 + 1, i * 2 + 2]:
209 if j < len(self.data):
210 assert e.sort_key <= self.data[j].sort_key, self.data
211
212 def __entry_was_accessed(self, i: int) -> None:
213 entry = self.data[i]
214 new_score = self.on_access(entry.key, entry.value, entry.score)
215 if new_score != entry.score:
216 entry.score = new_score
217 # changing the score of a pinned entry cannot unbalance the heap, as
218 # we place all pinned entries after unpinned ones, regardless of score.
219 if entry.pins == 0:
220 self.__balance(i)
221
222 def __swap(self, i: int, j: int) -> None:
223 assert i < j
224 assert self.data[j].sort_key < self.data[i].sort_key
225 self.data[i], self.data[j] = self.data[j], self.data[i]
226 self.keys_to_indices[self.data[i].key] = i
227 self.keys_to_indices[self.data[j].key] = j
228
229 def __balance(self, i: int) -> None:
230 """When we have made a modification to the heap such that
231 the heap property has been violated locally around i but previously
232 held for all other indexes (and no other values have been modified),
233 this fixes the heap so that the heap property holds everywhere."""
234 # bubble up (if score is too low for current position)
235 while (parent := (i - 1) // 2) >= 0:
236 if self.__out_of_order(parent, i):
237 self.__swap(parent, i)
238 i = parent
239 else:
240 break
241 # or bubble down (if score is too high for current position)
242 while children := [j for j in (2 * i + 1, 2 * i + 2) if j < len(self.data)]:
243 smallest_child = min(children, key=lambda j: self.data[j].sort_key)
244 if self.__out_of_order(i, smallest_child):
245 self.__swap(i, smallest_child)
246 i = smallest_child
247 else:
248 break
249
250 def __out_of_order(self, i: int, j: int) -> bool:
251 """Returns True if the indices i, j are in the wrong order.
252
253 i must be the parent of j.
254 """
255 assert i == (j - 1) // 2
256 return self.data[j].sort_key < self.data[i].sort_key
257
258
259class LRUReusedCache(GenericCache[K, V]):
260 """The only concrete implementation of GenericCache we use outside of tests
261 currently.
262
263 Adopts a modified least-recently used eviction policy: It evicts the key
264 that has been used least recently, but it will always preferentially evict
265 keys that have never been accessed after insertion. Among keys that have been
266 accessed, it ignores the number of accesses.
267
268 This retains most of the benefits of an LRU cache, but adds an element of
269 scan-resistance to the process: If we end up scanning through a large
270 number of keys without reusing them, this does not evict the existing
271 entries in preference for the new ones.
272 """
273
274 __slots__ = ("__tick",)
275
276 def __init__(self, max_size: int):
277 super().__init__(max_size)
278 self.__tick: int = 0
279
280 def tick(self) -> int:
281 self.__tick += 1
282 return self.__tick
283
284 def new_entry(self, key: K, value: V) -> Any:
285 return (1, self.tick())
286
287 def on_access(self, key: K, value: V, score: Any) -> Any:
288 return (2, self.tick())
289
290
291class LRUCache(Generic[K, V]):
292 """
293 This is a drop-in replacement for a GenericCache (despite the lack of inheritance)
294 in performance critical environments. It turns out that GenericCache's heap
295 balancing for arbitrary scores can be quite expensive compared to the doubly
296 linked list approach of lru_cache or OrderedDict.
297
298 This class is a pure LRU and does not provide any sort of affininty towards
299 the number of accesses beyond recency. If soft-pinning entries which have been
300 accessed at least once is important, use LRUReusedCache.
301 """
302
303 # Here are some nice performance references for lru_cache vs OrderedDict:
304 # https://github.com/python/cpython/issues/72426#issuecomment-1093727671
305 # https://discuss.python.org/t/simplify-lru-cache/18192/6
306 #
307 # We use OrderedDict here because it is unclear to me we can provide the same
308 # api as GenericCache using @lru_cache without messing with lru_cache internals.
309 #
310 # Anecdotally, OrderedDict seems quite competitive with lru_cache, but perhaps
311 # that is localized to our access patterns.
312
313 def __init__(self, max_size: int) -> None:
314 assert max_size > 0
315 self.max_size = max_size
316 self._threadlocal = threading.local()
317
318 @property
319 def cache(self) -> OrderedDict[K, V]:
320 try:
321 return self._threadlocal.cache
322 except AttributeError:
323 self._threadlocal.cache = OrderedDict()
324 return self._threadlocal.cache
325
326 def __setitem__(self, key: K, value: V) -> None:
327 self.cache[key] = value
328 self.cache.move_to_end(key)
329
330 while len(self.cache) > self.max_size:
331 self.cache.popitem(last=False)
332
333 def __getitem__(self, key: K) -> V:
334 val = self.cache[key]
335 self.cache.move_to_end(key)
336 return val
337
338 def __iter__(self):
339 return iter(self.cache)
340
341 def __len__(self) -> int:
342 return len(self.cache)
343
344 def __contains__(self, key: K) -> bool:
345 return key in self.cache
346
347 # implement GenericCache interface, for tests
348 def check_valid(self) -> None:
349 pass