Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/numpy_helper.py: 9%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
6import math
7import sys
8from typing import TYPE_CHECKING, Any
10import ml_dtypes
11import numpy as np
12import numpy.typing as npt
14import onnx.external_data_helper
15from onnx import helper
17if TYPE_CHECKING:
18 from collections.abc import Sequence
21def to_float8e8m0(
22 x: np.ndarray,
23 saturate: bool = True,
24 round_mode: str = "up",
25) -> np.ndarray:
26 """Convert float32 NumPy array to float8e8m0 representation. If the input
27 is not a float32 array, it will be cast to one first.
29 Args:
30 x: Input array to convert.
31 saturate: Whether to saturate at max/min float8e8m0 value.
32 round_mode: "nearest", "up", or "down".
34 Returns:
35 np.ndarray: Array of ml_dtypes.float8_e8m0fnu values.
36 """
37 x_f32 = np.asarray(x, dtype=np.float32)
38 f_bits = x_f32.view(np.uint32)
40 # Extract exponent bits
41 exponent = (f_bits >> 23) & 0xFF
42 exponent = exponent.astype(
43 np.uint16
44 ) # use uint16 to prevent overflow during computation
46 # Identify NaN or Inf
47 special_mask = exponent == 0xFF # noqa: PLR2004
48 output = np.zeros_like(exponent, dtype=np.uint8)
49 output[special_mask] = 0xFF # Preserve NaN/Inf as max exponent
51 # Process normal numbers
52 normal_mask = ~special_mask
54 if round_mode == "nearest":
55 # Get guard, round, sticky, and least significant bits
56 g = ((f_bits & 0x400000) > 0).astype(np.uint8)
57 r = ((f_bits & 0x200000) > 0).astype(np.uint8)
58 s = ((f_bits & 0x1FFFFF) > 0).astype(np.uint8)
59 lsb = (exponent > 0).astype(np.uint8)
61 round_up = (g == 1) & ((r == 1) | (s == 1) | (lsb == 1))
63 increment = np.zeros_like(exponent)
64 increment[round_up & normal_mask] = 1
66 if saturate:
67 max_mask = (exponent == 0xFE) & round_up & normal_mask # noqa: PLR2004
68 increment[max_mask] = 0 # Don't overflow past max value
70 exponent += increment
72 elif round_mode == "up":
73 has_fraction = (f_bits & 0x7FFFFF) > 0
74 round_up = has_fraction & normal_mask
76 if saturate:
77 max_mask = (exponent == 0xFE) & round_up # noqa: PLR2004
78 round_up[max_mask] = False
80 exponent += round_up.astype(np.uint16)
82 elif round_mode == "down":
83 pass # No rounding needed
85 else:
86 raise ValueError(f"Unsupported rounding mode: {round_mode}")
88 # Clip exponent to uint8 range
89 exponent = exponent.astype(np.uint8)
91 output[normal_mask] = exponent[normal_mask]
93 return output.view(ml_dtypes.float8_e8m0fnu)
96def _unpack_4bit(
97 data: npt.NDArray[np.uint8], dims: Sequence[int]
98) -> npt.NDArray[np.uint8]:
99 """Convert a packed uint4 array to unpacked uint4 array represented as uint8.
101 Args:
102 data: A numpy array.
103 dims: The dimensions are used to reshape the unpacked buffer.
105 Returns:
106 A numpy array of int8/uint8 reshaped to dims.
107 """
108 result = np.empty([data.size * 2], dtype=data.dtype)
109 array_low = data & np.uint8(0x0F)
110 array_high = data & np.uint8(0xF0)
111 array_high >>= np.uint8(4)
112 result[0::2] = array_low
113 result[1::2] = array_high
114 expected_elements = math.prod(dims)
115 if result.size == expected_elements + 1:
116 # handle single-element padding due to odd number of elements
117 result = result[:-1]
118 if expected_elements > result.size:
119 raise ValueError(
120 f"Packed 4-bit data ({data.size} bytes, {result.size} elements unpacked) "
121 f"is too small for the declared shape {list(dims)} "
122 f"({expected_elements} elements required)."
123 )
124 result.resize(dims, refcheck=False)
125 return result
128def _pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
129 """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
130 # Create a 1D copy
131 array_flat = array.ravel().view(np.uint8).copy()
132 size = array.size
133 odd_sized = size % 2 == 1
134 if odd_sized:
135 array_flat.resize([size + 1], refcheck=False)
136 array_flat &= 0x0F
137 array_flat[1::2] <<= 4
138 return array_flat[0::2] | array_flat[1::2]
141def _unpack_2bit(
142 data: npt.NDArray[np.uint8], dims: Sequence[int]
143) -> npt.NDArray[np.uint8]:
144 """Convert a packed uint2 array to unpacked uint2 array represented as uint8.
146 Args:
147 data: A numpy array.
148 dims: The dimensions are used to reshape the unpacked buffer.
150 Returns:
151 A numpy array of int8/uint8 reshaped to dims.
152 """
153 result = np.empty([data.size * 4], dtype=data.dtype)
154 result[0::4] = data & 0x03
155 result[1::4] = (data >> 2) & 0x03
156 result[2::4] = (data >> 4) & 0x03
157 result[3::4] = (data >> 6) & 0x03
158 expected_elements = math.prod(dims)
159 if result.size > expected_elements:
160 # handle padding due to non multiple of 4 elements
161 result = result[:expected_elements]
162 if expected_elements > result.size:
163 raise ValueError(
164 f"Packed 2-bit data ({data.size} bytes, {result.size} elements unpacked) "
165 f"is too small for the declared shape {list(dims)} "
166 f"({expected_elements} elements required)."
167 )
168 result.resize(dims, refcheck=False)
169 return result
172def _pack_2bitx4(array: np.ndarray) -> npt.NDArray[np.uint8]:
173 """Convert a numpy array to flatten, packed int2/uint2. Elements must be in the correct range."""
174 # Create a 1D copy
175 array_flat = array.ravel().view(np.uint8).copy()
176 size = array.size
177 pad_len = size % 4
178 if pad_len:
179 array_flat.resize([size + (4 - pad_len)], refcheck=False)
180 array_flat &= 0x03
181 array_flat[1::4] <<= 2
182 array_flat[2::4] <<= 4
183 array_flat[3::4] <<= 6
184 return array_flat[0::4] | array_flat[1::4] | array_flat[2::4] | array_flat[3::4]
187def to_array(tensor: onnx.TensorProto, base_dir: str = "") -> np.ndarray: # noqa: PLR0911
188 """Converts a tensor def object to a numpy array.
190 This function uses ml_dtypes if the dtype is not a native numpy dtype.
192 Args:
193 tensor: a TensorProto object.
194 base_dir: if external tensor exists, base_dir can help to find the path to it
196 Returns:
197 arr: the converted array.
198 """
199 if tensor.HasField("segment"):
200 raise ValueError("Currently not supporting loading segments.")
201 if tensor.data_type == onnx.TensorProto.UNDEFINED:
202 raise TypeError("The element type in the input tensor is UNDEFINED.")
204 tensor_dtype = tensor.data_type
205 np_dtype = helper.tensor_dtype_to_np_dtype(tensor_dtype)
206 storage_np_dtype = helper.tensor_dtype_to_np_dtype(
207 helper.tensor_dtype_to_storage_tensor_dtype(tensor_dtype)
208 )
209 storage_field = helper.tensor_dtype_to_field(tensor_dtype)
210 dims = tensor.dims
212 if tensor.data_type == onnx.TensorProto.STRING:
213 utf8_strings = getattr(tensor, storage_field)
214 ss = [s.decode("utf-8") for s in utf8_strings]
215 return np.asarray(ss).astype(np_dtype).reshape(dims)
217 # Load raw data from external tensor if it exists
218 if onnx.external_data_helper.uses_external_data(tensor):
219 onnx.external_data_helper.load_external_data_for_tensor(tensor, base_dir)
221 if tensor.HasField("raw_data"):
222 # Raw_bytes support: using frombuffer.
223 raw_data = tensor.raw_data
224 if sys.byteorder == "big":
225 # Convert endian from little to big
226 raw_data = np.frombuffer(raw_data, dtype=np_dtype).byteswap().tobytes()
228 if tensor_dtype in {
229 onnx.TensorProto.INT4,
230 onnx.TensorProto.UINT4,
231 onnx.TensorProto.FLOAT4E2M1,
232 }:
233 data = np.frombuffer(raw_data, dtype=np.uint8)
234 return _unpack_4bit(data, dims).view(np_dtype)
236 if tensor_dtype in {
237 onnx.TensorProto.UINT2,
238 onnx.TensorProto.INT2,
239 }:
240 data = np.frombuffer(raw_data, dtype=np.uint8)
241 return _unpack_2bit(data, dims).view(np_dtype)
243 return np.frombuffer(raw_data, dtype=np_dtype).reshape(dims)
245 if tensor_dtype in {
246 onnx.TensorProto.BFLOAT16,
247 onnx.TensorProto.FLOAT16,
248 onnx.TensorProto.INT16,
249 onnx.TensorProto.UINT16,
250 }:
251 return (
252 np.array(tensor.int32_data, dtype=np.int32)
253 .view(np.uint32)
254 .astype(np.uint16)
255 .reshape(dims)
256 .view(np_dtype)
257 )
259 if tensor_dtype in {
260 onnx.TensorProto.FLOAT8E4M3FN,
261 onnx.TensorProto.FLOAT8E4M3FNUZ,
262 onnx.TensorProto.FLOAT8E5M2,
263 onnx.TensorProto.FLOAT8E5M2FNUZ,
264 onnx.TensorProto.FLOAT8E8M0,
265 onnx.TensorProto.BOOL,
266 }:
267 return (
268 np.array(tensor.int32_data, dtype=np.int32)
269 .view(np.uint32)
270 .astype(np.uint8)
271 .view(np_dtype)
272 .reshape(dims)
273 )
275 if tensor_dtype in {
276 onnx.TensorProto.UINT4,
277 onnx.TensorProto.INT4,
278 onnx.TensorProto.FLOAT4E2M1,
279 }:
280 data = (
281 np.array(tensor.int32_data, dtype=np.int32).view(np.uint32).astype(np.uint8)
282 )
283 return _unpack_4bit(data, dims).view(np_dtype)
285 if tensor_dtype in {
286 onnx.TensorProto.UINT2,
287 onnx.TensorProto.INT2,
288 }:
289 data = (
290 np.array(tensor.int32_data, dtype=np.int32).view(np.uint32).astype(np.uint8)
291 )
292 return _unpack_2bit(data, dims).view(np_dtype)
294 data = getattr(tensor, storage_field)
295 if tensor_dtype in (onnx.TensorProto.COMPLEX64, onnx.TensorProto.COMPLEX128):
296 return np.array(data, dtype=storage_np_dtype).view(dtype=np_dtype).reshape(dims)
298 return np.asarray(data, dtype=storage_np_dtype).astype(np_dtype).reshape(dims)
301def tobytes_little_endian(array: np.ndarray) -> bytes:
302 """Converts an array into bytes in little endian byte order.
304 Args:
305 array: a numpy array.
307 Returns:
308 bytes: Byte representation of passed array in little endian byte order.
310 .. versionadded:: 1.20
311 """
312 if array.dtype.byteorder == ">" or (
313 sys.byteorder == "big" and array.dtype.byteorder == "="
314 ):
315 # Ensure that the bytes will be in little-endian byte-order.
316 array = array.astype(array.dtype.newbyteorder("<"))
318 return array.tobytes()
321def from_array(array: np.ndarray, /, name: str | None = None) -> onnx.TensorProto:
322 """Converts an array into a TensorProto including
324 Args:
325 array: a numpy array.
326 name: (optional) the name of the tensor.
328 Returns:
329 TensorProto: the converted tensor def.
330 """
331 tensor = onnx.TensorProto()
332 tensor.dims.extend(array.shape)
333 if name:
334 tensor.name = name
335 if array.dtype == object or np.issubdtype(array.dtype, np.str_):
336 # Special care for strings.
337 tensor.data_type = onnx.TensorProto.STRING
338 # TODO: Introduce full string support.
339 # We flatten the array in case there are n-D arrays are specified
340 # If you want more complex shapes then follow the below instructions.
341 # Unlike other types where the shape is automatically inferred from
342 # nested arrays of values, the only reliable way now to feed strings
343 # is to put them into a flat array then specify type astype(object)
344 # (otherwise all strings may have different types depending on their length)
345 # and then specify shape .reshape([x, y, z])
346 flat_array = array.flatten()
347 for e in flat_array:
348 if isinstance(e, str):
349 tensor.string_data.append(e.encode("utf-8"))
350 elif isinstance(e, bytes):
351 tensor.string_data.append(e)
352 else:
353 raise NotImplementedError(
354 f"Unrecognized object in the object array, expect a string, or array of bytes: {type(e)}"
355 )
356 return tensor
358 dtype = helper.np_dtype_to_tensor_dtype(array.dtype)
359 if dtype in {
360 onnx.TensorProto.INT4,
361 onnx.TensorProto.UINT4,
362 onnx.TensorProto.FLOAT4E2M1,
363 }:
364 # Pack the array into int4
365 array = _pack_4bitx2(array)
367 if dtype in {
368 onnx.TensorProto.UINT2,
369 onnx.TensorProto.INT2,
370 }:
371 # Pack the array into int2
372 array = _pack_2bitx4(array)
374 tensor.raw_data = tobytes_little_endian(array)
375 tensor.data_type = dtype # type: ignore[assignment]
376 return tensor
379def to_list(sequence: onnx.SequenceProto) -> list[Any]:
380 """Converts a sequence def to a Python list.
382 Args:
383 sequence: a SequenceProto object.
385 Returns:
386 list: the converted list.
387 """
388 elem_type = sequence.elem_type
389 if elem_type == onnx.SequenceProto.TENSOR:
390 return [to_array(v) for v in sequence.tensor_values]
391 if elem_type == onnx.SequenceProto.SPARSE_TENSOR:
392 return [to_array(v) for v in sequence.sparse_tensor_values] # type: ignore[arg-type]
393 if elem_type == onnx.SequenceProto.SEQUENCE:
394 return [to_list(v) for v in sequence.sequence_values]
395 if elem_type == onnx.SequenceProto.MAP:
396 return [to_dict(v) for v in sequence.map_values]
397 raise TypeError("The element type in the input sequence is not supported.")
400def from_list(
401 lst: list[Any], name: str | None = None, dtype: int | None = None
402) -> onnx.SequenceProto:
403 """Converts a list into a sequence def.
405 Args:
406 lst: a Python list
407 name: (optional) the name of the sequence.
408 dtype: (optional) type of element in the input list, used for specifying
409 sequence values when converting an empty list.
411 Returns:
412 SequenceProto: the converted sequence def.
413 """
414 sequence = onnx.SequenceProto()
415 if name:
416 sequence.name = name
418 if dtype is not None:
419 elem_type = dtype
420 elif len(lst) > 0:
421 first_elem = lst[0]
422 if isinstance(first_elem, dict):
423 elem_type = onnx.SequenceProto.MAP
424 elif isinstance(first_elem, list):
425 elem_type = onnx.SequenceProto.SEQUENCE
426 else:
427 elem_type = onnx.SequenceProto.TENSOR
428 else:
429 # if empty input list and no dtype specified
430 # choose sequence of tensors on default
431 elem_type = onnx.SequenceProto.TENSOR
432 sequence.elem_type = elem_type
434 if (len(lst) > 0) and not all(isinstance(elem, type(lst[0])) for elem in lst):
435 raise TypeError(
436 "The element type in the input list is not the same "
437 "for all elements and therefore is not supported as a sequence."
438 )
440 if elem_type == onnx.SequenceProto.TENSOR:
441 for tensor in lst:
442 sequence.tensor_values.extend([from_array(np.asarray(tensor))])
443 elif elem_type == onnx.SequenceProto.SEQUENCE:
444 for seq in lst:
445 sequence.sequence_values.extend([from_list(seq)])
446 elif elem_type == onnx.SequenceProto.MAP:
447 for mapping in lst:
448 sequence.map_values.extend([from_dict(mapping)])
449 else:
450 raise TypeError(
451 "The element type in the input list is not a tensor, "
452 "sequence, or map and is not supported."
453 )
454 return sequence
457def to_dict(map_proto: onnx.MapProto) -> dict[Any, Any]:
458 """Converts a map def to a Python dictionary.
460 Args:
461 map_proto: a MapProto object.
463 Returns:
464 The converted dictionary.
465 """
466 key_list: list[Any] = []
467 if map_proto.key_type == onnx.TensorProto.STRING:
468 key_list = list(map_proto.string_keys)
469 else:
470 key_list = list(map_proto.keys)
472 value_list = to_list(map_proto.values)
473 if len(key_list) != len(value_list):
474 raise IndexError(
475 f"Length of keys and values for MapProto (map name: {map_proto.name}) are not the same."
476 )
477 return dict(zip(key_list, value_list, strict=False))
480def from_dict(dict_: dict[Any, Any], name: str | None = None) -> onnx.MapProto:
481 """Converts a Python dictionary into a map def.
483 Args:
484 dict_: Python dictionary
485 name: (optional) the name of the map.
487 Returns:
488 MapProto: the converted map def.
489 """
490 map_proto = onnx.MapProto()
491 if name:
492 map_proto.name = name
493 if not dict_:
494 raise ValueError("Cannot convert an empty dictionary to MapProto.")
495 keys = list(dict_)
496 raw_key_type = np.result_type(keys[0])
497 key_type = helper.np_dtype_to_tensor_dtype(raw_key_type)
499 valid_key_int_types = {
500 onnx.TensorProto.INT8,
501 onnx.TensorProto.INT16,
502 onnx.TensorProto.INT32,
503 onnx.TensorProto.INT64,
504 onnx.TensorProto.UINT8,
505 onnx.TensorProto.UINT16,
506 onnx.TensorProto.UINT32,
507 onnx.TensorProto.UINT64,
508 }
510 if not (all(np.result_type(key) == raw_key_type for key in keys)):
511 raise TypeError(
512 "The key type in the input dictionary is not the same "
513 "for all keys and therefore is not valid as a map."
514 )
516 values = list(dict_.values())
517 raw_value_type = np.result_type(values[0])
518 if not all(np.result_type(val) == raw_value_type for val in values):
519 raise TypeError(
520 "The value type in the input dictionary is not the same "
521 "for all values and therefore is not valid as a map."
522 )
524 value_seq = from_list(values)
526 map_proto.key_type = key_type # type: ignore[assignment]
527 if key_type == onnx.TensorProto.STRING:
528 map_proto.string_keys.extend(keys)
529 elif key_type in valid_key_int_types:
530 map_proto.keys.extend(keys)
531 else:
532 raise TypeError(f"Unsupported map key type: {key_type}")
533 map_proto.values.CopyFrom(value_seq)
534 return map_proto
537def to_optional(optional: onnx.OptionalProto) -> Any | None:
538 """Converts an optional def to a Python optional.
540 Args:
541 optional: an OptionalProto object.
543 Returns:
544 opt: the converted optional.
545 """
546 elem_type = optional.elem_type
547 if elem_type == onnx.OptionalProto.UNDEFINED:
548 return None
549 if elem_type == onnx.OptionalProto.TENSOR:
550 return to_array(optional.tensor_value)
551 if elem_type == onnx.OptionalProto.SPARSE_TENSOR:
552 return to_array(optional.sparse_tensor_value) # type: ignore[arg-type]
553 if elem_type == onnx.OptionalProto.SEQUENCE:
554 return to_list(optional.sequence_value)
555 if elem_type == onnx.OptionalProto.MAP:
556 return to_dict(optional.map_value)
557 if elem_type == onnx.OptionalProto.OPTIONAL:
558 return to_optional(optional.optional_value)
559 raise TypeError("The element type in the input optional is not supported.")
562def from_optional(
563 opt: Any | None, name: str | None = None, dtype: int | None = None
564) -> onnx.OptionalProto:
565 """Converts an optional value into a Optional def.
567 Args:
568 opt: a Python optional
569 name: (optional) the name of the optional.
570 dtype: (optional) type of element in the input, used for specifying
571 optional values when converting empty none. dtype must
572 be a valid OptionalProto.DataType value
574 Returns:
575 optional: the converted optional def.
576 """
577 # TODO: create a map and replace conditional branches
578 optional = onnx.OptionalProto()
579 if name:
580 optional.name = name
582 if dtype is not None:
583 # dtype must be a valid onnx.OptionalProto.DataType
584 if dtype not in onnx.OptionalProto.DataType.values():
585 raise TypeError(f"{dtype} must be a valid OptionalProto.DataType.")
586 elem_type = dtype
587 elif isinstance(opt, dict):
588 elem_type = onnx.OptionalProto.MAP
589 elif isinstance(opt, list):
590 elem_type = onnx.OptionalProto.SEQUENCE
591 elif opt is None:
592 elem_type = onnx.OptionalProto.UNDEFINED
593 else:
594 elem_type = onnx.OptionalProto.TENSOR
596 optional.elem_type = elem_type
598 if opt is not None:
599 if elem_type == onnx.OptionalProto.TENSOR:
600 optional.tensor_value.CopyFrom(from_array(opt))
601 elif elem_type == onnx.OptionalProto.SEQUENCE:
602 optional.sequence_value.CopyFrom(from_list(opt))
603 elif elem_type == onnx.OptionalProto.MAP:
604 optional.map_value.CopyFrom(from_dict(opt))
605 else:
606 raise TypeError(
607 "The element type in the input is not a tensor, "
608 "sequence, or map and is not supported."
609 )
610 return optional
613def create_random_int(
614 input_shape: tuple[int], dtype: np.dtype, seed: int = 1
615) -> np.ndarray:
616 """Create random integer array for backend/test/case/node.
618 Args:
619 input_shape: The shape for the returned integer array.
620 dtype: The NumPy data type for the returned integer array.
621 seed: The seed for np.random.
623 Returns:
624 np.ndarray: Random integer array.
625 """
626 np.random.seed(seed)
627 if dtype in (
628 np.uint8,
629 np.uint16,
630 np.uint32,
631 np.uint64,
632 np.int8,
633 np.int16,
634 np.int32,
635 np.int64,
636 ):
637 # the range of np.random.randint is int32; set a fixed boundary if overflow
638 end = min(np.iinfo(dtype).max, np.iinfo(np.int32).max)
639 start = max(np.iinfo(dtype).min, np.iinfo(np.int32).min)
640 return np.random.randint(start, end, size=input_shape).astype(dtype)
641 raise TypeError(f"{dtype} is not supported by create_random_int.")
644def saturate_cast(x: np.ndarray, dtype: np.dtype) -> np.ndarray:
645 """Saturate cast for numeric types.
647 This function ensures that values outside the representable range
648 of the target dtype are clamped to the maximum or minimum representable
649 value of that dtype.
650 """
651 if np.issubdtype(dtype, np.integer) or dtype in (
652 ml_dtypes.int4,
653 ml_dtypes.uint4,
654 ml_dtypes.int2,
655 ml_dtypes.uint2,
656 ):
657 info = ml_dtypes.iinfo(dtype)
658 x = np.round(x)
659 else:
660 info = ml_dtypes.finfo(dtype) # type: ignore[assignment]
662 return np.clip(x, info.min, info.max).astype(dtype) # type: ignore[no-any-return]