Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/external_data_helper.py: 21%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

172 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4from __future__ import annotations 

5 

6import os 

7import re 

8import sys 

9import uuid 

10import warnings 

11from itertools import chain 

12from typing import IO, TYPE_CHECKING 

13 

14import onnx.checker as onnx_checker 

15import onnx.onnx_cpp2py_export.checker as c_checker 

16from onnx.onnx_pb import ( 

17 AttributeProto, 

18 FunctionProto, 

19 GraphProto, 

20 ModelProto, 

21 TensorProto, 

22) 

23 

24if TYPE_CHECKING: 

25 from collections.abc import Callable, Iterable 

26 

27 

28def _open_external_data_fd( 

29 base_dir: str, location: str, tensor_name: str, read_only: bool 

30) -> int: 

31 """Open external data via C++ and return a CRT file descriptor.""" 

32 return c_checker._open_external_data(base_dir, location, tensor_name, read_only) 

33 

34 

35# Security: 3-layer defense against malicious external_data entries (GHSA-538c-55jv-c5g9) 

36# 

37# Layer 1 (here) — Attribute whitelist: Only spec-defined keys are accepted. 

38# Unknown keys are warned and ignored, preventing arbitrary attribute injection (CWE-915). 

39# 

40# Layer 2 (ExternalDataInfo.__init__) — Bounds validation: offset and length must be 

41# non-negative integers. Catches invalid values at parse time (CWE-400). 

42# 

43# Layer 3 (load_external_data_for_tensor) — File-size validation: offset and length are 

44# checked against actual file size before reading. This is the critical safety net that 

45# prevents memory exhaustion regardless of how the model was constructed (CWE-400). 

46# 

47# 'basepath' is included because set_external_data() and model_container 

48# write it to protobuf entries; it must survive save/load round-trips. 

49_ALLOWED_EXTERNAL_DATA_KEYS = frozenset( 

50 {"location", "offset", "length", "checksum", "basepath"} 

51) 

52_SORTED_ALLOWED_KEYS = sorted(_ALLOWED_EXTERNAL_DATA_KEYS) 

53_MAX_UNKNOWN_KEYS_IN_WARNING = 10 

54_MAX_KEY_DISPLAY_LENGTH = 100 

55 

56 

57class ExternalDataInfo: 

58 def __init__(self, tensor: TensorProto) -> None: 

59 self.location = "" 

60 self.offset = None 

61 self.length = None 

62 self.checksum = None 

63 self.basepath = "" 

64 

65 unknown_keys: set[str] = set() 

66 unknown_key_count = 0 

67 for entry in tensor.external_data: 

68 # Layer 1: reject unknown keys (CWE-915 defense-in-depth) 

69 if entry.key in _ALLOWED_EXTERNAL_DATA_KEYS: 

70 setattr(self, entry.key, entry.value) 

71 else: 

72 unknown_key_count += 1 

73 if len(unknown_keys) < _MAX_UNKNOWN_KEYS_IN_WARNING: 

74 truncated = entry.key[:_MAX_KEY_DISPLAY_LENGTH] 

75 if len(entry.key) > _MAX_KEY_DISPLAY_LENGTH: 

76 truncated += "..." 

77 unknown_keys.add(truncated) 

78 

79 if unknown_keys: 

80 shown = sorted(unknown_keys) 

81 extra = unknown_key_count - len(shown) 

82 key_list = repr(shown) 

83 if extra > 0: 

84 key_list += f" and {extra} more" 

85 warnings.warn( 

86 f"Ignoring unknown external data key(s) {key_list} " 

87 f"for tensor {tensor.name!r}. " 

88 f"Allowed keys: {_SORTED_ALLOWED_KEYS}", 

89 stacklevel=2, 

90 ) 

91 

92 if self.offset is not None: 

93 self.offset = int(self.offset) 

94 if self.offset < 0: 

95 raise ValueError( 

96 f"External data offset must be non-negative, got {self.offset} " 

97 f"for tensor {tensor.name!r}" 

98 ) 

99 

100 if self.length is not None: 

101 self.length = int(self.length) 

