Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/django/core/cache/backends/redis.py: 4%

157 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 06:13 +0000

1"""Redis cache backend.""" 

2 

3import pickle 

4import random 

5import re 

6 

7from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache 

8from django.utils.functional import cached_property 

9from django.utils.module_loading import import_string 

10 

11 

12class RedisSerializer: 

13 def __init__(self, protocol=None): 

14 self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol 

15 

16 def dumps(self, obj): 

17 # Only skip pickling for integers, a int subclasses as bool should be 

18 # pickled. 

19 if type(obj) is int: 

20 return obj 

21 return pickle.dumps(obj, self.protocol) 

22 

23 def loads(self, data): 

24 try: 

25 return int(data) 

26 except ValueError: 

27 return pickle.loads(data) 

28 

29 

30class RedisCacheClient: 

31 def __init__( 

32 self, 

33 servers, 

34 serializer=None, 

35 pool_class=None, 

36 parser_class=None, 

37 **options, 

38 ): 

39 import redis 

40 

41 self._lib = redis 

42 self._servers = servers 

43 self._pools = {} 

44 

45 self._client = self._lib.Redis 

46 

47 if isinstance(pool_class, str): 

48 pool_class = import_string(pool_class) 

49 self._pool_class = pool_class or self._lib.ConnectionPool 

50 

51 if isinstance(serializer, str): 

52 serializer = import_string(serializer) 

53 if callable(serializer): 

54 serializer = serializer() 

55 self._serializer = serializer or RedisSerializer() 

56 

57 if isinstance(parser_class, str): 

58 parser_class = import_string(parser_class) 

59 parser_class = parser_class or self._lib.connection.DefaultParser 

60 

61 self._pool_options = {"parser_class": parser_class, **options} 

62 

63 def _get_connection_pool_index(self, write): 

64 # Write to the first server. Read from other servers if there are more, 

65 # otherwise read from the first server. 

66 if write or len(self._servers) == 1: 

67 return 0 

68 return random.randint(1, len(self._servers) - 1) 

69 

70 def _get_connection_pool(self, write): 

71 index = self._get_connection_pool_index(write) 

72 if index not in self._pools: 

73 self._pools[index] = self._pool_class.from_url( 

74 self._servers[index], 

75 **self._pool_options, 

76 ) 

77 return self._pools[index] 

78 

79 def get_client(self, key=None, *, write=False): 

80 # key is used so that the method signature remains the same and custom 

81 # cache client can be implemented which might require the key to select 

82 # the server, e.g. sharding. 

83 pool = self._get_connection_pool(write) 

84 return self._client(connection_pool=pool) 

85 

86 def add(self, key, value, timeout): 

87 client = self.get_client(key, write=True) 

88 value = self._serializer.dumps(value) 

89 

90 if timeout == 0: 

91 if ret := bool(client.set(key, value, nx=True)): 

92 client.delete(key) 

93 return ret 

94 else: 

95 return bool(client.set(key, value, ex=timeout, nx=True)) 

96 

97 def get(self, key, default): 

98 client = self.get_client(key) 

99 value = client.get(key) 

100 return default if value is None else self._serializer.loads(value) 

101 

102 def set(self, key, value, timeout): 

103 client = self.get_client(key, write=True) 

104 value = self._serializer.dumps(value) 

105 if timeout == 0: 

106 client.delete(key) 

107 else: 

108 client.set(key, value, ex=timeout) 

109 

110 def touch(self, key, timeout): 

111 client = self.get_client(key, write=True) 

112 if timeout is None: 

113 return bool(client.persist(key)) 

114 else: 

115 return bool(client.expire(key, timeout)) 

116 

117 def delete(self, key): 

118 client = self.get_client(key, write=True) 

119 return bool(client.delete(key)) 

120 

121 def get_many(self, keys): 

122 client = self.get_client(None) 

123 ret = client.mget(keys) 

