Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/helper.py: 26%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

700 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4from __future__ import annotations 

5 

6import collections.abc 

7import functools 

8import math 

9import numbers 

10import typing 

11from typing import TYPE_CHECKING, Any, TypeVar 

12 

13import google.protobuf.message 

14import numpy as np 

15import typing_extensions 

16 

17import onnx 

18from onnx import _mapping, defs 

19from onnx.onnx_data_pb import MapProto, OptionalProto, SequenceProto 

20from onnx.onnx_pb import ( 

21 AttributeProto, 

22 FunctionProto, 

23 GraphProto, 

24 ModelProto, 

25 NodeProto, 

26 OperatorSetIdProto, 

27 TensorProto, 

28 TensorShapeProto, 

29 TrainingInfoProto, 

30 TypeProto, 

31 ValueInfoProto, 

32) 

33 

34if TYPE_CHECKING: 

35 from collections.abc import Callable, KeysView, Sequence 

36 

37 from google.protobuf.internal.containers import RepeatedCompositeFieldContainer 

38 

39VersionRowType = tuple[str, int, int, int] | tuple[str, int, int, int, int] 

40VersionTableType = list[VersionRowType] 

41AssignmentBindingType = list[tuple[str, str]] 

42 

43# This is a copy of the documented version in https://github.com/onnx/onnx/blob/main/docs/Versioning.md#released-versions 

44# Both must be updated whenever a new version of ONNX is released. 

45VERSION_TABLE: VersionTableType = [ 

46 # Release-version, IR version, ai.onnx version, ai.onnx.ml version, (optional) ai.onnx.training version 

47 ("1.0", 3, 1, 1), 

48 ("1.1", 3, 5, 1), 

49 ("1.1.2", 3, 6, 1), 

50 ("1.2", 3, 7, 1), 

51 ("1.3", 3, 8, 1), 

52 ("1.4.1", 4, 9, 1), 

53 ("1.5.0", 5, 10, 1), 

54 ("1.6.0", 6, 11, 2), 

55 ("1.7.0", 7, 12, 2, 1), 

56 ("1.8.0", 7, 13, 2, 1), 

57 ("1.8.1", 7, 13, 2, 1), 

58 ("1.9.0", 7, 14, 2, 1), 

59 ("1.10.0", 8, 15, 2, 1), 

60 ("1.10.1", 8, 15, 2, 1), 

61 ("1.10.2", 8, 15, 2, 1), 

62 ("1.11.0", 8, 16, 3, 1), 

63 ("1.12.0", 8, 17, 3, 1), 

64 ("1.13.0", 8, 18, 3, 1), 

65 ("1.13.1", 8, 18, 3, 1), 

66 ("1.14.0", 9, 19, 3, 1), 

67 ("1.14.1", 9, 19, 3, 1), 

68 ("1.15.0", 9, 20, 4, 1), 

69 ("1.16.0", 10, 21, 5, 1), 

70 ("1.16.1", 10, 21, 5, 1), 

71 ("1.16.2", 10, 21, 5, 1), 

72 ("1.17.0", 10, 22, 5, 1), 

73 ("1.18.0", 11, 23, 5, 1), 

74 ("1.19.0", 12, 24, 5, 1), 

75 ("1.19.1", 12, 24, 5, 1), 

76 ("1.20.0", 13, 25, 5, 1), 

77 ("1.20.1", 13, 25, 5, 1), 

78 ("1.21.0", 13, 26, 5, 1), 

79] 

80 

81VersionMapType = dict[tuple[str, int], int] 

82 

83 

84def _create_op_set_id_version_map(table: VersionTableType) -> VersionMapType: 

85 """Create a map from (opset-domain, opset-version) to ir-version from above table.""" 

86 result: VersionMapType = {} 

87 for row in table: 

88 ir_version = row[1] 

89 for pair in zip( 

90 ["ai.onnx", "ai.onnx.ml", "ai.onnx.training"], 

91 row[2:], 

92 strict=False, 

93 ): 

94 if pair not in result: 

95 result[pair] = ir_version 

96 if pair[0] == "ai.onnx": 

97 result["ai.onnx.preview", pair[1]] = ir_version 

98 if pair[0] == "ai.onnx.training": 

99 result["ai.onnx.preview.training", pair[1]] = ir_version 

100 return result 

101 

102 

103OP_SET_ID_VERSION_MAP = _create_op_set_id_version_map(VERSION_TABLE) 

104 

105 

106def find_min_ir_version_for( 

107 opsetidlist: Sequence[OperatorSetIdProto], ignore_unknown: bool = False 

108) -> int: 

109 """Given list of opset ids, determine minimum IR version required. 

110 

111 Args: 

112 opsetidlist: A sequence of OperatorSetIdProto. 

113 ignore_unknown: If True, ignore unknown domain and return default minimum 

114 version for that domain. 

115 

116 Returns: 

117 The minimum IR version required (integer) 

118 """ 

119 default_min_version = 3 

120 

121 def find_min(domain: str | None, version: int) -> int: 

122 key = (domain or "ai.onnx", version) 

123 if key in OP_SET_ID_VERSION_MAP: 

124 return OP_SET_ID_VERSION_MAP[key] 

125 if ignore_unknown: 

126 return default_min_version 

127 raise ValueError("Unsupported opset-version.") 

128 

129 if opsetidlist: 

130 return max(find_min(x.domain, x.version) for x in opsetidlist) 

131 return default_min_version # if no opsets specified 

132 

133 

134def make_node( 

135 op_type: str, 

136 inputs: Sequence[str], 

137 outputs: Sequence[str], 

138 name: str | None = None, 

139 doc_string: str | None = None, 

140 domain: str | None = None, 

141 overload: str | None = None, 

142 **kwargs: Any, 

143) -> NodeProto: 

144 """Construct a NodeProto. 

145 

146 Args: 

147 op_type (string): The name of the operator to construct 

148 inputs (list of string): list of input names 

149 outputs (list of string): list of output names 

150 name (string, default None): optional unique identifier for NodeProto 

151 doc_string (string, default None): optional documentation string for NodeProto 

152 domain (string, default None): optional domain for NodeProto. 

153 If it's None, we will just use default domain (which is empty) 

154 overload (string, default None): optional field, used to 

155 resolve calls to model-local functions 

156 **kwargs (dict): the attributes of the node. The acceptable values 

157 are documented in :func:`make_attribute`. 

158 

159 Returns: 

160 NodeProto 

161 """ 

162 node = NodeProto() 

163 node.op_type = op_type 

164 node.input.extend(inputs) 

165 node.output.extend(outputs) 

166 if name: 

167 node.name = name 

168 if doc_string: 

169 node.doc_string = doc_string 

170 if domain is not None: 

171 node.domain = domain 

172 if overload is not None: 

173 node.overload = overload 

174 if kwargs: 

