Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_operators.py: 84%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operator overloads for `RaggedTensor`.""" 

16 

17from tensorflow.python.framework import ops 

18from tensorflow.python.ops import math_ops 

19from tensorflow.python.ops.ragged import ragged_getitem 

20from tensorflow.python.ops.ragged import ragged_tensor 

21from tensorflow.python.util import tf_decorator 

22 

23 

24# ============================================================================= 

25# Equality Docstring 

26# ============================================================================= 

27def ragged_eq(self, other): # pylint: disable=g-doc-args 

28 """Returns result of elementwise `==` or False if not broadcast-compatible. 

29 

30 Compares two ragged tensors elemewise for equality if they are 

31 broadcast-compatible; or returns False if they are not 

32 [broadcast-compatible](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). 

33 

34 Note that this behavior differs from `tf.math.equal`, which raises an 

35 exception if the two ragged tensors are not broadcast-compatible. 

36 

37 For example: 

38 

39 >>> rt1 = tf.ragged.constant([[1, 2], [3]]) 

40 >>> rt1 == rt1 

41 <tf.RaggedTensor [[True, True], [True]]> 

42 

43 >>> rt2 = tf.ragged.constant([[1, 2], [4]]) 

44 >>> rt1 == rt2 

45 <tf.RaggedTensor [[True, True], [False]]> 

46 

47 >>> rt3 = tf.ragged.constant([[1, 2], [3, 4]]) 

48 >>> # rt1 and rt3 are not broadcast-compatible. 

49 >>> rt1 == rt3 

50 False 

51 

52 >>> # You can also compare a `tf.RaggedTensor` to a `tf.Tensor`. 

53 >>> t = tf.constant([[1, 2], [3, 4]]) 

54 >>> rt1 == t 

55 False 

56 >>> t == rt1 

57 False 

58 >>> rt4 = tf.ragged.constant([[1, 2], [3, 4]]) 

59 >>> rt4 == t 

60 <tf.RaggedTensor [[True, True], [True, True]]> 

61 >>> t == rt4 

62 <tf.RaggedTensor [[True, True], [True, True]]> 

63 

64 Args: 

65 other: The right-hand side of the `==` operator. 

66 

67 Returns: 

68 The ragged tensor result of the elementwise `==` operation, or `False` if 

69 the arguments are not broadcast-compatible. 

70 """ 

71 return math_ops.tensor_equals(self, other) 

72 

73 

74# ============================================================================= 

75# Ordering Docstring 

76# ============================================================================= 

77def ragged_ge(self, other): # pylint: disable=g-doc-args 

78 """Elementwise `>=` comparison of two convertible-to-ragged-tensor values. 

79 

80 Computes the elemewise `>=` comparison of two values that are convertible to 

81 ragged tenors, with [broadcasting] 

82 (http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) support. 

83 Raises an exception if two values are not broadcast-compatible. 

84 

85 For example: 

86 

87 >>> rt1 = tf.ragged.constant([[1, 2], [3]]) 

88 >>> rt1 >= rt1 

89 <tf.RaggedTensor [[True, True], [True]]> 

90 

91 >>> rt2 = tf.ragged.constant([[2, 1], [3]]) 

92 >>> rt1 >= rt2 

93 <tf.RaggedTensor [[False, True], [True]]> 

94 

95 >>> rt3 = tf.ragged.constant([[1, 2], [3, 4]]) 

96 >>> # rt1 and rt3 are not broadcast-compatible. 

97 >>> rt1 >= rt3 

98 Traceback (most recent call last): 

99 ... 

100 InvalidArgumentError: ... 

101 

102 >>> # You can also compare a `tf.RaggedTensor` to a `tf.Tensor`. 

103 >>> rt4 = tf.ragged.constant([[1, 2],[3, 4]]) 

104 >>> t1 = tf.constant([[2, 1], [4, 3]]) 

105 >>> rt4 >= t1 

106 <tf.RaggedTensor [[False, True], 

107 [False, True]]> 

108 >>> t1 >= rt4 

109 <tf.RaggedTensor [[True, False], 

110 [True, False]]> 

111 

112 >>> # Compares a `tf.RaggedTensor` to a `tf.Tensor` with broadcasting. 

113 >>> t2 = tf.constant([[2]]) 

114 >>> rt4 >= t2 

115 <tf.RaggedTensor [[False, True], 

116 [True, True]]> 

117 >>> t2 >= rt4 

118 <tf.RaggedTensor [[True, True], 

119 [False, False]]> 

120 

121 Args: 

122 other: The right-hand side of the `>=` operator. 

123 

124 Returns: 

125 A `tf.RaggedTensor` of dtype `tf.bool` with the shape that `self` and 

126 `other` broadcast to. 

127 

128 Raises: 

129 InvalidArgumentError: If `self` and `other` are not broadcast-compatible. 

