Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/base_wrapper.py: 38%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Base class for wrapper layers. 

16 

17Wrappers are layers that augment the functionality of another layer. 

18""" 

19 

20 

21import copy 

22 

23from keras.src.engine.base_layer import Layer 

24from keras.src.saving import serialization_lib 

25from keras.src.saving.legacy import serialization as legacy_serialization 

26 

27# isort: off 

28from tensorflow.python.util.tf_export import keras_export 

29 

30 

31@keras_export("keras.layers.Wrapper") 

32class Wrapper(Layer): 

33 """Abstract wrapper base class. 

34 

35 Wrappers take another layer and augment it in various ways. 

36 Do not use this class as a layer, it is only an abstract base class. 

37 Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers. 

38 

39 Args: 

40 layer: The layer to be wrapped. 

41 """ 

42 

43 def __init__(self, layer, **kwargs): 

44 try: 

45 assert isinstance(layer, Layer) 

46 except Exception: 

47 raise ValueError( 

48 f"Layer {layer} supplied to wrapper is" 

49 " not a supported layer type. Please" 

50 " ensure wrapped layer is a valid Keras layer." 

51 ) 

52 self.layer = layer 

53 super().__init__(**kwargs) 

54 

55 def build(self, input_shape=None): 

56 if not self.layer.built: 

57 self.layer.build(input_shape) 

58 self.layer.built = True 

59 self.built = True 

60 

61 @property 

62 def activity_regularizer(self): 

63 if hasattr(self.layer, "activity_regularizer"): 

64 return self.layer.activity_regularizer 

65 else: 

66 return None 

67 

68 def get_config(self): 

69 try: 

70 config = { 

71 "layer": serialization_lib.serialize_keras_object(self.layer) 

72 } 

73 except TypeError: # Case of incompatible custom wrappers 

74 config = { 

75 "layer": legacy_serialization.serialize_keras_object(self.layer) 

76 } 

77 base_config = super().get_config() 

78 return dict(list(base_config.items()) + list(config.items())) 

79 

80 @classmethod 

81 def from_config(cls, config, custom_objects=None): 

82 from keras.src.layers import deserialize as deserialize_layer 

83 

84 # Avoid mutating the input dict 

85 config = copy.deepcopy(config) 

86 use_legacy_format = "module" not in config 

87 layer = deserialize_layer( 

88 config.pop("layer"), 

89 custom_objects=custom_objects, 

90 use_legacy_format=use_legacy_format, 

91 ) 

92 return cls(layer, **config) 

93