Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/initializers/__init__.py: 71%
85 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras initializer serialization / deserialization."""
17import threading
19import tensorflow.compat.v2 as tf
21from keras.src.initializers import initializers
22from keras.src.initializers import initializers_v1
23from keras.src.saving import serialization_lib
24from keras.src.saving.legacy import serialization as legacy_serialization
25from keras.src.utils import generic_utils
26from keras.src.utils import tf_inspect as inspect
28# isort: off
29from tensorflow.python import tf2
30from tensorflow.python.ops import init_ops
31from tensorflow.python.util.tf_export import keras_export
33# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
34# thread-local to avoid concurrent mutations.
35LOCAL = threading.local()
38def populate_deserializable_objects():
39 """Populates dict ALL_OBJECTS with every built-in initializer."""
40 global LOCAL
41 if not hasattr(LOCAL, "ALL_OBJECTS"):
42 LOCAL.ALL_OBJECTS = {}
43 LOCAL.GENERATED_WITH_V2 = None
45 if (
46 LOCAL.ALL_OBJECTS
47 and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled()
48 ):
49 # Objects dict is already generated for the proper TF version:
50 # do nothing.
51 return
53 LOCAL.ALL_OBJECTS = {}
54 LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled()
56 # Compatibility aliases (need to exist in both V1 and V2).
57 LOCAL.ALL_OBJECTS["ConstantV2"] = initializers.Constant
58 LOCAL.ALL_OBJECTS["GlorotNormalV2"] = initializers.GlorotNormal
59 LOCAL.ALL_OBJECTS["GlorotUniformV2"] = initializers.GlorotUniform
60 LOCAL.ALL_OBJECTS["HeNormalV2"] = initializers.HeNormal
61 LOCAL.ALL_OBJECTS["HeUniformV2"] = initializers.HeUniform
62 LOCAL.ALL_OBJECTS["IdentityV2"] = initializers.Identity
63 LOCAL.ALL_OBJECTS["LecunNormalV2"] = initializers.LecunNormal
64 LOCAL.ALL_OBJECTS["LecunUniformV2"] = initializers.LecunUniform
65 LOCAL.ALL_OBJECTS["OnesV2"] = initializers.Ones
66 LOCAL.ALL_OBJECTS["OrthogonalV2"] = initializers.Orthogonal
67 LOCAL.ALL_OBJECTS["RandomNormalV2"] = initializers.RandomNormal
68 LOCAL.ALL_OBJECTS["RandomUniformV2"] = initializers.RandomUniform
69 LOCAL.ALL_OBJECTS["TruncatedNormalV2"] = initializers.TruncatedNormal
70 LOCAL.ALL_OBJECTS["VarianceScalingV2"] = initializers.VarianceScaling
71 LOCAL.ALL_OBJECTS["ZerosV2"] = initializers.Zeros
73 # Out of an abundance of caution we also include these aliases that have
74 # a non-zero probability of having been included in saved configs in the
75 # past.
76 LOCAL.ALL_OBJECTS["glorot_normalV2"] = initializers.GlorotNormal
77 LOCAL.ALL_OBJECTS["glorot_uniformV2"] = initializers.GlorotUniform
78 LOCAL.ALL_OBJECTS["he_normalV2"] = initializers.HeNormal
79 LOCAL.ALL_OBJECTS["he_uniformV2"] = initializers.HeUniform
80 LOCAL.ALL_OBJECTS["lecun_normalV2"] = initializers.LecunNormal
81 LOCAL.ALL_OBJECTS["lecun_uniformV2"] = initializers.LecunUniform
83 if tf.__internal__.tf2.enabled():
84 # For V2, entries are generated automatically based on the content of
85 # initializers.py.
86 v2_objs = {}
87 base_cls = initializers.Initializer
88 generic_utils.populate_dict_with_module_objects(
89 v2_objs,
90 [initializers],
91 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls),
92 )
93 for key, value in v2_objs.items():
94 LOCAL.ALL_OBJECTS[key] = value
95 # Functional aliases.
96 LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
97 else:
98 # V1 initializers.
99 v1_objs = {
100 "Constant": tf.compat.v1.constant_initializer,
101 "GlorotNormal": tf.compat.v1.glorot_normal_initializer,
102 "GlorotUniform": tf.compat.v1.glorot_uniform_initializer,
103 "Identity": tf.compat.v1.initializers.identity,
104 "Ones": tf.compat.v1.ones_initializer,
105 "Orthogonal": tf.compat.v1.orthogonal_initializer,
106 "VarianceScaling": tf.compat.v1.variance_scaling_initializer,
107 "Zeros": tf.compat.v1.zeros_initializer,
108 "HeNormal": initializers_v1.HeNormal,
109 "HeUniform": initializers_v1.HeUniform,
110 "LecunNormal": initializers_v1.LecunNormal,
111 "LecunUniform": initializers_v1.LecunUniform,
112 "RandomNormal": initializers_v1.RandomNormal,
113 "RandomUniform": initializers_v1.RandomUniform,
114 "TruncatedNormal": initializers_v1.TruncatedNormal,
115 }
116 for key, value in v1_objs.items():
117 LOCAL.ALL_OBJECTS[key] = value
118 # Functional aliases.
119 LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
121 # More compatibility aliases.
122 LOCAL.ALL_OBJECTS["normal"] = LOCAL.ALL_OBJECTS["random_normal"]
123 LOCAL.ALL_OBJECTS["uniform"] = LOCAL.ALL_OBJECTS["random_uniform"]
124 LOCAL.ALL_OBJECTS["one"] = LOCAL.ALL_OBJECTS["ones"]
125 LOCAL.ALL_OBJECTS["zero"] = LOCAL.ALL_OBJECTS["zeros"]
128# For backwards compatibility, we populate this file with the objects
129# from ALL_OBJECTS. We make no guarantees as to whether these objects will
130# using their correct version.
131populate_deserializable_objects()
132globals().update(LOCAL.ALL_OBJECTS)
134# Utility functions
137@keras_export("keras.initializers.serialize")
138def serialize(initializer, use_legacy_format=False):
139 if use_legacy_format:
140 return legacy_serialization.serialize_keras_object(initializer)
142 return serialization_lib.serialize_keras_object(initializer)
145@keras_export("keras.initializers.deserialize")
146def deserialize(config, custom_objects=None, use_legacy_format=False):
147 """Return an `Initializer` object from its config."""
148 populate_deserializable_objects()
149 if use_legacy_format:
150 return legacy_serialization.deserialize_keras_object(
151 config,
152 module_objects=LOCAL.ALL_OBJECTS,
153 custom_objects=custom_objects,
154 printable_module_name="initializer",
155 )
157 return serialization_lib.deserialize_keras_object(
158 config,
159 module_objects=LOCAL.ALL_OBJECTS,
160 custom_objects=custom_objects,
161 printable_module_name="initializer",
162 )
165@keras_export("keras.initializers.get")
166def get(identifier):
167 """Retrieve a Keras initializer by the identifier.
169 The `identifier` may be the string name of a initializers function or class
170 (case-sensitively).
172 >>> identifier = 'Ones'
173 >>> tf.keras.initializers.deserialize(identifier)
174 <...keras.initializers.initializers.Ones...>
176 You can also specify `config` of the initializer to this function by passing
177 dict containing `class_name` and `config` as an identifier. Also note that
178 the `class_name` must map to a `Initializer` class.
180 >>> cfg = {'class_name': 'Ones', 'config': {}}
181 >>> tf.keras.initializers.deserialize(cfg)
182 <...keras.initializers.initializers.Ones...>
184 In the case that the `identifier` is a class, this method will return a new
185 instance of the class by its constructor.
187 Args:
188 identifier: String or dict that contains the initializer name or
189 configurations.
191 Returns:
192 Initializer instance base on the input identifier.
194 Raises:
195 ValueError: If the input identifier is not a supported type or in a bad
196 format.
197 """
199 if identifier is None:
200 return None
201 if isinstance(identifier, dict):
202 use_legacy_format = "module" not in identifier
203 return deserialize(identifier, use_legacy_format=use_legacy_format)
204 elif isinstance(identifier, str):
205 config = {"class_name": str(identifier), "config": {}}
206 return get(config)
207 elif callable(identifier):
208 if inspect.isclass(identifier):
209 identifier = identifier()
210 return identifier
211 else:
212 raise ValueError(
213 "Could not interpret initializer identifier: " + str(identifier)
214 )