Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorboard/util/tensor_util.py: 18%

234 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities to manipulate TensorProtos.""" 

16 

17import numpy as np 

18 

19from tensorboard.compat.proto import tensor_pb2 

20from tensorboard.compat.tensorflow_stub import dtypes, compat, tensor_shape 

21 

22 

23def ExtractBitsFromFloat16(x): 

24 return np.asarray(x, dtype=np.float16).view(np.uint16).item() 

25 

26 

27def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values): 

28 tensor_proto.half_val.extend( 

29 [ExtractBitsFromFloat16(x) for x in proto_values] 

30 ) 

31 

32 

33def ExtractBitsFromBFloat16(x): 

34 return ( 

35 np.asarray(x, dtype=dtypes.bfloat16.as_numpy_dtype) 

36 .view(np.uint16) 

37 .item() 

38 ) 

39 

40 

41def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): 

42 tensor_proto.half_val.extend( 

43 [ExtractBitsFromBFloat16(x) for x in proto_values] 

44 ) 

45 

46 

47def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values): 

48 tensor_proto.float_val.extend([x.item() for x in proto_values]) 

49 

50 

51def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values): 

52 tensor_proto.double_val.extend([x.item() for x in proto_values]) 

53 

54 

55def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values): 

56 tensor_proto.int_val.extend([x.item() for x in proto_values]) 

57 

58 

59def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values): 

60 tensor_proto.int64_val.extend([x.item() for x in proto_values]) 

61 

62 

63def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values): 

64 tensor_proto.int_val.extend([x[0].item() for x in proto_values]) 

65 

66 

67def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values): 

68 tensor_proto.uint32_val.extend([x.item() for x in proto_values]) 

69 

70 

71def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values): 

72 tensor_proto.uint64_val.extend([x.item() for x in proto_values]) 

73 

74 

75def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values): 

76 tensor_proto.scomplex_val.extend( 

77 [v.item() for x in proto_values for v in [x.real, x.imag]] 

78 ) 

79 

80 

81def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values): 

82 tensor_proto.dcomplex_val.extend( 

83 [v.item() for x in proto_values for v in [x.real, x.imag]] 

84 ) 

85 

86 

87def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values): 

88 tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values]) 

89 

90 

91def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values): 

92 tensor_proto.bool_val.extend([x.item() for x in proto_values]) 

93 

94 

95_NP_TO_APPEND_FN = { 

96 np.float16: SlowAppendFloat16ArrayToTensorProto, 

97 np.float32: SlowAppendFloat32ArrayToTensorProto, 

98 np.float64: SlowAppendFloat64ArrayToTensorProto, 

99 np.int32: SlowAppendIntArrayToTensorProto, 

100 np.int64: SlowAppendInt64ArrayToTensorProto, 

101 np.uint8: SlowAppendIntArrayToTensorProto, 

102 np.uint16: SlowAppendIntArrayToTensorProto, 

103 np.uint32: SlowAppendUInt32ArrayToTensorProto, 

104 np.uint64: SlowAppendUInt64ArrayToTensorProto, 

105 np.int8: SlowAppendIntArrayToTensorProto, 

106 np.int16: SlowAppendIntArrayToTensorProto, 

107 np.complex64: SlowAppendComplex64ArrayToTensorProto, 

108 np.complex128: SlowAppendComplex128ArrayToTensorProto, 

109 np.object_: SlowAppendObjectArrayToTensorProto, 

110 np.bool_: SlowAppendBoolArrayToTensorProto, 

111 dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 

112 dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 

113 dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 

114 dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 

115 dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 

116 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. 

117} 

118 

119BACKUP_DICT = { 

120 dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto 

121} 

122 

123 

124def GetFromNumpyDTypeDict(dtype_dict, dtype): 

125 # NOTE: dtype_dict.get(dtype) always returns None. 

126 for key, val in dtype_dict.items(): 

127 if key == dtype: 

128 return val 

129 for key, val in BACKUP_DICT.items(): 

130 if key == dtype: 

131 return val 

132 return None 

133 

134 

135def GetNumpyAppendFn(dtype): 

136 # numpy dtype for strings are variable length. We can not compare 

137 # dtype with a single constant (np.string does not exist) to decide 

138 # dtype is a "string" type. We need to compare the dtype.type to be 

139 # sure it's a string type. 

140 if dtype.type == np.string_ or dtype.type == np.unicode_: 

141 return SlowAppendObjectArrayToTensorProto 

142 return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype) 

143 

144 

145def _GetDenseDimensions(list_of_lists): 

146 """Returns the inferred dense dimensions of a list of lists.""" 

147 if not isinstance(list_of_lists, (list, tuple)): 

148 return [] 

149 elif not list_of_lists: 

150 return [0] 

151 else: 

152 return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0]) 

153 

154 

155def _FlattenToStrings(nested_strings): 

156 if isinstance(nested_strings, (list, tuple)): 

157 for inner in nested_strings: 

158 for flattened_string in _FlattenToStrings(inner): 

159 yield flattened_string 

160 else: 

161 yield nested_strings 

162 

163 

164_TENSOR_CONTENT_TYPES = frozenset( 

165 [ 

166 dtypes.float32, 

167 dtypes.float64, 

168 dtypes.int32, 

169 dtypes.uint8, 

170 dtypes.int16, 

171 dtypes.int8, 

172 dtypes.int64, 

173 dtypes.qint8, 

174 dtypes.quint8, 

175 dtypes.qint16, 

176 dtypes.quint16, 

177 dtypes.qint32, 

178 dtypes.uint32, 

179 dtypes.uint64, 

180 ] 

181) 

182 

183 

184class _Message: 

185 def __init__(self, message): 

186 self._message = message 

187 

188 def __repr__(self): 

189 return self._message 

190 

191 

192def _FirstNotNone(l): 

193 for x in l: 

194 if x is not None: 

195 return x 

196 return None 

197 

198 

199def _NotNone(v): 

200 if v is None: 

201 return _Message("None") 

202 else: 

203 return v 

204 

205 

206def _FilterTuple(v): 

207 if not isinstance(v, (list, tuple)): 

208 return v 

209 if isinstance(v, tuple): 

210 if not any(isinstance(x, (list, tuple)) for x in v): 

211 return None 

212 if isinstance(v, list): 

213 if not any(isinstance(x, (list, tuple)) for x in v): 

214 return _FirstNotNone( 

215 [None if isinstance(x, (list, tuple)) else x for x in v] 

216 ) 

217 return _FirstNotNone([_FilterTuple(x) for x in v]) 

218 

219 

220def _FilterInt(v): 

221 if isinstance(v, (list, tuple)): 

222 return _FirstNotNone([_FilterInt(x) for x in v]) 

223 return ( 

224 None 

225 if isinstance(v, (compat.integral_types, tensor_shape.Dimension)) 

226 else _NotNone(v) 

227 ) 

228 

229 

230def _FilterFloat(v): 

231 if isinstance(v, (list, tuple)): 

232 return _FirstNotNone([_FilterFloat(x) for x in v]) 

233 return None if isinstance(v, compat.real_types) else _NotNone(v) 

234 

235 

236def _FilterComplex(v): 

237 if isinstance(v, (list, tuple)): 

238 return _FirstNotNone([_FilterComplex(x) for x in v]) 

239 return None if isinstance(v, compat.complex_types) else _NotNone(v) 

240 

241 

242def _FilterStr(v): 

243 if isinstance(v, (list, tuple)): 

244 return _FirstNotNone([_FilterStr(x) for x in v]) 

245 if isinstance(v, compat.bytes_or_text_types): 

246 return None 

247 else: 

248 return _NotNone(v) 

249 

250 

251def _FilterBool(v): 

252 if isinstance(v, (list, tuple)): 

253 return _FirstNotNone([_FilterBool(x) for x in v]) 

254 return None if isinstance(v, bool) else _NotNone(v) 

255 

256 

257_TF_TO_IS_OK = { 

258 dtypes.bool: [_FilterBool], 

259 dtypes.complex128: [_FilterComplex], 

260 dtypes.complex64: [_FilterComplex], 

261 dtypes.float16: [_FilterFloat], 

262 dtypes.float32: [_FilterFloat], 

263 dtypes.float64: [_FilterFloat], 

264 dtypes.int16: [_FilterInt], 

265 dtypes.int32: [_FilterInt], 

266 dtypes.int64: [_FilterInt], 

267 dtypes.int8: [_FilterInt], 

268 dtypes.qint16: [_FilterInt, _FilterTuple], 

269 dtypes.qint32: [_FilterInt, _FilterTuple], 

270 dtypes.qint8: [_FilterInt, _FilterTuple], 

271 dtypes.quint16: [_FilterInt, _FilterTuple], 

272 dtypes.quint8: [_FilterInt, _FilterTuple], 

273 dtypes.string: [_FilterStr], 

274 dtypes.uint16: [_FilterInt], 

275 dtypes.uint8: [_FilterInt], 

276} 

277 

278 

279def _Assertconvertible(values, dtype): 

280 # If dtype is None or not recognized, assume it's convertible. 

281 if dtype is None or dtype not in _TF_TO_IS_OK: 

282 return 

283 fn_list = _TF_TO_IS_OK.get(dtype) 

284 mismatch = _FirstNotNone([fn(values) for fn in fn_list]) 

285 if mismatch is not None: 

286 raise TypeError( 

287 "Expected %s, got %s of type '%s' instead." 

288 % (dtype.name, repr(mismatch), type(mismatch).__name__) 

289 ) 

290 

291 

292def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): 

293 """Create a TensorProto. 

