Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/compose.py: 4%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

304 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4from __future__ import annotations 

5 

6from typing import TYPE_CHECKING 

7 

8from onnx import ( 

9 AttributeProto, 

10 GraphProto, 

11 ModelProto, 

12 TensorProto, 

13 checker, 

14 helper, 

15 utils, 

16) 

17 

18if TYPE_CHECKING: 

19 from collections.abc import MutableMapping 

20 

21 

22def check_overlapping_names( 

23 g1: GraphProto, g2: GraphProto, io_map: list[tuple[str, str]] | None = None 

24) -> list[tuple[str, list[str]]]: 

25 """Checks whether there are name collisions between two graphs 

26 

27 Returns a list of tuples where the first element represents the member containing overlapping names 

28 (One of: "node", "edge", "value_info", "initializer", "sparse_initializer"), and the 

29 second element contains a list of names that appear in both graphs on that category. 

30 

31 Optionally, it takes an io_map, representing the output/inputs to be connected. It provided, overlapping 

32 present in the io_map argument will be ignored. 

33 """ 

34 if not isinstance(g1, GraphProto): 

35 raise TypeError("g1 argument is not an ONNX graph") 

36 if not isinstance(g2, GraphProto): 

37 raise TypeError("g2 argument is not an ONNX graph") 

38 

39 def _overlapping(c1: list[str], c2: list[str]) -> list[str]: 

40 return list(set(c1) & set(c2)) 

41 

42 def _edge_names(graph: GraphProto, exclude: set[str] | None = None) -> list[str]: 

43 if exclude is None: 

44 exclude = set() 

45 edges = [] 

46 for n in graph.node: 

47 for i in n.input: 

48 if i != "" and i not in exclude: 

49 edges.append(i) # noqa: PERF401 

50 for o in n.output: 

51 if o != "" and o not in exclude: 

52 edges.append(o) # noqa: PERF401 

53 return edges 

54 

55 result = [] 

56 

57 if not io_map: 

58 io_map = [] 

59 io_map_inputs = {elem[1] for elem in io_map} 

60 

61 # Edges already cover input/output 

62 overlap = _overlapping(_edge_names(g1), _edge_names(g2, exclude=io_map_inputs)) 

63 if overlap: 

64 result.append(("edge", overlap)) 

65 

66 overlap = _overlapping( 

67 [e.name for e in g1.value_info], [e.name for e in g2.value_info] 

68 ) 

69 if overlap: 

70 result.append(("value_info", overlap)) 

71 

72 overlap = _overlapping( 

73 [e.name for e in g1.initializer], [e.name for e in g2.initializer] 

74 ) 

75 if overlap: 

76 result.append(("initializer", overlap)) 

77 

78 overlap = _overlapping( 

79 [e.values.name for e in g1.sparse_initializer], 

80 [e.values.name for e in g2.sparse_initializer], 

81 ) + _overlapping( 

82 [e.indices.name for e in g1.sparse_initializer], 

83 [e.indices.name for e in g2.sparse_initializer], 

84 ) 

85 if overlap: 

86 result.append(("sparse_initializer", overlap)) 

87 

88 return result 

89 

90 

91def merge_graphs( 

92 g1: GraphProto, 

93 g2: GraphProto, 

94 io_map: list[tuple[str, str]], 

95 inputs: list[str] | None = None, 

96 outputs: list[str] | None = None, 

97 prefix1: str | None = None, 

98 prefix2: str | None = None, 

99 name: str | None = None, 

100 doc_string: str | None = None, 

101) -> GraphProto: 

