Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/bitstring/array_.py: 20%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import math
4import numbers
5from collections.abc import Sized
6from bitstring.exceptions import CreationError
7from typing import Union, List, Iterable, Any, Optional, BinaryIO, overload, TextIO
8from bitstring.bits import Bits, BitsType
9from bitstring.bitarray_ import BitArray
10from bitstring.dtypes import Dtype, dtype_register
11from bitstring import utils
12from bitstring.bitstring_options import Options, Colour
13import copy
14import array
15import operator
16import io
17import sys
19# The possible types stored in each element of the Array
20ElementType = Union[float, str, int, bytes, bool, Bits]
22options = Options()
25class Array:
26 """Return an Array whose elements are initialised according to the fmt string.
27 The dtype string can be typecode as used in the struct module or any fixed-length bitstring
28 format.
30 a = Array('>H', [1, 15, 105])
31 b = Array('int5', [-9, 0, 4])
33 The Array data is stored compactly as a BitArray object and the Array behaves very like
34 a list of items of the given format. Both the Array data and fmt properties can be freely
35 modified after creation. If the data length is not a multiple of the fmt length then the
36 Array will have 'trailing_bits' which will prevent some methods from appending to the
37 Array.
39 Methods:
41 append() -- Append a single item to the end of the Array.
42 byteswap() -- Change byte endianness of all items.
43 count() -- Count the number of occurences of a value.
44 extend() -- Append new items to the end of the Array from an iterable.
45 fromfile() -- Append items read from a file object.
46 insert() -- Insert an item at a given position.
47 pop() -- Remove and return an item.
48 pp() -- Pretty print the Array.
49 reverse() -- Reverse the order of all items.
50 tobytes() -- Return Array data as bytes object, padding with zero bits at the end if needed.
51 tofile() -- Write Array data to a file, padding with zero bits at the end if needed.
52 tolist() -- Return Array items as a list.
54 Special methods:
56 Also available are the operators [], ==, !=, +, *, <<, >>, &, |, ^,
57 plus the mutating operators [], +=, *=, <<=, >>=, &=, |=, ^=.
59 Properties:
61 data -- The BitArray binary data of the Array. Can be freely modified.
62 dtype -- The format string or typecode. Can be freely modified.
63 itemsize -- The length *in bits* of a single item. Read only.
64 trailing_bits -- If the data length is not a multiple of the fmt length, this BitArray
65 gives the leftovers at the end of the data.
68 """
70 def __init__(self, dtype: Union[str, Dtype], initializer: Optional[Union[int, Array, array.array, Iterable, Bits, bytes, bytearray, memoryview, BinaryIO]] = None,
71 trailing_bits: Optional[BitsType] = None) -> None:
72 self.data = BitArray()
73 if isinstance(dtype, Dtype) and dtype.scale == 'auto':
74 if isinstance(initializer, (int, Bits, bytes, bytearray, memoryview, BinaryIO)):
75 raise TypeError("An Array with an 'auto' scale factor can only be created from an iterable of values.")
76 auto_scale = self._calculate_auto_scale(initializer, dtype.name, dtype.length)
77 dtype = Dtype(dtype.name, dtype.length, scale=auto_scale)
78 try:
79 self._set_dtype(dtype)
80 except ValueError as e:
81 raise CreationError(e)
83 if isinstance(initializer, numbers.Integral):
84 self.data = BitArray(initializer * self._dtype.bitlength)
85 elif isinstance(initializer, (Bits, bytes, bytearray, memoryview)):
86 self.data += initializer
87 elif isinstance(initializer, io.BufferedReader):
88 self.fromfile(initializer)
89 elif initializer is not None:
90 self.extend(initializer)
92 if trailing_bits is not None:
93 self.data += BitArray._create_from_bitstype(trailing_bits)
95 _largest_values = None
97 @staticmethod
98 def _calculate_auto_scale(initializer, name: str, length: Optional[int]) -> float:
99 # Now need to find the largest power of 2 representable with this format.
100 if Array._largest_values is None:
101 Array._largest_values = {
102 'mxint8': Bits('0b01111111').mxint8, # 1.0 + 63.0/64.0,
103 'e2m1mxfp4': Bits('0b0111').e2m1mxfp4, # 6.0
104 'e2m3mxfp6': Bits('0b011111').e2m3mxfp6, # 7.5
105 'e3m2mxfp6': Bits('0b011111').e3m2mxfp6, # 28.0
106 'e4m3mxfp8': Bits('0b01111110').e4m3mxfp8, # 448.0
107 'e5m2mxfp8': Bits('0b01111011').e5m2mxfp8, # 57344.0
108 'p4binary8': Bits('0b01111110').p4binary8, # 224.0
109 'p3binary8': Bits('0b01111110').p3binary8, # 49152.0
110 'float16': Bits('0x7bff').float16, # 65504.0
111 # The bfloat range is so large the scaling algorithm doesn't work well, so I'm disallowing it.
112 # 'bfloat16': Bits('0x7f7f').bfloat16, # 3.38953139e38,
113 }
114 if f'{name}{length}' in Array._largest_values.keys():
115 float_values = Array('float64', initializer).tolist()
116 if not float_values:
117 raise ValueError("Can't calculate an 'auto' scale with an empty Array initializer.")
118 max_float_value = max(abs(x) for x in float_values)
119 if max_float_value == 0:
120 # This special case isn't covered in the standard. I'm choosing to return no scale.
121 return 1.0
122 # We need to find the largest power of 2 that is less than the max value
123 log2 = math.floor(math.log2(max_float_value))
124 lp2 = math.floor(math.log2(Array._largest_values[f'{name}{length}']))
125 lg_scale = log2 - lp2
126 # Saturate at values representable in E8M0 format.
127 if lg_scale > 127:
128 lg_scale = 127
129 elif lg_scale < -127:
130 lg_scale = -127
131 return 2 ** lg_scale
132 else:
133 raise ValueError(f"Can't calculate auto scale for format '{name}{length}'. "
134 f"This feature is only available for these formats: {list(Array._largest_values.keys())}.")
136 @property
137 def itemsize(self) -> int:
138 return self._dtype.length
140 @property
141 def trailing_bits(self) -> BitArray:
142 trailing_bit_length = len(self.data) % self._dtype.bitlength
143 return BitArray() if trailing_bit_length == 0 else self.data[-trailing_bit_length:]
145 @property
146 def dtype(self) -> Dtype:
147 return self._dtype
149 @dtype.setter
150 def dtype(self, new_dtype: Union[str, Dtype]) -> None:
151 self._set_dtype(new_dtype)
153 def _set_dtype(self, new_dtype: Union[str, Dtype]) -> None:
154 if isinstance(new_dtype, Dtype):
155 self._dtype = new_dtype
156 else:
157 try:
158 dtype = Dtype(new_dtype)
159 except ValueError:
160 name_length = utils.parse_single_struct_token(new_dtype)
161 if name_length is not None:
162 dtype = Dtype(name_length[0], name_length[1])
163 else:
164 raise ValueError(f"Inappropriate Dtype for Array: '{new_dtype}'.")
165 if dtype.length is None:
166 raise ValueError(f"A fixed length format is needed for an Array, received '{new_dtype}'.")
167 self._dtype = dtype
168 if self._dtype.scale == 'auto':
169 raise ValueError("A Dtype with an 'auto' scale factor can only be used when creating a new Array.")
171 def _create_element(self, value: ElementType) -> Bits:
172 """Create Bits from value according to the token_name and token_length"""
173 b = self._dtype.build(value)
174 if len(b) != self._dtype.length:
175 raise ValueError(f"The value {value!r} has the wrong length for the format '{self._dtype}'.")
176 return b
178 def __len__(self) -> int:
179 return len(self.data) // self._dtype.length
181 @overload
182 def __getitem__(self, key: slice) -> Array:
183 ...
185 @overload
186 def __getitem__(self, key: int) -> ElementType:
187 ...
189 def __getitem__(self, key: Union[slice, int]) -> Union[Array, ElementType]:
190 if isinstance(key, slice):
191 start, stop, step = key.indices(len(self))
192 if step != 1:
193 d = BitArray()
194 for s in range(start * self._dtype.length, stop * self._dtype.length, step * self._dtype.length):
195 d.append(self.data[s: s + self._dtype.length])
196 a = self.__class__(self._dtype)
197 a.data = d
198 return a
199 else:
200 a = self.__class__(self._dtype)
201 a.data = self.data[start * self._dtype.length: stop * self._dtype.length]
202 return a
203 else:
204 if key < 0:
205 key += len(self)
206 if key < 0 or key >= len(self):
207 raise IndexError(f"Index {key} out of range for Array of length {len(self)}.")
208 return self._dtype.read_fn(self.data, start=self._dtype.length * key)
210 @overload
211 def __setitem__(self, key: slice, value: Iterable[ElementType]) -> None:
212 ...
214 @overload
215 def __setitem__(self, key: int, value: ElementType) -> None:
216 ...
218 def __setitem__(self, key: Union[slice, int], value: Union[Iterable[ElementType], ElementType]) -> None:
219 if isinstance(key, slice):
220 start, stop, step = key.indices(len(self))
221 if not isinstance(value, Iterable):
222 raise TypeError("Can only assign an iterable to a slice.")
223 if step == 1:
224 new_data = BitArray()
225 for x in value:
226 new_data += self._create_element(x)
227 self.data[start * self._dtype.length: stop * self._dtype.length] = new_data
228 return
229 items_in_slice = len(range(start, stop, step))
230 if not isinstance(value, Sized):
231 value = list(value)
232 if len(value) == items_in_slice:
233 for s, v in zip(range(start, stop, step), value):
234 self.data.overwrite(self._create_element(v), s * self._dtype.length)
235 else:
236 raise ValueError(f"Can't assign {len(value)} values to an extended slice of length {items_in_slice}.")
237 else:
238 if key < 0:
239 key += len(self)
240 if key < 0 or key >= len(self):
241 raise IndexError(f"Index {key} out of range for Array of length {len(self)}.")
242 start = self._dtype.length * key
243 self.data.overwrite(self._create_element(value), start)
244 return
246 def __delitem__(self, key: Union[slice, int]) -> None:
247 if isinstance(key, slice):
248 start, stop, step = key.indices(len(self))
249 if step == 1:
250 self.data.__delitem__(slice(start * self._dtype.length, stop * self._dtype.length))
251 return
252 # We need to delete from the end or the earlier positions will change
253 r = reversed(range(start, stop, step)) if step > 0 else range(start, stop, step)
254 for s in r:
255 self.data.__delitem__(slice(s * self._dtype.length, (s + 1) * self._dtype.length))
256 else:
257 if key < 0:
258 key += len(self)
259 if key < 0 or key >= len(self):
260 raise IndexError
261 start = self._dtype.length * key
262 del self.data[start: start + self._dtype.length]
264 def __repr__(self) -> str:
265 list_str = f"{self.tolist()}"
266 trailing_bit_length = len(self.data) % self._dtype.length
267 final_str = "" if trailing_bit_length == 0 else ", trailing_bits=" + repr(
268 self.data[-trailing_bit_length:])
269 return f"Array('{self._dtype}', {list_str}{final_str})"
271 def astype(self, dtype: Union[str, Dtype]) -> Array:
272 """Return Array with elements of new dtype, initialised from current Array."""
273 new_array = self.__class__(dtype, self.tolist())
274 return new_array
276 def tolist(self) -> List[ElementType]:
277 return [self._dtype.read_fn(self.data, start=start)
278 for start in range(0, len(self.data) - self._dtype.length + 1, self._dtype.length)]
280 def append(self, x: ElementType) -> None:
281 if len(self.data) % self._dtype.length != 0:
282 raise ValueError("Cannot append to Array as its length is not a multiple of the format length.")
283 self.data += self._create_element(x)
285 def extend(self, iterable: Union[Array, array.array, Iterable[Any]]) -> None:
286 if len(self.data) % self._dtype.length != 0:
287 raise ValueError(f"Cannot extend Array as its data length ({len(self.data)} bits) is not a multiple of the format length ({self._dtype.length} bits).")
288 if isinstance(iterable, Array):
289 if self._dtype.name != iterable._dtype.name or self._dtype.length != iterable._dtype.length:
290 raise TypeError(
291 f"Cannot extend an Array with format '{self._dtype}' from an Array of format '{iterable._dtype}'.")
292 # No need to iterate over the elements, we can just append the data
293 self.data.append(iterable.data)
294 elif isinstance(iterable, array.array):
295 # array.array types are always native-endian, hence the '='
296 name_value = utils.parse_single_struct_token('=' + iterable.typecode)
297 if name_value is None:
298 raise ValueError(f"Cannot extend from array with typecode {iterable.typecode}.")
299 other_dtype = dtype_register.get_dtype(*name_value, scale=None)
300 if self._dtype.name != other_dtype.name or self._dtype.length != other_dtype.length:
301 raise ValueError(
302 f"Cannot extend an Array with format '{self._dtype}' from an array with typecode '{iterable.typecode}'.")
303 self.data += iterable.tobytes()
304 else:
305 if isinstance(iterable, str):
306 raise TypeError("Can't extend an Array with a str.")
307 for item in iterable:
308 self.data += self._create_element(item)
310 def insert(self, i: int, x: ElementType) -> None:
311 """Insert a new element into the Array at position i.
313 """
314 i = min(i, len(self)) # Inserting beyond len of array inserts at the end (copying standard behaviour)
315 self.data.insert(self._create_element(x), i * self._dtype.length)
317 def pop(self, i: int = -1) -> ElementType:
318 """Return and remove an element of the Array.
320 Default is to return and remove the final element.
322 """
323 if len(self) == 0:
324 raise IndexError("Can't pop from an empty Array.")
325 x = self[i]
326 del self[i]
327 return x
329 def byteswap(self) -> None:
330 """Change the endianness in-place of all items in the Array.
332 If the Array format is not a whole number of bytes a ValueError will be raised.
334 """
335 if self._dtype.length % 8 != 0:
336 raise ValueError(
337 f"byteswap can only be used for whole-byte elements. The '{self._dtype}' format is {self._dtype.length} bits long.")
338 self.data.byteswap(self.itemsize // 8)
340 def count(self, value: ElementType) -> int:
341 """Return count of Array items that equal value.
343 value -- The quantity to compare each Array element to. Type should be appropriate for the Array format.
345 For floating point types using a value of float('nan') will count the number of elements that are NaN.
347 """
348 if math.isnan(value):
349 return sum(math.isnan(i) for i in self)
350 else:
351 return sum(i == value for i in self)
353 def tobytes(self) -> bytes:
354 """Return the Array data as a bytes object, padding with zero bits if needed.
356 Up to seven zero bits will be added at the end to byte align.
358 """
359 return self.data.tobytes()
361 def tofile(self, f: BinaryIO) -> None:
362 """Write the Array data to a file object, padding with zero bits if needed.
364 Up to seven zero bits will be added at the end to byte align.
366 """
367 self.data.tofile(f)
369 def fromfile(self, f: BinaryIO, n: Optional[int] = None) -> None:
370 trailing_bit_length = len(self.data) % self._dtype.bitlength
371 if trailing_bit_length != 0:
372 raise ValueError(f"Cannot extend Array as its data length ({len(self.data)} bits) is not a multiple of the format length ({self._dtype.bitlength} bits).")
374 new_data = Bits(f)
375 max_items = len(new_data) // self._dtype.length
376 items_to_append = max_items if n is None else min(n, max_items)
377 self.data += new_data[0: items_to_append * self._dtype.bitlength]
378 if n is not None and items_to_append < n:
379 raise EOFError(f"Only {items_to_append} were appended, not the {n} items requested.")
381 def reverse(self) -> None:
382 trailing_bit_length = len(self.data) % self._dtype.length
383 if trailing_bit_length != 0:
384 raise ValueError(f"Cannot reverse the items in the Array as its data length ({len(self.data)} bits) is not a multiple of the format length ({self._dtype.length} bits).")
385 for start_bit in range(0, len(self.data) // 2, self._dtype.length):
386 start_swap_bit = len(self.data) - start_bit - self._dtype.length
387 temp = self.data[start_bit: start_bit + self._dtype.length]
388 self.data[start_bit: start_bit + self._dtype.length] = self.data[
389 start_swap_bit: start_swap_bit + self._dtype.length]
390 self.data[start_swap_bit: start_swap_bit + self._dtype.length] = temp
392 def pp(self, fmt: Optional[str] = None, width: int = 120,
393 show_offset: bool = True, stream: TextIO = sys.stdout) -> None:
394 """Pretty-print the Array contents.
396 fmt -- Data format string. Defaults to current Array dtype.
397 width -- Max width of printed lines in characters. Defaults to 120. A single group will always
398 be printed per line even if it exceeds the max width.
399 show_offset -- If True shows the element offset in the first column of each line.
400 stream -- A TextIO object with a write() method. Defaults to sys.stdout.
402 """
403 colour = Colour(not options.no_color)
404 sep = ' '
405 dtype2 = None
406 tidy_fmt = None
407 if fmt is None:
408 fmt = self.dtype
409 dtype1 = self.dtype
410 tidy_fmt = "dtype='" + colour.purple + str(self.dtype) + "'" + colour.off
411 else:
412 token_list = utils.preprocess_tokens(fmt)
413 if len(token_list) not in [1, 2]:
414 raise ValueError(f"Only one or two tokens can be used in an Array.pp() format - '{fmt}' has {len(token_list)} tokens.")
415 name1, length1 = utils.parse_name_length_token(token_list[0])
416 dtype1 = Dtype(name1, length1)
417 if len(token_list) == 2:
418 name2, length2 = utils.parse_name_length_token(token_list[1])
419 dtype2 = Dtype(name2, length2)
421 token_length = dtype1.bitlength
422 if dtype2 is not None:
423 # For two types we're OK as long as they don't have different lengths given.
424 if dtype1.bitlength is not None and dtype2.bitlength is not None and dtype1.bitlength != dtype2.bitlength:
425 raise ValueError(f"Two different format lengths specified ('{fmt}'). Either specify just one, or two the same length.")
426 if token_length is None:
427 token_length = dtype2.bitlength
428 if token_length is None:
429 token_length = self.itemsize
431 trailing_bit_length = len(self.data) % token_length
432 format_sep = " : " # String to insert on each line between multiple formats
433 if tidy_fmt is None:
434 tidy_fmt = colour.purple + str(dtype1) + colour.off
435 if dtype2 is not None:
436 tidy_fmt += ', ' + colour.blue + str(dtype2) + colour.off
437 tidy_fmt = "fmt='" + tidy_fmt + "'"
438 data = self.data if trailing_bit_length == 0 else self.data[0: -trailing_bit_length]
439 length = len(self.data) // token_length
440 len_str = colour.green + str(length) + colour.off
441 stream.write(f"<{self.__class__.__name__} {tidy_fmt}, length={len_str}, itemsize={token_length} bits, total data size={(len(self.data) + 7) // 8} bytes> [\n")
442 data._pp(dtype1, dtype2, token_length, width, sep, format_sep, show_offset, stream, False, token_length)
443 stream.write("]")
444 if trailing_bit_length != 0:
445 stream.write(" + trailing_bits = " + str(self.data[-trailing_bit_length:]))
446 stream.write("\n")
448 def equals(self, other: Any) -> bool:
449 """Return True if format and all Array items are equal."""
450 if isinstance(other, Array):
451 if self._dtype.length != other._dtype.length:
452 return False
453 if self._dtype.name != other._dtype.name:
454 return False
455 if self.data != other.data:
456 return False
457 return True
458 elif isinstance(other, array.array):
459 # Assume we are comparing with an array type
460 if self.trailing_bits:
461 return False
462 # array's itemsize is in bytes, not bits.
463 if self.itemsize != other.itemsize * 8:
464 return False
465 if len(self) != len(other):
466 return False
467 if self.tolist() != other.tolist():
468 return False
469 return True
470 return False
472 def __iter__(self) -> Iterable[ElementType]:
473 start = 0
474 for _ in range(len(self)):
475 yield self._dtype.read_fn(self.data, start=start)
476 start += self._dtype.length
478 def __copy__(self) -> Array:
479 a_copy = self.__class__(self._dtype)
480 a_copy.data = copy.copy(self.data)
481 return a_copy
483 def _apply_op_to_all_elements(self, op, value: Union[int, float, None], is_comparison: bool = False) -> Array:
484 """Apply op with value to each element of the Array and return a new Array"""
485 new_array = self.__class__('bool' if is_comparison else self._dtype)
486 new_data = BitArray()
487 failures = index = 0
488 msg = ''
489 if value is not None:
490 def partial_op(a):
491 return op(a, value)
492 else:
493 def partial_op(a):
494 return op(a)
495 for i in range(len(self)):
496 v = self._dtype.read_fn(self.data, start=self._dtype.length * i)
497 try:
498 new_data.append(new_array._create_element(partial_op(v)))
499 except (CreationError, ZeroDivisionError, ValueError) as e:
500 if failures == 0:
501 msg = str(e)
502 index = i
503 failures += 1
504 if failures != 0:
505 raise ValueError(f"Applying operator '{op.__name__}' to Array caused {failures} errors. "
506 f'First error at index {index} was: "{msg}"')
507 new_array.data = new_data
508 return new_array
510 def _apply_op_to_all_elements_inplace(self, op, value: Union[int, float]) -> Array:
511 """Apply op with value to each element of the Array in place."""
512 # This isn't really being done in-place, but it's simpler and faster for now?
513 new_data = BitArray()
514 failures = index = 0
515 msg = ''
516 for i in range(len(self)):
517 v = self._dtype.read_fn(self.data, start=self._dtype.length * i)
518 try:
519 new_data.append(self._create_element(op(v, value)))
520 except (CreationError, ZeroDivisionError, ValueError) as e:
521 if failures == 0:
522 msg = str(e)
523 index = i
524 failures += 1
525 if failures != 0:
526 raise ValueError(f"Applying operator '{op.__name__}' to Array caused {failures} errors. "
527 f'First error at index {index} was: "{msg}"')
528 self.data = new_data
529 return self
531 def _apply_bitwise_op_to_all_elements(self, op, value: BitsType) -> Array:
532 """Apply op with value to each element of the Array as an unsigned integer and return a new Array"""
533 a_copy = self[:]
534 a_copy._apply_bitwise_op_to_all_elements_inplace(op, value)
535 return a_copy
537 def _apply_bitwise_op_to_all_elements_inplace(self, op, value: BitsType) -> Array:
538 """Apply op with value to each element of the Array as an unsigned integer in place."""
539 value = BitArray._create_from_bitstype(value)
540 if len(value) != self._dtype.length:
541 raise ValueError(f"Bitwise op needs a bitstring of length {self._dtype.length} to match format {self._dtype}.")
542 for start in range(0, len(self) * self._dtype.length, self._dtype.length):
543 self.data[start: start + self._dtype.length] = op(self.data[start: start + self._dtype.length], value)
544 return self
546 def _apply_op_between_arrays(self, op, other: Array, is_comparison: bool = False) -> Array:
547 if len(self) != len(other):
548 msg = f"Cannot operate element-wise on Arrays with different lengths ({len(self)} and {len(other)})."
549 if op in [operator.add, operator.iadd]:
550 msg += " Use extend() method to concatenate Arrays."
551 if op in [operator.eq, operator.ne]:
552 msg += " Use equals() method to compare Arrays for a single boolean result."
553 raise ValueError(msg)
554 if is_comparison:
555 new_type = dtype_register.get_dtype('bool', 1)
556 else:
557 new_type = self._promotetype(self._dtype, other._dtype)
558 new_array = self.__class__(new_type)
559 new_data = BitArray()
560 failures = index = 0
561 msg = ''
562 for i in range(len(self)):
563 a = self._dtype.read_fn(self.data, start=self._dtype.length * i)
564 b = other._dtype.read_fn(other.data, start=other._dtype.length * i)
565 try:
566 new_data.append(new_array._create_element(op(a, b)))
567 except (CreationError, ValueError, ZeroDivisionError) as e:
568 if failures == 0:
569 msg = str(e)
570 index = i
571 failures += 1
572 if failures != 0:
573 raise ValueError(f"Applying operator '{op.__name__}' between Arrays caused {failures} errors. "
574 f'First error at index {index} was: "{msg}"')
575 new_array.data = new_data
576 return new_array
578 @classmethod
579 def _promotetype(cls, type1: Dtype, type2: Dtype) -> Dtype:
580 """When combining types which one wins?
582 1. We only deal with types representing floats or integers.
583 2. One of the two types gets returned. We never create a new one.
584 3. Floating point types always win against integer types.
585 4. Signed integer types always win against unsigned integer types.
586 5. Longer types win against shorter types.
587 6. In a tie the first type wins against the second type.
589 """
590 def is_float(x): return x.return_type is float
591 def is_int(x): return x.return_type is int or x.return_type is bool
592 if is_float(type1) + is_int(type1) + is_float(type2) + is_int(type2) != 2:
593 raise ValueError(f"Only integer and floating point types can be combined - not '{type1}' and '{type2}'.")
594 # If same type choose the widest
595 if type1.name == type2.name:
596 return type1 if type1.length > type2.length else type2
597 # We choose floats above integers, irrespective of the widths
598 if is_float(type1) and is_int(type2):
599 return type1
600 if is_int(type1) and is_float(type2):
601 return type2
602 if is_float(type1) and is_float(type2):
603 return type2 if type2.length > type1.length else type1
604 assert is_int(type1) and is_int(type2)
605 if type1.is_signed and not type2.is_signed:
606 return type1
607 if type2.is_signed and not type1.is_signed:
608 return type2
609 return type2 if type2.length > type1.length else type1
611 # Operators between Arrays or an Array and scalar value
613 def __add__(self, other: Union[int, float, Array]) -> Array:
614 """Add int or float to all elements."""
615 if isinstance(other, Array):
616 return self._apply_op_between_arrays(operator.add, other)
617 return self._apply_op_to_all_elements(operator.add, other)
619 def __iadd__(self, other: Union[int, float, Array]) -> Array:
620 if isinstance(other, Array):
621 return self._apply_op_between_arrays(operator.add, other)
622 return self._apply_op_to_all_elements_inplace(operator.add, other)
624 def __isub__(self, other: Union[int, float, Array]) -> Array:
625 if isinstance(other, Array):
626 return self._apply_op_between_arrays(operator.sub, other)
627 return self._apply_op_to_all_elements_inplace(operator.sub, other)
629 def __sub__(self, other: Union[int, float, Array]) -> Array:
630 if isinstance(other, Array):
631 return self._apply_op_between_arrays(operator.sub, other)
632 return self._apply_op_to_all_elements(operator.sub, other)
634 def __mul__(self, other: Union[int, float, Array]) -> Array:
635 if isinstance(other, Array):
636 return self._apply_op_between_arrays(operator.mul, other)
637 return self._apply_op_to_all_elements(operator.mul, other)
639 def __imul__(self, other: Union[int, float, Array]) -> Array:
640 if isinstance(other, Array):
641 return self._apply_op_between_arrays(operator.mul, other)
642 return self._apply_op_to_all_elements_inplace(operator.mul, other)
644 def __floordiv__(self, other: Union[int, float, Array]) -> Array:
645 if isinstance(other, Array):
646 return self._apply_op_between_arrays(operator.floordiv, other)
647 return self._apply_op_to_all_elements(operator.floordiv, other)
649 def __ifloordiv__(self, other: Union[int, float, Array]) -> Array:
650 if isinstance(other, Array):
651 return self._apply_op_between_arrays(operator.floordiv, other)
652 return self._apply_op_to_all_elements_inplace(operator.floordiv, other)
654 def __truediv__(self, other: Union[int, float, Array]) -> Array:
655 if isinstance(other, Array):
656 return self._apply_op_between_arrays(operator.truediv, other)
657 return self._apply_op_to_all_elements(operator.truediv, other)
659 def __itruediv__(self, other: Union[int, float, Array]) -> Array:
660 if isinstance(other, Array):
661 return self._apply_op_between_arrays(operator.truediv, other)
662 return self._apply_op_to_all_elements_inplace(operator.truediv, other)
664 def __rshift__(self, other: Union[int, Array]) -> Array:
665 if isinstance(other, Array):
666 return self._apply_op_between_arrays(operator.rshift, other)
667 return self._apply_op_to_all_elements(operator.rshift, other)
669 def __lshift__(self, other: Union[int, Array]) -> Array:
670 if isinstance(other, Array):
671 return self._apply_op_between_arrays(operator.lshift, other)
672 return self._apply_op_to_all_elements(operator.lshift, other)
674 def __irshift__(self, other: Union[int, Array]) -> Array:
675 if isinstance(other, Array):
676 return self._apply_op_between_arrays(operator.rshift, other)
677 return self._apply_op_to_all_elements_inplace(operator.rshift, other)
679 def __ilshift__(self, other: Union[int, Array]) -> Array:
680 if isinstance(other, Array):
681 return self._apply_op_between_arrays(operator.lshift, other)
682 return self._apply_op_to_all_elements_inplace(operator.lshift, other)
684 def __mod__(self, other: Union[int, Array]) -> Array:
685 if isinstance(other, Array):
686 return self._apply_op_between_arrays(operator.mod, other)
687 return self._apply_op_to_all_elements(operator.mod, other)
689 def __imod__(self, other: Union[int, Array]) -> Array:
690 if isinstance(other, Array):
691 return self._apply_op_between_arrays(operator.mod, other)
692 return self._apply_op_to_all_elements_inplace(operator.mod, other)
694 # Bitwise operators
696 def __and__(self, other: BitsType) -> Array:
697 return self._apply_bitwise_op_to_all_elements(operator.iand, other)
699 def __iand__(self, other: BitsType) -> Array:
700 return self._apply_bitwise_op_to_all_elements_inplace(operator.iand, other)
702 def __or__(self, other: BitsType) -> Array:
703 return self._apply_bitwise_op_to_all_elements(operator.ior, other)
705 def __ior__(self, other: BitsType) -> Array:
706 return self._apply_bitwise_op_to_all_elements_inplace(operator.ior, other)
708 def __xor__(self, other: BitsType) -> Array:
709 return self._apply_bitwise_op_to_all_elements(operator.ixor, other)
711 def __ixor__(self, other: BitsType) -> Array:
712 return self._apply_bitwise_op_to_all_elements_inplace(operator.ixor, other)
714 # Reverse operators between a scalar value and an Array
716 def __rmul__(self, other: Union[int, float]) -> Array:
717 return self._apply_op_to_all_elements(operator.mul, other)
719 def __radd__(self, other: Union[int, float]) -> Array:
720 return self._apply_op_to_all_elements(operator.add, other)
722 def __rsub__(self, other: Union[int, float]) -> Array:
723 # i - A == (-A) + i
724 neg = self._apply_op_to_all_elements(operator.neg, None)
725 return neg._apply_op_to_all_elements(operator.add, other)
727 # Reverse operators between a scalar and something that can be a BitArray.
729 def __rand__(self, other: BitsType) -> Array:
730 return self._apply_bitwise_op_to_all_elements(operator.iand, other)
732 def __ror__(self, other: BitsType) -> Array:
733 return self._apply_bitwise_op_to_all_elements(operator.ior, other)
735 def __rxor__(self, other: BitsType) -> Array:
736 return self._apply_bitwise_op_to_all_elements(operator.ixor, other)
738 # Comparison operators
740 def __lt__(self, other: Union[int, float, Array]) -> Array:
741 if isinstance(other, Array):
742 return self._apply_op_between_arrays(operator.lt, other, is_comparison=True)
743 return self._apply_op_to_all_elements(operator.lt, other, is_comparison=True)
745 def __gt__(self, other: Union[int, float, Array]) -> Array:
746 if isinstance(other, Array):
747 return self._apply_op_between_arrays(operator.gt, other, is_comparison=True)
748 return self._apply_op_to_all_elements(operator.gt, other, is_comparison=True)
750 def __ge__(self, other: Union[int, float, Array]) -> Array:
751 if isinstance(other, Array):
752 return self._apply_op_between_arrays(operator.ge, other, is_comparison=True)
753 return self._apply_op_to_all_elements(operator.ge, other, is_comparison=True)
755 def __le__(self, other: Union[int, float, Array]) -> Array:
756 if isinstance(other, Array):
757 return self._apply_op_between_arrays(operator.le, other, is_comparison=True)
758 return self._apply_op_to_all_elements(operator.le, other, is_comparison=True)
760 def _eq_ne(self, op, other: Any) -> Array:
761 if isinstance(other, (int, float, str, Bits)):
762 return self._apply_op_to_all_elements(op, other, is_comparison=True)
763 other = self.__class__(self.dtype, other)
764 return self._apply_op_between_arrays(op, other, is_comparison=True)
766 def __eq__(self, other: Any) -> Array:
767 return self._eq_ne(operator.eq, other)
769 def __ne__(self, other: Any) -> Array:
770 return self._eq_ne(operator.ne, other)
772 # Unary operators
774 def __neg__(self):
775 return self._apply_op_to_all_elements(operator.neg, None)
777 def __abs__(self):
778 return self._apply_op_to_all_elements(operator.abs, None)