Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/serialization_lib.py: 16%

289 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Object config serialization and deserialization logic.""" 

16 

17import importlib 

18import inspect 

19import threading 

20import types 

21import warnings 

22 

23import numpy as np 

24import tensorflow.compat.v2 as tf 

25 

26from keras.src.saving import object_registration 

27from keras.src.saving.legacy import serialization as legacy_serialization 

28from keras.src.saving.legacy.saved_model.utils import in_tf_saved_model_scope 

29from keras.src.utils import generic_utils 

30 

31# isort: off 

32from tensorflow.python.util import tf_export 

33from tensorflow.python.util.tf_export import keras_export 

34 

35PLAIN_TYPES = (str, int, float, bool) 

36SHARED_OBJECTS = threading.local() 

37SAFE_MODE = threading.local() 

38# TODO(nkovela): Debug serialization of decorated functions inside lambdas 

39# to allow for serialization of custom_gradient. 

40NON_SERIALIZABLE_CLASS_MODULES = ("tensorflow.python.ops.custom_gradient",) 

41 

42# List of Keras modules with built-in string representations for Keras defaults 

43BUILTIN_MODULES = ( 

44 "activations", 

45 "constraints", 

46 "initializers", 

47 "losses", 

48 "metrics", 

49 "optimizers", 

50 "regularizers", 

51) 

52 

53 

54class Config: 

55 def __init__(self, **config): 

56 self.config = config 

57 

58 def serialize(self): 

59 return serialize_keras_object(self.config) 

60 

61 

62class SafeModeScope: 

63 """Scope to propagate safe mode flag to nested deserialization calls.""" 

64 

65 def __init__(self, safe_mode=True): 

66 self.safe_mode = safe_mode 

67 

68 def __enter__(self): 

69 self.original_value = in_safe_mode() 

70 SAFE_MODE.safe_mode = self.safe_mode 

71 

72 def __exit__(self, *args, **kwargs): 

73 SAFE_MODE.safe_mode = self.original_value 

74 

75 

76@keras_export("keras.__internal__.enable_unsafe_deserialization") 

77def enable_unsafe_deserialization(): 

78 """Disables safe mode globally, allowing deserialization of lambdas.""" 

79 SAFE_MODE.safe_mode = False 

80 

81 

82def in_safe_mode(): 

83 return getattr(SAFE_MODE, "safe_mode", None) 

84 

85 

86class ObjectSharingScope: 

87 """Scope to enable detection and reuse of previously seen objects.""" 

88 

89 def __enter__(self): 

90 SHARED_OBJECTS.enabled = True 

91 SHARED_OBJECTS.id_to_obj_map = {} 

92 SHARED_OBJECTS.id_to_config_map = {} 

93 

94 def __exit__(self, *args, **kwargs): 

95 SHARED_OBJECTS.enabled = False 

96 SHARED_OBJECTS.id_to_obj_map = {} 

97 SHARED_OBJECTS.id_to_config_map = {} 

98 

99 

100def get_shared_object(obj_id): 

101 """Retrieve an object previously seen during deserialization.""" 

102 if getattr(SHARED_OBJECTS, "enabled", False): 

103 return SHARED_OBJECTS.id_to_obj_map.get(obj_id, None) 

104 

105 

106def record_object_after_serialization(obj, config): 

107 """Call after serializing an object, to keep track of its config.""" 

108 if config["module"] == "__main__": 

109 config["module"] = None # Ensures module is None when no module found 

110 if not getattr(SHARED_OBJECTS, "enabled", False): 

111 return # Not in a sharing scope 

112 obj_id = int(id(obj)) 

113 if obj_id not in SHARED_OBJECTS.id_to_config_map: 

114 SHARED_OBJECTS.id_to_config_map[obj_id] = config 

115 else: 

116 config["shared_object_id"] = obj_id 

117 prev_config = SHARED_OBJECTS.id_to_config_map[obj_id] 

118 prev_config["shared_object_id"] = obj_id 

119 

120 

121def record_object_after_deserialization(obj, obj_id): 

122 """Call after deserializing an object, to keep track of it in the future.""" 

123 if not getattr(SHARED_OBJECTS, "enabled", False): 

124 return # Not in a sharing scope 

125 SHARED_OBJECTS.id_to_obj_map[obj_id] = obj 

126 

127 

128@keras_export( 

129 "keras.saving.serialize_keras_object", "keras.utils.serialize_keras_object" 

130) 

131def serialize_keras_object(obj): 

132 """Retrieve the config dict by serializing the Keras object. 

