Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/layer_utils.py: 36%
70 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to layer/model functionality."""
17# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils
18# once __init__ files no longer require all of tf.keras to be imported together.
20import collections
21import functools
22import weakref
24from tensorflow.python.util import object_identity
26try:
27 # typing module is only used for comment type annotations.
28 import typing # pylint: disable=g-import-not-at-top, unused-import
29except ImportError:
30 pass
33def is_layer(obj):
34 """Implicit check for Layer-like objects."""
35 # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
36 return hasattr(obj, "_is_layer") and not isinstance(obj, type)
39def has_weights(obj):
40 """Implicit check for Layer-like objects."""
41 # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
42 has_weight = (hasattr(type(obj), "trainable_weights")
43 and hasattr(type(obj), "non_trainable_weights"))
45 return has_weight and not isinstance(obj, type)
48def invalidate_recursive_cache(key):
49 """Convenience decorator to invalidate the cache when setting attributes."""
50 def outer(f):
51 @functools.wraps(f)
52 def wrapped(self, value):
53 sentinel = getattr(self, "_attribute_sentinel") # type: AttributeSentinel
54 sentinel.invalidate(key)
55 return f(self, value)
56 return wrapped
57 return outer
60class MutationSentinel(object):
61 """Container for tracking whether a property is in a cached state."""
62 _in_cached_state = False
64 def mark_as(self, value): # type: (MutationSentinel, bool) -> bool
65 may_affect_upstream = (value != self._in_cached_state)
66 self._in_cached_state = value
67 return may_affect_upstream
69 @property
70 def in_cached_state(self):
71 return self._in_cached_state
74class AttributeSentinel(object):
75 """Container for managing attribute cache state within a Layer.
77 The cache can be invalidated either on an individual basis (for instance when
78 an attribute is mutated) or a layer-wide basis (such as when a new dependency
79 is added).
80 """
82 def __init__(self, always_propagate=False):
83 self._parents = weakref.WeakSet()
84 self.attributes = collections.defaultdict(MutationSentinel)
86 # The trackable data structure containers are simple pass throughs. They
87 # don't know or care about particular attributes. As a result, they will
88 # consider themselves to be in a cached state, so it's up to the Layer
89 # which contains them to terminate propagation.
90 self.always_propagate = always_propagate
92 def __repr__(self):
93 return "{}\n {}".format(
94 super(AttributeSentinel, self).__repr__(),
95 {k: v.in_cached_state for k, v in self.attributes.items()})
97 def add_parent(self, node):
98 # type: (AttributeSentinel, AttributeSentinel) -> None
100 # Properly tracking removal is quite challenging; however since this is only
101 # used to invalidate a cache it's alright to be overly conservative. We need
102 # to invalidate the cache of `node` (since it has implicitly gained a child)
103 # but we don't need to invalidate self since attributes should not depend on
104 # parent Layers.
105 self._parents.add(node)
106 node.invalidate_all()
108 def get(self, key):
109 # type: (AttributeSentinel, str) -> bool
110 return self.attributes[key].in_cached_state
112 def _set(self, key, value):
113 # type: (AttributeSentinel, str, bool) -> None
114 may_affect_upstream = self.attributes[key].mark_as(value)
115 if may_affect_upstream or self.always_propagate:
116 for node in self._parents: # type: AttributeSentinel
117 node.invalidate(key)
119 def mark_cached(self, key):
120 # type: (AttributeSentinel, str) -> None
121 self._set(key, True)
123 def invalidate(self, key):
124 # type: (AttributeSentinel, str) -> None
125 self._set(key, False)
127 def invalidate_all(self):
128 # Parents may have different keys than their children, so we locally
129 # invalidate but use the `invalidate_all` method of parents.
130 for key in self.attributes.keys():
131 self.attributes[key].mark_as(False)
133 for node in self._parents:
134 node.invalidate_all()
137def filter_empty_layer_containers(layer_list):
138 """Filter out empty Layer-like containers and uniquify."""
139 # TODO(b/130381733): Make this an attribute in base_layer.Layer.
140 existing = object_identity.ObjectIdentitySet()
141 to_visit = layer_list[::-1]
142 while to_visit:
143 obj = to_visit.pop()
144 if obj in existing:
145 continue
146 existing.add(obj)
147 if is_layer(obj):
148 yield obj
149 else:
150 sub_layers = getattr(obj, "layers", None) or []
152 # Trackable data structures will not show up in ".layers" lists, but
153 # the layers they contain will.
154 to_visit.extend(sub_layers[::-1])