Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/numpy_helper.py: 9%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

287 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4from __future__ import annotations 

5 

6import math 

7import sys 

8from typing import TYPE_CHECKING, Any 

9 

10import ml_dtypes 

11import numpy as np 

12import numpy.typing as npt 

13 

14import onnx.external_data_helper 

15from onnx import helper 

16 

17if TYPE_CHECKING: 

18 from collections.abc import Sequence 

19 

20 

21def to_float8e8m0( 

22 x: np.ndarray, 

23 saturate: bool = True, 

24 round_mode: str = "up", 

25) -> np.ndarray: 

26 """Convert float32 NumPy array to float8e8m0 representation. If the input 

27 is not a float32 array, it will be cast to one first. 

28 

29 Args: 

30 x: Input array to convert. 

31 saturate: Whether to saturate at max/min float8e8m0 value. 

32 round_mode: "nearest", "up", or "down". 

33 

34 Returns: 

35 np.ndarray: Array of ml_dtypes.float8_e8m0fnu values. 

36 """ 

37 x_f32 = np.asarray(x, dtype=np.float32) 

38 f_bits = x_f32.view(np.uint32) 

39 

40 # Extract exponent bits 

41 exponent = (f_bits >> 23) & 0xFF 

42 exponent = exponent.astype( 

43 np.uint16 

44 ) # use uint16 to prevent overflow during computation 

45 

46 # Identify NaN or Inf 

47 special_mask = exponent == 0xFF # noqa: PLR2004 

48 output = np.zeros_like(exponent, dtype=np.uint8) 

49 output[special_mask] = 0xFF # Preserve NaN/Inf as max exponent 

50 

51 # Process normal numbers 

52 normal_mask = ~special_mask 

53 

54 if round_mode == "nearest": 

55 # Get guard, round, sticky, and least significant bits 

56 g = ((f_bits & 0x400000) > 0).astype(np.uint8) 

57 r = ((f_bits & 0x200000) > 0).astype(np.uint8) 

58 s = ((f_bits & 0x1FFFFF) > 0).astype(np.uint8) 

59 lsb = (exponent > 0).astype(np.uint8) 

60 

61 round_up = (g == 1) & ((r == 1) | (s == 1) | (lsb == 1)) 

62 

63 increment = np.zeros_like(exponent) 

64 increment[round_up & normal_mask] = 1 

65 

66 if saturate: 

67 max_mask = (exponent == 0xFE) & round_up & normal_mask # noqa: PLR2004 

68 increment[max_mask] = 0 # Don't overflow past max value 

69 

70 exponent += increment 

71 

72 elif round_mode == "up": 

73 has_fraction = (f_bits & 0x7FFFFF) > 0 

74 round_up = has_fraction & normal_mask 

75 

76 if saturate: 

77 max_mask = (exponent == 0xFE) & round_up # noqa: PLR2004 

78 round_up[max_mask] = False 

79 

80 exponent += round_up.astype(np.uint16) 

81 

82 elif round_mode == "down": 

83 pass # No rounding needed 

84 

85 else: 

86 raise ValueError(f"Unsupported rounding mode: {round_mode}") 

87 

88 # Clip exponent to uint8 range 

89 exponent = exponent.astype(np.uint8) 

90 

91 output[normal_mask] = exponent[normal_mask] 

92 

93 return output.view(ml_dtypes.float8_e8m0fnu) 

94 

95 

96def _unpack_4bit( 

97 data: npt.NDArray[np.uint8], dims: Sequence[int] 

98) -> npt.NDArray[np.uint8]: 

99 """Convert a packed uint4 array to unpacked uint4 array represented as uint8. 

100 

101 Args: 

102 data: A numpy array. 

103 dims: The dimensions are used to reshape the unpacked buffer. 

104 

105 Returns: 

106 A numpy array of int8/uint8 reshaped to dims. 

107 """ 

108 result = np.empty([data.size * 2], dtype=data.dtype) 

109 array_low = data & np.uint8(0x0F) 

110 array_high = data & np.uint8(0xF0) 

111 array_high >>= np.uint8(4) 

112 result[0::2] = array_low 

113 result[1::2] = array_high 

114 expected_elements = math.prod(dims) 

115 if result.size == expected_elements + 1: 

116 # handle single-element padding due to odd number of elements 

117 result = result[:-1] 

118 if expected_elements > result.size: 

