Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/reshaping/permute.py: 44%

27 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains the Permute layer.""" 

16 

17 

18import copy 

19 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src.engine.base_layer import Layer 

23from keras.src.engine.input_spec import InputSpec 

24 

25# isort: off 

26from tensorflow.python.util.tf_export import keras_export 

27 

28 

29@keras_export("keras.layers.Permute") 

30class Permute(Layer): 

31 """Permutes the dimensions of the input according to a given pattern. 

32 

33 Useful e.g. connecting RNNs and convnets. 

34 

35 Example: 

36 

37 ```python 

38 model = Sequential() 

39 model.add(Permute((2, 1), input_shape=(10, 64))) 

40 # now: model.output_shape == (None, 64, 10) 

41 # note: `None` is the batch dimension 

42 ``` 

43 

44 Args: 

45 dims: Tuple of integers. Permutation pattern does not include the 

46 samples dimension. Indexing starts at 1. 

47 For instance, `(2, 1)` permutes the first and second dimensions 

48 of the input. 

49 

50 Input shape: 

51 Arbitrary. Use the keyword argument `input_shape` 

52 (tuple of integers, does not include the samples axis) 

53 when using this layer as the first layer in a model. 

54 

55 Output shape: 

56 Same as the input shape, but with the dimensions re-ordered according 

57 to the specified pattern. 

58 """ 

59 

60 def __init__(self, dims, **kwargs): 

61 super().__init__(**kwargs) 

62 self.dims = tuple(dims) 

63 if sorted(dims) != list(range(1, len(dims) + 1)): 

64 raise ValueError( 

65 "Invalid permutation argument `dims` for Permute Layer. " 

66 "The set of indices in `dims` must be consecutive and start " 

67 f"from 1. Received dims={dims}" 

68 ) 

69 self.input_spec = InputSpec(ndim=len(self.dims) + 1) 

70 

71 def compute_output_shape(self, input_shape): 

72 input_shape = tf.TensorShape(input_shape).as_list() 

73 output_shape = copy.copy(input_shape) 

74 for i, dim in enumerate(self.dims): 

75 target_dim = input_shape[dim] 

76 output_shape[i + 1] = target_dim 

77 return tf.TensorShape(output_shape) 

78 

79 def call(self, inputs): 

80 return tf.transpose(inputs, perm=(0,) + self.dims) 

81 

82 def get_config(self): 

83 config = {"dims": self.dims} 

84 base_config = super().get_config() 

85 return dict(list(base_config.items()) + list(config.items())) 

86