102 if self.length < 0: 

103 raise ValueError( 

104 f"External data length must be non-negative, got {self.length} " 

105 f"for tensor {tensor.name!r}" 

106 ) 

107 

108 

109def _validate_external_data_file_bounds( 

110 data_file: IO[bytes], 

111 info: ExternalDataInfo, 

112 tensor_name: str, 

113) -> bytes: 

114 """Validate offset/length against actual file size and read data. 

115 

116 Layer 3 defense-in-depth (CWE-400): prevents memory exhaustion even if the 

117 model was crafted via direct protobuf APIs that bypass Python parsing. 

118 

119 Returns the raw bytes read from the file. 

120 """ 

121 file_size = os.fstat(data_file.fileno()).st_size 

122 

123 if info.offset is not None: 

124 if info.offset > file_size: 

125 raise ValueError( 

126 f"External data offset ({info.offset}) exceeds file size " 

127 f"({file_size}) for tensor {tensor_name!r}" 

128 ) 

129 data_file.seek(info.offset) 

130 

131 if info.length is not None: 

132 read_start = info.offset if info.offset is not None else 0 

133 available = file_size - read_start 

134 if info.length > available: 

135 raise ValueError( 

136 f"External data length ({info.length}) exceeds available data " 

137 f"({available} bytes from offset {read_start}) " 

138 f"for tensor {tensor_name!r}" 

139 ) 

140 return data_file.read(info.length) 

141 return data_file.read() 

142 

143 

144def load_external_data_for_tensor(tensor: TensorProto, base_dir: str) -> None: 

145 """Loads data from an external file for tensor. 

146 Ideally TensorProto should not hold any raw data but if it does it will be ignored. 

147 

148 Arguments: 

149 tensor: a TensorProto object. 

150 base_dir: directory that contains the external data. 

151 """ 

152 info = ExternalDataInfo(tensor) 

153 fd = _open_external_data_fd(base_dir, info.location, tensor.name, True) 

154 with os.fdopen(fd, "rb") as data_file: 

155 tensor.raw_data = _validate_external_data_file_bounds( 

156 data_file, info, tensor.name 

157 ) 

158 

159 

160def load_external_data_for_model(model: ModelProto, base_dir: str) -> None: 

161 """Loads external tensors into model 

162 

163 Arguments: 

164 model: ModelProto to load external data to 

165 base_dir: directory that contains external data 

166 """ 

167 for tensor in _get_all_tensors(model): 

168 if uses_external_data(tensor): 

169 load_external_data_for_tensor(tensor, base_dir) 

170 # After loading raw_data from external_data, change the state of tensors 

171 tensor.data_location = TensorProto.DEFAULT 

172 # and remove external data 

173 del tensor.external_data[:] 

174 

175 

176def set_external_data( 

177 tensor: TensorProto, 

178 location: str, 

179 offset: int | None = None, 

180 length: int | None = None, 

181 checksum: str | None = None, 

182 basepath: str | None = None, 

183) -> None: 

184 if not tensor.HasField("raw_data"): 

185 raise ValueError( 

186 f"Tensor {tensor.name} does not have raw_data field. Cannot set external data for this tensor." 

187 ) 

188 

189 del tensor.external_data[:] 

190 tensor.data_location = TensorProto.EXTERNAL 

191 for k, v in { 

192 "location": location, 

193 "offset": int(offset) if offset is not None else None, 

194 "length": int(length) if length is not None else None, 

195 "checksum": checksum, 

196 "basepath": basepath, 

197 }.items(): 

198 if v is not None: 

199 entry = tensor.external_data.add() 

200 entry.key = k 

201 entry.value = str(v) 

202 

203 

204def convert_model_to_external_data( 

205 model: ModelProto, 

206 all_tensors_to_one_file: bool = True, 

207 location: str | None = None, 

208 size_threshold: int = 1024, 

209 convert_attribute: bool = False, 

210) -> None: 