133 

134 `serialize_keras_object()` serializes a Keras object to a python dictionary 

135 that represents the object, and is a reciprocal function of 

136 `deserialize_keras_object()`. See `deserialize_keras_object()` for more 

137 information about the config format. 

138 

139 Args: 

140 obj: the Keras object to serialize. 

141 

142 Returns: 

143 A python dict that represents the object. The python dict can be 

144 deserialized via `deserialize_keras_object()`. 

145 """ 

146 # Fall back to legacy serialization for all TF1 users or if 

147 # wrapped by in_tf_saved_model_scope() to explicitly use legacy 

148 # saved_model logic. 

149 if not tf.__internal__.tf2.enabled() or in_tf_saved_model_scope(): 

150 return legacy_serialization.serialize_keras_object(obj) 

151 

152 if obj is None: 

153 return obj 

154 

155 if isinstance(obj, PLAIN_TYPES): 

156 return obj 

157 

158 if isinstance(obj, (list, tuple)): 

159 config_arr = [serialize_keras_object(x) for x in obj] 

160 return tuple(config_arr) if isinstance(obj, tuple) else config_arr 

161 if isinstance(obj, dict): 

162 return serialize_dict(obj) 

163 

164 # Special cases: 

165 if isinstance(obj, bytes): 

166 return { 

167 "class_name": "__bytes__", 

168 "config": {"value": obj.decode("utf-8")}, 

169 } 

170 if isinstance(obj, tf.TensorShape): 

171 return obj.as_list() if obj._dims is not None else None 

172 if isinstance(obj, tf.Tensor): 

173 return { 

174 "class_name": "__tensor__", 

175 "config": { 

176 "value": obj.numpy().tolist(), 

177 "dtype": obj.dtype.name, 

178 }, 

179 } 

180 if type(obj).__module__ == np.__name__: 

181 if isinstance(obj, np.ndarray) and obj.ndim > 0: 

182 return { 

183 "class_name": "__numpy__", 

184 "config": { 

185 "value": obj.tolist(), 

186 "dtype": obj.dtype.name, 

187 }, 

188 } 

189 else: 

190 # Treat numpy floats / etc as plain types. 

191 return obj.item() 

192 if isinstance(obj, tf.DType): 

193 return obj.name 

194 if isinstance(obj, tf.compat.v1.Dimension): 

195 return obj.value 

196 if isinstance(obj, types.FunctionType) and obj.__name__ == "<lambda>": 

197 warnings.warn( 

198 "The object being serialized includes a `lambda`. This is unsafe. " 

199 "In order to reload the object, you will have to pass " 

200 "`safe_mode=False` to the loading function. " 

201 "Please avoid using `lambda` in the " 

202 "future, and use named Python functions instead. " 

203 f"This is the `lambda` being serialized: {inspect.getsource(obj)}", 

204 stacklevel=2, 

205 ) 

206 return { 

207 "class_name": "__lambda__", 

208 "config": { 

209 "value": generic_utils.func_dump(obj), 

210 }, 

211 } 

212 if isinstance(obj, tf.TypeSpec): 

213 ts_config = obj._serialize() 

214 # TensorShape and tf.DType conversion 

215 ts_config = list( 

216 map( 

217 lambda x: x.as_list() 

218 if isinstance(x, tf.TensorShape) 

219 else (x.name if isinstance(x, tf.DType) else x), 

220 ts_config, 

221 ) 

222 ) 

223 return { 

224 "class_name": "__typespec__", 

225 "spec_name": obj.__class__.__name__, 

226 "module": obj.__class__.__module__, 

227 "config": ts_config, 

228 "registered_name": None, 

229 } 

230 

231 inner_config = _get_class_or_fn_config(obj) 

232 config_with_public_class = serialize_with_public_class( 

233 obj.__class__, inner_config 

234 ) 

235 

236 # TODO(nkovela): Add TF ops dispatch handler serialization for 

237 # ops.EagerTensor that contains nested numpy array. 

238 # Target: NetworkConstructionTest.test_constant_initializer_with_numpy 

239 if isinstance(inner_config, str) and inner_config == "op_dispatch_handler": 

240 return obj 

241 

242 if config_with_public_class is not None: 

243 

244 # Special case for non-serializable class modules 

245 if any( 

246 mod in config_with_public_class["module"] 

247 for mod in NON_SERIALIZABLE_CLASS_MODULES 

248 ): 

249 return obj 

250 

251 get_build_and_compile_config(obj, config_with_public_class) 

252 record_object_after_serialization(obj, config_with_public_class) 

253 return config_with_public_class 

254 

255 # Any custom object or otherwise non-exported object 

256 if isinstance(obj, types.FunctionType): 

257 module = obj.__module__ 

258 else: 

259 module = obj.__class__.__module__ 

260 class_name = obj.__class__.__name__ 

261 

262 if module == "builtins": 

263 registered_name = None 

264 else: 

265 if isinstance(obj, types.FunctionType): 

266 registered_name = object_registration.get_registered_name(obj) 

267 else: 

268 registered_name = object_registration.get_registered_name( 

269 obj.__class__ 

270 ) 

271 

272 config = { 

273 "module": module, 

274 "class_name": class_name, 

275 "config": inner_config, 

276 "registered_name": registered_name, 

277 } 

278 get_build_and_compile_config(obj, config) 

279 record_object_after_serialization(obj, config) 

280 return config 

281 

282 

283def get_build_and_compile_config(obj, config): 

284 if hasattr(obj, "get_build_config"): 

285 build_config = obj.get_build_config() 

286 if build_config is not None: 

287 config["build_config"] = serialize_dict(build_config) 

288 if hasattr(obj, "get_compile_config"): 

289 compile_config = obj.get_compile_config() 

290 if compile_config is not None: 

291 config["compile_config"] = serialize_dict(compile_config) 

292 return 

293 

294 

295def serialize_with_public_class(cls, inner_config=None): 

296 """Serializes classes from public Keras API or object registration. 

