Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py: 35%
425 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes for tensor shape inference."""
16import functools
17import operator
18from typing import Optional, Sequence, Type
20from tensorflow.core.framework import tensor_shape_pb2
21from tensorflow.core.function import trace_type
22from tensorflow.core.protobuf import struct_pb2
23from tensorflow.python import tf2
24from tensorflow.python.eager import monitoring
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.saved_model import nested_structure_coder
27from tensorflow.python.types import trace
28from tensorflow.python.util.tf_export import tf_export
29from tensorflow.tools.docs import doc_controls
31_TENSORSHAPE_V2_OVERRIDE = None
33_api_usage_gauge = monitoring.BoolGauge(
34 "/tensorflow/api/v2_tensorshape",
35 "Whether tensor_shape.enable_v2_tensorshape() is called.")
38@tf_export(v1=["enable_v2_tensorshape"])
39def enable_v2_tensorshape():
40 """In TensorFlow 2.0, iterating over a TensorShape instance returns values.
42 This enables the new behavior.
44 Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but
45 it V2 it returns either an integer, or None.
47 Examples:
49 ```
50 #######################
51 # If you had this in V1:
52 value = tensor_shape[i].value
54 # Do this in V2 instead:
55 value = tensor_shape[i]
57 #######################
58 # If you had this in V1:
59 for dim in tensor_shape:
60 value = dim.value
61 print(value)
63 # Do this in V2 instead:
64 for value in tensor_shape:
65 print(value)
67 #######################
68 # If you had this in V1:
69 dim = tensor_shape[i]
70 dim.assert_is_compatible_with(other_shape) # or using any other shape method
72 # Do this in V2 instead:
73 if tensor_shape.rank is None:
74 dim = Dimension(None)
75 else:
76 dim = tensor_shape.dims[i]
77 dim.assert_is_compatible_with(other_shape) # or using any other shape method
79 # The V2 suggestion above is more explicit, which will save you from
80 # the following trap (present in V1):
81 # you might do in-place modifications to `dim` and expect them to be reflected
82 # in `tensor_shape[i]`, but they would not be.
83 ```
84 """
85 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name
86 _TENSORSHAPE_V2_OVERRIDE = True
87 logging.vlog(1, "Enabling v2 tensorshape")
88 _api_usage_gauge.get_cell().set(True)
91@tf_export(v1=["disable_v2_tensorshape"])
92def disable_v2_tensorshape():
93 """Disables the V2 TensorShape behavior and reverts to V1 behavior.
95 See docstring for `enable_v2_tensorshape` for details about the new behavior.
96 """
97 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name
98 _TENSORSHAPE_V2_OVERRIDE = False
99 logging.vlog(1, "Disabling v2 tensorshape")
100 _api_usage_gauge.get_cell().set(False)
103@tf_export(
104 "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"])
105def dimension_value(dimension):
106 """Compatibility utility required to allow for both V1 and V2 behavior in TF.
108 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to
109 coexist with the new behavior. This utility is a bridge between the two.
111 When accessing the value of a TensorShape dimension,
112 use this utility, like this:
114 ```
115 # If you had this in your V1 code:
116 value = tensor_shape[i].value
118 # Use `dimension_value` as direct replacement compatible with both V1 & V2:
119 value = dimension_value(tensor_shape[i])
121 # This would be the V2 equivalent:
122 value = tensor_shape[i] # Warning: this will return the dim value in V2!
123 ```
125 Args:
126 dimension: Either a `Dimension` instance, an integer, or None.
128 Returns:
129 A plain value, i.e. an integer or None.
130 """
131 if isinstance(dimension, Dimension):
132 return dimension.value
133 return dimension
136@tf_export(
137 "compat.dimension_at_index",
138 v1=["dimension_at_index", "compat.dimension_at_index"])
139def dimension_at_index(shape, index):
140 """Compatibility utility required to allow for both V1 and V2 behavior in TF.
142 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to
143 coexist with the new behavior. This utility is a bridge between the two.
145 If you want to retrieve the Dimension instance corresponding to a certain
146 index in a TensorShape instance, use this utility, like this:
148 ```
149 # If you had this in your V1 code:
150 dim = tensor_shape[i]
152 # Use `dimension_at_index` as direct replacement compatible with both V1 & V2:
153 dim = dimension_at_index(tensor_shape, i)
155 # Another possibility would be this, but WARNING: it only works if the
156 # tensor_shape instance has a defined rank.
157 dim = tensor_shape.dims[i] # `dims` may be None if the rank is undefined!
159 # In native V2 code, we recommend instead being more explicit:
160 if tensor_shape.rank is None:
161 dim = Dimension(None)
162 else:
163 dim = tensor_shape.dims[i]
165 # Being more explicit will save you from the following trap (present in V1):
166 # you might do in-place modifications to `dim` and expect them to be reflected
167 # in `tensor_shape[i]`, but they would not be (as the Dimension object was
168 # instantiated on the fly.
169 ```
171 Args:
172 shape: A TensorShape instance.
173 index: An integer index.
175 Returns:
176 A dimension object.
177 """
178 assert isinstance(shape, TensorShape)
179 if shape.rank is None:
180 return Dimension(None)
181 else:
182 return shape.dims[index]
185@tf_export(v1=["Dimension"])
186class Dimension(object):
187 """Represents the value of one dimension in a TensorShape.
189 @compatibility(TF2)
190 In TF2, members of a `TensorShape` object are integers. The `Dimension` class
191 is not part of TF2's data model.
193 Please refer to the [TensorShape section of the migration guide]
194 (https://www.tensorflow.org/guide/migrate/index#tensorshape) on common code
195 patterns adapting Dimension objects to a TF2 syntax.
196 @end_compatibility
197 """
199 __slots__ = ["_value"]
201 def __init__(self, value):
202 """Creates a new Dimension with the given value."""
203 if isinstance(value, int): # Most common case.
204 if value < 0:
205 raise ValueError("Dimension %d must be >= 0" % value)
206 self._value = value
207 elif value is None:
208 self._value = None
209 elif isinstance(value, Dimension):
210 self._value = value._value
211 else:
212 try:
213 # int(...) compensates for the int/long dichotomy on Python 2.X.
214 # TODO(b/143206389): Remove once we fully migrate to 3.X.
215 self._value = int(value.__index__())
216 except AttributeError:
217 raise TypeError(
218 "Dimension value must be integer or None or have "
219 "an __index__ method, got value '{0!r}' with type '{1!r}'".format(
220 value, type(value))) from None
221 if self._value < 0:
222 raise ValueError("Dimension %d must be >= 0" % self._value)
224 def __repr__(self):
225 return "Dimension(%s)" % repr(self._value)
227 def __str__(self):
228 value = self._value
229 return "?" if value is None else str(value)
231 def __eq__(self, other):
232 """Returns true if `other` has the same known value as this Dimension."""
233 try:
234 other = as_dimension(other)
235 except (TypeError, ValueError):
236 return NotImplemented
237 if self._value is None or other.value is None:
238 return None
239 return self._value == other.value
241 def __ne__(self, other):
242 """Returns true if `other` has a different known value from `self`."""
243 try:
244 other = as_dimension(other)
245 except (TypeError, ValueError):
246 return NotImplemented
247 if self._value is None or other.value is None:
248 return None
249 return self._value != other.value
251 def __bool__(self):
252 """Equivalent to `bool(self.value)`."""
253 return bool(self._value)
255 def __int__(self):
256 return self._value
258 # This is needed for Windows.
259 # See https://github.com/tensorflow/tensorflow/pull/9780
260 def __long__(self):
261 return self._value
263 def __index__(self):
264 # Allow use in Python 3 range
265 return self._value
267 @property
268 def value(self):
269 """The value of this dimension, or None if it is unknown."""
270 return self._value
272 # TODO(b/225058047): Reconsider semantics.
273 def is_compatible_with(self, other):
274 """Returns true if `other` is compatible with this Dimension.
276 Two known Dimensions are compatible if they have the same value.
277 An unknown Dimension is compatible with all other Dimensions.
279 Args:
280 other: Another Dimension.
282 Returns:
283 True if this Dimension and `other` are compatible.
284 """
285 other = as_dimension(other)
286 return (self._value is None or other.value is None or
287 self._value == other.value)
289 def assert_is_compatible_with(self, other):
290 """Raises an exception if `other` is not compatible with this Dimension.
292 Args:
293 other: Another Dimension.
295 Raises:
296 ValueError: If `self` and `other` are not compatible (see
297 is_compatible_with).
298 """
299 if not self.is_compatible_with(other):
300 raise ValueError("Dimensions %s and %s are not compatible" %
301 (self, other))
303 def merge_with(self, other):
304 """Returns a Dimension that combines the information in `self` and `other`.
306 Dimensions are combined as follows:
308 ```python
309 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(n)) ==
310 tf.compat.v1.Dimension(n)
311 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(None)) ==
312 tf.compat.v1.Dimension(n)
313 tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(n)) ==
314 tf.compat.v1.Dimension(n)
315 # equivalent to tf.compat.v1.Dimension(None)
316 tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(None))
318 # raises ValueError for n != m
319 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(m))
320 ```
322 Args:
323 other: Another Dimension.
325 Returns:
326 A Dimension containing the combined information of `self` and
327 `other`.
329 Raises:
330 ValueError: If `self` and `other` are not compatible (see
331 is_compatible_with).
332 """
333 other = as_dimension(other)
334 self.assert_is_compatible_with(other)
335 if self._value is None:
336 return Dimension(other.value)
337 else:
338 return Dimension(self._value)
340 def __add__(self, other):
341 """Returns the sum of `self` and `other`.
343 Dimensions are summed as follows:
345 ```python
346 tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(n) ==
347 tf.compat.v1.Dimension(m + n)
348 tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(None) # equiv. to
349 tf.compat.v1.Dimension(None)
350 tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(n) # equiv. to
351 tf.compat.v1.Dimension(None)
352 tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(None) # equiv. to
353 tf.compat.v1.Dimension(None)
354 ```
356 Args:
357 other: Another Dimension, or a value accepted by `as_dimension`.
359 Returns:
360 A Dimension whose value is the sum of `self` and `other`.
361 """
362 try:
363 other = as_dimension(other)
364 except (TypeError, ValueError):
365 return NotImplemented
366 if self._value is None or other.value is None:
367 return Dimension(None)
368 else:
369 return Dimension(self._value + other.value)
371 def __radd__(self, other):
372 """Returns the sum of `other` and `self`.
374 Args:
375 other: Another Dimension, or a value accepted by `as_dimension`.
377 Returns:
378 A Dimension whose value is the sum of `self` and `other`.
379 """
380 return self + other
382 def __sub__(self, other):
383 """Returns the subtraction of `other` from `self`.
385 Dimensions are subtracted as follows:
387 ```python
388 tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(n) ==
389 tf.compat.v1.Dimension(m - n)
390 tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(None) # equiv. to
391 tf.compat.v1.Dimension(None)
392 tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(n) # equiv. to
393 tf.compat.v1.Dimension(None)
394 tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(None) # equiv. to
395 tf.compat.v1.Dimension(None)
396 ```
398 Args:
399 other: Another Dimension, or a value accepted by `as_dimension`.
401 Returns:
402 A Dimension whose value is the subtraction of `other` from `self`.
403 """
404 try:
405 other = as_dimension(other)
406 except (TypeError, ValueError):
407 return NotImplemented
408 if self._value is None or other.value is None:
409 return Dimension(None)
410 else:
411 return Dimension(self._value - other.value)
413 def __rsub__(self, other):
414 """Returns the subtraction of `self` from `other`.
416 Args:
417 other: Another Dimension, or a value accepted by `as_dimension`.
419 Returns:
420 A Dimension whose value is the subtraction of `self` from `other`.
421 """
422 other = as_dimension(other)
423 if self._value is None or other.value is None:
424 return Dimension(None)
425 else:
426 return Dimension(other.value - self._value)
428 def __mul__(self, other):
429 """Returns the product of `self` and `other`.
431 Dimensions are summed as follows:
433 ```python
434 tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(n) ==
435 tf.compat.v1.Dimension(m * n)
436 tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(None) # equiv. to
437 tf.compat.v1.Dimension(None)
438 tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(n) # equiv. to
439 tf.compat.v1.Dimension(None)
440 tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(None) # equiv. to
441 tf.compat.v1.Dimension(None)
442 ```
444 Args:
445 other: Another Dimension, or a value accepted by `as_dimension`.
447 Returns:
448 A Dimension whose value is the product of `self` and `other`.
449 """
450 try:
451 other = as_dimension(other)
452 except (TypeError, ValueError):
453 return NotImplemented
455 if self._value is None or other.value is None:
456 return Dimension(None)
457 else:
458 return Dimension(self._value * other.value)
460 def __rmul__(self, other):
461 """Returns the product of `self` and `other`.
463 Args:
464 other: Another Dimension, or a value accepted by `as_dimension`.
466 Returns:
467 A Dimension whose value is the product of `self` and `other`.
468 """
469 return self * other
471 def __floordiv__(self, other):
472 """Returns the quotient of `self` and `other` rounded down.
474 Dimensions are divided as follows:
476 ```python
477 tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(n) ==
478 tf.compat.v1.Dimension(m // n)
479 tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(None) # equiv. to
480 tf.compat.v1.Dimension(None)
481 tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(n) # equiv. to
482 tf.compat.v1.Dimension(None)
483 tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(None) # equiv. to
484 tf.compat.v1.Dimension(None)
485 ```
487 Args:
488 other: Another Dimension, or a value accepted by `as_dimension`.
490 Returns:
491 A `Dimension` whose value is the integer quotient of `self` and `other`.
492 """
493 try:
494 other = as_dimension(other)
495 except (TypeError, ValueError):
496 return NotImplemented
497 if self._value is None or other.value is None:
498 return Dimension(None)
499 else:
500 return Dimension(self._value // other.value)
502 def __rfloordiv__(self, other):
503 """Returns the quotient of `other` and `self` rounded down.
505 Args:
506 other: Another Dimension, or a value accepted by `as_dimension`.
508 Returns:
509 A `Dimension` whose value is the integer quotient of `self` and `other`.
510 """
511 other = as_dimension(other)
512 if self._value is None or other.value is None:
513 return Dimension(None)
514 else:
515 return Dimension(other.value // self._value)
517 def __div__(self, other):
518 """DEPRECATED: Use `__floordiv__` via `x // y` instead.
520 This function exists only for backwards compatibility purposes; new code
521 should use `__floordiv__` via the syntax `x // y`. Using `x // y`
522 communicates clearly that the result rounds down, and is forward compatible
523 to Python 3.
525 Args:
526 other: Another `Dimension`.
528 Returns:
529 A `Dimension` whose value is the integer quotient of `self` and `other`.
530 """
531 return self // other
533 def __rdiv__(self, other):
534 """Use `__floordiv__` via `x // y` instead.
536 This function exists only to have a better error message. Instead of:
537 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`,
538 this function will explicitly call for usage of `//` instead.
540 Args:
541 other: Another `Dimension`.
543 Raises:
544 TypeError.
545 """
546 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', "
547 "please use // instead".format(type(other).__name__))
549 def __truediv__(self, other):
550 """Use `__floordiv__` via `x // y` instead.
552 This function exists only to have a better error message. Instead of:
553 `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`,
554 this function will explicitly call for usage of `//` instead.
556 Args:
557 other: Another `Dimension`.
559 Raises:
560 TypeError.
561 """
562 raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', "
563 "please use // instead".format(type(other).__name__))
565 def __rtruediv__(self, other):
566 """Use `__floordiv__` via `x // y` instead.
568 This function exists only to have a better error message. Instead of:
569 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`,
570 this function will explicitly call for usage of `//` instead.
572 Args:
573 other: Another `Dimension`.
575 Raises:
576 TypeError.
577 """
578 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', "
579 "please use // instead".format(type(other).__name__))
581 def __mod__(self, other):
582 """Returns `self` modulo `other`.
584 Dimension modulo are computed as follows:
586 ```python
587 tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(n) ==
588 tf.compat.v1.Dimension(m % n)
589 tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(None) # equiv. to
590 tf.compat.v1.Dimension(None)
591 tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(n) # equiv. to
592 tf.compat.v1.Dimension(None)
593 tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(None) # equiv. to
594 tf.compat.v1.Dimension(None)
595 ```
597 Args:
598 other: Another Dimension, or a value accepted by `as_dimension`.
600 Returns:
601 A Dimension whose value is `self` modulo `other`.
602 """
603 other = as_dimension(other)
604 if self._value is None or other.value is None:
605 return Dimension(None)
606 else:
607 return Dimension(self._value % other.value)
609 def __rmod__(self, other):
610 """Returns `other` modulo `self`.
612 Args:
613 other: Another Dimension, or a value accepted by `as_dimension`.
615 Returns:
616 A Dimension whose value is `other` modulo `self`.
617 """
618 other = as_dimension(other)
619 return other % self
621 def __lt__(self, other):
622 """Returns True if `self` is known to be less than `other`.
624 Dimensions are compared as follows:
626 ```python
627 (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(n)) == (m < n)
628 (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(None)) == None
629 (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(n)) == None
630 (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(None)) == None
631 ```
633 Args:
634 other: Another Dimension.
636 Returns:
637 The value of `self.value < other.value` if both are known, otherwise
638 None.
639 """
640 other = as_dimension(other)
641 if self._value is None or other.value is None:
642 return None
643 else:
644 return self._value < other.value
646 def __le__(self, other):
647 """Returns True if `self` is known to be less than or equal to `other`.
649 Dimensions are compared as follows:
651 ```python
652 (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(n)) == (m <= n)
653 (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(None)) == None
654 (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(n)) == None
655 (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(None)) == None
656 ```
658 Args:
659 other: Another Dimension.
661 Returns:
662 The value of `self.value <= other.value` if both are known, otherwise
663 None.
664 """
665 other = as_dimension(other)
666 if self._value is None or other.value is None:
667 return None
668 else:
669 return self._value <= other.value
671 def __gt__(self, other):
672 """Returns True if `self` is known to be greater than `other`.
674 Dimensions are compared as follows:
676 ```python
677 (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(n)) == (m > n)
678 (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(None)) == None
679 (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(n)) == None
680 (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(None)) == None
681 ```
683 Args:
684 other: Another Dimension.
686 Returns:
687 The value of `self.value > other.value` if both are known, otherwise
688 None.
689 """
690 other = as_dimension(other)
691 if self._value is None or other.value is None:
692 return None
693 else:
694 return self._value > other.value
696 def __ge__(self, other):
697 """Returns True if `self` is known to be greater than or equal to `other`.
699 Dimensions are compared as follows:
701 ```python
702 (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(n)) == (m >= n)
703 (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(None)) == None
704 (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(n)) == None
705 (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(None)) == None
706 ```
708 Args:
709 other: Another Dimension.
711 Returns:
712 The value of `self.value >= other.value` if both are known, otherwise
713 None.
714 """
715 other = as_dimension(other)
716 if self._value is None or other.value is None:
717 return None
718 else:
719 return self._value >= other.value
721 def __reduce__(self):
722 return Dimension, (self._value,)
725def as_dimension(value):
726 """Converts the given value to a Dimension.
728 A Dimension input will be returned unmodified.
729 An input of `None` will be converted to an unknown Dimension.
730 An integer input will be converted to a Dimension with that value.
732 Args:
733 value: The value to be converted.
735 Returns:
736 A Dimension corresponding to the given value.
737 """
738 if isinstance(value, Dimension):
739 return value
740 else:
741 return Dimension(value)
744@tf_export("TensorShape")
745class TensorShape(trace.TraceType, trace_type.Serializable):
746 """Represents the shape of a `Tensor`.
748 >>> t = tf.constant([[1,2,3],[4,5,6]])
749 >>> t.shape
750 TensorShape([2, 3])
752 `TensorShape` is the *static* shape representation of a Tensor.
753 During eager execution a Tensor always has a fully specified shape but
754 when tracing a `tf.function` it may be one of the following:
756 * *Fully-known shape:* has a known number of dimensions and a known size
757 for each dimension. e.g. `TensorShape([16, 256])`
758 * *Partially-known shape:* has a known number of dimensions, and an unknown
759 size for one or more dimension. e.g. `TensorShape([None, 256])`
760 * *Unknown shape:* has an unknown number of dimensions, and an unknown
761 size in all dimensions. e.g. `TensorShape(None)`
763 During function tracing `t.shape` will return a `TensorShape` object
764 representing the shape of Tensor as it is known during tracing.
765 This static representation will be partially defined in cases where the
766 exact shape depends on the values within the tensors. To get the
767 *dynamic* representation, please use `tf.shape(t)`
768 which will return Tensor representing the fully defined shape of `t`.
769 This way, you can express logic that manipulates the shapes of tensors by
770 building other tensors that depend on the dynamic shape of `t`.
772 Note: `tf.RaggedTensor.shape` also returns a `tf.TensorShape`,
773 the lengths of any ragged dimensions are unknown (`None`).
775 For example, this function prints the `TensorShape' (`t.shape`), when you
776 trace the function, and returns a tensor `tf.shape(t)` for given input `t`:
778 >>> @tf.function
779 ... def get_dynamic_shape(t):
780 ... print("tracing...")
781 ... print(f"static shape is {t.shape}")
782 ... return tf.shape(t)
784 Just calling the function traces it with a fully-specified static shape:
786 >>> result = get_dynamic_shape(tf.constant([[1, 1, 1], [0, 0, 0]]))
787 tracing...
788 static shape is (2, 3)
789 >>> result.numpy()
790 array([2, 3], dtype=int32)
792 But `tf.function` can also trace the function with a partially specified
793 (or even unspecified) shape:
795 >>> cf1 = get_dynamic_shape.get_concrete_function(tf.TensorSpec(
796 ... shape=[None, 2]))
797 tracing...
798 static shape is (None, 2)
799 >>> cf1(tf.constant([[1., 0],[1, 0],[1, 0]])).numpy()
800 array([3, 2], dtype=int32)
802 >>> cf2 = get_dynamic_shape.get_concrete_function(tf.TensorSpec(shape=None))
803 tracing...
804 static shape is <unknown>
805 >>> cf2(tf.constant([[[[[1., 0]]]]])).numpy()
806 array([1, 1, 1, 1, 2], dtype=int32)
808 If a tensor is produced by an operation of type `"Foo"`, its shape
809 may be inferred if there is a registered shape function for
810 `"Foo"`. See [Shape
811 functions](https://www.tensorflow.org/guide/create_op#shape_functions_in_c)
812 for details of shape functions and how to register them. Alternatively,
813 you may set the shape explicitly using `tf.Tensor.ensure_shape`.
814 """
815 __slots__ = ["_dims"]
817 def __init__(self, dims):
818 """Creates a new TensorShape with the given dimensions.
820 Args:
821 dims: A list of Dimensions, or None if the shape is unspecified.
823 Raises:
824 TypeError: If dims cannot be converted to a list of dimensions.
825 """
826 if isinstance(dims, (tuple, list)): # Most common case.
827 self._dims = tuple(as_dimension(d).value for d in dims)
828 elif dims is None:
829 self._dims = None
830 elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
831 if dims.unknown_rank:
832 self._dims = None
833 else:
834 self._dims = tuple(
835 # Protos store variable-size dimensions as -1
836 dim.size if dim.size != -1 else None
837 for dim in dims.dim
838 )
839 elif isinstance(dims, TensorShape):
840 self._dims = dims._dims
841 else:
842 try:
843 dims_iter = iter(dims)
844 except TypeError:
845 # Treat as a singleton dimension
846 self._dims = (as_dimension(dims).value,)
847 else:
848 self._dims = []
849 for d in dims_iter:
850 try:
851 self._dims.append(as_dimension(d).value)
852 except TypeError as e:
853 raise TypeError(
854 "Failed to convert '{0!r}' to a shape: '{1!r}'"
855 "could not be converted to a dimension. A shape should "
856 "either be single dimension (e.g. 10), or an iterable of "
857 "dimensions (e.g. [1, 10, None]).".format(dims, d)) from e
858 self._dims = tuple(self._dims)
860 @property
861 def _v2_behavior(self):
862 if _TENSORSHAPE_V2_OVERRIDE is None:
863 return tf2.enabled()
864 return _TENSORSHAPE_V2_OVERRIDE
866 def __repr__(self):
867 if self._v2_behavior:
868 if self._dims is not None:
869 return f"TensorShape({list(self._dims)})"
870 else:
871 return "TensorShape(None)"
872 else:
873 return f"TensorShape({self.dims})"
875 def __str__(self):
876 if self.rank is None:
877 return "<unknown>"
878 elif self.rank == 1:
879 if self._v2_behavior:
880 return "(%s,)" % self._dims[0]
881 else:
882 return "(%s,)" % self.dims[0]
883 else:
884 if self._v2_behavior:
885 return "(%s)" % ", ".join(str(d) for d in self._dims)
886 else:
887 return "(%s)" % ", ".join(str(d) for d in self.dims)
889 @property
890 def rank(self):
891 """Returns the rank of this shape, or None if it is unspecified."""
892 if self._dims is not None:
893 return len(self._dims)
894 return None
896 @property
897 def dims(self):
898 """Deprecated. Returns list of dimensions for this shape.
900 Suggest `TensorShape.as_list` instead.
902 Returns:
903 A list containing `tf.compat.v1.Dimension`s, or None if the shape is
904 unspecified.
905 """
906 if self._dims is None:
907 return None
908 return [as_dimension(d) for d in self._dims]
910 @property
911 def ndims(self):
912 """Deprecated accessor for `rank`."""
913 return self.rank
915 def __len__(self):
916 """Returns the rank of this shape, or raises ValueError if unspecified."""
917 if self._dims is None:
918 raise ValueError("Cannot take the length of shape with unknown rank.")
919 return len(self._dims)
921 def __bool__(self):
922 """Returns True if this shape contains non-zero information."""
923 return self._dims is not None
925 # Python 3 wants __bool__, Python 2.7 wants __nonzero__
926 __nonzero__ = __bool__
928 def __iter__(self):
929 """Returns `self.dims` if the rank is known, otherwise raises ValueError."""
930 if self._dims is None:
931 raise ValueError("Cannot iterate over a shape with unknown rank.")
932 else:
933 if self._v2_behavior:
934 return iter(d for d in self._dims)
935 else:
936 return iter(d for d in self.dims)
938 def __getitem__(self, key):
939 """Returns the value of a dimension or a shape, depending on the key.
941 Args:
942 key: If `key` is an integer, returns the dimension at that index;
943 otherwise if `key` is a slice, returns a TensorShape whose dimensions
944 are those selected by the slice from `self`.
946 Returns:
947 An integer if `key` is an integer, or a `TensorShape` if `key` is a
948 slice.
950 Raises:
951 ValueError: If `key` is a slice and `self` is completely unknown and
952 the step is set.
953 """
954 if self._dims is not None:
955 if isinstance(key, slice):
956 return TensorShape(self._dims[key])
957 else:
958 if self._v2_behavior:
959 return self._dims[key]
960 else:
961 return self.dims[key]
962 else:
963 if isinstance(key, slice):
964 start = key.start if key.start is not None else 0
965 stop = key.stop
967 if key.step is not None:
968 # TODO(mrry): Handle these maybe.
969 raise ValueError("Steps are not yet handled")
970 if stop is None:
971 # NOTE(mrry): This implies that TensorShape(None) is compatible with
972 # TensorShape(None)[1:], which is obviously not true. It would be
973 # possible to track the number of dimensions symbolically,
974 # and perhaps we should do that.
975 return unknown_shape()
976 elif start < 0 or stop < 0:
977 # TODO(mrry): Handle this better, as it will be useful for handling
978 # suffixes of otherwise unknown shapes.
979 return unknown_shape()
980 else:
981 return unknown_shape(rank=stop - start)
982 else:
983 if self._v2_behavior:
984 return None
985 else:
986 return Dimension(None)
988 def num_elements(self):
989 """Returns the total number of elements, or none for incomplete shapes."""
990 if self.is_fully_defined():
991 return functools.reduce(operator.mul, self.as_list(), 1)
992 else:
993 return None
995 def merge_with(self, other):
996 """Returns a `TensorShape` combining the information in `self` and `other`.
998 The dimensions in `self` and `other` are merged element-wise,
999 according to the rules below:
1001 ```python
1002 Dimension(n).merge_with(Dimension(None)) == Dimension(n)
1003 Dimension(None).merge_with(Dimension(n)) == Dimension(n)
1004 Dimension(None).merge_with(Dimension(None)) == Dimension(None)
1005 # raises ValueError for n != m
1006 Dimension(n).merge_with(Dimension(m))
1007 ```
1008 >> ts = tf.TensorShape([1,2])
1009 >> ot1 = tf.TensorShape([1,2])
1010 >> ts.merge_with(ot).as_list()
1011 [1,2]
1013 >> ot2 = tf.TensorShape([1,None])
1014 >> ts.merge_with(ot2).as_list()
1015 [1,2]
1017 >> ot3 = tf.TensorShape([None, None])
1018 >> ot3.merge_with(ot2).as_list()
1019 [1, None]
1021 Args:
1022 other: Another `TensorShape`.
1024 Returns:
1025 A `TensorShape` containing the combined information of `self` and
1026 `other`.
1028 Raises:
1029 ValueError: If `self` and `other` are not compatible.
1030 """
1031 other = as_shape(other)
1032 if self.dims is None:
1033 return other
1034 if other.dims is None:
1035 return self
1036 else:
1037 try:
1038 self.assert_same_rank(other)
1039 new_dims = [
1040 dim.merge_with(other_dim)
1041 for dim, other_dim in zip(self.dims, other.dims)
1042 ]
1043 return TensorShape(new_dims)
1044 except ValueError:
1045 raise ValueError("Shapes %s and %s are not compatible" % (self, other))
1047 def __add__(self, other):
1048 return self.concatenate(other)
1050 def __radd__(self, other):
1051 if not isinstance(other, TensorShape):
1052 other = TensorShape(other)
1053 return other.concatenate(self)
1055 def concatenate(self, other):
1056 """Returns the concatenation of the dimension in `self` and `other`.
1058 *N.B.* If either `self` or `other` is completely unknown,
1059 concatenation will discard information about the other shape. In
1060 future, we might support concatenation that preserves this
1061 information for use with slicing.
1063 Args:
1064 other: Another `TensorShape`.
1066 Returns:
1067 A `TensorShape` whose dimensions are the concatenation of the
1068 dimensions in `self` and `other`.
1069 """
1070 # TODO(mrry): Handle the case where we concatenate a known shape with a
1071 # completely unknown shape, so that we can use the partial information.
1072 other = as_shape(other)
1073 if self.dims is None or other.dims is None:
1074 return unknown_shape()
1075 else:
1076 return TensorShape(self.dims + other.dims)
1078 def assert_same_rank(self, other):
1079 """Raises an exception if `self` and `other` do not have compatible ranks.
1081 Args:
1082 other: Another `TensorShape`.
1084 Raises:
1085 ValueError: If `self` and `other` do not represent shapes with the
1086 same rank.
1087 """
1088 other = as_shape(other)
1089 if self.rank is not None and other.rank is not None:
1090 if self.rank != other.rank:
1091 raise ValueError("Shapes %s and %s must have the same rank" %
1092 (self, other))
1094 def assert_has_rank(self, rank):
1095 """Raises an exception if `self` is not compatible with the given `rank`.
1097 Args:
1098 rank: An integer.
1100 Raises:
1101 ValueError: If `self` does not represent a shape with the given `rank`.
1102 """
1103 if self.rank not in (None, rank):
1104 raise ValueError("Shape %s must have rank %d" % (self, rank))
1106 def with_rank(self, rank):
1107 """Returns a shape based on `self` with the given rank.
1109 This method promotes a completely unknown shape to one with a
1110 known rank.
1112 Args:
1113 rank: An integer.
1115 Returns:
1116 A shape that is at least as specific as `self` with the given rank.
1118 Raises:
1119 ValueError: If `self` does not represent a shape with the given `rank`.
1120 """
1121 try:
1122 return self.merge_with(unknown_shape(rank=rank))
1123 except ValueError:
1124 raise ValueError("Shape %s must have rank %d" % (self, rank))
1126 def with_rank_at_least(self, rank):
1127 """Returns a shape based on `self` with at least the given rank.
1129 Args:
1130 rank: An integer.
1132 Returns:
1133 A shape that is at least as specific as `self` with at least the given
1134 rank.
1136 Raises:
1137 ValueError: If `self` does not represent a shape with at least the given
1138 `rank`.
1139 """
1140 if self.rank is not None and self.rank < rank:
1141 raise ValueError("Shape %s must have rank at least %d" % (self, rank))
1142 else:
1143 return self
1145 def with_rank_at_most(self, rank):
1146 """Returns a shape based on `self` with at most the given rank.
1148 Args:
1149 rank: An integer.
1151 Returns:
1152 A shape that is at least as specific as `self` with at most the given
1153 rank.
1155 Raises:
1156 ValueError: If `self` does not represent a shape with at most the given
1157 `rank`.
1158 """
1159 if self.rank is not None and self.rank > rank:
1160 raise ValueError("Shape %s must have rank at most %d" % (self, rank))
1161 else:
1162 return self
1164 def is_subtype_of(self, other: trace.TraceType) -> bool:
1165 """Returns True iff `self` is subtype of `other`.
1167 Shape A is a subtype of shape B if shape B can successfully represent it:
1169 * A `TensorShape` of any rank is a subtype of `TensorShape(None)`.
1171 * TensorShapes of equal ranks are covariant, i.e.
1172 `TensorShape([A1, A2, ..])` is a subtype of
1173 `TensorShape([B1, B2, ..])` iff An is a subtype of Bn.
1175 An is subtype of Bn iff An == Bn or Bn is None.
1177 * TensorShapes of different defined ranks have no subtyping relation.
1179 The subtyping relation is reflexive and transitive, but not symmetric.
1181 Some examples:
1182 * `TensorShape([32, 784])` is a subtype of `TensorShape(None)`, and
1183 `TensorShape([4, 4])` is also a subtype of `TensorShape(None)` but
1184 `TensorShape([32, 784])` and `TensorShape([4, 4])` are not subtypes of
1185 each other.
1187 * All two-dimensional shapes are subtypes of `TensorShape([None, None])`,
1188 such as `TensorShape([32, 784])`. There is no subtype relationship with,
1189 for example, `TensorShape([None])` or `TensorShape([None, None, None])`.
1191 * `TensorShape([32, None])` is also a subtype of `TensorShape([None, None])`
1192 and `TensorShape(None)`. It is not a subtype of, for example,
1193 `TensorShape([32])`, `TensorShape([32, None, 1])`,
1194 `TensorShape([64, None])` or `TensorShape([None, 32])`.
1196 * `TensorShape([32, 784])` is a subtype of itself, and also
1197 `TensorShape([32, None])`, `TensorShape([None, 784])`,
1198 `TensorShape([None, None])` and `TensorShape(None)`.
1199 It has no subtype relation with, for example, `TensorShape([32, 1, 784])`
1200 or `TensorShape([None])`.
1202 Args:
1203 other: Another `TensorShape`.
1205 Returns:
1206 True iff `self` is subtype of `other`.
1208 """
1209 if not isinstance(other, TensorShape):
1210 return False
1212 # All Tensors are subtypes of a Tensor with no shape.
1213 if other.rank is None:
1214 return True
1216 # Tensor with a defined shape can only be subtype of another with a defined
1217 # shape if they have the same number of dimensions.
1218 if self.rank != other.rank:
1219 return False
1221 # A Tensor is a subtype if each corresponding dimension is a subtype.
1222 return all(o is None or s == o for s, o in zip(self._dims, other._dims)) # pylint: disable=protected-access
1224 def most_specific_common_supertype(
1225 self, others: Sequence[trace.TraceType]) -> Optional["TensorShape"]:
1226 """Returns the most specific supertype `TensorShape` of self and others.
1228 * `TensorShape([None, 1])` is the most specific `TensorShape` supertyping
1229 both `TensorShape([2, 1])` and `TensorShape([5, 1])`. Note that
1230 `TensorShape(None)` is also a supertype but it is not "most specific".
1232 * `TensorShape([1, 2, 3])` is the most specific `TensorShape` supertyping
1233 both `TensorShape([1, 2, 3])` and `TensorShape([1, 2, 3]`). There are
1234 other less specific TensorShapes that supertype above mentioned
1235 TensorShapes, e.g. `TensorShape([1, 2, None])`, `TensorShape(None)`.
1237 * `TensorShape([None, None])` is the most specific `TensorShape`
1238 supertyping both `TensorShape([2, None])` and `TensorShape([None, 3])`.
1239 As always, `TensorShape(None)` is also a supertype but not the most
1240 specific one.
1242 * `TensorShape(None`) is the only `TensorShape` supertyping both
1243 `TensorShape([1, 2, 3])` and `TensorShape([1, 2])`. In general, any two
1244 shapes that have different ranks will only have `TensorShape(None)`
1245 as a common supertype.
1247 * `TensorShape(None)` is the only `TensorShape` supertyping both
1248 `TensorShape([1, 2, 3])` and `TensorShape(None)`. In general, the common
1249 supertype of any shape with `TensorShape(None)` is `TensorShape(None)`.
1251 Args:
1252 others: Sequence of `TensorShape`.
1254 Returns:
1255 A `TensorShape` which is the most specific supertype shape of `self`
1256 and `others`. None if it does not exist.
1257 """
1258 if any(not isinstance(other, TensorShape) for other in others):
1259 return None
1261 # A Rankless TensorShape is already a global supertype so we return another
1262 # instance of it.
1263 if self.rank is None:
1264 return unknown_shape()
1266 # A Rankless TensorShape is the most specific supertype for shapes whose
1267 # ranks do not match.
1268 if any(other.dims is None or self.rank != other.rank for other in others):
1269 return unknown_shape()
1271 # Retain the integer dimension if it is the same across all others, else
1272 # use an undefined dimension.
1273 dims = [
1274 dim if all(dim == other._dims[i]
1275 for other in others) else None
1276 for i, dim in enumerate(self._dims)
1277 ]
1278 return TensorShape(dims)
1280 @doc_controls.do_not_doc_inheritable
1281 def placeholder_value(self, placeholder_context):
1282 raise NotImplementedError("A graph placeholder is not currently supported"
1283 "for an object of type: TensorShape.")
1285 @classmethod
1286 def experimental_type_proto(cls) -> Type[tensor_shape_pb2.TensorShapeProto]:
1287 """Returns the type of proto associated with TensorShape serialization."""
1288 return tensor_shape_pb2.TensorShapeProto
1290 @classmethod
1291 def experimental_from_proto(
1292 cls, proto: tensor_shape_pb2.TensorShapeProto) -> "TensorShape":
1293 """Returns a TensorShape instance based on the serialized proto."""
1294 return TensorShape(proto)
1296 def experimental_as_proto(self) -> tensor_shape_pb2.TensorShapeProto:
1297 """Returns a proto representation of the TensorShape instance."""
1298 return self.as_proto()
1300 # TODO(b/216206374): Consider deprecation at TraceType release.
1301 def is_compatible_with(self, other):
1302 """Returns True iff `self` is compatible with `other`.
1304 Two possibly-partially-defined shapes are compatible if there
1305 exists a fully-defined shape that both shapes can represent. Thus,
1306 compatibility allows the shape inference code to reason about
1307 partially-defined shapes. For example:
1309 * TensorShape(None) is compatible with all shapes.
1311 * TensorShape([None, None]) is compatible with all two-dimensional
1312 shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
1313 not compatible with, for example, TensorShape([None]) or
1314 TensorShape([None, None, None]).
1316 * TensorShape([32, None]) is compatible with all two-dimensional shapes
1317 with size 32 in the 0th dimension, and also TensorShape([None, None])
1318 and TensorShape(None). It is not compatible with, for example,
1319 TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
1321 * TensorShape([32, 784]) is compatible with itself, and also
1322 TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
1323 None]) and TensorShape(None). It is not compatible with, for example,
1324 TensorShape([32, 1, 784]) or TensorShape([None]).
1326 The compatibility relation is reflexive and symmetric, but not
1327 transitive. For example, TensorShape([32, 784]) is compatible with
1328 TensorShape(None), and TensorShape(None) is compatible with
1329 TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with
1330 TensorShape([4, 4]).
1332 Args:
1333 other: Another TensorShape.
1335 Returns:
1336 True iff `self` is compatible with `other`.
1338 """
1339 other = as_shape(other)
1340 if self.dims is not None and other.dims is not None:
1341 if self.rank != other.rank:
1342 return False
1343 for x_dim, y_dim in zip(self.dims, other.dims):
1344 if not x_dim.is_compatible_with(y_dim):
1345 return False
1346 return True
1348 def assert_is_compatible_with(self, other):
1349 """Raises exception if `self` and `other` do not represent the same shape.
1351 This method can be used to assert that there exists a shape that both
1352 `self` and `other` represent.
1354 Args:
1355 other: Another TensorShape.
1357 Raises:
1358 ValueError: If `self` and `other` do not represent the same shape.
1359 """
1360 if not self.is_compatible_with(other):
1361 raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1363 def most_specific_compatible_shape(self, other):
1364 """Returns the most specific TensorShape compatible with `self` and `other`.
1366 * TensorShape([None, 1]) is the most specific TensorShape compatible with
1367 both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
1368 TensorShape(None) is also compatible with above mentioned TensorShapes.
1370 * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with
1371 both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
1372 less specific TensorShapes compatible with above mentioned TensorShapes,
1373 e.g. TensorShape([1, 2, None]), TensorShape(None).
1375 Args:
1376 other: Another `TensorShape`.
1378 Returns:
1379 A `TensorShape` which is the most specific compatible shape of `self`
1380 and `other`.
1381 """
1383 other = as_shape(other)
1384 if self.dims is None or other.dims is None or self.rank != other.rank:
1385 return unknown_shape()
1387 dims = [
1388 d1 if d1 is not None and d2 is not None and d1 == d2 else None
1389 for d1, d2 in zip(self.dims, other.dims)
1390 ]
1391 return TensorShape(dims)
1393 def is_fully_defined(self):
1394 """Returns True iff `self` is fully defined in every dimension."""
1395 return (self._dims is not None and
1396 all(dim is not None for dim in self._dims))
1398 def assert_is_fully_defined(self):
1399 """Raises an exception if `self` is not fully defined in every dimension.
1401 Raises:
1402 ValueError: If `self` does not have a known value for every dimension.
1403 """
1404 if not self.is_fully_defined():
1405 raise ValueError("Shape %s is not fully defined" % self)
1407 def as_list(self):
1408 """Returns a list of integers or `None` for each dimension.
1410 Returns:
1411 A list of integers or `None` for each dimension.
1413 Raises:
1414 ValueError: If `self` is an unknown shape with an unknown rank.
1415 """
1416 if self._dims is None:
1417 raise ValueError("as_list() is not defined on an unknown TensorShape.")
1418 return list(self._dims)
1420 def as_proto(self):
1421 """Returns this shape as a `TensorShapeProto`."""
1422 if self._dims is None:
1423 return tensor_shape_pb2.TensorShapeProto(unknown_rank=True)
1424 else:
1425 return tensor_shape_pb2.TensorShapeProto(dim=[
1426 tensor_shape_pb2.TensorShapeProto.Dim(
1427 size=-1 if d is None else d) for d in self._dims
1428 ])
1430 def __eq__(self, other):
1431 """Returns True if `self` is equivalent to `other`.
1433 It first tries to convert `other` to `TensorShape`. `TypeError` is thrown
1434 when the conversion fails. Otherwise, it compares each element in the
1435 TensorShape dimensions.
1437 * Two *Fully known* shapes, return True iff each element is equal.
1438 >>> t_a = tf.TensorShape([1,2])
1439 >>> a = [1, 2]
1440 >>> t_b = tf.TensorShape([1,2])
1441 >>> t_c = tf.TensorShape([1,2,3])
1442 >>> t_a.__eq__(a)
1443 True
1444 >>> t_a.__eq__(t_b)
1445 True
1446 >>> t_a.__eq__(t_c)
1447 False
1449 * Two *Partially-known* shapes, return True iff each element is equal.
1450 >>> p_a = tf.TensorShape([1,None])
1451 >>> p_b = tf.TensorShape([1,None])
1452 >>> p_c = tf.TensorShape([2,None])
1453 >>> p_a.__eq__(p_b)
1454 True
1455 >>> t_a.__eq__(p_a)
1456 False
1457 >>> p_a.__eq__(p_c)
1458 False
1460 * Two *Unknown shape*, return True.
1461 >>> unk_a = tf.TensorShape(None)
1462 >>> unk_b = tf.TensorShape(None)
1463 >>> unk_a.__eq__(unk_b)
1464 True
1465 >>> unk_a.__eq__(t_a)
1466 False
1468 Args:
1469 other: A `TensorShape` or type that can be converted to `TensorShape`.
1471 Returns:
1472 True if the dimensions are all equal.
1474 Raises:
1475 TypeError if `other` can not be converted to `TensorShape`.
1476 """
1478 try:
1479 other = as_shape(other)
1480 except TypeError:
1481 return NotImplemented
1483 return self._dims == other._dims
1485 def __hash__(self):
1486 return hash(self._dims)
1488 def __reduce__(self):
1489 return TensorShape, (self.dims,)
1491 def __concat__(self, other):
1492 return self.concatenate(other)
1494trace_type.register_serializable(TensorShape)
1497class _TensorShapeCodec:
1498 """Codec for `TensorShape`."""
1500 def can_encode(self, pyobj):
1501 return isinstance(pyobj, TensorShape)
1503 def do_encode(self, tensor_shape_value, encode_fn):
1504 del encode_fn
1505 encoded_tensor_shape = struct_pb2.StructuredValue()
1506 encoded_tensor_shape.tensor_shape_value.CopyFrom(
1507 tensor_shape_value.as_proto())
1508 return encoded_tensor_shape
1510 def can_decode(self, value):
1511 return value.HasField("tensor_shape_value")
1513 def do_decode(self, value, decode_fn):
1514 del decode_fn
1515 return TensorShape(value.tensor_shape_value)
1518nested_structure_coder.register_codec(_TensorShapeCodec())
1521def as_shape(shape):
1522 """Converts the given object to a TensorShape."""
1523 if isinstance(shape, TensorShape):
1524 return shape
1525 else:
1526 return TensorShape(shape)
1529def unknown_shape(rank=None, **kwargs):
1530 """Returns an unknown TensorShape, optionally with a known rank.
1532 Args:
1533 rank: (Optional) If specified, the number of dimensions in the shape.
1534 **kwargs: For backwards compatibility.
1536 Returns:
1537 An unknown TensorShape.
1539 Raises:
1540 TypeError: In case of invalid arguments.
1541 """
1542 if rank is None and "ndims" in kwargs:
1543 rank = kwargs.pop("ndims")
1544 if kwargs:
1545 raise TypeError("Unknown argument: %s" % kwargs)
1546 if rank is None:
1547 return TensorShape(None)
1548 else:
1549 return TensorShape([Dimension(None)] * rank)