Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/serialization.py: 22%

205 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Legacy serialization logic for Keras models.""" 

16 

17import threading 

18import weakref 

19 

20import tensorflow.compat.v2 as tf 

21 

22from keras.src.utils import tf_contextlib 

23from keras.src.utils import tf_inspect 

24 

25# isort: off 

26from tensorflow.python.util.tf_export import keras_export 

27 

28# Flag that determines whether to skip the NotImplementedError when calling 

29# get_config in custom models and layers. This is only enabled when saving to 

30# SavedModel, when the config isn't required. 

31_SKIP_FAILED_SERIALIZATION = False 

32# If a layer does not have a defined config, then the returned config will be a 

33# dictionary with the below key. 

34_LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config" 

35 

36# Store a unique, per-object ID for shared objects. 

37# 

38# We store a unique ID for each object so that we may, at loading time, 

39# re-create the network properly. Without this ID, we would have no way of 

40# determining whether a config is a description of a new object that 

41# should be created or is merely a reference to an already-created object. 

42SHARED_OBJECT_KEY = "shared_object_id" 

43 

44SHARED_OBJECT_DISABLED = threading.local() 

45SHARED_OBJECT_LOADING = threading.local() 

46SHARED_OBJECT_SAVING = threading.local() 

47 

48 

49# Attributes on the threadlocal variable must be set per-thread, thus we 

50# cannot initialize these globally. Instead, we have accessor functions with 

51# default values. 

52def _shared_object_disabled(): 

53 """Get whether shared object handling is disabled in a threadsafe manner.""" 

54 return getattr(SHARED_OBJECT_DISABLED, "disabled", False) 

55 

56 

57def _shared_object_loading_scope(): 

58 """Get the current shared object saving scope in a threadsafe manner.""" 

59 return getattr(SHARED_OBJECT_LOADING, "scope", NoopLoadingScope()) 

60 

61 

62def _shared_object_saving_scope(): 

63 """Get the current shared object saving scope in a threadsafe manner.""" 

64 return getattr(SHARED_OBJECT_SAVING, "scope", None) 

65 

66 

67class DisableSharedObjectScope: 

68 """A context manager for disabling handling of shared objects. 

69 

70 Disables shared object handling for both saving and loading. 

71 

72 Created primarily for use with `clone_model`, which does extra surgery that 

73 is incompatible with shared objects. 

74 """ 

75 

76 def __enter__(self): 

77 SHARED_OBJECT_DISABLED.disabled = True 

78 self._orig_loading_scope = _shared_object_loading_scope() 

79 self._orig_saving_scope = _shared_object_saving_scope() 

80 

81 def __exit__(self, *args, **kwargs): 

82 SHARED_OBJECT_DISABLED.disabled = False 

83 SHARED_OBJECT_LOADING.scope = self._orig_loading_scope 

84 SHARED_OBJECT_SAVING.scope = self._orig_saving_scope 

85 

86 

87class NoopLoadingScope: 

88 """The default shared object loading scope. It does nothing. 

89 

90 Created to simplify serialization code that doesn't care about shared 

91 objects (e.g. when serializing a single object). 

92 """ 

93 

94 def get(self, unused_object_id): 

95 return None 

96 

97 def set(self, object_id, obj): 

98 pass 

99 

100 

101class SharedObjectLoadingScope: 

102 """A context manager for keeping track of loaded objects. 

103 

104 During the deserialization process, we may come across objects that are 

105 shared across multiple layers. In order to accurately restore the network 

106 structure to its original state, `SharedObjectLoadingScope` allows us to 

107 re-use shared objects rather than cloning them. 

108 """ 

109 

110 def __enter__(self): 

111 if _shared_object_disabled(): 

112 return NoopLoadingScope() 

113 

114 global SHARED_OBJECT_LOADING 

115 SHARED_OBJECT_LOADING.scope = self 

116 self._obj_ids_to_obj = {} 

117 return self 

118 

119 def get(self, object_id): 

120 """Given a shared object ID, returns a previously instantiated object. 

121 

122 Args: 

123 object_id: shared object ID to use when attempting to find 

124 already-loaded object. 

125 

126 Returns: 

