Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/utils/object_identity.py: 45%
136 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Utilities for collecting objects based on "is" comparison."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
17import collections
18from typing import Any, Set
19import weakref
22# LINT.IfChange
23class _ObjectIdentityWrapper(object):
24 """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
26 Since __eq__ is based on object identity, it's safe to also define __hash__
27 based on object ids. This lets us add unhashable types like trackable
28 _ListWrapper objects to object-identity collections.
29 """
31 __slots__ = ["_wrapped", "__weakref__"]
33 def __init__(self, wrapped):
34 self._wrapped = wrapped
36 @property
37 def unwrapped(self):
38 return self._wrapped
40 def _assert_type(self, other):
41 if not isinstance(other, _ObjectIdentityWrapper):
42 raise TypeError("Cannot compare wrapped object with unwrapped object")
44 def __lt__(self, other):
45 self._assert_type(other)
46 return id(self._wrapped) < id(other._wrapped) # pylint: disable=protected-access
48 def __gt__(self, other):
49 self._assert_type(other)
50 return id(self._wrapped) > id(other._wrapped) # pylint: disable=protected-access
52 def __eq__(self, other):
53 if other is None:
54 return False
55 self._assert_type(other)
56 return self._wrapped is other._wrapped # pylint: disable=protected-access
58 def __ne__(self, other):
59 return not self.__eq__(other)
61 def __hash__(self):
62 # Wrapper id() is also fine for weakrefs. In fact, we rely on
63 # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
64 # weakref.ref(a) in _WeakObjectIdentityWrapper.
65 return id(self._wrapped)
67 def __repr__(self):
68 return "<{} wrapping {!r}>".format(type(self).__name__, self._wrapped)
71class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
73 __slots__ = ()
75 def __init__(self, wrapped):
76 super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped))
78 @property
79 def unwrapped(self):
80 return self._wrapped()
83class Reference(_ObjectIdentityWrapper):
84 """Reference that refers an object.
86 ```python
87 x = [1]
88 y = [1]
90 x_ref1 = Reference(x)
91 x_ref2 = Reference(x)
92 y_ref2 = Reference(y)
94 print(x_ref1 == x_ref2)
95 ==> True
97 print(x_ref1 == y)
98 ==> False
99 ```
100 """
102 __slots__ = ()
104 # Disabling super class' unwrapped field.
105 unwrapped = property()
107 def deref(self):
108 """Returns the referenced object.
110 ```python
111 x_ref = Reference(x)
112 print(x is x_ref.deref())
113 ==> True
114 ```
115 """
116 return self._wrapped
119class ObjectIdentityDictionary(collections.abc.MutableMapping):
120 """A mutable mapping data structure which compares using "is".
122 This is necessary because we have trackable objects (_ListWrapper) which
123 have behavior identical to built-in Python lists (including being unhashable
124 and comparing based on the equality of their contents by default).
125 """
127 __slots__ = ["_storage"]
129 def __init__(self):
130 self._storage = {}
132 def _wrap_key(self, key):
133 return _ObjectIdentityWrapper(key)
135 def __getitem__(self, key):
136 return self._storage[self._wrap_key(key)]
138 def __setitem__(self, key, value):
139 self._storage[self._wrap_key(key)] = value
141 def __delitem__(self, key):
142 del self._storage[self._wrap_key(key)]
144 def __len__(self):
145 return len(self._storage)
147 def __iter__(self):
148 for key in self._storage:
149 yield key.unwrapped
151 def __repr__(self):
152 return "ObjectIdentityDictionary(%s)" % repr(self._storage)
155class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
156 """Like weakref.WeakKeyDictionary, but compares objects with "is"."""
158 __slots__ = ["__weakref__"]
160 def _wrap_key(self, key):
161 return _WeakObjectIdentityWrapper(key)
163 def __len__(self):
164 # Iterate, discarding old weak refs
165 return len(list(self._storage))
167 def __iter__(self):
168 keys = self._storage.keys()
169 for key in keys:
170 unwrapped = key.unwrapped
171 if unwrapped is None:
172 del self[key]
173 else:
174 yield unwrapped
177class ObjectIdentitySet(collections.abc.MutableSet):
178 """Like the built-in set, but compares objects with "is"."""
180 __slots__ = ["_storage", "__weakref__"]
182 def __init__(self, *args):
183 self._storage = set(self._wrap_key(obj) for obj in list(*args))
185 def __le__(self, other: Set[Any]) -> bool:
186 if not isinstance(other, Set):
187 return NotImplemented
188 if len(self) > len(other):
189 return False
190 for item in self._storage:
191 if item not in other:
192 return False
193 return True
195 def __ge__(self, other: Set[Any]) -> bool:
196 if not isinstance(other, Set):
197 return NotImplemented
198 if len(self) < len(other):
199 return False
200 for item in other:
201 if item not in self:
202 return False
203 return True
205 @staticmethod
206 def _from_storage(storage):
207 result = ObjectIdentitySet()
208 result._storage = storage # pylint: disable=protected-access
209 return result
211 def _wrap_key(self, key):
212 return _ObjectIdentityWrapper(key)
214 def __contains__(self, key):
215 return self._wrap_key(key) in self._storage
217 def discard(self, key):
218 self._storage.discard(self._wrap_key(key))
220 def add(self, key):
221 self._storage.add(self._wrap_key(key))
223 def update(self, items):
224 self._storage.update([self._wrap_key(item) for item in items])
226 def clear(self):
227 self._storage.clear()
229 def intersection(self, items):
230 return self._storage.intersection([self._wrap_key(item) for item in items])
232 def difference(self, items):
233 return ObjectIdentitySet._from_storage(
234 self._storage.difference([self._wrap_key(item) for item in items]))
236 def __len__(self):
237 return len(self._storage)
239 def __iter__(self):
240 keys = list(self._storage)
241 for key in keys:
242 yield key.unwrapped
245class ObjectIdentityWeakSet(ObjectIdentitySet):
246 """Like weakref.WeakSet, but compares objects with "is"."""
248 __slots__ = ()
250 def _wrap_key(self, key):
251 return _WeakObjectIdentityWrapper(key)
253 def __len__(self):
254 # Iterate, discarding old weak refs
255 return len([_ for _ in self])
257 def __iter__(self):
258 keys = list(self._storage)
259 for key in keys:
260 unwrapped = key.unwrapped
261 if unwrapped is None:
262 self.discard(key)
263 else:
264 yield unwrapped
265# LINT.ThenChange(//tensorflow/python/util/object_identity.py)