119 raise ValueError( 

120 f"Packed 4-bit data ({data.size} bytes, {result.size} elements unpacked) " 

121 f"is too small for the declared shape {list(dims)} " 

122 f"({expected_elements} elements required)." 

123 ) 

124 result.resize(dims, refcheck=False) 

125 return result 

126 

127 

128def _pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]: 

129 """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range.""" 

130 # Create a 1D copy 

131 array_flat = array.ravel().view(np.uint8).copy() 

132 size = array.size 

133 odd_sized = size % 2 == 1 

134 if odd_sized: 

135 array_flat.resize([size + 1], refcheck=False) 

136 array_flat &= 0x0F 

137 array_flat[1::2] <<= 4 

138 return array_flat[0::2] | array_flat[1::2] 

139 

140 

141def _unpack_2bit( 

142 data: npt.NDArray[np.uint8], dims: Sequence[int] 

143) -> npt.NDArray[np.uint8]: 

144 """Convert a packed uint2 array to unpacked uint2 array represented as uint8. 

145 

146 Args: 

147 data: A numpy array. 

148 dims: The dimensions are used to reshape the unpacked buffer. 

149 

150 Returns: 

151 A numpy array of int8/uint8 reshaped to dims. 

152 """ 

153 result = np.empty([data.size * 4], dtype=data.dtype) 

154 result[0::4] = data & 0x03 

155 result[1::4] = (data >> 2) & 0x03 

156 result[2::4] = (data >> 4) & 0x03 

157 result[3::4] = (data >> 6) & 0x03 

158 expected_elements = math.prod(dims) 

159 if result.size > expected_elements: 

160 # handle padding due to non multiple of 4 elements 

161 result = result[:expected_elements] 

162 if expected_elements > result.size: 

163 raise ValueError( 

164 f"Packed 2-bit data ({data.size} bytes, {result.size} elements unpacked) " 

165 f"is too small for the declared shape {list(dims)} " 

166 f"({expected_elements} elements required)." 

167 ) 

168 result.resize(dims, refcheck=False) 

169 return result 

170 

171 

172def _pack_2bitx4(array: np.ndarray) -> npt.NDArray[np.uint8]: 

173 """Convert a numpy array to flatten, packed int2/uint2. Elements must be in the correct range.""" 

174 # Create a 1D copy 

175 array_flat = array.ravel().view(np.uint8).copy() 

176 size = array.size 

177 pad_len = size % 4 

178 if pad_len: 

179 array_flat.resize([size + (4 - pad_len)], refcheck=False) 

180 array_flat &= 0x03 

181 array_flat[1::4] <<= 2 

182 array_flat[2::4] <<= 4 

183 array_flat[3::4] <<= 6 

184 return array_flat[0::4] | array_flat[1::4] | array_flat[2::4] | array_flat[3::4] 

185 

186 

187def to_array(tensor: onnx.TensorProto, base_dir: str = "") -> np.ndarray: # noqa: PLR0911 

188 """Converts a tensor def object to a numpy array. 

189 

190 This function uses ml_dtypes if the dtype is not a native numpy dtype. 

191 

192 Args: 

193 tensor: a TensorProto object. 

194 base_dir: if external tensor exists, base_dir can help to find the path to it 

195 

196 Returns: 

197 arr: the converted array. 

198 """ 

199 if tensor.HasField("segment"): 

200 raise ValueError("Currently not supporting loading segments.") 

201 if tensor.data_type == onnx.TensorProto.UNDEFINED: 

202 raise TypeError("The element type in the input tensor is UNDEFINED.") 

203 

204 tensor_dtype = tensor.data_type 

205 np_dtype = helper.tensor_dtype_to_np_dtype(tensor_dtype) 

206 storage_np_dtype = helper.tensor_dtype_to_np_dtype( 

207 helper.tensor_dtype_to_storage_tensor_dtype(tensor_dtype) 

208 ) 

209 storage_field = helper.tensor_dtype_to_field(tensor_dtype) 

210 dims = tensor.dims 

211 

212 if tensor.data_type == onnx.TensorProto.STRING: 

213 utf8_strings = getattr(tensor, storage_field) 

214 ss = [s.decode("utf-8") for s in utf8_strings] 

215 return np.asarray(ss).astype(np_dtype).reshape(dims) 

216 

217 # Load raw data from external tensor if it exists 

218 if onnx.external_data_helper.uses_external_data(tensor): 

219 onnx.external_data_helper.load_external_data_for_tensor(tensor, base_dir) 

