Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/sets_impl.py: 46%

57 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implementation of tf.sets.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.framework import sparse_tensor 

20from tensorflow.python.ops import gen_set_ops 

21from tensorflow.python.util import dispatch 

22from tensorflow.python.util.tf_export import tf_export 

23 

24_VALID_DTYPES = frozenset([ 

25 dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, 

26 dtypes.uint16, dtypes.string 

27]) 

28 

29 

30@tf_export("sets.size", v1=["sets.size", "sets.set_size"]) 

31@dispatch.add_dispatch_support 

32def set_size(a, validate_indices=True): 

33 """Compute number of unique elements along last dimension of `a`. 

34 

35 Args: 

36 a: `SparseTensor`, with indices sorted in row-major order. 

37 validate_indices: Whether to validate the order and range of sparse indices 

38 in `a`. Note that setting this to `false` allows for undefined behavior 

39 when calling this function with invalid indices. 

40 

41 Returns: 

42 `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with 

43 rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the 

44 number of unique elements in the corresponding `[0...n-1]` dimension of `a`. 

45 

46 Raises: 

47 TypeError: If `a` is an invalid types. 

48 """ 

49 a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") 

50 if not isinstance(a, sparse_tensor.SparseTensor): 

51 raise TypeError("Expected `SparseTensor`, got %s." % a) 

52 if a.values.dtype.base_dtype not in _VALID_DTYPES: 

53 raise TypeError( 

54 f"Invalid dtype `{a.values.dtype}` not in supported dtypes: " 

55 f"`{_VALID_DTYPES}`.") 

56 # pylint: disable=protected-access 

57 return gen_set_ops.set_size(a.indices, a.values, a.dense_shape, 

58 validate_indices) 

59 

60 

61ops.NotDifferentiable("SetSize") 

62 

63ops.NotDifferentiable("DenseToDenseSetOperation") 

64ops.NotDifferentiable("DenseToSparseSetOperation") 

65ops.NotDifferentiable("SparseToSparseSetOperation") 

66 

67 

68def _convert_to_tensors_or_sparse_tensors(a, b): 

69 """Convert to tensor types, and flip order if necessary. 

70 

71 Args: 

72 a: `Tensor` or `SparseTensor` of the same type as `b`. 

73 b: `Tensor` or `SparseTensor` of the same type as `a`. 

74 

75 Returns: 

76 Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to 

77 `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has 

78 been flipped to make it dense,sparse instead of sparse,dense (since the set 

79 ops do not support the latter). 

80 """ 

81 a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") 

82 if a.dtype.base_dtype not in _VALID_DTYPES: 

83 raise TypeError( 

84 f"'a' has invalid dtype `{a.dtype}` not in supported dtypes: " 

85 f"`{_VALID_DTYPES}`.") 

86 b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b") 

87 if b.dtype.base_dtype != a.dtype.base_dtype: 

88 raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype)) 

89 if (isinstance(a, sparse_tensor.SparseTensor) and 

90 not isinstance(b, sparse_tensor.SparseTensor)): 

91 return b, a, True 

92 return a, b, False 

93 

94 

95def _set_operation(a, b, set_operation, validate_indices=True): 

96 """Compute set operation of elements in last dimension of `a` and `b`. 

97 

98 All but the last dimension of `a` and `b` must match. 

99 

100 Args: 

101 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 

102 must be sorted in row-major order. 

103 b: `Tensor` or `SparseTensor` of the same type as `a`. Must be 

104 `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be sorted 

105 in row-major order. 

106 set_operation: String indicating set operation. See 

107 SetOperationOp::SetOperationFromContext for valid values. 

108 validate_indices: Whether to validate the order and range of sparse indices 

109 in `a` and `b`. 

110 

111 Returns: 

112 A `SparseTensor` with the same rank as `a` and `b`, and all but the last 

113 dimension the same. Elements along the last dimension contain the results 

114 of the set operation. 

115 

116 Raises: 

117 TypeError: If inputs are invalid types. 

118 ValueError: If `a` is sparse and `b` is dense. 

119 """ 

120 if isinstance(a, sparse_tensor.SparseTensor): 

121 if isinstance(b, sparse_tensor.SparseTensor): 

122 indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation( 

123 a.indices, a.values, a.dense_shape, b.indices, b.values, 

124 b.dense_shape, set_operation, validate_indices) 

