Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/object_identity.py: 50%

117 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Utilities for collecting objects based on "is" comparison.""" 

2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16 

17import collections 

18import weakref 

19 

20 

21# LINT.IfChange 

22class _ObjectIdentityWrapper: 

23 """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped. 

24 

25 Since __eq__ is based on object identity, it's safe to also define __hash__ 

26 based on object ids. This lets us add unhashable types like trackable 

27 _ListWrapper objects to object-identity collections. 

28 """ 

29 

30 __slots__ = ["_wrapped", "__weakref__"] 

31 

32 def __init__(self, wrapped): 

33 self._wrapped = wrapped 

34 

35 @property 

36 def unwrapped(self): 

37 return self._wrapped 

38 

39 def _assert_type(self, other): 

40 if not isinstance(other, _ObjectIdentityWrapper): 

41 raise TypeError( 

42 "Cannot compare wrapped object with unwrapped object. " 

43 "Expect the object to be `_ObjectIdentityWrapper`. " 

44 f"Got: {other}" 

45 ) 

46 

47 def __lt__(self, other): 

48 self._assert_type(other) 

49 return id(self._wrapped) < id(other._wrapped) 

50 

51 def __gt__(self, other): 

52 self._assert_type(other) 

53 return id(self._wrapped) > id(other._wrapped) 

54 

55 def __eq__(self, other): 

56 if other is None: 

57 return False 

58 self._assert_type(other) 

59 return self._wrapped is other._wrapped 

60 

61 def __ne__(self, other): 

62 return not self.__eq__(other) 

63 

64 def __hash__(self): 

65 # Wrapper id() is also fine for weakrefs. In fact, we rely on 

66 # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is 

67 # weakref.ref(a) in _WeakObjectIdentityWrapper. 

68 return id(self._wrapped) 

69 

70 def __repr__(self): 

71 return f"<{type(self).__name__} wrapping {self._wrapped!r}>" 

72 

73 

74class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper): 

75 

76 __slots__ = () 

77 

78 def __init__(self, wrapped): 

79 super().__init__(weakref.ref(wrapped)) 

80 

81 @property 

82 def unwrapped(self): 

83 return self._wrapped() 

84 

85 

86class Reference(_ObjectIdentityWrapper): 

87 """Reference that refers an object. 

88 

89 ```python 

90 x = [1] 

91 y = [1] 

92 

93 x_ref1 = Reference(x) 

94 x_ref2 = Reference(x) 

95 y_ref2 = Reference(y) 

96 

97 print(x_ref1 == x_ref2) 

98 ==> True 

99 

100 print(x_ref1 == y) 

101 ==> False 

102 ``` 

103 """ 

104 

105 __slots__ = () 

106 

107 # Disabling super class' unwrapped field. 

108 unwrapped = property() 

109 

110 def deref(self): 

111 """Returns the referenced object. 

112 

113 ```python 

114 x_ref = Reference(x) 

115 print(x is x_ref.deref()) 

116 ==> True 

117 ``` 

118 """ 

119 return self._wrapped 

120 

121 

122class ObjectIdentityDictionary(collections.abc.MutableMapping): 

123 """A mutable mapping data structure which compares using "is". 

124 

125 This is necessary because we have trackable objects (_ListWrapper) which 

126 have behavior identical to built-in Python lists (including being unhashable 

127 and comparing based on the equality of their contents by default). 

128 """ 

129 

130 __slots__ = ["_storage"] 

131 

132 def __init__(self): 

133 self._storage = {} 

134 

135 def _wrap_key(self, key): 

136 return _ObjectIdentityWrapper(key) 

137 

138 def __getitem__(self, key): 

139 return self._storage[self._wrap_key(key)] 

140 

141 def __setitem__(self, key, value): 

142 self._storage[self._wrap_key(key)] = value 

143 

144 def __delitem__(self, key): 

145 del self._storage[self._wrap_key(key)] 

146 

147 def __len__(self): 

148 return len(self._storage) 

149 

150 def __iter__(self): 

151 for key in self._storage: 

152 yield key.unwrapped 

153 

154 def __repr__(self): 

155 return f"ObjectIdentityDictionary({repr(self._storage)})" 

156 

157 

158class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary): 

159 """Like weakref.WeakKeyDictionary, but compares objects with "is".""" 

160 

161 __slots__ = ["__weakref__"] 

162 

163 def _wrap_key(self, key): 

164 return _WeakObjectIdentityWrapper(key) 

165 

166 def __len__(self): 

167 # Iterate, discarding old weak refs 

168 return len(list(self._storage)) 

169 

170 def __iter__(self): 

171 keys = self._storage.keys() 

172 for key in keys: 

173 unwrapped = key.unwrapped 

174 if unwrapped is None: 

175 del self[key] 

176 else: 

177 yield unwrapped 

178 

179 

180class ObjectIdentitySet(collections.abc.MutableSet): 

181 """Like the built-in set, but compares objects with "is".""" 

182 

183 __slots__ = ["_storage", "__weakref__"] 

184 

185 def __init__(self, *args): 

186 self._storage = set(self._wrap_key(obj) for obj in list(*args)) 

187 

188 @staticmethod 

189 def _from_storage(storage): 

190 result = ObjectIdentitySet() 

191 result._storage = storage 

192 return result 

193 

194 def _wrap_key(self, key): 

195 return _ObjectIdentityWrapper(key) 

196 

197 def __contains__(self, key): 

198 return self._wrap_key(key) in self._storage 

199 

200 def discard(self, key): 

201 self._storage.discard(self._wrap_key(key)) 

202 

203 def add(self, key): 

204 self._storage.add(self._wrap_key(key)) 

205 

206 def update(self, items): 

207 self._storage.update([self._wrap_key(item) for item in items]) 

208 

209 def clear(self): 

210 self._storage.clear() 

211 

212 def intersection(self, items): 

213 return self._storage.intersection( 

214 [self._wrap_key(item) for item in items] 

215 ) 

216 

217 def difference(self, items): 

218 return ObjectIdentitySet._from_storage( 

219 self._storage.difference([self._wrap_key(item) for item in items]) 

220 ) 

221 

222 def __len__(self): 

223 return len(self._storage) 

224 

225 def __iter__(self): 

226 keys = list(self._storage) 

227 for key in keys: 

228 yield key.unwrapped 

229 

230 

231class ObjectIdentityWeakSet(ObjectIdentitySet): 

232 """Like weakref.WeakSet, but compares objects with "is".""" 

233 

234 __slots__ = () 

235 

236 def _wrap_key(self, key): 

237 return _WeakObjectIdentityWrapper(key) 

238 

239 def __len__(self): 

240 # Iterate, discarding old weak refs 

241 return len([_ for _ in self]) 

242 

243 def __iter__(self): 

244 keys = list(self._storage) 

245 for key in keys: 

246 unwrapped = key.unwrapped 

247 if unwrapped is None: 

248 self.discard(key) 

249 else: 

250 yield unwrapped 

251 

252 

253# LINT.ThenChange(//tensorflow/python/util/object_identity.py) 

254