Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/util.py: 14%

437 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for probability distributions.""" 

16 

17import functools 

18import hashlib 

19 

20import numpy as np 

21 

22from tensorflow.python.framework import constant_op 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import tensor_shape 

26from tensorflow.python.framework import tensor_util 

27from tensorflow.python.ops import array_ops 

28from tensorflow.python.ops import array_ops_stack 

29from tensorflow.python.ops import check_ops 

30from tensorflow.python.ops import cond as tf_cond 

31from tensorflow.python.ops import control_flow_ops 

32from tensorflow.python.ops import linalg_ops 

33from tensorflow.python.ops import math_ops 

34from tensorflow.python.ops import nn 

35from tensorflow.python.util import tf_inspect 

36 

37 

38def assert_integer_form(x, 

39 data=None, 

40 summarize=None, 

41 message=None, 

42 int_dtype=None, 

43 name="assert_integer_form"): 

44 """Assert that x has integer components (or floats equal to integers). 

45 

46 Args: 

47 x: Floating-point `Tensor` 

48 data: The tensors to print out if the condition is `False`. Defaults to 

49 error message and first few entries of `x` and `y`. 

50 summarize: Print this many entries of each tensor. 

51 message: A string to prefix to the default message. 

52 int_dtype: A `tf.dtype` used to cast the float to. The default (`None`) 

53 implies the smallest possible signed int will be used for casting. 

54 name: A name for this operation (optional). 

55 

56 Returns: 

57 Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`. 

58 """ 

59 with ops.name_scope(name, values=[x, data]): 

60 x = ops.convert_to_tensor(x, name="x") 

61 if x.dtype.is_integer: 

62 return control_flow_ops.no_op() 

63 message = message or "{} has non-integer components".format(x) 

64 if int_dtype is None: 

65 try: 

66 int_dtype = { 

67 dtypes.float16: dtypes.int16, 

68 dtypes.float32: dtypes.int32, 

69 dtypes.float64: dtypes.int64, 

70 }[x.dtype.base_dtype] 

71 except KeyError: 

72 raise TypeError("Unrecognized type {}".format(x.dtype.name)) 

73 return check_ops.assert_equal( 

74 x, 

75 math_ops.cast(math_ops.cast(x, int_dtype), x.dtype), 

76 data=data, 

77 summarize=summarize, 

78 message=message, 

79 name=name) 

80 

81 

82def assert_symmetric(matrix): 

83 matrix_t = array_ops.matrix_transpose(matrix) 

84 return control_flow_ops.with_dependencies( 

85 [check_ops.assert_equal(matrix, matrix_t)], matrix) 

86 

87 

88def embed_check_nonnegative_integer_form( 

89 x, name="embed_check_nonnegative_integer_form"): 

90 """Assert x is a non-negative tensor, and optionally of integers.""" 

91 with ops.name_scope(name, values=[x]): 

92 x = ops.convert_to_tensor(x, name="x") 

93 assertions = [ 

94 check_ops.assert_non_negative( 

95 x, message="'{}' must be non-negative.".format(x)), 

96 ] 

97 if not x.dtype.is_integer: 

98 assertions += [ 

99 assert_integer_form( 

100 x, 

101 message="'{}' cannot contain fractional components.".format(x)), 

102 ] 

103 return control_flow_ops.with_dependencies(assertions, x) 

104 

105 

106def same_dynamic_shape(a, b): 

107 """Returns whether a and b have the same dynamic shape. 

108 

109 Args: 

110 a: `Tensor` 

111 b: `Tensor` 

112 

113 Returns: 

114 `bool` `Tensor` representing if both tensors have the same shape. 

115 """ 

116 a = ops.convert_to_tensor(a, name="a") 

117 b = ops.convert_to_tensor(b, name="b") 

118 

119 # Here we can't just do math_ops.equal(a.shape, b.shape), since 

120 # static shape inference may break the equality comparison between 

121 # shape(a) and shape(b) in math_ops.equal. 

122 def all_shapes_equal(): 

123 return math_ops.reduce_all( 

124 math_ops.equal( 

125 array_ops.concat( 

126 [array_ops.shape(a), array_ops.shape(b)], 0), 

127 array_ops.concat( 

128 [array_ops.shape(b), array_ops.shape(a)], 0))) 

129 

130 # One of the shapes isn't fully defined, so we need to use the dynamic 

131 # shape. 

132 return tf_cond.cond( 

133 math_ops.equal(array_ops.rank(a), array_ops.rank(b)), 

134 all_shapes_equal, lambda: constant_op.constant(False)) 

135 

136 

137def maybe_get_static_value(x, dtype=None): 

138 """Helper which tries to return a static value. 

139 

140 Given `x`, extract it's value statically, optionally casting to a specific 

141 dtype. If this is not possible, None is returned. 

142 

143 Args: 

144 x: `Tensor` for which to extract a value statically. 

145 dtype: Optional dtype to cast to. 

146 

147 Returns: 

148 Statically inferred value if possible, otherwise None. 

149 """ 

150 if x is None: 

151 return x 

152 try: 

153 # This returns an np.ndarray. 

154 x_ = tensor_util.constant_value(x) 

155 except TypeError: 

156 x_ = x 

157 if x_ is None or dtype is None: 

158 return x_ 

159 return np.array(x_, dtype) 

160 

161 

162def get_logits_and_probs(logits=None, 

163 probs=None, 

164 multidimensional=False, 

165 validate_args=False, 

166 name="get_logits_and_probs", 

167 dtype=None): 

168 """Converts logit to probabilities (or vice-versa), and returns both. 

169 

170 Args: 

171 logits: Floating-point `Tensor` representing log-odds. 

172 probs: Floating-point `Tensor` representing probabilities. 

173 multidimensional: Python `bool`, default `False`. If `True`, represents 

174 whether the last dimension of `logits` or `probs`, a `[N1, N2, ... k]` 

175 dimensional tensor, representing the logit or probability of `shape[-1]` 

176 classes. 

177 validate_args: Python `bool`, default `False`. When `True`, either assert `0 

