Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/common_shapes.py: 19%

32 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A library of common shape functions.""" 

16import itertools 

17 

18from tensorflow.python.framework import tensor_shape 

19 

20 

21def _broadcast_shape_helper(shape_x, shape_y): 

22 """Helper functions for is_broadcast_compatible and broadcast_shape. 

23 

24 Args: 

25 shape_x: A `TensorShape` 

26 shape_y: A `TensorShape` 

27 

28 Returns: 

29 Returns None if the shapes are not broadcast compatible, 

30 a list of the broadcast dimensions otherwise. 

31 """ 

32 # To compute the broadcasted dimensions, we zip together shape_x and shape_y, 

33 # and pad with 1 to make them the same length. 

34 broadcasted_dims = reversed( 

35 list( 

36 itertools.zip_longest( 

37 reversed(shape_x.dims), 

38 reversed(shape_y.dims), 

39 fillvalue=tensor_shape.Dimension(1)))) 

40 # Next we combine the dimensions according to the numpy broadcasting rules. 

41 # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html 

42 return_dims = [] 

43 for (dim_x, dim_y) in broadcasted_dims: 

44 if dim_x.value is None or dim_y.value is None: 

45 # One or both dimensions is unknown. If either dimension is greater than 

46 # 1, we assume that the program is correct, and the other dimension will 

47 # be broadcast to match it. 

48 # TODO(mrry): If we eliminate the shape checks in C++, we must still 

49 # assert that the unknown dim is either 1 or the same as the known dim. 

50 if dim_x.value is not None and dim_x.value > 1: 

51 return_dims.append(dim_x) 

52 elif dim_y.value is not None and dim_y.value > 1: 

53 return_dims.append(dim_y) 

54 else: 

55 return_dims.append(None) 

56 elif dim_x.value == 1: 

57 # We will broadcast dim_x to dim_y. 

58 return_dims.append(dim_y) 

59 elif dim_y.value == 1: 

60 # We will broadcast dim_y to dim_x. 

61 return_dims.append(dim_x) 

62 elif dim_x.value == dim_y.value: 

63 # The dimensions are compatible, so output is the same size in that 

64 # dimension. 

65 return_dims.append(dim_x.merge_with(dim_y)) 

66 else: 

67 return None 

68 return return_dims 

69 

70 

71def is_broadcast_compatible(shape_x, shape_y): 

72 """Returns True if `shape_x` and `shape_y` are broadcast compatible. 

73 

74 Args: 

75 shape_x: A `TensorShape` 

76 shape_y: A `TensorShape` 

77 

78 Returns: 

79 True if a shape exists that both `shape_x` and `shape_y` can be broadcasted 

80 to. False otherwise. 

81 """ 

82 if shape_x.ndims is None or shape_y.ndims is None: 

83 return False 

84 return _broadcast_shape_helper(shape_x, shape_y) is not None 

85 

86 

87def broadcast_shape(shape_x, shape_y): 

88 """Returns the broadcasted shape between `shape_x` and `shape_y`. 

89 

90 Args: 

91 shape_x: A `TensorShape` 

92 shape_y: A `TensorShape` 

93 

94 Returns: 

95 A `TensorShape` representing the broadcasted shape. 

96 

97 Raises: 

98 ValueError: If the two shapes can not be broadcasted. 

99 """ 

100 if shape_x.ndims is None or shape_y.ndims is None: 

101 return tensor_shape.unknown_shape() 

102 return_dims = _broadcast_shape_helper(shape_x, shape_y) 

103 if return_dims is None: 

104 raise ValueError('Incompatible shapes for broadcasting. Two shapes are ' 

105 'compatible if for each dimension pair they are either ' 

106 'equal or one of them is 1. ' 

107 f'Received: {shape_x} and {shape_y}.') 

108 return tensor_shape.TensorShape(return_dims)