Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/sets_impl.py: 46%
57 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of tf.sets."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import sparse_tensor
20from tensorflow.python.ops import gen_set_ops
21from tensorflow.python.util import dispatch
22from tensorflow.python.util.tf_export import tf_export
24_VALID_DTYPES = frozenset([
25 dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8,
26 dtypes.uint16, dtypes.string
27])
30@tf_export("sets.size", v1=["sets.size", "sets.set_size"])
31@dispatch.add_dispatch_support
32def set_size(a, validate_indices=True):
33 """Compute number of unique elements along last dimension of `a`.
35 Args:
36 a: `SparseTensor`, with indices sorted in row-major order.
37 validate_indices: Whether to validate the order and range of sparse indices
38 in `a`. Note that setting this to `false` allows for undefined behavior
39 when calling this function with invalid indices.
41 Returns:
42 `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with
43 rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the
44 number of unique elements in the corresponding `[0...n-1]` dimension of `a`.
46 Raises:
47 TypeError: If `a` is an invalid types.
48 """
49 a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
50 if not isinstance(a, sparse_tensor.SparseTensor):
51 raise TypeError("Expected `SparseTensor`, got %s." % a)
52 if a.values.dtype.base_dtype not in _VALID_DTYPES:
53 raise TypeError(
54 f"Invalid dtype `{a.values.dtype}` not in supported dtypes: "
55 f"`{_VALID_DTYPES}`.")
56 # pylint: disable=protected-access
57 return gen_set_ops.set_size(a.indices, a.values, a.dense_shape,
58 validate_indices)
61ops.NotDifferentiable("SetSize")
63ops.NotDifferentiable("DenseToDenseSetOperation")
64ops.NotDifferentiable("DenseToSparseSetOperation")
65ops.NotDifferentiable("SparseToSparseSetOperation")
68def _convert_to_tensors_or_sparse_tensors(a, b):
69 """Convert to tensor types, and flip order if necessary.
71 Args:
72 a: `Tensor` or `SparseTensor` of the same type as `b`.
73 b: `Tensor` or `SparseTensor` of the same type as `a`.
75 Returns:
76 Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to
77 `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has
78 been flipped to make it dense,sparse instead of sparse,dense (since the set
79 ops do not support the latter).
80 """
81 a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
82 if a.dtype.base_dtype not in _VALID_DTYPES:
83 raise TypeError(
84 f"'a' has invalid dtype `{a.dtype}` not in supported dtypes: "
85 f"`{_VALID_DTYPES}`.")
86 b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b")
87 if b.dtype.base_dtype != a.dtype.base_dtype:
88 raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
89 if (isinstance(a, sparse_tensor.SparseTensor) and
90 not isinstance(b, sparse_tensor.SparseTensor)):
91 return b, a, True
92 return a, b, False
95def _set_operation(a, b, set_operation, validate_indices=True):
96 """Compute set operation of elements in last dimension of `a` and `b`.
98 All but the last dimension of `a` and `b` must match.
100 Args:
101 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
102 must be sorted in row-major order.
103 b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
104 `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be sorted
105 in row-major order.
106 set_operation: String indicating set operation. See
107 SetOperationOp::SetOperationFromContext for valid values.
108 validate_indices: Whether to validate the order and range of sparse indices
109 in `a` and `b`.
111 Returns:
112 A `SparseTensor` with the same rank as `a` and `b`, and all but the last
113 dimension the same. Elements along the last dimension contain the results
114 of the set operation.
116 Raises:
117 TypeError: If inputs are invalid types.
118 ValueError: If `a` is sparse and `b` is dense.
119 """
120 if isinstance(a, sparse_tensor.SparseTensor):
121 if isinstance(b, sparse_tensor.SparseTensor):
122 indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation(
123 a.indices, a.values, a.dense_shape, b.indices, b.values,
124 b.dense_shape, set_operation, validate_indices)
125 else:
126 raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
127 "Please flip the order of your inputs.")
128 elif isinstance(b, sparse_tensor.SparseTensor):
129 indices, values, shape = gen_set_ops.dense_to_sparse_set_operation(
130 a, b.indices, b.values, b.dense_shape, set_operation, validate_indices)
131 else:
132 indices, values, shape = gen_set_ops.dense_to_dense_set_operation(
133 a, b, set_operation, validate_indices)
134 return sparse_tensor.SparseTensor(indices, values, shape)
137@tf_export(
138 "sets.intersection", v1=["sets.intersection", "sets.set_intersection"])
139@dispatch.add_dispatch_support
140def set_intersection(a, b, validate_indices=True):
141 """Compute set intersection of elements in last dimension of `a` and `b`.
143 All but the last dimension of `a` and `b` must match.
145 Example:
147 ```python
148 import tensorflow as tf
149 import collections
151 # Represent the following array of sets as a sparse tensor:
152 # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]])
153 a = collections.OrderedDict([
154 ((0, 0, 0), 1),
155 ((0, 0, 1), 2),
156 ((0, 1, 0), 3),
157 ((1, 0, 0), 4),
158 ((1, 1, 0), 5),
159 ((1, 1, 1), 6),
160 ])
161 a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()),
162 dense_shape=[2,2,2])
164 # b = np.array([[{1}, {}], [{4}, {5, 6, 7, 8}]])
165 b = collections.OrderedDict([
166 ((0, 0, 0), 1),
167 ((1, 0, 0), 4),
168 ((1, 1, 0), 5),
169 ((1, 1, 1), 6),
170 ((1, 1, 2), 7),
171 ((1, 1, 3), 8),
172 ])
173 b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()),
174 dense_shape=[2, 2, 4])
176 # `tf.sets.intersection` is applied to each aligned pair of sets.
177 tf.sets.intersection(a, b)
179 # The result will be equivalent to either of:
180 #
181 # np.array([[{1}, {}], [{4}, {5, 6}]])
182 #
183 # collections.OrderedDict([
184 # ((0, 0, 0), 1),
185 # ((1, 0, 0), 4),
186 # ((1, 1, 0), 5),
187 # ((1, 1, 1), 6),
188 # ])
189 ```
191 Args:
192 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
193 must be sorted in row-major order.
194 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
195 must be sorted in row-major order.
196 validate_indices: Whether to validate the order and range of sparse indices
197 in `a` and `b`.
199 Returns:
200 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
201 the last dimension the same. Elements along the last dimension contain the
202 intersections.
203 """
204 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
205 return _set_operation(a, b, "intersection", validate_indices)
208@tf_export("sets.difference", v1=["sets.difference", "sets.set_difference"])
209@dispatch.add_dispatch_support
210def set_difference(a, b, aminusb=True, validate_indices=True):
211 """Compute set difference of elements in last dimension of `a` and `b`.
213 All but the last dimension of `a` and `b` must match.
215 Example:
217 ```python
218 import tensorflow as tf
219 import collections
221 # Represent the following array of sets as a sparse tensor:
222 # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]])
223 a = collections.OrderedDict([
224 ((0, 0, 0), 1),
225 ((0, 0, 1), 2),
226 ((0, 1, 0), 3),
227 ((1, 0, 0), 4),
228 ((1, 1, 0), 5),
229 ((1, 1, 1), 6),
230 ])
231 a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()),
232 dense_shape=[2, 2, 2])
234 # np.array([[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]])
235 b = collections.OrderedDict([
236 ((0, 0, 0), 1),
237 ((0, 0, 1), 3),
238 ((0, 1, 0), 2),
239 ((1, 0, 0), 4),
240 ((1, 0, 1), 5),
241 ((1, 1, 0), 5),
242 ((1, 1, 1), 6),
243 ((1, 1, 2), 7),
244 ((1, 1, 3), 8),
245 ])
246 b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()),
247 dense_shape=[2, 2, 4])
249 # `set_difference` is applied to each aligned pair of sets.
250 tf.sets.difference(a, b)
252 # The result will be equivalent to either of:
253 #
254 # np.array([[{2}, {3}], [{}, {}]])
255 #
256 # collections.OrderedDict([
257 # ((0, 0, 0), 2),
258 # ((0, 1, 0), 3),
259 # ])
260 ```
262 Args:
263 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
264 must be sorted in row-major order.
265 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
266 must be sorted in row-major order.
267 aminusb: Whether to subtract `b` from `a`, vs vice versa.
268 validate_indices: Whether to validate the order and range of sparse indices
269 in `a` and `b`.
271 Returns:
272 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
273 the last dimension the same. Elements along the last dimension contain the
274 differences.
276 Raises:
277 TypeError: If inputs are invalid types, or if `a` and `b` have
278 different types.
279 ValueError: If `a` is sparse and `b` is dense.
280 errors_impl.InvalidArgumentError: If the shapes of `a` and `b` do not
281 match in any dimension other than the last dimension.
282 """
283 a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b)
284 if flipped:
285 aminusb = not aminusb
286 return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices)
289@tf_export("sets.union", v1=["sets.union", "sets.set_union"])
290@dispatch.add_dispatch_support
291def set_union(a, b, validate_indices=True):
292 """Compute set union of elements in last dimension of `a` and `b`.
294 All but the last dimension of `a` and `b` must match.
296 Example:
298 ```python
299 import tensorflow as tf
300 import collections
302 # [[{1, 2}, {3}], [{4}, {5, 6}]]
303 a = collections.OrderedDict([
304 ((0, 0, 0), 1),
305 ((0, 0, 1), 2),
306 ((0, 1, 0), 3),
307 ((1, 0, 0), 4),
308 ((1, 1, 0), 5),
309 ((1, 1, 1), 6),
310 ])
311 a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()),
312 dense_shape=[2, 2, 2])
314 # [[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]]
315 b = collections.OrderedDict([
316 ((0, 0, 0), 1),
317 ((0, 0, 1), 3),
318 ((0, 1, 0), 2),
319 ((1, 0, 0), 4),
320 ((1, 0, 1), 5),
321 ((1, 1, 0), 5),
322 ((1, 1, 1), 6),
323 ((1, 1, 2), 7),
324 ((1, 1, 3), 8),
325 ])
326 b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()),
327 dense_shape=[2, 2, 4])
329 # `set_union` is applied to each aligned pair of sets.
330 tf.sets.union(a, b)
332 # The result will be a equivalent to either of:
333 #
334 # np.array([[{1, 2, 3}, {2, 3}], [{4, 5}, {5, 6, 7, 8}]])
335 #
336 # collections.OrderedDict([
337 # ((0, 0, 0), 1),
338 # ((0, 0, 1), 2),
339 # ((0, 0, 2), 3),
340 # ((0, 1, 0), 2),
341 # ((0, 1, 1), 3),
342 # ((1, 0, 0), 4),
343 # ((1, 0, 1), 5),
344 # ((1, 1, 0), 5),
345 # ((1, 1, 1), 6),
346 # ((1, 1, 2), 7),
347 # ((1, 1, 3), 8),
348 # ])
349 ```
351 Args:
352 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
353 must be sorted in row-major order.
354 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
355 must be sorted in row-major order.
356 validate_indices: Whether to validate the order and range of sparse indices
357 in `a` and `b`.
359 Returns:
360 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
361 the last dimension the same. Elements along the last dimension contain the
362 unions.
363 """
364 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
365 return _set_operation(a, b, "union", validate_indices)