102 """Combines two ONNX graphs into a single one. 

103 

104 The combined graph is defined by connecting the specified set of outputs/inputs. Those inputs/outputs 

105 not specified in the io_map argument will remain as inputs/outputs of the combined graph. 

106 

107 Arguments: 

108 g1 (GraphProto): First graph 

109 g2 (GraphProto): Second graph 

110 io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...] 

111 representing outputs of the first graph and inputs of the second 

112 to be connected 

113 inputs (list of string): Optional list of inputs to be included in the combined graph 

114 By default, all inputs not present in the ``io_map`` argument will be 

115 included in the combined model 

116 outputs (list of string): Optional list of outputs to be included in the combined graph 

117 By default, all outputs not present in the ``io_map`` argument will be 

118 included in the combined model 

119 prefix1 (string): Optional prefix to be added to all names in g1 

120 prefix2 (string): Optional prefix to be added to all names in g2 

121 name (string): Optional name for the combined graph 

122 By default, the name is g1.name and g2.name concatenated with an underscore delimiter 

123 doc_string (string): Optional docstring for the combined graph 

124 If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used 

125 

126 Returns: 

127 GraphProto 

128 """ 

129 if not isinstance(g1, GraphProto): 

130 raise TypeError("g1 argument is not an ONNX graph") 

131 if not isinstance(g2, GraphProto): 

132 raise TypeError("g2 argument is not an ONNX graph") 

133 

134 # Prefixing names in the graph if requested, adjusting io_map accordingly 

135 if prefix1 or prefix2: 

136 if prefix1: 

137 g1_copy = GraphProto() 

138 g1_copy.CopyFrom(g1) 

139 g1 = g1_copy 

140 g1 = add_prefix_graph(g1, prefix=prefix1) 

141 if prefix2: 

142 g2_copy = GraphProto() 

143 g2_copy.CopyFrom(g2) 

144 g2 = g2_copy 

145 g2 = add_prefix_graph(g2, prefix=prefix2) 

146 io_map = [ 

147 ( 

148 prefix1 + io[0] if prefix1 else io[0], 

149 prefix2 + io[1] if prefix2 else io[1], 

150 ) 

151 for io in io_map 

152 ] 

153 

154 io_map_g1_outs = {io[0] for io in io_map} 

155 io_map_g2_ins = {io[1] for io in io_map} 

156 reversed_io_map = {in_name: out_name for out_name, in_name in io_map} 

157 g1_outs = {o.name for o in g1.output} 

158 g2_ins = {i.name for i in g2.input} 

159 

160 # If necessary extract subgraphs 

161 if inputs or outputs: 

162 if not inputs: 

163 g1_inputs = [i.name for i in g1.input] 

164 g2_inputs = [i.name for i in g2.input] 

165 else: 

166 input_set = set(inputs) 

167 g1_inputs = [i.name for i in g1.input if i.name in input_set] 

168 g2_inputs = [ 

169 i.name 

170 for i in g2.input 

171 if i.name in input_set or i.name in io_map_g2_ins 

172 ] 

173 

174 if not outputs: 

175 g1_outputs = [o.name for o in g1.output] 

176 g2_outputs = [o.name for o in g2.output] 

177 else: 

178 output_set = set(outputs) 

179 g1_outputs = [ 

180 o.name 

181 for o in g1.output 

182 if o.name in output_set or o.name in io_map_g1_outs 

183 ] 

184 g2_outputs = [o.name for o in g2.output if o.name in output_set] 

185 

186 if len(g1_inputs) < len(g1.input) or len(g1_outputs) < len(g1.output): 

187 e1 = utils.Extractor(helper.make_model(g1)) 

188 g1 = e1.extract_model(g1_inputs, g1_outputs).graph 

189 

190 if len(g2_inputs) < len(g2.input) or len(g2_outputs) < len(g2.output): 

191 e2 = utils.Extractor(helper.make_model(g2)) 

192 g2 = e2.extract_model(g2_inputs, g2_outputs).graph 

193 

194 # Check that input/output names specified in the io_map argument are valid input/output names 

195 for g1_out_name, g2_in_name in io_map: 

196 if g1_out_name not in g1_outs: 

197 raise ValueError(f"Output {g1_out_name} is not present in g1") 

198 if g2_in_name not in g2_ins: 

199 raise ValueError(f"Input {g2_in_name} is not present in g2") 

200 

201 # Check for name collision 

202 overlapping_names = check_overlapping_names(g1, g2, io_map) 

203 if len(overlapping_names) > 0: 

204 category, names = overlapping_names[0] 