211 """Call to set all tensors with raw data as external data. This call should precede 'save_model'. 

212 'save_model' saves all the tensors data as external data after calling this function. 

213 

214 Arguments: 

215 model (ModelProto): Model to be converted. 

216 all_tensors_to_one_file (bool): If true, save all tensors to one external file specified by location. 

217 If false, save each tensor to a file named with the tensor name. 

218 location: specify the external file relative to the model that all tensors to save to. 

219 Path is relative to the model path. 

220 If not specified, will use the model name. 

221 size_threshold: Threshold for size of data. Only when tensor's data is >= the size_threshold 

222 it will be converted to external data. To convert every tensor with raw data to external data set size_threshold=0. 

223 convert_attribute (bool): If true, convert all tensors to external data 

224 If false, convert only non-attribute tensors to external data 

225 

226 Raise: 

227 ValueError: If location is not a relative path. 

228 FileExistsError: If a file already exists in location. 

229 """ 

230 tensors = _get_initializer_tensors(model) 

231 if convert_attribute: 

232 tensors = _get_all_tensors(model) 

233 

234 if all_tensors_to_one_file: 

235 file_name = str(uuid.uuid1()) + ".data" 

236 if location: 

237 if os.path.isabs(location): 

238 raise ValueError( 

239 "location must be a relative path that is relative to the model path." 

240 ) 

241 if os.path.exists(location): 

242 raise FileExistsError(f"External data file exists in {location}.") 

243 file_name = location 

244 for tensor in tensors: 

245 if ( 

246 tensor.HasField("raw_data") 

247 and sys.getsizeof(tensor.raw_data) >= size_threshold 

248 ): 

249 set_external_data(tensor, file_name) 

250 else: 

251 for tensor in tensors: 

252 if ( 

253 tensor.HasField("raw_data") 

254 and sys.getsizeof(tensor.raw_data) >= size_threshold 

255 ): 

256 tensor_location = tensor.name 

257 if not _is_valid_filename(tensor_location): 

258 tensor_location = str(uuid.uuid1()) 

259 set_external_data(tensor, tensor_location) 

260 

261 

262def convert_model_from_external_data(model: ModelProto) -> None: 

263 """Call to set all tensors which use external data as embedded data. 

264 save_model saves all the tensors data as embedded data after 

265 calling this function. 

266 

267 Arguments: 

268 model (ModelProto): Model to be converted. 

269 """ 

270 for tensor in _get_all_tensors(model): 

271 if uses_external_data(tensor): 

272 if not tensor.HasField("raw_data"): 

273 raise ValueError("raw_data field doesn't exist.") 

274 del tensor.external_data[:] 

275 tensor.data_location = TensorProto.DEFAULT 

276 

277 

278def save_external_data(tensor: TensorProto, base_path: str) -> None: 

279 """Writes tensor data to an external file according to information in the `external_data` field. 

280 The function checks the external is a valid name and located in folder `base_path`. 

281 

282 Arguments: 

283 tensor (TensorProto): Tensor object to be serialized 

284 base_path: System path of a folder where tensor data is to be stored 

285 

286 Raises: 

287 ValueError: If the external file is invalid. 

288 """ 

289 info = ExternalDataInfo(tensor) 

290 

291 if not tensor.HasField("raw_data"): 

292 raise onnx_checker.ValidationError("raw_data field doesn't exist.") 

293 

294 fd = _open_external_data_fd(base_path, info.location, tensor.name, False) 

295 with os.fdopen(fd, "r+b") as data_file: 

296 data_file.seek(0, 2) 

297 if info.offset is not None: 

298 # Pad file to required offset if needed 

299 file_size = data_file.tell() 

300 if info.offset > file_size: 

301 data_file.write(b"\0" * (info.offset - file_size)) 

302 

303 data_file.seek(info.offset) 

304 offset = data_file.tell() 

305 data_file.write(tensor.raw_data) 

306 set_external_data(tensor, info.location, offset, data_file.tell() - offset) 

307 

308 

309def _get_all_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]: 

310 """Scan an ONNX model for all tensors and return as an iterator.""" 

311 return chain( 

312 _get_initializer_tensors(onnx_model_proto), 

313 _get_attribute_tensors(onnx_model_proto), 

314 ) 

315 

