Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/initializers/__init__.py: 71%
77 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras initializer serialization / deserialization."""
17import threading
19from tensorflow.python import tf2
20from tensorflow.python.keras.initializers import initializers_v1
21from tensorflow.python.keras.initializers import initializers_v2
22from tensorflow.python.keras.utils import generic_utils
23from tensorflow.python.keras.utils import tf_inspect as inspect
24from tensorflow.python.ops import init_ops
25from tensorflow.python.util.tf_export import keras_export
28# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
29# thread-local to avoid concurrent mutations.
30LOCAL = threading.local()
33def populate_deserializable_objects():
34 """Populates dict ALL_OBJECTS with every built-in initializer.
35 """
36 global LOCAL
37 if not hasattr(LOCAL, 'ALL_OBJECTS'):
38 LOCAL.ALL_OBJECTS = {}
39 LOCAL.GENERATED_WITH_V2 = None
41 if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
42 # Objects dict is already generated for the proper TF version:
43 # do nothing.
44 return
46 LOCAL.ALL_OBJECTS = {}
47 LOCAL.GENERATED_WITH_V2 = tf2.enabled()
49 # Compatibility aliases (need to exist in both V1 and V2).
50 LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant
51 LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal
52 LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform
53 LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal
54 LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform
55 LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity
56 LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal
57 LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform
58 LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones
59 LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal
60 LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal
61 LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform
62 LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal
63 LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling
64 LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros
66 # Out of an abundance of caution we also include these aliases that have
67 # a non-zero probability of having been included in saved configs in the past.
68 LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal
69 LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform
70 LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal
71 LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform
72 LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal
73 LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform
75 if tf2.enabled():
76 # For V2, entries are generated automatically based on the content of
77 # initializers_v2.py.
78 v2_objs = {}
79 base_cls = initializers_v2.Initializer
80 generic_utils.populate_dict_with_module_objects(
81 v2_objs,
82 [initializers_v2],
83 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
84 for key, value in v2_objs.items():
85 LOCAL.ALL_OBJECTS[key] = value
86 # Functional aliases.
87 LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
88 else:
89 # V1 initializers.
90 v1_objs = {
91 'Constant': init_ops.Constant,
92 'GlorotNormal': init_ops.GlorotNormal,
93 'GlorotUniform': init_ops.GlorotUniform,
94 'Identity': init_ops.Identity,
95 'Ones': init_ops.Ones,
96 'Orthogonal': init_ops.Orthogonal,
97 'VarianceScaling': init_ops.VarianceScaling,
98 'Zeros': init_ops.Zeros,
99 'HeNormal': initializers_v1.HeNormal,
100 'HeUniform': initializers_v1.HeUniform,
101 'LecunNormal': initializers_v1.LecunNormal,
102 'LecunUniform': initializers_v1.LecunUniform,
103 'RandomNormal': initializers_v1.RandomNormal,
104 'RandomUniform': initializers_v1.RandomUniform,
105 'TruncatedNormal': initializers_v1.TruncatedNormal,
106 }
107 for key, value in v1_objs.items():
108 LOCAL.ALL_OBJECTS[key] = value
109 # Functional aliases.
110 LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
112 # More compatibility aliases.
113 LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal']
114 LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform']
115 LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones']
116 LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
119# For backwards compatibility, we populate this file with the objects
120# from ALL_OBJECTS. We make no guarantees as to whether these objects will
121# using their correct version.
122populate_deserializable_objects()
123globals().update(LOCAL.ALL_OBJECTS)
125# Utility functions
128@keras_export('keras.initializers.serialize')
129def serialize(initializer):
130 return generic_utils.serialize_keras_object(initializer)
133@keras_export('keras.initializers.deserialize')
134def deserialize(config, custom_objects=None):
135 """Return an `Initializer` object from its config."""
136 populate_deserializable_objects()
137 return generic_utils.deserialize_keras_object(
138 config,
139 module_objects=LOCAL.ALL_OBJECTS,
140 custom_objects=custom_objects,
141 printable_module_name='initializer')
144@keras_export('keras.initializers.get')
145def get(identifier):
146 """Retrieve a Keras initializer by the identifier.
148 The `identifier` may be the string name of a initializers function or class (
149 case-sensitively).
151 >>> identifier = 'Ones'
152 >>> tf.keras.initializers.deserialize(identifier)
153 <...keras.initializers.initializers_v2.Ones...>
155 You can also specify `config` of the initializer to this function by passing
156 dict containing `class_name` and `config` as an identifier. Also note that the
157 `class_name` must map to a `Initializer` class.
159 >>> cfg = {'class_name': 'Ones', 'config': {}}
160 >>> tf.keras.initializers.deserialize(cfg)
161 <...keras.initializers.initializers_v2.Ones...>
163 In the case that the `identifier` is a class, this method will return a new
164 instance of the class by its constructor.
166 Args:
167 identifier: String or dict that contains the initializer name or
168 configurations.
170 Returns:
171 Initializer instance base on the input identifier.
173 Raises:
174 ValueError: If the input identifier is not a supported type or in a bad
175 format.
176 """
178 if identifier is None:
179 return None
180 if isinstance(identifier, dict):
181 return deserialize(identifier)
182 elif isinstance(identifier, str):
183 identifier = str(identifier)
184 return deserialize(identifier)
185 elif callable(identifier):
186 if inspect.isclass(identifier):
187 identifier = identifier()
188 return identifier
189 else:
190 raise ValueError('Could not interpret initializer identifier: ' +
191 str(identifier))