178 <= probs <= 1` (if not `multidimensional`) or that the last dimension of 

179 `probs` sums to one. 

180 name: A name for this operation (optional). 

181 dtype: `tf.DType` to prefer when converting args to `Tensor`s. 

182 

183 Returns: 

184 logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or 

185 `1`, then the corresponding entry in the returned logit will be `-Inf` and 

186 `Inf` respectively. 

187 

188 Raises: 

189 ValueError: if neither `probs` nor `logits` were passed in, or both were. 

190 """ 

191 with ops.name_scope(name, values=[probs, logits]): 

192 if (probs is None) == (logits is None): 

193 raise ValueError("Must pass probs or logits, but not both.") 

194 

195 if probs is None: 

196 logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype) 

197 if not logits.dtype.is_floating: 

198 raise TypeError("logits must having floating type.") 

199 # We can early return since we constructed probs and therefore know 

200 # they're valid. 

201 if multidimensional: 

202 if validate_args: 

203 logits = embed_check_categorical_event_shape(logits) 

204 return logits, nn.softmax(logits, name="probs") 

205 return logits, math_ops.sigmoid(logits, name="probs") 

206 

207 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) 

208 if not probs.dtype.is_floating: 

209 raise TypeError("probs must having floating type.") 

210 

211 if validate_args: 

212 with ops.name_scope("validate_probs"): 

213 one = constant_op.constant(1., probs.dtype) 

214 dependencies = [check_ops.assert_non_negative(probs)] 

215 if multidimensional: 

216 probs = embed_check_categorical_event_shape(probs) 

217 dependencies += [ 

218 check_ops.assert_near( 

219 math_ops.reduce_sum(probs, -1), 

220 one, 

221 message="probs does not sum to 1.") 

222 ] 

223 else: 

224 dependencies += [ 

225 check_ops.assert_less_equal( 

226 probs, one, message="probs has components greater than 1.") 

227 ] 

228 probs = control_flow_ops.with_dependencies(dependencies, probs) 

229 

230 with ops.name_scope("logits"): 

231 if multidimensional: 

232 # Here we don't compute the multidimensional case, in a manner 

233 # consistent with respect to the unidimensional case. We do so 

234 # following the TF convention. Typically, you might expect to see 

235 # logits = log(probs) - log(probs[pivot]). A side-effect of 

236 # being consistent with the TF approach is that the unidimensional case 

237 # implicitly handles the second dimension but the multidimensional case 

238 # explicitly keeps the pivot dimension. 

239 return math_ops.log(probs), probs 

240 return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs 

241 

242 

243def _is_known_unsigned_by_dtype(dt): 

244 """Helper returning True if dtype is known to be unsigned.""" 

245 return { 

246 dtypes.bool: True, 

247 dtypes.uint8: True, 

248 dtypes.uint16: True, 

249 }.get(dt.base_dtype, False) 

250 

251 

252def _is_known_signed_by_dtype(dt): 

253 """Helper returning True if dtype is known to be signed.""" 

254 return { 

255 dtypes.float16: True, 

256 dtypes.float32: True, 

257 dtypes.float64: True, 

258 dtypes.int8: True, 

259 dtypes.int16: True, 

260 dtypes.int32: True, 

261 dtypes.int64: True, 

262 }.get(dt.base_dtype, False) 

263 

264 

265def _is_known_dtype(dt): 

266 """Helper returning True if dtype is known.""" 

267 return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt) 

268 

269 

270def _largest_integer_by_dtype(dt): 

271 """Helper returning the largest integer exactly representable by dtype.""" 

272 if not _is_known_dtype(dt): 

273 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 

274 if dt.is_floating: 

275 return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1)) 

276 if dt.is_integer: 

277 return np.iinfo(dt.as_numpy_dtype).max 

278 if dt.base_dtype == dtypes.bool: 

279 return int(1) 

280 # We actually can't land here but keep the case for completeness. 

281 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 

282 

283 

284def _smallest_integer_by_dtype(dt): 

285 """Helper returning the smallest integer exactly representable by dtype.""" 

286 if not _is_known_dtype(dt): 

287 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 

288 if _is_known_unsigned_by_dtype(dt): 

289 return 0 

290 return -1 * _largest_integer_by_dtype(dt) 

291 

292 

293def _is_integer_like_by_dtype(dt): 

294 """Helper returning True if dtype.is_integer or is `bool`.""" 

295 if not _is_known_dtype(dt): 

296 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 

297 return dt.is_integer or dt.base_dtype == dtypes.bool 

298 

299 

300def embed_check_categorical_event_shape( 

301 categorical_param, name="embed_check_categorical_event_shape"): 

302 """Embeds checks that categorical distributions don't have too many classes. 

303 

304 A categorical-type distribution is one which, e.g., returns the class label 

305 rather than a one-hot encoding. E.g., `Categorical(probs)`. 

306 

307 Since distributions output samples in the same dtype as the parameters, we 

308 must ensure that casting doesn't lose precision. That is, the 

309 `parameter.dtype` implies a maximum number of classes. However, since shape is 

310 `int32` and categorical variables are presumed to be indexes into a `Tensor`, 

311 we must also ensure that the number of classes is no larger than the largest 

312 possible `int32` index, i.e., `2**31-1`. 

313 

314 In other words the number of classes, `K`, must satisfy the following 

315 condition: 

316 

317 ```python 

318 K <= min( 

319 int(2**31 - 1), # Largest float as an index. 

320 { 

321 dtypes.float16: int(2**11), # Largest int as a float16. 

322 dtypes.float32: int(2**24), 

323 dtypes.float64: int(2**53), 

324 }.get(categorical_param.dtype.base_dtype, 0)) 

325 ``` 

326 

327 Args: 

328 categorical_param: Floating-point `Tensor` representing parameters of 

329 distribution over categories. The rightmost shape is presumed to be the 

330 number of categories. 

331 name: A name for this operation (optional). 

332 

333 Returns: 

334 categorical_param: Input `Tensor` with appropriate assertions embedded. 

335 

336 Raises: 

337 TypeError: if `categorical_param` has an unknown `dtype`. 

338 ValueError: if we can statically identify `categorical_param` as being too 

339 large (for being closed under int32/float casting). 

340 """ 

341 with ops.name_scope(name, values=[categorical_param]): 

342 x = ops.convert_to_tensor(categorical_param, name="categorical_param") 

343 # The size must not exceed both of: 

344 # - The largest possible int32 (since categorical values are presumed to be 

345 # indexes into a Tensor). 

346 # - The largest possible integer exactly representable under the given 

347 # floating-point dtype (since we need to cast to/from). 

348 # 

349 # The chosen floating-point thresholds are 2**(1 + mantissa_bits). 

350 # For more details, see: 

351 # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation 

352 x_dtype = x.dtype.base_dtype 

353 max_event_size = ( 

354 _largest_integer_by_dtype(x_dtype) if x_dtype.is_floating else 0) 

355 if max_event_size == 0: 

356 raise TypeError("Unable to validate size of unrecognized dtype " 

357 "({}).".format(x_dtype.name)) 

358 try: 

359 x_shape_static = x.get_shape().with_rank_at_least(1) 

360 except ValueError: 

361 raise ValueError("A categorical-distribution parameter must have " 

362 "at least 1 dimension.") 

363 if tensor_shape.dimension_value(x_shape_static[-1]) is not None: 

364 event_size = x_shape_static.dims[-1].value 

365 if event_size < 2: 

366 raise ValueError("A categorical-distribution parameter must have at " 

367 "least 2 events.") 

368 if event_size > max_event_size: 

369 raise ValueError("Number of classes exceeds `dtype` precision, i.e., " 

370 "{} implies shape ({}) cannot exceed {}.".format( 

371 x_dtype.name, event_size, max_event_size)) 

372 return x 

373 else: 

374 event_size = array_ops.shape(x, name="x_shape")[-1] 

375 return control_flow_ops.with_dependencies([ 

376 check_ops.assert_rank_at_least( 

377 x, 

378 1, 

379 message=("A categorical-distribution parameter must have " 

380 "at least 1 dimension.")), 

381 check_ops.assert_greater_equal( 

382 array_ops.shape(x)[-1], 

383 2, 

384 message=("A categorical-distribution parameter must have at " 

385 "least 2 events.")), 

386 check_ops.assert_less_equal( 

387 event_size, 

388 max_event_size, 

389 message="Number of classes exceeds `dtype` precision, " 

390 "i.e., {} dtype cannot exceed {} shape.".format( 

391 x_dtype.name, max_event_size)), 

392 ], x) 

393 

394 

395def embed_check_integer_casting_closed(x, 

396 target_dtype, 

397 assert_nonnegative=True, 

398 name="embed_check_casting_closed"): 

399 """Ensures integers remain unaffected despite casting to/from int/float types. 

400 

401 Example integer-types: `uint8`, `int32`, `bool`. 

402 Example floating-types: `float32`, `float64`. 

403 

404 The largest possible integer representable by an IEEE754 floating-point is 

405 `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is 