294 

295 Args: 

296 values: Values to put in the TensorProto. 

297 dtype: Optional tensor_pb2 DataType value. 

298 shape: List of integers representing the dimensions of tensor. 

299 verify_shape: Boolean that enables verification of a shape of values. 

300 

301 Returns: 

302 A `TensorProto`. Depending on the type, it may contain data in the 

303 "tensor_content" attribute, which is not directly useful to Python programs. 

304 To access the values you should convert the proto back to a numpy ndarray 

305 with `tensor_util.MakeNdarray(proto)`. 

306 

307 If `values` is a `TensorProto`, it is immediately returned; `dtype` and 

308 `shape` are ignored. 

309 

310 Raises: 

311 TypeError: if unsupported types are provided. 

312 ValueError: if arguments have inappropriate values or if verify_shape is 

313 True and shape of values is not equals to a shape from the argument. 

314 

315 make_tensor_proto accepts "values" of a python scalar, a python list, a 

316 numpy ndarray, or a numpy scalar. 

317 

318 If "values" is a python scalar or a python list, make_tensor_proto 

319 first convert it to numpy ndarray. If dtype is None, the 

320 conversion tries its best to infer the right numpy data 

321 type. Otherwise, the resulting numpy array has a convertible data 

322 type with the given dtype. 

