1import copy
2import random
3import string
4from typing import (
5 Any,
6 Callable,
7 Dict,
8 Iterable,
9 List,
10 Mapping,
11 Optional,
12)
13
14import redis
15from redis.typing import ChannelT, PubSubHandler, Subscription
16
17
18def list_or_args(keys: Any, args: Iterable[Any] | None) -> List[Any]:
19 # returns a single new list combining keys and args
20 try:
21 iter(keys)
22 # a string or bytes-like instance can be iterated, but indicates
23 # keys wasn't passed as a list
24 if isinstance(keys, (bytes, str, bytearray, memoryview)):
25 keys = [keys]
26 else:
27 keys = list(keys)
28 except TypeError:
29 keys = [keys]
30 if args:
31 keys.extend(args)
32 return keys
33
34
35def parse_pubsub_subscriptions(
36 args: tuple[Any, ...], kwargs: Mapping[str, PubSubHandler]
37) -> dict[ChannelT, PubSubHandler | None]:
38 parsed_args = list_or_args(args[0], args[1:]) if args else []
39 subscriptions: dict[ChannelT, PubSubHandler | None] = {}
40 for arg in parsed_args:
41 if isinstance(arg, Subscription):
42 subscriptions[arg.name] = arg.handler
43 else:
44 subscriptions[arg] = None
45 subscriptions.update(kwargs)
46 return subscriptions
47
48
49def pubsub_subscription_args(
50 subscriptions: Mapping[ChannelT, PubSubHandler | None],
51) -> list[ChannelT | Subscription]:
52 return [
53 channel if handler is None else Subscription(channel, handler)
54 for channel, handler in subscriptions.items()
55 ]
56
57
58def nativestr(x):
59 """Return the decoded binary string, or a string, depending on type."""
60 r = x.decode("utf-8", "replace") if isinstance(x, bytes) else x
61 if r == "null":
62 return
63 return r
64
65
66def delist(x):
67 """Given a list of binaries, return the stringified version."""
68 if x is None:
69 return x
70 return [nativestr(obj) for obj in x]
71
72
73def parse_to_list(response):
74 """Optimistically parse the response to a list."""
75 res = []
76
77 special_values = {"infinity", "nan", "-infinity"}
78
79 if response is None:
80 return res
81
82 for item in response:
83 if item is None:
84 res.append(None)
85 continue
86 if isinstance(item, float):
87 res.append(item)
88 continue
89 try:
90 item_str = nativestr(item)
91 except TypeError:
92 res.append(None)
93 continue
94
95 if isinstance(item_str, str) and item_str.lower() in special_values:
96 res.append(item_str) # Keep as string
97 else:
98 try:
99 res.append(int(item))
100 except (ValueError, OverflowError, TypeError):
101 try:
102 res.append(float(item))
103 except (ValueError, TypeError):
104 res.append(item_str)
105
106 return res
107
108
109def random_string(length=10):
110 """
111 Returns a random N character long string.
112 """
113 return "".join( # nosec
114 random.choice(string.ascii_lowercase) for x in range(length)
115 )
116
117
118def decode_dict_keys(obj):
119 """Decode the keys of the given dictionary with utf-8."""
120 newobj = copy.copy(obj)
121 for k in obj.keys():
122 if isinstance(k, bytes):
123 newobj[k.decode("utf-8")] = newobj[k]
124 newobj.pop(k)
125 return newobj
126
127
128def get_protocol_version(client):
129 if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis):
130 return client.connection_pool.connection_kwargs.get("protocol")
131 elif isinstance(client, redis.cluster.AbstractRedisCluster):
132 return client.nodes_manager.connection_kwargs.get("protocol")
133
134
135def get_legacy_responses(client):
136 """Return the user-supplied ``legacy_responses`` flag for ``client``.
137
138 Defaults to ``True`` when the flag is not present in the client's
139 ``connection_kwargs``. Mirrors :func:`get_protocol_version` so module
140 command bases can read both the protocol and the response-shape
141 selection from the same place.
142 """
143 if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis):
144 return client.connection_pool.connection_kwargs.get("legacy_responses", True)
145 elif isinstance(client, redis.cluster.AbstractRedisCluster):
146 return client.nodes_manager.connection_kwargs.get("legacy_responses", True)
147 return True
148
149
150def apply_module_callbacks(
151 user_protocol: Optional[int],
152 legacy_responses: bool,
153 *,
154 common: Dict[str, Callable[..., Any]],
155 resp2: Dict[str, Callable[..., Any]],
156 resp3: Dict[str, Callable[..., Any]],
157 resp2_unified: Optional[Dict[str, Callable[..., Any]]] = None,
158 resp3_unified: Optional[Dict[str, Callable[..., Any]]] = None,
159 resp3_to_resp2_legacy: Optional[Dict[str, Callable[..., Any]]] = None,
160) -> Dict[str, Callable[..., Any]]:
161 """Return the merged module-callback dict for the given (protocol,
162 legacy_responses) combination.
163
164 Mirrors the selection used by
165 :func:`redis._parsers.response_callbacks.get_response_callbacks` for
166 the core callbacks: ``common`` is overlaid with the protocol-specific
167 dict matching ``user_protocol`` and ``legacy_responses``.
168 ``resp2_unified`` defaults to ``resp2``, ``resp3_unified`` to ``resp3``,
169 and ``resp3_to_resp2_legacy`` to an empty dict.
170 """
171 callbacks: Dict[str, Callable[..., Any]] = dict(common)
172 if legacy_responses:
173 if user_protocol is None:
174 callbacks.update(resp3_to_resp2_legacy or {})
175 elif user_protocol in (3, "3"):
176 callbacks.update(resp3)
177 else:
178 callbacks.update(resp2)
179 else:
180 if user_protocol is None or user_protocol in (3, "3"):
181 callbacks.update(resp3_unified if resp3_unified is not None else resp3)
182 else:
183 callbacks.update(resp2_unified if resp2_unified is not None else resp2)
184 return callbacks
185
186
187def at_most_one_value_set(iterable: Iterable[Any]):
188 """
189 Checks that at most one of the values in the iterable is truthy.
190
191 Args:
192 iterable: An iterable of values to check.
193
194 Returns:
195 True if at most one value is truthy, False otherwise.
196
197 Raises:
198 Might raise an error if the values in iterable are not boolean-compatible.
199 For example if the type of the values implement
200 __len__ or __bool__ methods and they raise an error.
201 """
202 values = (bool(x) for x in iterable)
203 return sum(values) <= 1