Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/tensor_util.py: 14%
504 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to create TensorProtos."""
16import typing
17from typing import Protocol
18import numpy as np
20from tensorflow.core.framework import tensor_pb2
21from tensorflow.core.framework import tensor_shape_pb2
22from tensorflow.python.client import pywrap_tf_session as c_api
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import errors_impl
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.types import core
27from tensorflow.python.types import internal
28from tensorflow.python.util import compat
29from tensorflow.python.util import nest
30from tensorflow.python.util.tf_export import tf_export
32# Fallback in case fast_tensor_util is not properly compiled.
33# pylint: disable=g-import-not-at-top
34try:
35 from tensorflow.python.framework import fast_tensor_util
36 _FAST_TENSOR_UTIL_AVAILABLE = True
37except ImportError:
38 _FAST_TENSOR_UTIL_AVAILABLE = False
39# pylint: enable=g-import-not-at-top
42def ExtractBitsFromFloat16(x):
43 return np.asarray(x, dtype=np.float16).view(np.uint16).item()
46def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
47 tensor_proto.half_val.extend(
48 [ExtractBitsFromFloat16(x) for x in proto_values])
51def _MediumAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
52 # TODO: Remove the conversion if cython supports np.float16_t
53 fast_tensor_util.AppendFloat16ArrayToTensorProto(
54 tensor_proto,
55 np.asarray(proto_values, dtype=np.float16).view(np.uint16))
58def ExtractBitsFromBFloat16(x):
59 return np.asarray(
60 x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16).item()
63def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
64 tensor_proto.half_val.extend(
65 [ExtractBitsFromBFloat16(x) for x in proto_values])
68def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
69 fast_tensor_util.AppendBFloat16ArrayToTensorProto(
70 tensor_proto, np.asarray(
71 proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16))
74def ExtractBitsFromFloat8e5m2(x):
75 return np.asarray(
76 x, dtype=dtypes.float8_e5m2.as_numpy_dtype).view(np.uint8).item()
79def SlowAppendFloat8e5m2ArrayToTensorProto(tensor_proto, proto_values):
80 tensor_proto.half_val.extend(
81 [ExtractBitsFromFloat8e5m2(x) for x in proto_values])
84def FastAppendFloat8e5m2ArrayToTensorProto(tensor_proto, proto_values):
85 fast_tensor_util.AppendFloat8ArrayToTensorProto(
86 tensor_proto,
87 np.asarray(proto_values,
88 dtype=dtypes.float8_e5m2.as_numpy_dtype).view(np.uint8))
91def ExtractBitsFromFloat8e4m3fn(x):
92 return np.asarray(
93 x, dtype=dtypes.float8_e4m3fn.as_numpy_dtype).view(np.uint8).item()
96def SlowAppendFloat8e4m3fnArrayToTensorProto(tensor_proto, proto_values):
97 tensor_proto.half_val.extend(
98 [ExtractBitsFromFloat8e4m3fn(x) for x in proto_values])
101def FastAppendFloat8e4m3fnArrayToTensorProto(tensor_proto, proto_values):
102 fast_tensor_util.AppendFloat8ArrayToTensorProto(
103 tensor_proto,
104 np.asarray(proto_values,
105 dtype=dtypes.float8_e4m3fn.as_numpy_dtype).view(np.uint8))
108if _FAST_TENSOR_UTIL_AVAILABLE:
109 _NP_TO_APPEND_FN = {
110 dtypes.bfloat16.as_numpy_dtype:
111 FastAppendBFloat16ArrayToTensorProto,
112 dtypes.float8_e5m2.as_numpy_dtype:
113 FastAppendFloat8e5m2ArrayToTensorProto,
114 dtypes.float8_e4m3fn.as_numpy_dtype:
115 FastAppendFloat8e4m3fnArrayToTensorProto,
116 np.float16:
117 _MediumAppendFloat16ArrayToTensorProto,
118 np.float32:
119 fast_tensor_util.AppendFloat32ArrayToTensorProto,
120 np.float64:
121 fast_tensor_util.AppendFloat64ArrayToTensorProto,
122 np.int32:
123 fast_tensor_util.AppendInt32ArrayToTensorProto,
124 np.int64:
125 fast_tensor_util.AppendInt64ArrayToTensorProto,
126 np.uint8:
127 fast_tensor_util.AppendUInt8ArrayToTensorProto,
128 np.uint16:
129 fast_tensor_util.AppendUInt16ArrayToTensorProto,
130 np.uint32:
131 fast_tensor_util.AppendUInt32ArrayToTensorProto,
132 np.uint64:
133 fast_tensor_util.AppendUInt64ArrayToTensorProto,
134 np.int8:
135 fast_tensor_util.AppendInt8ArrayToTensorProto,
136 np.int16:
137 fast_tensor_util.AppendInt16ArrayToTensorProto,
138 np.complex64:
139 fast_tensor_util.AppendComplex64ArrayToTensorProto,
140 np.complex128:
141 fast_tensor_util.AppendComplex128ArrayToTensorProto,
142 np.object_:
143 fast_tensor_util.AppendObjectArrayToTensorProto,
144 np.bool_:
145 fast_tensor_util.AppendBoolArrayToTensorProto,
146 dtypes.qint8.as_numpy_dtype:
147 fast_tensor_util.AppendInt8ArrayToTensorProto,
148 dtypes.quint8.as_numpy_dtype:
149 fast_tensor_util.AppendUInt8ArrayToTensorProto,
150 dtypes.qint16.as_numpy_dtype:
151 fast_tensor_util.AppendInt16ArrayToTensorProto,
152 dtypes.quint16.as_numpy_dtype:
153 fast_tensor_util.AppendUInt16ArrayToTensorProto,
154 dtypes.qint32.as_numpy_dtype:
155 fast_tensor_util.AppendInt32ArrayToTensorProto,
156 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
157 }
158else:
160 def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values):
161 tensor_proto.float_val.extend([x.item() for x in proto_values])
163 def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values):
164 tensor_proto.double_val.extend([x.item() for x in proto_values])
166 def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values):
167 tensor_proto.int_val.extend([x.item() for x in proto_values])
169 def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
170 tensor_proto.int64_val.extend([x.item() for x in proto_values])
172 def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values):
173 tensor_proto.int_val.extend([x.item()[0] for x in proto_values])
175 def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values):
176 tensor_proto.uint32_val.extend([x.item() for x in proto_values])
178 def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values):
179 tensor_proto.uint64_val.extend([x.item() for x in proto_values])
181 def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values):
182 tensor_proto.scomplex_val.extend(
183 [v.item() for x in proto_values for v in [x.real, x.imag]])
185 def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values):
186 tensor_proto.dcomplex_val.extend(
187 [v.item() for x in proto_values for v in [x.real, x.imag]])
189 def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
190 tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
192 def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values):
193 tensor_proto.bool_val.extend([x.item() for x in proto_values])
195 _NP_TO_APPEND_FN = {
196 dtypes.bfloat16.as_numpy_dtype:
197 SlowAppendBFloat16ArrayToTensorProto,
198 dtypes.float8_e5m2.as_numpy_dtype:
199 SlowAppendFloat8e5m2ArrayToTensorProto,
200 dtypes.float8_e4m3fn.as_numpy_dtype:
201 SlowAppendFloat8e4m3fnArrayToTensorProto,
202 np.float16:
203 SlowAppendFloat16ArrayToTensorProto,
204 np.float32:
205 SlowAppendFloat32ArrayToTensorProto,
206 np.float64:
207 SlowAppendFloat64ArrayToTensorProto,
208 np.int32:
209 SlowAppendIntArrayToTensorProto,
210 np.int64:
211 SlowAppendInt64ArrayToTensorProto,
212 np.uint8:
213 SlowAppendIntArrayToTensorProto,
214 np.uint16:
215 SlowAppendIntArrayToTensorProto,
216 np.uint32:
217 SlowAppendUInt32ArrayToTensorProto,
218 np.uint64:
219 SlowAppendUInt64ArrayToTensorProto,
220 np.int8:
221 SlowAppendIntArrayToTensorProto,
222 np.int16:
223 SlowAppendIntArrayToTensorProto,
224 np.complex64:
225 SlowAppendComplex64ArrayToTensorProto,
226 np.complex128:
227 SlowAppendComplex128ArrayToTensorProto,
228 np.object_:
229 SlowAppendObjectArrayToTensorProto,
230 np.bool_:
231 SlowAppendBoolArrayToTensorProto,
232 dtypes.qint8.as_numpy_dtype:
233 SlowAppendQIntArrayToTensorProto,
234 dtypes.quint8.as_numpy_dtype:
235 SlowAppendQIntArrayToTensorProto,
236 dtypes.qint16.as_numpy_dtype:
237 SlowAppendQIntArrayToTensorProto,
238 dtypes.quint16.as_numpy_dtype:
239 SlowAppendQIntArrayToTensorProto,
240 dtypes.qint32.as_numpy_dtype:
241 SlowAppendQIntArrayToTensorProto,
242 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
243 }
246def GetFromNumpyDTypeDict(dtype_dict, dtype):
247 # NOTE: dtype_dict.get(dtype) always returns None.
248 for key, val in dtype_dict.items():
249 if key == dtype:
250 return val
251 return None
254def GetNumpyAppendFn(dtype):
255 # numpy dtype for strings are variable length. We can not compare
256 # dtype with a single constant (np.string does not exist) to decide
257 # dtype is a "string" type. We need to compare the dtype.type to be
258 # sure it's a string type.
259 if dtype.type == np.bytes_ or dtype.type == np.str_:
260 if _FAST_TENSOR_UTIL_AVAILABLE:
261 return fast_tensor_util.AppendObjectArrayToTensorProto
262 else:
263 return SlowAppendObjectArrayToTensorProto
264 return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
267def TensorShapeProtoToList(shape):
268 """Convert a TensorShape to a list.
270 Args:
271 shape: A TensorShapeProto.
273 Returns:
274 List of integers representing the dimensions of the tensor.
275 """
276 return [dim.size for dim in shape.dim]
279def _GetDenseDimensions(list_of_lists):
280 """Returns the inferred dense dimensions of a list of lists."""
281 if not isinstance(list_of_lists, (list, tuple)):
282 return []
283 elif not list_of_lists:
284 return [0]
285 else:
286 return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0])
289def _FlattenToStrings(nested_strings):
290 if isinstance(nested_strings, (list, tuple)):
291 for inner in nested_strings:
292 for flattened_string in _FlattenToStrings(inner):
293 yield flattened_string
294 else:
295 yield nested_strings
298_TENSOR_CONTENT_TYPES = frozenset([
299 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8,
300 dtypes.int16, dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8,
301 dtypes.qint16, dtypes.quint16, dtypes.qint32, dtypes.uint32, dtypes.uint64,
302 dtypes.float8_e5m2, dtypes.float8_e4m3fn
303])
306# pylint: disable=invalid-name
307def _check_failed(v):
308 # NB. none of the _check_* functions could raise a ValueError, so
309 # it is safe to use here.
310 raise ValueError(v)
313def _check_quantized(values):
314 # Cannot rely on `nest` because the leaves are tuples.
315 if not isinstance(values, (list, tuple)):
316 _check_failed(values)
317 if isinstance(values, tuple):
318 _ = [_check_int(v) for v in values]
319 else:
320 _ = [_check_quantized(v) for v in values]
323def _generate_isinstance_check(expected_types):
324 def inner(values):
325 for v in nest.flatten(values):
326 if not (isinstance(v, expected_types) or
327 (isinstance(v, np.ndarray) and
328 issubclass(v.dtype.type, expected_types))):
329 _check_failed(v)
331 return inner
333_check_int = _generate_isinstance_check(
334 (compat.integral_types, tensor_shape.Dimension))
335_check_float = _generate_isinstance_check(compat.real_types)
336_check_complex = _generate_isinstance_check(compat.complex_types)
337_check_str = _generate_isinstance_check(compat.bytes_or_text_types)
338_check_bool = _generate_isinstance_check(bool)
341def _check_not_tensor(values):
342 _ = [_check_failed(v) for v in nest.flatten(values)
343 if isinstance(v, core.Symbol)]
344# pylint: enable=invalid-name
346_TF_TO_IS_OK = {
347 dtypes.bool: _check_bool,
348 dtypes.complex128: _check_complex,
349 dtypes.complex64: _check_complex,
350 dtypes.float16: _check_float,
351 dtypes.float32: _check_float,
352 dtypes.float64: _check_float,
353 dtypes.int16: _check_int,
354 dtypes.int32: _check_int,
355 dtypes.int64: _check_int,
356 dtypes.int8: _check_int,
357 dtypes.qint16: _check_quantized,
358 dtypes.qint32: _check_quantized,
359 dtypes.qint8: _check_quantized,
360 dtypes.quint16: _check_quantized,
361 dtypes.quint8: _check_quantized,
362 dtypes.string: _check_str,
363 dtypes.uint16: _check_int,
364 dtypes.uint8: _check_int,
365 dtypes.uint32: _check_int,
366 dtypes.uint64: _check_int,
367}
370def _AssertCompatible(values, dtype):
371 if dtype is None:
372 fn = _check_not_tensor
373 else:
374 try:
375 fn = _TF_TO_IS_OK[dtype]
376 except KeyError:
377 # There isn't a specific fn, so we try to do the best possible.
378 if dtype.is_integer:
379 fn = _check_int
380 elif dtype.is_floating:
381 fn = _check_float
382 elif dtype.is_complex:
383 fn = _check_complex
384 elif dtype.is_quantized:
385 fn = _check_quantized
386 else:
387 fn = _check_not_tensor
389 try:
390 fn(values)
391 except ValueError as e:
392 [mismatch] = e.args
393 if dtype is None:
394 raise TypeError("Expected any non-tensor type, but got a tensor instead.")
395 else:
396 raise TypeError(f"Expected {dtype.name}, but got {mismatch} of type "
397 f"'{type(mismatch).__name__}'.")
400def _is_array_like(obj): # pylint: disable=invalid-name
401 """Check if a given object is array-like."""
402 if isinstance(obj, core.Symbol) and not isinstance(obj, core.Value): # pylint: disable=protected-access
403 # Tensor implements __array__ only so it can inform the user that it is not
404 # a valid array.
405 return False
407 # TODO(slebedev): an object could also implement C-level array interface.
408 if (callable(getattr(obj, "__array__", None)) or
409 isinstance(getattr(obj, "__array_interface__", None), dict)):
410 return True
412 try:
413 memoryview(obj)
414 except TypeError:
415 return False
416 else:
417 return not isinstance(obj, bytes)
420# pylint: disable=invalid-name
421@tf_export("make_tensor_proto")
422def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False,
423 allow_broadcast=False):
424 """Create a TensorProto.
426 In TensorFlow 2.0, representing tensors as protos should no longer be a
427 common workflow. That said, this utility function is still useful for
428 generating TF Serving request protos:
430 ```python
431 request = tensorflow_serving.apis.predict_pb2.PredictRequest()
432 request.model_spec.name = "my_model"
433 request.model_spec.signature_name = "serving_default"
434 request.inputs["images"].CopyFrom(tf.make_tensor_proto(X_new))
435 ```
437 `make_tensor_proto` accepts "values" of a python scalar, a python list, a
438 numpy ndarray, or a numpy scalar.
440 If "values" is a python scalar or a python list, make_tensor_proto
441 first convert it to numpy ndarray. If dtype is None, the
442 conversion tries its best to infer the right numpy data
443 type. Otherwise, the resulting numpy array has a compatible data
444 type with the given dtype.
446 In either case above, the numpy ndarray (either the caller provided
447 or the auto-converted) must have the compatible type with dtype.
449 `make_tensor_proto` then converts the numpy array to a tensor proto.
451 If "shape" is None, the resulting tensor proto represents the numpy
452 array precisely.
454 Otherwise, "shape" specifies the tensor's shape and the numpy array
455 can not have more elements than what "shape" specifies.
457 Args:
458 values: Values to put in the TensorProto.
459 dtype: Optional tensor_pb2 DataType value.
460 shape: List of integers representing the dimensions of tensor.
461 verify_shape: Boolean that enables verification of a shape of values.
462 allow_broadcast: Boolean that enables allowing scalars and 1 length vector
463 broadcasting. Cannot be true when verify_shape is true.
465 Returns:
466 A `TensorProto`. Depending on the type, it may contain data in the
467 "tensor_content" attribute, which is not directly useful to Python programs.
468 To access the values you should convert the proto back to a numpy ndarray
469 with `tf.make_ndarray(proto)`.
471 If `values` is a `TensorProto`, it is immediately returned; `dtype` and
472 `shape` are ignored.
474 Raises:
475 TypeError: if unsupported types are provided.
476 ValueError: if arguments have inappropriate values or if verify_shape is
477 True and shape of values is not equals to a shape from the argument.
479 """
480 if allow_broadcast and verify_shape:
481 raise ValueError("allow_broadcast and verify_shape are not both allowed.")
482 if isinstance(values, tensor_pb2.TensorProto):
483 return values
485 if dtype:
486 dtype = dtypes.as_dtype(dtype)
488 is_quantized = (
489 dtype in [
490 dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16,
491 dtypes.qint32
492 ])
494 if _is_array_like(values):
495 values = np.asarray(values)
497 # We first convert value to a numpy array or scalar.
498 if isinstance(values, (np.ndarray, np.generic)):
499 if dtype and dtype.is_numpy_compatible:
500 nparray = values.astype(dtype.as_numpy_dtype)
501 else:
502 nparray = values
503 else:
504 if values is None:
505 raise ValueError("None values not supported.")
506 # if dtype is provided, forces numpy array to be the type
507 # provided if possible.
508 if dtype and dtype.is_numpy_compatible:
509 np_dt = dtype.as_numpy_dtype
510 else:
511 np_dt = None
512 # If shape is None, numpy.prod returns None when dtype is not set, but
513 # raises exception when dtype is set to np.int64
514 if shape is not None and np.prod(shape, dtype=np.int64) == 0:
515 nparray = np.empty(shape, dtype=np_dt)
516 else:
517 _AssertCompatible(values, dtype)
518 nparray = np.array(values, dtype=np_dt)
519 # check to them.
520 # We need to pass in quantized values as tuples, so don't apply the shape
521 if (list(nparray.shape) != _GetDenseDimensions(values) and
522 not is_quantized):
523 raise ValueError(f"Expected values {values} to be a dense tensor with "
524 f"shape {_GetDenseDimensions(values)}, but got shape "
525 f"{list(nparray.shape)}.")
527 # python/numpy default float type is float64. We prefer float32 instead.
528 if (nparray.dtype == np.float64) and dtype is None:
529 nparray = nparray.astype(np.float32)
530 # python/numpy default int type is int64. We prefer int32 instead.
531 elif (nparray.dtype == np.int64) and dtype is None:
532 downcasted_array = nparray.astype(np.int32)
533 # Do not down cast if it leads to precision loss.
534 if np.array_equal(downcasted_array, nparray):
535 nparray = downcasted_array
537 # if dtype is provided, it must be compatible with what numpy
538 # conversion says.
539 numpy_dtype = dtypes.as_dtype(nparray.dtype)
540 if numpy_dtype is None:
541 raise TypeError(f"Unrecognized data type: {nparray.dtype}.")
543 # If dtype was specified and is a quantized type, we convert
544 # numpy_dtype back into the quantized version.
545 if is_quantized:
546 numpy_dtype = dtype
548 if dtype is not None and (not hasattr(dtype, "base_dtype") or
549 dtype.base_dtype != numpy_dtype.base_dtype):
550 raise TypeError(f"`dtype` {dtype} is not compatible with {values} of "
551 f"dtype {nparray.dtype}.")
553 # If shape is not given, get the shape from the numpy array.
554 if shape is None:
555 shape = nparray.shape
556 is_same_size = True
557 shape_size = nparray.size
558 else:
559 shape = [int(dim) for dim in shape]
560 shape_size = np.prod(shape, dtype=np.int64)
561 is_same_size = shape_size == nparray.size
563 if allow_broadcast:
564 if nparray.shape == (1,) or nparray.shape == tuple():
565 pass
566 elif nparray.size != shape_size:
567 raise TypeError(f"Expected Tensor's shape: {tuple(shape)}, but got "
568 f"{nparray.shape}.")
570 else:
571 if verify_shape and nparray.shape != tuple(shape):
572 raise TypeError(f"Expected Tensor's shape: {tuple(shape)}, but got "
573 f"{nparray.shape}.")
575 if nparray.size > shape_size:
576 raise ValueError("Too many elements provided. Takes at most "
577 f"{shape_size:d}, but got {nparray.size:d}.")
579 tensor_proto = tensor_pb2.TensorProto(
580 dtype=numpy_dtype.as_datatype_enum,
581 tensor_shape=tensor_shape.as_shape(shape).as_proto())
583 if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
584 if nparray.size * nparray.itemsize >= (1 << 31):
585 raise ValueError(
586 "Cannot create a tensor proto whose content is larger than 2GB.")
587 tensor_proto.tensor_content = nparray.tobytes()
588 return tensor_proto
590 # If we were not given values as a numpy array, compute the proto_values
591 # from the given values directly, to avoid numpy trimming nulls from the
592 # strings. Since values could be a list of strings, or a multi-dimensional
593 # list of lists that might or might not correspond to the given shape,
594 # we flatten it conservatively.
595 if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray):
596 proto_values = _FlattenToStrings(values)
598 # At this point, values may be a list of objects that we could not
599 # identify a common type for (hence it was inferred as
600 # np.object_/dtypes.string). If we are unable to convert it to a
601 # string, we raise a more helpful error message.
602 #
603 # Ideally, we'd be able to convert the elements of the list to a
604 # common type, but this type inference requires some thinking and
605 # so we defer it for now.
606 try:
607 str_values = [compat.as_bytes(x) for x in proto_values]
608 except TypeError:
609 raise TypeError(f"Failed to convert elements of {values} to Tensor. "
610 "Consider casting elements to a supported type. See "
611 "https://www.tensorflow.org/api_docs/python/tf/dtypes "
612 "for supported TF dtypes.")
613 tensor_proto.string_val.extend(str_values)
614 return tensor_proto
616 # TensorFlow expects C order (a.k.a., eigen row major).
617 proto_values = nparray.ravel()
619 append_fn = GetNumpyAppendFn(proto_values.dtype)
620 if append_fn is None:
621 raise TypeError(
622 f"Element type not supported in TensorProto: {numpy_dtype.name}.")
623 append_fn(tensor_proto, proto_values)
625 return tensor_proto
626# pylint: enable=invalid-name
629@tf_export("make_ndarray")
630def MakeNdarray(tensor):
631 """Create a numpy ndarray from a tensor.
633 Create a numpy ndarray with the same shape and data as the tensor.
635 For example:
637 ```python
638 # Tensor a has shape (2,3)
639 a = tf.constant([[1,2,3],[4,5,6]])
640 proto_tensor = tf.make_tensor_proto(a) # convert `tensor a` to a proto tensor
641 tf.make_ndarray(proto_tensor) # output: array([[1, 2, 3],
642 # [4, 5, 6]], dtype=int32)
643 # output has shape (2,3)
644 ```
646 Args:
647 tensor: A TensorProto.
649 Returns:
650 A numpy array with the tensor contents.
652 Raises:
653 TypeError: if tensor has unsupported type.
655 """
656 shape = [d.size for d in tensor.tensor_shape.dim]
657 num_elements = np.prod(shape, dtype=np.int64)
658 tensor_dtype = dtypes.as_dtype(tensor.dtype)
659 dtype = tensor_dtype.as_numpy_dtype
661 if tensor.tensor_content:
662 return (np.frombuffer(tensor.tensor_content,
663 dtype=dtype).copy().reshape(shape))
665 if tensor_dtype == dtypes.string:
666 # np.pad throws on these arrays of type np.object_.
667 values = list(tensor.string_val)
668 padding = num_elements - len(values)
669 if padding > 0:
670 last = values[-1] if values else ""
671 values.extend([last] * padding)
672 return np.array(values, dtype=dtype).reshape(shape)
674 if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
675 # the half_val field of the TensorProto stores the binary representation
676 # of the fp16: we need to reinterpret this as a proper float16
677 values = np.fromiter(tensor.half_val, dtype=np.uint16)
678 values.dtype = tensor_dtype.as_numpy_dtype
679 elif tensor_dtype == dtypes.float8_e5m2 or tensor_dtype == dtypes.float8_e4m3fn:
680 values = np.fromiter(tensor.float8_val, dtype=np.uint8)
681 values.dtype = tensor_dtype.as_numpy_dtype
682 elif tensor_dtype == dtypes.float32:
683 values = np.fromiter(tensor.float_val, dtype=dtype)
684 elif tensor_dtype == dtypes.float64:
685 values = np.fromiter(tensor.double_val, dtype=dtype)
686 elif tensor_dtype in [
687 dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8,
688 dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16
689 ]:
690 values = np.fromiter(tensor.int_val, dtype=dtype)
691 elif tensor_dtype == dtypes.int64:
692 values = np.fromiter(tensor.int64_val, dtype=dtype)
693 elif tensor_dtype == dtypes.uint32:
694 values = np.fromiter(tensor.uint32_val, dtype=dtype)
695 elif tensor_dtype == dtypes.uint64:
696 values = np.fromiter(tensor.uint64_val, dtype=dtype)
697 elif tensor_dtype == dtypes.complex64:
698 it = iter(tensor.scomplex_val)
699 values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
700 elif tensor_dtype == dtypes.complex128:
701 it = iter(tensor.dcomplex_val)
702 values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
703 elif tensor_dtype == dtypes.bool:
704 values = np.fromiter(tensor.bool_val, dtype=dtype)
705 else:
706 raise TypeError(f"Unsupported tensor type: {tensor.dtype}. See "
707 "https://www.tensorflow.org/api_docs/python/tf/dtypes "
708 "for supported TF dtypes.")
710 if values.size == 0:
711 return np.zeros(shape, dtype)
713 if values.size != num_elements:
714 values = np.pad(values, (0, num_elements - values.size), "edge")
716 return values.reshape(shape)
719def ShapeEquals(tensor_proto, shape):
720 """Returns True if "tensor_proto" has the given "shape".
722 Args:
723 tensor_proto: A TensorProto.
724 shape: A tensor shape, expressed as a TensorShape, list, or tuple.
726 Returns:
727 True if "tensor_proto" has the given "shape", otherwise False.
729 Raises:
730 TypeError: If "tensor_proto" is not a TensorProto, or shape is not a
731 TensorShape, list, or tuple.
732 """
733 if not isinstance(tensor_proto, tensor_pb2.TensorProto):
734 raise TypeError("`tensor_proto` must be a tensor_pb2.TensorProto object, "
735 f"but got type {type(tensor_proto)}.")
736 if isinstance(shape, tensor_shape_pb2.TensorShapeProto):
737 shape = [d.size for d in shape.dim]
738 elif not isinstance(shape, (list, tuple)):
739 raise TypeError("`shape` must be a list or tuple, but got type "
740 f"{type(shape)}.")
741 tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim]
742 return all(x == y for x, y in zip(tensor_shape_list, shape))
745def _ConstantValue(tensor, partial):
746 # TODO(touts): Support Variables?
747 if not isinstance(tensor, core.Symbol):
748 raise TypeError(f"{tensor!r} must be a Tensor, but got {type(tensor)}.")
749 if tensor.op.type == "Const":
750 return MakeNdarray(tensor.op.get_attr("value"))
751 elif tensor.op.type == "Shape":
752 input_shape = tensor.op.inputs[0].get_shape()
753 if input_shape.is_fully_defined():
754 return np.array(
755 [dim.value for dim in input_shape.dims],
756 dtype=tensor.dtype.as_numpy_dtype)
757 else:
758 return None
759 elif tensor.op.type == "Size":
760 input_shape = tensor.op.inputs[0].get_shape()
761 if input_shape.is_fully_defined():
762 return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32)
763 else:
764 return None
765 elif tensor.op.type == "Rank":
766 input_shape = tensor.op.inputs[0].get_shape()
767 if input_shape.ndims is not None:
768 return np.ndarray(
769 shape=(),
770 buffer=np.array([input_shape.ndims], dtype=np.int32),
771 dtype=np.int32)
772 else:
773 return None
774 elif tensor.op.type == "Range":
775 start = constant_value(tensor.op.inputs[0])
776 if start is None:
777 return None
778 limit = constant_value(tensor.op.inputs[1])
779 if limit is None:
780 return None
781 delta = constant_value(tensor.op.inputs[2])
782 if delta is None:
783 return None
784 return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype)
785 elif tensor.op.type == "Cast":
786 pre_cast = constant_value(tensor.op.inputs[0])
787 if pre_cast is None:
788 return None
789 cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
790 return pre_cast.astype(cast_dtype.as_numpy_dtype)
791 elif tensor.op.type == "Concat":
792 dim = constant_value(tensor.op.inputs[0])
793 if dim is None:
794 return None
795 values = []
796 for x in tensor.op.inputs[1:]:
797 value = constant_value(x)
798 if value is None:
799 return None
800 values.append(value)
801 return np.concatenate(values, axis=dim)
802 elif tensor.op.type == "ConcatV2":
803 dim = constant_value(tensor.op.inputs[-1])
804 if dim is None:
805 return None
806 values = []
807 for x in tensor.op.inputs[:-1]:
808 value = constant_value(x)
809 if value is None:
810 return None
811 values.append(value)
812 return np.concatenate(values, axis=dim)
813 elif tensor.op.type == "Pack":
814 values = []
815 # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
816 # and shouldn't be produced, but to deal sensibly with them here we check
817 # and return None.
818 if not tensor.op.inputs:
819 return None
820 # We can't handle axis != 0 Packs at the moment.
821 if tensor.op.get_attr("axis") != 0:
822 return None
823 for x in tensor.op.inputs:
824 value = constant_value(x, partial)
825 if value is None and not partial:
826 return None
827 values.append(value)
828 try:
829 return np.array(values)
830 except ValueError:
831 # If partial=True, some of the elements of values may be None.
832 return np.array(values, dtype=object)
833 elif tensor.op.type == "Unpack":
834 # We can't handle axis != 0 Unpacks at the moment.
835 if tensor.op.get_attr("axis") != 0:
836 return None
837 value = constant_value(tensor.op.inputs[0], partial)
838 if value is None:
839 return None
840 return value[tensor.value_index]
841 elif tensor.op.type == "Split":
842 dim = constant_value(tensor.op.inputs[0])
843 value = constant_value(tensor.op.inputs[1], partial)
844 if value is None or dim is None:
845 return None
846 split = np.split(value, tensor.op.get_attr("num_split"), dim)
847 return split[tensor.value_index]
848 elif tensor.op.type == "Fill":
849 fill_shape = tensor.shape
850 fill_value = constant_value(tensor.op.inputs[1])
851 if fill_shape.is_fully_defined() and fill_value is not None:
852 return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype)
853 else:
854 return None
855 elif tensor.op.type == "Equal":
856 value1 = constant_value(tensor.op.inputs[0])
857 if value1 is None:
858 return None
859 value2 = constant_value(tensor.op.inputs[1])
860 if value2 is None:
861 return None
862 return np.equal(value1, value2)
863 elif tensor.op.type == "NotEqual":
864 value1 = constant_value(tensor.op.inputs[0])
865 if value1 is None:
866 return None
867 value2 = constant_value(tensor.op.inputs[1])
868 if value2 is None:
869 return None
870 return np.not_equal(value1, value2)
871 elif tensor.op.type == "StopGradient":
872 return constant_value(tensor.op.inputs[0], partial)
873 elif tensor.op.type in ("CheckNumericsV2", "DebugIdentityV2", "Identity"):
874 return constant_value(tensor.op.inputs[0], partial)
875 else:
876 return None
879@tf_export("get_static_value")
880def constant_value(tensor, partial=False): # pylint: disable=invalid-name
881 """Returns the constant value of the given tensor, if efficiently calculable.
883 This function attempts to partially evaluate the given tensor, and
884 returns its value as a numpy ndarray if this succeeds.
886 Example usage:
888 >>> a = tf.constant(10)
889 >>> tf.get_static_value(a)
890 10
891 >>> b = tf.constant(20)
892 >>> tf.get_static_value(tf.add(a, b))
893 30
895 >>> # `tf.Variable` is not supported.
896 >>> c = tf.Variable(30)
897 >>> print(tf.get_static_value(c))
898 None
900 Using `partial` option is most relevant when calling `get_static_value` inside
901 a `tf.function`. Setting it to `True` will return the results but for the
902 values that cannot be evaluated will be `None`. For example:
904 ```python
905 class Foo:
906 def __init__(self):
907 self.a = tf.Variable(1)
908 self.b = tf.constant(2)
910 @tf.function
911 def bar(self, partial):
912 packed = tf.raw_ops.Pack(values=[self.a, self.b])
913 static_val = tf.get_static_value(packed, partial=partial)
914 tf.print(static_val)
916 f = Foo()
917 f.bar(partial=True) # `array([None, array(2, dtype=int32)], dtype=object)`
918 f.bar(partial=False) # `None`
919 ```
921 Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it
922 will no longer be possible to feed a different value for `tensor`. This allows
923 the result of this function to influence the graph that is constructed, and
924 permits static shape optimizations.
926 Args:
927 tensor: The Tensor to be evaluated.
928 partial: If True, the returned numpy array is allowed to have partially
929 evaluated values. Values that can't be evaluated will be None.
931 Returns:
932 A numpy ndarray containing the constant value of the given `tensor`,
933 or None if it cannot be calculated.
935 Raises:
936 TypeError: if tensor is not an ops.Tensor.
937 """
938 if isinstance(tensor, core.Value):
939 try:
940 return tensor.numpy()
941 except errors_impl.UnimplementedError:
942 # Some EagerTensors may not implement .numpy/resolve, e.g. parallel
943 # tensors with multiple components on different devices.
944 return None
945 if not is_tensor(tensor):
946 return tensor
947 if not isinstance(tensor, core.Symbol):
948 return None
949 ret = _ConstantValue(tensor, partial)
950 if ret is not None:
951 # The caller may now depend on the constant value of `tensor`, so we
952 # conservatively prevent it from being fed.
953 tensor.graph.prevent_feeding(tensor)
954 return ret
957def constant_value_as_shape(tensor): # pylint: disable=invalid-name
958 """A version of `constant_value()` that returns a `TensorShape`.
960 This version should be used when a constant tensor value is
961 interpreted as a (possibly partial) shape, e.g. in the shape
962 function for `tf.reshape()`. By explicitly requesting a
963 `TensorShape` as the return value, it is possible to represent
964 unknown dimensions; by contrast, `constant_value()` is
965 all-or-nothing.
967 Args:
968 tensor: The rank-0 or rank-1 Tensor to be evaluated.
970 Returns:
971 A `TensorShape` based on the constant value of the given `tensor`.
973 Raises:
974 ValueError: If the shape is rank-0 and is not statically known to be -1.
975 """
976 if isinstance(tensor, core.Value):
977 return tensor_shape.TensorShape(
978 [dim if dim != -1 else None for dim in tensor.numpy()])
980 if tensor.get_shape().ndims == 0:
981 value = constant_value(tensor)
982 if value is None:
983 raise ValueError(
984 "Received a scalar with unknown value as shape; require a statically "
985 "known scalar with value '-1' to describe an unknown shape.")
986 if value != -1:
987 raise ValueError(
988 f"Received a scalar value '{value}' as shape; require a statically "
989 "known scalar with value '-1' to describe an unknown shape.")
990 return tensor_shape.unknown_shape()
992 shape = tensor.get_shape().with_rank(1)
993 if shape == [0]:
994 return tensor_shape.TensorShape([])
995 elif tensor.op.type == "Cast":
996 pre_cast = constant_value_as_shape(tensor.op.inputs[0])
997 if pre_cast.dims is None:
998 # the input to cast has a totally undefined shape; just return that.
999 return pre_cast
1000 cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
1001 if cast_dtype not in (dtypes.int32, dtypes.int64):
1002 return tensor_shape.unknown_shape(shape.dims[0].value)
1003 dest_dtype_shape_array = np.array(
1004 [x if x is not None else -1 for x in pre_cast.as_list()]).astype(
1005 cast_dtype.as_numpy_dtype)
1006 return tensor_shape.TensorShape([
1007 x if x >= 0 else None
1008 for x in dest_dtype_shape_array])
1009 elif tensor.op.type == "Shape":
1010 return tensor.op.inputs[0].get_shape()
1011 elif tensor.op.type == "Pack":
1012 ret = tensor_shape.TensorShape([]) # Empty list.
1013 # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it
1014 # would not be rank 1.
1015 assert tensor.op.get_attr("axis") == 0
1016 for pack_input in tensor.op.inputs:
1017 # `pack_input` must be a scalar. Attempt to evaluate it, and append it
1018 # to `ret`.
1019 pack_input_val = constant_value(pack_input)
1020 if pack_input_val is None or pack_input_val < 0:
1021 new_dim = tensor_shape.Dimension(None)
1022 else:
1023 new_dim = tensor_shape.Dimension(pack_input_val)
1024 ret = ret.concatenate([new_dim])
1025 return ret
1026 elif tensor.op.type == "Concat":
1027 # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is
1028 # the only legal value when concatenating vectors, and it will
1029 # have been checked by a previous shape function.
1030 ret = tensor_shape.TensorShape([]) # Empty list.
1031 for concat_input in tensor.op.inputs[1:]:
1032 # `concat_input` must be a vector. Attempt to evaluate it as a shape,
1033 # and concatenate it with `ret`.
1034 ret = ret.concatenate(constant_value_as_shape(concat_input))
1035 return ret
1036 elif tensor.op.type == "ConcatV2":
1037 # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is
1038 # the only legal value when concatenating vectors, and it will
1039 # have been checked by a previous shape function.
1040 ret = tensor_shape.TensorShape([]) # Empty list.
1041 for concat_input in tensor.op.inputs[:-1]:
1042 # `concat_input` must be a vector. Attempt to evaluate it as a shape,
1043 # and concatenate it with `ret`.
1044 ret = ret.concatenate(constant_value_as_shape(concat_input))
1045 return ret
1046 elif tensor.op.type == "StridedSlice":
1047 try:
1048 begin = constant_value(tensor.op.inputs[1])
1049 end = constant_value(tensor.op.inputs[2])
1050 strides = constant_value(tensor.op.inputs[3])
1051 if begin is not None and end is not None and strides is not None:
1052 begin = begin[0]
1053 end = end[0]
1054 strides = strides[0]
1055 begin_mask = tensor.op.get_attr("begin_mask")
1056 if begin_mask == 1:
1057 begin = None
1058 end_mask = tensor.op.get_attr("end_mask")
1059 if end_mask == 1:
1060 end = None
1062 ellipsis_mask = tensor.op.get_attr("ellipsis_mask")
1063 new_axis_mask = tensor.op.get_attr("new_axis_mask")
1064 shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask")
1065 valid_attributes = (not ellipsis_mask and not new_axis_mask and
1066 not shrink_axis_mask and (not begin_mask or
1067 (begin_mask == 1)) and
1068 (not end_mask or (end_mask == 1)))
1069 if valid_attributes: # additional inputs not supported
1070 prev = constant_value_as_shape(tensor.op.inputs[0])
1071 prev = prev[begin:end:strides]
1072 ret = tensor_shape.TensorShape(prev)
1073 return ret
1075 except ValueError: # Could come from get_attr or slicing prev.
1076 pass
1077 except TypeError: # Could come from slicing prev.
1078 pass
1079 elif (tensor.op.type == "Placeholder" and
1080 tensor.op.graph.building_function and
1081 hasattr(tensor.op.graph, "internal_captures")):
1082 # If we are inside a FuncGraph try to lookup the constant value of the
1083 # corresponding external capture. Note that we only look at captures and
1084 # not the fed inputs because those can be fed different values in different
1085 # instantiations of the function call or different iterations of a
1086 # tf.while_loop.
1087 for i, capture in enumerate(tensor.op.graph.internal_captures):
1088 if capture is tensor:
1089 external_capture = tensor.op.graph.external_captures[i]
1090 return constant_value_as_shape(external_capture)
1092 ret = tensor_shape.unknown_shape(shape.dims[0].value)
1093 value = constant_value(tensor)
1094 if value is not None:
1095 ret = ret.merge_with(
1096 tensor_shape.TensorShape([d if d >= 0 else None for d in value]))
1097 return ret
1100@typing.runtime_checkable
1101class IsTensorLike(Protocol):
1103 def is_tensor_like(self): # pylint: disable=invalid-name
1104 pass
1107tf_type_classes = (internal.NativeObject, core.Tensor, IsTensorLike)
1110# TODO(mdan): Deprecate in favor of more static-friendly types.
1111@tf_export("is_tensor")
1112def is_tf_type(x): # pylint: disable=invalid-name
1113 """Checks whether `x` is a TF-native type that can be passed to many TF ops.
1115 Use `is_tensor` to differentiate types that can ingested by TensorFlow ops
1116 without any conversion (e.g., `tf.Tensor`, `tf.SparseTensor`, and
1117 `tf.RaggedTensor`) from types that need to be converted into tensors before
1118 they are ingested (e.g., numpy `ndarray` and Python scalars).
1120 For example, in the following code block:
1122 ```python
1123 if not tf.is_tensor(t):
1124 t = tf.convert_to_tensor(t)
1125 return t.shape, t.dtype
1126 ```
1128 we check to make sure that `t` is a tensor (and convert it if not) before
1129 accessing its `shape` and `dtype`. (But note that not all TensorFlow native
1130 types have shapes or dtypes; `tf.data.Dataset` is an example of a TensorFlow
1131 native type that has neither shape nor dtype.)
1133 Args:
1134 x: A python object to check.
1136 Returns:
1137 `True` if `x` is a TensorFlow-native type.
1138 """
1139 return isinstance(x, tf_type_classes)
1142# Deprecated alias for tensor_util.is_tf_type.
1143is_tensor = is_tf_type
1146def try_evaluate_constant(tensor): # pylint: disable=invalid-name
1147 """Evaluates a symbolic tensor as a constant.
1149 Args:
1150 tensor: a symbolic Tensor.
1152 Returns:
1153 ndarray if the evaluation succeeds, or None if it fails.
1154 """
1155 # pylint: disable=protected-access
1156 with tensor.graph._c_graph.get() as c_graph:
1157 return c_api.TF_TryEvaluateConstant_wrapper(c_graph, tensor._as_tf_output())
1158 # pylint: enable=protected-access