175 node.attribute.extend( 

176 make_attribute(key, value) 

177 for key, value in sorted(kwargs.items()) 

178 if value is not None 

179 ) 

180 return node 

181 

182 

183def make_operatorsetid( 

184 domain: str, 

185 version: int, 

186) -> OperatorSetIdProto: 

187 """Construct an OperatorSetIdProto. 

188 

189 Args: 

190 domain (string): The domain of the operator set id 

191 version (integer): Version of operator set id 

192 Returns: 

193 OperatorSetIdProto 

194 """ 

195 operatorsetid = OperatorSetIdProto() 

196 operatorsetid.domain = domain 

197 operatorsetid.version = version 

198 return operatorsetid 

199 

200 

201def make_graph( 

202 nodes: Sequence[NodeProto], 

203 name: str, 

204 inputs: Sequence[ValueInfoProto], 

205 outputs: Sequence[ValueInfoProto], 

206 initializer: Sequence[TensorProto] | None = None, 

207 doc_string: str | None = None, 

208 value_info: Sequence[ValueInfoProto] | None = None, 

209 sparse_initializer: Sequence[onnx.SparseTensorProto] | None = None, 

210) -> GraphProto: 

211 """Construct a GraphProto 

212 

213 Args: 

214 nodes: list of NodeProto 

215 name (string): graph name 

216 inputs: list of ValueInfoProto 

217 outputs: list of ValueInfoProto 

218 initializer: list of TensorProto 

219 doc_string (string): graph documentation 

220 value_info: list of ValueInfoProto 

221 sparse_initializer: list of onnx.SparseTensorProto 

222 Returns: 

223 GraphProto 

224 """ 

225 if initializer is None: 

226 initializer = [] 

227 if sparse_initializer is None: 

228 sparse_initializer = [] 

229 if value_info is None: 

230 value_info = [] 

231 graph = GraphProto() 

232 graph.node.extend(nodes) 

233 graph.name = name 

234 graph.input.extend(inputs) 

235 graph.output.extend(outputs) 

236 graph.initializer.extend(initializer) 

237 graph.sparse_initializer.extend(sparse_initializer) 

238 graph.value_info.extend(value_info) 

239 if doc_string: 

240 graph.doc_string = doc_string 

241 return graph 

242 

243 

244def make_opsetid(domain: str, version: int) -> OperatorSetIdProto: 

245 """Construct an OperatorSetIdProto. 

246 

247 Args: 

248 domain (string): The domain of the operator set id 

249 version (integer): Version of operator set id 

250 Returns: 

251 OperatorSetIdProto 

252 """ 

253 opsetid = OperatorSetIdProto() 

254 opsetid.domain = domain 

255 opsetid.version = version 

256 return opsetid 

257 

258 

259def make_function( 

260 domain: str, 

261 fname: str, 

262 inputs: Sequence[str], 

263 outputs: Sequence[str], 

264 nodes: Sequence[NodeProto], 

265 opset_imports: Sequence[OperatorSetIdProto], 

266 attributes: Sequence[str] | None = None, 

267 attribute_protos: Sequence[AttributeProto] | None = None, 

268 doc_string: str | None = None, 

269 overload: str | None = None, 

270 value_info: Sequence[ValueInfoProto] | None = None, 

271) -> FunctionProto: 

272 if attributes is None: 

273 attributes = [] 

274 if attribute_protos is None: 

275 attribute_protos = [] 

276 if value_info is None: 

277 value_info = [] 

278 f = FunctionProto() 

279 f.domain = domain 

280 f.name = fname 

281 f.input.extend(inputs) 

282 f.output.extend(outputs) 

283 f.node.extend(nodes) 

284 f.opset_import.extend(opset_imports) 

285 f.attribute.extend(attributes) 

286 f.attribute_proto.extend(attribute_protos) 

287 if doc_string: 

288 f.doc_string = doc_string 

289 if overload is not None: 

290 f.overload = overload 

291 f.value_info.extend(value_info) 

292 return f 

293 

294 

295def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto: 

296 """Construct a ModelProto 

297 

298 Args: 

299 graph (GraphProto): *make_graph* returns 

300 **kwargs: any attribute to add to the returned instance 

301 Returns: 

302 ModelProto 

303 """ 

304 model = ModelProto() 

305 # Touch model.ir_version so it is stored as the version from which it is 

306 # generated. 

307 model.ir_version = onnx.IR_VERSION 

308 model.graph.CopyFrom(graph) 

309 

310 opset_imports: Sequence[OperatorSetIdProto] | None = kwargs.pop( 

311 "opset_imports", None 

312 ) 

313 if opset_imports is not None: 

314 model.opset_import.extend(opset_imports) 

315 else: 

316 # Default import 

317 imp = model.opset_import.add() 

318 imp.version = defs.onnx_opset_version() 

319 

320 functions: Sequence[FunctionProto] | None = kwargs.pop("functions", None) 

321 if functions is not None: 

322 model.functions.extend(functions) 

323 

324 for k, v in kwargs.items(): 

325 # TODO: Does this work with repeated fields? 

326 setattr(model, k, v) 

327 return model 

328 

329 

330# An extension of make_model that infers an IR_VERSION for the model, 

331# if not specified, using a best-effort-basis. 

332def make_model_gen_version(graph: GraphProto, **kwargs: Any) -> ModelProto: 

333 ir_version_field = "ir_version" 

334 if ir_version_field not in kwargs: 

335 opset_imports_field = "opset_imports" 

336 imports = kwargs.get(opset_imports_field, []) 

337 kwargs[ir_version_field] = find_min_ir_version_for(imports) 

338 return make_model(graph, **kwargs) 

339 

340 

341def set_metadata_props( 

342 proto: ( 

343 ModelProto 

344 | GraphProto 

345 | FunctionProto 

346 | NodeProto 

347 | TensorProto 

348 | ValueInfoProto 

349 ), 

350 dict_value: dict[str, str], 

351) -> None: 

352 del proto.metadata_props[:] 

353 for k, v in dict_value.items(): 

354 entry = proto.metadata_props.add() 

355 entry.key = k 

356 entry.value = v 

357 

358 

359def set_model_props(model: ModelProto, dict_value: dict[str, str]) -> None: 

360 set_metadata_props(model, dict_value) 

361 

362 

363def make_tensor( 

364 name: str, 

365 data_type: int, 

366 dims: Sequence[int], 

367 vals: Sequence[int | float] | bytes | np.ndarray, 

368 raw: bool = False, 

369) -> TensorProto: 

370 """Make a TensorProto with specified arguments. If raw is False, this 

371 function will choose the corresponding proto field to store the 

372 values based on data_type. If raw is True, use "raw_data" proto 

373 field to store the values, and values should be of type bytes in 

374 this case. 

375 

376 Args: 

377 name: tensor name 

378 data_type: a value such as onnx.TensorProto.FLOAT 

379 dims: shape 

380 vals: values 

381 raw: if True, vals contains the serialized content of the tensor, 

382 otherwise, vals should be a list of values of the type defined by ``data_type``. 

383 

384 Returns: 

385 TensorProto 

386 """ 