297 

298 Called to check and retrieve the config of any class that has a public 

299 Keras API or has been registered as serializable via 

300 `keras.saving.register_keras_serializable()`. 

301 """ 

302 # This gets the `keras.*` exported name, such as "keras.optimizers.Adam". 

303 keras_api_name = tf_export.get_canonical_name_for_symbol( 

304 cls, api_name="keras" 

305 ) 

306 

307 # Case of custom or unknown class object 

308 if keras_api_name is None: 

309 registered_name = object_registration.get_registered_name(cls) 

310 if registered_name is None: 

311 return None 

312 

313 # Return custom object config with corresponding registration name 

314 return { 

315 "module": cls.__module__, 

316 "class_name": cls.__name__, 

317 "config": inner_config, 

318 "registered_name": registered_name, 

319 } 

320 

321 # Split the canonical Keras API name into a Keras module and class name. 

322 parts = keras_api_name.split(".") 

323 return { 

324 "module": ".".join(parts[:-1]), 

325 "class_name": parts[-1], 

326 "config": inner_config, 

327 "registered_name": None, 

328 } 

329 

330 

331def serialize_with_public_fn(fn, config, fn_module_name=None): 

332 """Serializes functions from public Keras API or object registration. 

333 

334 Called to check and retrieve the config of any function that has a public 

335 Keras API or has been registered as serializable via 

336 `keras.saving.register_keras_serializable()`. If function's module name is 

337 already known, returns corresponding config. 