127 The object, if we've seen this ID before. Else, `None`. 

128 """ 

129 # Explicitly check for `None` internally to make external calling code a 

130 # bit cleaner. 

131 if object_id is None: 

132 return 

133 return self._obj_ids_to_obj.get(object_id) 

134 

135 def set(self, object_id, obj): 

136 """Stores an instantiated object for future lookup and sharing.""" 

137 if object_id is None: 

138 return 

139 self._obj_ids_to_obj[object_id] = obj 

140 

141 def __exit__(self, *args, **kwargs): 

142 global SHARED_OBJECT_LOADING 

143 SHARED_OBJECT_LOADING.scope = NoopLoadingScope() 

144 

145 

146class SharedObjectConfig(dict): 

147 """A configuration container that keeps track of references. 

148 

149 `SharedObjectConfig` will automatically attach a shared object ID to any 

150 configs which are referenced more than once, allowing for proper shared 

151 object reconstruction at load time. 

152 

153 In most cases, it would be more proper to subclass something like 

154 `collections.UserDict` or `collections.Mapping` rather than `dict` directly. 

155 Unfortunately, python's json encoder does not support `Mapping`s. This is 

156 important functionality to retain, since we are dealing with serialization. 

157 

158 We should be safe to subclass `dict` here, since we aren't actually 

159 overriding any core methods, only augmenting with a new one for reference 

160 counting. 

161 """ 

162 

163 def __init__(self, base_config, object_id, **kwargs): 

164 self.ref_count = 1 

165 self.object_id = object_id 

166 super().__init__(base_config, **kwargs) 

167 

168 def increment_ref_count(self): 

169 # As soon as we've seen the object more than once, we want to attach the 

170 # shared object ID. This allows us to only attach the shared object ID 

171 # when it's strictly necessary, making backwards compatibility breakage 

172 # less likely. 

173 if self.ref_count == 1: 

174 self[SHARED_OBJECT_KEY] = self.object_id 

175 self.ref_count += 1 

176 

177 

178class SharedObjectSavingScope: 

179 """Keeps track of shared object configs when serializing.""" 

180 

181 def __enter__(self): 

182 if _shared_object_disabled(): 

183 return None 

184 

185 global SHARED_OBJECT_SAVING 

186 

187 # Serialization can happen at a number of layers for a number of 

188 # reasons. We may end up with a case where we're opening a saving scope 

189 # within another saving scope. In that case, we'd like to use the 

190 # outermost scope available and ignore inner scopes, since there is not 

191 # (yet) a reasonable use case for having these nested and distinct. 

192 if _shared_object_saving_scope() is not None: 

193 self._passthrough = True 

194 return _shared_object_saving_scope() 

195 else: 

196 self._passthrough = False 

197 

198 SHARED_OBJECT_SAVING.scope = self 

199 self._shared_objects_config = weakref.WeakKeyDictionary() 

200 self._next_id = 0 

201 return self 

202 

203 def get_config(self, obj): 

204 """Gets a `SharedObjectConfig` if one has already been seen for `obj`. 

205 

206 Args: 

207 obj: The object for which to retrieve the `SharedObjectConfig`. 

208 

209 Returns: 

210 The SharedObjectConfig for a given object, if already seen. Else, 

211 `None`. 

