Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/serialization.py: 48%

105 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Layer serialization/deserialization functions.""" 

16 

17import threading 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src.engine import base_layer 

22from keras.src.engine import input_layer 

23from keras.src.engine import input_spec 

24from keras.src.layers import activation 

25from keras.src.layers import attention 

26from keras.src.layers import convolutional 

27from keras.src.layers import core 

28from keras.src.layers import locally_connected 

29from keras.src.layers import merging 

30from keras.src.layers import pooling 

31from keras.src.layers import regularization 

32from keras.src.layers import reshaping 

33from keras.src.layers import rnn 

34from keras.src.layers.normalization import batch_normalization 

35from keras.src.layers.normalization import batch_normalization_v1 

36from keras.src.layers.normalization import group_normalization 

37from keras.src.layers.normalization import layer_normalization 

38from keras.src.layers.normalization import unit_normalization 

39from keras.src.layers.preprocessing import category_encoding 

40from keras.src.layers.preprocessing import discretization 

41from keras.src.layers.preprocessing import hashed_crossing 

42from keras.src.layers.preprocessing import hashing 

43from keras.src.layers.preprocessing import image_preprocessing 

44from keras.src.layers.preprocessing import integer_lookup 

45from keras.src.layers.preprocessing import ( 

46 normalization as preprocessing_normalization, 

47) 

48from keras.src.layers.preprocessing import string_lookup 

49from keras.src.layers.preprocessing import text_vectorization 

50from keras.src.layers.rnn import cell_wrappers 

51from keras.src.layers.rnn import gru 

52from keras.src.layers.rnn import lstm 

53from keras.src.metrics import base_metric 

54from keras.src.saving import serialization_lib 

55from keras.src.saving.legacy import serialization as legacy_serialization 

56from keras.src.saving.legacy.saved_model import json_utils 

57from keras.src.utils import generic_utils 

58from keras.src.utils import tf_inspect as inspect 

59 

60# isort: off 

61from tensorflow.python.util.tf_export import keras_export 

62 

63ALL_MODULES = ( 

64 base_layer, 

65 input_layer, 

66 activation, 

67 attention, 

68 convolutional, 

69 core, 

70 locally_connected, 

71 merging, 

72 batch_normalization_v1, 

73 group_normalization, 

74 layer_normalization, 

75 unit_normalization, 

76 pooling, 

77 image_preprocessing, 

78 regularization, 

79 reshaping, 

80 rnn, 

81 hashing, 

82 hashed_crossing, 

83 category_encoding, 

84 discretization, 

85 integer_lookup, 

86 preprocessing_normalization, 

87 string_lookup, 

88 text_vectorization, 

89) 

90ALL_V2_MODULES = ( 

91 batch_normalization, 

92 layer_normalization, 

93 cell_wrappers, 

94 gru, 

95 lstm, 

96) 

97# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it 

98# thread-local to avoid concurrent mutations. 

99LOCAL = threading.local() 

100 

101 

102def populate_deserializable_objects(): 

103 """Populates dict ALL_OBJECTS with every built-in layer.""" 

104 global LOCAL 

105 if not hasattr(LOCAL, "ALL_OBJECTS"): 

106 LOCAL.ALL_OBJECTS = {} 

107 LOCAL.GENERATED_WITH_V2 = None 

108 

109 if ( 

110 LOCAL.ALL_OBJECTS 

111 and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled() 

112 ): 

113 # Objects dict is already generated for the proper TF version: 

114 # do nothing. 

115 return 

116 

117 LOCAL.ALL_OBJECTS = {} 

118 LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled() 

119 

120 base_cls = base_layer.Layer 

121 generic_utils.populate_dict_with_module_objects( 

122 LOCAL.ALL_OBJECTS, 

123 ALL_MODULES, 

124 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls), 

125 ) 

126 

127 # Overwrite certain V1 objects with V2 versions 

128 if tf.__internal__.tf2.enabled(): 

129 generic_utils.populate_dict_with_module_objects( 

130 LOCAL.ALL_OBJECTS, 

131 ALL_V2_MODULES, 

132 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls), 

133 ) 

134 

135 # These deserialization aliases are added for backward compatibility, 

136 # as in TF 1.13, "BatchNormalizationV1" and "BatchNormalizationV2" 

137 # were used as class name for v1 and v2 version of BatchNormalization, 

138 # respectively. Here we explicitly convert them to their canonical names. 

139 LOCAL.ALL_OBJECTS[ 

140 "BatchNormalizationV1" 

141 ] = batch_normalization_v1.BatchNormalization 

142 LOCAL.ALL_OBJECTS[ 

143 "BatchNormalizationV2" 

144 ] = batch_normalization.BatchNormalization 

145 

146 # Prevent circular dependencies. 

147 from keras.src import models 

148 from keras.src.feature_column.sequence_feature_column import ( 

149 SequenceFeatures, 

150 ) 

151 from keras.src.premade_models.linear import ( 

152 LinearModel, 

153 ) 

154 from keras.src.premade_models.wide_deep import ( 

155 WideDeepModel, 

156 ) 

157 

158 LOCAL.ALL_OBJECTS["Input"] = input_layer.Input 

159 LOCAL.ALL_OBJECTS["InputSpec"] = input_spec.InputSpec 

160 LOCAL.ALL_OBJECTS["Functional"] = models.Functional 

161 LOCAL.ALL_OBJECTS["Model"] = models.Model 

162 LOCAL.ALL_OBJECTS["SequenceFeatures"] = SequenceFeatures 

163 LOCAL.ALL_OBJECTS["Sequential"] = models.Sequential 

164 LOCAL.ALL_OBJECTS["LinearModel"] = LinearModel 

165 LOCAL.ALL_OBJECTS["WideDeepModel"] = WideDeepModel 

166 

167 if tf.__internal__.tf2.enabled(): 

168 from keras.src.feature_column.dense_features_v2 import ( 

169 DenseFeatures, 

170 ) 

171 

172 LOCAL.ALL_OBJECTS["DenseFeatures"] = DenseFeatures 

173 else: 

174 from keras.src.feature_column.dense_features import ( 

175 DenseFeatures, 

176 ) 

177 

178 LOCAL.ALL_OBJECTS["DenseFeatures"] = DenseFeatures 

179 

180 # Merging layers, function versions. 

181 LOCAL.ALL_OBJECTS["add"] = merging.add 

182 LOCAL.ALL_OBJECTS["subtract"] = merging.subtract 

183 LOCAL.ALL_OBJECTS["multiply"] = merging.multiply 

184 LOCAL.ALL_OBJECTS["average"] = merging.average 

185 LOCAL.ALL_OBJECTS["maximum"] = merging.maximum 

186 LOCAL.ALL_OBJECTS["minimum"] = merging.minimum 

187 LOCAL.ALL_OBJECTS["concatenate"] = merging.concatenate 

188 LOCAL.ALL_OBJECTS["dot"] = merging.dot 

189 

190 

191@keras_export("keras.layers.serialize") 

192def serialize(layer, use_legacy_format=False): 

193 """Serializes a `Layer` object into a JSON-compatible representation. 