220 

221 if tensor.HasField("raw_data"): 

222 # Raw_bytes support: using frombuffer. 

223 raw_data = tensor.raw_data 

224 if sys.byteorder == "big": 

225 # Convert endian from little to big 

226 raw_data = np.frombuffer(raw_data, dtype=np_dtype).byteswap().tobytes() 

227 

228 if tensor_dtype in { 

229 onnx.TensorProto.INT4, 

230 onnx.TensorProto.UINT4, 

231 onnx.TensorProto.FLOAT4E2M1, 

232 }: 

233 data = np.frombuffer(raw_data, dtype=np.uint8) 

234 return _unpack_4bit(data, dims).view(np_dtype) 

235 

236 if tensor_dtype in { 

237 onnx.TensorProto.UINT2, 

238 onnx.TensorProto.INT2, 

239 }: 

240 data = np.frombuffer(raw_data, dtype=np.uint8) 

241 return _unpack_2bit(data, dims).view(np_dtype) 

242 

243 return np.frombuffer(raw_data, dtype=np_dtype).reshape(dims) 

244 

245 if tensor_dtype in { 

246 onnx.TensorProto.BFLOAT16, 

247 onnx.TensorProto.FLOAT16, 

248 onnx.TensorProto.INT16, 

249 onnx.TensorProto.UINT16, 

250 }: 

251 return ( 

252 np.array(tensor.int32_data, dtype=np.int32) 

253 .view(np.uint32) 

254 .astype(np.uint16) 

255 .reshape(dims) 

256 .view(np_dtype) 

257 ) 

258 

259 if tensor_dtype in { 

260 onnx.TensorProto.FLOAT8E4M3FN, 

261 onnx.TensorProto.FLOAT8E4M3FNUZ, 

262 onnx.TensorProto.FLOAT8E5M2, 

263 onnx.TensorProto.FLOAT8E5M2FNUZ, 

264 onnx.TensorProto.FLOAT8E8M0, 

265 onnx.TensorProto.BOOL, 

266 }: 

267 return ( 

268 np.array(tensor.int32_data, dtype=np.int32) 

269 .view(np.uint32) 

270 .astype(np.uint8) 

271 .view(np_dtype) 

272 .reshape(dims) 

273 ) 

274 

275 if tensor_dtype in { 

276 onnx.TensorProto.UINT4, 

277 onnx.TensorProto.INT4, 

278 onnx.TensorProto.FLOAT4E2M1, 

279 }: 

280 data = ( 

281 np.array(tensor.int32_data, dtype=np.int32).view(np.uint32).astype(np.uint8) 

282 ) 

283 return _unpack_4bit(data, dims).view(np_dtype) 

284 

285 if tensor_dtype in { 

286 onnx.TensorProto.UINT2, 

287 onnx.TensorProto.INT2, 

288 }: 

289 data = ( 

290 np.array(tensor.int32_data, dtype=np.int32).view(np.uint32).astype(np.uint8) 

291 ) 

292 return _unpack_2bit(data, dims).view(np_dtype) 

293 

294 data = getattr(tensor, storage_field) 

295 if tensor_dtype in (onnx.TensorProto.COMPLEX64, onnx.TensorProto.COMPLEX128): 

296 return np.array(data, dtype=storage_np_dtype).view(dtype=np_dtype).reshape(dims) 

297 

298 return np.asarray(data, dtype=storage_np_dtype).astype(np_dtype).reshape(dims) 

299 

300 

301def tobytes_little_endian(array: np.ndarray) -> bytes: 

302 """Converts an array into bytes in little endian byte order. 

303 

304 Args: 

305 array: a numpy array. 

306 

307 Returns: 

308 bytes: Byte representation of passed array in little endian byte order. 

309 

310 .. versionadded:: 1.20 

311 """ 

312 if array.dtype.byteorder == ">" or ( 

313 sys.byteorder == "big" and array.dtype.byteorder == "=" 

314 ): 

315 # Ensure that the bytes will be in little-endian byte-order. 

316 array = array.astype(array.dtype.newbyteorder("<")) 

317 

318 return array.tobytes() 

319 

320 

321def from_array(array: np.ndarray, /, name: str | None = None) -> onnx.TensorProto: 

322 """Converts an array into a TensorProto including 

323 

324 Args: 

325 array: a numpy array. 

326 name: (optional) the name of the tensor. 

327 

328 Returns: 

329 TensorProto: the converted tensor def. 

330 """ 

