Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/initializers/__init__.py: 71%

77 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras initializer serialization / deserialization.""" 

16 

17import threading 

18 

19from tensorflow.python import tf2 

20from tensorflow.python.keras.initializers import initializers_v1 

21from tensorflow.python.keras.initializers import initializers_v2 

22from tensorflow.python.keras.utils import generic_utils 

23from tensorflow.python.keras.utils import tf_inspect as inspect 

24from tensorflow.python.ops import init_ops 

25from tensorflow.python.util.tf_export import keras_export 

26 

27 

28# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it 

29# thread-local to avoid concurrent mutations. 

30LOCAL = threading.local() 

31 

32 

33def populate_deserializable_objects(): 

34 """Populates dict ALL_OBJECTS with every built-in initializer. 

35 """ 

36 global LOCAL 

37 if not hasattr(LOCAL, 'ALL_OBJECTS'): 

38 LOCAL.ALL_OBJECTS = {} 

39 LOCAL.GENERATED_WITH_V2 = None 

40 

41 if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): 

42 # Objects dict is already generated for the proper TF version: 

43 # do nothing. 

44 return 

45 

46 LOCAL.ALL_OBJECTS = {} 

47 LOCAL.GENERATED_WITH_V2 = tf2.enabled() 

48 

49 # Compatibility aliases (need to exist in both V1 and V2). 

50 LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant 

51 LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal 

52 LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform 

53 LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal 

54 LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform 

55 LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity 

56 LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal 

57 LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform 

58 LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones 

59 LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal 

60 LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal 

61 LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform 

62 LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal 

63 LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling 

64 LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros 

65 

66 # Out of an abundance of caution we also include these aliases that have 

67 # a non-zero probability of having been included in saved configs in the past. 

68 LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal 

69 LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform 

70 LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal 

71 LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform 

72 LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal 

73 LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform 

74 

75 if tf2.enabled(): 

76 # For V2, entries are generated automatically based on the content of 

77 # initializers_v2.py. 

78 v2_objs = {} 

79 base_cls = initializers_v2.Initializer 

80 generic_utils.populate_dict_with_module_objects( 

81 v2_objs, 

82 [initializers_v2], 

83 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) 

84 for key, value in v2_objs.items(): 

85 LOCAL.ALL_OBJECTS[key] = value 

86 # Functional aliases. 

87 LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value 

88 else: 

89 # V1 initializers. 

90 v1_objs = { 

91 'Constant': init_ops.Constant, 

92 'GlorotNormal': init_ops.GlorotNormal, 

93 'GlorotUniform': init_ops.GlorotUniform, 

94 'Identity': init_ops.Identity, 

95 'Ones': init_ops.Ones, 

96 'Orthogonal': init_ops.Orthogonal, 

97 'VarianceScaling': init_ops.VarianceScaling, 

98 'Zeros': init_ops.Zeros, 

99 'HeNormal': initializers_v1.HeNormal, 

100 'HeUniform': initializers_v1.HeUniform, 

101 'LecunNormal': initializers_v1.LecunNormal, 

102 'LecunUniform': initializers_v1.LecunUniform, 

103 'RandomNormal': initializers_v1.RandomNormal, 

104 'RandomUniform': initializers_v1.RandomUniform, 

105 'TruncatedNormal': initializers_v1.TruncatedNormal, 

106 } 

107 for key, value in v1_objs.items(): 

108 LOCAL.ALL_OBJECTS[key] = value 

109 # Functional aliases. 

110 LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value 

111 

112 # More compatibility aliases. 

113 LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal'] 

114 LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform'] 

115 LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones'] 

116 LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros'] 

117 

118 

119# For backwards compatibility, we populate this file with the objects 

120# from ALL_OBJECTS. We make no guarantees as to whether these objects will 

121# using their correct version. 

122populate_deserializable_objects() 

123globals().update(LOCAL.ALL_OBJECTS) 

124 

125# Utility functions 

126 

127 

128@keras_export('keras.initializers.serialize') 

129def serialize(initializer): 

130 return generic_utils.serialize_keras_object(initializer) 

131 

132 

133@keras_export('keras.initializers.deserialize') 

134def deserialize(config, custom_objects=None): 

135 """Return an `Initializer` object from its config.""" 

136 populate_deserializable_objects() 

137 return generic_utils.deserialize_keras_object( 

138 config, 

139 module_objects=LOCAL.ALL_OBJECTS, 

140 custom_objects=custom_objects, 

141 printable_module_name='initializer') 

142 

143 

144@keras_export('keras.initializers.get') 

145def get(identifier): 

146 """Retrieve a Keras initializer by the identifier. 

147 

148 The `identifier` may be the string name of a initializers function or class ( 

149 case-sensitively). 

150 

151 >>> identifier = 'Ones' 

152 >>> tf.keras.initializers.deserialize(identifier) 

153 <...keras.initializers.initializers_v2.Ones...> 

154 

155 You can also specify `config` of the initializer to this function by passing 

156 dict containing `class_name` and `config` as an identifier. Also note that the 

157 `class_name` must map to a `Initializer` class. 

158 

159 >>> cfg = {'class_name': 'Ones', 'config': {}} 

160 >>> tf.keras.initializers.deserialize(cfg) 

161 <...keras.initializers.initializers_v2.Ones...> 

162 

163 In the case that the `identifier` is a class, this method will return a new 

164 instance of the class by its constructor. 

165 

166 Args: 

167 identifier: String or dict that contains the initializer name or 

168 configurations. 

169 

170 Returns: 

171 Initializer instance base on the input identifier. 

172 

173 Raises: 

174 ValueError: If the input identifier is not a supported type or in a bad 

175 format. 

176 """ 

177 

178 if identifier is None: 

179 return None 

180 if isinstance(identifier, dict): 

181 return deserialize(identifier) 

182 elif isinstance(identifier, str): 

183 identifier = str(identifier) 

184 return deserialize(identifier) 

185 elif callable(identifier): 

186 if inspect.isclass(identifier): 

187 identifier = identifier() 

188 return identifier 

189 else: 

190 raise ValueError('Could not interpret initializer identifier: ' + 

191 str(identifier))