406 `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have 

407 integer-form values can be cast to some other type without loss of precision. 

408 

409 The smallest representable integer is the negative of the largest 

410 representable integer, except for types: `uint8`, `uint16`, `bool`. For these 

411 types, the smallest representable integer is `0`. 

412 

413 Args: 

414 x: `Tensor` representing integer-form values. 

415 target_dtype: TF `dtype` under which `x` should have identical values. 

416 assert_nonnegative: `bool` indicating `x` should contain nonnegative values. 

417 name: A name for this operation (optional). 

418 

419 Returns: 

420 x: Input `Tensor` with appropriate assertions embedded. 

421 

422 Raises: 

423 TypeError: if `x` is neither integer- nor floating-type. 

424 TypeError: if `target_dtype` is neither integer- nor floating-type. 

425 TypeError: if neither `x` nor `target_dtype` are integer-type. 

426 """ 

427 

428 with ops.name_scope(name, values=[x]): 

429 x = ops.convert_to_tensor(x, name="x") 

430 if (not _is_integer_like_by_dtype(x.dtype) and not x.dtype.is_floating): 

431 raise TypeError("{}.dtype must be floating- or " 

432 "integer-type.".format(x.dtype.name)) 

433 if (not _is_integer_like_by_dtype(target_dtype) and 

434 not target_dtype.is_floating): 

435 raise TypeError("target_dtype ({}) must be floating- or " 

436 "integer-type.".format(target_dtype.name)) 

437 if (not _is_integer_like_by_dtype(x.dtype) and 

438 not _is_integer_like_by_dtype(target_dtype)): 

439 raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) " 

440 "must be integer-type.".format(x, x.dtype.name, 

441 target_dtype.name)) 

442 

443 assertions = [] 

444 if assert_nonnegative: 

445 assertions += [ 

446 check_ops.assert_non_negative( 

447 x, message="Elements must be non-negative."), 

448 ] 

449 

450 if x.dtype.is_floating: 

451 # Being here means _is_integer_like_by_dtype(target_dtype) = True. 

452 # Since this check implies the magnitude check below, we need only it. 

453 assertions += [ 

454 assert_integer_form( 

455 x, 

456 int_dtype=target_dtype, 

457 message="Elements must be {}-equivalent.".format( 

458 target_dtype.name)), 

459 ] 

460 else: 

461 if (_largest_integer_by_dtype(x.dtype) > 

462 _largest_integer_by_dtype(target_dtype)): 

463 # Cast may lose integer precision. 

464 assertions += [ 

465 check_ops.assert_less_equal( 

466 x, 

467 _largest_integer_by_dtype(target_dtype), 

468 message=("Elements cannot exceed {}.".format( 

469 _largest_integer_by_dtype(target_dtype)))), 

470 ] 

471 if (not assert_nonnegative and (_smallest_integer_by_dtype( 

472 x.dtype) < _smallest_integer_by_dtype(target_dtype))): 

473 assertions += [ 

474 check_ops.assert_greater_equal( 

475 x, 

476 _smallest_integer_by_dtype(target_dtype), 

477 message=("Elements cannot be smaller than {}.".format( 

478 _smallest_integer_by_dtype(target_dtype)))), 

479 ] 

480 

481 if not assertions: 

482 return x 

483 return control_flow_ops.with_dependencies(assertions, x) 

484 

485 

486def log_combinations(n, counts, name="log_combinations"): 

487 """Multinomial coefficient. 

488 

489 Given `n` and `counts`, where `counts` has last dimension `k`, we compute 

490 the multinomial coefficient as: 

491 

492 ```n! / sum_i n_i!``` 

493 

494 where `i` runs over all `k` classes. 

495 

496 Args: 

497 n: Floating-point `Tensor` broadcastable with `counts`. This represents `n` 

498 outcomes. 

499 counts: Floating-point `Tensor` broadcastable with `n`. This represents 

500 counts in `k` classes, where `k` is the last dimension of the tensor. 

501 name: A name for this operation (optional). 

502 

503 Returns: 

504 `Tensor` representing the multinomial coefficient between `n` and `counts`. 

505 """ 

506 # First a bit about the number of ways counts could have come in: 

507 # E.g. if counts = [1, 2], then this is 3 choose 2. 

508 # In general, this is (sum counts)! / sum(counts!) 

509 # The sum should be along the last dimension of counts. This is the 

510 # "distribution" dimension. Here n a priori represents the sum of counts. 

511 with ops.name_scope(name, values=[n, counts]): 

512 n = ops.convert_to_tensor(n, name="n") 

513 counts = ops.convert_to_tensor(counts, name="counts") 

514 total_permutations = math_ops.lgamma(n + 1) 

515 counts_factorial = math_ops.lgamma(counts + 1) 

516 redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1]) 

517 return total_permutations - redundant_permutations 

518 

519 

520def matrix_diag_transform(matrix, transform=None, name=None): 

521 """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. 

522 

523 Create a trainable covariance defined by a Cholesky factor: 

524 

525 ```python 

526 # Transform network layer into 2 x 2 array. 

527 matrix_values = tf.contrib.layers.fully_connected(activations, 4) 

528 matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) 

529 

530 # Make the diagonal positive. If the upper triangle was zero, this would be a 

531 # valid Cholesky factor. 

532 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) 

533 

534 # LinearOperatorLowerTriangular ignores the upper triangle. 

535 operator = LinearOperatorLowerTriangular(chol) 

536 ``` 

537 

538 Example of heteroskedastic 2-D linear regression. 

539 

540 ```python 

541 tfd = tfp.distributions 

542 

543 # Get a trainable Cholesky factor. 

544 matrix_values = tf.contrib.layers.fully_connected(activations, 4) 

545 matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) 

546 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) 

547 

548 # Get a trainable mean. 

549 mu = tf.contrib.layers.fully_connected(activations, 2) 

550 

551 # This is a fully trainable multivariate normal! 

552 dist = tfd.MultivariateNormalTriL(mu, chol) 

553 

554 # Standard log loss. Minimizing this will "train" mu and chol, and then dist 

555 # will be a distribution predicting labels as multivariate Gaussians. 

556 loss = -1 * tf.reduce_mean(dist.log_prob(labels)) 