212 """ 

213 try: 

214 shared_object_config = self._shared_objects_config[obj] 

215 except (TypeError, KeyError): 

216 # If the object is unhashable (e.g. a subclass of 

217 # `AbstractBaseClass` that has not overridden `__hash__`), a 

218 # `TypeError` will be thrown. We'll just continue on without shared 

219 # object support. 

220 return None 

221 shared_object_config.increment_ref_count() 

222 return shared_object_config 

223 

224 def create_config(self, base_config, obj): 

225 """Create a new SharedObjectConfig for a given object.""" 

226 shared_object_config = SharedObjectConfig(base_config, self._next_id) 

227 self._next_id += 1 

228 try: 

229 self._shared_objects_config[obj] = shared_object_config 

230 except TypeError: 

231 # If the object is unhashable (e.g. a subclass of 

232 # `AbstractBaseClass` that has not overridden `__hash__`), a 

233 # `TypeError` will be thrown. We'll just continue on without shared 

234 # object support. 

235 pass 

236 return shared_object_config 

237 

238 def __exit__(self, *args, **kwargs): 

239 if not getattr(self, "_passthrough", False): 

240 global SHARED_OBJECT_SAVING 

241 SHARED_OBJECT_SAVING.scope = None 

242 

243 

244def serialize_keras_class_and_config( 

245 cls_name, cls_config, obj=None, shared_object_id=None 

246): 

247 """Returns the serialization of the class with the given config.""" 

248 base_config = {"class_name": cls_name, "config": cls_config} 

249 

250 # We call `serialize_keras_class_and_config` for some branches of the load 

251 # path. In that case, we may already have a shared object ID we'd like to 

252 # retain. 

253 if shared_object_id is not None: 

254 base_config[SHARED_OBJECT_KEY] = shared_object_id 

255 

256 # If we have an active `SharedObjectSavingScope`, check whether we've 

257 # already serialized this config. If so, just use that config. This will 

258 # store an extra ID field in the config, allowing us to re-create the shared 

259 # object relationship at load time. 

260 if _shared_object_saving_scope() is not None and obj is not None: 

261 shared_object_config = _shared_object_saving_scope().get_config(obj) 

262 if shared_object_config is None: 

263 return _shared_object_saving_scope().create_config(base_config, obj) 

264 return shared_object_config 

265 

266 return base_config 

267 

268 

269@tf_contextlib.contextmanager 

270def skip_failed_serialization(): 

271 global _SKIP_FAILED_SERIALIZATION 

272 prev = _SKIP_FAILED_SERIALIZATION 

273 try: 

274 _SKIP_FAILED_SERIALIZATION = True 

275 yield 

276 finally: 

277 _SKIP_FAILED_SERIALIZATION = prev 

278 

279 

280@keras_export("keras.utils.legacy.serialize_keras_object") 

281def serialize_keras_object(instance): 

282 """Serialize a Keras object into a JSON-compatible representation. 

283 

284 Calls to `serialize_keras_object` while underneath the 

285 `SharedObjectSavingScope` context manager will cause any objects re-used 

286 across multiple layers to be saved with a special shared object ID. This 

287 allows the network to be re-created properly during deserialization. 

288 

289 Args: 

290 instance: The object to serialize. 

291 

292 Returns: 

293 A dict-like, JSON-compatible representation of the object's config. 