338 """ 

339 if fn_module_name: 

340 return { 

341 "module": fn_module_name, 

342 "class_name": "function", 

343 "config": config, 

344 "registered_name": config, 

345 } 

346 keras_api_name = tf_export.get_canonical_name_for_symbol( 

347 fn, api_name="keras" 

348 ) 

349 if keras_api_name: 

350 parts = keras_api_name.split(".") 

351 return { 

352 "module": ".".join(parts[:-1]), 

353 "class_name": "function", 

354 "config": config, 

355 "registered_name": config, 

356 } 

357 else: 

358 registered_name = object_registration.get_registered_name(fn) 

359 if not registered_name and not fn.__module__ == "builtins": 

360 return None 

361 return { 

362 "module": fn.__module__, 

363 "class_name": "function", 

364 "config": config, 

365 "registered_name": registered_name, 

366 } 

367 

368 

369def _get_class_or_fn_config(obj): 

370 """Return the object's config depending on its type.""" 

371 # Functions / lambdas: 

372 if isinstance(obj, types.FunctionType): 

373 return obj.__name__ 

374 # All classes: 

375 if hasattr(obj, "get_config"): 

376 config = obj.get_config() 

377 if not isinstance(config, dict): 

378 raise TypeError( 

379 f"The `get_config()` method of {obj} should return " 

380 f"a dict. It returned: {config}" 

381 ) 

382 return serialize_dict(config) 

383 elif hasattr(obj, "__name__"): 

384 return object_registration.get_registered_name(obj) 

385 else: 

386 raise TypeError( 

387 f"Cannot serialize object {obj} of type {type(obj)}. " 

388 "To be serializable, " 

389 "a class must implement the `get_config()` method." 

390 ) 

391 

392 

393def serialize_dict(obj): 

394 return {key: serialize_keras_object(value) for key, value in obj.items()} 

395 

396 

397@keras_export( 

398 "keras.saving.deserialize_keras_object", 

399 "keras.utils.deserialize_keras_object", 

400) 

401def deserialize_keras_object( 

402 config, custom_objects=None, safe_mode=True, **kwargs 

403): 

404 """Retrieve the object by deserializing the config dict. 

405 

406 The config dict is a Python dictionary that consists of a set of key-value 

407 pairs, and represents a Keras object, such as an `Optimizer`, `Layer`, 

408 `Metrics`, etc. The saving and loading library uses the following keys to 

409 record information of a Keras object: 

410 

411 - `class_name`: String. This is the name of the class, 

412 as exactly defined in the source 

413 code, such as "LossesContainer". 

414 - `config`: Dict. Library-defined or user-defined key-value pairs that store 

415 the configuration of the object, as obtained by `object.get_config()`. 

416 - `module`: String. The path of the python module, such as 

417 "keras.engine.compile_utils". Built-in Keras classes 

418 expect to have prefix `keras`. 

419 - `registered_name`: String. The key the class is registered under via 

420 `keras.saving.register_keras_serializable(package, name)` API. The key has 

421 the format of '{package}>{name}', where `package` and `name` are the 

422 arguments passed to `register_keras_serializable()`. If `name` is not 

423 provided, it uses the class name. If `registered_name` successfully 

424 resolves to a class (that was registered), the `class_name` and `config` 

425 values in the dict will not be used. `registered_name` is only used for 

426 non-built-in classes. 

427 

428 For example, the following dictionary represents the built-in Adam optimizer 

429 with the relevant config: 

430 

431 ```python 

432 dict_structure = { 

433 "class_name": "Adam", 

434 "config": { 

435 "amsgrad": false, 

436 "beta_1": 0.8999999761581421, 

437 "beta_2": 0.9990000128746033, 

438 "decay": 0.0, 

439 "epsilon": 1e-07, 

440 "learning_rate": 0.0010000000474974513, 

441 "name": "Adam" 

442 }, 

443 "module": "keras.optimizers", 

444 "registered_name": None 

445 } 

446 # Returns an `Adam` instance identical to the original one. 

447 deserialize_keras_object(dict_structure) 

448 ``` 

449 

450 If the class does not have an exported Keras namespace, the library tracks 

451 it by its `module` and `class_name`. For example: 

452 

453 ```python 

454 dict_structure = { 

455 "class_name": "LossesContainer", 

456 "config": { 

457 "losses": [...], 

458 "total_loss_mean": {...}, 

459 }, 

460 "module": "keras.engine.compile_utils", 

461 "registered_name": "LossesContainer" 

462 } 

463 

464 # Returns a `LossesContainer` instance identical to the original one. 

465 deserialize_keras_object(dict_structure) 

466 ``` 

467 

468 And the following dictionary represents a user-customized `MeanSquaredError` 

469 loss: 

470 

471 ```python 

472 @keras.saving.register_keras_serializable(package='my_package') 

473 class ModifiedMeanSquaredError(keras.losses.MeanSquaredError): 

474 ... 

475 

476 dict_structure = { 

477 "class_name": "ModifiedMeanSquaredError", 

478 "config": { 

479 "fn": "mean_squared_error", 

480 "name": "mean_squared_error", 

481 "reduction": "auto" 

482 }, 

483 "registered_name": "my_package>ModifiedMeanSquaredError" 

484 } 

485 # Returns the `ModifiedMeanSquaredError` object 

486 deserialize_keras_object(dict_structure) 

487 ``` 

488 

489 Args: 

490 config: Python dict describing the object. 

491 custom_objects: Python dict containing a mapping between custom 

492 object names the corresponding classes or functions. 

493 safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization. 

494 When `safe_mode=False`, loading an object has the potential to 

495 trigger arbitrary code execution. This argument is only 

496 applicable to the Keras v3 model format. Defaults to `True`. 

497 

498 Returns: 

499 The object described by the `config` dictionary. 

500 

501 """ 