194 

195 Args: 

196 layer: The `Layer` object to serialize. 

197 

198 Returns: 

199 A JSON-serializable dict representing the object's config. 

200 

201 Example: 

202 

203 ```python 

204 from pprint import pprint 

205 model = tf.keras.models.Sequential() 

206 model.add(tf.keras.Input(shape=(16,))) 

207 model.add(tf.keras.layers.Dense(32, activation='relu')) 

208 

209 pprint(tf.keras.layers.serialize(model)) 

210 # prints the configuration of the model, as a dict. 

211 """ 

212 if isinstance(layer, base_metric.Metric): 

213 raise ValueError( 

214 f"Cannot serialize {layer} since it is a metric. " 

215 "Please use the `keras.metrics.serialize()` and " 

216 "`keras.metrics.deserialize()` APIs to serialize " 

217 "and deserialize metrics." 

218 ) 

219 if use_legacy_format: 

220 return legacy_serialization.serialize_keras_object(layer) 

221 

222 return serialization_lib.serialize_keras_object(layer) 

223 

224 

225@keras_export("keras.layers.deserialize") 

226def deserialize(config, custom_objects=None, use_legacy_format=False): 

227 """Instantiates a layer from a config dictionary. 

228 

229 Args: 

230 config: dict of the form {'class_name': str, 'config': dict} 

231 custom_objects: dict mapping class names (or function names) of custom 

232 (non-Keras) objects to class/functions 

233 

234 Returns: 

235 Layer instance (may be Model, Sequential, Network, Layer...) 

236 

237 Example: 

238 

239 ```python 

240 # Configuration of Dense(32, activation='relu') 

241 config = { 

242 'class_name': 'Dense', 

243 'config': { 

244 'activation': 'relu', 

245 'activity_regularizer': None, 

246 'bias_constraint': None, 

247 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 

248 'bias_regularizer': None, 

249 'dtype': 'float32', 

250 'kernel_constraint': None, 

251 'kernel_initializer': {'class_name': 'GlorotUniform', 

252 'config': {'seed': None}}, 

253 'kernel_regularizer': None, 

254 'name': 'dense', 

255 'trainable': True, 

256 'units': 32, 

257 'use_bias': True 

258 } 

259 } 

260 dense_layer = tf.keras.layers.deserialize(config) 

261 ``` 

262 """ 

263 populate_deserializable_objects() 

264 if not config: 

265 raise ValueError( 

266 f"Cannot deserialize empty config. Received: config={config}" 

267 ) 

268 if use_legacy_format: 

269 return legacy_serialization.deserialize_keras_object( 

270 config, 

271 module_objects=LOCAL.ALL_OBJECTS, 

272 custom_objects=custom_objects, 

273 printable_module_name="layer", 

274 ) 

275 

276 return serialization_lib.deserialize_keras_object( 

277 config, 

278 module_objects=LOCAL.ALL_OBJECTS, 

279 custom_objects=custom_objects, 

280 printable_module_name="layer", 

281 ) 

282 

283 

284def get_builtin_layer(class_name): 

285 """Returns class if `class_name` is registered, else returns None.""" 

286 if not hasattr(LOCAL, "ALL_OBJECTS"): 

287 populate_deserializable_objects() 

288 return LOCAL.ALL_OBJECTS.get(class_name) 

289 

290 

291def deserialize_from_json(json_string, custom_objects=None): 

292 """Instantiates a layer from a JSON string.""" 

293 populate_deserializable_objects() 

294 config = json_utils.decode_and_deserialize( 

295 json_string, 

296 module_objects=LOCAL.ALL_OBJECTS, 

297 custom_objects=custom_objects, 

298 ) 

299 return deserialize(config, custom_objects) 

300