Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/initializers/__init__.py: 71%

85 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras initializer serialization / deserialization.""" 

16 

17import threading 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src.initializers import initializers 

22from keras.src.initializers import initializers_v1 

23from keras.src.saving import serialization_lib 

24from keras.src.saving.legacy import serialization as legacy_serialization 

25from keras.src.utils import generic_utils 

26from keras.src.utils import tf_inspect as inspect 

27 

28# isort: off 

29from tensorflow.python import tf2 

30from tensorflow.python.ops import init_ops 

31from tensorflow.python.util.tf_export import keras_export 

32 

33# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it 

34# thread-local to avoid concurrent mutations. 

35LOCAL = threading.local() 

36 

37 

38def populate_deserializable_objects(): 

39 """Populates dict ALL_OBJECTS with every built-in initializer.""" 

40 global LOCAL 

41 if not hasattr(LOCAL, "ALL_OBJECTS"): 

42 LOCAL.ALL_OBJECTS = {} 

43 LOCAL.GENERATED_WITH_V2 = None 

44 

45 if ( 

46 LOCAL.ALL_OBJECTS 

47 and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled() 

48 ): 

49 # Objects dict is already generated for the proper TF version: 

50 # do nothing. 

51 return 

52 

53 LOCAL.ALL_OBJECTS = {} 

54 LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled() 

55 

56 # Compatibility aliases (need to exist in both V1 and V2). 

57 LOCAL.ALL_OBJECTS["ConstantV2"] = initializers.Constant 

58 LOCAL.ALL_OBJECTS["GlorotNormalV2"] = initializers.GlorotNormal 

59 LOCAL.ALL_OBJECTS["GlorotUniformV2"] = initializers.GlorotUniform 

60 LOCAL.ALL_OBJECTS["HeNormalV2"] = initializers.HeNormal 

61 LOCAL.ALL_OBJECTS["HeUniformV2"] = initializers.HeUniform 

62 LOCAL.ALL_OBJECTS["IdentityV2"] = initializers.Identity 

63 LOCAL.ALL_OBJECTS["LecunNormalV2"] = initializers.LecunNormal 

64 LOCAL.ALL_OBJECTS["LecunUniformV2"] = initializers.LecunUniform 

65 LOCAL.ALL_OBJECTS["OnesV2"] = initializers.Ones 

66 LOCAL.ALL_OBJECTS["OrthogonalV2"] = initializers.Orthogonal 

67 LOCAL.ALL_OBJECTS["RandomNormalV2"] = initializers.RandomNormal 

68 LOCAL.ALL_OBJECTS["RandomUniformV2"] = initializers.RandomUniform 

69 LOCAL.ALL_OBJECTS["TruncatedNormalV2"] = initializers.TruncatedNormal 

70 LOCAL.ALL_OBJECTS["VarianceScalingV2"] = initializers.VarianceScaling 

71 LOCAL.ALL_OBJECTS["ZerosV2"] = initializers.Zeros 

72 

73 # Out of an abundance of caution we also include these aliases that have 

74 # a non-zero probability of having been included in saved configs in the 

75 # past. 

76 LOCAL.ALL_OBJECTS["glorot_normalV2"] = initializers.GlorotNormal 

77 LOCAL.ALL_OBJECTS["glorot_uniformV2"] = initializers.GlorotUniform 

78 LOCAL.ALL_OBJECTS["he_normalV2"] = initializers.HeNormal 

79 LOCAL.ALL_OBJECTS["he_uniformV2"] = initializers.HeUniform 

80 LOCAL.ALL_OBJECTS["lecun_normalV2"] = initializers.LecunNormal 

81 LOCAL.ALL_OBJECTS["lecun_uniformV2"] = initializers.LecunUniform 

82 

83 if tf.__internal__.tf2.enabled(): 

84 # For V2, entries are generated automatically based on the content of 

85 # initializers.py. 

86 v2_objs = {} 

87 base_cls = initializers.Initializer 

88 generic_utils.populate_dict_with_module_objects( 

89 v2_objs, 

90 [initializers], 

91 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls), 

92 ) 

93 for key, value in v2_objs.items(): 

94 LOCAL.ALL_OBJECTS[key] = value 

95 # Functional aliases. 

96 LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value 

97 else: 

98 # V1 initializers. 

99 v1_objs = { 

100 "Constant": tf.compat.v1.constant_initializer, 

101 "GlorotNormal": tf.compat.v1.glorot_normal_initializer, 

102 "GlorotUniform": tf.compat.v1.glorot_uniform_initializer, 

103 "Identity": tf.compat.v1.initializers.identity, 

104 "Ones": tf.compat.v1.ones_initializer, 

105 "Orthogonal": tf.compat.v1.orthogonal_initializer, 

106 "VarianceScaling": tf.compat.v1.variance_scaling_initializer, 

107 "Zeros": tf.compat.v1.zeros_initializer, 

108 "HeNormal": initializers_v1.HeNormal, 

109 "HeUniform": initializers_v1.HeUniform, 

110 "LecunNormal": initializers_v1.LecunNormal, 

111 "LecunUniform": initializers_v1.LecunUniform, 

112 "RandomNormal": initializers_v1.RandomNormal, 

113 "RandomUniform": initializers_v1.RandomUniform, 

114 "TruncatedNormal": initializers_v1.TruncatedNormal, 

115 } 

116 for key, value in v1_objs.items(): 

117 LOCAL.ALL_OBJECTS[key] = value 

118 # Functional aliases. 

119 LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value 

120 

121 # More compatibility aliases. 

122 LOCAL.ALL_OBJECTS["normal"] = LOCAL.ALL_OBJECTS["random_normal"] 

123 LOCAL.ALL_OBJECTS["uniform"] = LOCAL.ALL_OBJECTS["random_uniform"] 

124 LOCAL.ALL_OBJECTS["one"] = LOCAL.ALL_OBJECTS["ones"] 

125 LOCAL.ALL_OBJECTS["zero"] = LOCAL.ALL_OBJECTS["zeros"] 

126 

127 

128# For backwards compatibility, we populate this file with the objects 

129# from ALL_OBJECTS. We make no guarantees as to whether these objects will 

130# using their correct version. 

131populate_deserializable_objects() 

132globals().update(LOCAL.ALL_OBJECTS) 

133 

134# Utility functions 

135 

136 

137@keras_export("keras.initializers.serialize") 

138def serialize(initializer, use_legacy_format=False): 

139 if use_legacy_format: 

140 return legacy_serialization.serialize_keras_object(initializer) 

141 

142 return serialization_lib.serialize_keras_object(initializer) 

143 

144 

145@keras_export("keras.initializers.deserialize") 

146def deserialize(config, custom_objects=None, use_legacy_format=False): 

147 """Return an `Initializer` object from its config.""" 

148 populate_deserializable_objects() 

149 if use_legacy_format: 

150 return legacy_serialization.deserialize_keras_object( 

151 config, 

152 module_objects=LOCAL.ALL_OBJECTS, 

153 custom_objects=custom_objects, 

154 printable_module_name="initializer", 

155 ) 

156 

157 return serialization_lib.deserialize_keras_object( 

158 config, 

159 module_objects=LOCAL.ALL_OBJECTS, 

160 custom_objects=custom_objects, 

161 printable_module_name="initializer", 

162 ) 

163 

164 

165@keras_export("keras.initializers.get") 

166def get(identifier): 

167 """Retrieve a Keras initializer by the identifier. 