557 ``` 

558 

559 Args: 

560 matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are 

561 equal. 

562 transform: Element-wise function mapping `Tensors` to `Tensors`. To be 

563 applied to the diagonal of `matrix`. If `None`, `matrix` is returned 

564 unchanged. Defaults to `None`. 

565 name: A name to give created ops. Defaults to "matrix_diag_transform". 

566 

567 Returns: 

568 A `Tensor` with same shape and `dtype` as `matrix`. 

569 """ 

570 with ops.name_scope(name, "matrix_diag_transform", [matrix]): 

571 matrix = ops.convert_to_tensor(matrix, name="matrix") 

572 if transform is None: 

573 return matrix 

574 # Replace the diag with transformed diag. 

575 diag = array_ops.matrix_diag_part(matrix) 

576 transformed_diag = transform(diag) 

577 transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag) 

578 

579 return transformed_mat 

580 

581 

582def rotate_transpose(x, shift, name="rotate_transpose"): 

583 """Circularly moves dims left or right. 

584 

585 Effectively identical to: 

586 

587 ```python 

588 numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift)) 

589 ``` 

590 

591 When `validate_args=False` additional graph-runtime checks are 

592 performed. These checks entail moving data from to GPU to CPU. 

593 

594 Example: 

595 

596 ```python 

597 x = tf.random.normal([1, 2, 3, 4]) # Tensor of shape [1, 2, 3, 4]. 

598 rotate_transpose(x, -1).shape == [2, 3, 4, 1] 

599 rotate_transpose(x, -2).shape == [3, 4, 1, 2] 

600 rotate_transpose(x, 1).shape == [4, 1, 2, 3] 

601 rotate_transpose(x, 2).shape == [3, 4, 1, 2] 

602 rotate_transpose(x, 7).shape == rotate_transpose(x, 3).shape # [2, 3, 4, 1] 

603 rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape # [4, 1, 2, 3] 

604 ``` 

605 

606 Args: 

607 x: `Tensor`. 

608 shift: `Tensor`. Number of dimensions to transpose left (shift<0) or 

609 transpose right (shift>0). 

610 name: Python `str`. The name to give this op. 

611 

612 Returns: 

613 rotated_x: Input `Tensor` with dimensions circularly rotated by shift. 

614 

615 Raises: 

616 TypeError: if shift is not integer type. 

617 """ 

618 with ops.name_scope(name, values=[x, shift]): 

619 x = ops.convert_to_tensor(x, name="x") 

620 shift = ops.convert_to_tensor(shift, name="shift") 

621 # We do not assign back to preserve constant-ness. 

622 check_ops.assert_integer(shift) 

623 shift_value_static = tensor_util.constant_value(shift) 

624 ndims = x.get_shape().ndims 

625 if ndims is not None and shift_value_static is not None: 

626 if ndims < 2: 

627 return x 

628 shift_value_static = np.sign(shift_value_static) * ( 

629 abs(shift_value_static) % ndims) 

630 if shift_value_static == 0: 

631 return x 

632 perm = np.roll(np.arange(ndims), shift_value_static) 

633 return array_ops.transpose(x, perm=perm) 

634 else: 

635 # Consider if we always had a positive shift, and some specified 

636 # direction. 

637 # When shifting left we want the new array: 

638 # last(x, n-shift) + first(x, shift) 

639 # and if shifting right then we want: 

640 # last(x, shift) + first(x, n-shift) 

641 # Observe that last(a) == slice(a, n) and first(a) == slice(0, a). 

642 # Also, we can encode direction and shift as one: direction * shift. 

643 # Combining these facts, we have: 

644 # a = cond(shift<0, -shift, n-shift) 

645 # last(x, n-a) + first(x, a) == x[a:n] + x[0:a] 

646 # Finally, we transform shift by modulo length so it can be specified 

647 # independently from the array upon which it operates (like python). 

648 ndims = array_ops.rank(x) 

649 shift = array_ops.where_v2( 

650 math_ops.less(shift, 0), 

651 math_ops.mod(-shift, ndims), # pylint: disable=invalid-unary-operand-type 

652 ndims - math_ops.mod(shift, ndims)) 

653 first = math_ops.range(0, shift) 

654 last = math_ops.range(shift, ndims) 

655 perm = array_ops.concat([last, first], 0) 

656 return array_ops.transpose(x, perm=perm) 

657 

658 

659def pick_vector(cond, true_vector, false_vector, name="pick_vector"): 

660 """Picks possibly different length row `Tensor`s based on condition. 

661 

662 Value `Tensor`s should have exactly one dimension. 

663 

664 If `cond` is a python Boolean or `tf.constant` then either `true_vector` or 

665 `false_vector` is immediately returned. I.e., no graph nodes are created and 

666 no validation happens. 

667 

668 Args: 

669 cond: `Tensor`. Must have `dtype=tf.bool` and be scalar. 

670 true_vector: `Tensor` of one dimension. Returned when cond is `True`. 

671 false_vector: `Tensor` of one dimension. Returned when cond is `False`. 

672 name: Python `str`. The name to give this op. 

673 Example: ```python pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15, 

674 18)) # [10, 11] pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15, 

675 18)) # [15, 16, 17] ``` 

676 

677 Returns: 

678 true_or_false_vector: `Tensor`. 

679 

680 Raises: 

681 TypeError: if `cond.dtype != tf.bool` 

682 TypeError: if `cond` is not a constant and 

683 `true_vector.dtype != false_vector.dtype` 

684 """ 

685 with ops.name_scope(name, values=(cond, true_vector, false_vector)): 

686 cond = ops.convert_to_tensor(cond, name="cond") 

687 if cond.dtype != dtypes.bool: 

688 raise TypeError("%s.dtype=%s which is not %s" % 

689 (cond, cond.dtype, dtypes.bool)) 

690 cond_value_static = tensor_util.constant_value(cond) 

691 if cond_value_static is not None: 

692 return true_vector if cond_value_static else false_vector 

693 true_vector = ops.convert_to_tensor(true_vector, name="true_vector") 

694 false_vector = ops.convert_to_tensor(false_vector, name="false_vector") 

695 if true_vector.dtype != false_vector.dtype: 

696 raise TypeError( 

697 "%s.dtype=%s does not match %s.dtype=%s" % 

698 (true_vector, true_vector.dtype, false_vector, false_vector.dtype)) 

699 n = array_ops.shape(true_vector)[0] 

700 return array_ops.slice( 

701 array_ops.concat([true_vector, false_vector], 0), 

702 [array_ops.where_v2(cond, 0, n)], [array_ops.where(cond, n, -1)]) 

703 

704 

705def prefer_static_broadcast_shape(shape1, 

706 shape2, 

707 name="prefer_static_broadcast_shape"): 

708 """Convenience function which statically broadcasts shape when possible. 

709 

710 Args: 

711 shape1: `1-D` integer `Tensor`. Already converted to tensor! 

712 shape2: `1-D` integer `Tensor`. Already converted to tensor! 

713 name: A string name to prepend to created ops. 

714 

715 Returns: 

716 The broadcast shape, either as `TensorShape` (if broadcast can be done 

717 statically), or as a `Tensor`. 

