1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
5
6import os
7import re
8import sys
9import uuid
10import warnings
11from itertools import chain
12from typing import IO, TYPE_CHECKING
13
14import onnx.checker as onnx_checker
15import onnx.onnx_cpp2py_export.checker as c_checker
16from onnx.onnx_pb import (
17 AttributeProto,
18 FunctionProto,
19 GraphProto,
20 ModelProto,
21 TensorProto,
22)
23
24if TYPE_CHECKING:
25 from collections.abc import Callable, Iterable
26
27
28def _open_external_data_fd(
29 base_dir: str, location: str, tensor_name: str, read_only: bool
30) -> int:
31 """Open external data via C++ and return a CRT file descriptor."""
32 return c_checker._open_external_data(base_dir, location, tensor_name, read_only)
33
34
35# Security: 3-layer defense against malicious external_data entries (GHSA-538c-55jv-c5g9)
36#
37# Layer 1 (here) — Attribute whitelist: Only spec-defined keys are accepted.
38# Unknown keys are warned and ignored, preventing arbitrary attribute injection (CWE-915).
39#
40# Layer 2 (ExternalDataInfo.__init__) — Bounds validation: offset and length must be
41# non-negative integers. Catches invalid values at parse time (CWE-400).
42#
43# Layer 3 (load_external_data_for_tensor) — File-size validation: offset and length are
44# checked against actual file size before reading. This is the critical safety net that
45# prevents memory exhaustion regardless of how the model was constructed (CWE-400).
46#
47# 'basepath' is included because set_external_data() and model_container
48# write it to protobuf entries; it must survive save/load round-trips.
49_ALLOWED_EXTERNAL_DATA_KEYS = frozenset(
50 {"location", "offset", "length", "checksum", "basepath"}
51)
52_SORTED_ALLOWED_KEYS = sorted(_ALLOWED_EXTERNAL_DATA_KEYS)
53_MAX_UNKNOWN_KEYS_IN_WARNING = 10
54_MAX_KEY_DISPLAY_LENGTH = 100
55
56
57class ExternalDataInfo:
58 def __init__(self, tensor: TensorProto) -> None:
59 self.location = ""
60 self.offset = None
61 self.length = None
62 self.checksum = None
63 self.basepath = ""
64
65 unknown_keys: set[str] = set()
66 unknown_key_count = 0
67 for entry in tensor.external_data:
68 # Layer 1: reject unknown keys (CWE-915 defense-in-depth)
69 if entry.key in _ALLOWED_EXTERNAL_DATA_KEYS:
70 setattr(self, entry.key, entry.value)
71 else:
72 unknown_key_count += 1
73 if len(unknown_keys) < _MAX_UNKNOWN_KEYS_IN_WARNING:
74 truncated = entry.key[:_MAX_KEY_DISPLAY_LENGTH]
75 if len(entry.key) > _MAX_KEY_DISPLAY_LENGTH:
76 truncated += "..."
77 unknown_keys.add(truncated)
78
79 if unknown_keys:
80 shown = sorted(unknown_keys)
81 extra = unknown_key_count - len(shown)
82 key_list = repr(shown)
83 if extra > 0:
84 key_list += f" and {extra} more"
85 warnings.warn(
86 f"Ignoring unknown external data key(s) {key_list} "
87 f"for tensor {tensor.name!r}. "
88 f"Allowed keys: {_SORTED_ALLOWED_KEYS}",
89 stacklevel=2,
90 )
91
92 if self.offset is not None:
93 self.offset = int(self.offset)
94 if self.offset < 0:
95 raise ValueError(
96 f"External data offset must be non-negative, got {self.offset} "
97 f"for tensor {tensor.name!r}"
98 )
99
100 if self.length is not None:
101 self.length = int(self.length)
102 if self.length < 0:
103 raise ValueError(
104 f"External data length must be non-negative, got {self.length} "
105 f"for tensor {tensor.name!r}"
106 )
107
108
109def _validate_external_data_file_bounds(
110 data_file: IO[bytes],
111 info: ExternalDataInfo,
112 tensor_name: str,
113) -> bytes:
114 """Validate offset/length against actual file size and read data.
115
116 Layer 3 defense-in-depth (CWE-400): prevents memory exhaustion even if the
117 model was crafted via direct protobuf APIs that bypass Python parsing.
118
119 Returns the raw bytes read from the file.
120 """
121 file_size = os.fstat(data_file.fileno()).st_size
122
123 if info.offset is not None:
124 if info.offset > file_size:
125 raise ValueError(
126 f"External data offset ({info.offset}) exceeds file size "
127 f"({file_size}) for tensor {tensor_name!r}"
128 )
129 data_file.seek(info.offset)
130
131 if info.length is not None:
132 read_start = info.offset if info.offset is not None else 0
133 available = file_size - read_start
134 if info.length > available:
135 raise ValueError(
136 f"External data length ({info.length}) exceeds available data "
137 f"({available} bytes from offset {read_start}) "
138 f"for tensor {tensor_name!r}"
139 )
140 return data_file.read(info.length)
141 return data_file.read()
142
143
144def load_external_data_for_tensor(tensor: TensorProto, base_dir: str) -> None:
145 """Loads data from an external file for tensor.
146 Ideally TensorProto should not hold any raw data but if it does it will be ignored.
147
148 Arguments:
149 tensor: a TensorProto object.
150 base_dir: directory that contains the external data.
151 """
152 info = ExternalDataInfo(tensor)
153 fd = _open_external_data_fd(base_dir, info.location, tensor.name, True)
154 with os.fdopen(fd, "rb") as data_file:
155 tensor.raw_data = _validate_external_data_file_bounds(
156 data_file, info, tensor.name
157 )
158
159
160def load_external_data_for_model(model: ModelProto, base_dir: str) -> None:
161 """Loads external tensors into model
162
163 Arguments:
164 model: ModelProto to load external data to
165 base_dir: directory that contains external data
166 """
167 for tensor in _get_all_tensors(model):
168 if uses_external_data(tensor):
169 load_external_data_for_tensor(tensor, base_dir)
170 # After loading raw_data from external_data, change the state of tensors
171 tensor.data_location = TensorProto.DEFAULT
172 # and remove external data
173 del tensor.external_data[:]
174
175
176def set_external_data(
177 tensor: TensorProto,
178 location: str,
179 offset: int | None = None,
180 length: int | None = None,
181 checksum: str | None = None,
182 basepath: str | None = None,
183) -> None:
184 if not tensor.HasField("raw_data"):
185 raise ValueError(
186 f"Tensor {tensor.name} does not have raw_data field. Cannot set external data for this tensor."
187 )
188
189 del tensor.external_data[:]
190 tensor.data_location = TensorProto.EXTERNAL
191 for k, v in {
192 "location": location,
193 "offset": int(offset) if offset is not None else None,
194 "length": int(length) if length is not None else None,
195 "checksum": checksum,
196 "basepath": basepath,
197 }.items():
198 if v is not None:
199 entry = tensor.external_data.add()
200 entry.key = k
201 entry.value = str(v)
202
203
204def convert_model_to_external_data(
205 model: ModelProto,
206 all_tensors_to_one_file: bool = True,
207 location: str | None = None,
208 size_threshold: int = 1024,
209 convert_attribute: bool = False,
210) -> None:
211 """Call to set all tensors with raw data as external data. This call should precede 'save_model'.
212 'save_model' saves all the tensors data as external data after calling this function.
213
214 Arguments:
215 model (ModelProto): Model to be converted.
216 all_tensors_to_one_file (bool): If true, save all tensors to one external file specified by location.
217 If false, save each tensor to a file named with the tensor name.
218 location: specify the external file relative to the model that all tensors to save to.
219 Path is relative to the model path.
220 If not specified, will use the model name.
221 size_threshold: Threshold for size of data. Only when tensor's data is >= the size_threshold
222 it will be converted to external data. To convert every tensor with raw data to external data set size_threshold=0.
223 convert_attribute (bool): If true, convert all tensors to external data
224 If false, convert only non-attribute tensors to external data
225
226 Raise:
227 ValueError: If location is not a relative path.
228 FileExistsError: If a file already exists in location.
229 """
230 tensors = _get_initializer_tensors(model)
231 if convert_attribute:
232 tensors = _get_all_tensors(model)
233
234 if all_tensors_to_one_file:
235 file_name = str(uuid.uuid1()) + ".data"
236 if location:
237 if os.path.isabs(location):
238 raise ValueError(
239 "location must be a relative path that is relative to the model path."
240 )
241 if os.path.exists(location):
242 raise FileExistsError(f"External data file exists in {location}.")
243 file_name = location
244 for tensor in tensors:
245 if (
246 tensor.HasField("raw_data")
247 and sys.getsizeof(tensor.raw_data) >= size_threshold
248 ):
249 set_external_data(tensor, file_name)
250 else:
251 for tensor in tensors:
252 if (
253 tensor.HasField("raw_data")
254 and sys.getsizeof(tensor.raw_data) >= size_threshold
255 ):
256 tensor_location = tensor.name
257 if not _is_valid_filename(tensor_location):
258 tensor_location = str(uuid.uuid1())
259 set_external_data(tensor, tensor_location)
260
261
262def convert_model_from_external_data(model: ModelProto) -> None:
263 """Call to set all tensors which use external data as embedded data.
264 save_model saves all the tensors data as embedded data after
265 calling this function.
266
267 Arguments:
268 model (ModelProto): Model to be converted.
269 """
270 for tensor in _get_all_tensors(model):
271 if uses_external_data(tensor):
272 if not tensor.HasField("raw_data"):
273 raise ValueError("raw_data field doesn't exist.")
274 del tensor.external_data[:]
275 tensor.data_location = TensorProto.DEFAULT
276
277
278def save_external_data(tensor: TensorProto, base_path: str) -> None:
279 """Writes tensor data to an external file according to information in the `external_data` field.
280 The function checks the external is a valid name and located in folder `base_path`.
281
282 Arguments:
283 tensor (TensorProto): Tensor object to be serialized
284 base_path: System path of a folder where tensor data is to be stored
285
286 Raises:
287 ValueError: If the external file is invalid.
288 """
289 info = ExternalDataInfo(tensor)
290
291 if not tensor.HasField("raw_data"):
292 raise onnx_checker.ValidationError("raw_data field doesn't exist.")
293
294 fd = _open_external_data_fd(base_path, info.location, tensor.name, False)
295 with os.fdopen(fd, "r+b") as data_file:
296 data_file.seek(0, 2)
297 if info.offset is not None:
298 # Pad file to required offset if needed
299 file_size = data_file.tell()
300 if info.offset > file_size:
301 data_file.write(b"\0" * (info.offset - file_size))
302
303 data_file.seek(info.offset)
304 offset = data_file.tell()
305 data_file.write(tensor.raw_data)
306 set_external_data(tensor, info.location, offset, data_file.tell() - offset)
307
308
309def _get_all_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
310 """Scan an ONNX model for all tensors and return as an iterator."""
311 return chain(
312 _get_initializer_tensors(onnx_model_proto),
313 _get_attribute_tensors(onnx_model_proto),
314 )
315
316
317def _recursive_attribute_processor(
318 attribute: AttributeProto, func: Callable[[GraphProto], Iterable[TensorProto]]
319) -> Iterable[TensorProto]:
320 """Create an iterator through processing ONNX model attributes with functor."""
321 if attribute.type == AttributeProto.GRAPH:
322 yield from func(attribute.g)
323 if attribute.type == AttributeProto.GRAPHS:
324 for graph in attribute.graphs:
325 yield from func(graph)
326
327
328def _get_initializer_tensors_from_graph(graph: GraphProto, /) -> Iterable[TensorProto]:
329 """Create an iterator of initializer tensors from ONNX model graph."""
330 yield from graph.initializer
331 for node in graph.node:
332 for attribute in node.attribute:
333 yield from _recursive_attribute_processor(
334 attribute, _get_initializer_tensors_from_graph
335 )
336
337
338def _get_initializer_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
339 """Create an iterator of initializer tensors from ONNX model."""
340 yield from _get_initializer_tensors_from_graph(onnx_model_proto.graph)
341
342
343def _get_attribute_tensors_from_graph(
344 graph_or_function: GraphProto | FunctionProto, /
345) -> Iterable[TensorProto]:
346 """Create an iterator of tensors from node attributes of an ONNX model graph/function."""
347 for node in graph_or_function.node:
348 for attribute in node.attribute:
349 if attribute.HasField("t"):
350 yield attribute.t
351 yield from attribute.tensors
352 yield from _recursive_attribute_processor(
353 attribute, _get_attribute_tensors_from_graph
354 )
355
356
357def _get_attribute_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
358 """Create an iterator of tensors from node attributes of an ONNX model."""
359 yield from _get_attribute_tensors_from_graph(onnx_model_proto.graph)
360 for function in onnx_model_proto.functions:
361 yield from _get_attribute_tensors_from_graph(function)
362
363
364def _is_valid_filename(filename: str) -> bool:
365 """Utility to check whether the provided filename is valid."""
366 exp = re.compile('^[^<>:;,?"*|/]+$')
367 match = exp.match(filename)
368 return bool(match)
369
370
371def uses_external_data(tensor: TensorProto) -> bool:
372 """Returns true if the tensor stores data in an external location."""
373 return (
374 tensor.HasField("data_location")
375 and tensor.data_location == TensorProto.EXTERNAL
376 )
377
378
379def remove_external_data_field(tensor: TensorProto, field_key: str) -> None:
380 """Removes a field from a Tensor's external_data key-value store.
381
382 Modifies tensor object in place.
383
384 Arguments:
385 tensor (TensorProto): Tensor object from which value will be removed
386 field_key (string): The key of the field to be removed
387 """
388 for i, field in enumerate(tensor.external_data):
389 if field.key == field_key:
390 del tensor.external_data[i]
391
392
393def write_external_data_tensors(model: ModelProto, filepath: str) -> ModelProto:
394 """Serializes data for all the tensors which have data location set to TensorProto.External.
395
396 Note: This function also strips basepath information from all tensors' external_data fields.
397
398 Arguments:
399 model (ModelProto): Model object which is the source of tensors to serialize.
400 filepath: System path to the directory which should be treated as base path for external data.
401
402 Returns:
403 ModelProto: The modified model object.
404 """
405 for tensor in _get_all_tensors(model):
406 # Writing to external data happens in 2 passes:
407 # 1. Tensors with raw data which pass the necessary conditions (size threshold etc) are marked for serialization
408 # 2. The raw data in these tensors is serialized to a file
409 # Thus serialize only if tensor has raw data and it was marked for serialization
410 if uses_external_data(tensor) and tensor.HasField("raw_data"):
411 save_external_data(tensor, filepath)
412 tensor.ClearField("raw_data")
413
414 return model