130 """ 

131 return math_ops.greater_equal(self, other) 

132 

133 

134# ============================================================================= 

135# Logical Docstring 

136# ============================================================================= 

137 

138 

139# ============================================================================= 

140# Arithmetic Docstring 

141# ============================================================================= 

142def ragged_abs(self, name=None): # pylint: disable=g-doc-args 

143 r"""Computes the absolute value of a ragged tensor. 

144 

145 Given a ragged tensor of integer or floating-point values, this operation 

146 returns a ragged tensor of the same type, where each element contains the 

147 absolute value of the corresponding element in the input. 

148 

149 Given a ragged tensor `x` of complex numbers, this operation returns a tensor 

150 of type `float32` or `float64` that is the absolute value of each element in 

151 `x`. For a complex number \\(a + bj\\), its absolute value is computed as 

152 \\(\sqrt{a^2 + b^2}\\). 

153 

154 For example: 

155 

156 >>> # real number 

157 >>> x = tf.ragged.constant([[-2.2, 3.2], [-4.2]]) 

158 >>> tf.abs(x) 

159 <tf.RaggedTensor [[2.2, 3.2], [4.2]]> 

160 

161 >>> # complex number 

162 >>> x = tf.ragged.constant([[-2.2 + 4.7j], [-3.2 + 5.7j], [-4.2 + 6.7j]]) 

163 >>> tf.abs(x) 

164 <tf.RaggedTensor [[5.189412298131649], 

165 [6.536818798161687], 

166 [7.907591289387685]]> 

167 

168 Args: 

169 name: A name for the operation (optional). 

170 

171 Returns: 

172 A `RaggedTensor` of the same size and type as `x`, with absolute values. 

173 Note, for `complex64` or `complex128` input, the returned `RaggedTensor` 

174 will be of type `float32` or `float64`, respectively. 

175 """ 

176 return math_ops.abs(self, name=name) 

177 

178 

179# =========================================================================== 

180def ragged_and(self, y, name=None): # pylint: disable=g-doc-args 

181 r"""Returns the truth value of elementwise `x & y`. 

182 

183 Logical AND function. 

184 

185 Requires that `x` and `y` have the same shape or have 