205 raise ValueError( 

206 "Cant merge two graphs with overlapping names. " 

207 f"Found repeated {category} names: " 

208 + ", ".join(names) 

209 + "\n" 

210 + "Consider using ``onnx.compose.add_prefix`` to add a prefix to names in one of the graphs." 

211 ) 

212 

213 g = GraphProto() 

214 

215 g.node.extend(g1.node) 

216 g2_nodes_begin = len(g.node) 

217 g.node.extend(g2.node) 

218 g2_nodes_end = len(g.node) 

219 

220 # Search inputs of the subgraph recursively 

221 def connect_io(sub_graph: GraphProto, start: int, end: int) -> None: 

222 for node_idx in range(start, end): 

223 node = sub_graph.node[node_idx] 

224 for attr in node.attribute: 

225 if attr.type == AttributeProto.GRAPH: 

226 connect_io(attr.g, 0, len(attr.g.node)) 

227 elif attr.type == AttributeProto.GRAPHS: 

228 for sub_g in attr.graphs: 

229 connect_io(sub_g, 0, len(sub_g.node)) 

230 

231 for index, name_ in enumerate(node.input): 

232 if name_ in reversed_io_map: 

233 node.input[index] = reversed_io_map[name_] 

234 

235 # Connecting outputs of the first graph with the inputs of the second 

236 connect_io(g, g2_nodes_begin, g2_nodes_end) 

237 

238 if inputs: 

239 input_set = set(inputs) 

240 g.input.extend([i for i in g1.input if i.name in input_set]) 

241 g.input.extend([i for i in g2.input if i.name in input_set]) 

242 else: 

243 g.input.extend(g1.input) 

244 g.input.extend([i for i in g2.input if i.name not in io_map_g2_ins]) 

245 

246 if outputs: 

247 output_set = set(outputs) 

248 g.output.extend([o for o in g1.output if o.name in output_set]) 

249 g.output.extend([o for o in g2.output if o.name in output_set]) 

250 else: 

251 g.output.extend([o for o in g1.output if o.name not in io_map_g1_outs]) 

252 g.output.extend(g2.output) 

253 

254 g.initializer.extend(g1.initializer) 

255 g.initializer.extend( 

256 [init for init in g2.initializer if init.name not in io_map_g2_ins] 

257 ) 

258 

259 g.sparse_initializer.extend(g1.sparse_initializer) 

260 g.sparse_initializer.extend( 

261 [ 

262 init 

263 for init in g2.sparse_initializer 

264 if init.values.name not in io_map_g2_ins 

265 ] 

266 ) 

267 

268 g.value_info.extend(g1.value_info) 

269 g.value_info.extend([vi for vi in g2.value_info if vi.name not in io_map_g2_ins]) 

270 value_info_names = {vi.name for vi in g.value_info} 

271 output_names = {o.name for o in g.output} 

272 g.value_info.extend( 

273 [ 

274 o 

275 for o in g1.output 

276 if o.name in io_map_g1_outs 

277 and o.name not in value_info_names 

278 and o.name not in output_names 

279 ] 

280 ) 

281 

282 g.name = name if name is not None else f"{g1.name}_{g2.name}" 

283 

284 if doc_string is None: 

285 doc_string = ( 

286 f"Graph combining {g1.name} and {g2.name}\n" 

287 + g1.name 

288 + "\n\n" 

289 + g1.doc_string 

290 + "\n\n" 

291 + g2.name 

292 + "\n\n" 

293 + g2.doc_string 

294 ) 

295 g.doc_string = doc_string 

296 

297 return g 

298 

299 

300def merge_models( 

301 m1: ModelProto, 

302 m2: ModelProto, 

303 io_map: list[tuple[str, str]], 

304 inputs: list[str] | None = None, 

305 outputs: list[str] | None = None, 

306 prefix1: str | None = None, 

307 prefix2: str | None = None, 

308 name: str | None = None, 

309 doc_string: str | None = None, 

310 producer_name: str | None = "onnx.compose.merge_models", 

311 producer_version: str | None = "1.0", 

312 domain: str | None = "", 

313 model_version: int | None = 1, 

314) -> ModelProto: 