387 tensor = TensorProto() 

388 tensor.data_type = data_type 

389 tensor.name = name 

390 tensor.dims.extend(dims) 

391 

392 if data_type == TensorProto.STRING and raw: 

393 raise TypeError("Can not use raw_data to store string type.") 

394 

395 np_dtype = tensor_dtype_to_np_dtype(data_type) 

396 

397 if raw: 

398 # NumPy doesn't have INT2/INT4/FP4. It is packed in couples to UINT8 buffers. 

399 if data_type in {TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1}: 

400 expected_size_bytes = 0.5 

401 elif data_type in {TensorProto.UINT2, TensorProto.INT2}: 

402 expected_size_bytes = 0.25 

403 else: 

404 expected_size_bytes = np_dtype.itemsize 

405 expected_size_bytes *= math.prod(dims) 

406 expected_size_bytes = math.ceil(expected_size_bytes) 

407 if isinstance(vals, np.ndarray): 

408 if data_type in { 

409 TensorProto.INT4, 

410 TensorProto.UINT4, 

411 TensorProto.FLOAT4E2M1, 

412 }: 

413 vals = onnx.numpy_helper._pack_4bitx2(vals) 

414 elif data_type in {TensorProto.UINT2, TensorProto.INT2}: 

415 vals = onnx.numpy_helper._pack_2bitx4(vals) 

416 

417 raw_data = onnx.numpy_helper.tobytes_little_endian(vals) 

418 elif isinstance(vals, bytes): 

419 raw_data = vals 

420 else: 

421 raise TypeError( 

422 f"Raw data must be bytes or numpy.ndarray, but got {type(vals)}." 

423 ) 

424 if len(raw_data) != expected_size_bytes: 

425 raise ValueError( 

426 f"Raw data size does not match tensor's size. Expected {expected_size_bytes} bytes, but got {len(raw_data)} bytes." 

427 ) 

428 tensor.raw_data = raw_data 

429 return tensor 

430 

431 assert not raw, "Bug: raw should be False at this point." 

432 

433 if data_type == TensorProto.STRING: 

434 vals = np.array(vals).flatten() 

435 if len(vals) != 0: 

436 vals = np.vectorize(_to_bytes)(vals) # Convert to bytes 

437 elif data_type in { 

438 TensorProto.FLOAT8E4M3FN, 

439 TensorProto.FLOAT8E4M3FNUZ, 

440 TensorProto.FLOAT8E5M2, 

441 TensorProto.FLOAT8E5M2FNUZ, 

442 }: 

443 # Float8 values are by default casted using saturating cast. 

444 vals = onnx.numpy_helper.saturate_cast(np.asarray(vals), np_dtype).flatten() 

445 elif data_type == TensorProto.FLOAT8E8M0: 

446 vals = onnx.numpy_helper.to_float8e8m0( 

447 np.asarray(vals), saturate=True, round_mode="up" 

448 ).flatten() 

449 else: 

450 vals = np.asarray(vals, dtype=np_dtype).flatten() 

451 

452 expected_elements = math.prod(dims) 

453 if len(vals) != expected_elements: 

454 raise ValueError( 

455 f"Number of values ({len(vals)}) does not match tensor " 

456 f"dimensions requiring {expected_elements} elements." 

457 ) 

458 if data_type == TensorProto.COMPLEX128: 

459 vals = vals.view(np.float64) # type: ignore[union-attr] 

460 elif data_type == TensorProto.COMPLEX64: 

461 vals = vals.view(np.float32) # type: ignore[union-attr] 

462 elif data_type in {TensorProto.BFLOAT16, TensorProto.FLOAT16}: 

463 vals = vals.view(np.uint16) # type: ignore[union-attr] 

464 elif data_type in { 

465 TensorProto.FLOAT8E4M3FN, 

466 TensorProto.FLOAT8E4M3FNUZ, 

467 TensorProto.FLOAT8E5M2, 

468 TensorProto.FLOAT8E5M2FNUZ, 

469 TensorProto.FLOAT8E8M0, 

470 }: 

471 vals = vals.view(np.uint8) # type: ignore[union-attr] 

472 elif data_type in {TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1}: 

473 # Convert to packed 4-bit representation 

474 vals = onnx.numpy_helper._pack_4bitx2(vals) # type: ignore[arg-type] 

475 elif data_type in {TensorProto.UINT2, TensorProto.INT2}: 

476 # Convert to packed 2-bit representation 

477 vals = onnx.numpy_helper._pack_2bitx4(vals) # type: ignore[arg-type] 

478 elif data_type == TensorProto.BOOL: 

479 vals = vals.astype(np.uint8) # type: ignore[union-attr] 

480 

481 field = tensor_dtype_to_field(data_type) 

482 getattr(tensor, field).extend(vals) 

483 return tensor 

484 

485 

486def make_sparse_tensor( 

487 values: TensorProto, indices: TensorProto, dims: Sequence[int] 

488) -> onnx.SparseTensorProto: 

489 """Construct a SparseTensorProto 

490 

491 Args: 

492 values (TensorProto): the values 

493 indices (TensorProto): the indices 

494 dims: the shape 

495 

496 Returns: 

497 SparseTensorProto 

498 """ 

499 sparse = onnx.SparseTensorProto() 

500 sparse.values.CopyFrom(values) 

501 sparse.indices.CopyFrom(indices) 

502 sparse.dims.extend(dims) 

503 return sparse 

504 

505 

506def make_sequence( 

507 name: str, 

508 elem_type: SequenceProto.DataType, 

509 values: Sequence[Any], 

510) -> SequenceProto: 

511 """Make a Sequence with specified value arguments.""" 

512 sequence = SequenceProto() 

513 sequence.name = name 

514 sequence.elem_type = elem_type # type: ignore[assignment] 

515 

516 if elem_type == SequenceProto.UNDEFINED: 

517 return sequence 

518 

519 attribute: RepeatedCompositeFieldContainer | None = None 

520 if elem_type == SequenceProto.TENSOR: 

521 attribute = sequence.tensor_values 

522 elif elem_type == SequenceProto.SPARSE_TENSOR: 

523 attribute = sequence.sparse_tensor_values 

524 elif elem_type == SequenceProto.SEQUENCE: 

525 attribute = sequence.sequence_values 

526 elif elem_type == SequenceProto.MAP: 

527 attribute = sequence.map_values 

528 elif elem_type == OptionalProto.OPTIONAL: 

529 attribute = sequence.optional_values 

530 else: 

531 raise TypeError("The element type in the input sequence is not supported.") 

532 

533 attribute.extend(values) 

534 return sequence 

535 

536 