718 """ 

719 with ops.name_scope(name, values=[shape1, shape2]): 

720 

721 def make_shape_tensor(x): 

722 return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) 

723 

724 def get_tensor_shape(s): 

725 if isinstance(s, tensor_shape.TensorShape): 

726 return s 

727 s_ = tensor_util.constant_value(make_shape_tensor(s)) 

728 if s_ is not None: 

729 return tensor_shape.TensorShape(s_) 

730 return None 

731 

732 def get_shape_tensor(s): 

733 if not isinstance(s, tensor_shape.TensorShape): 

734 return make_shape_tensor(s) 

735 if s.is_fully_defined(): 

736 return make_shape_tensor(s.as_list()) 

737 raise ValueError("Cannot broadcast from partially " 

738 "defined `TensorShape`.") 

739 

740 shape1_ = get_tensor_shape(shape1) 

741 shape2_ = get_tensor_shape(shape2) 

742 if shape1_ is not None and shape2_ is not None: 

743 return array_ops.broadcast_static_shape(shape1_, shape2_) 

744 

745 shape1_ = get_shape_tensor(shape1) 

746 shape2_ = get_shape_tensor(shape2) 

747 return array_ops.broadcast_dynamic_shape(shape1_, shape2_) 

748 

749 

750def prefer_static_rank(x): 

751 """Return static rank of tensor `x` if available, else `tf.rank(x)`. 

752 

753 Args: 

754 x: `Tensor` (already converted). 

755 

756 Returns: 

757 Numpy array (if static rank is obtainable), else `Tensor`. 

758 """ 

759 return prefer_static_value(array_ops.rank(x)) 

760 

761 

762def prefer_static_shape(x): 

763 """Return static shape of tensor `x` if available, else `tf.shape(x)`. 

764 

765 Args: 

766 x: `Tensor` (already converted). 

767 

768 Returns: 

769 Numpy array (if static shape is obtainable), else `Tensor`. 

770 """ 

771 return prefer_static_value(array_ops.shape(x)) 

772 

773 

774def prefer_static_value(x): 

775 """Return static value of tensor `x` if available, else `x`. 

776 

777 Args: 

778 x: `Tensor` (already converted). 

779 

780 Returns: 

781 Numpy array (if static value is obtainable), else `Tensor`. 

782 """ 

783 static_x = tensor_util.constant_value(x) 

784 if static_x is not None: 

785 return static_x 

786 return x 

787 

788 

789def gen_new_seed(seed, salt): 

790 """Generate a new seed, from the given seed and salt.""" 

791 if seed is None: 

792 return None 

793 string = (str(seed) + salt).encode("utf-8") 

794 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 

795 

796 

797def fill_triangular(x, upper=False, name=None): 

798 """Creates a (batch of) triangular matrix from a vector of inputs. 

799 

800 Created matrix can be lower- or upper-triangular. (It is more efficient to 

801 create the matrix as upper or lower, rather than transpose.) 

802 

803 Triangular matrix elements are filled in a clockwise spiral. See example, 

804 below. 

805 

806 If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is 

807 `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., 

808 `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. 

809 

810 Example: 

811 

812 ```python 

813 fill_triangular([1, 2, 3, 4, 5, 6]) 

814 # ==> [[4, 0, 0], 

815 # [6, 5, 0], 

816 # [3, 2, 1]] 

817 

818 fill_triangular([1, 2, 3, 4, 5, 6], upper=True) 

819 # ==> [[1, 2, 3], 

820 # [0, 5, 6], 

821 # [0, 0, 4]] 

822 ``` 

823 

824 For comparison, a pure numpy version of this function can be found in 

825 `util_test.py`, function `_fill_triangular`. 

826 

827 Args: 

828 x: `Tensor` representing lower (or upper) triangular elements. 

829 upper: Python `bool` representing whether output matrix should be upper 

830 triangular (`True`) or lower triangular (`False`, default). 

831 name: Python `str`. The name to give this op. 

832 

833 Returns: 

834 tril: `Tensor` with lower (or upper) triangular elements filled from `x`. 

835 

836 Raises: 

837 ValueError: if `x` cannot be mapped to a triangular matrix. 

838 """ 

839 

840 with ops.name_scope(name, "fill_triangular", values=[x]): 

841 x = ops.convert_to_tensor(x, name="x") 

842 if tensor_shape.dimension_value( 

843 x.shape.with_rank_at_least(1)[-1]) is not None: 

844 # Formula derived by solving for n: m = n(n+1)/2. 

845 m = np.int32(x.shape.dims[-1].value) 

846 n = np.sqrt(0.25 + 2. * m) - 0.5 

847 if n != np.floor(n): 

848 raise ValueError("Input right-most shape ({}) does not " 

849 "correspond to a triangular matrix.".format(m)) 

850 n = np.int32(n) 

851 static_final_shape = x.shape[:-1].concatenate([n, n]) 

852 else: 

853 m = array_ops.shape(x)[-1] 

854 # For derivation, see above. Casting automatically lops off the 0.5, so we 

855 # omit it. We don't validate n is an integer because this has 

856 # graph-execution cost; an error will be thrown from the reshape, below. 

857 n = math_ops.cast( 

858 math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)), 

859 dtype=dtypes.int32) 

860 static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate( 

861 [None, None]) 

862 # We now concatenate the "tail" of `x` to `x` (and reverse one of them). 

863 # 

864 # We do this based on the insight that the input `x` provides `ceil(n/2)` 

865 # rows of an `n x n` matrix, some of which will get zeroed out being on the 

866 # wrong side of the diagonal. The first row will not get zeroed out at all, 

867 # and we need `floor(n/2)` more rows, so the first is what we omit from 

868 # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)` 

869 # rows provided by a reversed tail, it is exactly the other set of elements 

870 # of the reversed tail which will be zeroed out for being on the wrong side 

871 # of the diagonal further up/down the matrix. And, in doing-so, we've filled 

872 # the triangular matrix in a clock-wise spiral pattern. Neat! 

873 # 

874 # Try it out in numpy: 

875 # n = 3 

876 # x = np.arange(n * (n + 1) / 2) 

877 # m = x.shape[0] 

878 # n = np.int32(np.sqrt(.25 + 2 * m) - .5) 

879 # x_tail = x[(m - (n**2 - m)):] 

880 # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower 

881 # # ==> array([[3, 4, 5], 

882 # [5, 4, 3], 

883 # [2, 1, 0]]) 

884 # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper 

885 # # ==> array([[0, 1, 2], 

886 # [3, 4, 5], 

887 # [5, 4, 3]]) 

888 # 

889 # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't 

890 # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. 

891 # Furthermore observe that: 

892 # m - (n**2 - m) 

893 # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) 

894 # = 2 (n**2 / 2 + n / 2) - n**2 

895 # = n**2 + n - n**2 

896 # = n 

897 ndims = prefer_static_rank(x) 

898 if upper: 

899 x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])] 

900 else: 

901 x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])] 

902 new_shape = ( 

903 static_final_shape.as_list() if static_final_shape.is_fully_defined() 

904 else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0)) 

905 x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape) 

906 x = array_ops.matrix_band_part( 

907 x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0)) 

908 x.set_shape(static_final_shape) 

909 return x 

910 

911 

912def fill_triangular_inverse(x, upper=False, name=None): 

913 """Creates a vector from a (batch of) triangular matrix. 

914 

915 The vector is created from the lower-triangular or upper-triangular portion 

916 depending on the value of the parameter `upper`. 

917 

918 If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is 

919 `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. 

920 

921 Example: 

922 

923 ```python 

924 fill_triangular_inverse( 

925 [[4, 0, 0], 

926 [6, 5, 0], 

927 [3, 2, 1]]) 