315 """Combines two ONNX models into a single one. 

316 

317 The combined model is defined by connecting the specified set of outputs/inputs. 

318 Those inputs/outputs not specified in the io_map argument will remain as 

319 inputs/outputs of the combined model. 

320 

321 Both models should have the same IR version, and same operator sets imported. 

322 

323 Arguments: 

324 m1 (ModelProto): First model 

325 m2 (ModelProto): Second model 

326 io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...] 

327 representing outputs of the first graph and inputs of the second 

328 to be connected 

329 inputs (list of string): Optional list of inputs to be included in the combined graph 

330 By default, all inputs not present in the ``io_map`` argument will be 

331 included in the combined model 

332 outputs (list of string): Optional list of outputs to be included in the combined graph 

333 By default, all outputs not present in the ``io_map`` argument will be 

334 included in the combined model 

335 prefix1 (string): Optional prefix to be added to all names in m1 

336 prefix2 (string): Optional prefix to be added to all names in m2 

337 name (string): Optional name for the combined graph 

338 By default, the name is g1.name and g2.name concatenated with an underscore delimiter 

339 doc_string (string): Optional docstring for the combined graph 

340 If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used 

341 producer_name (string): Optional producer name for the combined model. Default: 'onnx.compose' 

342 producer_version (string): Optional producer version for the combined model. Default: "1.0" 

343 domain (string): Optional domain of the combined model. Default: "" 

344 model_version (int): Optional version of the graph encoded. Default: 1 

345 

346 Returns: 

347 ModelProto 

348 """ 

349 if not isinstance(m1, ModelProto): 

350 raise TypeError("m1 argument is not an ONNX model") 

351 if not isinstance(m2, ModelProto): 

352 raise TypeError("m2 argument is not an ONNX model") 

353 

354 if m1.ir_version != m2.ir_version: 

355 raise ValueError( 

356 f"IR version mismatch {m1.ir_version} != {m2.ir_version}." 

357 " Both models should have the same IR version" 

358 ) 

359 ir_version = m1.ir_version 

360 

361 opset_import_map: MutableMapping[str, int] = {} 

362 opset_imports = list(m1.opset_import) + list(m2.opset_import) 

363 

364 for entry in opset_imports: 

365 if entry.domain in opset_import_map: 

366 found_version = opset_import_map[entry.domain] 

367 if entry.version != found_version: 

368 raise ValueError( 

369 "Can't merge two models with different operator set ids for a given domain. " 

370 f"Got: {m1.opset_import} and {m2.opset_import}" 

371 ) 

372 else: 

373 opset_import_map[entry.domain] = entry.version 

374 

375 # Prefixing names in the graph if requested, adjusting io_map accordingly 

376 if prefix1 or prefix2: 

377 if prefix1: 

378 m1_copy = ModelProto() 

379 m1_copy.CopyFrom(m1) 

380 m1 = m1_copy 

381 m1 = add_prefix(m1, prefix=prefix1) 

382 if prefix2: 

383 m2_copy = ModelProto() 

384 m2_copy.CopyFrom(m2) 

385 m2 = m2_copy 

386 m2 = add_prefix(m2, prefix=prefix2) 

387 io_map = [ 

388 ( 

389 prefix1 + io[0] if prefix1 else io[0], 

390 prefix2 + io[1] if prefix2 else io[1], 

391 ) 

392 for io in io_map 

393 ] 

394 

395 graph = merge_graphs( 

396 m1.graph, 

397 m2.graph, 

398 io_map, 

399 inputs=inputs, 

400 outputs=outputs, 

401 name=name, 

402 doc_string=doc_string, 

403 ) 

404 model = helper.make_model( 

405 graph, 

406 producer_name=producer_name, 

407 producer_version=producer_version, 

408 domain=domain, 

409 model_version=model_version, 

410 opset_imports=opset_imports, 

411 ir_version=ir_version, 

412 ) 

413 

414 # Merging model metadata props 

415 model_props = {} 

416 for meta_entry in m1.metadata_props: 

417 model_props[meta_entry.key] = meta_entry.value 

