Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorboard/compat/tensorflow_stub/tensor_shape.py: 26%
317 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes for tensor shape inference."""
17from . import compat, dtypes
18from tensorboard.compat.proto import tensor_shape_pb2
21# @tf_export("Dimension")
22class Dimension:
23 """Represents the value of one dimension in a TensorShape."""
25 def __init__(self, value):
26 """Creates a new Dimension with the given value."""
27 if value is None:
28 self._value = None
29 elif isinstance(value, dtypes.DType):
30 raise TypeError("Cannot convert %s to Dimension" % value)
31 else:
32 self._value = int(value)
33 if (
34 not isinstance(value, compat.bytes_or_text_types)
35 and self._value != value
36 ):
37 raise ValueError("Ambiguous dimension: %s" % value)
38 if self._value < 0:
39 raise ValueError("Dimension %d must be >= 0" % self._value)
41 def __repr__(self):
42 return "Dimension(%s)" % repr(self._value)
44 def __str__(self):
45 value = self._value
46 return "?" if value is None else str(value)
48 def __eq__(self, other):
49 """Returns true if `other` has the same known value as this
50 Dimension."""
51 try:
52 other = as_dimension(other)
53 except (TypeError, ValueError):
54 return NotImplemented
55 if self._value is None or other.value is None:
56 return None
57 return self._value == other.value
59 def __ne__(self, other):
60 """Returns true if `other` has a different known value from `self`."""
61 try:
62 other = as_dimension(other)
63 except (TypeError, ValueError):
64 return NotImplemented
65 if self._value is None or other.value is None:
66 return None
67 return self._value != other.value
69 def __int__(self):
70 return self._value
72 # This is needed for Windows.
73 # See https://github.com/tensorflow/tensorflow/pull/9780
74 def __long__(self):
75 return self._value
77 def __index__(self):
78 # Allow use in Python 3 range
79 return self._value
81 @property
82 def value(self):
83 """The value of this dimension, or None if it is unknown."""
84 return self._value
86 def is_convertible_with(self, other):
87 """Returns true if `other` is convertible with this Dimension.
89 Two known Dimensions are convertible if they have the same value.
90 An unknown Dimension is convertible with all other Dimensions.
92 Args:
93 other: Another Dimension.
95 Returns:
96 True if this Dimension and `other` are convertible.
97 """
98 other = as_dimension(other)
99 return (
100 self._value is None
101 or other.value is None
102 or self._value == other.value
103 )
105 def assert_is_convertible_with(self, other):
106 """Raises an exception if `other` is not convertible with this
107 Dimension.
109 Args:
110 other: Another Dimension.
112 Raises:
113 ValueError: If `self` and `other` are not convertible (see
114 is_convertible_with).
115 """
116 if not self.is_convertible_with(other):
117 raise ValueError(
118 "Dimensions %s and %s are not convertible" % (self, other)
119 )
121 def merge_with(self, other):
122 """Returns a Dimension that combines the information in `self` and
123 `other`.
125 Dimensions are combined as follows:
127 ```python
128 tf.Dimension(n) .merge_with(tf.Dimension(n)) == tf.Dimension(n)
129 tf.Dimension(n) .merge_with(tf.Dimension(None)) == tf.Dimension(n)
130 tf.Dimension(None).merge_with(tf.Dimension(n)) == tf.Dimension(n)
131 tf.Dimension(None).merge_with(tf.Dimension(None)) == tf.Dimension(None)
132 tf.Dimension(n) .merge_with(tf.Dimension(m)) # raises ValueError for n != m
133 ```
135 Args:
136 other: Another Dimension.
138 Returns:
139 A Dimension containing the combined information of `self` and
140 `other`.
142 Raises:
143 ValueError: If `self` and `other` are not convertible (see
144 is_convertible_with).
145 """
146 other = as_dimension(other)
147 self.assert_is_convertible_with(other)
148 if self._value is None:
149 return Dimension(other.value)
150 else:
151 return Dimension(self._value)
153 def __add__(self, other):
154 """Returns the sum of `self` and `other`.
156 Dimensions are summed as follows:
158 ```python
159 tf.Dimension(m) + tf.Dimension(n) == tf.Dimension(m + n)
160 tf.Dimension(m) + tf.Dimension(None) == tf.Dimension(None)
161 tf.Dimension(None) + tf.Dimension(n) == tf.Dimension(None)
162 tf.Dimension(None) + tf.Dimension(None) == tf.Dimension(None)
163 ```
165 Args:
166 other: Another Dimension, or a value accepted by `as_dimension`.
168 Returns:
169 A Dimension whose value is the sum of `self` and `other`.
170 """
171 other = as_dimension(other)
172 if self._value is None or other.value is None:
173 return Dimension(None)
174 else:
175 return Dimension(self._value + other.value)
177 def __radd__(self, other):
178 """Returns the sum of `other` and `self`.
180 Args:
181 other: Another Dimension, or a value accepted by `as_dimension`.
183 Returns:
184 A Dimension whose value is the sum of `self` and `other`.
185 """
186 return self + other
188 def __sub__(self, other):
189 """Returns the subtraction of `other` from `self`.
191 Dimensions are subtracted as follows:
193 ```python
194 tf.Dimension(m) - tf.Dimension(n) == tf.Dimension(m - n)
195 tf.Dimension(m) - tf.Dimension(None) == tf.Dimension(None)
196 tf.Dimension(None) - tf.Dimension(n) == tf.Dimension(None)
197 tf.Dimension(None) - tf.Dimension(None) == tf.Dimension(None)
198 ```
200 Args:
201 other: Another Dimension, or a value accepted by `as_dimension`.
203 Returns:
204 A Dimension whose value is the subtraction of `other` from `self`.
205 """
206 other = as_dimension(other)
207 if self._value is None or other.value is None:
208 return Dimension(None)
209 else:
210 return Dimension(self._value - other.value)
212 def __rsub__(self, other):
213 """Returns the subtraction of `self` from `other`.
215 Args:
216 other: Another Dimension, or a value accepted by `as_dimension`.
218 Returns:
219 A Dimension whose value is the subtraction of `self` from `other`.
220 """
221 other = as_dimension(other)
222 if self._value is None or other.value is None:
223 return Dimension(None)
224 else:
225 return Dimension(other.value - self._value)
227 def __mul__(self, other):
228 """Returns the product of `self` and `other`.
230 Dimensions are summed as follows:
232 ```python
233 tf.Dimension(m) * tf.Dimension(n) == tf.Dimension(m * n)
234 tf.Dimension(m) * tf.Dimension(None) == tf.Dimension(None)
235 tf.Dimension(None) * tf.Dimension(n) == tf.Dimension(None)
236 tf.Dimension(None) * tf.Dimension(None) == tf.Dimension(None)
237 ```
239 Args:
240 other: Another Dimension, or a value accepted by `as_dimension`.
242 Returns:
243 A Dimension whose value is the product of `self` and `other`.
244 """
245 try:
246 other = as_dimension(other)
247 except (TypeError, ValueError):
248 return NotImplemented
250 if self._value is None or other.value is None:
251 return Dimension(None)
252 else:
253 return Dimension(self._value * other.value)
255 def __rmul__(self, other):
256 """Returns the product of `self` and `other`.
258 Args:
259 other: Another Dimension, or a value accepted by `as_dimension`.
261 Returns:
262 A Dimension whose value is the product of `self` and `other`.
263 """
264 return self * other
266 def __floordiv__(self, other):
267 """Returns the quotient of `self` and `other` rounded down.
269 Dimensions are divided as follows:
271 ```python
272 tf.Dimension(m) // tf.Dimension(n) == tf.Dimension(m // n)
273 tf.Dimension(m) // tf.Dimension(None) == tf.Dimension(None)
274 tf.Dimension(None) // tf.Dimension(n) == tf.Dimension(None)
275 tf.Dimension(None) // tf.Dimension(None) == tf.Dimension(None)
276 ```
278 Args:
279 other: Another Dimension, or a value accepted by `as_dimension`.
281 Returns:
282 A `Dimension` whose value is the integer quotient of `self` and `other`.
283 """
284 try:
285 other = as_dimension(other)
286 except (TypeError, ValueError):
287 return NotImplemented
288 if self._value is None or other.value is None:
289 return Dimension(None)
290 else:
291 return Dimension(self._value // other.value)
293 def __rfloordiv__(self, other):
294 """Returns the quotient of `other` and `self` rounded down.
296 Args:
297 other: Another Dimension, or a value accepted by `as_dimension`.
299 Returns:
300 A `Dimension` whose value is the integer quotient of `self` and `other`.
301 """
302 other = as_dimension(other)
303 if self._value is None or other.value is None:
304 return Dimension(None)
305 else:
306 return Dimension(other.value // self._value)
308 def __div__(self, other):
309 """DEPRECATED: Use `__floordiv__` via `x // y` instead.
311 This function exists only for backwards convertibility purposes; new code
312 should use `__floordiv__` via the syntax `x // y`. Using `x // y`
313 communicates clearly that the result rounds down, and is forward convertible
314 to Python 3.
316 Args:
317 other: Another `Dimension`.
319 Returns:
320 A `Dimension` whose value is the integer quotient of `self` and `other`.
321 """
322 return self // other
324 def __mod__(self, other):
325 """Returns `self` modulo `other`.
327 Dimension moduli are computed as follows:
329 ```python
330 tf.Dimension(m) % tf.Dimension(n) == tf.Dimension(m % n)
331 tf.Dimension(m) % tf.Dimension(None) == tf.Dimension(None)
332 tf.Dimension(None) % tf.Dimension(n) == tf.Dimension(None)
333 tf.Dimension(None) % tf.Dimension(None) == tf.Dimension(None)
334 ```
336 Args:
337 other: Another Dimension, or a value accepted by `as_dimension`.
339 Returns:
340 A Dimension whose value is `self` modulo `other`.
341 """
342 try:
343 other = as_dimension(other)
344 except (TypeError, ValueError):
345 return NotImplemented
346 if self._value is None or other.value is None:
347 return Dimension(None)
348 else:
349 return Dimension(self._value % other.value)
351 def __rmod__(self, other):
352 """Returns `other` modulo `self`.
354 Args:
355 other: Another Dimension, or a value accepted by `as_dimension`.
357 Returns:
358 A Dimension whose value is `other` modulo `self`.
359 """
360 try:
361 other = as_dimension(other)
362 except (TypeError, ValueError):
363 return NotImplemented
364 return other % self
366 def __lt__(self, other):
367 """Returns True if `self` is known to be less than `other`.
369 Dimensions are compared as follows:
371 ```python
372 (tf.Dimension(m) < tf.Dimension(n)) == (m < n)
373 (tf.Dimension(m) < tf.Dimension(None)) == None
374 (tf.Dimension(None) < tf.Dimension(n)) == None
375 (tf.Dimension(None) < tf.Dimension(None)) == None
376 ```
378 Args:
379 other: Another Dimension.
381 Returns:
382 The value of `self.value < other.value` if both are known, otherwise
383 None.
384 """
385 other = as_dimension(other)
386 if self._value is None or other.value is None:
387 return None
388 else:
389 return self._value < other.value
391 def __le__(self, other):
392 """Returns True if `self` is known to be less than or equal to `other`.
394 Dimensions are compared as follows:
396 ```python
397 (tf.Dimension(m) <= tf.Dimension(n)) == (m <= n)
398 (tf.Dimension(m) <= tf.Dimension(None)) == None
399 (tf.Dimension(None) <= tf.Dimension(n)) == None
400 (tf.Dimension(None) <= tf.Dimension(None)) == None
401 ```
403 Args:
404 other: Another Dimension.
406 Returns:
407 The value of `self.value <= other.value` if both are known, otherwise
408 None.
409 """
410 other = as_dimension(other)
411 if self._value is None or other.value is None:
412 return None
413 else:
414 return self._value <= other.value
416 def __gt__(self, other):
417 """Returns True if `self` is known to be greater than `other`.
419 Dimensions are compared as follows:
421 ```python
422 (tf.Dimension(m) > tf.Dimension(n)) == (m > n)
423 (tf.Dimension(m) > tf.Dimension(None)) == None
424 (tf.Dimension(None) > tf.Dimension(n)) == None
425 (tf.Dimension(None) > tf.Dimension(None)) == None
426 ```
428 Args:
429 other: Another Dimension.
431 Returns:
432 The value of `self.value > other.value` if both are known, otherwise
433 None.
434 """
435 other = as_dimension(other)
436 if self._value is None or other.value is None:
437 return None
438 else:
439 return self._value > other.value
441 def __ge__(self, other):
442 """Returns True if `self` is known to be greater than or equal to
443 `other`.
445 Dimensions are compared as follows:
447 ```python
448 (tf.Dimension(m) >= tf.Dimension(n)) == (m >= n)
449 (tf.Dimension(m) >= tf.Dimension(None)) == None
450 (tf.Dimension(None) >= tf.Dimension(n)) == None
451 (tf.Dimension(None) >= tf.Dimension(None)) == None
452 ```
454 Args:
455 other: Another Dimension.
457 Returns:
458 The value of `self.value >= other.value` if both are known, otherwise
459 None.
460 """
461 other = as_dimension(other)
462 if self._value is None or other.value is None:
463 return None
464 else:
465 return self._value >= other.value
467 def __reduce__(self):
468 return Dimension, (self._value,)
471def as_dimension(value):
472 """Converts the given value to a Dimension.
474 A Dimension input will be returned unmodified.
475 An input of `None` will be converted to an unknown Dimension.
476 An integer input will be converted to a Dimension with that value.
478 Args:
479 value: The value to be converted.
481 Returns:
482 A Dimension corresponding to the given value.
483 """
484 if isinstance(value, Dimension):
485 return value
486 else:
487 return Dimension(value)
490# @tf_export("TensorShape")
491class TensorShape:
492 """Represents the shape of a `Tensor`.
494 A `TensorShape` represents a possibly-partial shape specification for a
495 `Tensor`. It may be one of the following:
497 * *Fully-known shape:* has a known number of dimensions and a known size
498 for each dimension. e.g. `TensorShape([16, 256])`
499 * *Partially-known shape:* has a known number of dimensions, and an unknown
500 size for one or more dimension. e.g. `TensorShape([None, 256])`
501 * *Unknown shape:* has an unknown number of dimensions, and an unknown
502 size in all dimensions. e.g. `TensorShape(None)`
504 If a tensor is produced by an operation of type `"Foo"`, its shape
505 may be inferred if there is a registered shape function for
506 `"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`}
507 for details of shape functions and how to register them. Alternatively,
508 the shape may be set explicitly using @{tf.Tensor.set_shape}.
509 """
511 def __init__(self, dims):
512 """Creates a new TensorShape with the given dimensions.
514 Args:
515 dims: A list of Dimensions, or None if the shape is unspecified.
516 DEPRECATED: A single integer is treated as a singleton list.
518 Raises:
519 TypeError: If dims cannot be converted to a list of dimensions.
520 """
521 # TODO(irving): Eliminate the single integer special case.
522 if dims is None:
523 self._dims = None
524 elif isinstance(dims, compat.bytes_or_text_types):
525 raise TypeError(
526 "A string has ambiguous TensorShape, please wrap in a "
527 "list or convert to an int: %s" % dims
528 )
529 elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
530 if dims.unknown_rank:
531 self._dims = None
532 else:
533 self._dims = [
534 # Protos store variable-size dimensions as -1
535 as_dimension(dim.size if dim.size != -1 else None)
536 for dim in dims.dim
537 ]
538 elif isinstance(dims, TensorShape):
539 self._dims = dims.dims
540 else:
541 try:
542 dims_iter = iter(dims)
543 except TypeError:
544 # Treat as a singleton dimension
545 self._dims = [as_dimension(dims)]
546 else:
547 # Got a list of dimensions
548 self._dims = [as_dimension(d) for d in dims_iter]
549 self._ndims = None
551 def __repr__(self):
552 return "TensorShape(%r)" % self._dims
554 def __str__(self):
555 if self.ndims is None:
556 return "<unknown>"
557 elif self.ndims == 1:
558 return "(%s,)" % self._dims[0]
559 else:
560 return "(%s)" % ", ".join(str(d) for d in self._dims)
562 @property
563 def dims(self):
564 """Returns a list of Dimensions, or None if the shape is
565 unspecified."""
566 return self._dims
568 @dims.setter
569 def dims(self, dims):
570 self._dims = dims
571 self._ndims = None
573 @property
574 def ndims(self):
575 """Returns the rank of this shape, or None if it is unspecified."""
576 if self._dims is None:
577 return None
578 else:
579 if self._ndims is None:
580 self._ndims = len(self._dims)
581 return self._ndims
583 def __len__(self):
584 """Returns the rank of this shape, or raises ValueError if
585 unspecified."""
586 if self._dims is None:
587 raise ValueError(
588 "Cannot take the length of Shape with unknown rank."
589 )
590 return self.ndims
592 def __bool__(self):
593 """Returns True if this shape contains non-zero information."""
594 return self._dims is not None
596 # Python 3 wants __bool__, Python 2.7 wants __nonzero__
597 __nonzero__ = __bool__
599 def __iter__(self):
600 """Returns `self.dims` if the rank is known, otherwise raises
601 ValueError."""
602 if self._dims is None:
603 raise ValueError("Cannot iterate over a shape with unknown rank.")
604 else:
605 return iter(self._dims)
607 def __getitem__(self, key):
608 """Returns the value of a dimension or a shape, depending on the key.
610 Args:
611 key: If `key` is an integer, returns the dimension at that index;
612 otherwise if `key` is a slice, returns a TensorShape whose
613 dimensions are those selected by the slice from `self`.
615 Returns:
616 A dimension if `key` is an integer, or a `TensorShape` if `key` is a
617 slice.
619 Raises:
620 ValueError: If `key` is a slice, and any of its elements are negative, or
621 if `self` is completely unknown and the step is set.
622 """
623 if self._dims is not None:
624 if isinstance(key, slice):
625 return TensorShape(self._dims[key])
626 else:
627 return self._dims[key]
628 else:
629 if isinstance(key, slice):
630 start = key.start if key.start is not None else 0
631 stop = key.stop
633 if key.step is not None:
634 # TODO(mrry): Handle these maybe.
635 raise ValueError("Steps are not yet handled")
636 if stop is None:
637 # NOTE(mrry): This implies that TensorShape(None) is convertible with
638 # TensorShape(None)[1:], which is obviously not true. It would be
639 # possible to track the number of dimensions symbolically,
640 # and perhaps we should do that.
641 return unknown_shape()
642 elif start < 0 or stop < 0:
643 # TODO(mrry): Handle this better, as it will be useful for handling
644 # suffixes of otherwise unknown shapes.
645 return unknown_shape()
646 else:
647 return unknown_shape(ndims=stop - start)
648 else:
649 return Dimension(None)
651 def num_elements(self):
652 """Returns the total number of elements, or none for incomplete
653 shapes."""
654 if self.is_fully_defined():
655 size = 1
656 for dim in self._dims:
657 size *= dim.value
658 return size
659 else:
660 return None
662 def merge_with(self, other):
663 """Returns a `TensorShape` combining the information in `self` and
664 `other`.
666 The dimensions in `self` and `other` are merged elementwise,
667 according to the rules defined for `Dimension.merge_with()`.
669 Args:
670 other: Another `TensorShape`.
672 Returns:
673 A `TensorShape` containing the combined information of `self` and
674 `other`.
676 Raises:
677 ValueError: If `self` and `other` are not convertible.
678 """
679 other = as_shape(other)
680 if self._dims is None:
681 return other
682 else:
683 try:
684 self.assert_same_rank(other)
685 new_dims = []
686 for i, dim in enumerate(self._dims):
687 new_dims.append(dim.merge_with(other[i]))
688 return TensorShape(new_dims)
689 except ValueError:
690 raise ValueError(
691 "Shapes %s and %s are not convertible" % (self, other)
692 )
694 def concatenate(self, other):
695 """Returns the concatenation of the dimension in `self` and `other`.
697 *N.B.* If either `self` or `other` is completely unknown,
698 concatenation will discard information about the other shape. In
699 future, we might support concatenation that preserves this
700 information for use with slicing.
702 Args:
703 other: Another `TensorShape`.
705 Returns:
706 A `TensorShape` whose dimensions are the concatenation of the
707 dimensions in `self` and `other`.
708 """
709 # TODO(mrry): Handle the case where we concatenate a known shape with a
710 # completely unknown shape, so that we can use the partial information.
711 other = as_shape(other)
712 if self._dims is None or other.dims is None:
713 return unknown_shape()
714 else:
715 return TensorShape(self._dims + other.dims)
717 def assert_same_rank(self, other):
718 """Raises an exception if `self` and `other` do not have convertible
719 ranks.
721 Args:
722 other: Another `TensorShape`.
724 Raises:
725 ValueError: If `self` and `other` do not represent shapes with the
726 same rank.
727 """
728 other = as_shape(other)
729 if self.ndims is not None and other.ndims is not None:
730 if self.ndims != other.ndims:
731 raise ValueError(
732 "Shapes %s and %s must have the same rank" % (self, other)
733 )
735 def assert_has_rank(self, rank):
736 """Raises an exception if `self` is not convertible with the given
737 `rank`.
739 Args:
740 rank: An integer.
742 Raises:
743 ValueError: If `self` does not represent a shape with the given `rank`.
744 """
745 if self.ndims not in (None, rank):
746 raise ValueError("Shape %s must have rank %d" % (self, rank))
748 def with_rank(self, rank):
749 """Returns a shape based on `self` with the given rank.
751 This method promotes a completely unknown shape to one with a
752 known rank.
754 Args:
755 rank: An integer.
757 Returns:
758 A shape that is at least as specific as `self` with the given rank.
760 Raises:
761 ValueError: If `self` does not represent a shape with the given `rank`.
762 """
763 try:
764 return self.merge_with(unknown_shape(ndims=rank))
765 except ValueError:
766 raise ValueError("Shape %s must have rank %d" % (self, rank))
768 def with_rank_at_least(self, rank):
769 """Returns a shape based on `self` with at least the given rank.
771 Args:
772 rank: An integer.
774 Returns:
775 A shape that is at least as specific as `self` with at least the given
776 rank.
778 Raises:
779 ValueError: If `self` does not represent a shape with at least the given
780 `rank`.
781 """
782 if self.ndims is not None and self.ndims < rank:
783 raise ValueError(
784 "Shape %s must have rank at least %d" % (self, rank)
785 )
786 else:
787 return self
789 def with_rank_at_most(self, rank):
790 """Returns a shape based on `self` with at most the given rank.
792 Args:
793 rank: An integer.
795 Returns:
796 A shape that is at least as specific as `self` with at most the given
797 rank.
799 Raises:
800 ValueError: If `self` does not represent a shape with at most the given
801 `rank`.
802 """
803 if self.ndims is not None and self.ndims > rank:
804 raise ValueError(
805 "Shape %s must have rank at most %d" % (self, rank)
806 )
807 else:
808 return self
810 def is_convertible_with(self, other):
811 """Returns True iff `self` is convertible with `other`.
813 Two possibly-partially-defined shapes are convertible if there
814 exists a fully-defined shape that both shapes can represent. Thus,
815 convertibility allows the shape inference code to reason about
816 partially-defined shapes. For example:
818 * TensorShape(None) is convertible with all shapes.
820 * TensorShape([None, None]) is convertible with all two-dimensional
821 shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
822 not convertible with, for example, TensorShape([None]) or
823 TensorShape([None, None, None]).
825 * TensorShape([32, None]) is convertible with all two-dimensional shapes
826 with size 32 in the 0th dimension, and also TensorShape([None, None])
827 and TensorShape(None). It is not convertible with, for example,
828 TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
830 * TensorShape([32, 784]) is convertible with itself, and also
831 TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
832 None]) and TensorShape(None). It is not convertible with, for example,
833 TensorShape([32, 1, 784]) or TensorShape([None]).
835 The convertibility relation is reflexive and symmetric, but not
836 transitive. For example, TensorShape([32, 784]) is convertible with
837 TensorShape(None), and TensorShape(None) is convertible with
838 TensorShape([4, 4]), but TensorShape([32, 784]) is not convertible with
839 TensorShape([4, 4]).
841 Args:
842 other: Another TensorShape.
844 Returns:
845 True iff `self` is convertible with `other`.
846 """
847 other = as_shape(other)
848 if self._dims is not None and other.dims is not None:
849 if self.ndims != other.ndims:
850 return False
851 for x_dim, y_dim in zip(self._dims, other.dims):
852 if not x_dim.is_convertible_with(y_dim):
853 return False
854 return True
856 def assert_is_convertible_with(self, other):
857 """Raises exception if `self` and `other` do not represent the same
858 shape.
860 This method can be used to assert that there exists a shape that both
861 `self` and `other` represent.
863 Args:
864 other: Another TensorShape.
866 Raises:
867 ValueError: If `self` and `other` do not represent the same shape.
868 """
869 if not self.is_convertible_with(other):
870 raise ValueError(
871 "Shapes %s and %s are inconvertible" % (self, other)
872 )
874 def most_specific_convertible_shape(self, other):
875 """Returns the most specific TensorShape convertible with `self` and
876 `other`.
878 * TensorShape([None, 1]) is the most specific TensorShape convertible with
879 both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
880 TensorShape(None) is also convertible with above mentioned TensorShapes.
882 * TensorShape([1, 2, 3]) is the most specific TensorShape convertible with
883 both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
884 less specific TensorShapes convertible with above mentioned TensorShapes,
885 e.g. TensorShape([1, 2, None]), TensorShape(None).
887 Args:
888 other: Another `TensorShape`.
890 Returns:
891 A `TensorShape` which is the most specific convertible shape of `self`
892 and `other`.
893 """
895 other = as_shape(other)
896 if (
897 self._dims is None
898 or other.dims is None
899 or self.ndims != other.ndims
900 ):
901 return unknown_shape()
903 dims = [(Dimension(None))] * self.ndims
904 for i, (d1, d2) in enumerate(zip(self._dims, other.dims)):
905 if d1 is not None and d2 is not None and d1 == d2:
906 dims[i] = d1
907 return TensorShape(dims)
909 def is_fully_defined(self):
910 """Returns True iff `self` is fully defined in every dimension."""
911 return self._dims is not None and all(
912 dim.value is not None for dim in self._dims
913 )
915 def assert_is_fully_defined(self):
916 """Raises an exception if `self` is not fully defined in every
917 dimension.
919 Raises:
920 ValueError: If `self` does not have a known value for every dimension.
921 """
922 if not self.is_fully_defined():
923 raise ValueError("Shape %s is not fully defined" % self)
925 def as_list(self):
926 """Returns a list of integers or `None` for each dimension.
928 Returns:
929 A list of integers or `None` for each dimension.
931 Raises:
932 ValueError: If `self` is an unknown shape with an unknown rank.
933 """
934 if self._dims is None:
935 raise ValueError(
936 "as_list() is not defined on an unknown TensorShape."
937 )
938 return [dim.value for dim in self._dims]
940 def as_proto(self):
941 """Returns this shape as a `TensorShapeProto`."""
942 if self._dims is None:
943 return tensor_shape_pb2.TensorShapeProto(unknown_rank=True)
944 else:
945 return tensor_shape_pb2.TensorShapeProto(
946 dim=[
947 tensor_shape_pb2.TensorShapeProto.Dim(
948 size=-1 if d.value is None else d.value
949 )
950 for d in self._dims
951 ]
952 )
954 def __eq__(self, other):
955 """Returns True if `self` is equivalent to `other`."""
956 try:
957 other = as_shape(other)
958 except TypeError:
959 return NotImplemented
960 return self._dims == other.dims
962 def __ne__(self, other):
963 """Returns True if `self` is known to be different from `other`."""
964 try:
965 other = as_shape(other)
966 except TypeError:
967 return NotImplemented
968 if self.ndims is None or other.ndims is None:
969 raise ValueError(
970 "The inequality of unknown TensorShapes is undefined."
971 )
972 if self.ndims != other.ndims:
973 return True
974 return self._dims != other.dims
976 def __reduce__(self):
977 return TensorShape, (self._dims,)
980def as_shape(shape):
981 """Converts the given object to a TensorShape."""
982 if isinstance(shape, TensorShape):
983 return shape
984 else:
985 return TensorShape(shape)
988def unknown_shape(ndims=None):
989 """Returns an unknown TensorShape, optionally with a known rank.
991 Args:
992 ndims: (Optional) If specified, the number of dimensions in the shape.
994 Returns:
995 An unknown TensorShape.
996 """
997 if ndims is None:
998 return TensorShape(None)
999 else:
1000 return TensorShape([Dimension(None)] * ndims)
1003_SCALAR_SHAPE = TensorShape([])
1006def scalar():
1007 """Returns a shape representing a scalar."""
1008 return _SCALAR_SHAPE
1011def vector(length):
1012 """Returns a shape representing a vector.
1014 Args:
1015 length: The length of the vector, which may be None if unknown.
1017 Returns:
1018 A TensorShape representing a vector of the given length.
1019 """
1020 return TensorShape([length])
1023def matrix(rows, cols):
1024 """Returns a shape representing a matrix.
1026 Args:
1027 rows: The number of rows in the matrix, which may be None if unknown.
1028 cols: The number of columns in the matrix, which may be None if unknown.
1030 Returns:
1031 A TensorShape representing a matrix of the given size.
1032 """
1033 return TensorShape([rows, cols])