Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/object_registration.py: 57%

51 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python utilities required by Keras.""" 

16 

17import inspect 

18import threading 

19 

20# isort: off 

21from tensorflow.python.util.tf_export import keras_export 

22 

23_GLOBAL_CUSTOM_OBJECTS = {} 

24_GLOBAL_CUSTOM_NAMES = {} 

25# Thread-local custom objects set by custom_object_scope. 

26_THREAD_LOCAL_CUSTOM_OBJECTS = threading.local() 

27 

28 

29@keras_export( 

30 "keras.saving.custom_object_scope", 

31 "keras.utils.custom_object_scope", 

32 "keras.utils.CustomObjectScope", 

33) 

34class CustomObjectScope: 

35 """Exposes custom classes/functions to Keras deserialization internals. 

36 

37 Under a scope `with custom_object_scope(objects_dict)`, Keras methods such 

38 as `tf.keras.models.load_model` or `tf.keras.models.model_from_config` 

39 will be able to deserialize any custom object referenced by a 

40 saved config (e.g. a custom layer or metric). 

41 

42 Example: 

43 

44 Consider a custom regularizer `my_regularizer`: 

45 

46 ```python 

47 layer = Dense(3, kernel_regularizer=my_regularizer) 

48 # Config contains a reference to `my_regularizer` 

49 config = layer.get_config() 

50 ... 

51 # Later: 

52 with custom_object_scope({'my_regularizer': my_regularizer}): 

53 layer = Dense.from_config(config) 

54 ``` 

55 

56 Args: 

57 *args: Dictionary or dictionaries of `{name: object}` pairs. 

58 """ 

59 

60 def __init__(self, *args): 

61 self.custom_objects = args 

62 self.backup = None 

63 

64 def __enter__(self): 

65 self.backup = _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.copy() 

66 for objects in self.custom_objects: 

67 _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.update(objects) 

68 return self 

69 

70 def __exit__(self, *args, **kwargs): 

71 _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.clear() 

72 _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__.update(self.backup) 

73 

74 

75@keras_export( 

76 "keras.saving.get_custom_objects", "keras.utils.get_custom_objects" 

77) 

78def get_custom_objects(): 

79 """Retrieves a live reference to the global dictionary of custom objects. 

80 

81 Custom objects set using using `custom_object_scope` are not added to the 

82 global dictionary of custom objects, and will not appear in the returned 

83 dictionary. 

84 

85 Example: 

86 

87 ```python 

88 get_custom_objects().clear() 

89 get_custom_objects()['MyObject'] = MyObject 

90 ``` 

91 

92 Returns: 

93 Global dictionary mapping registered class names to classes. 

94 """ 

95 return _GLOBAL_CUSTOM_OBJECTS 

96 

97 

98@keras_export( 

99 "keras.saving.register_keras_serializable", 

100 "keras.utils.register_keras_serializable", 

101) 

102def register_keras_serializable(package="Custom", name=None): 

103 """Registers an object with the Keras serialization framework. 

104 

105 This decorator injects the decorated class or function into the Keras custom 

106 object dictionary, so that it can be serialized and deserialized without 

107 needing an entry in the user-provided custom object dict. It also injects a 

108 function that Keras will call to get the object's serializable string key. 

109 

110 Note that to be serialized and deserialized, classes must implement the 

111 `get_config()` method. Functions do not have this requirement. 

112 

113 The object will be registered under the key 'package>name' where `name`, 

114 defaults to the object name if not passed. 

115 

116 Example: 

117 

118 ```python 

119 # Note that `'my_package'` is used as the `package` argument here, and since 

120 # the `name` argument is not provided, `'MyDense'` is used as the `name`. 

121 @keras.saving.register_keras_serializable('my_package') 

122 class MyDense(keras.layers.Dense): 

123 pass 

124 

125 assert keras.saving.get_registered_object('my_package>MyDense') == MyDense 

126 assert keras.saving.get_registered_name(MyDense) == 'my_package>MyDense' 

127 ``` 

128 

129 Args: 

130 package: The package that this class belongs to. This is used for the 

131 `key` (which is `"package>name"`) to idenfify the class. Note that this 

132 is the first argument passed into the decorator. 

133 name: The name to serialize this class under in this package. If not 

134 provided or `None`, the class' name will be used (note that this is the 

135 case when the decorator is used with only one argument, which becomes 

136 the `package`). 

137 

138 Returns: 

139 A decorator that registers the decorated class with the passed names. 

140 """ 

141 

142 def decorator(arg): 

143 """Registers a class with the Keras serialization framework.""" 

144 class_name = name if name is not None else arg.__name__ 

145 registered_name = package + ">" + class_name 

146 

147 if inspect.isclass(arg) and not hasattr(arg, "get_config"): 

148 raise ValueError( 

149 "Cannot register a class that does not have a " 

150 "get_config() method." 

151 ) 

152 

153 _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg 

154 _GLOBAL_CUSTOM_NAMES[arg] = registered_name 

155 

156 return arg 

157 

158 return decorator 

159 

160 

161@keras_export( 

162 "keras.saving.get_registered_name", "keras.utils.get_registered_name" 

163) 

164def get_registered_name(obj): 

165 """Returns the name registered to an object within the Keras framework. 

166 

167 This function is part of the Keras serialization and deserialization 

168 framework. It maps objects to the string names associated with those objects 

169 for serialization/deserialization. 

170 

171 Args: 

172 obj: The object to look up. 

173 

174 Returns: 

175 The name associated with the object, or the default Python name if the 

176 object is not registered. 

177 """ 

178 if obj in _GLOBAL_CUSTOM_NAMES: 

179 return _GLOBAL_CUSTOM_NAMES[obj] 

180 else: 

181 return obj.__name__ 

182 

183 

184@keras_export( 

185 "keras.saving.get_registered_object", "keras.utils.get_registered_object" 

186) 

187def get_registered_object(name, custom_objects=None, module_objects=None): 

188 """Returns the class associated with `name` if it is registered with Keras. 

189 

190 This function is part of the Keras serialization and deserialization 

191 framework. It maps strings to the objects associated with them for 

192 serialization/deserialization. 

193 

194 Example: 

195 

196 ```python 

197 def from_config(cls, config, custom_objects=None): 

198 if 'my_custom_object_name' in config: 

199 config['hidden_cls'] = tf.keras.saving.get_registered_object( 

200 config['my_custom_object_name'], custom_objects=custom_objects) 

201 ``` 

202 

203 Args: 

204 name: The name to look up. 

205 custom_objects: A dictionary of custom objects to look the name up in. 

206 Generally, custom_objects is provided by the user. 

207 module_objects: A dictionary of custom objects to look the name up in. 

208 Generally, module_objects is provided by midlevel library implementers. 

209 

210 Returns: 

211 An instantiable class associated with `name`, or `None` if no such class 

212 exists. 

213 """ 

214 if name in _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__: 

215 return _THREAD_LOCAL_CUSTOM_OBJECTS.__dict__[name] 

216 elif name in _GLOBAL_CUSTOM_OBJECTS: 

217 return _GLOBAL_CUSTOM_OBJECTS[name] 

218 elif custom_objects and name in custom_objects: 

219 return custom_objects[name] 

220 elif module_objects and name in module_objects: 

221 return module_objects[name] 

222 return None 

223 

224 

225# Aliases 

226custom_object_scope = CustomObjectScope 

227