418 for meta_entry in m2.metadata_props: 

419 if meta_entry.key in model_props: 

420 value = model_props[meta_entry.key] 

421 if value != meta_entry.value: 

422 raise ValueError( 

423 "Can't merge models with different values for the same model metadata property." 

424 f" Found: property = {meta_entry.key}, with values {value} and {meta_entry.value}." 

425 ) 

426 else: 

427 model_props[meta_entry.key] = meta_entry.value 

428 helper.set_model_props(model, model_props) 

429 

430 # Merging functions 

431 function_overlap = list( 

432 {f.name for f in m1.functions} & {f.name for f in m2.functions} 

433 ) 

434 if function_overlap: 

435 raise ValueError( 

436 "Can't merge models with overlapping local function names." 

437 " Found in both graphs: " + ", ".join(function_overlap) 

438 ) 

439 model.functions.MergeFrom(m1.functions) 

440 model.functions.MergeFrom(m2.functions) 

441 

442 checker.check_model(model) 

443 return model 

444 

445 

446def add_prefix_graph( 

447 graph: GraphProto, 

448 prefix: str, 

449 rename_nodes: bool | None = True, 

450 rename_edges: bool | None = True, 

451 rename_inputs: bool | None = True, 

452 rename_outputs: bool | None = True, 

453 rename_initializers: bool | None = True, 

454 rename_value_infos: bool | None = True, 

455 inplace: bool | None = False, 

456 name_map: dict[str, str] | None = None, 

457) -> GraphProto: 

458 """Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs, 

459 initializers, sparse initializer, value infos. 

460 

461 It can be used as a utility before merging graphs that have overlapping names. 

462 Empty names are not prefixed. 

463 

464 Arguments: 

465 graph (GraphProto): Graph 

466 prefix (str): Prefix to be added to each name in the graph 

467 rename_nodes (bool): Whether to prefix node names 

468 rename_edges (bool): Whether to prefix node edge names 

469 rename_inputs (bool): Whether to prefix input names 

470 rename_outputs (bool): Whether to prefix output names 

471 rename_initializers (bool): Whether to prefix initializer and sparse initializer names 

472 rename_value_infos (bool): Whether to prefix value info names 

473 inplace (bool): If True, mutates the graph directly. 

474 Otherwise, a copy will be created 

475 name_map: (Dict): shared name_map in subgraph 

476 

477 Returns: 

478 GraphProto 

479 """ 

480 if not isinstance(graph, GraphProto): 

481 raise TypeError("graph argument is not an ONNX graph") 

482 

483 if not inplace: 

484 g = GraphProto() 

485 g.CopyFrom(graph) 

486 else: 

487 g = graph 

488 

489 def _prefixed(prefix: str, name: str) -> str: 

490 return prefix + name if len(name) > 0 else name 

491 

492 if name_map is None: 

493 name_map = {} 

494 

495 if rename_edges: 

496 # See https://github.com/onnx/onnx/pull/6869#issuecomment-2852719536. 

497 # Consider only intermediate nodes, that are not connected to graph outputs. 

498 # Rename graph inputs or outputs separately based on rename_inputs/rename_outputs flags. 

499 graph_output_names = {o.name for o in g.output} 

500 for n in g.node: 

501 for e in n.output: 

502 if e not in graph_output_names: 

503 name_map[e] = _prefixed(prefix, e) 

504 

505 if rename_inputs: 

506 for entry in g.input: 

507 name_map[entry.name] = _prefixed(prefix, entry.name) 

508 if rename_outputs: 

509 for entry in g.output: 

510 name_map[entry.name] = _prefixed(prefix, entry.name) 

511 

512 if rename_nodes: 

513 for n in g.node: 

514 n.name = _prefixed(prefix, n.name) 

515 for attribute in n.attribute: 

516 if attribute.HasField("g"): 

517 add_prefix_graph( 

518 attribute.g, prefix, inplace=True, name_map=name_map 

519 ) 

520 for sub_g in attribute.graphs: 

521 add_prefix_graph(sub_g, prefix, inplace=True, name_map=name_map) 

522 

523 if rename_initializers: 