124 return { 

125 k: self._serializer.loads(v) for k, v in zip(keys, ret) if v is not None 

126 } 

127 

128 def has_key(self, key): 

129 client = self.get_client(key) 

130 return bool(client.exists(key)) 

131 

132 def incr(self, key, delta): 

133 client = self.get_client(key, write=True) 

134 if not client.exists(key): 

135 raise ValueError("Key '%s' not found." % key) 

136 return client.incr(key, delta) 

137 

138 def set_many(self, data, timeout): 

139 client = self.get_client(None, write=True) 

140 pipeline = client.pipeline() 

141 pipeline.mset({k: self._serializer.dumps(v) for k, v in data.items()}) 

142 

143 if timeout is not None: 

144 # Setting timeout for each key as redis does not support timeout 

145 # with mset(). 

146 for key in data: 

147 pipeline.expire(key, timeout) 

148 pipeline.execute() 

149 

150 def delete_many(self, keys): 

151 client = self.get_client(None, write=True) 

152 client.delete(*keys) 

153 

154 def clear(self): 

155 client = self.get_client(None, write=True) 

156 return bool(client.flushdb()) 

157 

158 

159class RedisCache(BaseCache): 

160 def __init__(self, server, params): 

161 super().__init__(params) 

162 if isinstance(server, str): 

163 self._servers = re.split("[;,]", server) 

164 else: 

165 self._servers = server 

166 

167 self._class = RedisCacheClient 

168 self._options = params.get("OPTIONS", {}) 

169 

170 @cached_property 

171 def _cache(self): 

172 return self._class(self._servers, **self._options) 

173 

174 def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT): 

175 if timeout == DEFAULT_TIMEOUT: 

176 timeout = self.default_timeout 

177 # The key will be made persistent if None used as a timeout. 

178 # Non-positive values will cause the key to be deleted. 

179 return None if timeout is None else max(0, int(timeout)) 

180 

181 def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): 

182 key = self.make_and_validate_key(key, version=version) 

183 return self._cache.add(key, value, self.get_backend_timeout(timeout)) 

184 

185 def get(self, key, default=None, version=None): 

186 key = self.make_and_validate_key(key, version=version) 

187 return self._cache.get(key, default) 

188 

189 def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): 

190 key = self.make_and_validate_key(key, version=version) 

191 self._cache.set(key, value, self.get_backend_timeout(timeout)) 

192 

193 def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None): 

194 key = self.make_and_validate_key(key, version=version) 

195 return self._cache.touch(key, self.get_backend_timeout(timeout)) 

196 

197 def delete(self, key, version=None): 

198 key = self.make_and_validate_key(key, version=version) 

199 return self._cache.delete(key) 

200 

201 def get_many(self, keys, version=None): 

202 key_map = { 

203 self.make_and_validate_key(key, version=version): key for key in keys 

204 } 

205 ret = self._cache.get_many(key_map.keys()) 

206 return {key_map[k]: v for k, v in ret.items()} 

207 

208 def has_key(self, key, version=None): 

209 key = self.make_and_validate_key(key, version=version) 

210 return self._cache.has_key(key) 

211 

212 def incr(self, key, delta=1, version=None): 

213 key = self.make_and_validate_key(key, version=version) 

214 return self._cache.incr(key, delta) 

215 

216 def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None): 

217 if not data: 

218 return [] 

219 safe_data = {} 

220 for key, value in data.items(): 

221 key = self.make_and_validate_key(key, version=version) 

222 safe_data[key] = value 

223 self._cache.set_many(safe_data, self.get_backend_timeout(timeout)) 

224 return [] 

225 

226 def delete_many(self, keys, version=None): 

227 if not keys: 

228 return 

229 safe_keys = [self.make_and_validate_key(key, version=version) for key in keys] 

230 self._cache.delete_many(safe_keys) 

231 

232 def clear(self): 

233 return self._cache.clear()