502 safe_scope_arg = in_safe_mode() # Enforces SafeModeScope 

503 safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode 

504 

505 module_objects = kwargs.pop("module_objects", None) 

506 custom_objects = custom_objects or {} 

507 tlco = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__ 

508 gco = object_registration._GLOBAL_CUSTOM_OBJECTS 

509 custom_objects = {**custom_objects, **tlco, **gco} 

510 

511 # Optional deprecated argument for legacy deserialization call 

512 printable_module_name = kwargs.pop("printable_module_name", "object") 

513 if kwargs: 

514 raise ValueError( 

515 "The following argument(s) are not supported: " 

516 f"{list(kwargs.keys())}" 

517 ) 

518 

519 # Fall back to legacy deserialization for all TF1 users or if 

520 # wrapped by in_tf_saved_model_scope() to explicitly use legacy 

521 # saved_model logic. 

522 if not tf.__internal__.tf2.enabled() or in_tf_saved_model_scope(): 

523 return legacy_serialization.deserialize_keras_object( 

524 config, module_objects, custom_objects, printable_module_name 

525 ) 

526 

527 if config is None: 

528 return None 

529 

530 if ( 

531 isinstance(config, str) 

532 and custom_objects 

533 and custom_objects.get(config) is not None 

534 ): 

535 # This is to deserialize plain functions which are serialized as 

536 # string names by legacy saving formats. 

537 return custom_objects[config] 

538 

539 if isinstance(config, (list, tuple)): 

540 return [ 

541 deserialize_keras_object( 

542 x, custom_objects=custom_objects, safe_mode=safe_mode 

543 ) 

544 for x in config 

545 ] 

546 

547 if module_objects is not None: 

548 inner_config, fn_module_name, has_custom_object = None, None, False 

549 if isinstance(config, dict): 

550 if "config" in config: 

551 inner_config = config["config"] 

552 if "class_name" not in config: 

553 raise ValueError( 

554 f"Unknown `config` as a `dict`, config={config}" 

555 ) 

556 

557 # Check case where config is function or class and in custom objects 

558 if custom_objects and ( 

559 config["class_name"] in custom_objects 

560 or config.get("registered_name") in custom_objects 

561 or ( 

562 isinstance(inner_config, str) 

563 and inner_config in custom_objects 

564 ) 

565 ): 

566 has_custom_object = True 

567 

568 # Case where config is function but not in custom objects 

569 elif config["class_name"] == "function": 

570 fn_module_name = config["module"] 

571 if fn_module_name == "builtins": 