316 

317def _recursive_attribute_processor( 

318 attribute: AttributeProto, func: Callable[[GraphProto], Iterable[TensorProto]] 

319) -> Iterable[TensorProto]: 

320 """Create an iterator through processing ONNX model attributes with functor.""" 

321 if attribute.type == AttributeProto.GRAPH: 

322 yield from func(attribute.g) 

323 if attribute.type == AttributeProto.GRAPHS: 

324 for graph in attribute.graphs: 

325 yield from func(graph) 

326 

327 

328def _get_initializer_tensors_from_graph(graph: GraphProto, /) -> Iterable[TensorProto]: 

329 """Create an iterator of initializer tensors from ONNX model graph.""" 

330 yield from graph.initializer 

331 for node in graph.node: 

332 for attribute in node.attribute: 

333 yield from _recursive_attribute_processor( 

334 attribute, _get_initializer_tensors_from_graph 

335 ) 

336 

337 

338def _get_initializer_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]: 

339 """Create an iterator of initializer tensors from ONNX model.""" 

340 yield from _get_initializer_tensors_from_graph(onnx_model_proto.graph) 

341 

342 

343def _get_attribute_tensors_from_graph( 

344 graph_or_function: GraphProto | FunctionProto, / 

345) -> Iterable[TensorProto]: 

346 """Create an iterator of tensors from node attributes of an ONNX model graph/function.""" 

347 for node in graph_or_function.node: 

348 for attribute in node.attribute: 

349 if attribute.HasField("t"): 

350 yield attribute.t 

351 yield from attribute.tensors 

352 yield from _recursive_attribute_processor( 

353 attribute, _get_attribute_tensors_from_graph 

354 ) 

355 

356 

357def _get_attribute_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]: 

358 """Create an iterator of tensors from node attributes of an ONNX model.""" 

359 yield from _get_attribute_tensors_from_graph(onnx_model_proto.graph) 

360 for function in onnx_model_proto.functions: 

361 yield from _get_attribute_tensors_from_graph(function) 

362 

363 

364def _is_valid_filename(filename: str) -> bool: 

365 """Utility to check whether the provided filename is valid.""" 

366 exp = re.compile('^[^<>:;,?"*|/]+$') 

367 match = exp.match(filename) 

368 return bool(match) 

369 

370 

371def uses_external_data(tensor: TensorProto) -> bool: 

372 """Returns true if the tensor stores data in an external location.""" 

373 return ( 

374 tensor.HasField("data_location") 

375 and tensor.data_location == TensorProto.EXTERNAL 

376 ) 

377 

378 

379def remove_external_data_field(tensor: TensorProto, field_key: str) -> None: 

380 """Removes a field from a Tensor's external_data key-value store. 

381 

382 Modifies tensor object in place. 

383 

384 Arguments: 

385 tensor (TensorProto): Tensor object from which value will be removed 

386 field_key (string): The key of the field to be removed 

387 """ 

388 for i, field in enumerate(tensor.external_data): 

389 if field.key == field_key: 

390 del tensor.external_data[i] 

391 

392 

393def write_external_data_tensors(model: ModelProto, filepath: str) -> ModelProto: 

394 """Serializes data for all the tensors which have data location set to TensorProto.External. 

395 

396 Note: This function also strips basepath information from all tensors' external_data fields. 

397 

398 Arguments: 

399 model (ModelProto): Model object which is the source of tensors to serialize. 

400 filepath: System path to the directory which should be treated as base path for external data. 

401 

402 Returns: 

403 ModelProto: The modified model object. 

404 """ 

405 for tensor in _get_all_tensors(model): 

406 # Writing to external data happens in 2 passes: 

407 # 1. Tensors with raw data which pass the necessary conditions (size threshold etc) are marked for serialization 

408 # 2. The raw data in these tensors is serialized to a file 

409 # Thus serialize only if tensor has raw data and it was marked for serialization 

410 if uses_external_data(tensor) and tensor.HasField("raw_data"): 

411 save_external_data(tensor, filepath) 

412 tensor.ClearField("raw_data") 

413 

414 return model