Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/autotrackable.py: 27%
64 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Dependency tracking for trackable objects."""
17import warnings
19from absl import logging
21from tensorflow.python.eager import def_function
22from tensorflow.python.eager import function as defun
23from tensorflow.python.trackable import base
24from tensorflow.python.trackable import data_structures
25from tensorflow.python.types import core as core_types
26from tensorflow.python.util.tf_export import tf_export
29@tf_export("__internal__.tracking.AutoTrackable", v1=[])
30class AutoTrackable(base.Trackable):
31 """Manages dependencies on other objects.
33 `Trackable` objects may have dependencies: other `Trackable` objects
34 which should be saved if the object declaring the dependency is saved. A
35 correctly saveable program has a dependency graph such that if changing a
36 global variable affects an object (e.g. changes the behavior of any of its
37 methods) then there is a chain of dependencies from the influenced object to
38 the variable.
40 Dependency edges have names, and are created implicitly when a
41 `Trackable` object is assigned to an attribute of another
42 `Trackable` object. For example:
44 ```
45 obj = Trackable()
46 obj.v = ResourceVariable(0.)
47 ```
49 The `Trackable` object `obj` now has a dependency named "v" on a
50 variable.
52 `Trackable` objects may specify `Tensor`s to be saved and restored
53 directly (e.g. a `Variable` indicating how to save itself) rather than through
54 dependencies on other objects. See
55 `Trackable._gather_saveables_for_checkpoint` for details.
56 """
58 def __setattr__(self, name, value):
59 """Support self.foo = trackable syntax."""
60 try:
61 if getattr(self, name) is value:
62 # Short circuit for `self.$x = self.$x`.
63 return
64 except AttributeError:
65 pass
67 if getattr(self, "_self_setattr_tracking", True):
68 value = data_structures.sticky_attribute_assignment(
69 trackable=self, value=value, name=name)
70 super(AutoTrackable, self).__setattr__(name, value)
72 def __delattr__(self, name):
73 self._delete_tracking(name)
74 super(AutoTrackable, self).__delattr__(name)
76 def _no_dependency(self, value):
77 """Override to allow TrackableBase to disable dependency tracking."""
78 return data_structures.NoDependency(value)
80 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
81 """Returns all children of a trackable, including functions."""
82 if save_type != base.SaveType.SAVEDMODEL:
83 return super(AutoTrackable, self)._trackable_children(
84 save_type, **kwargs)
86 functions = {}
87 try:
88 # We get the attributes, suppressing warnings and exceptions.
89 logging_verbosity = logging.get_verbosity()
90 logging.set_verbosity(logging.FATAL)
91 for attribute_name in dir(self):
92 try:
93 with warnings.catch_warnings():
94 warnings.simplefilter("ignore")
95 attribute_value = getattr(self, attribute_name, None)
96 except Exception: # pylint: disable=broad-except
97 # NOTE: If we make the exception catching here less broad, we might
98 # need to revisit `finally` block below.
99 # We really don't want to throw an exception just because some
100 # object's attribute accessor is broken.
101 attribute_value = None
102 if isinstance(attribute_value, (def_function.Function,
103 defun.ConcreteFunction)):
104 functions[attribute_name] = attribute_value
105 finally:
106 logging.set_verbosity(logging_verbosity)
108 # Trace concrete functions to force side-effects:
109 # 1. populate the cache for functions that have an input_signature
110 # and have not been called
111 # 2. force side effects of creation of concrete functions, e.g. create
112 # variables on first run.
113 for fn in functions.values():
114 if isinstance(fn, core_types.GenericFunction):
115 fn._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
117 # Additional dependencies may have been generated during function tracing
118 # (e.g. captured variables). Make sure we return those too.
119 children = {}
120 for name, child in self._checkpoint_dependencies:
121 if isinstance(child, (core_types.GenericFunction,
122 core_types.ConcreteFunction)):
123 # Skip "tracked" functions for now since there may be objects that
124 # automatically track functions that should not be saved.
125 # TODO(kathywu): remove once `_list_functions_for_serialization` has
126 # been fully deprecated.
127 continue
129 if name in functions and child is not functions[name]:
130 raise ValueError(
131 "Can't save object because it has multiple children with the same "
132 f"name. Object: {self}, attribute name: {name}, child 1: "
133 f"{child}, child 2: {functions[name]}")
135 children[name] = child
137 children.update(functions)
138 return children
140 def _delete_tracking(self, name):
141 """Removes the tracking of name."""
142 self._maybe_initialize_trackable()
143 if name in self._unconditional_dependency_names:
144 del self._unconditional_dependency_names[name]
145 for index, (dep_name, _) in enumerate(
146 self._unconditional_checkpoint_dependencies):
147 if dep_name == name:
148 del self._unconditional_checkpoint_dependencies[index]
149 break
151 def _add_trackable_child(self, name, value):
152 self.__setattr__(name, value)