Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/compose.py: 4%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
6from typing import TYPE_CHECKING
8from onnx import (
9 AttributeProto,
10 GraphProto,
11 ModelProto,
12 TensorProto,
13 checker,
14 helper,
15 utils,
16)
18if TYPE_CHECKING:
19 from collections.abc import MutableMapping
22def check_overlapping_names(
23 g1: GraphProto, g2: GraphProto, io_map: list[tuple[str, str]] | None = None
24) -> list[tuple[str, list[str]]]:
25 """Checks whether there are name collisions between two graphs
27 Returns a list of tuples where the first element represents the member containing overlapping names
28 (One of: "node", "edge", "value_info", "initializer", "sparse_initializer"), and the
29 second element contains a list of names that appear in both graphs on that category.
31 Optionally, it takes an io_map, representing the output/inputs to be connected. It provided, overlapping
32 present in the io_map argument will be ignored.
33 """
34 if not isinstance(g1, GraphProto):
35 raise TypeError("g1 argument is not an ONNX graph")
36 if not isinstance(g2, GraphProto):
37 raise TypeError("g2 argument is not an ONNX graph")
39 def _overlapping(c1: list[str], c2: list[str]) -> list[str]:
40 return list(set(c1) & set(c2))
42 def _edge_names(graph: GraphProto, exclude: set[str] | None = None) -> list[str]:
43 if exclude is None:
44 exclude = set()
45 edges = []
46 for n in graph.node:
47 for i in n.input:
48 if i != "" and i not in exclude:
49 edges.append(i) # noqa: PERF401
50 for o in n.output:
51 if o != "" and o not in exclude:
52 edges.append(o) # noqa: PERF401
53 return edges
55 result = []
57 if not io_map:
58 io_map = []
59 io_map_inputs = {elem[1] for elem in io_map}
61 # Edges already cover input/output
62 overlap = _overlapping(_edge_names(g1), _edge_names(g2, exclude=io_map_inputs))
63 if overlap:
64 result.append(("edge", overlap))
66 overlap = _overlapping(
67 [e.name for e in g1.value_info], [e.name for e in g2.value_info]
68 )
69 if overlap:
70 result.append(("value_info", overlap))
72 overlap = _overlapping(
73 [e.name for e in g1.initializer], [e.name for e in g2.initializer]
74 )
75 if overlap:
76 result.append(("initializer", overlap))
78 overlap = _overlapping(
79 [e.values.name for e in g1.sparse_initializer],
80 [e.values.name for e in g2.sparse_initializer],
81 ) + _overlapping(
82 [e.indices.name for e in g1.sparse_initializer],
83 [e.indices.name for e in g2.sparse_initializer],
84 )
85 if overlap:
86 result.append(("sparse_initializer", overlap))
88 return result
91def merge_graphs(
92 g1: GraphProto,
93 g2: GraphProto,
94 io_map: list[tuple[str, str]],
95 inputs: list[str] | None = None,
96 outputs: list[str] | None = None,
97 prefix1: str | None = None,
98 prefix2: str | None = None,
99 name: str | None = None,
100 doc_string: str | None = None,
101) -> GraphProto:
102 """Combines two ONNX graphs into a single one.
104 The combined graph is defined by connecting the specified set of outputs/inputs. Those inputs/outputs
105 not specified in the io_map argument will remain as inputs/outputs of the combined graph.
107 Arguments:
108 g1 (GraphProto): First graph
109 g2 (GraphProto): Second graph
110 io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...]
111 representing outputs of the first graph and inputs of the second
112 to be connected
113 inputs (list of string): Optional list of inputs to be included in the combined graph
114 By default, all inputs not present in the ``io_map`` argument will be
115 included in the combined model
116 outputs (list of string): Optional list of outputs to be included in the combined graph
117 By default, all outputs not present in the ``io_map`` argument will be
118 included in the combined model
119 prefix1 (string): Optional prefix to be added to all names in g1
120 prefix2 (string): Optional prefix to be added to all names in g2
121 name (string): Optional name for the combined graph
122 By default, the name is g1.name and g2.name concatenated with an underscore delimiter
123 doc_string (string): Optional docstring for the combined graph
124 If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used
126 Returns:
127 GraphProto
128 """
129 if not isinstance(g1, GraphProto):
130 raise TypeError("g1 argument is not an ONNX graph")
131 if not isinstance(g2, GraphProto):
132 raise TypeError("g2 argument is not an ONNX graph")
134 # Prefixing names in the graph if requested, adjusting io_map accordingly
135 if prefix1 or prefix2:
136 if prefix1:
137 g1_copy = GraphProto()
138 g1_copy.CopyFrom(g1)
139 g1 = g1_copy
140 g1 = add_prefix_graph(g1, prefix=prefix1)
141 if prefix2:
142 g2_copy = GraphProto()
143 g2_copy.CopyFrom(g2)
144 g2 = g2_copy
145 g2 = add_prefix_graph(g2, prefix=prefix2)
146 io_map = [
147 (
148 prefix1 + io[0] if prefix1 else io[0],
149 prefix2 + io[1] if prefix2 else io[1],
150 )
151 for io in io_map
152 ]
154 io_map_g1_outs = {io[0] for io in io_map}
155 io_map_g2_ins = {io[1] for io in io_map}
156 reversed_io_map = {in_name: out_name for out_name, in_name in io_map}
157 g1_outs = {o.name for o in g1.output}
158 g2_ins = {i.name for i in g2.input}
160 # If necessary extract subgraphs
161 if inputs or outputs:
162 if not inputs:
163 g1_inputs = [i.name for i in g1.input]
164 g2_inputs = [i.name for i in g2.input]
165 else:
166 input_set = set(inputs)
167 g1_inputs = [i.name for i in g1.input if i.name in input_set]
168 g2_inputs = [
169 i.name
170 for i in g2.input
171 if i.name in input_set or i.name in io_map_g2_ins
172 ]
174 if not outputs:
175 g1_outputs = [o.name for o in g1.output]
176 g2_outputs = [o.name for o in g2.output]
177 else:
178 output_set = set(outputs)
179 g1_outputs = [
180 o.name
181 for o in g1.output
182 if o.name in output_set or o.name in io_map_g1_outs
183 ]
184 g2_outputs = [o.name for o in g2.output if o.name in output_set]
186 if len(g1_inputs) < len(g1.input) or len(g1_outputs) < len(g1.output):
187 e1 = utils.Extractor(helper.make_model(g1))
188 g1 = e1.extract_model(g1_inputs, g1_outputs).graph
190 if len(g2_inputs) < len(g2.input) or len(g2_outputs) < len(g2.output):
191 e2 = utils.Extractor(helper.make_model(g2))
192 g2 = e2.extract_model(g2_inputs, g2_outputs).graph
194 # Check that input/output names specified in the io_map argument are valid input/output names
195 for g1_out_name, g2_in_name in io_map:
196 if g1_out_name not in g1_outs:
197 raise ValueError(f"Output {g1_out_name} is not present in g1")
198 if g2_in_name not in g2_ins:
199 raise ValueError(f"Input {g2_in_name} is not present in g2")
201 # Check for name collision
202 overlapping_names = check_overlapping_names(g1, g2, io_map)
203 if len(overlapping_names) > 0:
204 category, names = overlapping_names[0]
205 raise ValueError(
206 "Cant merge two graphs with overlapping names. "
207 f"Found repeated {category} names: "
208 + ", ".join(names)
209 + "\n"
210 + "Consider using ``onnx.compose.add_prefix`` to add a prefix to names in one of the graphs."
211 )
213 g = GraphProto()
215 g.node.extend(g1.node)
216 g2_nodes_begin = len(g.node)
217 g.node.extend(g2.node)
218 g2_nodes_end = len(g.node)
220 # Search inputs of the subgraph recursively
221 def connect_io(sub_graph: GraphProto, start: int, end: int) -> None:
222 for node_idx in range(start, end):
223 node = sub_graph.node[node_idx]
224 for attr in node.attribute:
225 if attr.type == AttributeProto.GRAPH:
226 connect_io(attr.g, 0, len(attr.g.node))
227 elif attr.type == AttributeProto.GRAPHS:
228 for sub_g in attr.graphs:
229 connect_io(sub_g, 0, len(sub_g.node))
231 for index, name_ in enumerate(node.input):
232 if name_ in reversed_io_map:
233 node.input[index] = reversed_io_map[name_]
235 # Connecting outputs of the first graph with the inputs of the second
236 connect_io(g, g2_nodes_begin, g2_nodes_end)
238 if inputs:
239 input_set = set(inputs)
240 g.input.extend([i for i in g1.input if i.name in input_set])
241 g.input.extend([i for i in g2.input if i.name in input_set])
242 else:
243 g.input.extend(g1.input)
244 g.input.extend([i for i in g2.input if i.name not in io_map_g2_ins])
246 if outputs:
247 output_set = set(outputs)
248 g.output.extend([o for o in g1.output if o.name in output_set])
249 g.output.extend([o for o in g2.output if o.name in output_set])
250 else:
251 g.output.extend([o for o in g1.output if o.name not in io_map_g1_outs])
252 g.output.extend(g2.output)
254 g.initializer.extend(g1.initializer)
255 g.initializer.extend(
256 [init for init in g2.initializer if init.name not in io_map_g2_ins]
257 )
259 g.sparse_initializer.extend(g1.sparse_initializer)
260 g.sparse_initializer.extend(
261 [
262 init
263 for init in g2.sparse_initializer
264 if init.values.name not in io_map_g2_ins
265 ]
266 )
268 g.value_info.extend(g1.value_info)
269 g.value_info.extend([vi for vi in g2.value_info if vi.name not in io_map_g2_ins])
270 value_info_names = {vi.name for vi in g.value_info}
271 output_names = {o.name for o in g.output}
272 g.value_info.extend(
273 [
274 o
275 for o in g1.output
276 if o.name in io_map_g1_outs
277 and o.name not in value_info_names
278 and o.name not in output_names
279 ]
280 )
282 g.name = name if name is not None else f"{g1.name}_{g2.name}"
284 if doc_string is None:
285 doc_string = (
286 f"Graph combining {g1.name} and {g2.name}\n"
287 + g1.name
288 + "\n\n"
289 + g1.doc_string
290 + "\n\n"
291 + g2.name
292 + "\n\n"
293 + g2.doc_string
294 )
295 g.doc_string = doc_string
297 return g
300def merge_models(
301 m1: ModelProto,
302 m2: ModelProto,
303 io_map: list[tuple[str, str]],
304 inputs: list[str] | None = None,
305 outputs: list[str] | None = None,
306 prefix1: str | None = None,
307 prefix2: str | None = None,
308 name: str | None = None,
309 doc_string: str | None = None,
310 producer_name: str | None = "onnx.compose.merge_models",
311 producer_version: str | None = "1.0",
312 domain: str | None = "",
313 model_version: int | None = 1,
314) -> ModelProto:
315 """Combines two ONNX models into a single one.
317 The combined model is defined by connecting the specified set of outputs/inputs.
318 Those inputs/outputs not specified in the io_map argument will remain as
319 inputs/outputs of the combined model.
321 Both models should have the same IR version, and same operator sets imported.
323 Arguments:
324 m1 (ModelProto): First model
325 m2 (ModelProto): Second model
326 io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...]
327 representing outputs of the first graph and inputs of the second
328 to be connected
329 inputs (list of string): Optional list of inputs to be included in the combined graph
330 By default, all inputs not present in the ``io_map`` argument will be
331 included in the combined model
332 outputs (list of string): Optional list of outputs to be included in the combined graph
333 By default, all outputs not present in the ``io_map`` argument will be
334 included in the combined model
335 prefix1 (string): Optional prefix to be added to all names in m1
336 prefix2 (string): Optional prefix to be added to all names in m2
337 name (string): Optional name for the combined graph
338 By default, the name is g1.name and g2.name concatenated with an underscore delimiter
339 doc_string (string): Optional docstring for the combined graph
340 If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used
341 producer_name (string): Optional producer name for the combined model. Default: 'onnx.compose'
342 producer_version (string): Optional producer version for the combined model. Default: "1.0"
343 domain (string): Optional domain of the combined model. Default: ""
344 model_version (int): Optional version of the graph encoded. Default: 1
346 Returns:
347 ModelProto
348 """
349 if not isinstance(m1, ModelProto):
350 raise TypeError("m1 argument is not an ONNX model")
351 if not isinstance(m2, ModelProto):
352 raise TypeError("m2 argument is not an ONNX model")
354 if m1.ir_version != m2.ir_version:
355 raise ValueError(
356 f"IR version mismatch {m1.ir_version} != {m2.ir_version}."
357 " Both models should have the same IR version"
358 )
359 ir_version = m1.ir_version
361 opset_import_map: MutableMapping[str, int] = {}
362 opset_imports = list(m1.opset_import) + list(m2.opset_import)
364 for entry in opset_imports:
365 if entry.domain in opset_import_map:
366 found_version = opset_import_map[entry.domain]
367 if entry.version != found_version:
368 raise ValueError(
369 "Can't merge two models with different operator set ids for a given domain. "
370 f"Got: {m1.opset_import} and {m2.opset_import}"
371 )
372 else:
373 opset_import_map[entry.domain] = entry.version
375 # Prefixing names in the graph if requested, adjusting io_map accordingly
376 if prefix1 or prefix2:
377 if prefix1:
378 m1_copy = ModelProto()
379 m1_copy.CopyFrom(m1)
380 m1 = m1_copy
381 m1 = add_prefix(m1, prefix=prefix1)
382 if prefix2:
383 m2_copy = ModelProto()
384 m2_copy.CopyFrom(m2)
385 m2 = m2_copy
386 m2 = add_prefix(m2, prefix=prefix2)
387 io_map = [
388 (
389 prefix1 + io[0] if prefix1 else io[0],
390 prefix2 + io[1] if prefix2 else io[1],
391 )
392 for io in io_map
393 ]
395 graph = merge_graphs(
396 m1.graph,
397 m2.graph,
398 io_map,
399 inputs=inputs,
400 outputs=outputs,
401 name=name,
402 doc_string=doc_string,
403 )
404 model = helper.make_model(
405 graph,
406 producer_name=producer_name,
407 producer_version=producer_version,
408 domain=domain,
409 model_version=model_version,
410 opset_imports=opset_imports,
411 ir_version=ir_version,
412 )
414 # Merging model metadata props
415 model_props = {}
416 for meta_entry in m1.metadata_props:
417 model_props[meta_entry.key] = meta_entry.value
418 for meta_entry in m2.metadata_props:
419 if meta_entry.key in model_props:
420 value = model_props[meta_entry.key]
421 if value != meta_entry.value:
422 raise ValueError(
423 "Can't merge models with different values for the same model metadata property."
424 f" Found: property = {meta_entry.key}, with values {value} and {meta_entry.value}."
425 )
426 else:
427 model_props[meta_entry.key] = meta_entry.value
428 helper.set_model_props(model, model_props)
430 # Merging functions
431 function_overlap = list(
432 {f.name for f in m1.functions} & {f.name for f in m2.functions}
433 )
434 if function_overlap:
435 raise ValueError(
436 "Can't merge models with overlapping local function names."
437 " Found in both graphs: " + ", ".join(function_overlap)
438 )
439 model.functions.MergeFrom(m1.functions)
440 model.functions.MergeFrom(m2.functions)
442 checker.check_model(model)
443 return model
446def add_prefix_graph(
447 graph: GraphProto,
448 prefix: str,
449 rename_nodes: bool | None = True,
450 rename_edges: bool | None = True,
451 rename_inputs: bool | None = True,
452 rename_outputs: bool | None = True,
453 rename_initializers: bool | None = True,
454 rename_value_infos: bool | None = True,
455 inplace: bool | None = False,
456 name_map: dict[str, str] | None = None,
457) -> GraphProto:
458 """Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs,
459 initializers, sparse initializer, value infos.
461 It can be used as a utility before merging graphs that have overlapping names.
462 Empty names are not prefixed.
464 Arguments:
465 graph (GraphProto): Graph
466 prefix (str): Prefix to be added to each name in the graph
467 rename_nodes (bool): Whether to prefix node names
468 rename_edges (bool): Whether to prefix node edge names
469 rename_inputs (bool): Whether to prefix input names
470 rename_outputs (bool): Whether to prefix output names
471 rename_initializers (bool): Whether to prefix initializer and sparse initializer names
472 rename_value_infos (bool): Whether to prefix value info names
473 inplace (bool): If True, mutates the graph directly.
474 Otherwise, a copy will be created
475 name_map: (Dict): shared name_map in subgraph
477 Returns:
478 GraphProto
479 """
480 if not isinstance(graph, GraphProto):
481 raise TypeError("graph argument is not an ONNX graph")
483 if not inplace:
484 g = GraphProto()
485 g.CopyFrom(graph)
486 else:
487 g = graph
489 def _prefixed(prefix: str, name: str) -> str:
490 return prefix + name if len(name) > 0 else name
492 if name_map is None:
493 name_map = {}
495 if rename_edges:
496 # See https://github.com/onnx/onnx/pull/6869#issuecomment-2852719536.
497 # Consider only intermediate nodes, that are not connected to graph outputs.
498 # Rename graph inputs or outputs separately based on rename_inputs/rename_outputs flags.
499 graph_output_names = {o.name for o in g.output}
500 for n in g.node:
501 for e in n.output:
502 if e not in graph_output_names:
503 name_map[e] = _prefixed(prefix, e)
505 if rename_inputs:
506 for entry in g.input:
507 name_map[entry.name] = _prefixed(prefix, entry.name)
508 if rename_outputs:
509 for entry in g.output:
510 name_map[entry.name] = _prefixed(prefix, entry.name)
512 if rename_nodes:
513 for n in g.node:
514 n.name = _prefixed(prefix, n.name)
515 for attribute in n.attribute:
516 if attribute.HasField("g"):
517 add_prefix_graph(
518 attribute.g, prefix, inplace=True, name_map=name_map
519 )
520 for sub_g in attribute.graphs:
521 add_prefix_graph(sub_g, prefix, inplace=True, name_map=name_map)
523 if rename_initializers:
524 for init in g.initializer:
525 name_map[init.name] = _prefixed(prefix, init.name)
526 for sparse_init in g.sparse_initializer:
527 name_map[sparse_init.values.name] = _prefixed(
528 prefix, sparse_init.values.name
529 )
530 name_map[sparse_init.indices.name] = _prefixed(
531 prefix, sparse_init.indices.name
532 )
534 if rename_value_infos:
535 for entry in g.value_info:
536 name_map[entry.name] = _prefixed(prefix, entry.name)
538 for n in g.node:
539 for i, output in enumerate(n.output):
540 if n.output[i] in name_map:
541 n.output[i] = name_map[output]
542 for i, input_ in enumerate(n.input):
543 if n.input[i] in name_map:
544 n.input[i] = name_map[input_]
546 for in_desc in g.input:
547 if in_desc.name in name_map:
548 in_desc.name = name_map[in_desc.name]
549 for out_desc in g.output:
550 if out_desc.name in name_map:
551 out_desc.name = name_map[out_desc.name]
553 for initializer in g.initializer:
554 if initializer.name in name_map:
555 initializer.name = name_map[initializer.name]
556 for sparse_initializer in g.sparse_initializer:
557 if sparse_initializer.values.name in name_map:
558 sparse_initializer.values.name = name_map[sparse_initializer.values.name]
559 if sparse_initializer.indices.name in name_map:
560 sparse_initializer.indices.name = name_map[sparse_initializer.indices.name]
562 for value_info in g.value_info:
563 if value_info.name in name_map:
564 value_info.name = name_map[value_info.name]
566 return g
569def add_prefix(
570 model: ModelProto,
571 prefix: str,
572 rename_nodes: bool | None = True,
573 rename_edges: bool | None = True,
574 rename_inputs: bool | None = True,
575 rename_outputs: bool | None = True,
576 rename_initializers: bool | None = True,
577 rename_value_infos: bool | None = True,
578 rename_functions: bool | None = True,
579 inplace: bool | None = False,
580) -> ModelProto:
581 """Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs,
582 initializers, sparse initializer, value infos, and local functions.
584 It can be used as a utility before merging graphs that have overlapping names.
585 Empty names are not _prefixed.
587 Arguments:
588 model (ModelProto): Model
589 prefix (str): Prefix to be added to each name in the graph
590 rename_nodes (bool): Whether to prefix node names
591 rename_edges (bool): Whether to prefix node edge names
592 rename_inputs (bool): Whether to prefix input names
593 rename_outputs (bool): Whether to prefix output names
594 rename_initializers (bool): Whether to prefix initializer and sparse initializer names
595 rename_value_infos (bool): Whether to prefix value info nanes
596 rename_functions (bool): Whether to prefix local function names
597 inplace (bool): If True, mutates the model directly.
598 Otherwise, a copy will be created
600 Returns:
601 ModelProto
602 """
603 if not isinstance(model, ModelProto):
604 raise TypeError("model argument is not an ONNX model")
606 if not inplace:
607 m = ModelProto()
608 m.CopyFrom(model)
609 model = m
611 add_prefix_graph(
612 model.graph,
613 prefix,
614 rename_nodes=rename_nodes,
615 rename_edges=rename_edges,
616 rename_inputs=rename_inputs,
617 rename_outputs=rename_outputs,
618 rename_initializers=rename_initializers,
619 rename_value_infos=rename_value_infos,
620 inplace=True, # No need to create a copy, since it's a new model
621 )
623 if rename_functions:
624 f_name_map = {}
625 for f in model.functions:
626 new_f_name = prefix + f.name
627 f_name_map[f.name] = new_f_name
628 f.name = new_f_name
629 # Adjust references to local functions in other local function
630 # definitions
631 for f in model.functions:
632 for n in f.node:
633 if n.op_type in f_name_map:
634 n.op_type = f_name_map[n.op_type]
635 # Adjust references to local functions in the graph
636 for n in model.graph.node:
637 if n.op_type in f_name_map:
638 n.op_type = f_name_map[n.op_type]
640 return model
643def expand_out_dim_graph(
644 graph: GraphProto,
645 dim_idx: int,
646 inplace: bool | None = False,
647) -> GraphProto:
648 """Inserts an extra dimension with extent 1 to each output in the graph.
650 Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs,
651 for example when the second one expects a batch dimension.
653 Arguments:
654 graph (GraphProto): Graph
655 dim_idx (int): Index of the dimension to be inserted.
656 A negative value means counting dimensions from the back.
657 inplace (bool): If True, mutates the model directly.
658 Otherwise, a copy will be created
660 Returns:
661 GraphProto
662 """
663 if not isinstance(graph, GraphProto):
664 raise TypeError("graph argument is not an ONNX graph")
666 if not inplace:
667 g = GraphProto()
668 g.CopyFrom(graph)
669 else:
670 g = graph
672 orig_out_names = [output.name for output in g.output]
674 for n in g.node:
675 for i, out in enumerate(n.output):
676 if out in orig_out_names:
677 n.output[i] = out + f"_collapsed_dim_{dim_idx}"
678 for i, inp in enumerate(n.input):
679 if inp in orig_out_names:
680 n.input[i] = inp + f"_collapsed_dim_{dim_idx}"
682 expand_dim_k = g.name + "_expand_out_dim_idx"
683 g.node.append(
684 helper.make_node(
685 "Constant",
686 inputs=[],
687 outputs=[expand_dim_k],
688 name=f"{expand_dim_k}-constant",
689 value=helper.make_tensor(
690 name=f"{expand_dim_k}-value",
691 data_type=TensorProto.INT64,
692 dims=[
693 1,
694 ],
695 vals=[
696 dim_idx,
697 ],
698 ),
699 )
700 )
702 for _ in range(len(g.output)):
703 o = g.output.pop(0)
704 prev_output = o.name + f"_collapsed_dim_{dim_idx}"
705 g.node.append(
706 helper.make_node(
707 "Unsqueeze",
708 inputs=[prev_output, expand_dim_k],
709 outputs=[o.name],
710 name=f"unsqueeze-{o.name}",
711 )
712 )
713 new_shape = [d.dim_value for d in o.type.tensor_type.shape.dim]
714 new_shape.insert(dim_idx, 1)
715 g.output.append(
716 helper.make_tensor_value_info(
717 o.name, o.type.tensor_type.elem_type, new_shape
718 )
719 )
720 return g
723def expand_out_dim(
724 model: ModelProto,
725 dim_idx: int,
726 inplace: bool | None = False,
727) -> ModelProto:
728 """Inserts an extra dimension with extent 1 to each output in the graph.
730 Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs,
731 for example when the second one expects a batch dimension.
733 Arguments:
734 model (ModelProto): Model
735 dim_idx (int): Index of the dimension to be inserted.
736 A negative value means counting dimensions from the back.
737 inplace (bool): If True, mutates the model directly.
738 Otherwise, a copy will be created
740 Returns:
741 ModelProto
742 """
743 if not isinstance(model, ModelProto):
744 raise TypeError("model argument is not an ONNX model")
746 if not inplace:
747 m = ModelProto()
748 m.CopyFrom(model)
749 model = m
751 expand_out_dim_graph(
752 model.graph,
753 dim_idx,
754 inplace=True, # No need to create a copy, since it's a new model
755 )
756 return model