125 else: 

126 raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. " 

127 "Please flip the order of your inputs.") 

128 elif isinstance(b, sparse_tensor.SparseTensor): 

129 indices, values, shape = gen_set_ops.dense_to_sparse_set_operation( 

130 a, b.indices, b.values, b.dense_shape, set_operation, validate_indices) 

131 else: 

132 indices, values, shape = gen_set_ops.dense_to_dense_set_operation( 

133 a, b, set_operation, validate_indices) 

134 return sparse_tensor.SparseTensor(indices, values, shape) 

135 

136 

137@tf_export( 

138 "sets.intersection", v1=["sets.intersection", "sets.set_intersection"]) 

139@dispatch.add_dispatch_support 

140def set_intersection(a, b, validate_indices=True): 

141 """Compute set intersection of elements in last dimension of `a` and `b`. 

142 

143 All but the last dimension of `a` and `b` must match. 

144 

145 Example: 

146 

147 ```python 

148 import tensorflow as tf 

149 import collections 

150 

151 # Represent the following array of sets as a sparse tensor: 

152 # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]]) 

153 a = collections.OrderedDict([ 

154 ((0, 0, 0), 1), 

155 ((0, 0, 1), 2), 

156 ((0, 1, 0), 3), 

157 ((1, 0, 0), 4), 

158 ((1, 1, 0), 5), 

159 ((1, 1, 1), 6), 

160 ]) 

161 a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()), 

162 dense_shape=[2,2,2]) 

163 

164 # b = np.array([[{1}, {}], [{4}, {5, 6, 7, 8}]]) 

165 b = collections.OrderedDict([ 

166 ((0, 0, 0), 1), 

167 ((1, 0, 0), 4), 

168 ((1, 1, 0), 5), 

169 ((1, 1, 1), 6), 

170 ((1, 1, 2), 7), 

171 ((1, 1, 3), 8), 

172 ]) 

173 b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()), 

174 dense_shape=[2, 2, 4]) 

175 

176 # `tf.sets.intersection` is applied to each aligned pair of sets. 

177 tf.sets.intersection(a, b) 

178 

179 # The result will be equivalent to either of: 

180 # 

181 # np.array([[{1}, {}], [{4}, {5, 6}]]) 

182 # 

183 # collections.OrderedDict([ 

184 # ((0, 0, 0), 1), 

185 # ((1, 0, 0), 4), 

186 # ((1, 1, 0), 5), 

187 # ((1, 1, 1), 6), 

188 # ]) 

189 ``` 

190 

191 Args: 

192 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 

193 must be sorted in row-major order. 

194 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 

195 must be sorted in row-major order. 

196 validate_indices: Whether to validate the order and range of sparse indices 

197 in `a` and `b`. 

198 

199 Returns: 

200 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 

201 the last dimension the same. Elements along the last dimension contain the 

202 intersections. 

203 """ 

204 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) 

205 return _set_operation(a, b, "intersection", validate_indices) 

206 

207 

208@tf_export("sets.difference", v1=["sets.difference", "sets.set_difference"]) 

209@dispatch.add_dispatch_support 

210def set_difference(a, b, aminusb=True, validate_indices=True): 

