Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/utils/object_identity.py: 50%
117 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Utilities for collecting objects based on "is" comparison."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
17import collections
18import weakref
21# LINT.IfChange
22class _ObjectIdentityWrapper:
23 """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
25 Since __eq__ is based on object identity, it's safe to also define __hash__
26 based on object ids. This lets us add unhashable types like trackable
27 _ListWrapper objects to object-identity collections.
28 """
30 __slots__ = ["_wrapped", "__weakref__"]
32 def __init__(self, wrapped):
33 self._wrapped = wrapped
35 @property
36 def unwrapped(self):
37 return self._wrapped
39 def _assert_type(self, other):
40 if not isinstance(other, _ObjectIdentityWrapper):
41 raise TypeError(
42 "Cannot compare wrapped object with unwrapped object. "
43 "Expect the object to be `_ObjectIdentityWrapper`. "
44 f"Got: {other}"
45 )
47 def __lt__(self, other):
48 self._assert_type(other)
49 return id(self._wrapped) < id(other._wrapped)
51 def __gt__(self, other):
52 self._assert_type(other)
53 return id(self._wrapped) > id(other._wrapped)
55 def __eq__(self, other):
56 if other is None:
57 return False
58 self._assert_type(other)
59 return self._wrapped is other._wrapped
61 def __ne__(self, other):
62 return not self.__eq__(other)
64 def __hash__(self):
65 # Wrapper id() is also fine for weakrefs. In fact, we rely on
66 # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
67 # weakref.ref(a) in _WeakObjectIdentityWrapper.
68 return id(self._wrapped)
70 def __repr__(self):
71 return f"<{type(self).__name__} wrapping {self._wrapped!r}>"
74class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
76 __slots__ = ()
78 def __init__(self, wrapped):
79 super().__init__(weakref.ref(wrapped))
81 @property
82 def unwrapped(self):
83 return self._wrapped()
86class Reference(_ObjectIdentityWrapper):
87 """Reference that refers an object.
89 ```python
90 x = [1]
91 y = [1]
93 x_ref1 = Reference(x)
94 x_ref2 = Reference(x)
95 y_ref2 = Reference(y)
97 print(x_ref1 == x_ref2)
98 ==> True
100 print(x_ref1 == y)
101 ==> False
102 ```
103 """
105 __slots__ = ()
107 # Disabling super class' unwrapped field.
108 unwrapped = property()
110 def deref(self):
111 """Returns the referenced object.
113 ```python
114 x_ref = Reference(x)
115 print(x is x_ref.deref())
116 ==> True
117 ```
118 """
119 return self._wrapped
122class ObjectIdentityDictionary(collections.abc.MutableMapping):
123 """A mutable mapping data structure which compares using "is".
125 This is necessary because we have trackable objects (_ListWrapper) which
126 have behavior identical to built-in Python lists (including being unhashable
127 and comparing based on the equality of their contents by default).
128 """
130 __slots__ = ["_storage"]
132 def __init__(self):
133 self._storage = {}
135 def _wrap_key(self, key):
136 return _ObjectIdentityWrapper(key)
138 def __getitem__(self, key):
139 return self._storage[self._wrap_key(key)]
141 def __setitem__(self, key, value):
142 self._storage[self._wrap_key(key)] = value
144 def __delitem__(self, key):
145 del self._storage[self._wrap_key(key)]
147 def __len__(self):
148 return len(self._storage)
150 def __iter__(self):
151 for key in self._storage:
152 yield key.unwrapped
154 def __repr__(self):
155 return f"ObjectIdentityDictionary({repr(self._storage)})"
158class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
159 """Like weakref.WeakKeyDictionary, but compares objects with "is"."""
161 __slots__ = ["__weakref__"]
163 def _wrap_key(self, key):
164 return _WeakObjectIdentityWrapper(key)
166 def __len__(self):
167 # Iterate, discarding old weak refs
168 return len(list(self._storage))
170 def __iter__(self):
171 keys = self._storage.keys()
172 for key in keys:
173 unwrapped = key.unwrapped
174 if unwrapped is None:
175 del self[key]
176 else:
177 yield unwrapped
180class ObjectIdentitySet(collections.abc.MutableSet):
181 """Like the built-in set, but compares objects with "is"."""
183 __slots__ = ["_storage", "__weakref__"]
185 def __init__(self, *args):
186 self._storage = set(self._wrap_key(obj) for obj in list(*args))
188 @staticmethod
189 def _from_storage(storage):
190 result = ObjectIdentitySet()
191 result._storage = storage
192 return result
194 def _wrap_key(self, key):
195 return _ObjectIdentityWrapper(key)
197 def __contains__(self, key):
198 return self._wrap_key(key) in self._storage
200 def discard(self, key):
201 self._storage.discard(self._wrap_key(key))
203 def add(self, key):
204 self._storage.add(self._wrap_key(key))
206 def update(self, items):
207 self._storage.update([self._wrap_key(item) for item in items])
209 def clear(self):
210 self._storage.clear()
212 def intersection(self, items):
213 return self._storage.intersection(
214 [self._wrap_key(item) for item in items]
215 )
217 def difference(self, items):
218 return ObjectIdentitySet._from_storage(
219 self._storage.difference([self._wrap_key(item) for item in items])
220 )
222 def __len__(self):
223 return len(self._storage)
225 def __iter__(self):
226 keys = list(self._storage)
227 for key in keys:
228 yield key.unwrapped
231class ObjectIdentityWeakSet(ObjectIdentitySet):
232 """Like weakref.WeakSet, but compares objects with "is"."""
234 __slots__ = ()
236 def _wrap_key(self, key):
237 return _WeakObjectIdentityWrapper(key)
239 def __len__(self):
240 # Iterate, discarding old weak refs
241 return len([_ for _ in self])
243 def __iter__(self):
244 keys = list(self._storage)
245 for key in keys:
246 unwrapped = key.unwrapped
247 if unwrapped is None:
248 self.discard(key)
249 else:
250 yield unwrapped
253# LINT.ThenChange(//tensorflow/python/util/object_identity.py)