294 """ 

295 from keras.src.saving import object_registration 

296 

297 _, instance = tf.__internal__.decorator.unwrap(instance) 

298 if instance is None: 

299 return None 

300 

301 if hasattr(instance, "get_config"): 

302 name = object_registration.get_registered_name(instance.__class__) 

303 try: 

304 config = instance.get_config() 

305 except NotImplementedError as e: 

306 if _SKIP_FAILED_SERIALIZATION: 

307 return serialize_keras_class_and_config( 

308 name, {_LAYER_UNDEFINED_CONFIG_KEY: True} 

309 ) 

310 raise e 

311 serialization_config = {} 

312 for key, item in config.items(): 

313 if isinstance(item, str): 

314 serialization_config[key] = item 

315 continue 

316 

317 # Any object of a different type needs to be converted to string or 

318 # dict for serialization (e.g. custom functions, custom classes) 

319 try: 

320 serialized_item = serialize_keras_object(item) 

321 if isinstance(serialized_item, dict) and not isinstance( 

322 item, dict 

323 ): 

324 serialized_item["__passive_serialization__"] = True 

325 serialization_config[key] = serialized_item 

326 except ValueError: 

327 serialization_config[key] = item 

328 

329 name = object_registration.get_registered_name(instance.__class__) 

330 return serialize_keras_class_and_config( 

331 name, serialization_config, instance 

332 ) 

333 if hasattr(instance, "__name__"): 

334 return object_registration.get_registered_name(instance) 

335 raise ValueError( 

336 f"Cannot serialize {instance} because it doesn't implement " 

337 "`get_config()`." 

338 ) 

339 

340 

341def class_and_config_for_serialized_keras_object( 

342 config, 

343 module_objects=None, 

344 custom_objects=None, 

345 printable_module_name="object", 

346): 

347 """Returns the class name and config for a serialized keras object.""" 

348 from keras.src.saving import object_registration 

349 

350 if ( 

351 not isinstance(config, dict) 

352 or "class_name" not in config 

353 or "config" not in config 

354 ): 

355 raise ValueError( 

356 f"Improper config format for {config}. " 

357 "Expecting python dict contains `class_name` and `config` as keys" 

358 ) 

359 

360 class_name = config["class_name"] 

361 cls = object_registration.get_registered_object( 

362 class_name, custom_objects, module_objects 

363 ) 

364 if cls is None: 

365 raise ValueError( 

366 f"Unknown {printable_module_name}: '{class_name}'. " 

367 "Please ensure you are using a `keras.utils.custom_object_scope` " 

368 "and that this object is included in the scope. See " 

369 "https://www.tensorflow.org/guide/keras/save_and_serialize" 

370 "#registering_the_custom_object for details." 

371 ) 

372 

373 cls_config = config["config"] 

374 # Check if `cls_config` is a list. If it is a list, return the class and the 

375 # associated class configs for recursively deserialization. This case will 

376 # happen on the old version of sequential model (e.g. `keras_version` == 

377 # "2.0.6"), which is serialized in a different structure, for example 

378 # "{'class_name': 'Sequential', 

379 # 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}". 

380 if isinstance(cls_config, list): 

381 return (cls, cls_config) 

382 

383 deserialized_objects = {} 

384 for key, item in cls_config.items(): 

385 if key == "name": 

386 # Assume that the value of 'name' is a string that should not be 

387 # deserialized as a function. This avoids the corner case where 

388 # cls_config['name'] has an identical name to a custom function and 

389 # gets converted into that function. 

390 deserialized_objects[key] = item 

391 elif isinstance(item, dict) and "__passive_serialization__" in item: 

392 deserialized_objects[key] = deserialize_keras_object( 

393 item, 

394 module_objects=module_objects, 

395 custom_objects=custom_objects, 

396 printable_module_name="config_item", 

397 ) 

398 # TODO(momernick): Should this also have 'module_objects'? 

399 elif isinstance(item, str) and tf_inspect.isfunction( 

400 object_registration.get_registered_object(item, custom_objects) 

401 ): 

402 # Handle custom functions here. When saving functions, we only save 

403 # the function's name as a string. If we find a matching string in 

404 # the custom objects during deserialization, we convert the string 

405 # back to the original function. 

406 # Note that a potential issue is that a string field could have a 

407 # naming conflict with a custom function name, but this should be a 

408 # rare case. This issue does not occur if a string field has a 

409 # naming conflict with a custom object, since the config of an 

410 # object will always be a dict. 

411 deserialized_objects[ 

412 key 

413 ] = object_registration.get_registered_object(item, custom_objects) 

414 for key, item in deserialized_objects.items(): 

415 cls_config[key] = deserialized_objects[key] 

416 

417 return (cls, cls_config) 

418 

419 

420@keras_export("keras.utils.legacy.deserialize_keras_object") 

421def deserialize_keras_object( 

422 identifier, 

423 module_objects=None, 

424 custom_objects=None, 

425 printable_module_name="object", 

426): 

427 """Turns the serialized form of a Keras object back into an actual object. 

428 

429 This function is for mid-level library implementers rather than end users. 

430 

431 Importantly, this utility requires you to provide the dict of 

432 `module_objects` to use for looking up the object config; this is not 

433 populated by default. If you need a deserialization utility that has 

434 preexisting knowledge of built-in Keras objects, use e.g. 

435 `keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`, 

436 etc. 

437 

438 Calling `deserialize_keras_object` while underneath the 

439 `SharedObjectLoadingScope` context manager will cause any already-seen 

440 shared objects to be returned as-is rather than creating a new object. 

441 

442 Args: 

443 identifier: the serialized form of the object. 

444 module_objects: A dictionary of built-in objects to look the name up in. 

445 Generally, `module_objects` is provided by midlevel library 

446 implementers. 

447 custom_objects: A dictionary of custom objects to look the name up in. 

448 Generally, `custom_objects` is provided by the end user. 

449 printable_module_name: A human-readable string representing the type of 

