Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py: 22%

493 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python utilities required by Keras.""" 

16 

17import binascii 

18import codecs 

19import importlib 

20import marshal 

21import os 

22import re 

23import sys 

24import threading 

25import time 

26import types as python_types 

27import warnings 

28import weakref 

29 

30import numpy as np 

31 

32from tensorflow.python.keras.utils import tf_contextlib 

33from tensorflow.python.keras.utils import tf_inspect 

34from tensorflow.python.util import nest 

35from tensorflow.python.util import tf_decorator 

36from tensorflow.python.util.tf_export import keras_export 

37 

38_GLOBAL_CUSTOM_OBJECTS = {} 

39_GLOBAL_CUSTOM_NAMES = {} 

40 

41# Flag that determines whether to skip the NotImplementedError when calling 

42# get_config in custom models and layers. This is only enabled when saving to 

43# SavedModel, when the config isn't required. 

44_SKIP_FAILED_SERIALIZATION = False 

45# If a layer does not have a defined config, then the returned config will be a 

46# dictionary with the below key. 

47_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config' 

48 

49 

50@keras_export('keras.utils.custom_object_scope', # pylint: disable=g-classes-have-attributes 

51 'keras.utils.CustomObjectScope') 

52class CustomObjectScope(object): 

53 """Exposes custom classes/functions to Keras deserialization internals. 

54 

55 Under a scope `with custom_object_scope(objects_dict)`, Keras methods such 

56 as `tf.keras.models.load_model` or `tf.keras.models.model_from_config` 

57 will be able to deserialize any custom object referenced by a 

58 saved config (e.g. a custom layer or metric). 

59 

60 Example: 

61 

62 Consider a custom regularizer `my_regularizer`: 

63 

64 ```python 

65 layer = Dense(3, kernel_regularizer=my_regularizer) 

66 config = layer.get_config() # Config contains a reference to `my_regularizer` 

67 ... 

68 # Later: 

69 with custom_object_scope({'my_regularizer': my_regularizer}): 

70 layer = Dense.from_config(config) 

71 ``` 

72 

73 Args: 

74 *args: Dictionary or dictionaries of `{name: object}` pairs. 

75 """ 

76 

77 def __init__(self, *args): 

78 self.custom_objects = args 

79 self.backup = None 

80 

81 def __enter__(self): 

82 self.backup = _GLOBAL_CUSTOM_OBJECTS.copy() 

83 for objects in self.custom_objects: 

84 _GLOBAL_CUSTOM_OBJECTS.update(objects) 

85 return self 

86 

87 def __exit__(self, *args, **kwargs): 

88 _GLOBAL_CUSTOM_OBJECTS.clear() 

89 _GLOBAL_CUSTOM_OBJECTS.update(self.backup) 

90 

91 

92@keras_export('keras.utils.get_custom_objects') 

93def get_custom_objects(): 

94 """Retrieves a live reference to the global dictionary of custom objects. 

95 

96 Updating and clearing custom objects using `custom_object_scope` 

97 is preferred, but `get_custom_objects` can 

98 be used to directly access the current collection of custom objects. 

99 

100 Example: 

101 

102 ```python 

103 get_custom_objects().clear() 

104 get_custom_objects()['MyObject'] = MyObject 

105 ``` 

106 

107 Returns: 

