Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/django/core/cache/backends/redis.py: 4%
157 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 06:13 +0000
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 06:13 +0000
1"""Redis cache backend."""
3import pickle
4import random
5import re
7from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
8from django.utils.functional import cached_property
9from django.utils.module_loading import import_string
12class RedisSerializer:
13 def __init__(self, protocol=None):
14 self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol
16 def dumps(self, obj):
17 # Only skip pickling for integers, a int subclasses as bool should be
18 # pickled.
19 if type(obj) is int:
20 return obj
21 return pickle.dumps(obj, self.protocol)
23 def loads(self, data):
24 try:
25 return int(data)
26 except ValueError:
27 return pickle.loads(data)
30class RedisCacheClient:
31 def __init__(
32 self,
33 servers,
34 serializer=None,
35 pool_class=None,
36 parser_class=None,
37 **options,
38 ):
39 import redis
41 self._lib = redis
42 self._servers = servers
43 self._pools = {}
45 self._client = self._lib.Redis
47 if isinstance(pool_class, str):
48 pool_class = import_string(pool_class)
49 self._pool_class = pool_class or self._lib.ConnectionPool
51 if isinstance(serializer, str):
52 serializer = import_string(serializer)
53 if callable(serializer):
54 serializer = serializer()
55 self._serializer = serializer or RedisSerializer()
57 if isinstance(parser_class, str):
58 parser_class = import_string(parser_class)
59 parser_class = parser_class or self._lib.connection.DefaultParser
61 self._pool_options = {"parser_class": parser_class, **options}
63 def _get_connection_pool_index(self, write):
64 # Write to the first server. Read from other servers if there are more,
65 # otherwise read from the first server.
66 if write or len(self._servers) == 1:
67 return 0
68 return random.randint(1, len(self._servers) - 1)
70 def _get_connection_pool(self, write):
71 index = self._get_connection_pool_index(write)
72 if index not in self._pools:
73 self._pools[index] = self._pool_class.from_url(
74 self._servers[index],
75 **self._pool_options,
76 )
77 return self._pools[index]
79 def get_client(self, key=None, *, write=False):
80 # key is used so that the method signature remains the same and custom
81 # cache client can be implemented which might require the key to select
82 # the server, e.g. sharding.
83 pool = self._get_connection_pool(write)
84 return self._client(connection_pool=pool)
86 def add(self, key, value, timeout):
87 client = self.get_client(key, write=True)
88 value = self._serializer.dumps(value)
90 if timeout == 0:
91 if ret := bool(client.set(key, value, nx=True)):
92 client.delete(key)
93 return ret
94 else:
95 return bool(client.set(key, value, ex=timeout, nx=True))
97 def get(self, key, default):
98 client = self.get_client(key)
99 value = client.get(key)
100 return default if value is None else self._serializer.loads(value)
102 def set(self, key, value, timeout):
103 client = self.get_client(key, write=True)
104 value = self._serializer.dumps(value)
105 if timeout == 0:
106 client.delete(key)
107 else:
108 client.set(key, value, ex=timeout)
110 def touch(self, key, timeout):
111 client = self.get_client(key, write=True)
112 if timeout is None:
113 return bool(client.persist(key))
114 else:
115 return bool(client.expire(key, timeout))
117 def delete(self, key):
118 client = self.get_client(key, write=True)
119 return bool(client.delete(key))
121 def get_many(self, keys):
122 client = self.get_client(None)
123 ret = client.mget(keys)
124 return {
125 k: self._serializer.loads(v) for k, v in zip(keys, ret) if v is not None
126 }
128 def has_key(self, key):
129 client = self.get_client(key)
130 return bool(client.exists(key))
132 def incr(self, key, delta):
133 client = self.get_client(key, write=True)
134 if not client.exists(key):
135 raise ValueError("Key '%s' not found." % key)
136 return client.incr(key, delta)
138 def set_many(self, data, timeout):
139 client = self.get_client(None, write=True)
140 pipeline = client.pipeline()
141 pipeline.mset({k: self._serializer.dumps(v) for k, v in data.items()})
143 if timeout is not None:
144 # Setting timeout for each key as redis does not support timeout
145 # with mset().
146 for key in data:
147 pipeline.expire(key, timeout)
148 pipeline.execute()
150 def delete_many(self, keys):
151 client = self.get_client(None, write=True)
152 client.delete(*keys)
154 def clear(self):
155 client = self.get_client(None, write=True)
156 return bool(client.flushdb())
159class RedisCache(BaseCache):
160 def __init__(self, server, params):
161 super().__init__(params)
162 if isinstance(server, str):
163 self._servers = re.split("[;,]", server)
164 else:
165 self._servers = server
167 self._class = RedisCacheClient
168 self._options = params.get("OPTIONS", {})
170 @cached_property
171 def _cache(self):
172 return self._class(self._servers, **self._options)
174 def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
175 if timeout == DEFAULT_TIMEOUT:
176 timeout = self.default_timeout
177 # The key will be made persistent if None used as a timeout.
178 # Non-positive values will cause the key to be deleted.
179 return None if timeout is None else max(0, int(timeout))
181 def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
182 key = self.make_and_validate_key(key, version=version)
183 return self._cache.add(key, value, self.get_backend_timeout(timeout))
185 def get(self, key, default=None, version=None):
186 key = self.make_and_validate_key(key, version=version)
187 return self._cache.get(key, default)
189 def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
190 key = self.make_and_validate_key(key, version=version)
191 self._cache.set(key, value, self.get_backend_timeout(timeout))
193 def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
194 key = self.make_and_validate_key(key, version=version)
195 return self._cache.touch(key, self.get_backend_timeout(timeout))
197 def delete(self, key, version=None):
198 key = self.make_and_validate_key(key, version=version)
199 return self._cache.delete(key)
201 def get_many(self, keys, version=None):
202 key_map = {
203 self.make_and_validate_key(key, version=version): key for key in keys
204 }
205 ret = self._cache.get_many(key_map.keys())
206 return {key_map[k]: v for k, v in ret.items()}
208 def has_key(self, key, version=None):
209 key = self.make_and_validate_key(key, version=version)
210 return self._cache.has_key(key)
212 def incr(self, key, delta=1, version=None):
213 key = self.make_and_validate_key(key, version=version)
214 return self._cache.incr(key, delta)
216 def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
217 if not data:
218 return []
219 safe_data = {}
220 for key, value in data.items():
221 key = self.make_and_validate_key(key, version=version)
222 safe_data[key] = value
223 self._cache.set_many(safe_data, self.get_backend_timeout(timeout))
224 return []
226 def delete_many(self, keys, version=None):
227 if not keys:
228 return
229 safe_keys = [self.make_and_validate_key(key, version=version) for key in keys]
230 self._cache.delete_many(safe_keys)
232 def clear(self):
233 return self._cache.clear()