928 

929 # ==> [1, 2, 3, 4, 5, 6] 

930 

931 fill_triangular_inverse( 

932 [[1, 2, 3], 

933 [0, 5, 6], 

934 [0, 0, 4]], upper=True) 

935 

936 # ==> [1, 2, 3, 4, 5, 6] 

937 ``` 

938 

939 Args: 

940 x: `Tensor` representing lower (or upper) triangular elements. 

941 upper: Python `bool` representing whether output matrix should be upper 

942 triangular (`True`) or lower triangular (`False`, default). 

943 name: Python `str`. The name to give this op. 

944 

945 Returns: 

946 flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower 

947 (or upper) triangular elements from `x`. 

948 """ 

949 

950 with ops.name_scope(name, "fill_triangular_inverse", values=[x]): 

951 x = ops.convert_to_tensor(x, name="x") 

952 if tensor_shape.dimension_value( 

953 x.shape.with_rank_at_least(2)[-1]) is not None: 

954 n = np.int32(x.shape.dims[-1].value) 

955 m = np.int32((n * (n + 1)) // 2) 

956 static_final_shape = x.shape[:-2].concatenate([m]) 

957 else: 

958 n = array_ops.shape(x)[-1] 

959 m = (n * (n + 1)) // 2 

960 static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate( 

961 [None]) 

962 ndims = prefer_static_rank(x) 

963 if upper: 

964 initial_elements = x[..., 0, :] 

965 triangular_portion = x[..., 1:, :] 

966 else: 

967 initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2]) 

968 triangular_portion = x[..., :-1, :] 

969 rotated_triangular_portion = array_ops.reverse( 

970 array_ops.reverse(triangular_portion, axis=[ndims - 1]), 

971 axis=[ndims - 2]) 

972 consolidated_matrix = triangular_portion + rotated_triangular_portion 

973 end_sequence = array_ops.reshape( 

974 consolidated_matrix, 

975 array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0)) 

976 y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) 

977 y.set_shape(static_final_shape) 

978 return y 

979 

980 

981def tridiag(below=None, diag=None, above=None, name=None): 

982 """Creates a matrix with values set above, below, and on the diagonal. 

983 

984 Example: 

985 

986 ```python 

987 tridiag(below=[1., 2., 3.], 

988 diag=[4., 5., 6., 7.], 

989 above=[8., 9., 10.]) 

990 # ==> array([[ 4., 8., 0., 0.], 

991 # [ 1., 5., 9., 0.], 

992 # [ 0., 2., 6., 10.], 

993 # [ 0., 0., 3., 7.]], dtype=float32) 

994 ``` 

995 

996 Warning: This Op is intended for convenience, not efficiency. 

997 

998 Args: 

999 below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below 

1000 diagonal part. `None` is logically equivalent to `below = 0`. 

1001 diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal 

1002 part. `None` is logically equivalent to `diag = 0`. 

1003 above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above 

1004 diagonal part. `None` is logically equivalent to `above = 0`. 

1005 name: Python `str`. The name to give this op. 

1006 

1007 Returns: 

1008 tridiag: `Tensor` with values set above, below and on the diagonal. 

1009 

1010 Raises: 

1011 ValueError: if all inputs are `None`. 

1012 """ 

1013 

1014 def _pad(x): 

1015 """Prepends and appends a zero to every vector in a batch of vectors.""" 

1016 shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0) 

1017 z = array_ops.zeros(shape, dtype=x.dtype) 

1018 return array_ops.concat([z, x, z], axis=-1) 

1019 

1020 def _add(*x): 

1021 """Adds list of Tensors, ignoring `None`.""" 

1022 s = None 

1023 for y in x: 

1024 if y is None: 

1025 continue 

1026 elif s is None: 

1027 s = y 

1028 else: 

1029 s += y 

1030 if s is None: 

1031 raise ValueError("Must specify at least one of `below`, `diag`, `above`.") 

1032 return s 

1033 

1034 with ops.name_scope(name, "tridiag", [below, diag, above]): 

1035 if below is not None: 

1036 below = ops.convert_to_tensor(below, name="below") 

1037 below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:] 

1038 if diag is not None: 

1039 diag = ops.convert_to_tensor(diag, name="diag") 

1040 diag = array_ops.matrix_diag(diag) 

1041 if above is not None: 

1042 above = ops.convert_to_tensor(above, name="above") 

1043 above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1] 

1044 # TODO(jvdillon): Consider using scatter_nd instead of creating three full 

1045 # matrices. 

1046 return _add(below, diag, above) 

1047 

1048 

1049def reduce_weighted_logsumexp(logx, 

1050 w=None, 

1051 axis=None, 

1052 keep_dims=False, 

1053 return_sign=False, 

1054 name=None): 

1055 """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`. 

1056 

1057 If all weights `w` are known to be positive, it is more efficient to directly 

1058 use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.math.log(w))` is 

1059 more 

1060 efficient than `du.reduce_weighted_logsumexp(logx, w)`. 

1061 

1062 Reduces `input_tensor` along the dimensions given in `axis`. 

1063 Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each 

1064 entry in `axis`. If `keep_dims` is true, the reduced dimensions 

1065 are retained with length 1. 

1066 

1067 If `axis` has no entries, all dimensions are reduced, and a 

1068 tensor with a single element is returned. 

1069 

1070 This function is more numerically stable than log(sum(w * exp(input))). It 

1071 avoids overflows caused by taking the exp of large inputs and underflows 

1072 caused by taking the log of small inputs. 

1073 

1074 For example: 

1075 

1076 ```python 

1077 x = tf.constant([[0., 0, 0], 

1078 [0, 0, 0]]) 

1079 

1080 w = tf.constant([[-1., 1, 1], 

1081 [1, 1, 1]]) 

1082 

1083 du.reduce_weighted_logsumexp(x, w) 

1084 # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4) 

1085 

1086 du.reduce_weighted_logsumexp(x, w, axis=0) 

1087 # ==> [log(-1+1), log(1+1), log(1+1)] 

1088 

1089 du.reduce_weighted_logsumexp(x, w, axis=1) 

1090 # ==> [log(-1+1+1), log(1+1+1)] 

1091 

1092 du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True) 

1093 # ==> [[log(-1+1+1)], [log(1+1+1)]] 

1094 

1095 du.reduce_weighted_logsumexp(x, w, axis=[0, 1]) 

1096 # ==> log(-1+5) 

