Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py: 49%
55 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Layer serialization/deserialization functions.
16"""
17# pylint: disable=wildcard-import
18# pylint: disable=unused-import
20import threading
22from tensorflow.python import tf2
23from tensorflow.python.keras.engine import base_layer
24from tensorflow.python.keras.engine import input_layer
25from tensorflow.python.keras.engine import input_spec
26from tensorflow.python.keras.layers import advanced_activations
27from tensorflow.python.keras.layers import convolutional
28from tensorflow.python.keras.layers import convolutional_recurrent
29from tensorflow.python.keras.layers import core
30from tensorflow.python.keras.layers import dense_attention
31from tensorflow.python.keras.layers import embeddings
32from tensorflow.python.keras.layers import merge
33from tensorflow.python.keras.layers import pooling
34from tensorflow.python.keras.layers import recurrent
35from tensorflow.python.keras.layers import rnn_cell_wrapper_v2
36from tensorflow.python.keras.utils import generic_utils
37from tensorflow.python.keras.utils import tf_inspect as inspect
38from tensorflow.python.util.tf_export import keras_export
40ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional,
41 convolutional_recurrent, core, dense_attention,
42 embeddings, merge, pooling, recurrent)
43ALL_V2_MODULES = (rnn_cell_wrapper_v2,)
44# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
45# thread-local to avoid concurrent mutations.
46LOCAL = threading.local()
49def populate_deserializable_objects():
50 """Populates dict ALL_OBJECTS with every built-in layer.
51 """
52 global LOCAL
53 if not hasattr(LOCAL, 'ALL_OBJECTS'):
54 LOCAL.ALL_OBJECTS = {}
55 LOCAL.GENERATED_WITH_V2 = None
57 if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
58 # Objects dict is already generated for the proper TF version:
59 # do nothing.
60 return
62 LOCAL.ALL_OBJECTS = {}
63 LOCAL.GENERATED_WITH_V2 = tf2.enabled()
65 base_cls = base_layer.Layer
66 generic_utils.populate_dict_with_module_objects(
67 LOCAL.ALL_OBJECTS,
68 ALL_MODULES,
69 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
71 # Overwrite certain V1 objects with V2 versions
72 if tf2.enabled():
73 generic_utils.populate_dict_with_module_objects(
74 LOCAL.ALL_OBJECTS,
75 ALL_V2_MODULES,
76 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
78 # Prevent circular dependencies.
79 from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
81 LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
82 LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec
83 LOCAL.ALL_OBJECTS['Functional'] = models.Functional
84 LOCAL.ALL_OBJECTS['Model'] = models.Model
85 LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
87 # Merge layers, function versions.
88 LOCAL.ALL_OBJECTS['add'] = merge.add
89 LOCAL.ALL_OBJECTS['subtract'] = merge.subtract
90 LOCAL.ALL_OBJECTS['multiply'] = merge.multiply
91 LOCAL.ALL_OBJECTS['average'] = merge.average
92 LOCAL.ALL_OBJECTS['maximum'] = merge.maximum
93 LOCAL.ALL_OBJECTS['minimum'] = merge.minimum
94 LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate
95 LOCAL.ALL_OBJECTS['dot'] = merge.dot
98@keras_export('keras.layers.serialize')
99def serialize(layer):
100 return generic_utils.serialize_keras_object(layer)
103@keras_export('keras.layers.deserialize')
104def deserialize(config, custom_objects=None):
105 """Instantiates a layer from a config dictionary.
107 Args:
108 config: dict of the form {'class_name': str, 'config': dict}
109 custom_objects: dict mapping class names (or function names)
110 of custom (non-Keras) objects to class/functions
112 Returns:
113 Layer instance (may be Model, Sequential, Network, Layer...)
114 """
115 populate_deserializable_objects()
116 return generic_utils.deserialize_keras_object(
117 config,
118 module_objects=LOCAL.ALL_OBJECTS,
119 custom_objects=custom_objects,
120 printable_module_name='layer')