537def make_map( 

538 name: str, key_type: int, keys: list[Any], values: SequenceProto 

539) -> MapProto: 

540 """Make a Map with specified key-value pair arguments. 

541 

542 Criteria for conversion: 

543 - Keys and Values must have the same number of elements 

544 - Every key in keys must be of the same type 

545 - Every value in values must be of the same type 

546 """ 

547 map_proto = MapProto() 

548 valid_key_int_types = [ 

549 TensorProto.INT8, 

550 TensorProto.INT16, 

551 TensorProto.INT32, 

552 TensorProto.INT64, 

553 TensorProto.UINT8, 

554 TensorProto.UINT16, 

555 TensorProto.UINT32, 

556 TensorProto.UINT64, 

557 ] 

558 map_proto.name = name 

559 map_proto.key_type = key_type 

560 if key_type == TensorProto.STRING: 

561 map_proto.string_keys.extend(keys) 

562 elif key_type in valid_key_int_types: 

563 map_proto.keys.extend(keys) 

564 map_proto.values.CopyFrom(values) 

565 return map_proto 

566 

567 

568def make_optional( 

569 name: str, 

570 elem_type: OptionalProto.DataType, 

571 value: google.protobuf.message.Message | None, 

572) -> OptionalProto: 

573 """Make an Optional with specified value arguments.""" 

574 optional = OptionalProto() 

575 optional.name = name 

576 optional.elem_type = elem_type # type: ignore[assignment] 

577 

578 if elem_type == OptionalProto.UNDEFINED: 

579 return optional 

580 attribute: google.protobuf.message.Message | None = None 

581 if elem_type == OptionalProto.TENSOR: 

582 attribute = optional.tensor_value 

583 elif elem_type == OptionalProto.SPARSE_TENSOR: 

584 attribute = optional.sparse_tensor_value 

585 elif elem_type == OptionalProto.SEQUENCE: 

586 attribute = optional.sequence_value 

587 elif elem_type == OptionalProto.MAP: 

588 attribute = optional.map_value 

589 elif elem_type == OptionalProto.OPTIONAL: 

590 attribute = optional.optional_value 

591 else: 

592 raise TypeError("The element type in the input optional is not supported.") 

593 

594 assert value is not None 

595 attribute.CopyFrom(value) # type: ignore[arg-type] 

596 return optional 

597 

598 

599def _to_bytes(value: str | bytes) -> bytes: 

600 """Coerce a string (or bytes) value into UTF-8 bytes.""" 

601 if isinstance(value, str): 

602 return value.encode("utf-8") 

603 return value 

604 

605 

606def make_attribute( 

607 key: str, 

608 value: Any, 

609 doc_string: str | None = None, 

610 attr_type: int | None = None, 

611) -> AttributeProto: 

612 """Makes an AttributeProto based on the value type.""" 

613 attr = AttributeProto() 

614 attr.name = key 

615 if doc_string: 

616 attr.doc_string = doc_string 

617 

618 # Singular cases 

619 if isinstance(value, numbers.Integral): 

620 attr.i = int(value) 

621 attr.type = AttributeProto.INT 

622 elif isinstance(value, numbers.Real): 

623 attr.f = float(value) 

624 attr.type = AttributeProto.FLOAT 

625 elif isinstance(value, (str, bytes)): 

626 # Encode strings into utf-8 

627 attr.s = _to_bytes(value) 

628 attr.type = AttributeProto.STRING 

629 elif isinstance(value, TensorProto): 

630 attr.t.CopyFrom(value) 

631 attr.type = AttributeProto.TENSOR 

632 elif isinstance(value, onnx.SparseTensorProto): 

633 attr.sparse_tensor.CopyFrom(value) 

634 attr.type = AttributeProto.SPARSE_TENSOR 

635 elif isinstance(value, GraphProto): 

636 attr.g.CopyFrom(value) 

637 attr.type = AttributeProto.GRAPH 

638 elif isinstance(value, TypeProto): 

639 attr.tp.CopyFrom(value) 

640 attr.type = AttributeProto.TYPE_PROTO 

641 # Iterable cases 

642 elif isinstance(value, collections.abc.Iterable): 

643 value = list(value) 

644 if len(value) == 0 and attr_type is None: 

645 raise ValueError( 

646 f"Could not infer attribute `{key}` type from empty iterator" 

647 ) 

648 if attr_type is None: 

649 types = {type(v) for v in value} 

650 for exp_t, exp_enum in ( 

651 (numbers.Integral, AttributeProto.INTS), 

652 (numbers.Real, AttributeProto.FLOATS), 

653 ((str, bytes), AttributeProto.STRINGS), 

654 (TensorProto, AttributeProto.TENSORS), 

655 (onnx.SparseTensorProto, AttributeProto.SPARSE_TENSORS), 

656 (GraphProto, AttributeProto.GRAPHS), 

657 (TypeProto, AttributeProto.TYPE_PROTOS), 

658 ): 

659 if all(issubclass(t, exp_t) for t in types): 

660 attr_type = exp_enum 

661 break 

662 if attr_type is None: 

663 raise ValueError( 

664 "Could not infer the attribute type from the elements of the passed Iterable value." 

665 ) 

666 

667 if attr_type == AttributeProto.INTS: 

668 attr.ints.extend(value) 

669 attr.type = AttributeProto.INTS 

670 elif attr_type == AttributeProto.FLOATS: 

671 attr.floats.extend(value) 

672 attr.type = AttributeProto.FLOATS 

673 elif attr_type == AttributeProto.STRINGS: 

674 attr.strings.extend(_to_bytes(v) for v in value) 

675 attr.type = AttributeProto.STRINGS 

676 elif attr_type == AttributeProto.TENSORS: 

677 attr.tensors.extend(value) 

678 attr.type = AttributeProto.TENSORS 

679 elif attr_type == AttributeProto.SPARSE_TENSORS: 

680 attr.sparse_tensors.extend(value) 

681 attr.type = AttributeProto.SPARSE_TENSORS 

682 elif attr_type == AttributeProto.GRAPHS: 

683 attr.graphs.extend(value) 

684 attr.type = AttributeProto.GRAPHS 

685 elif attr_type == AttributeProto.TYPE_PROTOS: 

686 attr.type_protos.extend(value) 

687 attr.type = AttributeProto.TYPE_PROTOS 

688 else: 

689 raise AssertionError # Should not reach since `ValueError` must be raised in attr_type checking 

690 else: 

691 raise TypeError(f"'{value}' is not an accepted attribute value.") 

692 

693 if attr_type is not None and attr.type != attr_type: 

694 raise TypeError( 

695 f"Inferred attribute type '{_attr_type_to_str(attr.type)}'({attr.type}) mismatched with specified type '{_attr_type_to_str(attr_type)}'({attr_type})" 

696 ) 

697 return attr 

698 

699 

