1from collections.abc import Container, Iterable, Sized, Hashable
2from functools import reduce
3from typing import Generic, TypeVar
4from pyrsistent._pmap import pmap
5
6T_co = TypeVar('T_co', covariant=True)
7
8
9def _add_to_counters(counters, element):
10 return counters.set(element, counters.get(element, 0) + 1)
11
12
13class PBag(Generic[T_co]):
14 """
15 A persistent bag/multiset type.
16
17 Requires elements to be hashable, and allows duplicates, but has no
18 ordering. Bags are hashable.
19
20 Do not instantiate directly, instead use the factory functions :py:func:`b`
21 or :py:func:`pbag` to create an instance.
22
23 Some examples:
24
25 >>> s = pbag([1, 2, 3, 1])
26 >>> s2 = s.add(4)
27 >>> s3 = s2.remove(1)
28 >>> s
29 pbag([1, 1, 2, 3])
30 >>> s2
31 pbag([1, 1, 2, 3, 4])
32 >>> s3
33 pbag([1, 2, 3, 4])
34 """
35
36 __slots__ = ('_counts', '__weakref__')
37
38 def __init__(self, counts):
39 self._counts = counts
40
41 def add(self, element):
42 """
43 Add an element to the bag.
44
45 >>> s = pbag([1])
46 >>> s2 = s.add(1)
47 >>> s3 = s.add(2)
48 >>> s2
49 pbag([1, 1])
50 >>> s3
51 pbag([1, 2])
52 """
53 return PBag(_add_to_counters(self._counts, element))
54
55 def update(self, iterable):
56 """
57 Update bag with all elements in iterable.
58
59 >>> s = pbag([1])
60 >>> s.update([1, 2])
61 pbag([1, 1, 2])
62 """
63 if iterable:
64 return PBag(reduce(_add_to_counters, iterable, self._counts))
65
66 return self
67
68 def remove(self, element):
69 """
70 Remove an element from the bag.
71
72 >>> s = pbag([1, 1, 2])
73 >>> s2 = s.remove(1)
74 >>> s3 = s.remove(2)
75 >>> s2
76 pbag([1, 2])
77 >>> s3
78 pbag([1, 1])
79 """
80 if element not in self._counts:
81 raise KeyError(element)
82 elif self._counts[element] == 1:
83 newc = self._counts.remove(element)
84 else:
85 newc = self._counts.set(element, self._counts[element] - 1)
86 return PBag(newc)
87
88 def count(self, element):
89 """
90 Return the number of times an element appears.
91
92
93 >>> pbag([]).count('non-existent')
94 0
95 >>> pbag([1, 1, 2]).count(1)
96 2
97 """
98 return self._counts.get(element, 0)
99
100 def __len__(self):
101 """
102 Return the length including duplicates.
103
104 >>> len(pbag([1, 1, 2]))
105 3
106 """
107 return sum(self._counts.itervalues())
108
109 def __iter__(self):
110 """
111 Return an iterator of all elements, including duplicates.
112
113 >>> list(pbag([1, 1, 2]))
114 [1, 1, 2]
115 >>> list(pbag([1, 2]))
116 [1, 2]
117 """
118 for elt, count in self._counts.iteritems():
119 for i in range(count):
120 yield elt
121
122 def __contains__(self, elt):
123 """
124 Check if an element is in the bag.
125
126 >>> 1 in pbag([1, 1, 2])
127 True
128 >>> 0 in pbag([1, 2])
129 False
130 """
131 return elt in self._counts
132
133 def __repr__(self):
134 return "pbag({0})".format(list(self))
135
136 def __eq__(self, other):
137 """
138 Check if two bags are equivalent, honoring the number of duplicates,
139 and ignoring insertion order.
140
141 >>> pbag([1, 1, 2]) == pbag([1, 2])
142 False
143 >>> pbag([2, 1, 0]) == pbag([0, 1, 2])
144 True
145 """
146 if type(other) is not PBag:
147 raise TypeError("Can only compare PBag with PBags")
148 return self._counts == other._counts
149
150 def __lt__(self, other):
151 raise TypeError('PBags are not orderable')
152
153 __le__ = __lt__
154 __gt__ = __lt__
155 __ge__ = __lt__
156
157 # Multiset-style operations similar to collections.Counter
158
159 def __add__(self, other):
160 """
161 Combine elements from two PBags.
162
163 >>> pbag([1, 2, 2]) + pbag([2, 3, 3])
164 pbag([1, 2, 2, 2, 3, 3])
165 """
166 if not isinstance(other, PBag):
167 return NotImplemented
168 result = self._counts.evolver()
169 for elem, other_count in other._counts.iteritems():
170 result[elem] = self.count(elem) + other_count
171 return PBag(result.persistent())
172
173 def __sub__(self, other):
174 """
175 Remove elements from one PBag that are present in another.
176
177 >>> pbag([1, 2, 2, 2, 3]) - pbag([2, 3, 3, 4])
178 pbag([1, 2, 2])
179 """
180 if not isinstance(other, PBag):
181 return NotImplemented
182 result = self._counts.evolver()
183 for elem, other_count in other._counts.iteritems():
184 newcount = self.count(elem) - other_count
185 if newcount > 0:
186 result[elem] = newcount
187 elif elem in self:
188 result.remove(elem)
189 return PBag(result.persistent())
190
191 def __or__(self, other):
192 """
193 Union: Keep elements that are present in either of two PBags.
194
195 >>> pbag([1, 2, 2, 2]) | pbag([2, 3, 3])
196 pbag([1, 2, 2, 2, 3, 3])
197 """
198 if not isinstance(other, PBag):
199 return NotImplemented
200 result = self._counts.evolver()
201 for elem, other_count in other._counts.iteritems():
202 count = self.count(elem)
203 newcount = max(count, other_count)
204 result[elem] = newcount
205 return PBag(result.persistent())
206
207 def __and__(self, other):
208 """
209 Intersection: Only keep elements that are present in both PBags.
210
211 >>> pbag([1, 2, 2, 2]) & pbag([2, 3, 3])
212 pbag([2])
213 """
214 if not isinstance(other, PBag):
215 return NotImplemented
216 result = pmap().evolver()
217 for elem, count in self._counts.iteritems():
218 newcount = min(count, other.count(elem))
219 if newcount > 0:
220 result[elem] = newcount
221 return PBag(result.persistent())
222
223 def __hash__(self):
224 """
225 Hash based on value of elements.
226
227 >>> m = pmap({pbag([1, 2]): "it's here!"})
228 >>> m[pbag([2, 1])]
229 "it's here!"
230 >>> pbag([1, 1, 2]) in m
231 False
232 """
233 return hash(self._counts)
234
235
236Container.register(PBag)
237Iterable.register(PBag)
238Sized.register(PBag)
239Hashable.register(PBag)
240
241
242def b(*elements):
243 """
244 Construct a persistent bag.
245
246 Takes an arbitrary number of arguments to insert into the new persistent
247 bag.
248
249 >>> b(1, 2, 3, 2)
250 pbag([1, 2, 2, 3])
251 """
252 return pbag(elements)
253
254
255def pbag(elements):
256 """
257 Convert an iterable to a persistent bag.
258
259 Takes an iterable with elements to insert.
260
261 >>> pbag([1, 2, 3, 2])
262 pbag([1, 2, 2, 3])
263 """
264 if not elements:
265 return _EMPTY_PBAG
266 return PBag(reduce(_add_to_counters, elements, pmap()))
267
268
269_EMPTY_PBAG = PBag(pmap())
270