Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/reshape.py: 27%

45 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains the Reshape layer.""" 

16 

17 

18import numpy as np 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src.engine.base_layer import Layer 

22 

23# isort: off 

24from tensorflow.python.util.tf_export import keras_export 

25 

26 

27@keras_export("keras.layers.Reshape") 

28class Reshape(Layer): 

29 """Layer that reshapes inputs into the given shape. 

30 

31 Input shape: 

32 Arbitrary, although all dimensions in the input shape must be known/fixed. 

33 Use the keyword argument `input_shape` (tuple of integers, does not 

34 include the samples/batch size axis) when using this layer as the first 

35 layer in a model. 

36 

37 Output shape: 

38 `(batch_size,) + target_shape` 

39 

40 Example: 

41 

42 >>> # as first layer in a Sequential model 

43 >>> model = tf.keras.Sequential() 

44 >>> model.add(tf.keras.layers.Reshape((3, 4), input_shape=(12,))) 

45 >>> # model.output_shape == (None, 3, 4), `None` is the batch size. 

46 >>> model.output_shape 

47 (None, 3, 4) 

48 

49 >>> # as intermediate layer in a Sequential model 

50 >>> model.add(tf.keras.layers.Reshape((6, 2))) 

51 >>> model.output_shape 

52 (None, 6, 2) 

53 

54 >>> # also supports shape inference using `-1` as dimension 

55 >>> model.add(tf.keras.layers.Reshape((-1, 2, 2))) 

56 >>> model.output_shape 

57 (None, 3, 2, 2) 

58 """ 

59 

60 def __init__(self, target_shape, **kwargs): 

61 """Creates a `tf.keras.layers.Reshape` layer instance. 

62 

63 Args: 

64 target_shape: Target shape. Tuple of integers, does not include the 

65 samples dimension (batch size). 

66 **kwargs: Any additional layer keyword arguments. 

67 """ 

68 super().__init__(**kwargs) 

69 self.target_shape = tuple(target_shape) 

70 

71 def _fix_unknown_dimension(self, input_shape, output_shape): 

72 """Find and replace a missing dimension in an output shape. 

73 

74 This is a near direct port of the internal Numpy function 

75 `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c` 

76 

77 Args: 

78 input_shape: Shape of array being reshaped 

79 output_shape: Desired shape of the array with at most a single -1 

80 which indicates a dimension that should be derived from the input 

81 shape. 

82 

83 Returns: 

84 The new output shape with a -1 replaced with its computed value. 

85 

86 Raises: 

87 ValueError: If the total array size of the output_shape is 

88 different than the input_shape, or more than one unknown dimension 

89 is specified. 

90 """ 

91 output_shape = list(output_shape) 

92 msg = ( 

93 "total size of new array must be unchanged, " 

94 "input_shape = {}, output_shape = {}".format( 

95 input_shape, output_shape 

96 ) 

97 ) 

98 

99 known, unknown = 1, None 

100 for index, dim in enumerate(output_shape): 

101 if dim < 0: 

102 if unknown is None: 

103 unknown = index 

104 else: 

105 raise ValueError( 

106 "There must be at most one unknown dimension in " 

107 f"output_shape. Received: output_shape={output_shape}." 

108 ) 

109 else: 

110 known *= dim 

111 

112 original = np.prod(input_shape, dtype=int) 

113 if unknown is not None: 

114 if known == 0 or original % known != 0: 

115 raise ValueError(msg) 

116 output_shape[unknown] = original // known 

117 elif original != known: 

118 raise ValueError(msg) 

119 return output_shape 

120 

121 def compute_output_shape(self, input_shape): 

122 input_shape = tf.TensorShape(input_shape).as_list() 

123 if None in input_shape[1:]: 

124 output_shape = [input_shape[0]] 

125 # input shape (partially) unknown? replace -1's with None's 

126 output_shape += tuple( 

127 s if s != -1 else None for s in self.target_shape 

128 ) 

129 else: 

130 output_shape = [input_shape[0]] 

131 output_shape += self._fix_unknown_dimension( 

132 input_shape[1:], self.target_shape 

133 ) 

134 return tf.TensorShape(output_shape) 

135 

136 def call(self, inputs): 

137 result = tf.reshape(inputs, (tf.shape(inputs)[0],) + self.target_shape) 

138 if not tf.executing_eagerly(): 

139 # Set the static shape for the result since it might lost during 

140 # array_ops reshape, eg, some `None` dim in the result could be 

141 # inferred. 

142 result.set_shape(self.compute_output_shape(inputs.shape)) 

143 return result 

144 

145 def get_config(self): 

146 config = {"target_shape": self.target_shape} 

147 base_config = super().get_config() 

148 return dict(list(base_config.items()) + list(config.items())) 

149