700def make_attribute_ref( 

701 name: str, 

702 attr_type: AttributeProto.AttributeType, 

703 doc_string: str | None = None, 

704 *, 

705 ref_attr_name: str | None = None, 

706) -> AttributeProto: 

707 """Make an AttributeProto holding a reference to the parent function's attribute. 

708 

709 The returned attribute carries no value of its own; at instantiation time its 

710 value is supplied by the parent function's attribute named ``ref_attr_name``. 

711 When ``ref_attr_name`` is not provided, it defaults to ``name``. Reference 

712 attributes are only valid inside a function (sub-graph). 

713 

714 Args: 

715 name: The name of this attribute as used inside the function body. 

716 attr_type: The type of the attribute. 

717 doc_string: Optional human-readable documentation for the attribute. 

718 ref_attr_name: The name of the parent function's attribute being referenced. 

719 """ 

720 if ref_attr_name is None: 

721 ref_attr_name = name 

722 if not ref_attr_name: 

723 raise ValueError("ref_attr_name must be non-empty") 

724 

725 attr = AttributeProto() 

726 attr.name = name 

727 attr.type = attr_type # type: ignore[assignment] 

728 attr.ref_attr_name = ref_attr_name 

729 if doc_string: 

730 attr.doc_string = doc_string 

731 return attr 

732 

733 

734def get_attribute_value(attr: AttributeProto) -> Any: # noqa: PLR0911 

735 if attr.ref_attr_name: 

736 raise ValueError(f"Cannot get value of reference attribute: {attr}") 

737 if attr.type == AttributeProto.FLOAT: 

738 return attr.f 

739 if attr.type == AttributeProto.INT: 

740 return attr.i 

741 if attr.type == AttributeProto.STRING: 

742 return attr.s 

743 if attr.type == AttributeProto.TENSOR: 

744 return attr.t 

745 if attr.type == AttributeProto.SPARSE_TENSOR: 

746 return attr.sparse_tensor 

747 if attr.type == AttributeProto.GRAPH: 

748 return attr.g 

749 if attr.type == AttributeProto.TYPE_PROTO: 

750 return attr.tp 

751 if attr.type == AttributeProto.FLOATS: 

752 return list(attr.floats) 

753 if attr.type == AttributeProto.INTS: 

754 return list(attr.ints) 

755 if attr.type == AttributeProto.STRINGS: 

756 return list(attr.strings) 

757 if attr.type == AttributeProto.TENSORS: 

758 return list(attr.tensors) 

759 if attr.type == AttributeProto.SPARSE_TENSORS: 

760 return list(attr.sparse_tensors) 

761 if attr.type == AttributeProto.GRAPHS: 

762 return list(attr.graphs) 

763 if attr.type == AttributeProto.TYPE_PROTOS: 

764 return list(attr.type_protos) 

765 if attr.type == AttributeProto.UNDEFINED: 

766 return None 

767 raise ValueError(f"Unsupported ONNX attribute: {attr}") 

768 

769 

770def get_node_attr_value(node: NodeProto, attr_name: str) -> Any: 

771 matching = [x for x in node.attribute if x.name == attr_name] 

772 if len(matching) > 1: 

773 raise ValueError(f"Node has multiple attributes with name {attr_name}") 

774 if len(matching) < 1: 

775 raise ValueError(f"Node has no attribute with name {attr_name}") 

776 return get_attribute_value(matching[0]) 

777 

778 

779def make_empty_tensor_value_info(name: str) -> ValueInfoProto: 

780 value_info_proto = ValueInfoProto() 

781 value_info_proto.name = name 

782 return value_info_proto 

783 

784 

785def make_tensor_type_proto( 

786 elem_type: int, 

787 shape: Sequence[str | int | None] | None, 

788 shape_denotation: list[str] | None = None, 

789) -> TypeProto: 

790 """Makes a Tensor TypeProto based on the data type and shape.""" 

791 type_proto = TypeProto() 

792 tensor_type_proto = type_proto.tensor_type 

793 tensor_type_proto.elem_type = elem_type 

794 tensor_shape_proto = tensor_type_proto.shape 

795 

796 if shape is not None: 

797 # You might think this is a no-op (extending a normal Python 

798 # list by [] certainly is), but protobuf lists work a little 

799 # differently; if a field is never set, it is omitted from the 

800 # resulting protobuf; a list that is explicitly set to be 

801 # empty will get an (empty) entry in the protobuf. This 

802 # difference is visible to our consumers, so make sure we emit 

803 # an empty shape! 

804 tensor_shape_proto.dim.extend([]) 

805 

806 if shape_denotation and len(shape_denotation) != len(shape): 

807 raise ValueError( 

808 "Invalid shape_denotation. Must be of the same length as shape." 

809 ) 

810 

811 for i, d in enumerate(shape): 

812 dim = tensor_shape_proto.dim.add() 

813 if d is None: 

814 pass 

815 elif isinstance(d, int): 

816 dim.dim_value = d 

817 elif isinstance(d, str): 

818 dim.dim_param = d 

819 else: 

820 raise ValueError( 

821 f"Invalid item in shape: {d}. Needs to be of int or str." 

822 ) 

823 

824 if shape_denotation: 

825 dim.denotation = shape_denotation[i] 

826 

827 return type_proto 

828 

829 

830def make_tensor_value_info( 

831 name: str, 

832 elem_type: int, 

833 shape: Sequence[str | int | None] | None, 

834 doc_string: str = "", 

835 shape_denotation: list[str] | None = None, 

836) -> ValueInfoProto: 

837 """Makes a ValueInfoProto based on the data type and shape.""" 

838 value_info_proto = ValueInfoProto() 

839 value_info_proto.name = name 

840 if doc_string: 

841 value_info_proto.doc_string = doc_string 

842 

843 tensor_type_proto = make_tensor_type_proto(elem_type, shape, shape_denotation) 

844 value_info_proto.type.CopyFrom(tensor_type_proto) 

845 return value_info_proto 

846 

847 

848def make_sparse_tensor_type_proto( 

849 elem_type: int, 

850 shape: Sequence[str | int | None] | None, 

851 shape_denotation: list[str] | None = None, 

852) -> TypeProto: 

853 """Makes a SparseTensor TypeProto based on the data type and shape.""" 

854 type_proto = TypeProto() 

855 sparse_tensor_type_proto = type_proto.sparse_tensor_type 

856 sparse_tensor_type_proto.elem_type = elem_type 

857 sparse_tensor_shape_proto = sparse_tensor_type_proto.shape 

858 

859 if shape is not None: 

860 # You might think this is a no-op (extending a normal Python 

861 # list by [] certainly is), but protobuf lists work a little 

862 # differently; if a field is never set, it is omitted from the 

863 # resulting protobuf; a list that is explicitly set to be 

864 # empty will get an (empty) entry in the protobuf. This 

865 # difference is visible to our consumers, so make sure we emit 