524 for init in g.initializer: 

525 name_map[init.name] = _prefixed(prefix, init.name) 

526 for sparse_init in g.sparse_initializer: 

527 name_map[sparse_init.values.name] = _prefixed( 

528 prefix, sparse_init.values.name 

529 ) 

530 name_map[sparse_init.indices.name] = _prefixed( 

531 prefix, sparse_init.indices.name 

532 ) 

533 

534 if rename_value_infos: 

535 for entry in g.value_info: 

536 name_map[entry.name] = _prefixed(prefix, entry.name) 

537 

538 for n in g.node: 

539 for i, output in enumerate(n.output): 

540 if n.output[i] in name_map: 

541 n.output[i] = name_map[output] 

542 for i, input_ in enumerate(n.input): 

543 if n.input[i] in name_map: 

544 n.input[i] = name_map[input_] 

545 

546 for in_desc in g.input: 

547 if in_desc.name in name_map: 

548 in_desc.name = name_map[in_desc.name] 

549 for out_desc in g.output: 

550 if out_desc.name in name_map: 

551 out_desc.name = name_map[out_desc.name] 

552 

553 for initializer in g.initializer: 

554 if initializer.name in name_map: 

555 initializer.name = name_map[initializer.name] 

556 for sparse_initializer in g.sparse_initializer: 

557 if sparse_initializer.values.name in name_map: 

558 sparse_initializer.values.name = name_map[sparse_initializer.values.name] 

559 if sparse_initializer.indices.name in name_map: 

560 sparse_initializer.indices.name = name_map[sparse_initializer.indices.name] 

561 

562 for value_info in g.value_info: 

563 if value_info.name in name_map: 

564 value_info.name = name_map[value_info.name] 

565 

566 return g 

567 

568 

569def add_prefix( 

570 model: ModelProto, 

571 prefix: str, 

572 rename_nodes: bool | None = True, 

573 rename_edges: bool | None = True, 

574 rename_inputs: bool | None = True, 

575 rename_outputs: bool | None = True, 

576 rename_initializers: bool | None = True, 

577 rename_value_infos: bool | None = True, 

578 rename_functions: bool | None = True, 

579 inplace: bool | None = False, 

580) -> ModelProto: 

581 """Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs, 

582 initializers, sparse initializer, value infos, and local functions. 

583 

584 It can be used as a utility before merging graphs that have overlapping names. 

585 Empty names are not _prefixed. 

586 

587 Arguments: 

588 model (ModelProto): Model 

589 prefix (str): Prefix to be added to each name in the graph 

590 rename_nodes (bool): Whether to prefix node names 

591 rename_edges (bool): Whether to prefix node edge names 

592 rename_inputs (bool): Whether to prefix input names 

593 rename_outputs (bool): Whether to prefix output names 

594 rename_initializers (bool): Whether to prefix initializer and sparse initializer names 

595 rename_value_infos (bool): Whether to prefix value info nanes 

596 rename_functions (bool): Whether to prefix local function names 

597 inplace (bool): If True, mutates the model directly. 

598 Otherwise, a copy will be created 

599 

600 Returns: 

601 ModelProto 

602 """ 

603 if not isinstance(model, ModelProto): 

604 raise TypeError("model argument is not an ONNX model") 

605 

606 if not inplace: 

607 m = ModelProto() 

608 m.CopyFrom(model) 

609 model = m 

610 

611 add_prefix_graph( 

612 model.graph, 

613 prefix, 

614 rename_nodes=rename_nodes, 

615 rename_edges=rename_edges, 

616 rename_inputs=rename_inputs, 

617 rename_outputs=rename_outputs, 

618 rename_initializers=rename_initializers, 

619 rename_value_infos=rename_value_infos, 

620 inplace=True, # No need to create a copy, since it's a new model 

621 ) 

622 

623 if rename_functions: 

624 f_name_map = {} 

625 for f in model.functions: 

626 new_f_name = prefix + f.name 

627 f_name_map[f.name] = new_f_name 

628 f.name = new_f_name 

629 # Adjust references to local functions in other local function 

630 # definitions 