572 config = config["config"] 

573 else: 

574 config = config["registered_name"] 

575 

576 # Case where config is class but not in custom objects 

577 else: 

578 if config.get("module", "_") is None: 

579 raise TypeError( 

580 "Cannot deserialize object of type " 

581 f"`{config['class_name']}`. If " 

582 f"`{config['class_name']}` is a custom class, please " 

583 "register it using the " 

584 "`@keras.saving.register_keras_serializable()` " 

585 "decorator." 

586 ) 

587 config = config["class_name"] 

588 if not has_custom_object: 

589 # Return if not found in either module objects or custom objects 

590 if config not in module_objects: 

591 # Object has already been deserialized 

592 return config 

593 if isinstance(module_objects[config], types.FunctionType): 

594 return deserialize_keras_object( 

595 serialize_with_public_fn( 

596 module_objects[config], config, fn_module_name 

597 ), 

598 custom_objects=custom_objects, 

599 ) 

600 return deserialize_keras_object( 

601 serialize_with_public_class( 

602 module_objects[config], inner_config=inner_config 

603 ), 

604 custom_objects=custom_objects, 

605 ) 

606 

607 if isinstance(config, PLAIN_TYPES): 

608 return config 

609 if not isinstance(config, dict): 

610 raise TypeError(f"Could not parse config: {config}") 

611 

612 if "class_name" not in config or "config" not in config: 

613 return { 

614 key: deserialize_keras_object( 

615 value, custom_objects=custom_objects, safe_mode=safe_mode 

616 ) 

617 for key, value in config.items() 

618 } 

619 

620 class_name = config["class_name"] 

621 inner_config = config["config"] or {} 

622 custom_objects = custom_objects or {} 

623 

624 # Special cases: 

625 if class_name == "__tensor__": 

626 return tf.constant(inner_config["value"], dtype=inner_config["dtype"]) 

627 if class_name == "__numpy__": 

628 return np.array(inner_config["value"], dtype=inner_config["dtype"]) 

629 if config["class_name"] == "__bytes__": 

630 return inner_config["value"].encode("utf-8") 

631 if config["class_name"] == "__lambda__": 

632 if safe_mode: 

633 raise ValueError( 

634 "Requested the deserialization of a `lambda` object. " 

635 "This carries a potential risk of arbitrary code execution " 

636 "and thus it is disallowed by default. If you trust the " 

637 "source of the saved model, you can pass `safe_mode=False` to " 

638 "the loading function in order to allow `lambda` loading." 

639 ) 

640 return generic_utils.func_load(inner_config["value"]) 

641 if config["class_name"] == "__typespec__": 

642 obj = _retrieve_class_or_fn( 

643 config["spec_name"], 

644 config["registered_name"], 

645 config["module"], 

646 obj_type="class", 

647 full_config=config, 

648 custom_objects=custom_objects, 

649 ) 

650 # Conversion to TensorShape and tf.DType 

651 inner_config = map( 

652 lambda x: tf.TensorShape(x) 

653 if isinstance(x, list) 

654 else (getattr(tf, x) if hasattr(tf.dtypes, str(x)) else x), 

655 inner_config, 

656 ) 

657 return obj._deserialize(tuple(inner_config)) 

658 

659 # Below: classes and functions. 

660 module = config.get("module", None) 

661 registered_name = config.get("registered_name", class_name) 

662 

663 if class_name == "function": 

664 fn_name = inner_config 

665 return _retrieve_class_or_fn( 

666 fn_name, 

667 registered_name, 

668 module, 

669 obj_type="function", 

670 full_config=config, 

671 custom_objects=custom_objects, 

672 ) 

673 

674 # Below, handling of all classes. 

675 # First, is it a shared object? 

676 if "shared_object_id" in config: 

677 obj = get_shared_object(config["shared_object_id"]) 

678 if obj is not None: 

679 return obj 

680 

681 cls = _retrieve_class_or_fn( 

682 class_name, 

683 registered_name, 

684 module, 

685 obj_type="class", 

686 full_config=config, 

687 custom_objects=custom_objects, 

688 ) 