866 # an empty shape! 

867 sparse_tensor_shape_proto.dim.extend([]) 

868 

869 if shape_denotation and len(shape_denotation) != len(shape): 

870 raise ValueError( 

871 "Invalid shape_denotation. Must be of the same length as shape." 

872 ) 

873 

874 for i, d in enumerate(shape): 

875 dim = sparse_tensor_shape_proto.dim.add() 

876 if d is None: 

877 pass 

878 elif isinstance(d, int): 

879 dim.dim_value = d 

880 elif isinstance(d, str): 

881 dim.dim_param = d 

882 else: 

883 raise ValueError( 

884 f"Invalid item in shape: {d}. Needs to be of int or text." 

885 ) 

886 

887 if shape_denotation: 

888 dim.denotation = shape_denotation[i] 

889 

890 return type_proto 

891 

892 

893def make_sparse_tensor_value_info( 

894 name: str, 

895 elem_type: int, 

896 shape: Sequence[str | int | None] | None, 

897 doc_string: str = "", 

898 shape_denotation: list[str] | None = None, 

899) -> ValueInfoProto: 

900 """Makes a SparseTensor ValueInfoProto based on the data type and shape.""" 

901 value_info_proto = ValueInfoProto() 

902 value_info_proto.name = name 

903 if doc_string: 

904 value_info_proto.doc_string = doc_string 

905 

906 sparse_tensor_type_proto = make_sparse_tensor_type_proto( 

907 elem_type, shape, shape_denotation 

908 ) 

909 value_info_proto.type.sparse_tensor_type.CopyFrom( 

910 sparse_tensor_type_proto.sparse_tensor_type 

911 ) 

912 return value_info_proto 

913 

914 

915def make_sequence_type_proto( 

916 inner_type_proto: TypeProto, 

917) -> TypeProto: 

918 """Makes a sequence TypeProto.""" 

919 type_proto = TypeProto() 

920 type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto) 

921 return type_proto 

922 

923 

924def make_optional_type_proto( 

925 inner_type_proto: TypeProto, 

926) -> TypeProto: 

927 """Makes an optional TypeProto.""" 

928 type_proto = TypeProto() 

929 type_proto.optional_type.elem_type.CopyFrom(inner_type_proto) 

930 return type_proto 

931 

932 

933def make_map_type_proto( 

934 key_type: int, 

935 value_type: TypeProto, 

936) -> TypeProto: 

937 """Makes a map TypeProto.""" 

938 type_proto = TypeProto() 

939 type_proto.map_type.key_type = key_type 

940 type_proto.map_type.value_type.CopyFrom(value_type) 

941 return type_proto 

942 

943 

944def make_value_info( 

945 name: str, 

946 type_proto: TypeProto, 

947 doc_string: str = "", 

948) -> ValueInfoProto: 

949 """Makes a ValueInfoProto with the given type_proto.""" 

950 value_info_proto = ValueInfoProto() 

951 value_info_proto.name = name 

952 if doc_string: 

953 value_info_proto.doc_string = doc_string 

954 

955 value_info_proto.type.CopyFrom(type_proto) 

956 return value_info_proto 

957 

958 

959def _sanitize_str(s: str | bytes) -> str: 

960 if isinstance(s, str): 

961 sanitized = s 

962 elif isinstance(s, bytes): 

963 sanitized = s.decode("utf-8", errors="ignore") 

964 else: 

965 sanitized = str(s) 

966 if len(sanitized) < 64: # noqa: PLR2004 

967 return sanitized 

968 return sanitized[:64] + f"...<+len={(len(sanitized) - 64)}>" 

969 

970 

971def make_tensor_sequence_value_info( 

972 name: str, 

973 elem_type: int, 

974 shape: Sequence[str | int | None] | None, 

975 doc_string: str = "", 

976 elem_shape_denotation: list[str] | None = None, 

977) -> ValueInfoProto: 

978 """Makes a Sequence[Tensors] ValueInfoProto based on the data type and shape.""" 

979 value_info_proto = ValueInfoProto() 

980 value_info_proto.name = name 

981 if doc_string: 

982 value_info_proto.doc_string = doc_string 

983 

984 tensor_type_proto = make_tensor_type_proto(elem_type, shape, elem_shape_denotation) 

985 sequence_type_proto = make_sequence_type_proto(tensor_type_proto) 

986 value_info_proto.type.sequence_type.CopyFrom(sequence_type_proto.sequence_type) 

987 

988 return value_info_proto 

989 

990 

991def printable_attribute( 

992 attr: AttributeProto, subgraphs: bool = False 

993) -> str | tuple[str, list[GraphProto]]: 

994 content = [] 

995 content.append(attr.name) 

996 content.append("=") 

997 

998 def str_float(f: float) -> str: 

999 # NB: Different Python versions print different numbers of trailing 

1000 # decimals, specifying this explicitly keeps it consistent for all 

1001 # versions 

1002 return f"{f:.15g}" 

1003 

1004 def str_int(i: int) -> str: 

1005 return str(i) 

1006 

1007 _T = TypeVar("_T") 

1008 

1009 def str_list(str_elem: Callable[[_T], str], xs: Sequence[_T]) -> str: 

1010 return "[" + ", ".join(map(str_elem, xs)) + "]" 

1011 

1012 # for now, this logic should continue to work as long as we are running on a proto3 

1013 # implementation. If/when we switch to proto3, we will need to use attr.type 

1014 

1015 # To support printing subgraphs, if we find a graph attribute, print out 

1016 # its name here and pass the graph itself up to the caller for later 

1017 # printing. 

1018 graphs = [] 

1019 if attr.HasField("f"): 

1020 content.append(str_float(attr.f)) 

1021 elif attr.HasField("i"): 

1022 content.append(str_int(attr.i)) 

1023 elif attr.HasField("s"): 

1024 # TODO: Bit nervous about Python 2 / Python 3 determinism implications 

1025 content.append(repr(_sanitize_str(attr.s))) 

1026 elif attr.HasField("t"): 

1027 if len(attr.t.dims) > 0: 

1028 content.append("<Tensor>") 

1029 else: 

1030 # special case to print scalars 

1031 field = tensor_dtype_to_field(attr.t.data_type) 

1032 content.append(f"<Scalar Tensor {getattr(attr.t, field)}>") 

1033 elif attr.HasField("g"): 

1034 content.append(f"<graph {attr.g.name}>") 

1035 graphs.append(attr.g) 

1036 elif attr.HasField("tp"): 

1037 content.append(f"<Type Proto {attr.tp}>") 

1038 elif attr.HasField("sparse_tensor"): 

1039 content.append("<Sparse Tensor>") 

1040 elif attr.floats: 

1041 content.append(str_list(str_float, attr.floats)) 

1042 elif attr.ints: 

1043 content.append(str_list(str_int, attr.ints)) 

1044 elif attr.strings: 

