Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/_mapping.py: 91%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

11 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4from __future__ import annotations 

5 

6from typing import NamedTuple 

7 

8import ml_dtypes 

9import numpy as np 

10 

11from onnx.onnx_pb import TensorProto 

12 

13 

14class TensorDtypeMap(NamedTuple): 

15 np_dtype: np.dtype 

16 storage_dtype: int 

17 name: str 

18 

19 

20# tensor_dtype: (numpy type, storage type, string name) 

21# The storage type is the type used to store the tensor in the *_data field of 

22# a TensorProto. All available fields are float_data, int32_data, int64_data, 

23# string_data, uint64_data and double_data. 

24TENSOR_TYPE_MAP: dict[int, TensorDtypeMap] = { 

25 int(TensorProto.FLOAT): TensorDtypeMap( 

26 np.dtype("float32"), int(TensorProto.FLOAT), "TensorProto.FLOAT" 

27 ), 

28 int(TensorProto.UINT8): TensorDtypeMap( 

29 np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT8" 

30 ), 

31 int(TensorProto.INT8): TensorDtypeMap( 

32 np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT8" 

33 ), 

34 int(TensorProto.UINT16): TensorDtypeMap( 

35 np.dtype("uint16"), int(TensorProto.INT32), "TensorProto.UINT16" 

36 ), 

37 int(TensorProto.INT16): TensorDtypeMap( 

38 np.dtype("int16"), int(TensorProto.INT32), "TensorProto.INT16" 

39 ), 

40 int(TensorProto.INT32): TensorDtypeMap( 

41 np.dtype("int32"), int(TensorProto.INT32), "TensorProto.INT32" 

42 ), 

43 int(TensorProto.INT64): TensorDtypeMap( 

44 np.dtype("int64"), int(TensorProto.INT64), "TensorProto.INT64" 

45 ), 

46 int(TensorProto.BOOL): TensorDtypeMap( 

47 np.dtype("bool"), int(TensorProto.INT32), "TensorProto.BOOL" 

48 ), 

49 int(TensorProto.FLOAT16): TensorDtypeMap( 

50 np.dtype("float16"), int(TensorProto.INT32), "TensorProto.FLOAT16" 

51 ), 

52 int(TensorProto.BFLOAT16): TensorDtypeMap( 

53 np.dtype(ml_dtypes.bfloat16), 

54 int(TensorProto.INT32), 

55 "TensorProto.BFLOAT16", 

56 ), 

57 int(TensorProto.DOUBLE): TensorDtypeMap( 

58 np.dtype("float64"), int(TensorProto.DOUBLE), "TensorProto.DOUBLE" 

59 ), 

60 int(TensorProto.COMPLEX64): TensorDtypeMap( 

61 np.dtype("complex64"), int(TensorProto.FLOAT), "TensorProto.COMPLEX64" 

62 ), 

63 int(TensorProto.COMPLEX128): TensorDtypeMap( 

64 np.dtype("complex128"), 

65 int(TensorProto.DOUBLE), 

66 "TensorProto.COMPLEX128", 

67 ), 

68 int(TensorProto.UINT32): TensorDtypeMap( 

69 np.dtype("uint32"), int(TensorProto.UINT64), "TensorProto.UINT32" 

70 ), 

71 int(TensorProto.UINT64): TensorDtypeMap( 

72 np.dtype("uint64"), int(TensorProto.UINT64), "TensorProto.UINT64" 

73 ), 

74 int(TensorProto.STRING): TensorDtypeMap( 

75 np.dtype("object"), int(TensorProto.STRING), "TensorProto.STRING" 

76 ), 

77 int(TensorProto.FLOAT8E4M3FN): TensorDtypeMap( 

78 np.dtype(ml_dtypes.float8_e4m3fn), 

79 int(TensorProto.INT32), 

80 "TensorProto.FLOAT8E4M3FN", 

81 ), 

82 int(TensorProto.FLOAT8E4M3FNUZ): TensorDtypeMap( 

83 np.dtype(ml_dtypes.float8_e4m3fnuz), 

84 int(TensorProto.INT32), 

85 "TensorProto.FLOAT8E4M3FNUZ", 

86 ), 

87 int(TensorProto.FLOAT8E5M2): TensorDtypeMap( 

88 np.dtype(ml_dtypes.float8_e5m2), 

89 int(TensorProto.INT32), 

90 "TensorProto.FLOAT8E5M2", 

91 ), 

92 int(TensorProto.FLOAT8E5M2FNUZ): TensorDtypeMap( 

93 np.dtype(ml_dtypes.float8_e5m2fnuz), 

94 int(TensorProto.INT32), 

95 "TensorProto.FLOAT8E5M2FNUZ", 

96 ), 

97 int(TensorProto.UINT4): TensorDtypeMap( 

98 np.dtype(ml_dtypes.uint4), int(TensorProto.INT32), "TensorProto.UINT4" 

99 ), 

100 int(TensorProto.INT4): TensorDtypeMap( 

101 np.dtype(ml_dtypes.int4), int(TensorProto.INT32), "TensorProto.INT4" 

102 ), 

103 int(TensorProto.FLOAT4E2M1): TensorDtypeMap( 

104 np.dtype(ml_dtypes.float4_e2m1fn), 

105 int(TensorProto.INT32), 

106 "TensorProto.FLOAT4E2M1", 

107 ), 

108 int(TensorProto.FLOAT8E8M0): TensorDtypeMap( 

109 np.dtype(ml_dtypes.float8_e8m0fnu), 

110 int(TensorProto.INT32), 

111 "TensorProto.FLOAT8E8M0", 

112 ), 

113 int(TensorProto.UINT2): TensorDtypeMap( 

114 np.dtype(ml_dtypes.uint2), int(TensorProto.INT32), "TensorProto.UINT2" 

115 ), 

116 int(TensorProto.INT2): TensorDtypeMap( 

117 np.dtype(ml_dtypes.int2), int(TensorProto.INT32), "TensorProto.INT2" 

118 ), 

119}