Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/reshape.py: 27%
45 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the Reshape layer."""
18import numpy as np
19import tensorflow.compat.v2 as tf
21from keras.src.engine.base_layer import Layer
23# isort: off
24from tensorflow.python.util.tf_export import keras_export
27@keras_export("keras.layers.Reshape")
28class Reshape(Layer):
29 """Layer that reshapes inputs into the given shape.
31 Input shape:
32 Arbitrary, although all dimensions in the input shape must be known/fixed.
33 Use the keyword argument `input_shape` (tuple of integers, does not
34 include the samples/batch size axis) when using this layer as the first
35 layer in a model.
37 Output shape:
38 `(batch_size,) + target_shape`
40 Example:
42 >>> # as first layer in a Sequential model
43 >>> model = tf.keras.Sequential()
44 >>> model.add(tf.keras.layers.Reshape((3, 4), input_shape=(12,)))
45 >>> # model.output_shape == (None, 3, 4), `None` is the batch size.
46 >>> model.output_shape
47 (None, 3, 4)
49 >>> # as intermediate layer in a Sequential model
50 >>> model.add(tf.keras.layers.Reshape((6, 2)))
51 >>> model.output_shape
52 (None, 6, 2)
54 >>> # also supports shape inference using `-1` as dimension
55 >>> model.add(tf.keras.layers.Reshape((-1, 2, 2)))
56 >>> model.output_shape
57 (None, 3, 2, 2)
58 """
60 def __init__(self, target_shape, **kwargs):
61 """Creates a `tf.keras.layers.Reshape` layer instance.
63 Args:
64 target_shape: Target shape. Tuple of integers, does not include the
65 samples dimension (batch size).
66 **kwargs: Any additional layer keyword arguments.
67 """
68 super().__init__(**kwargs)
69 self.target_shape = tuple(target_shape)
71 def _fix_unknown_dimension(self, input_shape, output_shape):
72 """Find and replace a missing dimension in an output shape.
74 This is a near direct port of the internal Numpy function
75 `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c`
77 Args:
78 input_shape: Shape of array being reshaped
79 output_shape: Desired shape of the array with at most a single -1
80 which indicates a dimension that should be derived from the input
81 shape.
83 Returns:
84 The new output shape with a -1 replaced with its computed value.
86 Raises:
87 ValueError: If the total array size of the output_shape is
88 different than the input_shape, or more than one unknown dimension
89 is specified.
90 """
91 output_shape = list(output_shape)
92 msg = (
93 "total size of new array must be unchanged, "
94 "input_shape = {}, output_shape = {}".format(
95 input_shape, output_shape
96 )
97 )
99 known, unknown = 1, None
100 for index, dim in enumerate(output_shape):
101 if dim < 0:
102 if unknown is None:
103 unknown = index
104 else:
105 raise ValueError(
106 "There must be at most one unknown dimension in "
107 f"output_shape. Received: output_shape={output_shape}."
108 )
109 else:
110 known *= dim
112 original = np.prod(input_shape, dtype=int)
113 if unknown is not None:
114 if known == 0 or original % known != 0:
115 raise ValueError(msg)
116 output_shape[unknown] = original // known
117 elif original != known:
118 raise ValueError(msg)
119 return output_shape
121 def compute_output_shape(self, input_shape):
122 input_shape = tf.TensorShape(input_shape).as_list()
123 if None in input_shape[1:]:
124 output_shape = [input_shape[0]]
125 # input shape (partially) unknown? replace -1's with None's
126 output_shape += tuple(
127 s if s != -1 else None for s in self.target_shape
128 )
129 else:
130 output_shape = [input_shape[0]]
131 output_shape += self._fix_unknown_dimension(
132 input_shape[1:], self.target_shape
133 )
134 return tf.TensorShape(output_shape)
136 def call(self, inputs):
137 result = tf.reshape(inputs, (tf.shape(inputs)[0],) + self.target_shape)
138 if not tf.executing_eagerly():
139 # Set the static shape for the result since it might lost during
140 # array_ops reshape, eg, some `None` dim in the result could be
141 # inferred.
142 result.set_shape(self.compute_output_shape(inputs.shape))
143 return result
145 def get_config(self):
146 config = {"target_shape": self.target_shape}
147 base_config = super().get_config()
148 return dict(list(base_config.items()) + list(config.items()))