108 Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`). 

109 """ 

110 return _GLOBAL_CUSTOM_OBJECTS 

111 

112 

113# Store a unique, per-object ID for shared objects. 

114# 

115# We store a unique ID for each object so that we may, at loading time, 

116# re-create the network properly. Without this ID, we would have no way of 

117# determining whether a config is a description of a new object that 

118# should be created or is merely a reference to an already-created object. 

119SHARED_OBJECT_KEY = 'shared_object_id' 

120 

121 

122SHARED_OBJECT_DISABLED = threading.local() 

123SHARED_OBJECT_LOADING = threading.local() 

124SHARED_OBJECT_SAVING = threading.local() 

125 

126 

127# Attributes on the threadlocal variable must be set per-thread, thus we 

128# cannot initialize these globally. Instead, we have accessor functions with 

129# default values. 

130def _shared_object_disabled(): 

131 """Get whether shared object handling is disabled in a threadsafe manner.""" 

132 return getattr(SHARED_OBJECT_DISABLED, 'disabled', False) 

133 

134 

135def _shared_object_loading_scope(): 

136 """Get the current shared object saving scope in a threadsafe manner.""" 

137 return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope()) 

138 

139 

140def _shared_object_saving_scope(): 

141 """Get the current shared object saving scope in a threadsafe manner.""" 

142 return getattr(SHARED_OBJECT_SAVING, 'scope', None) 

143 

144 

145class DisableSharedObjectScope(object): 

146 """A context manager for disabling handling of shared objects. 

147 

148 Disables shared object handling for both saving and loading. 

149 

150 Created primarily for use with `clone_model`, which does extra surgery that 

151 is incompatible with shared objects. 

152 """ 

153 

154 def __enter__(self): 

155 SHARED_OBJECT_DISABLED.disabled = True 

156 self._orig_loading_scope = _shared_object_loading_scope() 

157 self._orig_saving_scope = _shared_object_saving_scope() 

158 

159 def __exit__(self, *args, **kwargs): 

160 SHARED_OBJECT_DISABLED.disabled = False 

161 SHARED_OBJECT_LOADING.scope = self._orig_loading_scope 

162 SHARED_OBJECT_SAVING.scope = self._orig_saving_scope 

163 

164 

165class NoopLoadingScope(object): 

166 """The default shared object loading scope. It does nothing. 

167 

168 Created to simplify serialization code that doesn't care about shared objects 

169 (e.g. when serializing a single object). 

170 """ 

171 

172 def get(self, unused_object_id): 

173 return None 

174 

175 def set(self, object_id, obj): 

176 pass 

177 

178 

179class SharedObjectLoadingScope(object): 

180 """A context manager for keeping track of loaded objects. 

181 

182 During the deserialization process, we may come across objects that are 

183 shared across multiple layers. In order to accurately restore the network 

184 structure to its original state, `SharedObjectLoadingScope` allows us to 

185 re-use shared objects rather than cloning them. 

186 """ 

187 

188 def __enter__(self): 

189 if _shared_object_disabled(): 

190 return NoopLoadingScope() 

191 

192 global SHARED_OBJECT_LOADING 

193 SHARED_OBJECT_LOADING.scope = self 

194 self._obj_ids_to_obj = {} 

195 return self 

196 

197 def get(self, object_id): 

198 """Given a shared object ID, returns a previously instantiated object. 

199 

200 Args: 

201 object_id: shared object ID to use when attempting to find already-loaded 

202 object. 

203 

204 Returns: 

205 The object, if we've seen this ID before. Else, `None`. 

206 """ 

207 # Explicitly check for `None` internally to make external calling code a 

208 # bit cleaner. 

209 if object_id is None: 

210 return 

211 return self._obj_ids_to_obj.get(object_id) 

212 

213 def set(self, object_id, obj): 

214 """Stores an instantiated object for future lookup and sharing.""" 

215 if object_id is None: 

216 return 

217 self._obj_ids_to_obj[object_id] = obj 

218 

219 def __exit__(self, *args, **kwargs): 

220 global SHARED_OBJECT_LOADING 

221 SHARED_OBJECT_LOADING.scope = NoopLoadingScope() 

222 

223 

224class SharedObjectConfig(dict): 

225 """A configuration container that keeps track of references. 

226 

227 `SharedObjectConfig` will automatically attach a shared object ID to any 

228 configs which are referenced more than once, allowing for proper shared 

229 object reconstruction at load time. 

230 

231 In most cases, it would be more proper to subclass something like 

232 `collections.UserDict` or `collections.Mapping` rather than `dict` directly. 

233 Unfortunately, python's json encoder does not support `Mapping`s. This is 

234 important functionality to retain, since we are dealing with serialization. 

235 

236 We should be safe to subclass `dict` here, since we aren't actually 

237 overriding any core methods, only augmenting with a new one for reference 

238 counting. 

239 """ 

240 

241 def __init__(self, base_config, object_id, **kwargs): 

242 self.ref_count = 1 

243 self.object_id = object_id 

244 super(SharedObjectConfig, self).__init__(base_config, **kwargs) 

245 

246 def increment_ref_count(self): 

247 # As soon as we've seen the object more than once, we want to attach the 

248 # shared object ID. This allows us to only attach the shared object ID when 

249 # it's strictly necessary, making backwards compatibility breakage less 

250 # likely. 

251 if self.ref_count == 1: 

252 self[SHARED_OBJECT_KEY] = self.object_id 

253 self.ref_count += 1 

254 

255 

256class SharedObjectSavingScope(object): 

257 """Keeps track of shared object configs when serializing.""" 

258 

259 def __enter__(self): 

260 if _shared_object_disabled(): 

261 return None 

262 

263 global SHARED_OBJECT_SAVING 

264 

265 # Serialization can happen at a number of layers for a number of reasons. 

266 # We may end up with a case where we're opening a saving scope within 

267 # another saving scope. In that case, we'd like to use the outermost scope 

268 # available and ignore inner scopes, since there is not (yet) a reasonable 

269 # use case for having these nested and distinct. 

270 if _shared_object_saving_scope() is not None: 

271 self._passthrough = True 

272 return _shared_object_saving_scope() 

273 else: 

274 self._passthrough = False 

275 

276 SHARED_OBJECT_SAVING.scope = self 

277 self._shared_objects_config = weakref.WeakKeyDictionary() 

278 self._next_id = 0 

279 return self 

280 

281 def get_config(self, obj): 

282 """Gets a `SharedObjectConfig` if one has already been seen for `obj`. 

283 

284 Args: 

285 obj: The object for which to retrieve the `SharedObjectConfig`. 

286 

287 Returns: 

288 The SharedObjectConfig for a given object, if already seen. Else, 

289 `None`. 

290 """ 

291 try: 

292 shared_object_config = self._shared_objects_config[obj] 

293 except (TypeError, KeyError): 

294 # If the object is unhashable (e.g. a subclass of `AbstractBaseClass` 

295 # that has not overridden `__hash__`), a `TypeError` will be thrown. 

296 # We'll just continue on without shared object support. 

297 return None 

298 shared_object_config.increment_ref_count() 

299 return shared_object_config 

300 

301 def create_config(self, base_config, obj): 

302 """Create a new SharedObjectConfig for a given object.""" 

303 shared_object_config = SharedObjectConfig(base_config, self._next_id) 

304 self._next_id += 1 

305 try: 

306 self._shared_objects_config[obj] = shared_object_config 

307 except TypeError: 

308 # If the object is unhashable (e.g. a subclass of `AbstractBaseClass` 

309 # that has not overridden `__hash__`), a `TypeError` will be thrown. 

310 # We'll just continue on without shared object support. 

311 pass 

312 return shared_object_config 

313 

314 def __exit__(self, *args, **kwargs): 

315 if not getattr(self, '_passthrough', False): 

316 global SHARED_OBJECT_SAVING 

317 SHARED_OBJECT_SAVING.scope = None 

318 

319 

320def serialize_keras_class_and_config( 

321 cls_name, cls_config, obj=None, shared_object_id=None): 

322 """Returns the serialization of the class with the given config.""" 

323 base_config = {'class_name': cls_name, 'config': cls_config} 

324 

325 # We call `serialize_keras_class_and_config` for some branches of the load 

326 # path. In that case, we may already have a shared object ID we'd like to 

327 # retain. 

328 if shared_object_id is not None: 

329 base_config[SHARED_OBJECT_KEY] = shared_object_id 

330 

331 # If we have an active `SharedObjectSavingScope`, check whether we've already 

332 # serialized this config. If so, just use that config. This will store an 

333 # extra ID field in the config, allowing us to re-create the shared object 

334 # relationship at load time. 

335 if _shared_object_saving_scope() is not None and obj is not None: 

336 shared_object_config = _shared_object_saving_scope().get_config(obj) 

337 if shared_object_config is None: 

338 return _shared_object_saving_scope().create_config(base_config, obj) 

339 return shared_object_config 

340 

341 return base_config 

342 

343 

344@keras_export('keras.utils.register_keras_serializable') 

345def register_keras_serializable(package='Custom', name=None): 

346 """Registers an object with the Keras serialization framework. 

347 

348 This decorator injects the decorated class or function into the Keras custom 

349 object dictionary, so that it can be serialized and deserialized without 

350 needing an entry in the user-provided custom object dict. It also injects a 

351 function that Keras will call to get the object's serializable string key. 

352 

353 Note that to be serialized and deserialized, classes must implement the 

354 `get_config()` method. Functions do not have this requirement. 

355 

356 The object will be registered under the key 'package>name' where `name`, 

357 defaults to the object name if not passed. 

358 

359 Args: 

360 package: The package that this class belongs to. 

361 name: The name to serialize this class under in this package. If None, the 

362 class' name will be used. 

363 

364 Returns: 

365 A decorator that registers the decorated class with the passed names. 

366 """ 

367 

368 def decorator(arg): 

369 """Registers a class with the Keras serialization framework.""" 

370 class_name = name if name is not None else arg.__name__ 

371 registered_name = package + '>' + class_name 

372 

373 if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'): 

374 raise ValueError( 

375 'Cannot register a class that does not have a get_config() method.') 

376 

377 if registered_name in _GLOBAL_CUSTOM_OBJECTS: 

378 raise ValueError( 

379 '%s has already been registered to %s' % 

380 (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name])) 

381 

382 if arg in _GLOBAL_CUSTOM_NAMES: 

383 raise ValueError('%s has already been registered to %s' % 

384 (arg, _GLOBAL_CUSTOM_NAMES[arg])) 

385 _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg 

386 _GLOBAL_CUSTOM_NAMES[arg] = registered_name 

387 

388 return arg 

389 

390 return decorator 

391 

392 

393@keras_export('keras.utils.get_registered_name') 

394def get_registered_name(obj): 

395 """Returns the name registered to an object within the Keras framework. 

396 

397 This function is part of the Keras serialization and deserialization 

398 framework. It maps objects to the string names associated with those objects 

399 for serialization/deserialization. 

400 

401 Args: 

402 obj: The object to look up. 

403 

404 Returns: 

405 The name associated with the object, or the default Python name if the 

406 object is not registered. 

407 """ 

408 if obj in _GLOBAL_CUSTOM_NAMES: 

409 return _GLOBAL_CUSTOM_NAMES[obj] 

410 else: 

411 return obj.__name__ 

412 

413 

414@tf_contextlib.contextmanager 

415def skip_failed_serialization(): 

416 global _SKIP_FAILED_SERIALIZATION 

417 prev = _SKIP_FAILED_SERIALIZATION 

418 try: 

419 _SKIP_FAILED_SERIALIZATION = True 

420 yield 

421 finally: 

422 _SKIP_FAILED_SERIALIZATION = prev 

423 

424 

425@keras_export('keras.utils.get_registered_object') 

426def get_registered_object(name, custom_objects=None, module_objects=None): 

427 """Returns the class associated with `name` if it is registered with Keras. 

428 

429 This function is part of the Keras serialization and deserialization 

430 framework. It maps strings to the objects associated with them for 

431 serialization/deserialization. 

432 

433 Example: 

434 ``` 