331 tensor = onnx.TensorProto() 

332 tensor.dims.extend(array.shape) 

333 if name: 

334 tensor.name = name 

335 if array.dtype == object or np.issubdtype(array.dtype, np.str_): 

336 # Special care for strings. 

337 tensor.data_type = onnx.TensorProto.STRING 

338 # TODO: Introduce full string support. 

339 # We flatten the array in case there are n-D arrays are specified 

340 # If you want more complex shapes then follow the below instructions. 

341 # Unlike other types where the shape is automatically inferred from 

342 # nested arrays of values, the only reliable way now to feed strings 

343 # is to put them into a flat array then specify type astype(object) 

344 # (otherwise all strings may have different types depending on their length) 

345 # and then specify shape .reshape([x, y, z]) 

346 flat_array = array.flatten() 

347 for e in flat_array: 

348 if isinstance(e, str): 

349 tensor.string_data.append(e.encode("utf-8")) 

350 elif isinstance(e, bytes): 

351 tensor.string_data.append(e) 

352 else: 

353 raise NotImplementedError( 

354 f"Unrecognized object in the object array, expect a string, or array of bytes: {type(e)}" 

355 ) 

356 return tensor 

357 

358 dtype = helper.np_dtype_to_tensor_dtype(array.dtype) 

359 if dtype in { 

360 onnx.TensorProto.INT4, 

361 onnx.TensorProto.UINT4, 

362 onnx.TensorProto.FLOAT4E2M1, 

363 }: 

364 # Pack the array into int4 

365 array = _pack_4bitx2(array) 

366 

367 if dtype in { 

368 onnx.TensorProto.UINT2, 

369 onnx.TensorProto.INT2, 

370 }: 

371 # Pack the array into int2 

372 array = _pack_2bitx4(array) 

373 

374 tensor.raw_data = tobytes_little_endian(array) 

375 tensor.data_type = dtype # type: ignore[assignment] 

376 return tensor 

377 

378 

379def to_list(sequence: onnx.SequenceProto) -> list[Any]: 

380 """Converts a sequence def to a Python list. 

381 

382 Args: 

383 sequence: a SequenceProto object. 

384 

385 Returns: 

386 list: the converted list. 

387 """ 

388 elem_type = sequence.elem_type 

389 if elem_type == onnx.SequenceProto.TENSOR: 

390 return [to_array(v) for v in sequence.tensor_values] 

391 if elem_type == onnx.SequenceProto.SPARSE_TENSOR: 

392 return [to_array(v) for v in sequence.sparse_tensor_values] # type: ignore[arg-type] 

393 if elem_type == onnx.SequenceProto.SEQUENCE: 

394 return [to_list(v) for v in sequence.sequence_values] 

395 if elem_type == onnx.SequenceProto.MAP: 

396 return [to_dict(v) for v in sequence.map_values] 

397 raise TypeError("The element type in the input sequence is not supported.") 

398 

399 

400def from_list( 

401 lst: list[Any], name: str | None = None, dtype: int | None = None 

402) -> onnx.SequenceProto: 

403 """Converts a list into a sequence def. 

404 

405 Args: 

406 lst: a Python list 

407 name: (optional) the name of the sequence. 

408 dtype: (optional) type of element in the input list, used for specifying 

409 sequence values when converting an empty list. 

410 

411 Returns: 

412 SequenceProto: the converted sequence def. 

413 """ 

414 sequence = onnx.SequenceProto() 

415 if name: 

416 sequence.name = name 

417 

418 if dtype is not None: 

419 elem_type = dtype 

420 elif len(lst) > 0: 

421 first_elem = lst[0] 

422 if isinstance(first_elem, dict): 

423 elem_type = onnx.SequenceProto.MAP 

424 elif isinstance(first_elem, list): 

425 elem_type = onnx.SequenceProto.SEQUENCE 

426 else: 

427 elem_type = onnx.SequenceProto.TENSOR 

428 else: 

429 # if empty input list and no dtype specified 

430 # choose sequence of tensors on default 

431 elem_type = onnx.SequenceProto.TENSOR 

432 sequence.elem_type = elem_type 

433 

434 if (len(lst) > 0) and not all(isinstance(elem, type(lst[0])) for elem in lst): 

435 raise TypeError( 

436 "The element type in the input list is not the same " 

437 "for all elements and therefore is not supported as a sequence." 

438 ) 

