Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/object_identity.py: 45%

136 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Utilities for collecting objects based on "is" comparison.""" 

2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

3# 

4# Licensed under the Apache License, Version 2.0 (the "License"); 

5# you may not use this file except in compliance with the License. 

6# You may obtain a copy of the License at 

7# 

8# http://www.apache.org/licenses/LICENSE-2.0 

9# 

10# Unless required by applicable law or agreed to in writing, software 

11# distributed under the License is distributed on an "AS IS" BASIS, 

12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

13# See the License for the specific language governing permissions and 

14# limitations under the License. 

15# ============================================================================== 

16 

17import collections 

18from typing import Any, Set 

19import weakref 

20 

21 

22# LINT.IfChange 

23class _ObjectIdentityWrapper(object): 

24 """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped. 

25 

26 Since __eq__ is based on object identity, it's safe to also define __hash__ 

27 based on object ids. This lets us add unhashable types like trackable 

28 _ListWrapper objects to object-identity collections. 

29 """ 

30 

31 __slots__ = ["_wrapped", "__weakref__"] 

32 

33 def __init__(self, wrapped): 

34 self._wrapped = wrapped 

35 

36 @property 

37 def unwrapped(self): 

38 return self._wrapped 

39 

40 def _assert_type(self, other): 

41 if not isinstance(other, _ObjectIdentityWrapper): 

42 raise TypeError("Cannot compare wrapped object with unwrapped object") 

43 

44 def __lt__(self, other): 

45 self._assert_type(other) 

46 return id(self._wrapped) < id(other._wrapped) # pylint: disable=protected-access 

47 

48 def __gt__(self, other): 

49 self._assert_type(other) 

50 return id(self._wrapped) > id(other._wrapped) # pylint: disable=protected-access 

51 

52 def __eq__(self, other): 

53 if other is None: 

54 return False 

55 self._assert_type(other) 

56 return self._wrapped is other._wrapped # pylint: disable=protected-access 

57 

58 def __ne__(self, other): 

59 return not self.__eq__(other) 

60 

61 def __hash__(self): 

62 # Wrapper id() is also fine for weakrefs. In fact, we rely on 

63 # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is 

64 # weakref.ref(a) in _WeakObjectIdentityWrapper. 

65 return id(self._wrapped) 

66 

67 def __repr__(self): 

68 return "<{} wrapping {!r}>".format(type(self).__name__, self._wrapped) 

69 

70 

71class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper): 

72 

73 __slots__ = () 

74 

75 def __init__(self, wrapped): 

76 super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped)) 

77 

78 @property 

79 def unwrapped(self): 

80 return self._wrapped() 

81 

82 

83class Reference(_ObjectIdentityWrapper): 

84 """Reference that refers an object. 

85 

86 ```python 

87 x = [1] 

88 y = [1] 

89 

90 x_ref1 = Reference(x) 

91 x_ref2 = Reference(x) 

92 y_ref2 = Reference(y) 

93 

94 print(x_ref1 == x_ref2) 

95 ==> True 

96 

97 print(x_ref1 == y) 

98 ==> False 

99 ``` 

100 """ 

101 

102 __slots__ = () 

103 

104 # Disabling super class' unwrapped field. 

105 unwrapped = property() 

106 

107 def deref(self): 

108 """Returns the referenced object. 

109 

110 ```python 

111 x_ref = Reference(x) 

112 print(x is x_ref.deref()) 

113 ==> True 

114 ``` 

115 """ 

116 return self._wrapped 

117 

118 

119class ObjectIdentityDictionary(collections.abc.MutableMapping): 

120 """A mutable mapping data structure which compares using "is". 

121 

122 This is necessary because we have trackable objects (_ListWrapper) which 

123 have behavior identical to built-in Python lists (including being unhashable 

124 and comparing based on the equality of their contents by default). 