435 def from_config(cls, config, custom_objects=None): 

436 if 'my_custom_object_name' in config: 

437 config['hidden_cls'] = tf.keras.utils.get_registered_object( 

438 config['my_custom_object_name'], custom_objects=custom_objects) 

439 ``` 

440 

441 Args: 

442 name: The name to look up. 

443 custom_objects: A dictionary of custom objects to look the name up in. 

444 Generally, custom_objects is provided by the user. 

445 module_objects: A dictionary of custom objects to look the name up in. 

446 Generally, module_objects is provided by midlevel library implementers. 

447 

448 Returns: 

449 An instantiable class associated with 'name', or None if no such class 

450 exists. 

451 """ 

452 if name in _GLOBAL_CUSTOM_OBJECTS: 

453 return _GLOBAL_CUSTOM_OBJECTS[name] 

454 elif custom_objects and name in custom_objects: 

455 return custom_objects[name] 

456 elif module_objects and name in module_objects: 

457 return module_objects[name] 

458 return None 

459 

460 

461# pylint: disable=g-bad-exception-name 

462class CustomMaskWarning(Warning): 

463 pass 

464# pylint: enable=g-bad-exception-name 

465 

466 

467@keras_export('keras.utils.serialize_keras_object') 

468def serialize_keras_object(instance): 

469 """Serialize a Keras object into a JSON-compatible representation. 

470 

471 Calls to `serialize_keras_object` while underneath the 

472 `SharedObjectSavingScope` context manager will cause any objects re-used 

473 across multiple layers to be saved with a special shared object ID. This 

474 allows the network to be re-created properly during deserialization. 

475 

476 Args: 

477 instance: The object to serialize. 

478 

479 Returns: 

480 A dict-like, JSON-compatible representation of the object's config. 

481 """ 

482 _, instance = tf_decorator.unwrap(instance) 

483 if instance is None: 

484 return None 

485 

486 # pylint: disable=protected-access 

487 # 

488 # For v1 layers, checking supports_masking is not enough. We have to also 

489 # check whether compute_mask has been overridden. 

490 supports_masking = (getattr(instance, 'supports_masking', False) 

491 or (hasattr(instance, 'compute_mask') 

492 and not is_default(instance.compute_mask))) 

493 if supports_masking and is_default(instance.get_config): 

494 warnings.warn('Custom mask layers require a config and must override ' 

495 'get_config. When loading, the custom mask layer must be ' 

496 'passed to the custom_objects argument.', 

497 category=CustomMaskWarning) 

498 # pylint: enable=protected-access 

499 

500 if hasattr(instance, 'get_config'): 

501 name = get_registered_name(instance.__class__) 

502 try: 

503 config = instance.get_config() 

504 except NotImplementedError as e: 

505 if _SKIP_FAILED_SERIALIZATION: 

506 return serialize_keras_class_and_config( 

507 name, {_LAYER_UNDEFINED_CONFIG_KEY: True}) 

508 raise e 

509 serialization_config = {} 

510 for key, item in config.items(): 

511 if isinstance(item, str): 

512 serialization_config[key] = item 

513 continue 

514 

515 # Any object of a different type needs to be converted to string or dict 

516 # for serialization (e.g. custom functions, custom classes) 

517 try: 

518 serialized_item = serialize_keras_object(item) 

519 if isinstance(serialized_item, dict) and not isinstance(item, dict): 

520 serialized_item['__passive_serialization__'] = True 

521 serialization_config[key] = serialized_item 

522 except ValueError: 

523 serialization_config[key] = item 

524 

525 name = get_registered_name(instance.__class__) 

526 return serialize_keras_class_and_config( 

527 name, serialization_config, instance) 

528 if hasattr(instance, '__name__'): 

529 return get_registered_name(instance) 

530 raise ValueError('Cannot serialize', instance) 

531 

532 

533def get_custom_objects_by_name(item, custom_objects=None): 

534 """Returns the item if it is in either local or global custom objects.""" 

535 if item in _GLOBAL_CUSTOM_OBJECTS: 

536 return _GLOBAL_CUSTOM_OBJECTS[item] 

537 elif custom_objects and item in custom_objects: 

538 return custom_objects[item] 

539 return None 

540 

541 

542def class_and_config_for_serialized_keras_object( 

543 config, 

544 module_objects=None, 

545 custom_objects=None, 

546 printable_module_name='object'): 

547 """Returns the class name and config for a serialized keras object.""" 

548 if (not isinstance(config, dict) 

549 or 'class_name' not in config 

550 or 'config' not in config): 

551 raise ValueError('Improper config format: ' + str(config)) 

552 

553 class_name = config['class_name'] 

554 cls = get_registered_object(class_name, custom_objects, module_objects) 

555 if cls is None: 

556 raise ValueError( 

557 'Unknown {}: {}. Please ensure this object is ' 

558 'passed to the `custom_objects` argument. See ' 

559 'https://www.tensorflow.org/guide/keras/save_and_serialize' 

560 '#registering_the_custom_object for details.' 

561 .format(printable_module_name, class_name)) 

562 

563 cls_config = config['config'] 

564 # Check if `cls_config` is a list. If it is a list, return the class and the 

565 # associated class configs for recursively deserialization. This case will 

566 # happen on the old version of sequential model (e.g. `keras_version` == 

567 # "2.0.6"), which is serialized in a different structure, for example 

568 # "{'class_name': 'Sequential', 

569 # 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}". 

570 if isinstance(cls_config, list): 

571 return (cls, cls_config) 

572 

573 deserialized_objects = {} 

574 for key, item in cls_config.items(): 

575 if key == 'name': 

576 # Assume that the value of 'name' is a string that should not be 

577 # deserialized as a function. This avoids the corner case where 

578 # cls_config['name'] has an identical name to a custom function and 

579 # gets converted into that function. 

580 deserialized_objects[key] = item 

581 elif isinstance(item, dict) and '__passive_serialization__' in item: 

582 deserialized_objects[key] = deserialize_keras_object( 

583 item, 

584 module_objects=module_objects, 

585 custom_objects=custom_objects, 

586 printable_module_name='config_item') 

587 # TODO(momernick): Should this also have 'module_objects'? 

588 elif (isinstance(item, str) and 

589 tf_inspect.isfunction(get_registered_object(item, custom_objects))): 

590 # Handle custom functions here. When saving functions, we only save the 

591 # function's name as a string. If we find a matching string in the custom 

592 # objects during deserialization, we convert the string back to the 

593 # original function. 

594 # Note that a potential issue is that a string field could have a naming 

595 # conflict with a custom function name, but this should be a rare case. 

596 # This issue does not occur if a string field has a naming conflict with 

597 # a custom object, since the config of an object will always be a dict. 

598 deserialized_objects[key] = get_registered_object(item, custom_objects) 

599 for key, item in deserialized_objects.items(): 

600 cls_config[key] = deserialized_objects[key] 

601 

602 return (cls, cls_config) 

603 

604 

605@keras_export('keras.utils.deserialize_keras_object') 

606def deserialize_keras_object(identifier, 

607 module_objects=None, 

608 custom_objects=None, 

609 printable_module_name='object'): 

610 """Turns the serialized form of a Keras object back into an actual object. 

611 

612 This function is for mid-level library implementers rather than end users. 

613 

614 Importantly, this utility requires you to provide the dict of `module_objects` 

615 to use for looking up the object config; this is not populated by default. 

616 If you need a deserialization utility that has preexisting knowledge of 

617 built-in Keras objects, use e.g. `keras.layers.deserialize(config)`, 

618 `keras.metrics.deserialize(config)`, etc. 

619 

620 Calling `deserialize_keras_object` while underneath the 

621 `SharedObjectLoadingScope` context manager will cause any already-seen shared 

622 objects to be returned as-is rather than creating a new object. 

623 

624 Args: 

625 identifier: the serialized form of the object. 

626 module_objects: A dictionary of built-in objects to look the name up in. 

627 Generally, `module_objects` is provided by midlevel library implementers. 

628 custom_objects: A dictionary of custom objects to look the name up in. 

629 Generally, `custom_objects` is provided by the end user. 

630 printable_module_name: A human-readable string representing the type of the 

631 object. Printed in case of exception. 

632 

633 Returns: 

634 The deserialized object. 

635 

636 Example: 

637 

638 A mid-level library implementer might want to implement a utility for 

639 retrieving an object from its config, as such: 

640 

641 ```python 