323 

324 In either case above, the numpy ndarray (either the caller provided 

325 or the auto converted) must have the convertible type with dtype. 

326 

327 make_tensor_proto then converts the numpy array to a tensor proto. 

328 

329 If "shape" is None, the resulting tensor proto represents the numpy 

330 array precisely. 

331 

332 Otherwise, "shape" specifies the tensor's shape and the numpy array 

333 can not have more elements than what "shape" specifies. 

334 """ 

335 if isinstance(values, tensor_pb2.TensorProto): 

336 return values 

337 

338 if dtype: 

339 dtype = dtypes.as_dtype(dtype) 

340 

341 is_quantized = dtype in [ 

342 dtypes.qint8, 

343 dtypes.quint8, 

344 dtypes.qint16, 

345 dtypes.quint16, 

346 dtypes.qint32, 

347 ] 

348 

349 # We first convert value to a numpy array or scalar. 

350 if isinstance(values, (np.ndarray, np.generic)): 

351 if dtype: 

352 nparray = values.astype(dtype.as_numpy_dtype) 

353 else: 

354 nparray = values 

355 elif callable(getattr(values, "__array__", None)) or isinstance( 

356 getattr(values, "__array_interface__", None), dict 

357 ): 

358 # If a class has the __array__ method, or __array_interface__ dict, then it 

359 # is possible to convert to numpy array. 

360 nparray = np.asarray(values, dtype=dtype) 

361 

362 # This is the preferred way to create an array from the object, so replace 

363 # the `values` with the array so that _FlattenToStrings is not run. 

364 values = nparray 

365 else: 

366 if values is None: 

367 raise ValueError("None values not supported.") 

368 # if dtype is provided, forces numpy array to be the type 

369 # provided if possible. 

370 if dtype and dtype.is_numpy_compatible: 

371 np_dt = dtype.as_numpy_dtype 

372 else: 

373 np_dt = None 

374 # If shape is None, numpy.prod returns None when dtype is not set, but raises 

375 # exception when dtype is set to np.int64 

376 if shape is not None and np.prod(shape, dtype=np.int64) == 0: 

377 nparray = np.empty(shape, dtype=np_dt) 

378 else: 

379 _Assertconvertible(values, dtype) 

380 nparray = np.array(values, dtype=np_dt) 

381 # check to them. 

382 # We need to pass in quantized values as tuples, so don't apply the shape 

383 if ( 

384 list(nparray.shape) != _GetDenseDimensions(values) 

385 and not is_quantized 

386 ): 

387 raise ValueError( 

388 """Argument must be a dense tensor: %s""" 

389 """ - got shape %s, but wanted %s.""" 

390 % (values, list(nparray.shape), _GetDenseDimensions(values)) 

391 ) 

392 

393 # python/numpy default float type is float64. We prefer float32 instead. 

394 if (nparray.dtype == np.float64) and dtype is None: 

395 nparray = nparray.astype(np.float32) 

396 # python/numpy default int type is int64. We prefer int32 instead. 

397 elif (nparray.dtype == np.int64) and dtype is None: 

398 downcasted_array = nparray.astype(np.int32) 

399 # Do not down cast if it leads to precision loss. 

400 if np.array_equal(downcasted_array, nparray): 

401 nparray = downcasted_array 

402 

403 # if dtype is provided, it must be convertible with what numpy 

404 # conversion says. 

405 numpy_dtype = dtypes.as_dtype(nparray.dtype) 

406 if numpy_dtype is None: 

407 raise TypeError("Unrecognized data type: %s" % nparray.dtype) 

408 

409 # If dtype was specified and is a quantized type, we convert 

410 # numpy_dtype back into the quantized version. 

411 if is_quantized: 

412 numpy_dtype = dtype 

413 

414 if dtype is not None and ( 

415 not hasattr(dtype, "base_dtype") 

416 or dtype.base_dtype != numpy_dtype.base_dtype 

417 ): 

418 raise TypeError( 

419 "Inconvertible types: %s vs. %s. Value is %s" 

420 % (dtype, nparray.dtype, values) 

421 ) 

422 

423 # If shape is not given, get the shape from the numpy array. 

424 if shape is None: 

425 shape = nparray.shape 

426 is_same_size = True 

427 shape_size = nparray.size 

428 else: 

429 shape = [int(dim) for dim in shape] 

430 shape_size = np.prod(shape, dtype=np.int64) 

431 is_same_size = shape_size == nparray.size 

432 

433 if verify_shape: 

434 if not nparray.shape == tuple(shape): 

435 raise TypeError( 

436 "Expected Tensor's shape: %s, got %s." 

437 % (tuple(shape), nparray.shape) 

438 ) 

439 

440 if nparray.size > shape_size: 

441 raise ValueError( 

442 "Too many elements provided. Needed at most %d, but received %d" 

443 % (shape_size, nparray.size) 

444 ) 

445 

446 tensor_proto = tensor_pb2.TensorProto( 

447 dtype=numpy_dtype.as_datatype_enum, 

448 tensor_shape=tensor_shape.as_shape(shape).as_proto(), 

449 ) 

450 

451 if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1: 

452 if nparray.size * nparray.itemsize >= (1 << 31): 

453 raise ValueError( 

454 "Cannot create a tensor proto whose content is larger than 2GB." 

455 ) 

456 tensor_proto.tensor_content = nparray.tobytes() 

457 return tensor_proto 

458 

459 # If we were not given values as a numpy array, compute the proto_values 

460 # from the given values directly, to avoid numpy trimming nulls from the 

461 # strings. Since values could be a list of strings, or a multi-dimensional 

462 # list of lists that might or might not correspond to the given shape, 

463 # we flatten it conservatively. 

464 if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray): 

465 proto_values = _FlattenToStrings(values) 

466 

467 # At this point, values may be a list of objects that we could not 

468 # identify a common type for (hence it was inferred as 

469 # np.object/dtypes.string). If we are unable to convert it to a 

470 # string, we raise a more helpful error message. 

471 # 

472 # Ideally, we'd be able to convert the elements of the list to a 

473 # common type, but this type inference requires some thinking and 

474 # so we defer it for now. 

475 try: 

476 str_values = [compat.as_bytes(x) for x in proto_values] 

477 except TypeError: 

478 raise TypeError( 

479 "Failed to convert object of type %s to Tensor. " 

480 "Contents: %s. Consider casting elements to a " 

481 "supported type." % (type(values), values) 

482 ) 

483 tensor_proto.string_val.extend(str_values) 

484 return tensor_proto 

485 

486 # TensorFlow expects C order (a.k.a., eigen row major). 

487 proto_values = nparray.ravel() 

488 

489 append_fn = GetNumpyAppendFn(proto_values.dtype) 

490 if append_fn is None: 

491 raise TypeError( 

492 "Element type not supported in TensorProto: %s" % numpy_dtype.name 

493 ) 

494 append_fn(tensor_proto, proto_values) 

495 

496 return tensor_proto 

497 

498 

499def make_ndarray(tensor): 

500 """Create a numpy ndarray from a tensor. 