186 [broadcast-compatible](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 

187 shapes. For example, `y` can be: 

188 

189 - A single Python boolean, where the result will be calculated by applying 

190 logical AND with the single element to each element in `x`. 

191 - A `tf.Tensor` object of dtype `tf.bool` of the same shape or 

192 [broadcast-compatible](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 

193 shape. In this case, the result will be the element-wise logical AND of 

194 `x` and `y`. 

195 - A `tf.RaggedTensor` object of dtype `tf.bool` of the same shape or 

196 [broadcast-compatible](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) 

197 shape. In this case, the result will be the element-wise logical AND of 

198 `x` and `y`. 

199 

200 For example: 

201 

202 >>> # `y` is a Python boolean 

203 >>> x = tf.ragged.constant([[True, False], [True]]) 

204 >>> y = True 

205 >>> x & y 

206 <tf.RaggedTensor [[True, False], [True]]> 

207 >>> tf.math.logical_and(x, y) # Equivalent of x & y 

208 <tf.RaggedTensor [[True, False], [True]]> 

209 >>> y & x 

210 <tf.RaggedTensor [[True, False], [True]]> 

211 >>> tf.math.reduce_all(x & y) # Reduce to a scalar bool Tensor. 

212 <tf.Tensor: shape=(), dtype=bool, numpy=False> 

213 

214 >>> # `y` is a tf.Tensor of the same shape. 

215 >>> x = tf.ragged.constant([[True, False], [True, False]]) 

216 >>> y = tf.constant([[True, False], [False, True]]) 

217 >>> x & y 

218 <tf.RaggedTensor [[True, False], [False, False]]> 

219 

220 >>> # `y` is a tf.Tensor of a broadcast-compatible shape. 

221 >>> x = tf.ragged.constant([[True, False], [True]]) 

222 >>> y = tf.constant([[True], [False]]) 

223 >>> x & y 

224 <tf.RaggedTensor [[True, False], [False]]> 

225 

226 >>> # `y` is a `tf.RaggedTensor` of the same shape. 

227 >>> x = tf.ragged.constant([[True, False], [True]]) 

228 >>> y = tf.ragged.constant([[False, True], [True]]) 

229 >>> x & y 

230 <tf.RaggedTensor [[False, False], [True]]> 

231 

232 >>> # `y` is a `tf.RaggedTensor` of a broadcast-compatible shape. 

233 >>> x = tf.ragged.constant([[[True, True, False]], [[]], [[True, False]]]) 

234 >>> y = tf.ragged.constant([[[True]], [[True]], [[False]]], ragged_rank=1) 

235 >>> x & y 

236 <tf.RaggedTensor [[[True, True, False]], [[]], [[False, False]]]> 

237 

238 Args: 

239 y: A Python boolean or a `tf.Tensor` or `tf.RaggedTensor` of dtype 

240 `tf.bool`. 

241 name: A name for the operation (optional). 

242 

243 Returns: 

244 A `tf.RaggedTensor` of dtype `tf.bool` with the shape that `x` and `y` 

245 broadcast to. 

246 """ 

247 return math_ops.logical_and(self, y, name) 

248 

249 

250# Helper Methods. 

251def _right(operator): 

252 """Right-handed version of an operator: swap args x and y.""" 

253 return tf_decorator.make_decorator(operator, lambda y, x: operator(x, y)) 

254 

255 

256def ragged_hash(self): 

257 """The operation invoked by the `RaggedTensor.__hash__` operator.""" 

258 g = getattr(self.row_splits, "graph", None) 

259 # pylint: disable=protected-access 

260 if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and 

261 (g is None or g.building_function)): 

262 raise TypeError("RaggedTensor is unhashable.") 

263 else: 

264 return id(self) 

265 

266 

267# Indexing 

268ragged_tensor.RaggedTensor.__getitem__ = ragged_getitem.ragged_tensor_getitem 

269 

270# Equality 

271ragged_tensor.RaggedTensor.__eq__ = ragged_eq 

272ragged_tensor.RaggedTensor.__ne__ = math_ops.tensor_not_equals 

273ragged_tensor.RaggedTensor.__hash__ = ragged_hash 

274 

275# Ordering operators 

276ragged_tensor.RaggedTensor.__ge__ = ragged_ge 

277ragged_tensor.RaggedTensor.__gt__ = math_ops.greater 

278ragged_tensor.RaggedTensor.__le__ = math_ops.less_equal 

279ragged_tensor.RaggedTensor.__lt__ = math_ops.less 

280 

281# Logical operators 

282ragged_tensor.RaggedTensor.__and__ = ragged_and 

283ragged_tensor.RaggedTensor.__rand__ = _right(ragged_and) 

284 

285ragged_tensor.RaggedTensor.__invert__ = math_ops.logical_not 

286ragged_tensor.RaggedTensor.__ror__ = _right(math_ops.logical_or) 

287ragged_tensor.RaggedTensor.__or__ = math_ops.logical_or 

288ragged_tensor.RaggedTensor.__xor__ = math_ops.logical_xor 

289ragged_tensor.RaggedTensor.__rxor__ = _right(math_ops.logical_xor) 

290 

291# Arithmetic operators 

292ragged_tensor.RaggedTensor.__abs__ = ragged_abs 

293ragged_tensor.RaggedTensor.__add__ = math_ops.add 

294ragged_tensor.RaggedTensor.__radd__ = _right(math_ops.add) 

295ragged_tensor.RaggedTensor.__div__ = math_ops.div 

296ragged_tensor.RaggedTensor.__rdiv__ = _right(math_ops.div) 

297ragged_tensor.RaggedTensor.__floordiv__ = math_ops.floordiv 

298ragged_tensor.RaggedTensor.__rfloordiv__ = _right(math_ops.floordiv) 

299ragged_tensor.RaggedTensor.__mod__ = math_ops.floormod 

300ragged_tensor.RaggedTensor.__rmod__ = _right(math_ops.floormod) 

301ragged_tensor.RaggedTensor.__mul__ = math_ops.multiply 

302ragged_tensor.RaggedTensor.__rmul__ = _right(math_ops.multiply) 

303ragged_tensor.RaggedTensor.__neg__ = math_ops.negative 

304ragged_tensor.RaggedTensor.__pow__ = math_ops.pow 

305ragged_tensor.RaggedTensor.__rpow__ = _right(math_ops.pow) 

306ragged_tensor.RaggedTensor.__sub__ = math_ops.subtract 

307ragged_tensor.RaggedTensor.__rsub__ = _right(math_ops.subtract) 

308ragged_tensor.RaggedTensor.__truediv__ = math_ops.truediv 

309ragged_tensor.RaggedTensor.__rtruediv__ = _right(math_ops.truediv) 

310 

311 

312def ragged_bool(self): # pylint: disable=g-doc-args 

313 """Raises TypeError when a RaggedTensor is used as a Python bool. 

314 

315 To prevent RaggedTensor from being used as a bool, this function always raise 

316 TypeError when being called. 

317 

318 For example: 

319 

320 >>> x = tf.ragged.constant([[1, 2], [3]]) 

321 >>> result = True if x else False # Evaluate x as a bool value. 

322 Traceback (most recent call last): 

323 ... 

324 TypeError: RaggedTensor may not be used as a boolean. 

325 

326 >>> x = tf.ragged.constant([[1]]) 

327 >>> r = (x == 1) # tf.RaggedTensor [[True]] 

328 >>> if r: # Evaluate r as a bool value. 

329 ... pass 

330 Traceback (most recent call last): 

331 ... 

332 TypeError: RaggedTensor may not be used as a boolean. 

333 """ 

334 raise TypeError("RaggedTensor may not be used as a boolean.") 

335 

336 

337ragged_tensor.RaggedTensor.__bool__ = ragged_bool # Python3 bool conversion. 

338ragged_tensor.RaggedTensor.__nonzero__ = ragged_bool # Python2 bool conversion.