642 def deserialize(config, custom_objects=None): 

643 return deserialize_keras_object( 

644 identifier, 

645 module_objects=globals(), 

646 custom_objects=custom_objects, 

647 name="MyObjectType", 

648 ) 

649 ``` 

650 

651 This is how e.g. `keras.layers.deserialize()` is implemented. 

652 """ 

653 if identifier is None: 

654 return None 

655 

656 if isinstance(identifier, dict): 

657 # In this case we are dealing with a Keras config dictionary. 

658 config = identifier 

659 (cls, cls_config) = class_and_config_for_serialized_keras_object( 

660 config, module_objects, custom_objects, printable_module_name) 

661 

662 # If this object has already been loaded (i.e. it's shared between multiple 

663 # objects), return the already-loaded object. 

664 shared_object_id = config.get(SHARED_OBJECT_KEY) 

665 shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none 

666 if shared_object is not None: 

667 return shared_object 

668 

669 if hasattr(cls, 'from_config'): 

670 arg_spec = tf_inspect.getfullargspec(cls.from_config) 

671 custom_objects = custom_objects or {} 

672 

673 if 'custom_objects' in arg_spec.args: 

674 deserialized_obj = cls.from_config( 

675 cls_config, 

676 custom_objects=dict( 

677 list(_GLOBAL_CUSTOM_OBJECTS.items()) + 

678 list(custom_objects.items()))) 

679 else: 

680 with CustomObjectScope(custom_objects): 

681 deserialized_obj = cls.from_config(cls_config) 

682 else: 

683 # Then `cls` may be a function returning a class. 

684 # in this case by convention `config` holds 

685 # the kwargs of the function. 

686 custom_objects = custom_objects or {} 

687 with CustomObjectScope(custom_objects): 

688 deserialized_obj = cls(**cls_config) 

689 

690 # Add object to shared objects, in case we find it referenced again. 

691 _shared_object_loading_scope().set(shared_object_id, deserialized_obj) 

692 

693 return deserialized_obj 

694 

695 elif isinstance(identifier, str): 

696 object_name = identifier 

697 if custom_objects and object_name in custom_objects: 

698 obj = custom_objects.get(object_name) 

699 elif object_name in _GLOBAL_CUSTOM_OBJECTS: 

700 obj = _GLOBAL_CUSTOM_OBJECTS[object_name] 

701 else: 

702 obj = module_objects.get(object_name) 

703 if obj is None: 

704 raise ValueError( 

705 'Unknown {}: {}. Please ensure this object is ' 

706 'passed to the `custom_objects` argument. See ' 

707 'https://www.tensorflow.org/guide/keras/save_and_serialize' 

708 '#registering_the_custom_object for details.' 

709 .format(printable_module_name, object_name)) 

710 

711 # Classes passed by name are instantiated with no args, functions are 

712 # returned as-is. 

713 if tf_inspect.isclass(obj): 

714 return obj() 

715 return obj 

716 elif tf_inspect.isfunction(identifier): 

717 # If a function has already been deserialized, return as is. 

718 return identifier 

719 else: 

720 raise ValueError('Could not interpret serialized %s: %s' % 

721 (printable_module_name, identifier)) 

722 

723 

724def func_dump(func): 

725 """Serializes a user defined function. 

726 

727 Args: 

728 func: the function to serialize. 

729 

730 Returns: 

731 A tuple `(code, defaults, closure)`. 

732 """ 

733 if os.name == 'nt': 

734 raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/') 

735 code = codecs.encode(raw_code, 'base64').decode('ascii') 

736 else: 

737 raw_code = marshal.dumps(func.__code__) 

738 code = codecs.encode(raw_code, 'base64').decode('ascii') 

739 defaults = func.__defaults__ 

740 if func.__closure__: 

741 closure = tuple(c.cell_contents for c in func.__closure__) 

742 else: 

743 closure = None 

744 return code, defaults, closure 

745 

746 

747def func_load(code, defaults=None, closure=None, globs=None): 

748 """Deserializes a user defined function. 

749 

750 Args: 

751 code: bytecode of the function. 

752 defaults: defaults of the function. 

753 closure: closure of the function. 

754 globs: dictionary of global objects. 

755 

756 Returns: 

757 A function object. 

758 """ 

759 if isinstance(code, (tuple, list)): # unpack previous dump 

760 code, defaults, closure = code 

761 if isinstance(defaults, list): 

762 defaults = tuple(defaults) 

763 

764 def ensure_value_to_cell(value): 

765 """Ensures that a value is converted to a python cell object. 

766 

767 Args: 

768 value: Any value that needs to be casted to the cell type 

769 

770 Returns: 

771 A value wrapped as a cell object (see function "func_load") 

772 """ 

773 

774 def dummy_fn(): 

775 # pylint: disable=pointless-statement 

776 value # just access it so it gets captured in .__closure__ 

777 

778 cell_value = dummy_fn.__closure__[0] 

779 if not isinstance(value, type(cell_value)): 

780 return cell_value 

781 return value 

782 

783 if closure is not None: 

784 closure = tuple(ensure_value_to_cell(_) for _ in closure) 

785 try: 

786 raw_code = codecs.decode(code.encode('ascii'), 'base64') 

787 except (UnicodeEncodeError, binascii.Error): 

788 raw_code = code.encode('raw_unicode_escape') 

789 code = marshal.loads(raw_code) 

790 if globs is None: 

791 globs = globals() 

792 return python_types.FunctionType( 

793 code, globs, name=code.co_name, argdefs=defaults, closure=closure) 

794 

795 

796def has_arg(fn, name, accept_all=False): 

797 """Checks if a callable accepts a given keyword argument. 

798 

799 Args: 

800 fn: Callable to inspect. 

801 name: Check if `fn` can be called with `name` as a keyword argument. 

802 accept_all: What to return if there is no parameter called `name` but the 

803 function accepts a `**kwargs` argument. 

804 

805 Returns: 

806 bool, whether `fn` accepts a `name` keyword argument. 

807 """ 

808 arg_spec = tf_inspect.getfullargspec(fn) 

809 if accept_all and arg_spec.varkw is not None: 

810 return True 

811 return name in arg_spec.args or name in arg_spec.kwonlyargs 

812 

813 

814@keras_export('keras.utils.Progbar') 

815class Progbar(object): 

816 """Displays a progress bar. 

817 

818 Args: 

819 target: Total number of steps expected, None if unknown. 

820 width: Progress bar width on screen. 

821 verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 

822 stateful_metrics: Iterable of string names of metrics that should *not* be 

823 averaged over time. Metrics in this list will be displayed as-is. All 

824 others will be averaged by the progbar before display. 

825 interval: Minimum visual progress update interval (in seconds). 

826 unit_name: Display name for step counts (usually "step" or "sample"). 

827 """ 

828 

829 def __init__(self, 

830 target, 

831 width=30, 

832 verbose=1, 

833 interval=0.05, 

834 stateful_metrics=None, 

835 unit_name='step'): 

836 self.target = target 

837 self.width = width 

838 self.verbose = verbose 

839 self.interval = interval 

840 self.unit_name = unit_name 

841 if stateful_metrics: 

842 self.stateful_metrics = set(stateful_metrics) 

843 else: 

844 self.stateful_metrics = set() 

845 

846 self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 

847 sys.stdout.isatty()) or 

848 'ipykernel' in sys.modules or 

849 'posix' in sys.modules or 

850 'PYCHARM_HOSTED' in os.environ) 

851 self._total_width = 0 

852 self._seen_so_far = 0 

853 # We use a dict + list to avoid garbage collection 

854 # issues found in OrderedDict 

855 self._values = {} 

856 self._values_order = [] 

857 self._start = time.time() 

858 self._last_update = 0 

859 

860 self._time_after_first_step = None 

861 

862 def update(self, current, values=None, finalize=None): 

863 """Updates the progress bar. 

864 

865 Args: 

866 current: Index of current step. 

867 values: List of tuples: `(name, value_for_last_step)`. If `name` is in 

868 `stateful_metrics`, `value_for_last_step` will be displayed as-is. 

869 Else, an average of the metric over time will be displayed. 

870 finalize: Whether this is the last update for the progress bar. If 

871 `None`, defaults to `current >= self.target`. 

872 """ 

873 if finalize is None: 

874 if self.target is None: 

875 finalize = False 

876 else: 

877 finalize = current >= self.target 

878 

879 values = values or [] 

880 for k, v in values: 

881 if k not in self._values_order: 

882 self._values_order.append(k) 

883 if k not in self.stateful_metrics: 

884 # In the case that progress bar doesn't have a target value in the first 

885 # epoch, both on_batch_end and on_epoch_end will be called, which will 

886 # cause 'current' and 'self._seen_so_far' to have the same value. Force 

887 # the minimal value to 1 here, otherwise stateful_metric will be 0s. 

888 value_base = max(current - self._seen_so_far, 1) 

889 if k not in self._values: 

890 self._values[k] = [v * value_base, value_base] 

891 else: 

892 self._values[k][0] += v * value_base 

893 self._values[k][1] += value_base 

894 else: 

895 # Stateful metrics output a numeric value. This representation 

896 # means "take an average from a single value" but keeps the 

897 # numeric formatting. 

898 self._values[k] = [v, 1] 

899 self._seen_so_far = current 

900 

901 now = time.time() 

902 info = ' - %.0fs' % (now - self._start) 

903 if self.verbose == 1: 

904 if now - self._last_update < self.interval and not finalize: 

905 return 

906 

907 prev_total_width = self._total_width 

908 if self._dynamic_display: 

909 sys.stdout.write('\b' * prev_total_width) 

910 sys.stdout.write('\r') 

911 else: 

912 sys.stdout.write('\n') 

913 

914 if self.target is not None: 

915 numdigits = int(np.log10(self.target)) + 1 

916 bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) 

917 prog = float(current) / self.target 

918 prog_width = int(self.width * prog) 

919 if prog_width > 0: 

920 bar += ('=' * (prog_width - 1)) 

921 if current < self.target: 

922 bar += '>' 

923 else: 

924 bar += '=' 

925 bar += ('.' * (self.width - prog_width)) 

926 bar += ']' 

927 else: 

928 bar = '%7d/Unknown' % current 

929 

930 self._total_width = len(bar) 

931 sys.stdout.write(bar) 

932 

933 time_per_unit = self._estimate_step_duration(current, now) 

934 

935 if self.target is None or finalize: 

936 if time_per_unit >= 1 or time_per_unit == 0: 

937 info += ' %.0fs/%s' % (time_per_unit, self.unit_name) 

938 elif time_per_unit >= 1e-3: 

939 info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) 

940 else: 

941 info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) 

942 else: 

943 eta = time_per_unit * (self.target - current) 

944 if eta > 3600: 

945 eta_format = '%d:%02d:%02d' % (eta // 3600, 

946 (eta % 3600) // 60, eta % 60) 

947 elif eta > 60: 

948 eta_format = '%d:%02d' % (eta // 60, eta % 60) 

949 else: 

950 eta_format = '%ds' % eta 

951 

952 info = ' - ETA: %s' % eta_format 

953 

954 for k in self._values_order: 

955 info += ' - %s:' % k 

956 if isinstance(self._values[k], list): 

957 avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 

958 if abs(avg) > 1e-3: 

959 info += ' %.4f' % avg 

960 else: 

961 info += ' %.4e' % avg 

962 else: 

963 info += ' %s' % self._values[k] 

964 

965 self._total_width += len(info) 

966 if prev_total_width > self._total_width: 

967 info += (' ' * (prev_total_width - self._total_width)) 

968 

969 if finalize: 

970 info += '\n' 

971 

972 sys.stdout.write(info) 

973 sys.stdout.flush() 

974 

975 elif self.verbose == 2: 

976 if finalize: 

977 numdigits = int(np.log10(self.target)) + 1 

978 count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) 

979 info = count + info 

980 for k in self._values_order: 

981 info += ' - %s:' % k 

982 avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 

983 if avg > 1e-3: 

984 info += ' %.4f' % avg 

985 else: 

986 info += ' %.4e' % avg 

987 info += '\n' 

988 

989 sys.stdout.write(info) 

990 sys.stdout.flush() 

991 

992 self._last_update = now 

993 

994 def add(self, n, values=None): 

995 self.update(self._seen_so_far + n, values) 

996 

997 def _estimate_step_duration(self, current, now): 

998 """Estimate the duration of a single step. 

999 

1000 Given the step number `current` and the corresponding time `now` 

1001 this function returns an estimate for how long a single step 

1002 takes. If this is called before one step has been completed 

1003 (i.e. `current == 0`) then zero is given as an estimate. The duration 

1004 estimate ignores the duration of the (assumed to be non-representative) 

1005 first step for estimates when more steps are available (i.e. `current>1`). 

1006 Args: 

1007 current: Index of current step. 

1008 now: The current time. 

1009 Returns: Estimate of the duration of a single step. 

1010 """ 

1011 if current: 

1012 # there are a few special scenarios here: 

1013 # 1) somebody is calling the progress bar without ever supplying step 1 

1014 # 2) somebody is calling the progress bar and supplies step one mulitple 

1015 # times, e.g. as part of a finalizing call 

1016 # in these cases, we just fall back to the simple calculation 

1017 if self._time_after_first_step is not None and current > 1: 

1018 time_per_unit = (now - self._time_after_first_step) / (current - 1) 

1019 else: 

1020 time_per_unit = (now - self._start) / current 

1021 

1022 if current == 1: 

1023 self._time_after_first_step = now 

1024 return time_per_unit 

1025 else: 

1026 return 0 

1027 

1028 def _update_stateful_metrics(self, stateful_metrics): 

1029 self.stateful_metrics = self.stateful_metrics.union(stateful_metrics) 

1030 

1031 

1032def make_batches(size, batch_size): 

1033 """Returns a list of batch indices (tuples of indices). 

1034 

1035 Args: 

1036 size: Integer, total size of the data to slice into batches. 

1037 batch_size: Integer, batch size. 

1038 

1039 Returns: 

1040 A list of tuples of array indices. 

1041 """ 

1042 num_batches = int(np.ceil(size / float(batch_size))) 

1043 return [(i * batch_size, min(size, (i + 1) * batch_size)) 

1044 for i in range(0, num_batches)] 

1045 

1046 

1047def slice_arrays(arrays, start=None, stop=None): 

1048 """Slice an array or list of arrays. 

1049 

1050 This takes an array-like, or a list of 

1051 array-likes, and outputs: 

1052 - arrays[start:stop] if `arrays` is an array-like 

1053 - [x[start:stop] for x in arrays] if `arrays` is a list 

1054 

1055 Can also work on list/array of indices: `slice_arrays(x, indices)` 

1056 

1057 Args: 

1058 arrays: Single array or list of arrays. 

1059 start: can be an integer index (start index) or a list/array of indices 

1060 stop: integer (stop index); should be None if `start` was a list. 

1061 

1062 Returns: 

1063 A slice of the array(s). 

1064 

1065 Raises: 

1066 ValueError: If the value of start is a list and stop is not None. 

1067 """ 

1068 if arrays is None: 

1069 return [None] 

1070 if isinstance(start, list) and stop is not None: 

1071 raise ValueError('The stop argument has to be None if the value of start ' 

1072 'is a list.') 

1073 elif isinstance(arrays, list): 

1074 if hasattr(start, '__len__'): 

1075 # hdf5 datasets only support list objects as indices 

1076 if hasattr(start, 'shape'): 

1077 start = start.tolist() 

1078 return [None if x is None else x[start] for x in arrays] 

1079 return [ 

1080 None if x is None else 

1081 None if not hasattr(x, '__getitem__') else x[start:stop] for x in arrays 

1082 ] 

1083 else: 

1084 if hasattr(start, '__len__'): 

1085 if hasattr(start, 'shape'): 

1086 start = start.tolist() 

1087 return arrays[start] 

1088 if hasattr(start, '__getitem__'): 

1089 return arrays[start:stop] 

1090 return [None] 

1091 

1092 

1093def to_list(x): 

1094 """Normalizes a list/tensor into a list. 

1095 

1096 If a tensor is passed, we return 

1097 a list of size 1 containing the tensor. 

1098 

1099 Args: 

1100 x: target object to be normalized. 

1101 

1102 Returns: 

1103 A list. 

1104 """ 

1105 if isinstance(x, list): 

1106 return x 

1107 return [x] 

1108 

1109 

1110def to_snake_case(name): 

1111 intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) 

1112 insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() 

1113 # If the class is private the name starts with "_" which is not secure 

1114 # for creating scopes. We prefix the name with "private" in this case. 

1115 if insecure[0] != '_': 

1116 return insecure 

1117 return 'private' + insecure 

1118 

1119 

1120def is_all_none(structure): 

1121 iterable = nest.flatten(structure) 

1122 # We cannot use Python's `any` because the iterable may return Tensors. 

1123 for element in iterable: 

1124 if element is not None: 

1125 return False 

1126 return True 

1127 

1128 

1129def check_for_unexpected_keys(name, input_dict, expected_values): 

1130 unknown = set(input_dict.keys()).difference(expected_values) 

1131 if unknown: 

1132 raise ValueError('Unknown entries in {} dictionary: {}. Only expected ' 

1133 'following keys: {}'.format(name, list(unknown), 

1134 expected_values)) 

1135 

1136 

1137def validate_kwargs(kwargs, 

1138 allowed_kwargs, 

1139 error_message='Keyword argument not understood:'): 

1140 """Checks that all keyword arguments are in the set of allowed keys.""" 

1141 for kwarg in kwargs: 

1142 if kwarg not in allowed_kwargs: 

1143 raise TypeError(error_message, kwarg) 

1144 

1145 

1146def validate_config(config): 

1147 """Determines whether config appears to be a valid layer config.""" 

1148 return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config 

1149 

1150 

1151def default(method): 

1152 """Decorates a method to detect overrides in subclasses.""" 

1153 method._is_default = True # pylint: disable=protected-access 

1154 return method 

1155 

1156 

1157def is_default(method): 

1158 """Check if a method is decorated with the `default` wrapper.""" 

1159 return getattr(method, '_is_default', False) 

1160 

1161 

1162def populate_dict_with_module_objects(target_dict, modules, obj_filter): 

1163 for module in modules: 

1164 for name in dir(module): 

1165 obj = getattr(module, name) 

1166 if obj_filter(obj): 

1167 target_dict[name] = obj 

1168 

1169 

1170class LazyLoader(python_types.ModuleType): 

1171 """Lazily import a module, mainly to avoid pulling in large dependencies.""" 

1172 

1173 def __init__(self, local_name, parent_module_globals, name): 

1174 self._local_name = local_name 

1175 self._parent_module_globals = parent_module_globals 

1176 super(LazyLoader, self).__init__(name) 

1177 

1178 def _load(self): 

1179 """Load the module and insert it into the parent's globals.""" 

1180 # Import the target module and insert it into the parent's namespace 

1181 module = importlib.import_module(self.__name__) 

1182 self._parent_module_globals[self._local_name] = module 

1183 # Update this object's dict so that if someone keeps a reference to the 

1184 # LazyLoader, lookups are efficient (__getattr__ is only called on lookups 

1185 # that fail). 

1186 self.__dict__.update(module.__dict__) 

1187 return module 

1188 

1189 def __getattr__(self, item): 

1190 module = self._load() 

1191 return getattr(module, item) 

1192 

1193 

1194# Aliases 

1195 

1196custom_object_scope = CustomObjectScope # pylint: disable=invalid-name