1097 ``` 

1098 

1099 Args: 

1100 logx: The tensor to reduce. Should have numeric type. 

1101 w: The weight tensor. Should have numeric type identical to `logx`. 

1102 axis: The dimensions to reduce. If `None` (the default), reduces all 

1103 dimensions. Must be in the range `[-rank(input_tensor), 

1104 rank(input_tensor))`. 

1105 keep_dims: If true, retains reduced dimensions with length 1. 

1106 return_sign: If `True`, returns the sign of the result. 

1107 name: A name for the operation (optional). 

1108 

1109 Returns: 

1110 lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor. 

1111 sign: (Optional) The sign of `sum(weight * exp(x))`. 

1112 """ 

1113 with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]): 

1114 logx = ops.convert_to_tensor(logx, name="logx") 

1115 if w is None: 

1116 lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims) 

1117 if return_sign: 

1118 sgn = array_ops.ones_like(lswe) 

1119 return lswe, sgn 

1120 return lswe 

1121 w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w") 

1122 log_absw_x = logx + math_ops.log(math_ops.abs(w)) 

1123 max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True) 

1124 # If the largest element is `-inf` or `inf` then we don't bother subtracting 

1125 # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That 

1126 # this is ok follows from the fact that we're actually free to subtract any 

1127 # value we like, so long as we add it back after taking the `log(sum(...))`. 

1128 max_log_absw_x = array_ops.where_v2( 

1129 math_ops.is_inf(max_log_absw_x), array_ops.zeros_like(max_log_absw_x), 

1130 max_log_absw_x) 

1131 wx_over_max_absw_x = ( 

1132 math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x)) 

1133 sum_wx_over_max_absw_x = math_ops.reduce_sum( 

1134 wx_over_max_absw_x, axis=axis, keepdims=keep_dims) 

1135 if not keep_dims: 

1136 max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis) 

1137 sgn = math_ops.sign(sum_wx_over_max_absw_x) 

1138 lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x) 

1139 if return_sign: 

1140 return lswe, sgn 

1141 return lswe 

1142 

1143 

1144# TODO(jvdillon): Merge this test back into: 

1145# tensorflow/python/ops/softplus_op_test.py 

1146# once TF core is accepting new ops. 

1147def softplus_inverse(x, name=None): 

1148 """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). 

1149 

1150 Mathematically this op is equivalent to: 

1151 

1152 ```none 

1153 softplus_inverse = log(exp(x) - 1.) 

1154 ``` 

1155 

1156 Args: 

1157 x: `Tensor`. Non-negative (not enforced), floating-point. 

1158 name: A name for the operation (optional). 

1159 

1160 Returns: 

1161 `Tensor`. Has the same type/shape as input `x`. 

1162 """ 

1163 with ops.name_scope(name, "softplus_inverse", values=[x]): 

1164 x = ops.convert_to_tensor(x, name="x") 

1165 # We begin by deriving a more numerically stable softplus_inverse: 

1166 # x = softplus(y) = Log[1 + exp{y}], (which means x > 0). 

1167 # ==> exp{x} = 1 + exp{y} (1) 

1168 # ==> y = Log[exp{x} - 1] (2) 

1169 # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}] 

1170 # = Log[(1 - exp{-x}) / 1] + Log[exp{x}] 

1171 # = Log[1 - exp{-x}] + x (3) 

1172 # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x. 

1173 # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will 

1174 # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0. 

1175 # 

1176 # In addition to the numerically stable derivation above, we clamp 

1177 # small/large values to be congruent with the logic in: 

1178 # tensorflow/core/kernels/softplus_op.h 

1179 # 

1180 # Finally, we set the input to one whenever the input is too large or too 

1181 # small. This ensures that no unchosen codepath is +/- inf. This is 

1182 # necessary to ensure the gradient doesn't get NaNs. Recall that the 

1183 # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false` 

1184 # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful 

1185 # to overwrite `x` with ones only when we will never actually use this 

1186 # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`. 

1187 threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2. 

1188 is_too_small = math_ops.less(x, np.exp(threshold)) 

1189 is_too_large = math_ops.greater(x, -threshold) 

1190 too_small_value = math_ops.log(x) 

1191 too_large_value = x 

1192 # This `where` will ultimately be a NOP because we won't select this 

1193 # codepath whenever we used the surrogate `ones_like`. 

1194 x = array_ops.where_v2( 

1195 math_ops.logical_or(is_too_small, is_too_large), array_ops.ones_like(x), 

1196 x) 

1197 y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x)) 

1198 return array_ops.where_v2( 

1199 is_too_small, too_small_value, 

1200 array_ops.where_v2(is_too_large, too_large_value, y)) 

1201 

1202 

1203# TODO(b/35290280): Add unit-tests. 

1204def dimension_size(x, axis): 

1205 """Returns the size of a specific dimension.""" 

1206 # Since tf.gather isn't "constant-in, constant-out", we must first check the 

1207 # static shape or fallback to dynamic shape. 

1208 s = tensor_shape.dimension_value( 

1209 x.shape.with_rank_at_least(np.abs(axis))[axis]) 

1210 if s is not None: 

1211 return s 

1212 return array_ops.shape(x)[axis] 

1213 

1214 

1215def process_quadrature_grid_and_probs(quadrature_grid_and_probs, 

1216 dtype, 

1217 validate_args, 

1218 name=None): 

1219 """Validates quadrature grid, probs or computes them as necessary. 

1220 

1221 Args: 

1222 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s 

1223 representing the sample points and the corresponding (possibly 

1224 normalized) weight. When `None`, defaults to: 

1225 `np.polynomial.hermite.hermgauss(deg=8)`. 

1226 dtype: The expected `dtype` of `grid` and `probs`. 

1227 validate_args: Python `bool`, default `False`. When `True` distribution 

1228 parameters are checked for validity despite possibly degrading runtime 

1229 performance. When `False` invalid inputs may silently render incorrect 

1230 outputs. 

1231 name: Python `str` name prefixed to Ops created by this class. 

1232 

1233 Returns: 

1234 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s 

1235 representing the sample points and the corresponding (possibly 

1236 normalized) weight. 

1237 

1238 Raises: 

1239 ValueError: if `quadrature_grid_and_probs is not None` and 

1240 `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` 

1241 """ 

1242 with ops.name_scope(name, "process_quadrature_grid_and_probs", 

1243 [quadrature_grid_and_probs]): 

1244 if quadrature_grid_and_probs is None: 

1245 grid, probs = np.polynomial.hermite.hermgauss(deg=8) 

1246 grid = grid.astype(dtype.as_numpy_dtype) 

1247 probs = probs.astype(dtype.as_numpy_dtype) 

1248 probs /= np.linalg.norm(probs, ord=1, keepdims=True) 

1249 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) 

1250 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) 

1251 return grid, probs 

1252 

1253 grid, probs = tuple(quadrature_grid_and_probs) 

1254 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) 

1255 probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype) 

1256 probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs") 

1257 

1258 def _static_event_size(x): 

1259 """Returns the static size of a specific dimension or `None`.""" 

1260 return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1]) 

1261 

1262 m, n = _static_event_size(probs), _static_event_size(grid) 

1263 if m is not None and n is not None: 

1264 if m != n: 

1265 raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " 

1266 "same-length zero-th-dimension `Tensor`s " 

1267 "(saw lengths {}, {})".format(m, n)) 

1268 elif validate_args: 

1269 assertions = [ 

1270 check_ops.assert_equal( 

1271 dimension_size(probs, axis=-1), 

1272 dimension_size(grid, axis=-1), 

1273 message=("`quadrature_grid_and_probs` must be a `tuple` of " 

1274 "same-length zero-th-dimension `Tensor`s")), 

1275 ] 

1276 with ops.control_dependencies(assertions): 

1277 grid = array_ops.identity(grid) 

1278 probs = array_ops.identity(probs) 

1279 return grid, probs 

1280 

1281 

1282def pad(x, axis, front=False, back=False, value=0, count=1, name=None): 

