Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorboard/util/tensor_util.py: 18%
234 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to manipulate TensorProtos."""
17import numpy as np
19from tensorboard.compat.proto import tensor_pb2
20from tensorboard.compat.tensorflow_stub import dtypes, compat, tensor_shape
23def ExtractBitsFromFloat16(x):
24 return np.asarray(x, dtype=np.float16).view(np.uint16).item()
27def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
28 tensor_proto.half_val.extend(
29 [ExtractBitsFromFloat16(x) for x in proto_values]
30 )
33def ExtractBitsFromBFloat16(x):
34 return (
35 np.asarray(x, dtype=dtypes.bfloat16.as_numpy_dtype)
36 .view(np.uint16)
37 .item()
38 )
41def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
42 tensor_proto.half_val.extend(
43 [ExtractBitsFromBFloat16(x) for x in proto_values]
44 )
47def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values):
48 tensor_proto.float_val.extend([x.item() for x in proto_values])
51def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values):
52 tensor_proto.double_val.extend([x.item() for x in proto_values])
55def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values):
56 tensor_proto.int_val.extend([x.item() for x in proto_values])
59def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
60 tensor_proto.int64_val.extend([x.item() for x in proto_values])
63def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values):
64 tensor_proto.int_val.extend([x[0].item() for x in proto_values])
67def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values):
68 tensor_proto.uint32_val.extend([x.item() for x in proto_values])
71def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values):
72 tensor_proto.uint64_val.extend([x.item() for x in proto_values])
75def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values):
76 tensor_proto.scomplex_val.extend(
77 [v.item() for x in proto_values for v in [x.real, x.imag]]
78 )
81def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values):
82 tensor_proto.dcomplex_val.extend(
83 [v.item() for x in proto_values for v in [x.real, x.imag]]
84 )
87def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
88 tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
91def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values):
92 tensor_proto.bool_val.extend([x.item() for x in proto_values])
95_NP_TO_APPEND_FN = {
96 np.float16: SlowAppendFloat16ArrayToTensorProto,
97 np.float32: SlowAppendFloat32ArrayToTensorProto,
98 np.float64: SlowAppendFloat64ArrayToTensorProto,
99 np.int32: SlowAppendIntArrayToTensorProto,
100 np.int64: SlowAppendInt64ArrayToTensorProto,
101 np.uint8: SlowAppendIntArrayToTensorProto,
102 np.uint16: SlowAppendIntArrayToTensorProto,
103 np.uint32: SlowAppendUInt32ArrayToTensorProto,
104 np.uint64: SlowAppendUInt64ArrayToTensorProto,
105 np.int8: SlowAppendIntArrayToTensorProto,
106 np.int16: SlowAppendIntArrayToTensorProto,
107 np.complex64: SlowAppendComplex64ArrayToTensorProto,
108 np.complex128: SlowAppendComplex128ArrayToTensorProto,
109 np.object_: SlowAppendObjectArrayToTensorProto,
110 np.bool_: SlowAppendBoolArrayToTensorProto,
111 dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
112 dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
113 dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
114 dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
115 dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
116 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
117}
119BACKUP_DICT = {
120 dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto
121}
124def GetFromNumpyDTypeDict(dtype_dict, dtype):
125 # NOTE: dtype_dict.get(dtype) always returns None.
126 for key, val in dtype_dict.items():
127 if key == dtype:
128 return val
129 for key, val in BACKUP_DICT.items():
130 if key == dtype:
131 return val
132 return None
135def GetNumpyAppendFn(dtype):
136 # numpy dtype for strings are variable length. We can not compare
137 # dtype with a single constant (np.string does not exist) to decide
138 # dtype is a "string" type. We need to compare the dtype.type to be
139 # sure it's a string type.
140 if dtype.type == np.string_ or dtype.type == np.unicode_:
141 return SlowAppendObjectArrayToTensorProto
142 return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
145def _GetDenseDimensions(list_of_lists):
146 """Returns the inferred dense dimensions of a list of lists."""
147 if not isinstance(list_of_lists, (list, tuple)):
148 return []
149 elif not list_of_lists:
150 return [0]
151 else:
152 return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0])
155def _FlattenToStrings(nested_strings):
156 if isinstance(nested_strings, (list, tuple)):
157 for inner in nested_strings:
158 for flattened_string in _FlattenToStrings(inner):
159 yield flattened_string
160 else:
161 yield nested_strings
164_TENSOR_CONTENT_TYPES = frozenset(
165 [
166 dtypes.float32,
167 dtypes.float64,
168 dtypes.int32,
169 dtypes.uint8,
170 dtypes.int16,
171 dtypes.int8,
172 dtypes.int64,
173 dtypes.qint8,
174 dtypes.quint8,
175 dtypes.qint16,
176 dtypes.quint16,
177 dtypes.qint32,
178 dtypes.uint32,
179 dtypes.uint64,
180 ]
181)
184class _Message:
185 def __init__(self, message):
186 self._message = message
188 def __repr__(self):
189 return self._message
192def _FirstNotNone(l):
193 for x in l:
194 if x is not None:
195 return x
196 return None
199def _NotNone(v):
200 if v is None:
201 return _Message("None")
202 else:
203 return v
206def _FilterTuple(v):
207 if not isinstance(v, (list, tuple)):
208 return v
209 if isinstance(v, tuple):
210 if not any(isinstance(x, (list, tuple)) for x in v):
211 return None
212 if isinstance(v, list):
213 if not any(isinstance(x, (list, tuple)) for x in v):
214 return _FirstNotNone(
215 [None if isinstance(x, (list, tuple)) else x for x in v]
216 )
217 return _FirstNotNone([_FilterTuple(x) for x in v])
220def _FilterInt(v):
221 if isinstance(v, (list, tuple)):
222 return _FirstNotNone([_FilterInt(x) for x in v])
223 return (
224 None
225 if isinstance(v, (compat.integral_types, tensor_shape.Dimension))
226 else _NotNone(v)
227 )
230def _FilterFloat(v):
231 if isinstance(v, (list, tuple)):
232 return _FirstNotNone([_FilterFloat(x) for x in v])
233 return None if isinstance(v, compat.real_types) else _NotNone(v)
236def _FilterComplex(v):
237 if isinstance(v, (list, tuple)):
238 return _FirstNotNone([_FilterComplex(x) for x in v])
239 return None if isinstance(v, compat.complex_types) else _NotNone(v)
242def _FilterStr(v):
243 if isinstance(v, (list, tuple)):
244 return _FirstNotNone([_FilterStr(x) for x in v])
245 if isinstance(v, compat.bytes_or_text_types):
246 return None
247 else:
248 return _NotNone(v)
251def _FilterBool(v):
252 if isinstance(v, (list, tuple)):
253 return _FirstNotNone([_FilterBool(x) for x in v])
254 return None if isinstance(v, bool) else _NotNone(v)
257_TF_TO_IS_OK = {
258 dtypes.bool: [_FilterBool],
259 dtypes.complex128: [_FilterComplex],
260 dtypes.complex64: [_FilterComplex],
261 dtypes.float16: [_FilterFloat],
262 dtypes.float32: [_FilterFloat],
263 dtypes.float64: [_FilterFloat],
264 dtypes.int16: [_FilterInt],
265 dtypes.int32: [_FilterInt],
266 dtypes.int64: [_FilterInt],
267 dtypes.int8: [_FilterInt],
268 dtypes.qint16: [_FilterInt, _FilterTuple],
269 dtypes.qint32: [_FilterInt, _FilterTuple],
270 dtypes.qint8: [_FilterInt, _FilterTuple],
271 dtypes.quint16: [_FilterInt, _FilterTuple],
272 dtypes.quint8: [_FilterInt, _FilterTuple],
273 dtypes.string: [_FilterStr],
274 dtypes.uint16: [_FilterInt],
275 dtypes.uint8: [_FilterInt],
276}
279def _Assertconvertible(values, dtype):
280 # If dtype is None or not recognized, assume it's convertible.
281 if dtype is None or dtype not in _TF_TO_IS_OK:
282 return
283 fn_list = _TF_TO_IS_OK.get(dtype)
284 mismatch = _FirstNotNone([fn(values) for fn in fn_list])
285 if mismatch is not None:
286 raise TypeError(
287 "Expected %s, got %s of type '%s' instead."
288 % (dtype.name, repr(mismatch), type(mismatch).__name__)
289 )
292def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
293 """Create a TensorProto.
295 Args:
296 values: Values to put in the TensorProto.
297 dtype: Optional tensor_pb2 DataType value.
298 shape: List of integers representing the dimensions of tensor.
299 verify_shape: Boolean that enables verification of a shape of values.
301 Returns:
302 A `TensorProto`. Depending on the type, it may contain data in the
303 "tensor_content" attribute, which is not directly useful to Python programs.
304 To access the values you should convert the proto back to a numpy ndarray
305 with `tensor_util.MakeNdarray(proto)`.
307 If `values` is a `TensorProto`, it is immediately returned; `dtype` and
308 `shape` are ignored.
310 Raises:
311 TypeError: if unsupported types are provided.
312 ValueError: if arguments have inappropriate values or if verify_shape is
313 True and shape of values is not equals to a shape from the argument.
315 make_tensor_proto accepts "values" of a python scalar, a python list, a
316 numpy ndarray, or a numpy scalar.
318 If "values" is a python scalar or a python list, make_tensor_proto
319 first convert it to numpy ndarray. If dtype is None, the
320 conversion tries its best to infer the right numpy data
321 type. Otherwise, the resulting numpy array has a convertible data
322 type with the given dtype.
324 In either case above, the numpy ndarray (either the caller provided
325 or the auto converted) must have the convertible type with dtype.
327 make_tensor_proto then converts the numpy array to a tensor proto.
329 If "shape" is None, the resulting tensor proto represents the numpy
330 array precisely.
332 Otherwise, "shape" specifies the tensor's shape and the numpy array
333 can not have more elements than what "shape" specifies.
334 """
335 if isinstance(values, tensor_pb2.TensorProto):
336 return values
338 if dtype:
339 dtype = dtypes.as_dtype(dtype)
341 is_quantized = dtype in [
342 dtypes.qint8,
343 dtypes.quint8,
344 dtypes.qint16,
345 dtypes.quint16,
346 dtypes.qint32,
347 ]
349 # We first convert value to a numpy array or scalar.
350 if isinstance(values, (np.ndarray, np.generic)):
351 if dtype:
352 nparray = values.astype(dtype.as_numpy_dtype)
353 else:
354 nparray = values
355 elif callable(getattr(values, "__array__", None)) or isinstance(
356 getattr(values, "__array_interface__", None), dict
357 ):
358 # If a class has the __array__ method, or __array_interface__ dict, then it
359 # is possible to convert to numpy array.
360 nparray = np.asarray(values, dtype=dtype)
362 # This is the preferred way to create an array from the object, so replace
363 # the `values` with the array so that _FlattenToStrings is not run.
364 values = nparray
365 else:
366 if values is None:
367 raise ValueError("None values not supported.")
368 # if dtype is provided, forces numpy array to be the type
369 # provided if possible.
370 if dtype and dtype.is_numpy_compatible:
371 np_dt = dtype.as_numpy_dtype
372 else:
373 np_dt = None
374 # If shape is None, numpy.prod returns None when dtype is not set, but raises
375 # exception when dtype is set to np.int64
376 if shape is not None and np.prod(shape, dtype=np.int64) == 0:
377 nparray = np.empty(shape, dtype=np_dt)
378 else:
379 _Assertconvertible(values, dtype)
380 nparray = np.array(values, dtype=np_dt)
381 # check to them.
382 # We need to pass in quantized values as tuples, so don't apply the shape
383 if (
384 list(nparray.shape) != _GetDenseDimensions(values)
385 and not is_quantized
386 ):
387 raise ValueError(
388 """Argument must be a dense tensor: %s"""
389 """ - got shape %s, but wanted %s."""
390 % (values, list(nparray.shape), _GetDenseDimensions(values))
391 )
393 # python/numpy default float type is float64. We prefer float32 instead.
394 if (nparray.dtype == np.float64) and dtype is None:
395 nparray = nparray.astype(np.float32)
396 # python/numpy default int type is int64. We prefer int32 instead.
397 elif (nparray.dtype == np.int64) and dtype is None:
398 downcasted_array = nparray.astype(np.int32)
399 # Do not down cast if it leads to precision loss.
400 if np.array_equal(downcasted_array, nparray):
401 nparray = downcasted_array
403 # if dtype is provided, it must be convertible with what numpy
404 # conversion says.
405 numpy_dtype = dtypes.as_dtype(nparray.dtype)
406 if numpy_dtype is None:
407 raise TypeError("Unrecognized data type: %s" % nparray.dtype)
409 # If dtype was specified and is a quantized type, we convert
410 # numpy_dtype back into the quantized version.
411 if is_quantized:
412 numpy_dtype = dtype
414 if dtype is not None and (
415 not hasattr(dtype, "base_dtype")
416 or dtype.base_dtype != numpy_dtype.base_dtype
417 ):
418 raise TypeError(
419 "Inconvertible types: %s vs. %s. Value is %s"
420 % (dtype, nparray.dtype, values)
421 )
423 # If shape is not given, get the shape from the numpy array.
424 if shape is None:
425 shape = nparray.shape
426 is_same_size = True
427 shape_size = nparray.size
428 else:
429 shape = [int(dim) for dim in shape]
430 shape_size = np.prod(shape, dtype=np.int64)
431 is_same_size = shape_size == nparray.size
433 if verify_shape:
434 if not nparray.shape == tuple(shape):
435 raise TypeError(
436 "Expected Tensor's shape: %s, got %s."
437 % (tuple(shape), nparray.shape)
438 )
440 if nparray.size > shape_size:
441 raise ValueError(
442 "Too many elements provided. Needed at most %d, but received %d"
443 % (shape_size, nparray.size)
444 )
446 tensor_proto = tensor_pb2.TensorProto(
447 dtype=numpy_dtype.as_datatype_enum,
448 tensor_shape=tensor_shape.as_shape(shape).as_proto(),
449 )
451 if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
452 if nparray.size * nparray.itemsize >= (1 << 31):
453 raise ValueError(
454 "Cannot create a tensor proto whose content is larger than 2GB."
455 )
456 tensor_proto.tensor_content = nparray.tobytes()
457 return tensor_proto
459 # If we were not given values as a numpy array, compute the proto_values
460 # from the given values directly, to avoid numpy trimming nulls from the
461 # strings. Since values could be a list of strings, or a multi-dimensional
462 # list of lists that might or might not correspond to the given shape,
463 # we flatten it conservatively.
464 if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray):
465 proto_values = _FlattenToStrings(values)
467 # At this point, values may be a list of objects that we could not
468 # identify a common type for (hence it was inferred as
469 # np.object/dtypes.string). If we are unable to convert it to a
470 # string, we raise a more helpful error message.
471 #
472 # Ideally, we'd be able to convert the elements of the list to a
473 # common type, but this type inference requires some thinking and
474 # so we defer it for now.
475 try:
476 str_values = [compat.as_bytes(x) for x in proto_values]
477 except TypeError:
478 raise TypeError(
479 "Failed to convert object of type %s to Tensor. "
480 "Contents: %s. Consider casting elements to a "
481 "supported type." % (type(values), values)
482 )
483 tensor_proto.string_val.extend(str_values)
484 return tensor_proto
486 # TensorFlow expects C order (a.k.a., eigen row major).
487 proto_values = nparray.ravel()
489 append_fn = GetNumpyAppendFn(proto_values.dtype)
490 if append_fn is None:
491 raise TypeError(
492 "Element type not supported in TensorProto: %s" % numpy_dtype.name
493 )
494 append_fn(tensor_proto, proto_values)
496 return tensor_proto
499def make_ndarray(tensor):
500 """Create a numpy ndarray from a tensor.
502 Create a numpy ndarray with the same shape and data as the tensor.
504 Args:
505 tensor: A TensorProto.
507 Returns:
508 A numpy array with the tensor contents.
510 Raises:
511 TypeError: if tensor has unsupported type.
512 """
513 shape = [d.size for d in tensor.tensor_shape.dim]
514 num_elements = np.prod(shape, dtype=np.int64)
515 tensor_dtype = dtypes.as_dtype(tensor.dtype)
516 dtype = tensor_dtype.as_numpy_dtype
518 if tensor.tensor_content:
519 return (
520 np.frombuffer(tensor.tensor_content, dtype=dtype)
521 .copy()
522 .reshape(shape)
523 )
524 elif tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
525 # the half_val field of the TensorProto stores the binary representation
526 # of the fp16: we need to reinterpret this as a proper float16
527 if len(tensor.half_val) == 1:
528 tmp = np.array(tensor.half_val[0], dtype=np.uint16)
529 tmp.dtype = tensor_dtype.as_numpy_dtype
530 return np.repeat(tmp, num_elements).reshape(shape)
531 else:
532 tmp = np.fromiter(tensor.half_val, dtype=np.uint16)
533 tmp.dtype = tensor_dtype.as_numpy_dtype
534 return tmp.reshape(shape)
535 elif tensor_dtype == dtypes.float32:
536 if len(tensor.float_val) == 1:
537 return np.repeat(
538 np.array(tensor.float_val[0], dtype=dtype), num_elements
539 ).reshape(shape)
540 else:
541 return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
542 elif tensor_dtype == dtypes.float64:
543 if len(tensor.double_val) == 1:
544 return np.repeat(
545 np.array(tensor.double_val[0], dtype=dtype), num_elements
546 ).reshape(shape)
547 else:
548 return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
549 elif tensor_dtype in [
550 dtypes.int32,
551 dtypes.uint8,
552 dtypes.uint16,
553 dtypes.int16,
554 dtypes.int8,
555 dtypes.qint32,
556 dtypes.quint8,
557 dtypes.qint8,
558 dtypes.qint16,
559 dtypes.quint16,
560 ]:
561 if len(tensor.int_val) == 1:
562 return np.repeat(
563 np.array(tensor.int_val[0], dtype=dtype), num_elements
564 ).reshape(shape)
565 else:
566 return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
567 elif tensor_dtype == dtypes.int64:
568 if len(tensor.int64_val) == 1:
569 return np.repeat(
570 np.array(tensor.int64_val[0], dtype=dtype), num_elements
571 ).reshape(shape)
572 else:
573 return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
574 elif tensor_dtype == dtypes.string:
575 if len(tensor.string_val) == 1:
576 return np.repeat(
577 np.array(tensor.string_val[0], dtype=dtype), num_elements
578 ).reshape(shape)
579 else:
580 return np.array(list(tensor.string_val), dtype=dtype).reshape(shape)
581 elif tensor_dtype == dtypes.complex64:
582 it = iter(tensor.scomplex_val)
583 if len(tensor.scomplex_val) == 2:
584 return np.repeat(
585 np.array(
586 complex(tensor.scomplex_val[0], tensor.scomplex_val[1]),
587 dtype=dtype,
588 ),
589 num_elements,
590 ).reshape(shape)
591 else:
592 return np.array(
593 [complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype
594 ).reshape(shape)
595 elif tensor_dtype == dtypes.complex128:
596 it = iter(tensor.dcomplex_val)
597 if len(tensor.dcomplex_val) == 2:
598 return np.repeat(
599 np.array(
600 complex(tensor.dcomplex_val[0], tensor.dcomplex_val[1]),
601 dtype=dtype,
602 ),
603 num_elements,
604 ).reshape(shape)
605 else:
606 return np.array(
607 [complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype
608 ).reshape(shape)
609 elif tensor_dtype == dtypes.bool:
610 if len(tensor.bool_val) == 1:
611 return np.repeat(
612 np.array(tensor.bool_val[0], dtype=dtype), num_elements
613 ).reshape(shape)
614 else:
615 return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
616 else:
617 raise TypeError("Unsupported tensor type: %s" % tensor.dtype)