439 

440 if elem_type == onnx.SequenceProto.TENSOR: 

441 for tensor in lst: 

442 sequence.tensor_values.extend([from_array(np.asarray(tensor))]) 

443 elif elem_type == onnx.SequenceProto.SEQUENCE: 

444 for seq in lst: 

445 sequence.sequence_values.extend([from_list(seq)]) 

446 elif elem_type == onnx.SequenceProto.MAP: 

447 for mapping in lst: 

448 sequence.map_values.extend([from_dict(mapping)]) 

449 else: 

450 raise TypeError( 

451 "The element type in the input list is not a tensor, " 

452 "sequence, or map and is not supported." 

453 ) 

454 return sequence 

455 

456 

457def to_dict(map_proto: onnx.MapProto) -> dict[Any, Any]: 

458 """Converts a map def to a Python dictionary. 

459 

460 Args: 

461 map_proto: a MapProto object. 

462 

463 Returns: 

464 The converted dictionary. 

465 """ 

466 key_list: list[Any] = [] 

467 if map_proto.key_type == onnx.TensorProto.STRING: 

468 key_list = list(map_proto.string_keys) 

469 else: 

470 key_list = list(map_proto.keys) 

471 

472 value_list = to_list(map_proto.values) 

473 if len(key_list) != len(value_list): 

474 raise IndexError( 

475 f"Length of keys and values for MapProto (map name: {map_proto.name}) are not the same." 

476 ) 

477 return dict(zip(key_list, value_list, strict=False)) 

478 

479 

480def from_dict(dict_: dict[Any, Any], name: str | None = None) -> onnx.MapProto: 

481 """Converts a Python dictionary into a map def. 

482 

483 Args: 

484 dict_: Python dictionary 

485 name: (optional) the name of the map. 

486 

487 Returns: 

488 MapProto: the converted map def. 

489 """ 

490 map_proto = onnx.MapProto() 

491 if name: 

492 map_proto.name = name 

493 if not dict_: 

494 raise ValueError("Cannot convert an empty dictionary to MapProto.") 

495 keys = list(dict_) 

496 raw_key_type = np.result_type(keys[0]) 

497 key_type = helper.np_dtype_to_tensor_dtype(raw_key_type) 

498 

499 valid_key_int_types = { 

500 onnx.TensorProto.INT8, 

501 onnx.TensorProto.INT16, 

502 onnx.TensorProto.INT32, 

503 onnx.TensorProto.INT64, 

504 onnx.TensorProto.UINT8, 

505 onnx.TensorProto.UINT16, 

506 onnx.TensorProto.UINT32, 

507 onnx.TensorProto.UINT64, 

508 } 

509 

510 if not (all(np.result_type(key) == raw_key_type for key in keys)): 

511 raise TypeError( 

512 "The key type in the input dictionary is not the same " 

513 "for all keys and therefore is not valid as a map." 

514 ) 

515 

516 values = list(dict_.values()) 

517 raw_value_type = np.result_type(values[0]) 

518 if not all(np.result_type(val) == raw_value_type for val in values): 

519 raise TypeError( 

520 "The value type in the input dictionary is not the same " 

521 "for all values and therefore is not valid as a map." 

522 ) 

523 

524 value_seq = from_list(values) 

525 

526 map_proto.key_type = key_type # type: ignore[assignment] 

527 if key_type == onnx.TensorProto.STRING: 

528 map_proto.string_keys.extend(keys) 

529 elif key_type in valid_key_int_types: 

530 map_proto.keys.extend(keys) 

531 else: 

532 raise TypeError(f"Unsupported map key type: {key_type}") 

533 map_proto.values.CopyFrom(value_seq) 

534 return map_proto 

535 

536 

537def to_optional(optional: onnx.OptionalProto) -> Any | None: 

538 """Converts an optional def to a Python optional. 

539 

540 Args: 

541 optional: an OptionalProto object. 

542 

543 Returns: 

544 opt: the converted optional. 

545 """ 

546 elem_type = optional.elem_type 

547 if elem_type == onnx.OptionalProto.UNDEFINED: 

548 return None 

549 if elem_type == onnx.OptionalProto.TENSOR: 

550 return to_array(optional.tensor_value) 

551 if elem_type == onnx.OptionalProto.SPARSE_TENSOR: 

552 return to_array(optional.sparse_tensor_value) # type: ignore[arg-type] 

