Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/trackable/layer_utils.py: 36%

70 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities related to layer/model functionality.""" 

16 

17# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils 

18# once __init__ files no longer require all of tf.keras to be imported together. 

19 

20import collections 

21import functools 

22import weakref 

23 

24from tensorflow.python.util import object_identity 

25 

26try: 

27 # typing module is only used for comment type annotations. 

28 import typing # pylint: disable=g-import-not-at-top, unused-import 

29except ImportError: 

30 pass 

31 

32 

33def is_layer(obj): 

34 """Implicit check for Layer-like objects.""" 

35 # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). 

36 return hasattr(obj, "_is_layer") and not isinstance(obj, type) 

37 

38 

39def has_weights(obj): 

40 """Implicit check for Layer-like objects.""" 

41 # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). 

42 has_weight = (hasattr(type(obj), "trainable_weights") 

43 and hasattr(type(obj), "non_trainable_weights")) 

44 

45 return has_weight and not isinstance(obj, type) 

46 

47 

48def invalidate_recursive_cache(key): 

49 """Convenience decorator to invalidate the cache when setting attributes.""" 

50 def outer(f): 

51 @functools.wraps(f) 

52 def wrapped(self, value): 

53 sentinel = getattr(self, "_attribute_sentinel") # type: AttributeSentinel 

54 sentinel.invalidate(key) 

55 return f(self, value) 

56 return wrapped 

57 return outer 

58 

59 

60class MutationSentinel(object): 

61 """Container for tracking whether a property is in a cached state.""" 

62 _in_cached_state = False 

63 

64 def mark_as(self, value): # type: (MutationSentinel, bool) -> bool 

65 may_affect_upstream = (value != self._in_cached_state) 

66 self._in_cached_state = value 

67 return may_affect_upstream 

68 

69 @property 

70 def in_cached_state(self): 

71 return self._in_cached_state 

72 

73 

74class AttributeSentinel(object): 

75 """Container for managing attribute cache state within a Layer. 

76 

77 The cache can be invalidated either on an individual basis (for instance when 

78 an attribute is mutated) or a layer-wide basis (such as when a new dependency 

79 is added). 

80 """ 

81 

82 def __init__(self, always_propagate=False): 

83 self._parents = weakref.WeakSet() 

84 self.attributes = collections.defaultdict(MutationSentinel) 

85 

86 # The trackable data structure containers are simple pass throughs. They 

87 # don't know or care about particular attributes. As a result, they will 

88 # consider themselves to be in a cached state, so it's up to the Layer 

89 # which contains them to terminate propagation. 

90 self.always_propagate = always_propagate 

91 

92 def __repr__(self): 

93 return "{}\n {}".format( 

94 super(AttributeSentinel, self).__repr__(), 

95 {k: v.in_cached_state for k, v in self.attributes.items()}) 

96 

97 def add_parent(self, node): 

98 # type: (AttributeSentinel, AttributeSentinel) -> None 

99 

100 # Properly tracking removal is quite challenging; however since this is only 

101 # used to invalidate a cache it's alright to be overly conservative. We need 

102 # to invalidate the cache of `node` (since it has implicitly gained a child) 

103 # but we don't need to invalidate self since attributes should not depend on 

104 # parent Layers. 

105 self._parents.add(node) 

106 node.invalidate_all() 

107 

108 def get(self, key): 

109 # type: (AttributeSentinel, str) -> bool 

110 return self.attributes[key].in_cached_state 

111 

112 def _set(self, key, value): 

113 # type: (AttributeSentinel, str, bool) -> None 

114 may_affect_upstream = self.attributes[key].mark_as(value) 

115 if may_affect_upstream or self.always_propagate: 

116 for node in self._parents: # type: AttributeSentinel 

117 node.invalidate(key) 

118 

119 def mark_cached(self, key): 

120 # type: (AttributeSentinel, str) -> None 

121 self._set(key, True) 

122 

123 def invalidate(self, key): 

124 # type: (AttributeSentinel, str) -> None 

125 self._set(key, False) 

126 

127 def invalidate_all(self): 

128 # Parents may have different keys than their children, so we locally 

129 # invalidate but use the `invalidate_all` method of parents. 

130 for key in self.attributes.keys(): 

131 self.attributes[key].mark_as(False) 

132 

133 for node in self._parents: 

134 node.invalidate_all() 

135 

136 

137def filter_empty_layer_containers(layer_list): 

138 """Filter out empty Layer-like containers and uniquify.""" 

139 # TODO(b/130381733): Make this an attribute in base_layer.Layer. 

140 existing = object_identity.ObjectIdentitySet() 

141 to_visit = layer_list[::-1] 

142 while to_visit: 

143 obj = to_visit.pop() 

144 if obj in existing: 

145 continue 

146 existing.add(obj) 

147 if is_layer(obj): 

148 yield obj 

149 else: 

150 sub_layers = getattr(obj, "layers", None) or [] 

151 

152 # Trackable data structures will not show up in ".layers" lists, but 

153 # the layers they contain will. 

154 to_visit.extend(sub_layers[::-1])