450 the object. Printed in case of exception. 

451 

452 Returns: 

453 The deserialized object. 

454 

455 Example: 

456 

457 A mid-level library implementer might want to implement a utility for 

458 retrieving an object from its config, as such: 

459 

460 ```python 

461 def deserialize(config, custom_objects=None): 

462 return deserialize_keras_object( 

463 identifier, 

464 module_objects=globals(), 

465 custom_objects=custom_objects, 

466 name="MyObjectType", 

467 ) 

468 ``` 

469 

470 This is how e.g. `keras.layers.deserialize()` is implemented. 

471 """ 

472 from keras.src.saving import object_registration 

473 

474 if identifier is None: 

475 return None 

476 

477 if isinstance(identifier, dict): 

478 # In this case we are dealing with a Keras config dictionary. 

479 config = identifier 

480 (cls, cls_config) = class_and_config_for_serialized_keras_object( 

481 config, module_objects, custom_objects, printable_module_name 

482 ) 

483 

484 # If this object has already been loaded (i.e. it's shared between 

485 # multiple objects), return the already-loaded object. 

486 shared_object_id = config.get(SHARED_OBJECT_KEY) 

487 shared_object = _shared_object_loading_scope().get(shared_object_id) 

488 if shared_object is not None: 

489 return shared_object 

490 

491 if hasattr(cls, "from_config"): 

492 arg_spec = tf_inspect.getfullargspec(cls.from_config) 

493 custom_objects = custom_objects or {} 

494 

495 if "custom_objects" in arg_spec.args: 

496 tlco = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__ 

497 deserialized_obj = cls.from_config( 

498 cls_config, 

499 custom_objects={ 

500 **object_registration._GLOBAL_CUSTOM_OBJECTS, 

501 **tlco, 

502 **custom_objects, 

503 }, 

504 ) 

505 else: 

506 with object_registration.CustomObjectScope(custom_objects): 

507 deserialized_obj = cls.from_config(cls_config) 

508 else: 

509 # Then `cls` may be a function returning a class. 

510 # in this case by convention `config` holds 

511 # the kwargs of the function. 

512 custom_objects = custom_objects or {} 

513 with object_registration.CustomObjectScope(custom_objects): 

514 deserialized_obj = cls(**cls_config) 

515 

516 # Add object to shared objects, in case we find it referenced again. 

517 _shared_object_loading_scope().set(shared_object_id, deserialized_obj) 

518 

519 return deserialized_obj 

520 

521 elif isinstance(identifier, str): 

522 object_name = identifier 

523 if custom_objects and object_name in custom_objects: 

524 obj = custom_objects.get(object_name) 

525 elif ( 

526 object_name 

527 in object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__ 

528 ): 

529 obj = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__[ 

530 object_name 

531 ] 

532 elif object_name in object_registration._GLOBAL_CUSTOM_OBJECTS: 

533 obj = object_registration._GLOBAL_CUSTOM_OBJECTS[object_name] 

534 else: 

535 obj = module_objects.get(object_name) 

536 if obj is None: 

537 raise ValueError( 

538 f"Unknown {printable_module_name}: '{object_name}'. " 

539 "Please ensure you are using a " 

540 "`keras.utils.custom_object_scope` " 

541 "and that this object is included in the scope. See " 

542 "https://www.tensorflow.org/guide/keras/save_and_serialize" 

543 "#registering_the_custom_object for details." 

544 ) 

545 

546 # Classes passed by name are instantiated with no args, functions are 

547 # returned as-is. 

548 if tf_inspect.isclass(obj): 

549 return obj() 

550 return obj 

551 elif tf_inspect.isfunction(identifier): 

552 # If a function has already been deserialized, return as is. 

553 return identifier 

554 else: 

555 raise ValueError( 

556 "Could not interpret serialized " 

557 f"{printable_module_name}: {identifier}" 

558 ) 

559 

560 

561def validate_config(config): 

562 """Determines whether config appears to be a valid layer config.""" 

563 return ( 

564 isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config 

565 ) 

566 

567 

568def is_default(method): 

569 """Check if a method is decorated with the `default` wrapper.""" 

570 return getattr(method, "_is_default", False) 

571