553 if elem_type == onnx.OptionalProto.SEQUENCE: 

554 return to_list(optional.sequence_value) 

555 if elem_type == onnx.OptionalProto.MAP: 

556 return to_dict(optional.map_value) 

557 if elem_type == onnx.OptionalProto.OPTIONAL: 

558 return to_optional(optional.optional_value) 

559 raise TypeError("The element type in the input optional is not supported.") 

560 

561 

562def from_optional( 

563 opt: Any | None, name: str | None = None, dtype: int | None = None 

564) -> onnx.OptionalProto: 

565 """Converts an optional value into a Optional def. 

566 

567 Args: 

568 opt: a Python optional 

569 name: (optional) the name of the optional. 

570 dtype: (optional) type of element in the input, used for specifying 

571 optional values when converting empty none. dtype must 

572 be a valid OptionalProto.DataType value 

573 

574 Returns: 

575 optional: the converted optional def. 

576 """ 

577 # TODO: create a map and replace conditional branches 

578 optional = onnx.OptionalProto() 

579 if name: 

580 optional.name = name 

581 

582 if dtype is not None: 

583 # dtype must be a valid onnx.OptionalProto.DataType 

584 if dtype not in onnx.OptionalProto.DataType.values(): 

585 raise TypeError(f"{dtype} must be a valid OptionalProto.DataType.") 

586 elem_type = dtype 

587 elif isinstance(opt, dict): 

588 elem_type = onnx.OptionalProto.MAP 

589 elif isinstance(opt, list): 

590 elem_type = onnx.OptionalProto.SEQUENCE 

591 elif opt is None: 

592 elem_type = onnx.OptionalProto.UNDEFINED 

593 else: 

594 elem_type = onnx.OptionalProto.TENSOR 

595 

596 optional.elem_type = elem_type 

597 

598 if opt is not None: 

599 if elem_type == onnx.OptionalProto.TENSOR: 

600 optional.tensor_value.CopyFrom(from_array(opt)) 

601 elif elem_type == onnx.OptionalProto.SEQUENCE: 

602 optional.sequence_value.CopyFrom(from_list(opt)) 

603 elif elem_type == onnx.OptionalProto.MAP: 

604 optional.map_value.CopyFrom(from_dict(opt)) 

605 else: 

606 raise TypeError( 

607 "The element type in the input is not a tensor, " 

608 "sequence, or map and is not supported." 

609 ) 

610 return optional 

611 

612 

613def create_random_int( 

614 input_shape: tuple[int], dtype: np.dtype, seed: int = 1 

615) -> np.ndarray: 

616 """Create random integer array for backend/test/case/node. 

617 

618 Args: 

619 input_shape: The shape for the returned integer array. 

620 dtype: The NumPy data type for the returned integer array. 

621 seed: The seed for np.random. 

622 

623 Returns: 

624 np.ndarray: Random integer array. 

625 """ 

626 np.random.seed(seed) 

627 if dtype in ( 

628 np.uint8, 

629 np.uint16, 

630 np.uint32, 

631 np.uint64, 

632 np.int8, 

633 np.int16, 

634 np.int32, 

635 np.int64, 

636 ): 

637 # the range of np.random.randint is int32; set a fixed boundary if overflow 

638 end = min(np.iinfo(dtype).max, np.iinfo(np.int32).max) 

639 start = max(np.iinfo(dtype).min, np.iinfo(np.int32).min) 

640 return np.random.randint(start, end, size=input_shape).astype(dtype) 

641 raise TypeError(f"{dtype} is not supported by create_random_int.") 

642 

643 

644def saturate_cast(x: np.ndarray, dtype: np.dtype) -> np.ndarray: 

645 """Saturate cast for numeric types. 

646 

647 This function ensures that values outside the representable range 

648 of the target dtype are clamped to the maximum or minimum representable 

649 value of that dtype. 

650 """ 

651 if np.issubdtype(dtype, np.integer) or dtype in ( 

652 ml_dtypes.int4, 

653 ml_dtypes.uint4, 

654 ml_dtypes.int2, 

655 ml_dtypes.uint2, 

656 ): 

657 info = ml_dtypes.iinfo(dtype) 

658 x = np.round(x) 

659 else: 

660 info = ml_dtypes.finfo(dtype) # type: ignore[assignment] 

661 

662 return np.clip(x, info.min, info.max).astype(dtype) # type: ignore[no-any-return]