1045 # TODO: Bit nervous about Python 2 / Python 3 determinism implications 

1046 content.append(str(list(map(_sanitize_str, attr.strings)))) 

1047 elif attr.tensors: 

1048 content.append("[<Tensor>, ...]") 

1049 elif attr.sparse_tensors: 

1050 content.append("[<Sparse Tensor>, ...]") 

1051 elif attr.type_protos: 

1052 content.append("[") 

1053 for i, tp in enumerate(attr.type_protos): 

1054 comma = "," if i != len(attr.type_protos) - 1 else "" 

1055 content.append(f"<Type Proto {tp}>{comma}") 

1056 content.append("]") 

1057 elif attr.graphs: 

1058 content.append("[") 

1059 for i, g in enumerate(attr.graphs): 

1060 comma = "," if i != len(attr.graphs) - 1 else "" 

1061 content.append(f"<graph {g.name}>{comma}") 

1062 content.append("]") 

1063 graphs.extend(attr.graphs) 

1064 else: 

1065 content.append("<Unknown>") 

1066 if subgraphs: 

1067 return " ".join(content), graphs 

1068 return " ".join(content) 

1069 

1070 

1071def printable_dim(dim: TensorShapeProto.Dimension) -> str: 

1072 which = dim.WhichOneof("value") 

1073 if which is None: 

1074 return "?" 

1075 return str(getattr(dim, which)) 

1076 

1077 

1078def printable_type(t: TypeProto) -> str: 

1079 if t.WhichOneof("value") == "tensor_type": 

1080 s: str = TensorProto.DataType.Name(t.tensor_type.elem_type) # type: ignore[arg-type] 

1081 if t.tensor_type.HasField("shape"): 

1082 if len(t.tensor_type.shape.dim): 

1083 s += str(", " + "x".join(map(printable_dim, t.tensor_type.shape.dim))) 

1084 else: 

1085 s += ", scalar" 

1086 return s 

1087 if t.WhichOneof("value") is None: 

1088 return "" 

1089 return f"Unknown type {t.WhichOneof('value')}" 

1090 

1091 

1092def printable_value_info(v: ValueInfoProto) -> str: 

1093 s = f"%{v.name}" 

1094 if v.type: 

1095 s = f"{s}[{printable_type(v.type)}]" 

1096 return s 

1097 

1098 

1099def printable_tensor_proto(t: TensorProto) -> str: 

1100 s = f"%{t.name}[" 

1101 s += TensorProto.DataType.Name(t.data_type) # type: ignore[arg-type] 

1102 if t.dims is not None: 

1103 if len(t.dims): 

1104 s += str(", " + "x".join(map(str, t.dims))) 

1105 else: 

1106 s += ", scalar" 

1107 s += "]" 

1108 return s 

1109 

1110 

1111def printable_node( 

1112 node: NodeProto, prefix: str = "", subgraphs: bool = False 

1113) -> str | tuple[str, list[GraphProto]]: 

1114 content = [] 

1115 if len(node.output): 

1116 content.append(", ".join([f"%{name}" for name in node.output])) 

1117 content.append("=") 

1118 # To deal with nested graphs 

1119 graphs: list[GraphProto] = [] 

1120 printed_attrs = [] 

1121 for attr in node.attribute: 

1122 if subgraphs: 

1123 printed_attr_subgraphs = printable_attribute(attr, subgraphs) 

1124 if not isinstance(printed_attr_subgraphs[1], list): 

1125 raise TypeError( 

1126 f"printed_attr_subgraphs[1] must be an instance of {list}." 

1127 ) 

1128 graphs.extend(printed_attr_subgraphs[1]) 

1129 printed_attrs.append(printed_attr_subgraphs[0]) 

1130 else: 

1131 printed = printable_attribute(attr) 

1132 if not isinstance(printed, str): 

1133 raise TypeError(f"printed must be an instance of {str}.") 

1134 printed_attrs.append(printed) 

1135 printed_attributes = ", ".join(sorted(printed_attrs)) 

1136 printed_inputs = ", ".join([f"%{name}" for name in node.input]) 

1137 if node.attribute: 

1138 content.append(f"{node.op_type}[{printed_attributes}]({printed_inputs})") 

1139 else: 

1140 content.append(f"{node.op_type}({printed_inputs})") 

1141 if subgraphs: 

1142 return prefix + " ".join(content), graphs 

1143 return prefix + " ".join(content) 

1144 

1145 

1146@typing_extensions.deprecated( 

1147 "Deprecated since 1.19. Consider using onnx.printer.to_text() instead." 

1148) 

1149def printable_graph(graph: GraphProto, prefix: str = "") -> str: 

1150 """Display a GraphProto as a string. 

1151 

1152 .. deprecated:: 1.19 

1153 Consider using :func:`onnx.printer.to_text` instead. 

1154 

1155 Args: 

1156 graph (GraphProto): the graph to display 

1157 prefix (string): prefix of every line 

1158 

1159 Returns: 

1160 string 

1161 """ 

1162 content = [] 

1163 indent = prefix + " " 

1164 # header 

1165 header = ["graph", graph.name] 

1166 initializers = {t.name for t in graph.initializer} 

1167 if len(graph.input): 

1168 header.append("(") 

1169 in_strs = [] # required inputs 

1170 in_with_init_strs: list = [] # optional inputs with initializer providing default value 

1171 for inp in graph.input: 

1172 if inp.name not in initializers: 

1173 in_strs.append(printable_value_info(inp)) 

1174 else: 

1175 in_with_init_strs.append(printable_value_info(inp)) 

1176 if in_strs: 

1177 content.append(prefix + " ".join(header)) 

1178 header = [] 

1179 for line in in_strs: 

1180 content.append(prefix + " " + line) # noqa: PERF401 

1181 header.append(")") 

1182 

1183 if in_with_init_strs: 

1184 header.append("optional inputs with matching initializers (") 

1185 content.append(prefix + " ".join(header)) 

1186 header = [] 

1187 for line in in_with_init_strs: 

1188 content.append(prefix + " " + line) # noqa: PERF401 

1189 header.append(")") 

1190 

1191 # from IR 4 onwards an initializer is not required to have a matching graph input 

1192 # so output the name, type and shape of those as well 

1193 if len(in_with_init_strs) < len(initializers): 

1194 graph_inputs = {i.name for i in graph.input} 

1195 init_strs = [ 

1196 printable_tensor_proto(i) 

1197 for i in graph.initializer 

1198 if i.name not in graph_inputs 

1199 ] 

1200 header.append("initializers (") 

1201 content.append(prefix + " ".join(header)) 

1202 header = [] 

1203 for line in init_strs: 

1204 content.append(prefix + " " + line) # noqa: PERF401 

1205 header.append(")") 

1206 

1207 header.append("{") 

1208 content.append(prefix + " ".join(header)) 

1209 graphs: list[GraphProto] = [] 