501 

502 Create a numpy ndarray with the same shape and data as the tensor. 

503 

504 Args: 

505 tensor: A TensorProto. 

506 

507 Returns: 

508 A numpy array with the tensor contents. 

509 

510 Raises: 

511 TypeError: if tensor has unsupported type. 

512 """ 

513 shape = [d.size for d in tensor.tensor_shape.dim] 

514 num_elements = np.prod(shape, dtype=np.int64) 

515 tensor_dtype = dtypes.as_dtype(tensor.dtype) 

516 dtype = tensor_dtype.as_numpy_dtype 

517 

518 if tensor.tensor_content: 

519 return ( 

520 np.frombuffer(tensor.tensor_content, dtype=dtype) 

521 .copy() 

522 .reshape(shape) 

523 ) 

524 elif tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16: 

525 # the half_val field of the TensorProto stores the binary representation 

526 # of the fp16: we need to reinterpret this as a proper float16 

527 if len(tensor.half_val) == 1: 

528 tmp = np.array(tensor.half_val[0], dtype=np.uint16) 

529 tmp.dtype = tensor_dtype.as_numpy_dtype 

530 return np.repeat(tmp, num_elements).reshape(shape) 

531 else: 

532 tmp = np.fromiter(tensor.half_val, dtype=np.uint16) 

533 tmp.dtype = tensor_dtype.as_numpy_dtype 

534 return tmp.reshape(shape) 

535 elif tensor_dtype == dtypes.float32: 

536 if len(tensor.float_val) == 1: 

537 return np.repeat( 

538 np.array(tensor.float_val[0], dtype=dtype), num_elements 

539 ).reshape(shape) 

540 else: 

541 return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape) 

542 elif tensor_dtype == dtypes.float64: 

543 if len(tensor.double_val) == 1: 

544 return np.repeat( 

545 np.array(tensor.double_val[0], dtype=dtype), num_elements 

546 ).reshape(shape) 

547 else: 

548 return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape) 

549 elif tensor_dtype in [ 

550 dtypes.int32, 

551 dtypes.uint8, 

552 dtypes.uint16, 

553 dtypes.int16, 

554 dtypes.int8, 

555 dtypes.qint32, 

556 dtypes.quint8, 

557 dtypes.qint8, 

558 dtypes.qint16, 

559 dtypes.quint16, 

560 ]: 

561 if len(tensor.int_val) == 1: 

562 return np.repeat( 

563 np.array(tensor.int_val[0], dtype=dtype), num_elements 

564 ).reshape(shape) 

565 else: 

566 return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape) 

567 elif tensor_dtype == dtypes.int64: 

568 if len(tensor.int64_val) == 1: 

569 return np.repeat( 

570 np.array(tensor.int64_val[0], dtype=dtype), num_elements 

571 ).reshape(shape) 

572 else: 

573 return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape) 

574 elif tensor_dtype == dtypes.string: 

575 if len(tensor.string_val) == 1: 

576 return np.repeat( 

577 np.array(tensor.string_val[0], dtype=dtype), num_elements 

578 ).reshape(shape) 

579 else: 

580 return np.array(list(tensor.string_val), dtype=dtype).reshape(shape) 

581 elif tensor_dtype == dtypes.complex64: 

582 it = iter(tensor.scomplex_val) 

583 if len(tensor.scomplex_val) == 2: 

584 return np.repeat( 

585 np.array( 

586 complex(tensor.scomplex_val[0], tensor.scomplex_val[1]), 

587 dtype=dtype, 

588 ), 

589 num_elements, 

590 ).reshape(shape) 

591 else: 

592 return np.array( 

593 [complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype 

594 ).reshape(shape) 

595 elif tensor_dtype == dtypes.complex128: 

596 it = iter(tensor.dcomplex_val) 

597 if len(tensor.dcomplex_val) == 2: 

598 return np.repeat( 

599 np.array( 

600 complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]), 

601 dtype=dtype, 

602 ), 

603 num_elements, 

604 ).reshape(shape) 

605 else: 

606 return np.array( 

607 [complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype 

608 ).reshape(shape) 

609 elif tensor_dtype == dtypes.bool: 

610 if len(tensor.bool_val) == 1: 

611 return np.repeat( 

612 np.array(tensor.bool_val[0], dtype=dtype), num_elements 

613 ).reshape(shape) 

614 else: 

615 return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape) 

616 else: 

617 raise TypeError("Unsupported tensor type: %s" % tensor.dtype)