125 """ 

126 

127 __slots__ = ["_storage"] 

128 

129 def __init__(self): 

130 self._storage = {} 

131 

132 def _wrap_key(self, key): 

133 return _ObjectIdentityWrapper(key) 

134 

135 def __getitem__(self, key): 

136 return self._storage[self._wrap_key(key)] 

137 

138 def __setitem__(self, key, value): 

139 self._storage[self._wrap_key(key)] = value 

140 

141 def __delitem__(self, key): 

142 del self._storage[self._wrap_key(key)] 

143 

144 def __len__(self): 

145 return len(self._storage) 

146 

147 def __iter__(self): 

148 for key in self._storage: 

149 yield key.unwrapped 

150 

151 def __repr__(self): 

152 return "ObjectIdentityDictionary(%s)" % repr(self._storage) 

153 

154 

155class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary): 

156 """Like weakref.WeakKeyDictionary, but compares objects with "is".""" 

157 

158 __slots__ = ["__weakref__"] 

159 

160 def _wrap_key(self, key): 

161 return _WeakObjectIdentityWrapper(key) 

162 

163 def __len__(self): 

164 # Iterate, discarding old weak refs 

165 return len(list(self._storage)) 

166 

167 def __iter__(self): 

168 keys = self._storage.keys() 

169 for key in keys: 

170 unwrapped = key.unwrapped 

171 if unwrapped is None: 

172 del self[key] 

173 else: 

174 yield unwrapped 

175 

176 

177class ObjectIdentitySet(collections.abc.MutableSet): 

178 """Like the built-in set, but compares objects with "is".""" 

179 

180 __slots__ = ["_storage", "__weakref__"] 

181 

182 def __init__(self, *args): 

183 self._storage = set(self._wrap_key(obj) for obj in list(*args)) 

184 

185 def __le__(self, other: Set[Any]) -> bool: 

186 if not isinstance(other, Set): 

187 return NotImplemented 

188 if len(self) > len(other): 

189 return False 

190 for item in self._storage: 

191 if item not in other: 

192 return False 

193 return True 

194 

195 def __ge__(self, other: Set[Any]) -> bool: 

196 if not isinstance(other, Set): 

197 return NotImplemented 

198 if len(self) < len(other): 

199 return False 

200 for item in other: 

201 if item not in self: 

202 return False 

203 return True 

204 

205 @staticmethod 

206 def _from_storage(storage): 

207 result = ObjectIdentitySet() 

208 result._storage = storage # pylint: disable=protected-access 

209 return result 

210 

211 def _wrap_key(self, key): 

212 return _ObjectIdentityWrapper(key) 

213 

214 def __contains__(self, key): 

215 return self._wrap_key(key) in self._storage 

216 

217 def discard(self, key): 

218 self._storage.discard(self._wrap_key(key)) 

219 

220 def add(self, key): 

221 self._storage.add(self._wrap_key(key)) 

222 

223 def update(self, items): 

224 self._storage.update([self._wrap_key(item) for item in items]) 

225 

226 def clear(self): 

227 self._storage.clear() 

228 

229 def intersection(self, items): 

230 return self._storage.intersection([self._wrap_key(item) for item in items]) 

231 

232 def difference(self, items): 

233 return ObjectIdentitySet._from_storage( 

234 self._storage.difference([self._wrap_key(item) for item in items])) 

235 

236 def __len__(self): 

237 return len(self._storage) 

238 

239 def __iter__(self): 

240 keys = list(self._storage) 

241 for key in keys: 

242 yield key.unwrapped 

243 

244 

245class ObjectIdentityWeakSet(ObjectIdentitySet): 

246 """Like weakref.WeakSet, but compares objects with "is".""" 

247 

248 __slots__ = () 

249 

250 def _wrap_key(self, key): 

251 return _WeakObjectIdentityWrapper(key) 

252 

253 def __len__(self): 

254 # Iterate, discarding old weak refs 

255 return len([_ for _ in self]) 

256 

257 def __iter__(self): 

258 keys = list(self._storage) 

259 for key in keys: 

260 unwrapped = key.unwrapped 

261 if unwrapped is None: 

262 self.discard(key) 

263 else: 

264 yield unwrapped 

265# LINT.ThenChange(//tensorflow/python/util/object_identity.py)