168 

169 The `identifier` may be the string name of a initializers function or class 

170 (case-sensitively). 

171 

172 >>> identifier = 'Ones' 

173 >>> tf.keras.initializers.deserialize(identifier) 

174 <...keras.initializers.initializers.Ones...> 

175 

176 You can also specify `config` of the initializer to this function by passing 

177 dict containing `class_name` and `config` as an identifier. Also note that 

178 the `class_name` must map to a `Initializer` class. 

179 

180 >>> cfg = {'class_name': 'Ones', 'config': {}} 

181 >>> tf.keras.initializers.deserialize(cfg) 

182 <...keras.initializers.initializers.Ones...> 

183 

184 In the case that the `identifier` is a class, this method will return a new 

185 instance of the class by its constructor. 

186 

187 Args: 

188 identifier: String or dict that contains the initializer name or 

189 configurations. 

190 

191 Returns: 

192 Initializer instance base on the input identifier. 

193 

194 Raises: 

195 ValueError: If the input identifier is not a supported type or in a bad 

196 format. 

197 """ 

198 

199 if identifier is None: 

200 return None 

201 if isinstance(identifier, dict): 

202 use_legacy_format = "module" not in identifier 

203 return deserialize(identifier, use_legacy_format=use_legacy_format) 

204 elif isinstance(identifier, str): 

205 config = {"class_name": str(identifier), "config": {}} 

206 return get(config) 

207 elif callable(identifier): 

208 if inspect.isclass(identifier): 

209 identifier = identifier() 

210 return identifier 

211 else: 

212 raise ValueError( 

213 "Could not interpret initializer identifier: " + str(identifier) 

214 ) 

215