Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/helper.py: 26%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
6import collections.abc
7import functools
8import math
9import numbers
10import typing
11from typing import TYPE_CHECKING, Any, TypeVar
13import google.protobuf.message
14import numpy as np
15import typing_extensions
17import onnx
18from onnx import _mapping, defs
19from onnx.onnx_data_pb import MapProto, OptionalProto, SequenceProto
20from onnx.onnx_pb import (
21 AttributeProto,
22 FunctionProto,
23 GraphProto,
24 ModelProto,
25 NodeProto,
26 OperatorSetIdProto,
27 TensorProto,
28 TensorShapeProto,
29 TrainingInfoProto,
30 TypeProto,
31 ValueInfoProto,
32)
34if TYPE_CHECKING:
35 from collections.abc import Callable, KeysView, Sequence
37 from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
39VersionRowType = tuple[str, int, int, int] | tuple[str, int, int, int, int]
40VersionTableType = list[VersionRowType]
41AssignmentBindingType = list[tuple[str, str]]
43# This is a copy of the documented version in https://github.com/onnx/onnx/blob/main/docs/Versioning.md#released-versions
44# Both must be updated whenever a new version of ONNX is released.
45VERSION_TABLE: VersionTableType = [
46 # Release-version, IR version, ai.onnx version, ai.onnx.ml version, (optional) ai.onnx.training version
47 ("1.0", 3, 1, 1),
48 ("1.1", 3, 5, 1),
49 ("1.1.2", 3, 6, 1),
50 ("1.2", 3, 7, 1),
51 ("1.3", 3, 8, 1),
52 ("1.4.1", 4, 9, 1),
53 ("1.5.0", 5, 10, 1),
54 ("1.6.0", 6, 11, 2),
55 ("1.7.0", 7, 12, 2, 1),
56 ("1.8.0", 7, 13, 2, 1),
57 ("1.8.1", 7, 13, 2, 1),
58 ("1.9.0", 7, 14, 2, 1),
59 ("1.10.0", 8, 15, 2, 1),
60 ("1.10.1", 8, 15, 2, 1),
61 ("1.10.2", 8, 15, 2, 1),
62 ("1.11.0", 8, 16, 3, 1),
63 ("1.12.0", 8, 17, 3, 1),
64 ("1.13.0", 8, 18, 3, 1),
65 ("1.13.1", 8, 18, 3, 1),
66 ("1.14.0", 9, 19, 3, 1),
67 ("1.14.1", 9, 19, 3, 1),
68 ("1.15.0", 9, 20, 4, 1),
69 ("1.16.0", 10, 21, 5, 1),
70 ("1.16.1", 10, 21, 5, 1),
71 ("1.16.2", 10, 21, 5, 1),
72 ("1.17.0", 10, 22, 5, 1),
73 ("1.18.0", 11, 23, 5, 1),
74 ("1.19.0", 12, 24, 5, 1),
75 ("1.19.1", 12, 24, 5, 1),
76 ("1.20.0", 13, 25, 5, 1),
77 ("1.20.1", 13, 25, 5, 1),
78 ("1.21.0", 13, 26, 5, 1),
79]
81VersionMapType = dict[tuple[str, int], int]
84def _create_op_set_id_version_map(table: VersionTableType) -> VersionMapType:
85 """Create a map from (opset-domain, opset-version) to ir-version from above table."""
86 result: VersionMapType = {}
87 for row in table:
88 ir_version = row[1]
89 for pair in zip(
90 ["ai.onnx", "ai.onnx.ml", "ai.onnx.training"],
91 row[2:],
92 strict=False,
93 ):
94 if pair not in result:
95 result[pair] = ir_version
96 if pair[0] == "ai.onnx":
97 result["ai.onnx.preview", pair[1]] = ir_version
98 if pair[0] == "ai.onnx.training":
99 result["ai.onnx.preview.training", pair[1]] = ir_version
100 return result
103OP_SET_ID_VERSION_MAP = _create_op_set_id_version_map(VERSION_TABLE)
106def find_min_ir_version_for(
107 opsetidlist: Sequence[OperatorSetIdProto], ignore_unknown: bool = False
108) -> int:
109 """Given list of opset ids, determine minimum IR version required.
111 Args:
112 opsetidlist: A sequence of OperatorSetIdProto.
113 ignore_unknown: If True, ignore unknown domain and return default minimum
114 version for that domain.
116 Returns:
117 The minimum IR version required (integer)
118 """
119 default_min_version = 3
121 def find_min(domain: str | None, version: int) -> int:
122 key = (domain or "ai.onnx", version)
123 if key in OP_SET_ID_VERSION_MAP:
124 return OP_SET_ID_VERSION_MAP[key]
125 if ignore_unknown:
126 return default_min_version
127 raise ValueError("Unsupported opset-version.")
129 if opsetidlist:
130 return max(find_min(x.domain, x.version) for x in opsetidlist)
131 return default_min_version # if no opsets specified
134def make_node(
135 op_type: str,
136 inputs: Sequence[str],
137 outputs: Sequence[str],
138 name: str | None = None,
139 doc_string: str | None = None,
140 domain: str | None = None,
141 overload: str | None = None,
142 **kwargs: Any,
143) -> NodeProto:
144 """Construct a NodeProto.
146 Args:
147 op_type (string): The name of the operator to construct
148 inputs (list of string): list of input names
149 outputs (list of string): list of output names
150 name (string, default None): optional unique identifier for NodeProto
151 doc_string (string, default None): optional documentation string for NodeProto
152 domain (string, default None): optional domain for NodeProto.
153 If it's None, we will just use default domain (which is empty)
154 overload (string, default None): optional field, used to
155 resolve calls to model-local functions
156 **kwargs (dict): the attributes of the node. The acceptable values
157 are documented in :func:`make_attribute`.
159 Returns:
160 NodeProto
161 """
162 node = NodeProto()
163 node.op_type = op_type
164 node.input.extend(inputs)
165 node.output.extend(outputs)
166 if name:
167 node.name = name
168 if doc_string:
169 node.doc_string = doc_string
170 if domain is not None:
171 node.domain = domain
172 if overload is not None:
173 node.overload = overload
174 if kwargs:
175 node.attribute.extend(
176 make_attribute(key, value)
177 for key, value in sorted(kwargs.items())
178 if value is not None
179 )
180 return node
183def make_operatorsetid(
184 domain: str,
185 version: int,
186) -> OperatorSetIdProto:
187 """Construct an OperatorSetIdProto.
189 Args:
190 domain (string): The domain of the operator set id
191 version (integer): Version of operator set id
192 Returns:
193 OperatorSetIdProto
194 """
195 operatorsetid = OperatorSetIdProto()
196 operatorsetid.domain = domain
197 operatorsetid.version = version
198 return operatorsetid
201def make_graph(
202 nodes: Sequence[NodeProto],
203 name: str,
204 inputs: Sequence[ValueInfoProto],
205 outputs: Sequence[ValueInfoProto],
206 initializer: Sequence[TensorProto] | None = None,
207 doc_string: str | None = None,
208 value_info: Sequence[ValueInfoProto] | None = None,
209 sparse_initializer: Sequence[onnx.SparseTensorProto] | None = None,
210) -> GraphProto:
211 """Construct a GraphProto
213 Args:
214 nodes: list of NodeProto
215 name (string): graph name
216 inputs: list of ValueInfoProto
217 outputs: list of ValueInfoProto
218 initializer: list of TensorProto
219 doc_string (string): graph documentation
220 value_info: list of ValueInfoProto
221 sparse_initializer: list of onnx.SparseTensorProto
222 Returns:
223 GraphProto
224 """
225 if initializer is None:
226 initializer = []
227 if sparse_initializer is None:
228 sparse_initializer = []
229 if value_info is None:
230 value_info = []
231 graph = GraphProto()
232 graph.node.extend(nodes)
233 graph.name = name
234 graph.input.extend(inputs)
235 graph.output.extend(outputs)
236 graph.initializer.extend(initializer)
237 graph.sparse_initializer.extend(sparse_initializer)
238 graph.value_info.extend(value_info)
239 if doc_string:
240 graph.doc_string = doc_string
241 return graph
244def make_opsetid(domain: str, version: int) -> OperatorSetIdProto:
245 """Construct an OperatorSetIdProto.
247 Args:
248 domain (string): The domain of the operator set id
249 version (integer): Version of operator set id
250 Returns:
251 OperatorSetIdProto
252 """
253 opsetid = OperatorSetIdProto()
254 opsetid.domain = domain
255 opsetid.version = version
256 return opsetid
259def make_function(
260 domain: str,
261 fname: str,
262 inputs: Sequence[str],
263 outputs: Sequence[str],
264 nodes: Sequence[NodeProto],
265 opset_imports: Sequence[OperatorSetIdProto],
266 attributes: Sequence[str] | None = None,
267 attribute_protos: Sequence[AttributeProto] | None = None,
268 doc_string: str | None = None,
269 overload: str | None = None,
270 value_info: Sequence[ValueInfoProto] | None = None,
271) -> FunctionProto:
272 if attributes is None:
273 attributes = []
274 if attribute_protos is None:
275 attribute_protos = []
276 if value_info is None:
277 value_info = []
278 f = FunctionProto()
279 f.domain = domain
280 f.name = fname
281 f.input.extend(inputs)
282 f.output.extend(outputs)
283 f.node.extend(nodes)
284 f.opset_import.extend(opset_imports)
285 f.attribute.extend(attributes)
286 f.attribute_proto.extend(attribute_protos)
287 if doc_string:
288 f.doc_string = doc_string
289 if overload is not None:
290 f.overload = overload
291 f.value_info.extend(value_info)
292 return f
295def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto:
296 """Construct a ModelProto
298 Args:
299 graph (GraphProto): *make_graph* returns
300 **kwargs: any attribute to add to the returned instance
301 Returns:
302 ModelProto
303 """
304 model = ModelProto()
305 # Touch model.ir_version so it is stored as the version from which it is
306 # generated.
307 model.ir_version = onnx.IR_VERSION
308 model.graph.CopyFrom(graph)
310 opset_imports: Sequence[OperatorSetIdProto] | None = kwargs.pop(
311 "opset_imports", None
312 )
313 if opset_imports is not None:
314 model.opset_import.extend(opset_imports)
315 else:
316 # Default import
317 imp = model.opset_import.add()
318 imp.version = defs.onnx_opset_version()
320 functions: Sequence[FunctionProto] | None = kwargs.pop("functions", None)
321 if functions is not None:
322 model.functions.extend(functions)
324 for k, v in kwargs.items():
325 # TODO: Does this work with repeated fields?
326 setattr(model, k, v)
327 return model
330# An extension of make_model that infers an IR_VERSION for the model,
331# if not specified, using a best-effort-basis.
332def make_model_gen_version(graph: GraphProto, **kwargs: Any) -> ModelProto:
333 ir_version_field = "ir_version"
334 if ir_version_field not in kwargs:
335 opset_imports_field = "opset_imports"
336 imports = kwargs.get(opset_imports_field, [])
337 kwargs[ir_version_field] = find_min_ir_version_for(imports)
338 return make_model(graph, **kwargs)
341def set_metadata_props(
342 proto: (
343 ModelProto
344 | GraphProto
345 | FunctionProto
346 | NodeProto
347 | TensorProto
348 | ValueInfoProto
349 ),
350 dict_value: dict[str, str],
351) -> None:
352 del proto.metadata_props[:]
353 for k, v in dict_value.items():
354 entry = proto.metadata_props.add()
355 entry.key = k
356 entry.value = v
359def set_model_props(model: ModelProto, dict_value: dict[str, str]) -> None:
360 set_metadata_props(model, dict_value)
363def make_tensor(
364 name: str,
365 data_type: int,
366 dims: Sequence[int],
367 vals: Sequence[int | float] | bytes | np.ndarray,
368 raw: bool = False,
369) -> TensorProto:
370 """Make a TensorProto with specified arguments. If raw is False, this
371 function will choose the corresponding proto field to store the
372 values based on data_type. If raw is True, use "raw_data" proto
373 field to store the values, and values should be of type bytes in
374 this case.
376 Args:
377 name: tensor name
378 data_type: a value such as onnx.TensorProto.FLOAT
379 dims: shape
380 vals: values
381 raw: if True, vals contains the serialized content of the tensor,
382 otherwise, vals should be a list of values of the type defined by ``data_type``.
384 Returns:
385 TensorProto
386 """
387 tensor = TensorProto()
388 tensor.data_type = data_type
389 tensor.name = name
390 tensor.dims.extend(dims)
392 if data_type == TensorProto.STRING and raw:
393 raise TypeError("Can not use raw_data to store string type.")
395 np_dtype = tensor_dtype_to_np_dtype(data_type)
397 if raw:
398 # NumPy doesn't have INT2/INT4/FP4. It is packed in couples to UINT8 buffers.
399 if data_type in {TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1}:
400 expected_size_bytes = 0.5
401 elif data_type in {TensorProto.UINT2, TensorProto.INT2}:
402 expected_size_bytes = 0.25
403 else:
404 expected_size_bytes = np_dtype.itemsize
405 expected_size_bytes *= math.prod(dims)
406 expected_size_bytes = math.ceil(expected_size_bytes)
407 if isinstance(vals, np.ndarray):
408 if data_type in {
409 TensorProto.INT4,
410 TensorProto.UINT4,
411 TensorProto.FLOAT4E2M1,
412 }:
413 vals = onnx.numpy_helper._pack_4bitx2(vals)
414 elif data_type in {TensorProto.UINT2, TensorProto.INT2}:
415 vals = onnx.numpy_helper._pack_2bitx4(vals)
417 raw_data = onnx.numpy_helper.tobytes_little_endian(vals)
418 elif isinstance(vals, bytes):
419 raw_data = vals
420 else:
421 raise TypeError(
422 f"Raw data must be bytes or numpy.ndarray, but got {type(vals)}."
423 )
424 if len(raw_data) != expected_size_bytes:
425 raise ValueError(
426 f"Raw data size does not match tensor's size. Expected {expected_size_bytes} bytes, but got {len(raw_data)} bytes."
427 )
428 tensor.raw_data = raw_data
429 return tensor
431 assert not raw, "Bug: raw should be False at this point."
433 if data_type == TensorProto.STRING:
434 vals = np.array(vals).flatten()
435 if len(vals) != 0:
436 vals = np.vectorize(_to_bytes)(vals) # Convert to bytes
437 elif data_type in {
438 TensorProto.FLOAT8E4M3FN,
439 TensorProto.FLOAT8E4M3FNUZ,
440 TensorProto.FLOAT8E5M2,
441 TensorProto.FLOAT8E5M2FNUZ,
442 }:
443 # Float8 values are by default casted using saturating cast.
444 vals = onnx.numpy_helper.saturate_cast(np.asarray(vals), np_dtype).flatten()
445 elif data_type == TensorProto.FLOAT8E8M0:
446 vals = onnx.numpy_helper.to_float8e8m0(
447 np.asarray(vals), saturate=True, round_mode="up"
448 ).flatten()
449 else:
450 vals = np.asarray(vals, dtype=np_dtype).flatten()
452 expected_elements = math.prod(dims)
453 if len(vals) != expected_elements:
454 raise ValueError(
455 f"Number of values ({len(vals)}) does not match tensor "
456 f"dimensions requiring {expected_elements} elements."
457 )
458 if data_type == TensorProto.COMPLEX128:
459 vals = vals.view(np.float64) # type: ignore[union-attr]
460 elif data_type == TensorProto.COMPLEX64:
461 vals = vals.view(np.float32) # type: ignore[union-attr]
462 elif data_type in {TensorProto.BFLOAT16, TensorProto.FLOAT16}:
463 vals = vals.view(np.uint16) # type: ignore[union-attr]
464 elif data_type in {
465 TensorProto.FLOAT8E4M3FN,
466 TensorProto.FLOAT8E4M3FNUZ,
467 TensorProto.FLOAT8E5M2,
468 TensorProto.FLOAT8E5M2FNUZ,
469 TensorProto.FLOAT8E8M0,
470 }:
471 vals = vals.view(np.uint8) # type: ignore[union-attr]
472 elif data_type in {TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1}:
473 # Convert to packed 4-bit representation
474 vals = onnx.numpy_helper._pack_4bitx2(vals) # type: ignore[arg-type]
475 elif data_type in {TensorProto.UINT2, TensorProto.INT2}:
476 # Convert to packed 2-bit representation
477 vals = onnx.numpy_helper._pack_2bitx4(vals) # type: ignore[arg-type]
478 elif data_type == TensorProto.BOOL:
479 vals = vals.astype(np.uint8) # type: ignore[union-attr]
481 field = tensor_dtype_to_field(data_type)
482 getattr(tensor, field).extend(vals)
483 return tensor
486def make_sparse_tensor(
487 values: TensorProto, indices: TensorProto, dims: Sequence[int]
488) -> onnx.SparseTensorProto:
489 """Construct a SparseTensorProto
491 Args:
492 values (TensorProto): the values
493 indices (TensorProto): the indices
494 dims: the shape
496 Returns:
497 SparseTensorProto
498 """
499 sparse = onnx.SparseTensorProto()
500 sparse.values.CopyFrom(values)
501 sparse.indices.CopyFrom(indices)
502 sparse.dims.extend(dims)
503 return sparse
506def make_sequence(
507 name: str,
508 elem_type: SequenceProto.DataType,
509 values: Sequence[Any],
510) -> SequenceProto:
511 """Make a Sequence with specified value arguments."""
512 sequence = SequenceProto()
513 sequence.name = name
514 sequence.elem_type = elem_type # type: ignore[assignment]
516 if elem_type == SequenceProto.UNDEFINED:
517 return sequence
519 attribute: RepeatedCompositeFieldContainer | None = None
520 if elem_type == SequenceProto.TENSOR:
521 attribute = sequence.tensor_values
522 elif elem_type == SequenceProto.SPARSE_TENSOR:
523 attribute = sequence.sparse_tensor_values
524 elif elem_type == SequenceProto.SEQUENCE:
525 attribute = sequence.sequence_values
526 elif elem_type == SequenceProto.MAP:
527 attribute = sequence.map_values
528 elif elem_type == OptionalProto.OPTIONAL:
529 attribute = sequence.optional_values
530 else:
531 raise TypeError("The element type in the input sequence is not supported.")
533 attribute.extend(values)
534 return sequence
537def make_map(
538 name: str, key_type: int, keys: list[Any], values: SequenceProto
539) -> MapProto:
540 """Make a Map with specified key-value pair arguments.
542 Criteria for conversion:
543 - Keys and Values must have the same number of elements
544 - Every key in keys must be of the same type
545 - Every value in values must be of the same type
546 """
547 map_proto = MapProto()
548 valid_key_int_types = [
549 TensorProto.INT8,
550 TensorProto.INT16,
551 TensorProto.INT32,
552 TensorProto.INT64,
553 TensorProto.UINT8,
554 TensorProto.UINT16,
555 TensorProto.UINT32,
556 TensorProto.UINT64,
557 ]
558 map_proto.name = name
559 map_proto.key_type = key_type
560 if key_type == TensorProto.STRING:
561 map_proto.string_keys.extend(keys)
562 elif key_type in valid_key_int_types:
563 map_proto.keys.extend(keys)
564 map_proto.values.CopyFrom(values)
565 return map_proto
568def make_optional(
569 name: str,
570 elem_type: OptionalProto.DataType,
571 value: google.protobuf.message.Message | None,
572) -> OptionalProto:
573 """Make an Optional with specified value arguments."""
574 optional = OptionalProto()
575 optional.name = name
576 optional.elem_type = elem_type # type: ignore[assignment]
578 if elem_type == OptionalProto.UNDEFINED:
579 return optional
580 attribute: google.protobuf.message.Message | None = None
581 if elem_type == OptionalProto.TENSOR:
582 attribute = optional.tensor_value
583 elif elem_type == OptionalProto.SPARSE_TENSOR:
584 attribute = optional.sparse_tensor_value
585 elif elem_type == OptionalProto.SEQUENCE:
586 attribute = optional.sequence_value
587 elif elem_type == OptionalProto.MAP:
588 attribute = optional.map_value
589 elif elem_type == OptionalProto.OPTIONAL:
590 attribute = optional.optional_value
591 else:
592 raise TypeError("The element type in the input optional is not supported.")
594 assert value is not None
595 attribute.CopyFrom(value) # type: ignore[arg-type]
596 return optional
599def _to_bytes(value: str | bytes) -> bytes:
600 """Coerce a string (or bytes) value into UTF-8 bytes."""
601 if isinstance(value, str):
602 return value.encode("utf-8")
603 return value
606def make_attribute(
607 key: str,
608 value: Any,
609 doc_string: str | None = None,
610 attr_type: int | None = None,
611) -> AttributeProto:
612 """Makes an AttributeProto based on the value type."""
613 attr = AttributeProto()
614 attr.name = key
615 if doc_string:
616 attr.doc_string = doc_string
618 # Singular cases
619 if isinstance(value, numbers.Integral):
620 attr.i = int(value)
621 attr.type = AttributeProto.INT
622 elif isinstance(value, numbers.Real):
623 attr.f = float(value)
624 attr.type = AttributeProto.FLOAT
625 elif isinstance(value, (str, bytes)):
626 # Encode strings into utf-8
627 attr.s = _to_bytes(value)
628 attr.type = AttributeProto.STRING
629 elif isinstance(value, TensorProto):
630 attr.t.CopyFrom(value)
631 attr.type = AttributeProto.TENSOR
632 elif isinstance(value, onnx.SparseTensorProto):
633 attr.sparse_tensor.CopyFrom(value)
634 attr.type = AttributeProto.SPARSE_TENSOR
635 elif isinstance(value, GraphProto):
636 attr.g.CopyFrom(value)
637 attr.type = AttributeProto.GRAPH
638 elif isinstance(value, TypeProto):
639 attr.tp.CopyFrom(value)
640 attr.type = AttributeProto.TYPE_PROTO
641 # Iterable cases
642 elif isinstance(value, collections.abc.Iterable):
643 value = list(value)
644 if len(value) == 0 and attr_type is None:
645 raise ValueError(
646 f"Could not infer attribute `{key}` type from empty iterator"
647 )
648 if attr_type is None:
649 types = {type(v) for v in value}
650 for exp_t, exp_enum in (
651 (numbers.Integral, AttributeProto.INTS),
652 (numbers.Real, AttributeProto.FLOATS),
653 ((str, bytes), AttributeProto.STRINGS),
654 (TensorProto, AttributeProto.TENSORS),
655 (onnx.SparseTensorProto, AttributeProto.SPARSE_TENSORS),
656 (GraphProto, AttributeProto.GRAPHS),
657 (TypeProto, AttributeProto.TYPE_PROTOS),
658 ):
659 if all(issubclass(t, exp_t) for t in types):
660 attr_type = exp_enum
661 break
662 if attr_type is None:
663 raise ValueError(
664 "Could not infer the attribute type from the elements of the passed Iterable value."
665 )
667 if attr_type == AttributeProto.INTS:
668 attr.ints.extend(value)
669 attr.type = AttributeProto.INTS
670 elif attr_type == AttributeProto.FLOATS:
671 attr.floats.extend(value)
672 attr.type = AttributeProto.FLOATS
673 elif attr_type == AttributeProto.STRINGS:
674 attr.strings.extend(_to_bytes(v) for v in value)
675 attr.type = AttributeProto.STRINGS
676 elif attr_type == AttributeProto.TENSORS:
677 attr.tensors.extend(value)
678 attr.type = AttributeProto.TENSORS
679 elif attr_type == AttributeProto.SPARSE_TENSORS:
680 attr.sparse_tensors.extend(value)
681 attr.type = AttributeProto.SPARSE_TENSORS
682 elif attr_type == AttributeProto.GRAPHS:
683 attr.graphs.extend(value)
684 attr.type = AttributeProto.GRAPHS
685 elif attr_type == AttributeProto.TYPE_PROTOS:
686 attr.type_protos.extend(value)
687 attr.type = AttributeProto.TYPE_PROTOS
688 else:
689 raise AssertionError # Should not reach since `ValueError` must be raised in attr_type checking
690 else:
691 raise TypeError(f"'{value}' is not an accepted attribute value.")
693 if attr_type is not None and attr.type != attr_type:
694 raise TypeError(
695 f"Inferred attribute type '{_attr_type_to_str(attr.type)}'({attr.type}) mismatched with specified type '{_attr_type_to_str(attr_type)}'({attr_type})"
696 )
697 return attr
700def make_attribute_ref(
701 name: str,
702 attr_type: AttributeProto.AttributeType,
703 doc_string: str | None = None,
704 *,
705 ref_attr_name: str | None = None,
706) -> AttributeProto:
707 """Make an AttributeProto holding a reference to the parent function's attribute.
709 The returned attribute carries no value of its own; at instantiation time its
710 value is supplied by the parent function's attribute named ``ref_attr_name``.
711 When ``ref_attr_name`` is not provided, it defaults to ``name``. Reference
712 attributes are only valid inside a function (sub-graph).
714 Args:
715 name: The name of this attribute as used inside the function body.
716 attr_type: The type of the attribute.
717 doc_string: Optional human-readable documentation for the attribute.
718 ref_attr_name: The name of the parent function's attribute being referenced.
719 """
720 if ref_attr_name is None:
721 ref_attr_name = name
722 if not ref_attr_name:
723 raise ValueError("ref_attr_name must be non-empty")
725 attr = AttributeProto()
726 attr.name = name
727 attr.type = attr_type # type: ignore[assignment]
728 attr.ref_attr_name = ref_attr_name
729 if doc_string:
730 attr.doc_string = doc_string
731 return attr
734def get_attribute_value(attr: AttributeProto) -> Any: # noqa: PLR0911
735 if attr.ref_attr_name:
736 raise ValueError(f"Cannot get value of reference attribute: {attr}")
737 if attr.type == AttributeProto.FLOAT:
738 return attr.f
739 if attr.type == AttributeProto.INT:
740 return attr.i
741 if attr.type == AttributeProto.STRING:
742 return attr.s
743 if attr.type == AttributeProto.TENSOR:
744 return attr.t
745 if attr.type == AttributeProto.SPARSE_TENSOR:
746 return attr.sparse_tensor
747 if attr.type == AttributeProto.GRAPH:
748 return attr.g
749 if attr.type == AttributeProto.TYPE_PROTO:
750 return attr.tp
751 if attr.type == AttributeProto.FLOATS:
752 return list(attr.floats)
753 if attr.type == AttributeProto.INTS:
754 return list(attr.ints)
755 if attr.type == AttributeProto.STRINGS:
756 return list(attr.strings)
757 if attr.type == AttributeProto.TENSORS:
758 return list(attr.tensors)
759 if attr.type == AttributeProto.SPARSE_TENSORS:
760 return list(attr.sparse_tensors)
761 if attr.type == AttributeProto.GRAPHS:
762 return list(attr.graphs)
763 if attr.type == AttributeProto.TYPE_PROTOS:
764 return list(attr.type_protos)
765 if attr.type == AttributeProto.UNDEFINED:
766 return None
767 raise ValueError(f"Unsupported ONNX attribute: {attr}")
770def get_node_attr_value(node: NodeProto, attr_name: str) -> Any:
771 matching = [x for x in node.attribute if x.name == attr_name]
772 if len(matching) > 1:
773 raise ValueError(f"Node has multiple attributes with name {attr_name}")
774 if len(matching) < 1:
775 raise ValueError(f"Node has no attribute with name {attr_name}")
776 return get_attribute_value(matching[0])
779def make_empty_tensor_value_info(name: str) -> ValueInfoProto:
780 value_info_proto = ValueInfoProto()
781 value_info_proto.name = name
782 return value_info_proto
785def make_tensor_type_proto(
786 elem_type: int,
787 shape: Sequence[str | int | None] | None,
788 shape_denotation: list[str] | None = None,
789) -> TypeProto:
790 """Makes a Tensor TypeProto based on the data type and shape."""
791 type_proto = TypeProto()
792 tensor_type_proto = type_proto.tensor_type
793 tensor_type_proto.elem_type = elem_type
794 tensor_shape_proto = tensor_type_proto.shape
796 if shape is not None:
797 # You might think this is a no-op (extending a normal Python
798 # list by [] certainly is), but protobuf lists work a little
799 # differently; if a field is never set, it is omitted from the
800 # resulting protobuf; a list that is explicitly set to be
801 # empty will get an (empty) entry in the protobuf. This
802 # difference is visible to our consumers, so make sure we emit
803 # an empty shape!
804 tensor_shape_proto.dim.extend([])
806 if shape_denotation and len(shape_denotation) != len(shape):
807 raise ValueError(
808 "Invalid shape_denotation. Must be of the same length as shape."
809 )
811 for i, d in enumerate(shape):
812 dim = tensor_shape_proto.dim.add()
813 if d is None:
814 pass
815 elif isinstance(d, int):
816 dim.dim_value = d
817 elif isinstance(d, str):
818 dim.dim_param = d
819 else:
820 raise ValueError(
821 f"Invalid item in shape: {d}. Needs to be of int or str."
822 )
824 if shape_denotation:
825 dim.denotation = shape_denotation[i]
827 return type_proto
830def make_tensor_value_info(
831 name: str,
832 elem_type: int,
833 shape: Sequence[str | int | None] | None,
834 doc_string: str = "",
835 shape_denotation: list[str] | None = None,
836) -> ValueInfoProto:
837 """Makes a ValueInfoProto based on the data type and shape."""
838 value_info_proto = ValueInfoProto()
839 value_info_proto.name = name
840 if doc_string:
841 value_info_proto.doc_string = doc_string
843 tensor_type_proto = make_tensor_type_proto(elem_type, shape, shape_denotation)
844 value_info_proto.type.CopyFrom(tensor_type_proto)
845 return value_info_proto
848def make_sparse_tensor_type_proto(
849 elem_type: int,
850 shape: Sequence[str | int | None] | None,
851 shape_denotation: list[str] | None = None,
852) -> TypeProto:
853 """Makes a SparseTensor TypeProto based on the data type and shape."""
854 type_proto = TypeProto()
855 sparse_tensor_type_proto = type_proto.sparse_tensor_type
856 sparse_tensor_type_proto.elem_type = elem_type
857 sparse_tensor_shape_proto = sparse_tensor_type_proto.shape
859 if shape is not None:
860 # You might think this is a no-op (extending a normal Python
861 # list by [] certainly is), but protobuf lists work a little
862 # differently; if a field is never set, it is omitted from the
863 # resulting protobuf; a list that is explicitly set to be
864 # empty will get an (empty) entry in the protobuf. This
865 # difference is visible to our consumers, so make sure we emit
866 # an empty shape!
867 sparse_tensor_shape_proto.dim.extend([])
869 if shape_denotation and len(shape_denotation) != len(shape):
870 raise ValueError(
871 "Invalid shape_denotation. Must be of the same length as shape."
872 )
874 for i, d in enumerate(shape):
875 dim = sparse_tensor_shape_proto.dim.add()
876 if d is None:
877 pass
878 elif isinstance(d, int):
879 dim.dim_value = d
880 elif isinstance(d, str):
881 dim.dim_param = d
882 else:
883 raise ValueError(
884 f"Invalid item in shape: {d}. Needs to be of int or text."
885 )
887 if shape_denotation:
888 dim.denotation = shape_denotation[i]
890 return type_proto
893def make_sparse_tensor_value_info(
894 name: str,
895 elem_type: int,
896 shape: Sequence[str | int | None] | None,
897 doc_string: str = "",
898 shape_denotation: list[str] | None = None,
899) -> ValueInfoProto:
900 """Makes a SparseTensor ValueInfoProto based on the data type and shape."""
901 value_info_proto = ValueInfoProto()
902 value_info_proto.name = name
903 if doc_string:
904 value_info_proto.doc_string = doc_string
906 sparse_tensor_type_proto = make_sparse_tensor_type_proto(
907 elem_type, shape, shape_denotation
908 )
909 value_info_proto.type.sparse_tensor_type.CopyFrom(
910 sparse_tensor_type_proto.sparse_tensor_type
911 )
912 return value_info_proto
915def make_sequence_type_proto(
916 inner_type_proto: TypeProto,
917) -> TypeProto:
918 """Makes a sequence TypeProto."""
919 type_proto = TypeProto()
920 type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto)
921 return type_proto
924def make_optional_type_proto(
925 inner_type_proto: TypeProto,
926) -> TypeProto:
927 """Makes an optional TypeProto."""
928 type_proto = TypeProto()
929 type_proto.optional_type.elem_type.CopyFrom(inner_type_proto)
930 return type_proto
933def make_map_type_proto(
934 key_type: int,
935 value_type: TypeProto,
936) -> TypeProto:
937 """Makes a map TypeProto."""
938 type_proto = TypeProto()
939 type_proto.map_type.key_type = key_type
940 type_proto.map_type.value_type.CopyFrom(value_type)
941 return type_proto
944def make_value_info(
945 name: str,
946 type_proto: TypeProto,
947 doc_string: str = "",
948) -> ValueInfoProto:
949 """Makes a ValueInfoProto with the given type_proto."""
950 value_info_proto = ValueInfoProto()
951 value_info_proto.name = name
952 if doc_string:
953 value_info_proto.doc_string = doc_string
955 value_info_proto.type.CopyFrom(type_proto)
956 return value_info_proto
959def _sanitize_str(s: str | bytes) -> str:
960 if isinstance(s, str):
961 sanitized = s
962 elif isinstance(s, bytes):
963 sanitized = s.decode("utf-8", errors="ignore")
964 else:
965 sanitized = str(s)
966 if len(sanitized) < 64: # noqa: PLR2004
967 return sanitized
968 return sanitized[:64] + f"...<+len={(len(sanitized) - 64)}>"
971def make_tensor_sequence_value_info(
972 name: str,
973 elem_type: int,
974 shape: Sequence[str | int | None] | None,
975 doc_string: str = "",
976 elem_shape_denotation: list[str] | None = None,
977) -> ValueInfoProto:
978 """Makes a Sequence[Tensors] ValueInfoProto based on the data type and shape."""
979 value_info_proto = ValueInfoProto()
980 value_info_proto.name = name
981 if doc_string:
982 value_info_proto.doc_string = doc_string
984 tensor_type_proto = make_tensor_type_proto(elem_type, shape, elem_shape_denotation)
985 sequence_type_proto = make_sequence_type_proto(tensor_type_proto)
986 value_info_proto.type.sequence_type.CopyFrom(sequence_type_proto.sequence_type)
988 return value_info_proto
991def printable_attribute(
992 attr: AttributeProto, subgraphs: bool = False
993) -> str | tuple[str, list[GraphProto]]:
994 content = []
995 content.append(attr.name)
996 content.append("=")
998 def str_float(f: float) -> str:
999 # NB: Different Python versions print different numbers of trailing
1000 # decimals, specifying this explicitly keeps it consistent for all
1001 # versions
1002 return f"{f:.15g}"
1004 def str_int(i: int) -> str:
1005 return str(i)
1007 _T = TypeVar("_T")
1009 def str_list(str_elem: Callable[[_T], str], xs: Sequence[_T]) -> str:
1010 return "[" + ", ".join(map(str_elem, xs)) + "]"
1012 # for now, this logic should continue to work as long as we are running on a proto3
1013 # implementation. If/when we switch to proto3, we will need to use attr.type
1015 # To support printing subgraphs, if we find a graph attribute, print out
1016 # its name here and pass the graph itself up to the caller for later
1017 # printing.
1018 graphs = []
1019 if attr.HasField("f"):
1020 content.append(str_float(attr.f))
1021 elif attr.HasField("i"):
1022 content.append(str_int(attr.i))
1023 elif attr.HasField("s"):
1024 # TODO: Bit nervous about Python 2 / Python 3 determinism implications
1025 content.append(repr(_sanitize_str(attr.s)))
1026 elif attr.HasField("t"):
1027 if len(attr.t.dims) > 0:
1028 content.append("<Tensor>")
1029 else:
1030 # special case to print scalars
1031 field = tensor_dtype_to_field(attr.t.data_type)
1032 content.append(f"<Scalar Tensor {getattr(attr.t, field)}>")
1033 elif attr.HasField("g"):
1034 content.append(f"<graph {attr.g.name}>")
1035 graphs.append(attr.g)
1036 elif attr.HasField("tp"):
1037 content.append(f"<Type Proto {attr.tp}>")
1038 elif attr.HasField("sparse_tensor"):
1039 content.append("<Sparse Tensor>")
1040 elif attr.floats:
1041 content.append(str_list(str_float, attr.floats))
1042 elif attr.ints:
1043 content.append(str_list(str_int, attr.ints))
1044 elif attr.strings:
1045 # TODO: Bit nervous about Python 2 / Python 3 determinism implications
1046 content.append(str(list(map(_sanitize_str, attr.strings))))
1047 elif attr.tensors:
1048 content.append("[<Tensor>, ...]")
1049 elif attr.sparse_tensors:
1050 content.append("[<Sparse Tensor>, ...]")
1051 elif attr.type_protos:
1052 content.append("[")
1053 for i, tp in enumerate(attr.type_protos):
1054 comma = "," if i != len(attr.type_protos) - 1 else ""
1055 content.append(f"<Type Proto {tp}>{comma}")
1056 content.append("]")
1057 elif attr.graphs:
1058 content.append("[")
1059 for i, g in enumerate(attr.graphs):
1060 comma = "," if i != len(attr.graphs) - 1 else ""
1061 content.append(f"<graph {g.name}>{comma}")
1062 content.append("]")
1063 graphs.extend(attr.graphs)
1064 else:
1065 content.append("<Unknown>")
1066 if subgraphs:
1067 return " ".join(content), graphs
1068 return " ".join(content)
1071def printable_dim(dim: TensorShapeProto.Dimension) -> str:
1072 which = dim.WhichOneof("value")
1073 if which is None:
1074 return "?"
1075 return str(getattr(dim, which))
1078def printable_type(t: TypeProto) -> str:
1079 if t.WhichOneof("value") == "tensor_type":
1080 s: str = TensorProto.DataType.Name(t.tensor_type.elem_type) # type: ignore[arg-type]
1081 if t.tensor_type.HasField("shape"):
1082 if len(t.tensor_type.shape.dim):
1083 s += str(", " + "x".join(map(printable_dim, t.tensor_type.shape.dim)))
1084 else:
1085 s += ", scalar"
1086 return s
1087 if t.WhichOneof("value") is None:
1088 return ""
1089 return f"Unknown type {t.WhichOneof('value')}"
1092def printable_value_info(v: ValueInfoProto) -> str:
1093 s = f"%{v.name}"
1094 if v.type:
1095 s = f"{s}[{printable_type(v.type)}]"
1096 return s
1099def printable_tensor_proto(t: TensorProto) -> str:
1100 s = f"%{t.name}["
1101 s += TensorProto.DataType.Name(t.data_type) # type: ignore[arg-type]
1102 if t.dims is not None:
1103 if len(t.dims):
1104 s += str(", " + "x".join(map(str, t.dims)))
1105 else:
1106 s += ", scalar"
1107 s += "]"
1108 return s
1111def printable_node(
1112 node: NodeProto, prefix: str = "", subgraphs: bool = False
1113) -> str | tuple[str, list[GraphProto]]:
1114 content = []
1115 if len(node.output):
1116 content.append(", ".join([f"%{name}" for name in node.output]))
1117 content.append("=")
1118 # To deal with nested graphs
1119 graphs: list[GraphProto] = []
1120 printed_attrs = []
1121 for attr in node.attribute:
1122 if subgraphs:
1123 printed_attr_subgraphs = printable_attribute(attr, subgraphs)
1124 if not isinstance(printed_attr_subgraphs[1], list):
1125 raise TypeError(
1126 f"printed_attr_subgraphs[1] must be an instance of {list}."
1127 )
1128 graphs.extend(printed_attr_subgraphs[1])
1129 printed_attrs.append(printed_attr_subgraphs[0])
1130 else:
1131 printed = printable_attribute(attr)
1132 if not isinstance(printed, str):
1133 raise TypeError(f"printed must be an instance of {str}.")
1134 printed_attrs.append(printed)
1135 printed_attributes = ", ".join(sorted(printed_attrs))
1136 printed_inputs = ", ".join([f"%{name}" for name in node.input])
1137 if node.attribute:
1138 content.append(f"{node.op_type}[{printed_attributes}]({printed_inputs})")
1139 else:
1140 content.append(f"{node.op_type}({printed_inputs})")
1141 if subgraphs:
1142 return prefix + " ".join(content), graphs
1143 return prefix + " ".join(content)
1146@typing_extensions.deprecated(
1147 "Deprecated since 1.19. Consider using onnx.printer.to_text() instead."
1148)
1149def printable_graph(graph: GraphProto, prefix: str = "") -> str:
1150 """Display a GraphProto as a string.
1152 .. deprecated:: 1.19
1153 Consider using :func:`onnx.printer.to_text` instead.
1155 Args:
1156 graph (GraphProto): the graph to display
1157 prefix (string): prefix of every line
1159 Returns:
1160 string
1161 """
1162 content = []
1163 indent = prefix + " "
1164 # header
1165 header = ["graph", graph.name]
1166 initializers = {t.name for t in graph.initializer}
1167 if len(graph.input):
1168 header.append("(")
1169 in_strs = [] # required inputs
1170 in_with_init_strs: list = [] # optional inputs with initializer providing default value
1171 for inp in graph.input:
1172 if inp.name not in initializers:
1173 in_strs.append(printable_value_info(inp))
1174 else:
1175 in_with_init_strs.append(printable_value_info(inp))
1176 if in_strs:
1177 content.append(prefix + " ".join(header))
1178 header = []
1179 for line in in_strs:
1180 content.append(prefix + " " + line) # noqa: PERF401
1181 header.append(")")
1183 if in_with_init_strs:
1184 header.append("optional inputs with matching initializers (")
1185 content.append(prefix + " ".join(header))
1186 header = []
1187 for line in in_with_init_strs:
1188 content.append(prefix + " " + line) # noqa: PERF401
1189 header.append(")")
1191 # from IR 4 onwards an initializer is not required to have a matching graph input
1192 # so output the name, type and shape of those as well
1193 if len(in_with_init_strs) < len(initializers):
1194 graph_inputs = {i.name for i in graph.input}
1195 init_strs = [
1196 printable_tensor_proto(i)
1197 for i in graph.initializer
1198 if i.name not in graph_inputs
1199 ]
1200 header.append("initializers (")
1201 content.append(prefix + " ".join(header))
1202 header = []
1203 for line in init_strs:
1204 content.append(prefix + " " + line) # noqa: PERF401
1205 header.append(")")
1207 header.append("{")
1208 content.append(prefix + " ".join(header))
1209 graphs: list[GraphProto] = []
1210 # body
1211 for node in graph.node:
1212 contents_subgraphs = printable_node(node, indent, subgraphs=True)
1213 if not isinstance(contents_subgraphs[1], list):
1214 raise TypeError(f"contents_subgraphs[1] must be an instance of {list}.")
1215 content.append(contents_subgraphs[0])
1216 graphs.extend(contents_subgraphs[1])
1217 # tail
1218 tail = ["return"]
1219 if len(graph.output):
1220 tail.append(", ".join([f"%{out.name}" for out in graph.output]))
1221 content.append(indent + " ".join(tail))
1222 # closing bracket
1223 content.append(prefix + "}")
1224 for g in graphs:
1225 content.append("\n" + printable_graph(g)) # noqa: PERF401
1226 return "\n".join(content)
1229def strip_doc_string(proto: google.protobuf.message.Message) -> None:
1230 """Empties `doc_string` field on any nested protobuf messages"""
1231 if not isinstance(proto, google.protobuf.message.Message):
1232 raise TypeError(
1233 f"proto must be an instance of {google.protobuf.message.Message}."
1234 )
1235 for descriptor in proto.DESCRIPTOR.fields:
1236 if descriptor.name == "doc_string":
1237 proto.ClearField(descriptor.name)
1238 elif descriptor.type == descriptor.TYPE_MESSAGE:
1239 if descriptor.label == descriptor.LABEL_REPEATED:
1240 for x in getattr(proto, descriptor.name):
1241 strip_doc_string(x)
1242 elif proto.HasField(descriptor.name):
1243 strip_doc_string(getattr(proto, descriptor.name))
1246def make_training_info(
1247 algorithm: GraphProto,
1248 algorithm_bindings: AssignmentBindingType,
1249 initialization: GraphProto | None,
1250 initialization_bindings: AssignmentBindingType | None,
1251) -> TrainingInfoProto:
1252 training_info = TrainingInfoProto()
1253 training_info.algorithm.CopyFrom(algorithm)
1254 for k, v in algorithm_bindings:
1255 binding = training_info.update_binding.add()
1256 binding.key = k
1257 binding.value = v
1259 if initialization:
1260 training_info.initialization.CopyFrom(initialization)
1261 if initialization_bindings:
1262 for k, v in initialization_bindings:
1263 binding = training_info.initialization_binding.add()
1264 binding.key = k
1265 binding.value = v
1267 return training_info
1270# Following functions are used for mapping
1271def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
1272 """Convert a TensorProto's data_type to corresponding numpy dtype. It can be used while making tensor.
1274 Args:
1275 tensor_dtype: TensorProto's data_type
1277 Returns:
1278 numpy's data_type
1279 """
1280 return _mapping.TENSOR_TYPE_MAP[tensor_dtype].np_dtype
1283def tensor_dtype_to_storage_tensor_dtype(tensor_dtype: int) -> int:
1284 """Convert a TensorProto's data_type to corresponding data_type for storage.
1286 Args:
1287 tensor_dtype: TensorProto's data_type
1289 Returns:
1290 data_type for storage
1291 """
1292 return _mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype
1295def tensor_dtype_to_string(tensor_dtype: int) -> str:
1296 """Get the name of given TensorProto's data_type.
1298 Args:
1299 tensor_dtype: TensorProto's data_type
1301 Returns:
1302 the name of data_type
1303 """
1304 return _mapping.TENSOR_TYPE_MAP[tensor_dtype].name
1307@functools.lru_cache(None)
1308def tensor_dtype_to_field(tensor_dtype: int) -> str:
1309 """Convert a TensorProto's data_type to corresponding field name for storage. It can be used while making tensors.
1311 Args:
1312 tensor_dtype: TensorProto's data_type
1314 Returns:
1315 field name
1316 """
1317 storage_tensor_type_to_field = {
1318 int(TensorProto.FLOAT): "float_data",
1319 int(TensorProto.INT32): "int32_data",
1320 int(TensorProto.INT64): "int64_data",
1321 int(TensorProto.DOUBLE): "double_data",
1322 int(TensorProto.UINT32): "uint64_data",
1323 int(TensorProto.UINT64): "uint64_data",
1324 int(TensorProto.STRING): "string_data",
1325 }
1326 return storage_tensor_type_to_field[
1327 _mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype
1328 ]
1331@functools.lru_cache(None)
1332def np_dtype_to_tensor_dtype(np_dtype: np.dtype) -> TensorProto.DataType:
1333 """Convert a numpy's dtype to corresponding tensor type. It can be used while converting numpy arrays to tensors.
1335 Args:
1336 np_dtype: numpy's data_type
1338 Returns:
1339 TensorsProto's data_type
1340 """
1341 _np_dtype_to_tensor_dtype = {
1342 v.np_dtype: k for k, v in _mapping.TENSOR_TYPE_MAP.items()
1343 }
1344 if np_dtype in _np_dtype_to_tensor_dtype:
1345 return typing.cast("TensorProto.DataType", _np_dtype_to_tensor_dtype[np_dtype])
1346 if np.issubdtype(np_dtype, np.str_):
1347 return TensorProto.STRING # type: ignore[return-value]
1349 raise ValueError(
1350 f"Unable to convert type {np_dtype!r} into TensorProto element type."
1351 )
1354def get_all_tensor_dtypes() -> KeysView[int]:
1355 """Get all tensor types from TensorProto.
1357 Returns:
1358 all tensor types from TensorProto
1359 """
1360 return _mapping.TENSOR_TYPE_MAP.keys()
1363_ATTRIBUTE_TYPE_TO_STR: dict[int, str] = {
1364 k: v for v, k in AttributeProto.AttributeType.items()
1365}
1368def _attr_type_to_str(attr_type: int) -> str:
1369 """Convert AttributeProto type to string.
1371 Args:
1372 attr_type: AttributeProto type.
1374 Returns:
1375 String representing the supplied attr_type.
1376 """
1377 if attr_type in AttributeProto.AttributeType.values():
1378 return _ATTRIBUTE_TYPE_TO_STR[attr_type]
1379 return AttributeProto.AttributeType.keys()[0]