689 

690 if isinstance(cls, types.FunctionType): 

691 return cls 

692 if not hasattr(cls, "from_config"): 

693 raise TypeError( 

694 f"Unable to reconstruct an instance of '{class_name}' because " 

695 f"the class is missing a `from_config()` method. " 

696 f"Full object config: {config}" 

697 ) 

698 

699 # Instantiate the class from its config inside a custom object scope 

700 # so that we can catch any custom objects that the config refers to. 

701 custom_obj_scope = object_registration.custom_object_scope(custom_objects) 

702 safe_mode_scope = SafeModeScope(safe_mode) 

703 with custom_obj_scope, safe_mode_scope: 

704 instance = cls.from_config(inner_config) 

705 build_config = config.get("build_config", None) 

706 if build_config: 

707 instance.build_from_config(build_config) 

708 compile_config = config.get("compile_config", None) 

709 if compile_config: 

710 instance.compile_from_config(compile_config) 

711 

712 if "shared_object_id" in config: 

713 record_object_after_deserialization( 

714 instance, config["shared_object_id"] 

715 ) 

716 return instance 

717 

718 

719def _retrieve_class_or_fn( 

720 name, registered_name, module, obj_type, full_config, custom_objects=None 

721): 

722 # If there is a custom object registered via 

723 # `register_keras_serializable()`, that takes precedence. 

724 if obj_type == "function": 

725 custom_obj = object_registration.get_registered_object( 

726 name, custom_objects=custom_objects 

727 ) 

728 else: 

729 custom_obj = object_registration.get_registered_object( 

730 registered_name, custom_objects=custom_objects 

731 ) 

732 if custom_obj is not None: 

733 return custom_obj 

734 

735 if module: 

736 # If it's a Keras built-in object, 

737 # we cannot always use direct import, because the exported 

738 # module name might not match the package structure 

739 # (e.g. experimental symbols). 

740 if module == "keras" or module.startswith("keras."): 

741 api_name = module + "." + name 

742 

743 # Legacy internal APIs are stored in TF API naming dict 

744 # with `compat.v1` prefix 

745 if "__internal__.legacy" in api_name: 

746 api_name = "compat.v1." + api_name 

747 

748 obj = tf_export.get_symbol_from_name(api_name) 

749 if obj is not None: 

750 return obj 

751 

752 # Configs of Keras built-in functions do not contain identifying 

753 # information other than their name (e.g. 'acc' or 'tanh'). This special 

754 # case searches the Keras modules that contain built-ins to retrieve 

755 # the corresponding function from the identifying string. 

756 if obj_type == "function" and module == "builtins": 

757 for mod in BUILTIN_MODULES: 

758 obj = tf_export.get_symbol_from_name( 

759 "keras." + mod + "." + name 

760 ) 

761 if obj is not None: 

762 return obj 

763 

764 # Retrieval of registered custom function in a package 

765 filtered_dict = { 

766 k: v 

767 for k, v in custom_objects.items() 

768 if k.endswith(full_config["config"]) 

769 } 

770 if filtered_dict: 

771 return next(iter(filtered_dict.values())) 

772 

773 # Otherwise, attempt to retrieve the class object given the `module` 

774 # and `class_name`. Import the module, find the class. 

775 try: 

776 mod = importlib.import_module(module) 

777 except ModuleNotFoundError: 

778 raise TypeError( 

779 f"Could not deserialize {obj_type} '{name}' because " 

780 f"its parent module {module} cannot be imported. " 

781 f"Full object config: {full_config}" 

782 ) 

783 obj = vars(mod).get(name, None) 

784 

785 # Special case for keras.metrics.metrics 

786 if obj is None and registered_name is not None: 

787 obj = vars(mod).get(registered_name, None) 

788 

789 if obj is not None: 

790 return obj 

791 

792 raise TypeError( 

793 f"Could not locate {obj_type} '{name}'. " 

794 "Make sure custom classes are decorated with " 

795 "`@keras.saving.register_keras_serializable()`. " 

796 f"Full object config: {full_config}" 

797 ) 

798