1210 # body 

1211 for node in graph.node: 

1212 contents_subgraphs = printable_node(node, indent, subgraphs=True) 

1213 if not isinstance(contents_subgraphs[1], list): 

1214 raise TypeError(f"contents_subgraphs[1] must be an instance of {list}.") 

1215 content.append(contents_subgraphs[0]) 

1216 graphs.extend(contents_subgraphs[1]) 

1217 # tail 

1218 tail = ["return"] 

1219 if len(graph.output): 

1220 tail.append(", ".join([f"%{out.name}" for out in graph.output])) 

1221 content.append(indent + " ".join(tail)) 

1222 # closing bracket 

1223 content.append(prefix + "}") 

1224 for g in graphs: 

1225 content.append("\n" + printable_graph(g)) # noqa: PERF401 

1226 return "\n".join(content) 

1227 

1228 

1229def strip_doc_string(proto: google.protobuf.message.Message) -> None: 

1230 """Empties `doc_string` field on any nested protobuf messages""" 

1231 if not isinstance(proto, google.protobuf.message.Message): 

1232 raise TypeError( 

1233 f"proto must be an instance of {google.protobuf.message.Message}." 

1234 ) 

1235 for descriptor in proto.DESCRIPTOR.fields: 

1236 if descriptor.name == "doc_string": 

1237 proto.ClearField(descriptor.name) 

1238 elif descriptor.type == descriptor.TYPE_MESSAGE: 

1239 if descriptor.label == descriptor.LABEL_REPEATED: 

1240 for x in getattr(proto, descriptor.name): 

1241 strip_doc_string(x) 

1242 elif proto.HasField(descriptor.name): 

1243 strip_doc_string(getattr(proto, descriptor.name)) 

1244 

1245 

1246def make_training_info( 

1247 algorithm: GraphProto, 

1248 algorithm_bindings: AssignmentBindingType, 

1249 initialization: GraphProto | None, 

1250 initialization_bindings: AssignmentBindingType | None, 

1251) -> TrainingInfoProto: 

1252 training_info = TrainingInfoProto() 

1253 training_info.algorithm.CopyFrom(algorithm) 

1254 for k, v in algorithm_bindings: 

1255 binding = training_info.update_binding.add() 

1256 binding.key = k 

1257 binding.value = v 

1258 

1259 if initialization: 

1260 training_info.initialization.CopyFrom(initialization) 

1261 if initialization_bindings: 

1262 for k, v in initialization_bindings: 

1263 binding = training_info.initialization_binding.add() 

1264 binding.key = k 

1265 binding.value = v 

1266 

1267 return training_info 

1268 

1269 

1270# Following functions are used for mapping 

1271def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype: 

1272 """Convert a TensorProto's data_type to corresponding numpy dtype. It can be used while making tensor. 

1273 

1274 Args: 

1275 tensor_dtype: TensorProto's data_type 

1276 

1277 Returns: 

1278 numpy's data_type 

1279 """ 

1280 return _mapping.TENSOR_TYPE_MAP[tensor_dtype].np_dtype 

1281 

1282 

1283def tensor_dtype_to_storage_tensor_dtype(tensor_dtype: int) -> int: 

1284 """Convert a TensorProto's data_type to corresponding data_type for storage. 

1285 

1286 Args: 

1287 tensor_dtype: TensorProto's data_type 

1288 

1289 Returns: 

1290 data_type for storage 

1291 """ 

1292 return _mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype 

1293 

1294 

1295def tensor_dtype_to_string(tensor_dtype: int) -> str: 

1296 """Get the name of given TensorProto's data_type. 

1297 

1298 Args: 

1299 tensor_dtype: TensorProto's data_type 

1300 

1301 Returns: 

1302 the name of data_type 

1303 """ 

1304 return _mapping.TENSOR_TYPE_MAP[tensor_dtype].name 

1305 

1306 

1307@functools.lru_cache(None) 

1308def tensor_dtype_to_field(tensor_dtype: int) -> str: 

1309 """Convert a TensorProto's data_type to corresponding field name for storage. It can be used while making tensors. 

1310 

1311 Args: 

1312 tensor_dtype: TensorProto's data_type 

1313 

1314 Returns: 

1315 field name 

1316 """ 

1317 storage_tensor_type_to_field = { 

1318 int(TensorProto.FLOAT): "float_data", 

1319 int(TensorProto.INT32): "int32_data", 

1320 int(TensorProto.INT64): "int64_data", 

1321 int(TensorProto.DOUBLE): "double_data", 

1322 int(TensorProto.UINT32): "uint64_data", 

1323 int(TensorProto.UINT64): "uint64_data", 

1324 int(TensorProto.STRING): "string_data", 

1325 } 

1326 return storage_tensor_type_to_field[ 

1327 _mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype 

1328 ] 

1329 

1330 

1331@functools.lru_cache(None) 

1332def np_dtype_to_tensor_dtype(np_dtype: np.dtype) -> TensorProto.DataType: 

1333 """Convert a numpy's dtype to corresponding tensor type. It can be used while converting numpy arrays to tensors. 

1334 

1335 Args: 

1336 np_dtype: numpy's data_type 

1337 

1338 Returns: 

1339 TensorsProto's data_type 

1340 """ 

1341 _np_dtype_to_tensor_dtype = { 

1342 v.np_dtype: k for k, v in _mapping.TENSOR_TYPE_MAP.items() 

1343 } 

1344 if np_dtype in _np_dtype_to_tensor_dtype: 

1345 return typing.cast("TensorProto.DataType", _np_dtype_to_tensor_dtype[np_dtype]) 

1346 if np.issubdtype(np_dtype, np.str_): 

1347 return TensorProto.STRING # type: ignore[return-value] 

1348 

1349 raise ValueError( 

1350 f"Unable to convert type {np_dtype!r} into TensorProto element type." 

1351 ) 

1352 

1353 

1354def get_all_tensor_dtypes() -> KeysView[int]: 

1355 """Get all tensor types from TensorProto. 

1356 

1357 Returns: 

1358 all tensor types from TensorProto 

1359 """ 

1360 return _mapping.TENSOR_TYPE_MAP.keys() 

1361 

1362 

1363_ATTRIBUTE_TYPE_TO_STR: dict[int, str] = { 

1364 k: v for v, k in AttributeProto.AttributeType.items() 

1365} 

1366 

1367 

1368def _attr_type_to_str(attr_type: int) -> str: 

1369 """Convert AttributeProto type to string. 

1370 

1371 Args: 

1372 attr_type: AttributeProto type. 

1373 

1374 Returns: 

1375 String representing the supplied attr_type. 

1376 """ 

1377 if attr_type in AttributeProto.AttributeType.values(): 

1378 return _ATTRIBUTE_TYPE_TO_STR[attr_type] 

1379 return AttributeProto.AttributeType.keys()[0]