1283 """Pads `value` to the front and/or back of a `Tensor` dim, `count` times. 

1284 

1285 Args: 

1286 x: `Tensor` input. 

1287 axis: Scalar `int`-like `Tensor` representing the single dimension to pad. 

1288 (Negative indexing is supported.) 

1289 front: Python `bool`; if `True` the beginning of the `axis` dimension is 

1290 padded with `value`, `count` times. If `False` no front padding is made. 

1291 back: Python `bool`; if `True` the end of the `axis` dimension is padded 

1292 with `value`, `count` times. If `False` no end padding is made. 

1293 value: Scalar `int`-like `Tensor` representing the actual value added to the 

1294 front and/or back of the `axis` dimension of `x`. 

1295 count: Scalar `int`-like `Tensor` representing number of elements added to 

1296 the front and/or back of the `axis` dimension of `x`. E.g., if `front = 

1297 back = True` then `2 * count` elements are added. 

1298 name: Python `str` name prefixed to Ops created by this function. 

1299 

1300 Returns: 

1301 pad: The padded version of input `x`. 

1302 

1303 Raises: 

1304 ValueError: if both `front` and `back` are `False`. 

1305 TypeError: if `count` is not `int`-like. 

1306 """ 

1307 with ops.name_scope(name, "pad", [x, value, count]): 

1308 x = ops.convert_to_tensor(x, name="x") 

1309 value = ops.convert_to_tensor(value, dtype=x.dtype, name="value") 

1310 count = ops.convert_to_tensor(count, name="count") 

1311 if not count.dtype.is_integer: 

1312 raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format( 

1313 count.dtype.name)) 

1314 if not front and not back: 

1315 raise ValueError("At least one of `front`, `back` must be `True`.") 

1316 ndims = ( 

1317 x.shape.ndims if x.shape.ndims is not None else array_ops.rank( 

1318 x, name="ndims")) 

1319 axis = ops.convert_to_tensor(axis, name="axis") 

1320 axis_ = tensor_util.constant_value(axis) 

1321 if axis_ is not None: 

1322 axis = axis_ 

1323 if axis < 0: 

1324 axis = ndims + axis 

1325 count_ = tensor_util.constant_value(count) 

1326 if axis_ >= 0 or x.shape.ndims is not None: 

1327 head = x.shape[:axis] 

1328 middle = tensor_shape.TensorShape(None if count_ is None else ( 

1329 tensor_shape.dimension_at_index(x.shape, axis) + count_ * 

1330 (front + back))) 

1331 tail = x.shape[axis + 1:] 

1332 final_shape = head.concatenate(middle.concatenate(tail)) 

1333 else: 

1334 final_shape = None 

1335 else: 

1336 axis = array_ops.where_v2(axis < 0, ndims + axis, axis) 

1337 final_shape = None 

1338 x = array_ops.pad( 

1339 x, 

1340 paddings=array_ops.one_hot( 

1341 indices=array_ops_stack.stack( 

1342 [axis if front else -1, axis if back else -1]), 

1343 depth=ndims, 

1344 axis=0, 

1345 on_value=count, 

1346 dtype=dtypes.int32), 

1347 constant_values=value) 

1348 if final_shape is not None: 

1349 x.set_shape(final_shape) 

1350 return x 

1351 

1352 

1353def parent_frame_arguments(): 

1354 """Returns parent frame arguments. 

1355 

1356 When called inside a function, returns a dictionary with the caller's function 

1357 arguments. These are positional arguments and keyword arguments (**kwargs), 

1358 while variable arguments (*varargs) are excluded. 

1359 

1360 When called at global scope, this will return an empty dictionary, since there 

1361 are no arguments. 

1362 

1363 WARNING: If caller function argument names are overloaded before invoking 

1364 this method, then values will reflect the overloaded value. For this reason, 

1365 we recommend calling `parent_frame_arguments` at the beginning of the 

1366 function. 

1367 """ 

1368 # All arguments and the names used for *varargs, and **kwargs 

1369 arg_names, variable_arg_name, keyword_arg_name, local_vars = ( 

1370 tf_inspect._inspect.getargvalues( # pylint: disable=protected-access 

1371 # Get the first frame of the caller of this method. 

1372 tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access 

1373 

1374 # Remove the *varargs, and flatten the **kwargs. Both are 

1375 # nested lists. 

1376 local_vars.pop(variable_arg_name, {}) 

1377 keyword_args = local_vars.pop(keyword_arg_name, {}) 

1378 

1379 final_args = {} 

1380 # Copy over arguments and their values. In general, local_vars 

1381 # may contain more than just the arguments, since this method 

1382 # can be called anywhere in a function. 

1383 for arg_name in arg_names: 

1384 final_args[arg_name] = local_vars.pop(arg_name) 

1385 final_args.update(keyword_args) 

1386 

1387 return final_args 

1388 

1389 

1390class AppendDocstring: 

1391 """Helper class to promote private subclass docstring to public counterpart. 

1392 

1393 Example: 

1394 

1395 ```python 

1396 class TransformedDistribution(Distribution): 

1397 @distribution_util.AppendDocstring( 

1398 additional_note="A special note!", 

1399 kwargs_dict={"foo": "An extra arg."}) 

1400 def _prob(self, y, foo=None): 

1401 pass 

1402 ``` 

1403 

1404 In this case, the `AppendDocstring` decorator appends the `additional_note` to 

1405 the docstring of `prob` (not `_prob`) and adds a new `kwargs` 

1406 section with each dictionary item as a bullet-point. 

1407 

1408 For a more detailed example, see `TransformedDistribution`. 

1409 """ 

1410 

1411 def __init__(self, additional_note="", kwargs_dict=None): 

1412 """Initializes the AppendDocstring object. 

1413 

1414 Args: 

1415 additional_note: Python string added as additional docstring to public 

1416 version of function. 

1417 kwargs_dict: Python string/string dictionary representing specific kwargs 

1418 expanded from the **kwargs input. 

1419 

1420 Raises: 

1421 ValueError: if kwargs_dict.key contains whitespace. 

1422 ValueError: if kwargs_dict.value contains newlines. 

1423 """ 

1424 self._additional_note = additional_note 

1425 if kwargs_dict: 

1426 bullets = [] 

1427 for key in sorted(kwargs_dict.keys()): 

1428 value = kwargs_dict[key] 

1429 if any(x.isspace() for x in key): 

1430 raise ValueError("Parameter name \"%s\" contains whitespace." % key) 

1431 value = value.lstrip() 

1432 if "\n" in value: 

1433 raise ValueError( 

1434 "Parameter description for \"%s\" contains newlines." % key) 

1435 bullets.append("* `%s`: %s" % (key, value)) 

1436 self._additional_note += ("\n\n##### `kwargs`:\n\n" + "\n".join(bullets)) 

1437 

1438 def __call__(self, fn): 

1439 

1440 @functools.wraps(fn) 

1441 def _fn(*args, **kwargs): 

1442 return fn(*args, **kwargs) 

1443 

1444 if _fn.__doc__ is None: 

1445 _fn.__doc__ = self._additional_note 

1446 else: 

1447 _fn.__doc__ += "\n%s" % self._additional_note 

1448 return _fn