Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/_mapping.py: 91%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
6from typing import NamedTuple
8import ml_dtypes
9import numpy as np
11from onnx.onnx_pb import TensorProto
14class TensorDtypeMap(NamedTuple):
15 np_dtype: np.dtype
16 storage_dtype: int
17 name: str
20# tensor_dtype: (numpy type, storage type, string name)
21# The storage type is the type used to store the tensor in the *_data field of
22# a TensorProto. All available fields are float_data, int32_data, int64_data,
23# string_data, uint64_data and double_data.
24TENSOR_TYPE_MAP: dict[int, TensorDtypeMap] = {
25 int(TensorProto.FLOAT): TensorDtypeMap(
26 np.dtype("float32"), int(TensorProto.FLOAT), "TensorProto.FLOAT"
27 ),
28 int(TensorProto.UINT8): TensorDtypeMap(
29 np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT8"
30 ),
31 int(TensorProto.INT8): TensorDtypeMap(
32 np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT8"
33 ),
34 int(TensorProto.UINT16): TensorDtypeMap(
35 np.dtype("uint16"), int(TensorProto.INT32), "TensorProto.UINT16"
36 ),
37 int(TensorProto.INT16): TensorDtypeMap(
38 np.dtype("int16"), int(TensorProto.INT32), "TensorProto.INT16"
39 ),
40 int(TensorProto.INT32): TensorDtypeMap(
41 np.dtype("int32"), int(TensorProto.INT32), "TensorProto.INT32"
42 ),
43 int(TensorProto.INT64): TensorDtypeMap(
44 np.dtype("int64"), int(TensorProto.INT64), "TensorProto.INT64"
45 ),
46 int(TensorProto.BOOL): TensorDtypeMap(
47 np.dtype("bool"), int(TensorProto.INT32), "TensorProto.BOOL"
48 ),
49 int(TensorProto.FLOAT16): TensorDtypeMap(
50 np.dtype("float16"), int(TensorProto.INT32), "TensorProto.FLOAT16"
51 ),
52 int(TensorProto.BFLOAT16): TensorDtypeMap(
53 np.dtype(ml_dtypes.bfloat16),
54 int(TensorProto.INT32),
55 "TensorProto.BFLOAT16",
56 ),
57 int(TensorProto.DOUBLE): TensorDtypeMap(
58 np.dtype("float64"), int(TensorProto.DOUBLE), "TensorProto.DOUBLE"
59 ),
60 int(TensorProto.COMPLEX64): TensorDtypeMap(
61 np.dtype("complex64"), int(TensorProto.FLOAT), "TensorProto.COMPLEX64"
62 ),
63 int(TensorProto.COMPLEX128): TensorDtypeMap(
64 np.dtype("complex128"),
65 int(TensorProto.DOUBLE),
66 "TensorProto.COMPLEX128",
67 ),
68 int(TensorProto.UINT32): TensorDtypeMap(
69 np.dtype("uint32"), int(TensorProto.UINT64), "TensorProto.UINT32"
70 ),
71 int(TensorProto.UINT64): TensorDtypeMap(
72 np.dtype("uint64"), int(TensorProto.UINT64), "TensorProto.UINT64"
73 ),
74 int(TensorProto.STRING): TensorDtypeMap(
75 np.dtype("object"), int(TensorProto.STRING), "TensorProto.STRING"
76 ),
77 int(TensorProto.FLOAT8E4M3FN): TensorDtypeMap(
78 np.dtype(ml_dtypes.float8_e4m3fn),
79 int(TensorProto.INT32),
80 "TensorProto.FLOAT8E4M3FN",
81 ),
82 int(TensorProto.FLOAT8E4M3FNUZ): TensorDtypeMap(
83 np.dtype(ml_dtypes.float8_e4m3fnuz),
84 int(TensorProto.INT32),
85 "TensorProto.FLOAT8E4M3FNUZ",
86 ),
87 int(TensorProto.FLOAT8E5M2): TensorDtypeMap(
88 np.dtype(ml_dtypes.float8_e5m2),
89 int(TensorProto.INT32),
90 "TensorProto.FLOAT8E5M2",
91 ),
92 int(TensorProto.FLOAT8E5M2FNUZ): TensorDtypeMap(
93 np.dtype(ml_dtypes.float8_e5m2fnuz),
94 int(TensorProto.INT32),
95 "TensorProto.FLOAT8E5M2FNUZ",
96 ),
97 int(TensorProto.UINT4): TensorDtypeMap(
98 np.dtype(ml_dtypes.uint4), int(TensorProto.INT32), "TensorProto.UINT4"
99 ),
100 int(TensorProto.INT4): TensorDtypeMap(
101 np.dtype(ml_dtypes.int4), int(TensorProto.INT32), "TensorProto.INT4"
102 ),
103 int(TensorProto.FLOAT4E2M1): TensorDtypeMap(
104 np.dtype(ml_dtypes.float4_e2m1fn),
105 int(TensorProto.INT32),
106 "TensorProto.FLOAT4E2M1",
107 ),
108 int(TensorProto.FLOAT8E8M0): TensorDtypeMap(
109 np.dtype(ml_dtypes.float8_e8m0fnu),
110 int(TensorProto.INT32),
111 "TensorProto.FLOAT8E8M0",
112 ),
113 int(TensorProto.UINT2): TensorDtypeMap(
114 np.dtype(ml_dtypes.uint2), int(TensorProto.INT32), "TensorProto.UINT2"
115 ),
116 int(TensorProto.INT2): TensorDtypeMap(
117 np.dtype(ml_dtypes.int2), int(TensorProto.INT32), "TensorProto.INT2"
118 ),
119}