211 """Compute set difference of elements in last dimension of `a` and `b`. 

212 

213 All but the last dimension of `a` and `b` must match. 

214 

215 Example: 

216 

217 ```python 

218 import tensorflow as tf 

219 import collections 

220 

221 # Represent the following array of sets as a sparse tensor: 

222 # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]]) 

223 a = collections.OrderedDict([ 

224 ((0, 0, 0), 1), 

225 ((0, 0, 1), 2), 

226 ((0, 1, 0), 3), 

227 ((1, 0, 0), 4), 

228 ((1, 1, 0), 5), 

229 ((1, 1, 1), 6), 

230 ]) 

231 a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()), 

232 dense_shape=[2, 2, 2]) 

233 

234 # np.array([[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]]) 

235 b = collections.OrderedDict([ 

236 ((0, 0, 0), 1), 

237 ((0, 0, 1), 3), 

238 ((0, 1, 0), 2), 

239 ((1, 0, 0), 4), 

240 ((1, 0, 1), 5), 

241 ((1, 1, 0), 5), 

242 ((1, 1, 1), 6), 

243 ((1, 1, 2), 7), 

244 ((1, 1, 3), 8), 

245 ]) 

246 b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()), 

247 dense_shape=[2, 2, 4]) 

248 

249 # `set_difference` is applied to each aligned pair of sets. 

250 tf.sets.difference(a, b) 

251 

252 # The result will be equivalent to either of: 

253 # 

254 # np.array([[{2}, {3}], [{}, {}]]) 

255 # 

256 # collections.OrderedDict([ 

257 # ((0, 0, 0), 2), 

258 # ((0, 1, 0), 3), 

259 # ]) 

260 ``` 

261 

262 Args: 

263 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 

264 must be sorted in row-major order. 

265 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 

266 must be sorted in row-major order. 

267 aminusb: Whether to subtract `b` from `a`, vs vice versa. 

268 validate_indices: Whether to validate the order and range of sparse indices 

269 in `a` and `b`. 

270 

271 Returns: 

272 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 

273 the last dimension the same. Elements along the last dimension contain the 

274 differences. 

275 

276 Raises: 

277 TypeError: If inputs are invalid types, or if `a` and `b` have 

278 different types. 

279 ValueError: If `a` is sparse and `b` is dense. 

280 errors_impl.InvalidArgumentError: If the shapes of `a` and `b` do not 

281 match in any dimension other than the last dimension. 

282 """ 

283 a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b) 

284 if flipped: 

285 aminusb = not aminusb 

286 return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices) 

287 

288 

289@tf_export("sets.union", v1=["sets.union", "sets.set_union"]) 

290@dispatch.add_dispatch_support 

291def set_union(a, b, validate_indices=True): 

292 """Compute set union of elements in last dimension of `a` and `b`. 

293 

294 All but the last dimension of `a` and `b` must match. 

295 

296 Example: 

297 

298 ```python 

299 import tensorflow as tf 

300 import collections 

301 

302 # [[{1, 2}, {3}], [{4}, {5, 6}]] 

303 a = collections.OrderedDict([ 

304 ((0, 0, 0), 1), 

305 ((0, 0, 1), 2), 

306 ((0, 1, 0), 3), 

307 ((1, 0, 0), 4), 

308 ((1, 1, 0), 5), 

309 ((1, 1, 1), 6), 

310 ]) 

311 a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()), 

312 dense_shape=[2, 2, 2]) 

313 

314 # [[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]] 

315 b = collections.OrderedDict([ 

316 ((0, 0, 0), 1), 

317 ((0, 0, 1), 3), 

318 ((0, 1, 0), 2), 

319 ((1, 0, 0), 4), 

320 ((1, 0, 1), 5), 

321 ((1, 1, 0), 5), 

322 ((1, 1, 1), 6), 

323 ((1, 1, 2), 7), 

324 ((1, 1, 3), 8), 

325 ]) 

326 b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()), 

327 dense_shape=[2, 2, 4]) 

328 

329 # `set_union` is applied to each aligned pair of sets. 

330 tf.sets.union(a, b) 

331 

332 # The result will be a equivalent to either of: 

333 # 

334 # np.array([[{1, 2, 3}, {2, 3}], [{4, 5}, {5, 6, 7, 8}]]) 

335 # 

336 # collections.OrderedDict([ 

337 # ((0, 0, 0), 1), 

338 # ((0, 0, 1), 2), 

339 # ((0, 0, 2), 3), 

340 # ((0, 1, 0), 2), 

341 # ((0, 1, 1), 3), 

342 # ((1, 0, 0), 4), 

343 # ((1, 0, 1), 5), 

344 # ((1, 1, 0), 5), 

345 # ((1, 1, 1), 6), 

346 # ((1, 1, 2), 7), 

347 # ((1, 1, 3), 8), 

348 # ]) 

349 ``` 

350 

351 Args: 

352 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 

353 must be sorted in row-major order. 

354 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 

355 must be sorted in row-major order. 

356 validate_indices: Whether to validate the order and range of sparse indices 

357 in `a` and `b`. 

358 

359 Returns: 

360 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 

361 the last dimension the same. Elements along the last dimension contain the 

362 unions. 

363 """ 

364 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) 

365 return _set_operation(a, b, "union", validate_indices)