1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Copyright the Hypothesis Authors.
5# Individual contributors are listed in AUTHORS.rst and the git log.
6#
7# This Source Code Form is subject to the terms of the Mozilla Public License,
8# v. 2.0. If a copy of the MPL was not distributed with this file, You can
9# obtain one at https://mozilla.org/MPL/2.0/.
10
11import threading
12from collections import OrderedDict
13from typing import Any, Generic, TypeVar
14
15import attr
16
17from hypothesis.errors import InvalidArgument
18
19K = TypeVar("K")
20V = TypeVar("V")
21
22
23@attr.s(slots=True)
24class Entry(Generic[K, V]):
25 key: K = attr.ib()
26 value: V = attr.ib()
27 score: int = attr.ib()
28 pins: int = attr.ib(default=0)
29
30 @property
31 def sort_key(self) -> tuple[int, ...]:
32 if self.pins == 0:
33 # Unpinned entries are sorted by score.
34 return (0, self.score)
35 else:
36 # Pinned entries sort after unpinned ones. Beyond that, we don't
37 # worry about their relative order.
38 return (1,)
39
40
41class GenericCache(Generic[K, V]):
42 """Generic supertype for cache implementations.
43
44 Defines a dict-like mapping with a maximum size, where as well as mapping
45 to a value, each key also maps to a score. When a write would cause the
46 dict to exceed its maximum size, it first evicts the existing key with
47 the smallest score, then adds the new key to the map. If due to pinning
48 no key can be evicted, ValueError is raised.
49
50 A key has the following lifecycle:
51
52 1. key is written for the first time, the key is given the score
53 self.new_entry(key, value)
54 2. whenever an existing key is read or written, self.on_access(key, value,
55 score) is called. This returns a new score for the key.
56 3. After a key is evicted, self.on_evict(key, value, score) is called.
57
58 The cache will be in a valid state in all of these cases.
59
60 Implementations are expected to implement new_entry and optionally
61 on_access and on_evict to implement a specific scoring strategy.
62 """
63
64 __slots__ = ("_threadlocal", "max_size")
65
66 def __init__(self, max_size: int):
67 if max_size <= 0:
68 raise InvalidArgument("Cache size must be at least one.")
69
70 self.max_size = max_size
71
72 # Implementation: We store a binary heap of Entry objects in self.data,
73 # with the heap property requiring that a parent's score is <= that of
74 # its children. keys_to_index then maps keys to their index in the
75 # heap. We keep these two in sync automatically - the heap is never
76 # reordered without updating the index.
77 self._threadlocal = threading.local()
78
79 @property
80 def keys_to_indices(self) -> dict[K, int]:
81 try:
82 return self._threadlocal.keys_to_indices
83 except AttributeError:
84 self._threadlocal.keys_to_indices = {}
85 return self._threadlocal.keys_to_indices
86
87 @property
88 def data(self) -> list[Entry[K, V]]:
89 try:
90 return self._threadlocal.data
91 except AttributeError:
92 self._threadlocal.data = []
93 return self._threadlocal.data
94
95 def __len__(self) -> int:
96 assert len(self.keys_to_indices) == len(self.data)
97 return len(self.data)
98
99 def __contains__(self, key: K) -> bool:
100 return key in self.keys_to_indices
101
102 def __getitem__(self, key: K) -> V:
103 i = self.keys_to_indices[key]
104 result = self.data[i]
105 self.__entry_was_accessed(i)
106 return result.value
107
108 def __setitem__(self, key: K, value: V) -> None:
109 evicted = None
110 try:
111 i = self.keys_to_indices[key]
112 except KeyError:
113 entry = Entry(key, value, self.new_entry(key, value))
114 if len(self.data) >= self.max_size:
115 evicted = self.data[0]
116 if evicted.pins > 0:
117 raise ValueError(
118 "Cannot increase size of cache where all keys have been pinned."
119 ) from None
120 del self.keys_to_indices[evicted.key]
121 i = 0
122 self.data[0] = entry
123 else:
124 i = len(self.data)
125 self.data.append(entry)
126 self.keys_to_indices[key] = i
127 self.__balance(i)
128 else:
129 entry = self.data[i]
130 assert entry.key == key
131 entry.value = value
132 self.__entry_was_accessed(i)
133
134 if evicted is not None:
135 if self.data[0] is not entry:
136 assert evicted.sort_key <= self.data[0].sort_key
137 self.on_evict(evicted.key, evicted.value, evicted.score)
138
139 def __iter__(self):
140 return iter(self.keys_to_indices)
141
142 def pin(self, key: K, value: V) -> None:
143 """Mark ``key`` as pinned (with the given value). That is, it may not
144 be evicted until ``unpin(key)`` has been called. The same key may be
145 pinned multiple times, possibly changing its value, and will not be
146 unpinned until the same number of calls to unpin have been made.
147 """
148 self[key] = value
149
150 i = self.keys_to_indices[key]
151 entry = self.data[i]
152 entry.pins += 1
153 if entry.pins == 1:
154 self.__balance(i)
155
156 def unpin(self, key: K) -> None:
157 """Undo one previous call to ``pin(key)``. The value stays the same.
158 Once all calls are undone this key may be evicted as normal."""
159 i = self.keys_to_indices[key]
160 entry = self.data[i]
161 if entry.pins == 0:
162 raise ValueError(f"Key {key!r} has not been pinned")
163 entry.pins -= 1
164 if entry.pins == 0:
165 self.__balance(i)
166
167 def is_pinned(self, key: K) -> bool:
168 """Returns True if the key is currently pinned."""
169 i = self.keys_to_indices[key]
170 return self.data[i].pins > 0
171
172 def clear(self) -> None:
173 """Remove all keys, regardless of their pinned status."""
174 del self.data[:]
175 self.keys_to_indices.clear()
176
177 def __repr__(self) -> str:
178 return "{" + ", ".join(f"{e.key!r}: {e.value!r}" for e in self.data) + "}"
179
180 def new_entry(self, key: K, value: V) -> int:
181 """Called when a key is written that does not currently appear in the
182 map.
183
184 Returns the score to associate with the key.
185 """
186 raise NotImplementedError
187
188 def on_access(self, key: K, value: V, score: Any) -> Any:
189 """Called every time a key that is already in the map is read or
190 written.
191
192 Returns the new score for the key.
193 """
194 return score
195
196 def on_evict(self, key: K, value: V, score: Any) -> Any:
197 """Called after a key has been evicted, with the score it had had at
198 the point of eviction."""
199
200 def check_valid(self) -> None:
201 """Debugging method for use in tests.
202
203 Asserts that all of the cache's invariants hold. When everything
204 is working correctly this should be an expensive no-op.
205 """
206 assert len(self.keys_to_indices) == len(self.data)
207 for i, e in enumerate(self.data):
208 assert self.keys_to_indices[e.key] == i
209 for j in [i * 2 + 1, i * 2 + 2]:
210 if j < len(self.data):
211 assert e.sort_key <= self.data[j].sort_key, self.data
212
213 def __entry_was_accessed(self, i: int) -> None:
214 entry = self.data[i]
215 new_score = self.on_access(entry.key, entry.value, entry.score)
216 if new_score != entry.score:
217 entry.score = new_score
218 # changing the score of a pinned entry cannot unbalance the heap, as
219 # we place all pinned entries after unpinned ones, regardless of score.
220 if entry.pins == 0:
221 self.__balance(i)
222
223 def __swap(self, i: int, j: int) -> None:
224 assert i < j
225 assert self.data[j].sort_key < self.data[i].sort_key
226 self.data[i], self.data[j] = self.data[j], self.data[i]
227 self.keys_to_indices[self.data[i].key] = i
228 self.keys_to_indices[self.data[j].key] = j
229
230 def __balance(self, i: int) -> None:
231 """When we have made a modification to the heap such that
232 the heap property has been violated locally around i but previously
233 held for all other indexes (and no other values have been modified),
234 this fixes the heap so that the heap property holds everywhere."""
235 # bubble up (if score is too low for current position)
236 while (parent := (i - 1) // 2) >= 0:
237 if self.__out_of_order(parent, i):
238 self.__swap(parent, i)
239 i = parent
240 else:
241 break
242 # or bubble down (if score is too high for current position)
243 while children := [j for j in (2 * i + 1, 2 * i + 2) if j < len(self.data)]:
244 smallest_child = min(children, key=lambda j: self.data[j].sort_key)
245 if self.__out_of_order(i, smallest_child):
246 self.__swap(i, smallest_child)
247 i = smallest_child
248 else:
249 break
250
251 def __out_of_order(self, i: int, j: int) -> bool:
252 """Returns True if the indices i, j are in the wrong order.
253
254 i must be the parent of j.
255 """
256 assert i == (j - 1) // 2
257 return self.data[j].sort_key < self.data[i].sort_key
258
259
260class LRUReusedCache(GenericCache[K, V]):
261 """The only concrete implementation of GenericCache we use outside of tests
262 currently.
263
264 Adopts a modified least-recently used eviction policy: It evicts the key
265 that has been used least recently, but it will always preferentially evict
266 keys that have never been accessed after insertion. Among keys that have been
267 accessed, it ignores the number of accesses.
268
269 This retains most of the benefits of an LRU cache, but adds an element of
270 scan-resistance to the process: If we end up scanning through a large
271 number of keys without reusing them, this does not evict the existing
272 entries in preference for the new ones.
273 """
274
275 __slots__ = ("__tick",)
276
277 def __init__(self, max_size: int):
278 super().__init__(max_size)
279 self.__tick: int = 0
280
281 def tick(self) -> int:
282 self.__tick += 1
283 return self.__tick
284
285 def new_entry(self, key: K, value: V) -> Any:
286 return (1, self.tick())
287
288 def on_access(self, key: K, value: V, score: Any) -> Any:
289 return (2, self.tick())
290
291
292class LRUCache(Generic[K, V]):
293 """
294 This is a drop-in replacement for a GenericCache (despite the lack of inheritance)
295 in performance critical environments. It turns out that GenericCache's heap
296 balancing for arbitrary scores can be quite expensive compared to the doubly
297 linked list approach of lru_cache or OrderedDict.
298
299 This class is a pure LRU and does not provide any sort of affininty towards
300 the number of accesses beyond recency. If soft-pinning entries which have been
301 accessed at least once is important, use LRUReusedCache.
302 """
303
304 # Here are some nice performance references for lru_cache vs OrderedDict:
305 # https://github.com/python/cpython/issues/72426#issuecomment-1093727671
306 # https://discuss.python.org/t/simplify-lru-cache/18192/6
307 #
308 # We use OrderedDict here because it is unclear to me we can provide the same
309 # api as GenericCache using @lru_cache without messing with lru_cache internals.
310 #
311 # Anecdotally, OrderedDict seems quite competitive with lru_cache, but perhaps
312 # that is localized to our access patterns.
313
314 def __init__(self, max_size: int) -> None:
315 assert max_size > 0
316 self.max_size = max_size
317 self._threadlocal = threading.local()
318
319 @property
320 def cache(self) -> OrderedDict[K, V]:
321 try:
322 return self._threadlocal.cache
323 except AttributeError:
324 self._threadlocal.cache = OrderedDict()
325 return self._threadlocal.cache
326
327 def __setitem__(self, key: K, value: V) -> None:
328 self.cache[key] = value
329 self.cache.move_to_end(key)
330
331 while len(self.cache) > self.max_size:
332 self.cache.popitem(last=False)
333
334 def __getitem__(self, key: K) -> V:
335 val = self.cache[key]
336 self.cache.move_to_end(key)
337 return val
338
339 def __iter__(self):
340 return iter(self.cache)
341
342 def __len__(self) -> int:
343 return len(self.cache)
344
345 def __contains__(self, key: K) -> bool:
346 return key in self.cache
347
348 # implement GenericCache interface, for tests
349 def check_valid(self) -> None:
350 pass