631 for f in model.functions: 

632 for n in f.node: 

633 if n.op_type in f_name_map: 

634 n.op_type = f_name_map[n.op_type] 

635 # Adjust references to local functions in the graph 

636 for n in model.graph.node: 

637 if n.op_type in f_name_map: 

638 n.op_type = f_name_map[n.op_type] 

639 

640 return model 

641 

642 

643def expand_out_dim_graph( 

644 graph: GraphProto, 

645 dim_idx: int, 

646 inplace: bool | None = False, 

647) -> GraphProto: 

648 """Inserts an extra dimension with extent 1 to each output in the graph. 

649 

650 Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs, 

651 for example when the second one expects a batch dimension. 

652 

653 Arguments: 

654 graph (GraphProto): Graph 

655 dim_idx (int): Index of the dimension to be inserted. 

656 A negative value means counting dimensions from the back. 

657 inplace (bool): If True, mutates the model directly. 

658 Otherwise, a copy will be created 

659 

660 Returns: 

661 GraphProto 

662 """ 

663 if not isinstance(graph, GraphProto): 

664 raise TypeError("graph argument is not an ONNX graph") 

665 

666 if not inplace: 

667 g = GraphProto() 

668 g.CopyFrom(graph) 

669 else: 

670 g = graph 

671 

672 orig_out_names = [output.name for output in g.output] 

673 

674 for n in g.node: 

675 for i, out in enumerate(n.output): 

676 if out in orig_out_names: 

677 n.output[i] = out + f"_collapsed_dim_{dim_idx}" 

678 for i, inp in enumerate(n.input): 

679 if inp in orig_out_names: 

680 n.input[i] = inp + f"_collapsed_dim_{dim_idx}" 

681 

682 expand_dim_k = g.name + "_expand_out_dim_idx" 

683 g.node.append( 

684 helper.make_node( 

685 "Constant", 

686 inputs=[], 

687 outputs=[expand_dim_k], 

688 name=f"{expand_dim_k}-constant", 

689 value=helper.make_tensor( 

690 name=f"{expand_dim_k}-value", 

691 data_type=TensorProto.INT64, 

692 dims=[ 

693 1, 

694 ], 

695 vals=[ 

696 dim_idx, 

697 ], 

698 ), 

699 ) 

700 ) 

701 

702 for _ in range(len(g.output)): 

703 o = g.output.pop(0) 

704 prev_output = o.name + f"_collapsed_dim_{dim_idx}" 

705 g.node.append( 

706 helper.make_node( 

707 "Unsqueeze", 

708 inputs=[prev_output, expand_dim_k], 

709 outputs=[o.name], 

710 name=f"unsqueeze-{o.name}", 

711 ) 

712 ) 

713 new_shape = [d.dim_value for d in o.type.tensor_type.shape.dim] 

714 new_shape.insert(dim_idx, 1) 

715 g.output.append( 

716 helper.make_tensor_value_info( 

717 o.name, o.type.tensor_type.elem_type, new_shape 

718 ) 

719 ) 

720 return g 

721 

722 

723def expand_out_dim( 

724 model: ModelProto, 

725 dim_idx: int, 

726 inplace: bool | None = False, 

727) -> ModelProto: 

728 """Inserts an extra dimension with extent 1 to each output in the graph. 

729 

730 Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs, 

731 for example when the second one expects a batch dimension. 

732 

733 Arguments: 

734 model (ModelProto): Model 

735 dim_idx (int): Index of the dimension to be inserted. 

736 A negative value means counting dimensions from the back. 

737 inplace (bool): If True, mutates the model directly. 

738 Otherwise, a copy will be created 

739 

740 Returns: 

741 ModelProto 

742 """ 

743 if not isinstance(model, ModelProto): 

744 raise TypeError("model argument is not an ONNX model") 

745 

746 if not inplace: 

747 m = ModelProto() 

748 m.CopyFrom(model) 

749 model = m 

750 

751 expand_out_dim_graph( 

752 model.graph, 

753 dim_idx, 

754 inplace=True, # No need to create a copy, since it's a new model 

755 ) 

756 return model