Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py: 22%
493 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python utilities required by Keras."""
17import binascii
18import codecs
19import importlib
20import marshal
21import os
22import re
23import sys
24import threading
25import time
26import types as python_types
27import warnings
28import weakref
30import numpy as np
32from tensorflow.python.keras.utils import tf_contextlib
33from tensorflow.python.keras.utils import tf_inspect
34from tensorflow.python.util import nest
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util.tf_export import keras_export
38_GLOBAL_CUSTOM_OBJECTS = {}
39_GLOBAL_CUSTOM_NAMES = {}
41# Flag that determines whether to skip the NotImplementedError when calling
42# get_config in custom models and layers. This is only enabled when saving to
43# SavedModel, when the config isn't required.
44_SKIP_FAILED_SERIALIZATION = False
45# If a layer does not have a defined config, then the returned config will be a
46# dictionary with the below key.
47_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
50@keras_export('keras.utils.custom_object_scope', # pylint: disable=g-classes-have-attributes
51 'keras.utils.CustomObjectScope')
52class CustomObjectScope(object):
53 """Exposes custom classes/functions to Keras deserialization internals.
55 Under a scope `with custom_object_scope(objects_dict)`, Keras methods such
56 as `tf.keras.models.load_model` or `tf.keras.models.model_from_config`
57 will be able to deserialize any custom object referenced by a
58 saved config (e.g. a custom layer or metric).
60 Example:
62 Consider a custom regularizer `my_regularizer`:
64 ```python
65 layer = Dense(3, kernel_regularizer=my_regularizer)
66 config = layer.get_config() # Config contains a reference to `my_regularizer`
67 ...
68 # Later:
69 with custom_object_scope({'my_regularizer': my_regularizer}):
70 layer = Dense.from_config(config)
71 ```
73 Args:
74 *args: Dictionary or dictionaries of `{name: object}` pairs.
75 """
77 def __init__(self, *args):
78 self.custom_objects = args
79 self.backup = None
81 def __enter__(self):
82 self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
83 for objects in self.custom_objects:
84 _GLOBAL_CUSTOM_OBJECTS.update(objects)
85 return self
87 def __exit__(self, *args, **kwargs):
88 _GLOBAL_CUSTOM_OBJECTS.clear()
89 _GLOBAL_CUSTOM_OBJECTS.update(self.backup)
92@keras_export('keras.utils.get_custom_objects')
93def get_custom_objects():
94 """Retrieves a live reference to the global dictionary of custom objects.
96 Updating and clearing custom objects using `custom_object_scope`
97 is preferred, but `get_custom_objects` can
98 be used to directly access the current collection of custom objects.
100 Example:
102 ```python
103 get_custom_objects().clear()
104 get_custom_objects()['MyObject'] = MyObject
105 ```
107 Returns:
108 Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`).
109 """
110 return _GLOBAL_CUSTOM_OBJECTS
113# Store a unique, per-object ID for shared objects.
114#
115# We store a unique ID for each object so that we may, at loading time,
116# re-create the network properly. Without this ID, we would have no way of
117# determining whether a config is a description of a new object that
118# should be created or is merely a reference to an already-created object.
119SHARED_OBJECT_KEY = 'shared_object_id'
122SHARED_OBJECT_DISABLED = threading.local()
123SHARED_OBJECT_LOADING = threading.local()
124SHARED_OBJECT_SAVING = threading.local()
127# Attributes on the threadlocal variable must be set per-thread, thus we
128# cannot initialize these globally. Instead, we have accessor functions with
129# default values.
130def _shared_object_disabled():
131 """Get whether shared object handling is disabled in a threadsafe manner."""
132 return getattr(SHARED_OBJECT_DISABLED, 'disabled', False)
135def _shared_object_loading_scope():
136 """Get the current shared object saving scope in a threadsafe manner."""
137 return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
140def _shared_object_saving_scope():
141 """Get the current shared object saving scope in a threadsafe manner."""
142 return getattr(SHARED_OBJECT_SAVING, 'scope', None)
145class DisableSharedObjectScope(object):
146 """A context manager for disabling handling of shared objects.
148 Disables shared object handling for both saving and loading.
150 Created primarily for use with `clone_model`, which does extra surgery that
151 is incompatible with shared objects.
152 """
154 def __enter__(self):
155 SHARED_OBJECT_DISABLED.disabled = True
156 self._orig_loading_scope = _shared_object_loading_scope()
157 self._orig_saving_scope = _shared_object_saving_scope()
159 def __exit__(self, *args, **kwargs):
160 SHARED_OBJECT_DISABLED.disabled = False
161 SHARED_OBJECT_LOADING.scope = self._orig_loading_scope
162 SHARED_OBJECT_SAVING.scope = self._orig_saving_scope
165class NoopLoadingScope(object):
166 """The default shared object loading scope. It does nothing.
168 Created to simplify serialization code that doesn't care about shared objects
169 (e.g. when serializing a single object).
170 """
172 def get(self, unused_object_id):
173 return None
175 def set(self, object_id, obj):
176 pass
179class SharedObjectLoadingScope(object):
180 """A context manager for keeping track of loaded objects.
182 During the deserialization process, we may come across objects that are
183 shared across multiple layers. In order to accurately restore the network
184 structure to its original state, `SharedObjectLoadingScope` allows us to
185 re-use shared objects rather than cloning them.
186 """
188 def __enter__(self):
189 if _shared_object_disabled():
190 return NoopLoadingScope()
192 global SHARED_OBJECT_LOADING
193 SHARED_OBJECT_LOADING.scope = self
194 self._obj_ids_to_obj = {}
195 return self
197 def get(self, object_id):
198 """Given a shared object ID, returns a previously instantiated object.
200 Args:
201 object_id: shared object ID to use when attempting to find already-loaded
202 object.
204 Returns:
205 The object, if we've seen this ID before. Else, `None`.
206 """
207 # Explicitly check for `None` internally to make external calling code a
208 # bit cleaner.
209 if object_id is None:
210 return
211 return self._obj_ids_to_obj.get(object_id)
213 def set(self, object_id, obj):
214 """Stores an instantiated object for future lookup and sharing."""
215 if object_id is None:
216 return
217 self._obj_ids_to_obj[object_id] = obj
219 def __exit__(self, *args, **kwargs):
220 global SHARED_OBJECT_LOADING
221 SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
224class SharedObjectConfig(dict):
225 """A configuration container that keeps track of references.
227 `SharedObjectConfig` will automatically attach a shared object ID to any
228 configs which are referenced more than once, allowing for proper shared
229 object reconstruction at load time.
231 In most cases, it would be more proper to subclass something like
232 `collections.UserDict` or `collections.Mapping` rather than `dict` directly.
233 Unfortunately, python's json encoder does not support `Mapping`s. This is
234 important functionality to retain, since we are dealing with serialization.
236 We should be safe to subclass `dict` here, since we aren't actually
237 overriding any core methods, only augmenting with a new one for reference
238 counting.
239 """
241 def __init__(self, base_config, object_id, **kwargs):
242 self.ref_count = 1
243 self.object_id = object_id
244 super(SharedObjectConfig, self).__init__(base_config, **kwargs)
246 def increment_ref_count(self):
247 # As soon as we've seen the object more than once, we want to attach the
248 # shared object ID. This allows us to only attach the shared object ID when
249 # it's strictly necessary, making backwards compatibility breakage less
250 # likely.
251 if self.ref_count == 1:
252 self[SHARED_OBJECT_KEY] = self.object_id
253 self.ref_count += 1
256class SharedObjectSavingScope(object):
257 """Keeps track of shared object configs when serializing."""
259 def __enter__(self):
260 if _shared_object_disabled():
261 return None
263 global SHARED_OBJECT_SAVING
265 # Serialization can happen at a number of layers for a number of reasons.
266 # We may end up with a case where we're opening a saving scope within
267 # another saving scope. In that case, we'd like to use the outermost scope
268 # available and ignore inner scopes, since there is not (yet) a reasonable
269 # use case for having these nested and distinct.
270 if _shared_object_saving_scope() is not None:
271 self._passthrough = True
272 return _shared_object_saving_scope()
273 else:
274 self._passthrough = False
276 SHARED_OBJECT_SAVING.scope = self
277 self._shared_objects_config = weakref.WeakKeyDictionary()
278 self._next_id = 0
279 return self
281 def get_config(self, obj):
282 """Gets a `SharedObjectConfig` if one has already been seen for `obj`.
284 Args:
285 obj: The object for which to retrieve the `SharedObjectConfig`.
287 Returns:
288 The SharedObjectConfig for a given object, if already seen. Else,
289 `None`.
290 """
291 try:
292 shared_object_config = self._shared_objects_config[obj]
293 except (TypeError, KeyError):
294 # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
295 # that has not overridden `__hash__`), a `TypeError` will be thrown.
296 # We'll just continue on without shared object support.
297 return None
298 shared_object_config.increment_ref_count()
299 return shared_object_config
301 def create_config(self, base_config, obj):
302 """Create a new SharedObjectConfig for a given object."""
303 shared_object_config = SharedObjectConfig(base_config, self._next_id)
304 self._next_id += 1
305 try:
306 self._shared_objects_config[obj] = shared_object_config
307 except TypeError:
308 # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
309 # that has not overridden `__hash__`), a `TypeError` will be thrown.
310 # We'll just continue on without shared object support.
311 pass
312 return shared_object_config
314 def __exit__(self, *args, **kwargs):
315 if not getattr(self, '_passthrough', False):
316 global SHARED_OBJECT_SAVING
317 SHARED_OBJECT_SAVING.scope = None
320def serialize_keras_class_and_config(
321 cls_name, cls_config, obj=None, shared_object_id=None):
322 """Returns the serialization of the class with the given config."""
323 base_config = {'class_name': cls_name, 'config': cls_config}
325 # We call `serialize_keras_class_and_config` for some branches of the load
326 # path. In that case, we may already have a shared object ID we'd like to
327 # retain.
328 if shared_object_id is not None:
329 base_config[SHARED_OBJECT_KEY] = shared_object_id
331 # If we have an active `SharedObjectSavingScope`, check whether we've already
332 # serialized this config. If so, just use that config. This will store an
333 # extra ID field in the config, allowing us to re-create the shared object
334 # relationship at load time.
335 if _shared_object_saving_scope() is not None and obj is not None:
336 shared_object_config = _shared_object_saving_scope().get_config(obj)
337 if shared_object_config is None:
338 return _shared_object_saving_scope().create_config(base_config, obj)
339 return shared_object_config
341 return base_config
344@keras_export('keras.utils.register_keras_serializable')
345def register_keras_serializable(package='Custom', name=None):
346 """Registers an object with the Keras serialization framework.
348 This decorator injects the decorated class or function into the Keras custom
349 object dictionary, so that it can be serialized and deserialized without
350 needing an entry in the user-provided custom object dict. It also injects a
351 function that Keras will call to get the object's serializable string key.
353 Note that to be serialized and deserialized, classes must implement the
354 `get_config()` method. Functions do not have this requirement.
356 The object will be registered under the key 'package>name' where `name`,
357 defaults to the object name if not passed.
359 Args:
360 package: The package that this class belongs to.
361 name: The name to serialize this class under in this package. If None, the
362 class' name will be used.
364 Returns:
365 A decorator that registers the decorated class with the passed names.
366 """
368 def decorator(arg):
369 """Registers a class with the Keras serialization framework."""
370 class_name = name if name is not None else arg.__name__
371 registered_name = package + '>' + class_name
373 if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'):
374 raise ValueError(
375 'Cannot register a class that does not have a get_config() method.')
377 if registered_name in _GLOBAL_CUSTOM_OBJECTS:
378 raise ValueError(
379 '%s has already been registered to %s' %
380 (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name]))
382 if arg in _GLOBAL_CUSTOM_NAMES:
383 raise ValueError('%s has already been registered to %s' %
384 (arg, _GLOBAL_CUSTOM_NAMES[arg]))
385 _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
386 _GLOBAL_CUSTOM_NAMES[arg] = registered_name
388 return arg
390 return decorator
393@keras_export('keras.utils.get_registered_name')
394def get_registered_name(obj):
395 """Returns the name registered to an object within the Keras framework.
397 This function is part of the Keras serialization and deserialization
398 framework. It maps objects to the string names associated with those objects
399 for serialization/deserialization.
401 Args:
402 obj: The object to look up.
404 Returns:
405 The name associated with the object, or the default Python name if the
406 object is not registered.
407 """
408 if obj in _GLOBAL_CUSTOM_NAMES:
409 return _GLOBAL_CUSTOM_NAMES[obj]
410 else:
411 return obj.__name__
414@tf_contextlib.contextmanager
415def skip_failed_serialization():
416 global _SKIP_FAILED_SERIALIZATION
417 prev = _SKIP_FAILED_SERIALIZATION
418 try:
419 _SKIP_FAILED_SERIALIZATION = True
420 yield
421 finally:
422 _SKIP_FAILED_SERIALIZATION = prev
425@keras_export('keras.utils.get_registered_object')
426def get_registered_object(name, custom_objects=None, module_objects=None):
427 """Returns the class associated with `name` if it is registered with Keras.
429 This function is part of the Keras serialization and deserialization
430 framework. It maps strings to the objects associated with them for
431 serialization/deserialization.
433 Example:
434 ```
435 def from_config(cls, config, custom_objects=None):
436 if 'my_custom_object_name' in config:
437 config['hidden_cls'] = tf.keras.utils.get_registered_object(
438 config['my_custom_object_name'], custom_objects=custom_objects)
439 ```
441 Args:
442 name: The name to look up.
443 custom_objects: A dictionary of custom objects to look the name up in.
444 Generally, custom_objects is provided by the user.
445 module_objects: A dictionary of custom objects to look the name up in.
446 Generally, module_objects is provided by midlevel library implementers.
448 Returns:
449 An instantiable class associated with 'name', or None if no such class
450 exists.
451 """
452 if name in _GLOBAL_CUSTOM_OBJECTS:
453 return _GLOBAL_CUSTOM_OBJECTS[name]
454 elif custom_objects and name in custom_objects:
455 return custom_objects[name]
456 elif module_objects and name in module_objects:
457 return module_objects[name]
458 return None
461# pylint: disable=g-bad-exception-name
462class CustomMaskWarning(Warning):
463 pass
464# pylint: enable=g-bad-exception-name
467@keras_export('keras.utils.serialize_keras_object')
468def serialize_keras_object(instance):
469 """Serialize a Keras object into a JSON-compatible representation.
471 Calls to `serialize_keras_object` while underneath the
472 `SharedObjectSavingScope` context manager will cause any objects re-used
473 across multiple layers to be saved with a special shared object ID. This
474 allows the network to be re-created properly during deserialization.
476 Args:
477 instance: The object to serialize.
479 Returns:
480 A dict-like, JSON-compatible representation of the object's config.
481 """
482 _, instance = tf_decorator.unwrap(instance)
483 if instance is None:
484 return None
486 # pylint: disable=protected-access
487 #
488 # For v1 layers, checking supports_masking is not enough. We have to also
489 # check whether compute_mask has been overridden.
490 supports_masking = (getattr(instance, 'supports_masking', False)
491 or (hasattr(instance, 'compute_mask')
492 and not is_default(instance.compute_mask)))
493 if supports_masking and is_default(instance.get_config):
494 warnings.warn('Custom mask layers require a config and must override '
495 'get_config. When loading, the custom mask layer must be '
496 'passed to the custom_objects argument.',
497 category=CustomMaskWarning)
498 # pylint: enable=protected-access
500 if hasattr(instance, 'get_config'):
501 name = get_registered_name(instance.__class__)
502 try:
503 config = instance.get_config()
504 except NotImplementedError as e:
505 if _SKIP_FAILED_SERIALIZATION:
506 return serialize_keras_class_and_config(
507 name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
508 raise e
509 serialization_config = {}
510 for key, item in config.items():
511 if isinstance(item, str):
512 serialization_config[key] = item
513 continue
515 # Any object of a different type needs to be converted to string or dict
516 # for serialization (e.g. custom functions, custom classes)
517 try:
518 serialized_item = serialize_keras_object(item)
519 if isinstance(serialized_item, dict) and not isinstance(item, dict):
520 serialized_item['__passive_serialization__'] = True
521 serialization_config[key] = serialized_item
522 except ValueError:
523 serialization_config[key] = item
525 name = get_registered_name(instance.__class__)
526 return serialize_keras_class_and_config(
527 name, serialization_config, instance)
528 if hasattr(instance, '__name__'):
529 return get_registered_name(instance)
530 raise ValueError('Cannot serialize', instance)
533def get_custom_objects_by_name(item, custom_objects=None):
534 """Returns the item if it is in either local or global custom objects."""
535 if item in _GLOBAL_CUSTOM_OBJECTS:
536 return _GLOBAL_CUSTOM_OBJECTS[item]
537 elif custom_objects and item in custom_objects:
538 return custom_objects[item]
539 return None
542def class_and_config_for_serialized_keras_object(
543 config,
544 module_objects=None,
545 custom_objects=None,
546 printable_module_name='object'):
547 """Returns the class name and config for a serialized keras object."""
548 if (not isinstance(config, dict)
549 or 'class_name' not in config
550 or 'config' not in config):
551 raise ValueError('Improper config format: ' + str(config))
553 class_name = config['class_name']
554 cls = get_registered_object(class_name, custom_objects, module_objects)
555 if cls is None:
556 raise ValueError(
557 'Unknown {}: {}. Please ensure this object is '
558 'passed to the `custom_objects` argument. See '
559 'https://www.tensorflow.org/guide/keras/save_and_serialize'
560 '#registering_the_custom_object for details.'
561 .format(printable_module_name, class_name))
563 cls_config = config['config']
564 # Check if `cls_config` is a list. If it is a list, return the class and the
565 # associated class configs for recursively deserialization. This case will
566 # happen on the old version of sequential model (e.g. `keras_version` ==
567 # "2.0.6"), which is serialized in a different structure, for example
568 # "{'class_name': 'Sequential',
569 # 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}".
570 if isinstance(cls_config, list):
571 return (cls, cls_config)
573 deserialized_objects = {}
574 for key, item in cls_config.items():
575 if key == 'name':
576 # Assume that the value of 'name' is a string that should not be
577 # deserialized as a function. This avoids the corner case where
578 # cls_config['name'] has an identical name to a custom function and
579 # gets converted into that function.
580 deserialized_objects[key] = item
581 elif isinstance(item, dict) and '__passive_serialization__' in item:
582 deserialized_objects[key] = deserialize_keras_object(
583 item,
584 module_objects=module_objects,
585 custom_objects=custom_objects,
586 printable_module_name='config_item')
587 # TODO(momernick): Should this also have 'module_objects'?
588 elif (isinstance(item, str) and
589 tf_inspect.isfunction(get_registered_object(item, custom_objects))):
590 # Handle custom functions here. When saving functions, we only save the
591 # function's name as a string. If we find a matching string in the custom
592 # objects during deserialization, we convert the string back to the
593 # original function.
594 # Note that a potential issue is that a string field could have a naming
595 # conflict with a custom function name, but this should be a rare case.
596 # This issue does not occur if a string field has a naming conflict with
597 # a custom object, since the config of an object will always be a dict.
598 deserialized_objects[key] = get_registered_object(item, custom_objects)
599 for key, item in deserialized_objects.items():
600 cls_config[key] = deserialized_objects[key]
602 return (cls, cls_config)
605@keras_export('keras.utils.deserialize_keras_object')
606def deserialize_keras_object(identifier,
607 module_objects=None,
608 custom_objects=None,
609 printable_module_name='object'):
610 """Turns the serialized form of a Keras object back into an actual object.
612 This function is for mid-level library implementers rather than end users.
614 Importantly, this utility requires you to provide the dict of `module_objects`
615 to use for looking up the object config; this is not populated by default.
616 If you need a deserialization utility that has preexisting knowledge of
617 built-in Keras objects, use e.g. `keras.layers.deserialize(config)`,
618 `keras.metrics.deserialize(config)`, etc.
620 Calling `deserialize_keras_object` while underneath the
621 `SharedObjectLoadingScope` context manager will cause any already-seen shared
622 objects to be returned as-is rather than creating a new object.
624 Args:
625 identifier: the serialized form of the object.
626 module_objects: A dictionary of built-in objects to look the name up in.
627 Generally, `module_objects` is provided by midlevel library implementers.
628 custom_objects: A dictionary of custom objects to look the name up in.
629 Generally, `custom_objects` is provided by the end user.
630 printable_module_name: A human-readable string representing the type of the
631 object. Printed in case of exception.
633 Returns:
634 The deserialized object.
636 Example:
638 A mid-level library implementer might want to implement a utility for
639 retrieving an object from its config, as such:
641 ```python
642 def deserialize(config, custom_objects=None):
643 return deserialize_keras_object(
644 identifier,
645 module_objects=globals(),
646 custom_objects=custom_objects,
647 name="MyObjectType",
648 )
649 ```
651 This is how e.g. `keras.layers.deserialize()` is implemented.
652 """
653 if identifier is None:
654 return None
656 if isinstance(identifier, dict):
657 # In this case we are dealing with a Keras config dictionary.
658 config = identifier
659 (cls, cls_config) = class_and_config_for_serialized_keras_object(
660 config, module_objects, custom_objects, printable_module_name)
662 # If this object has already been loaded (i.e. it's shared between multiple
663 # objects), return the already-loaded object.
664 shared_object_id = config.get(SHARED_OBJECT_KEY)
665 shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none
666 if shared_object is not None:
667 return shared_object
669 if hasattr(cls, 'from_config'):
670 arg_spec = tf_inspect.getfullargspec(cls.from_config)
671 custom_objects = custom_objects or {}
673 if 'custom_objects' in arg_spec.args:
674 deserialized_obj = cls.from_config(
675 cls_config,
676 custom_objects=dict(
677 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
678 list(custom_objects.items())))
679 else:
680 with CustomObjectScope(custom_objects):
681 deserialized_obj = cls.from_config(cls_config)
682 else:
683 # Then `cls` may be a function returning a class.
684 # in this case by convention `config` holds
685 # the kwargs of the function.
686 custom_objects = custom_objects or {}
687 with CustomObjectScope(custom_objects):
688 deserialized_obj = cls(**cls_config)
690 # Add object to shared objects, in case we find it referenced again.
691 _shared_object_loading_scope().set(shared_object_id, deserialized_obj)
693 return deserialized_obj
695 elif isinstance(identifier, str):
696 object_name = identifier
697 if custom_objects and object_name in custom_objects:
698 obj = custom_objects.get(object_name)
699 elif object_name in _GLOBAL_CUSTOM_OBJECTS:
700 obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
701 else:
702 obj = module_objects.get(object_name)
703 if obj is None:
704 raise ValueError(
705 'Unknown {}: {}. Please ensure this object is '
706 'passed to the `custom_objects` argument. See '
707 'https://www.tensorflow.org/guide/keras/save_and_serialize'
708 '#registering_the_custom_object for details.'
709 .format(printable_module_name, object_name))
711 # Classes passed by name are instantiated with no args, functions are
712 # returned as-is.
713 if tf_inspect.isclass(obj):
714 return obj()
715 return obj
716 elif tf_inspect.isfunction(identifier):
717 # If a function has already been deserialized, return as is.
718 return identifier
719 else:
720 raise ValueError('Could not interpret serialized %s: %s' %
721 (printable_module_name, identifier))
724def func_dump(func):
725 """Serializes a user defined function.
727 Args:
728 func: the function to serialize.
730 Returns:
731 A tuple `(code, defaults, closure)`.
732 """
733 if os.name == 'nt':
734 raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/')
735 code = codecs.encode(raw_code, 'base64').decode('ascii')
736 else:
737 raw_code = marshal.dumps(func.__code__)
738 code = codecs.encode(raw_code, 'base64').decode('ascii')
739 defaults = func.__defaults__
740 if func.__closure__:
741 closure = tuple(c.cell_contents for c in func.__closure__)
742 else:
743 closure = None
744 return code, defaults, closure
747def func_load(code, defaults=None, closure=None, globs=None):
748 """Deserializes a user defined function.
750 Args:
751 code: bytecode of the function.
752 defaults: defaults of the function.
753 closure: closure of the function.
754 globs: dictionary of global objects.
756 Returns:
757 A function object.
758 """
759 if isinstance(code, (tuple, list)): # unpack previous dump
760 code, defaults, closure = code
761 if isinstance(defaults, list):
762 defaults = tuple(defaults)
764 def ensure_value_to_cell(value):
765 """Ensures that a value is converted to a python cell object.
767 Args:
768 value: Any value that needs to be casted to the cell type
770 Returns:
771 A value wrapped as a cell object (see function "func_load")
772 """
774 def dummy_fn():
775 # pylint: disable=pointless-statement
776 value # just access it so it gets captured in .__closure__
778 cell_value = dummy_fn.__closure__[0]
779 if not isinstance(value, type(cell_value)):
780 return cell_value
781 return value
783 if closure is not None:
784 closure = tuple(ensure_value_to_cell(_) for _ in closure)
785 try:
786 raw_code = codecs.decode(code.encode('ascii'), 'base64')
787 except (UnicodeEncodeError, binascii.Error):
788 raw_code = code.encode('raw_unicode_escape')
789 code = marshal.loads(raw_code)
790 if globs is None:
791 globs = globals()
792 return python_types.FunctionType(
793 code, globs, name=code.co_name, argdefs=defaults, closure=closure)
796def has_arg(fn, name, accept_all=False):
797 """Checks if a callable accepts a given keyword argument.
799 Args:
800 fn: Callable to inspect.
801 name: Check if `fn` can be called with `name` as a keyword argument.
802 accept_all: What to return if there is no parameter called `name` but the
803 function accepts a `**kwargs` argument.
805 Returns:
806 bool, whether `fn` accepts a `name` keyword argument.
807 """
808 arg_spec = tf_inspect.getfullargspec(fn)
809 if accept_all and arg_spec.varkw is not None:
810 return True
811 return name in arg_spec.args or name in arg_spec.kwonlyargs
814@keras_export('keras.utils.Progbar')
815class Progbar(object):
816 """Displays a progress bar.
818 Args:
819 target: Total number of steps expected, None if unknown.
820 width: Progress bar width on screen.
821 verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
822 stateful_metrics: Iterable of string names of metrics that should *not* be
823 averaged over time. Metrics in this list will be displayed as-is. All
824 others will be averaged by the progbar before display.
825 interval: Minimum visual progress update interval (in seconds).
826 unit_name: Display name for step counts (usually "step" or "sample").
827 """
829 def __init__(self,
830 target,
831 width=30,
832 verbose=1,
833 interval=0.05,
834 stateful_metrics=None,
835 unit_name='step'):
836 self.target = target
837 self.width = width
838 self.verbose = verbose
839 self.interval = interval
840 self.unit_name = unit_name
841 if stateful_metrics:
842 self.stateful_metrics = set(stateful_metrics)
843 else:
844 self.stateful_metrics = set()
846 self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
847 sys.stdout.isatty()) or
848 'ipykernel' in sys.modules or
849 'posix' in sys.modules or
850 'PYCHARM_HOSTED' in os.environ)
851 self._total_width = 0
852 self._seen_so_far = 0
853 # We use a dict + list to avoid garbage collection
854 # issues found in OrderedDict
855 self._values = {}
856 self._values_order = []
857 self._start = time.time()
858 self._last_update = 0
860 self._time_after_first_step = None
862 def update(self, current, values=None, finalize=None):
863 """Updates the progress bar.
865 Args:
866 current: Index of current step.
867 values: List of tuples: `(name, value_for_last_step)`. If `name` is in
868 `stateful_metrics`, `value_for_last_step` will be displayed as-is.
869 Else, an average of the metric over time will be displayed.
870 finalize: Whether this is the last update for the progress bar. If
871 `None`, defaults to `current >= self.target`.
872 """
873 if finalize is None:
874 if self.target is None:
875 finalize = False
876 else:
877 finalize = current >= self.target
879 values = values or []
880 for k, v in values:
881 if k not in self._values_order:
882 self._values_order.append(k)
883 if k not in self.stateful_metrics:
884 # In the case that progress bar doesn't have a target value in the first
885 # epoch, both on_batch_end and on_epoch_end will be called, which will
886 # cause 'current' and 'self._seen_so_far' to have the same value. Force
887 # the minimal value to 1 here, otherwise stateful_metric will be 0s.
888 value_base = max(current - self._seen_so_far, 1)
889 if k not in self._values:
890 self._values[k] = [v * value_base, value_base]
891 else:
892 self._values[k][0] += v * value_base
893 self._values[k][1] += value_base
894 else:
895 # Stateful metrics output a numeric value. This representation
896 # means "take an average from a single value" but keeps the
897 # numeric formatting.
898 self._values[k] = [v, 1]
899 self._seen_so_far = current
901 now = time.time()
902 info = ' - %.0fs' % (now - self._start)
903 if self.verbose == 1:
904 if now - self._last_update < self.interval and not finalize:
905 return
907 prev_total_width = self._total_width
908 if self._dynamic_display:
909 sys.stdout.write('\b' * prev_total_width)
910 sys.stdout.write('\r')
911 else:
912 sys.stdout.write('\n')
914 if self.target is not None:
915 numdigits = int(np.log10(self.target)) + 1
916 bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target)
917 prog = float(current) / self.target
918 prog_width = int(self.width * prog)
919 if prog_width > 0:
920 bar += ('=' * (prog_width - 1))
921 if current < self.target:
922 bar += '>'
923 else:
924 bar += '='
925 bar += ('.' * (self.width - prog_width))
926 bar += ']'
927 else:
928 bar = '%7d/Unknown' % current
930 self._total_width = len(bar)
931 sys.stdout.write(bar)
933 time_per_unit = self._estimate_step_duration(current, now)
935 if self.target is None or finalize:
936 if time_per_unit >= 1 or time_per_unit == 0:
937 info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
938 elif time_per_unit >= 1e-3:
939 info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name)
940 else:
941 info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name)
942 else:
943 eta = time_per_unit * (self.target - current)
944 if eta > 3600:
945 eta_format = '%d:%02d:%02d' % (eta // 3600,
946 (eta % 3600) // 60, eta % 60)
947 elif eta > 60:
948 eta_format = '%d:%02d' % (eta // 60, eta % 60)
949 else:
950 eta_format = '%ds' % eta
952 info = ' - ETA: %s' % eta_format
954 for k in self._values_order:
955 info += ' - %s:' % k
956 if isinstance(self._values[k], list):
957 avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
958 if abs(avg) > 1e-3:
959 info += ' %.4f' % avg
960 else:
961 info += ' %.4e' % avg
962 else:
963 info += ' %s' % self._values[k]
965 self._total_width += len(info)
966 if prev_total_width > self._total_width:
967 info += (' ' * (prev_total_width - self._total_width))
969 if finalize:
970 info += '\n'
972 sys.stdout.write(info)
973 sys.stdout.flush()
975 elif self.verbose == 2:
976 if finalize:
977 numdigits = int(np.log10(self.target)) + 1
978 count = ('%' + str(numdigits) + 'd/%d') % (current, self.target)
979 info = count + info
980 for k in self._values_order:
981 info += ' - %s:' % k
982 avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
983 if avg > 1e-3:
984 info += ' %.4f' % avg
985 else:
986 info += ' %.4e' % avg
987 info += '\n'
989 sys.stdout.write(info)
990 sys.stdout.flush()
992 self._last_update = now
994 def add(self, n, values=None):
995 self.update(self._seen_so_far + n, values)
997 def _estimate_step_duration(self, current, now):
998 """Estimate the duration of a single step.
1000 Given the step number `current` and the corresponding time `now`
1001 this function returns an estimate for how long a single step
1002 takes. If this is called before one step has been completed
1003 (i.e. `current == 0`) then zero is given as an estimate. The duration
1004 estimate ignores the duration of the (assumed to be non-representative)
1005 first step for estimates when more steps are available (i.e. `current>1`).
1006 Args:
1007 current: Index of current step.
1008 now: The current time.
1009 Returns: Estimate of the duration of a single step.
1010 """
1011 if current:
1012 # there are a few special scenarios here:
1013 # 1) somebody is calling the progress bar without ever supplying step 1
1014 # 2) somebody is calling the progress bar and supplies step one mulitple
1015 # times, e.g. as part of a finalizing call
1016 # in these cases, we just fall back to the simple calculation
1017 if self._time_after_first_step is not None and current > 1:
1018 time_per_unit = (now - self._time_after_first_step) / (current - 1)
1019 else:
1020 time_per_unit = (now - self._start) / current
1022 if current == 1:
1023 self._time_after_first_step = now
1024 return time_per_unit
1025 else:
1026 return 0
1028 def _update_stateful_metrics(self, stateful_metrics):
1029 self.stateful_metrics = self.stateful_metrics.union(stateful_metrics)
1032def make_batches(size, batch_size):
1033 """Returns a list of batch indices (tuples of indices).
1035 Args:
1036 size: Integer, total size of the data to slice into batches.
1037 batch_size: Integer, batch size.
1039 Returns:
1040 A list of tuples of array indices.
1041 """
1042 num_batches = int(np.ceil(size / float(batch_size)))
1043 return [(i * batch_size, min(size, (i + 1) * batch_size))
1044 for i in range(0, num_batches)]
1047def slice_arrays(arrays, start=None, stop=None):
1048 """Slice an array or list of arrays.
1050 This takes an array-like, or a list of
1051 array-likes, and outputs:
1052 - arrays[start:stop] if `arrays` is an array-like
1053 - [x[start:stop] for x in arrays] if `arrays` is a list
1055 Can also work on list/array of indices: `slice_arrays(x, indices)`
1057 Args:
1058 arrays: Single array or list of arrays.
1059 start: can be an integer index (start index) or a list/array of indices
1060 stop: integer (stop index); should be None if `start` was a list.
1062 Returns:
1063 A slice of the array(s).
1065 Raises:
1066 ValueError: If the value of start is a list and stop is not None.
1067 """
1068 if arrays is None:
1069 return [None]
1070 if isinstance(start, list) and stop is not None:
1071 raise ValueError('The stop argument has to be None if the value of start '
1072 'is a list.')
1073 elif isinstance(arrays, list):
1074 if hasattr(start, '__len__'):
1075 # hdf5 datasets only support list objects as indices
1076 if hasattr(start, 'shape'):
1077 start = start.tolist()
1078 return [None if x is None else x[start] for x in arrays]
1079 return [
1080 None if x is None else
1081 None if not hasattr(x, '__getitem__') else x[start:stop] for x in arrays
1082 ]
1083 else:
1084 if hasattr(start, '__len__'):
1085 if hasattr(start, 'shape'):
1086 start = start.tolist()
1087 return arrays[start]
1088 if hasattr(start, '__getitem__'):
1089 return arrays[start:stop]
1090 return [None]
1093def to_list(x):
1094 """Normalizes a list/tensor into a list.
1096 If a tensor is passed, we return
1097 a list of size 1 containing the tensor.
1099 Args:
1100 x: target object to be normalized.
1102 Returns:
1103 A list.
1104 """
1105 if isinstance(x, list):
1106 return x
1107 return [x]
1110def to_snake_case(name):
1111 intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
1112 insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
1113 # If the class is private the name starts with "_" which is not secure
1114 # for creating scopes. We prefix the name with "private" in this case.
1115 if insecure[0] != '_':
1116 return insecure
1117 return 'private' + insecure
1120def is_all_none(structure):
1121 iterable = nest.flatten(structure)
1122 # We cannot use Python's `any` because the iterable may return Tensors.
1123 for element in iterable:
1124 if element is not None:
1125 return False
1126 return True
1129def check_for_unexpected_keys(name, input_dict, expected_values):
1130 unknown = set(input_dict.keys()).difference(expected_values)
1131 if unknown:
1132 raise ValueError('Unknown entries in {} dictionary: {}. Only expected '
1133 'following keys: {}'.format(name, list(unknown),
1134 expected_values))
1137def validate_kwargs(kwargs,
1138 allowed_kwargs,
1139 error_message='Keyword argument not understood:'):
1140 """Checks that all keyword arguments are in the set of allowed keys."""
1141 for kwarg in kwargs:
1142 if kwarg not in allowed_kwargs:
1143 raise TypeError(error_message, kwarg)
1146def validate_config(config):
1147 """Determines whether config appears to be a valid layer config."""
1148 return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
1151def default(method):
1152 """Decorates a method to detect overrides in subclasses."""
1153 method._is_default = True # pylint: disable=protected-access
1154 return method
1157def is_default(method):
1158 """Check if a method is decorated with the `default` wrapper."""
1159 return getattr(method, '_is_default', False)
1162def populate_dict_with_module_objects(target_dict, modules, obj_filter):
1163 for module in modules:
1164 for name in dir(module):
1165 obj = getattr(module, name)
1166 if obj_filter(obj):
1167 target_dict[name] = obj
1170class LazyLoader(python_types.ModuleType):
1171 """Lazily import a module, mainly to avoid pulling in large dependencies."""
1173 def __init__(self, local_name, parent_module_globals, name):
1174 self._local_name = local_name
1175 self._parent_module_globals = parent_module_globals
1176 super(LazyLoader, self).__init__(name)
1178 def _load(self):
1179 """Load the module and insert it into the parent's globals."""
1180 # Import the target module and insert it into the parent's namespace
1181 module = importlib.import_module(self.__name__)
1182 self._parent_module_globals[self._local_name] = module
1183 # Update this object's dict so that if someone keeps a reference to the
1184 # LazyLoader, lookups are efficient (__getattr__ is only called on lookups
1185 # that fail).
1186 self.__dict__.update(module.__dict__)
1187 return module
1189 def __getattr__(self, item):
1190 module = self._load()
1191 return getattr(module, item)
1194# Aliases
1196custom_object_scope = CustomObjectScope # pylint: disable=invalid-name