Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/rnn/base_wrapper.py: 38%
39 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for wrapper layers.
17Wrappers are layers that augment the functionality of another layer.
18"""
21import copy
23from keras.src.engine.base_layer import Layer
24from keras.src.saving import serialization_lib
25from keras.src.saving.legacy import serialization as legacy_serialization
27# isort: off
28from tensorflow.python.util.tf_export import keras_export
31@keras_export("keras.layers.Wrapper")
32class Wrapper(Layer):
33 """Abstract wrapper base class.
35 Wrappers take another layer and augment it in various ways.
36 Do not use this class as a layer, it is only an abstract base class.
37 Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
39 Args:
40 layer: The layer to be wrapped.
41 """
43 def __init__(self, layer, **kwargs):
44 try:
45 assert isinstance(layer, Layer)
46 except Exception:
47 raise ValueError(
48 f"Layer {layer} supplied to wrapper is"
49 " not a supported layer type. Please"
50 " ensure wrapped layer is a valid Keras layer."
51 )
52 self.layer = layer
53 super().__init__(**kwargs)
55 def build(self, input_shape=None):
56 if not self.layer.built:
57 self.layer.build(input_shape)
58 self.layer.built = True
59 self.built = True
61 @property
62 def activity_regularizer(self):
63 if hasattr(self.layer, "activity_regularizer"):
64 return self.layer.activity_regularizer
65 else:
66 return None
68 def get_config(self):
69 try:
70 config = {
71 "layer": serialization_lib.serialize_keras_object(self.layer)
72 }
73 except TypeError: # Case of incompatible custom wrappers
74 config = {
75 "layer": legacy_serialization.serialize_keras_object(self.layer)
76 }
77 base_config = super().get_config()
78 return dict(list(base_config.items()) + list(config.items()))
80 @classmethod
81 def from_config(cls, config, custom_objects=None):
82 from keras.src.layers import deserialize as deserialize_layer
84 # Avoid mutating the input dict
85 config = copy.deepcopy(config)
86 use_legacy_format = "module" not in config
87 layer = deserialize_layer(
88 config.pop("layer"),
89 custom_objects=custom_objects,
90 use_legacy_format=use_legacy_format,
91 )
92 return cls(layer, **config)