Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/common_shapes.py: 19%
32 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A library of common shape functions."""
16import itertools
18from tensorflow.python.framework import tensor_shape
21def _broadcast_shape_helper(shape_x, shape_y):
22 """Helper functions for is_broadcast_compatible and broadcast_shape.
24 Args:
25 shape_x: A `TensorShape`
26 shape_y: A `TensorShape`
28 Returns:
29 Returns None if the shapes are not broadcast compatible,
30 a list of the broadcast dimensions otherwise.
31 """
32 # To compute the broadcasted dimensions, we zip together shape_x and shape_y,
33 # and pad with 1 to make them the same length.
34 broadcasted_dims = reversed(
35 list(
36 itertools.zip_longest(
37 reversed(shape_x.dims),
38 reversed(shape_y.dims),
39 fillvalue=tensor_shape.Dimension(1))))
40 # Next we combine the dimensions according to the numpy broadcasting rules.
41 # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
42 return_dims = []
43 for (dim_x, dim_y) in broadcasted_dims:
44 if dim_x.value is None or dim_y.value is None:
45 # One or both dimensions is unknown. If either dimension is greater than
46 # 1, we assume that the program is correct, and the other dimension will
47 # be broadcast to match it.
48 # TODO(mrry): If we eliminate the shape checks in C++, we must still
49 # assert that the unknown dim is either 1 or the same as the known dim.
50 if dim_x.value is not None and dim_x.value > 1:
51 return_dims.append(dim_x)
52 elif dim_y.value is not None and dim_y.value > 1:
53 return_dims.append(dim_y)
54 else:
55 return_dims.append(None)
56 elif dim_x.value == 1:
57 # We will broadcast dim_x to dim_y.
58 return_dims.append(dim_y)
59 elif dim_y.value == 1:
60 # We will broadcast dim_y to dim_x.
61 return_dims.append(dim_x)
62 elif dim_x.value == dim_y.value:
63 # The dimensions are compatible, so output is the same size in that
64 # dimension.
65 return_dims.append(dim_x.merge_with(dim_y))
66 else:
67 return None
68 return return_dims
71def is_broadcast_compatible(shape_x, shape_y):
72 """Returns True if `shape_x` and `shape_y` are broadcast compatible.
74 Args:
75 shape_x: A `TensorShape`
76 shape_y: A `TensorShape`
78 Returns:
79 True if a shape exists that both `shape_x` and `shape_y` can be broadcasted
80 to. False otherwise.
81 """
82 if shape_x.ndims is None or shape_y.ndims is None:
83 return False
84 return _broadcast_shape_helper(shape_x, shape_y) is not None
87def broadcast_shape(shape_x, shape_y):
88 """Returns the broadcasted shape between `shape_x` and `shape_y`.
90 Args:
91 shape_x: A `TensorShape`
92 shape_y: A `TensorShape`
94 Returns:
95 A `TensorShape` representing the broadcasted shape.
97 Raises:
98 ValueError: If the two shapes can not be broadcasted.
99 """
100 if shape_x.ndims is None or shape_y.ndims is None:
101 return tensor_shape.unknown_shape()
102 return_dims = _broadcast_shape_helper(shape_x, shape_y)
103 if return_dims is None:
104 raise ValueError('Incompatible shapes for broadcasting. Two shapes are '
105 'compatible if for each dimension pair they are either '
106 'equal or one of them is 1. '
107 f'Received: {shape_x} and {shape_y}.')
108 return tensor_shape.TensorShape(return_dims)