Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/util.py: 14%
437 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for probability distributions."""
17import functools
18import hashlib
20import numpy as np
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import array_ops_stack
29from tensorflow.python.ops import check_ops
30from tensorflow.python.ops import cond as tf_cond
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import linalg_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import nn
35from tensorflow.python.util import tf_inspect
38def assert_integer_form(x,
39 data=None,
40 summarize=None,
41 message=None,
42 int_dtype=None,
43 name="assert_integer_form"):
44 """Assert that x has integer components (or floats equal to integers).
46 Args:
47 x: Floating-point `Tensor`
48 data: The tensors to print out if the condition is `False`. Defaults to
49 error message and first few entries of `x` and `y`.
50 summarize: Print this many entries of each tensor.
51 message: A string to prefix to the default message.
52 int_dtype: A `tf.dtype` used to cast the float to. The default (`None`)
53 implies the smallest possible signed int will be used for casting.
54 name: A name for this operation (optional).
56 Returns:
57 Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`.
58 """
59 with ops.name_scope(name, values=[x, data]):
60 x = ops.convert_to_tensor(x, name="x")
61 if x.dtype.is_integer:
62 return control_flow_ops.no_op()
63 message = message or "{} has non-integer components".format(x)
64 if int_dtype is None:
65 try:
66 int_dtype = {
67 dtypes.float16: dtypes.int16,
68 dtypes.float32: dtypes.int32,
69 dtypes.float64: dtypes.int64,
70 }[x.dtype.base_dtype]
71 except KeyError:
72 raise TypeError("Unrecognized type {}".format(x.dtype.name))
73 return check_ops.assert_equal(
74 x,
75 math_ops.cast(math_ops.cast(x, int_dtype), x.dtype),
76 data=data,
77 summarize=summarize,
78 message=message,
79 name=name)
82def assert_symmetric(matrix):
83 matrix_t = array_ops.matrix_transpose(matrix)
84 return control_flow_ops.with_dependencies(
85 [check_ops.assert_equal(matrix, matrix_t)], matrix)
88def embed_check_nonnegative_integer_form(
89 x, name="embed_check_nonnegative_integer_form"):
90 """Assert x is a non-negative tensor, and optionally of integers."""
91 with ops.name_scope(name, values=[x]):
92 x = ops.convert_to_tensor(x, name="x")
93 assertions = [
94 check_ops.assert_non_negative(
95 x, message="'{}' must be non-negative.".format(x)),
96 ]
97 if not x.dtype.is_integer:
98 assertions += [
99 assert_integer_form(
100 x,
101 message="'{}' cannot contain fractional components.".format(x)),
102 ]
103 return control_flow_ops.with_dependencies(assertions, x)
106def same_dynamic_shape(a, b):
107 """Returns whether a and b have the same dynamic shape.
109 Args:
110 a: `Tensor`
111 b: `Tensor`
113 Returns:
114 `bool` `Tensor` representing if both tensors have the same shape.
115 """
116 a = ops.convert_to_tensor(a, name="a")
117 b = ops.convert_to_tensor(b, name="b")
119 # Here we can't just do math_ops.equal(a.shape, b.shape), since
120 # static shape inference may break the equality comparison between
121 # shape(a) and shape(b) in math_ops.equal.
122 def all_shapes_equal():
123 return math_ops.reduce_all(
124 math_ops.equal(
125 array_ops.concat(
126 [array_ops.shape(a), array_ops.shape(b)], 0),
127 array_ops.concat(
128 [array_ops.shape(b), array_ops.shape(a)], 0)))
130 # One of the shapes isn't fully defined, so we need to use the dynamic
131 # shape.
132 return tf_cond.cond(
133 math_ops.equal(array_ops.rank(a), array_ops.rank(b)),
134 all_shapes_equal, lambda: constant_op.constant(False))
137def maybe_get_static_value(x, dtype=None):
138 """Helper which tries to return a static value.
140 Given `x`, extract it's value statically, optionally casting to a specific
141 dtype. If this is not possible, None is returned.
143 Args:
144 x: `Tensor` for which to extract a value statically.
145 dtype: Optional dtype to cast to.
147 Returns:
148 Statically inferred value if possible, otherwise None.
149 """
150 if x is None:
151 return x
152 try:
153 # This returns an np.ndarray.
154 x_ = tensor_util.constant_value(x)
155 except TypeError:
156 x_ = x
157 if x_ is None or dtype is None:
158 return x_
159 return np.array(x_, dtype)
162def get_logits_and_probs(logits=None,
163 probs=None,
164 multidimensional=False,
165 validate_args=False,
166 name="get_logits_and_probs",
167 dtype=None):
168 """Converts logit to probabilities (or vice-versa), and returns both.
170 Args:
171 logits: Floating-point `Tensor` representing log-odds.
172 probs: Floating-point `Tensor` representing probabilities.
173 multidimensional: Python `bool`, default `False`. If `True`, represents
174 whether the last dimension of `logits` or `probs`, a `[N1, N2, ... k]`
175 dimensional tensor, representing the logit or probability of `shape[-1]`
176 classes.
177 validate_args: Python `bool`, default `False`. When `True`, either assert `0
178 <= probs <= 1` (if not `multidimensional`) or that the last dimension of
179 `probs` sums to one.
180 name: A name for this operation (optional).
181 dtype: `tf.DType` to prefer when converting args to `Tensor`s.
183 Returns:
184 logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
185 `1`, then the corresponding entry in the returned logit will be `-Inf` and
186 `Inf` respectively.
188 Raises:
189 ValueError: if neither `probs` nor `logits` were passed in, or both were.
190 """
191 with ops.name_scope(name, values=[probs, logits]):
192 if (probs is None) == (logits is None):
193 raise ValueError("Must pass probs or logits, but not both.")
195 if probs is None:
196 logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
197 if not logits.dtype.is_floating:
198 raise TypeError("logits must having floating type.")
199 # We can early return since we constructed probs and therefore know
200 # they're valid.
201 if multidimensional:
202 if validate_args:
203 logits = embed_check_categorical_event_shape(logits)
204 return logits, nn.softmax(logits, name="probs")
205 return logits, math_ops.sigmoid(logits, name="probs")
207 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
208 if not probs.dtype.is_floating:
209 raise TypeError("probs must having floating type.")
211 if validate_args:
212 with ops.name_scope("validate_probs"):
213 one = constant_op.constant(1., probs.dtype)
214 dependencies = [check_ops.assert_non_negative(probs)]
215 if multidimensional:
216 probs = embed_check_categorical_event_shape(probs)
217 dependencies += [
218 check_ops.assert_near(
219 math_ops.reduce_sum(probs, -1),
220 one,
221 message="probs does not sum to 1.")
222 ]
223 else:
224 dependencies += [
225 check_ops.assert_less_equal(
226 probs, one, message="probs has components greater than 1.")
227 ]
228 probs = control_flow_ops.with_dependencies(dependencies, probs)
230 with ops.name_scope("logits"):
231 if multidimensional:
232 # Here we don't compute the multidimensional case, in a manner
233 # consistent with respect to the unidimensional case. We do so
234 # following the TF convention. Typically, you might expect to see
235 # logits = log(probs) - log(probs[pivot]). A side-effect of
236 # being consistent with the TF approach is that the unidimensional case
237 # implicitly handles the second dimension but the multidimensional case
238 # explicitly keeps the pivot dimension.
239 return math_ops.log(probs), probs
240 return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs
243def _is_known_unsigned_by_dtype(dt):
244 """Helper returning True if dtype is known to be unsigned."""
245 return {
246 dtypes.bool: True,
247 dtypes.uint8: True,
248 dtypes.uint16: True,
249 }.get(dt.base_dtype, False)
252def _is_known_signed_by_dtype(dt):
253 """Helper returning True if dtype is known to be signed."""
254 return {
255 dtypes.float16: True,
256 dtypes.float32: True,
257 dtypes.float64: True,
258 dtypes.int8: True,
259 dtypes.int16: True,
260 dtypes.int32: True,
261 dtypes.int64: True,
262 }.get(dt.base_dtype, False)
265def _is_known_dtype(dt):
266 """Helper returning True if dtype is known."""
267 return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt)
270def _largest_integer_by_dtype(dt):
271 """Helper returning the largest integer exactly representable by dtype."""
272 if not _is_known_dtype(dt):
273 raise TypeError("Unrecognized dtype: {}".format(dt.name))
274 if dt.is_floating:
275 return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1))
276 if dt.is_integer:
277 return np.iinfo(dt.as_numpy_dtype).max
278 if dt.base_dtype == dtypes.bool:
279 return int(1)
280 # We actually can't land here but keep the case for completeness.
281 raise TypeError("Unrecognized dtype: {}".format(dt.name))
284def _smallest_integer_by_dtype(dt):
285 """Helper returning the smallest integer exactly representable by dtype."""
286 if not _is_known_dtype(dt):
287 raise TypeError("Unrecognized dtype: {}".format(dt.name))
288 if _is_known_unsigned_by_dtype(dt):
289 return 0
290 return -1 * _largest_integer_by_dtype(dt)
293def _is_integer_like_by_dtype(dt):
294 """Helper returning True if dtype.is_integer or is `bool`."""
295 if not _is_known_dtype(dt):
296 raise TypeError("Unrecognized dtype: {}".format(dt.name))
297 return dt.is_integer or dt.base_dtype == dtypes.bool
300def embed_check_categorical_event_shape(
301 categorical_param, name="embed_check_categorical_event_shape"):
302 """Embeds checks that categorical distributions don't have too many classes.
304 A categorical-type distribution is one which, e.g., returns the class label
305 rather than a one-hot encoding. E.g., `Categorical(probs)`.
307 Since distributions output samples in the same dtype as the parameters, we
308 must ensure that casting doesn't lose precision. That is, the
309 `parameter.dtype` implies a maximum number of classes. However, since shape is
310 `int32` and categorical variables are presumed to be indexes into a `Tensor`,
311 we must also ensure that the number of classes is no larger than the largest
312 possible `int32` index, i.e., `2**31-1`.
314 In other words the number of classes, `K`, must satisfy the following
315 condition:
317 ```python
318 K <= min(
319 int(2**31 - 1), # Largest float as an index.
320 {
321 dtypes.float16: int(2**11), # Largest int as a float16.
322 dtypes.float32: int(2**24),
323 dtypes.float64: int(2**53),
324 }.get(categorical_param.dtype.base_dtype, 0))
325 ```
327 Args:
328 categorical_param: Floating-point `Tensor` representing parameters of
329 distribution over categories. The rightmost shape is presumed to be the
330 number of categories.
331 name: A name for this operation (optional).
333 Returns:
334 categorical_param: Input `Tensor` with appropriate assertions embedded.
336 Raises:
337 TypeError: if `categorical_param` has an unknown `dtype`.
338 ValueError: if we can statically identify `categorical_param` as being too
339 large (for being closed under int32/float casting).
340 """
341 with ops.name_scope(name, values=[categorical_param]):
342 x = ops.convert_to_tensor(categorical_param, name="categorical_param")
343 # The size must not exceed both of:
344 # - The largest possible int32 (since categorical values are presumed to be
345 # indexes into a Tensor).
346 # - The largest possible integer exactly representable under the given
347 # floating-point dtype (since we need to cast to/from).
348 #
349 # The chosen floating-point thresholds are 2**(1 + mantissa_bits).
350 # For more details, see:
351 # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation
352 x_dtype = x.dtype.base_dtype
353 max_event_size = (
354 _largest_integer_by_dtype(x_dtype) if x_dtype.is_floating else 0)
355 if max_event_size == 0:
356 raise TypeError("Unable to validate size of unrecognized dtype "
357 "({}).".format(x_dtype.name))
358 try:
359 x_shape_static = x.get_shape().with_rank_at_least(1)
360 except ValueError:
361 raise ValueError("A categorical-distribution parameter must have "
362 "at least 1 dimension.")
363 if tensor_shape.dimension_value(x_shape_static[-1]) is not None:
364 event_size = x_shape_static.dims[-1].value
365 if event_size < 2:
366 raise ValueError("A categorical-distribution parameter must have at "
367 "least 2 events.")
368 if event_size > max_event_size:
369 raise ValueError("Number of classes exceeds `dtype` precision, i.e., "
370 "{} implies shape ({}) cannot exceed {}.".format(
371 x_dtype.name, event_size, max_event_size))
372 return x
373 else:
374 event_size = array_ops.shape(x, name="x_shape")[-1]
375 return control_flow_ops.with_dependencies([
376 check_ops.assert_rank_at_least(
377 x,
378 1,
379 message=("A categorical-distribution parameter must have "
380 "at least 1 dimension.")),
381 check_ops.assert_greater_equal(
382 array_ops.shape(x)[-1],
383 2,
384 message=("A categorical-distribution parameter must have at "
385 "least 2 events.")),
386 check_ops.assert_less_equal(
387 event_size,
388 max_event_size,
389 message="Number of classes exceeds `dtype` precision, "
390 "i.e., {} dtype cannot exceed {} shape.".format(
391 x_dtype.name, max_event_size)),
392 ], x)
395def embed_check_integer_casting_closed(x,
396 target_dtype,
397 assert_nonnegative=True,
398 name="embed_check_casting_closed"):
399 """Ensures integers remain unaffected despite casting to/from int/float types.
401 Example integer-types: `uint8`, `int32`, `bool`.
402 Example floating-types: `float32`, `float64`.
404 The largest possible integer representable by an IEEE754 floating-point is
405 `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is
406 `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have
407 integer-form values can be cast to some other type without loss of precision.
409 The smallest representable integer is the negative of the largest
410 representable integer, except for types: `uint8`, `uint16`, `bool`. For these
411 types, the smallest representable integer is `0`.
413 Args:
414 x: `Tensor` representing integer-form values.
415 target_dtype: TF `dtype` under which `x` should have identical values.
416 assert_nonnegative: `bool` indicating `x` should contain nonnegative values.
417 name: A name for this operation (optional).
419 Returns:
420 x: Input `Tensor` with appropriate assertions embedded.
422 Raises:
423 TypeError: if `x` is neither integer- nor floating-type.
424 TypeError: if `target_dtype` is neither integer- nor floating-type.
425 TypeError: if neither `x` nor `target_dtype` are integer-type.
426 """
428 with ops.name_scope(name, values=[x]):
429 x = ops.convert_to_tensor(x, name="x")
430 if (not _is_integer_like_by_dtype(x.dtype) and not x.dtype.is_floating):
431 raise TypeError("{}.dtype must be floating- or "
432 "integer-type.".format(x.dtype.name))
433 if (not _is_integer_like_by_dtype(target_dtype) and
434 not target_dtype.is_floating):
435 raise TypeError("target_dtype ({}) must be floating- or "
436 "integer-type.".format(target_dtype.name))
437 if (not _is_integer_like_by_dtype(x.dtype) and
438 not _is_integer_like_by_dtype(target_dtype)):
439 raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) "
440 "must be integer-type.".format(x, x.dtype.name,
441 target_dtype.name))
443 assertions = []
444 if assert_nonnegative:
445 assertions += [
446 check_ops.assert_non_negative(
447 x, message="Elements must be non-negative."),
448 ]
450 if x.dtype.is_floating:
451 # Being here means _is_integer_like_by_dtype(target_dtype) = True.
452 # Since this check implies the magnitude check below, we need only it.
453 assertions += [
454 assert_integer_form(
455 x,
456 int_dtype=target_dtype,
457 message="Elements must be {}-equivalent.".format(
458 target_dtype.name)),
459 ]
460 else:
461 if (_largest_integer_by_dtype(x.dtype) >
462 _largest_integer_by_dtype(target_dtype)):
463 # Cast may lose integer precision.
464 assertions += [
465 check_ops.assert_less_equal(
466 x,
467 _largest_integer_by_dtype(target_dtype),
468 message=("Elements cannot exceed {}.".format(
469 _largest_integer_by_dtype(target_dtype)))),
470 ]
471 if (not assert_nonnegative and (_smallest_integer_by_dtype(
472 x.dtype) < _smallest_integer_by_dtype(target_dtype))):
473 assertions += [
474 check_ops.assert_greater_equal(
475 x,
476 _smallest_integer_by_dtype(target_dtype),
477 message=("Elements cannot be smaller than {}.".format(
478 _smallest_integer_by_dtype(target_dtype)))),
479 ]
481 if not assertions:
482 return x
483 return control_flow_ops.with_dependencies(assertions, x)
486def log_combinations(n, counts, name="log_combinations"):
487 """Multinomial coefficient.
489 Given `n` and `counts`, where `counts` has last dimension `k`, we compute
490 the multinomial coefficient as:
492 ```n! / sum_i n_i!```
494 where `i` runs over all `k` classes.
496 Args:
497 n: Floating-point `Tensor` broadcastable with `counts`. This represents `n`
498 outcomes.
499 counts: Floating-point `Tensor` broadcastable with `n`. This represents
500 counts in `k` classes, where `k` is the last dimension of the tensor.
501 name: A name for this operation (optional).
503 Returns:
504 `Tensor` representing the multinomial coefficient between `n` and `counts`.
505 """
506 # First a bit about the number of ways counts could have come in:
507 # E.g. if counts = [1, 2], then this is 3 choose 2.
508 # In general, this is (sum counts)! / sum(counts!)
509 # The sum should be along the last dimension of counts. This is the
510 # "distribution" dimension. Here n a priori represents the sum of counts.
511 with ops.name_scope(name, values=[n, counts]):
512 n = ops.convert_to_tensor(n, name="n")
513 counts = ops.convert_to_tensor(counts, name="counts")
514 total_permutations = math_ops.lgamma(n + 1)
515 counts_factorial = math_ops.lgamma(counts + 1)
516 redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1])
517 return total_permutations - redundant_permutations
520def matrix_diag_transform(matrix, transform=None, name=None):
521 """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.
523 Create a trainable covariance defined by a Cholesky factor:
525 ```python
526 # Transform network layer into 2 x 2 array.
527 matrix_values = tf.contrib.layers.fully_connected(activations, 4)
528 matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
530 # Make the diagonal positive. If the upper triangle was zero, this would be a
531 # valid Cholesky factor.
532 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
534 # LinearOperatorLowerTriangular ignores the upper triangle.
535 operator = LinearOperatorLowerTriangular(chol)
536 ```
538 Example of heteroskedastic 2-D linear regression.
540 ```python
541 tfd = tfp.distributions
543 # Get a trainable Cholesky factor.
544 matrix_values = tf.contrib.layers.fully_connected(activations, 4)
545 matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
546 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
548 # Get a trainable mean.
549 mu = tf.contrib.layers.fully_connected(activations, 2)
551 # This is a fully trainable multivariate normal!
552 dist = tfd.MultivariateNormalTriL(mu, chol)
554 # Standard log loss. Minimizing this will "train" mu and chol, and then dist
555 # will be a distribution predicting labels as multivariate Gaussians.
556 loss = -1 * tf.reduce_mean(dist.log_prob(labels))
557 ```
559 Args:
560 matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
561 equal.
562 transform: Element-wise function mapping `Tensors` to `Tensors`. To be
563 applied to the diagonal of `matrix`. If `None`, `matrix` is returned
564 unchanged. Defaults to `None`.
565 name: A name to give created ops. Defaults to "matrix_diag_transform".
567 Returns:
568 A `Tensor` with same shape and `dtype` as `matrix`.
569 """
570 with ops.name_scope(name, "matrix_diag_transform", [matrix]):
571 matrix = ops.convert_to_tensor(matrix, name="matrix")
572 if transform is None:
573 return matrix
574 # Replace the diag with transformed diag.
575 diag = array_ops.matrix_diag_part(matrix)
576 transformed_diag = transform(diag)
577 transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag)
579 return transformed_mat
582def rotate_transpose(x, shift, name="rotate_transpose"):
583 """Circularly moves dims left or right.
585 Effectively identical to:
587 ```python
588 numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift))
589 ```
591 When `validate_args=False` additional graph-runtime checks are
592 performed. These checks entail moving data from to GPU to CPU.
594 Example:
596 ```python
597 x = tf.random.normal([1, 2, 3, 4]) # Tensor of shape [1, 2, 3, 4].
598 rotate_transpose(x, -1).shape == [2, 3, 4, 1]
599 rotate_transpose(x, -2).shape == [3, 4, 1, 2]
600 rotate_transpose(x, 1).shape == [4, 1, 2, 3]
601 rotate_transpose(x, 2).shape == [3, 4, 1, 2]
602 rotate_transpose(x, 7).shape == rotate_transpose(x, 3).shape # [2, 3, 4, 1]
603 rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape # [4, 1, 2, 3]
604 ```
606 Args:
607 x: `Tensor`.
608 shift: `Tensor`. Number of dimensions to transpose left (shift<0) or
609 transpose right (shift>0).
610 name: Python `str`. The name to give this op.
612 Returns:
613 rotated_x: Input `Tensor` with dimensions circularly rotated by shift.
615 Raises:
616 TypeError: if shift is not integer type.
617 """
618 with ops.name_scope(name, values=[x, shift]):
619 x = ops.convert_to_tensor(x, name="x")
620 shift = ops.convert_to_tensor(shift, name="shift")
621 # We do not assign back to preserve constant-ness.
622 check_ops.assert_integer(shift)
623 shift_value_static = tensor_util.constant_value(shift)
624 ndims = x.get_shape().ndims
625 if ndims is not None and shift_value_static is not None:
626 if ndims < 2:
627 return x
628 shift_value_static = np.sign(shift_value_static) * (
629 abs(shift_value_static) % ndims)
630 if shift_value_static == 0:
631 return x
632 perm = np.roll(np.arange(ndims), shift_value_static)
633 return array_ops.transpose(x, perm=perm)
634 else:
635 # Consider if we always had a positive shift, and some specified
636 # direction.
637 # When shifting left we want the new array:
638 # last(x, n-shift) + first(x, shift)
639 # and if shifting right then we want:
640 # last(x, shift) + first(x, n-shift)
641 # Observe that last(a) == slice(a, n) and first(a) == slice(0, a).
642 # Also, we can encode direction and shift as one: direction * shift.
643 # Combining these facts, we have:
644 # a = cond(shift<0, -shift, n-shift)
645 # last(x, n-a) + first(x, a) == x[a:n] + x[0:a]
646 # Finally, we transform shift by modulo length so it can be specified
647 # independently from the array upon which it operates (like python).
648 ndims = array_ops.rank(x)
649 shift = array_ops.where_v2(
650 math_ops.less(shift, 0),
651 math_ops.mod(-shift, ndims), # pylint: disable=invalid-unary-operand-type
652 ndims - math_ops.mod(shift, ndims))
653 first = math_ops.range(0, shift)
654 last = math_ops.range(shift, ndims)
655 perm = array_ops.concat([last, first], 0)
656 return array_ops.transpose(x, perm=perm)
659def pick_vector(cond, true_vector, false_vector, name="pick_vector"):
660 """Picks possibly different length row `Tensor`s based on condition.
662 Value `Tensor`s should have exactly one dimension.
664 If `cond` is a python Boolean or `tf.constant` then either `true_vector` or
665 `false_vector` is immediately returned. I.e., no graph nodes are created and
666 no validation happens.
668 Args:
669 cond: `Tensor`. Must have `dtype=tf.bool` and be scalar.
670 true_vector: `Tensor` of one dimension. Returned when cond is `True`.
671 false_vector: `Tensor` of one dimension. Returned when cond is `False`.
672 name: Python `str`. The name to give this op.
673 Example: ```python pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15,
674 18)) # [10, 11] pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15,
675 18)) # [15, 16, 17] ```
677 Returns:
678 true_or_false_vector: `Tensor`.
680 Raises:
681 TypeError: if `cond.dtype != tf.bool`
682 TypeError: if `cond` is not a constant and
683 `true_vector.dtype != false_vector.dtype`
684 """
685 with ops.name_scope(name, values=(cond, true_vector, false_vector)):
686 cond = ops.convert_to_tensor(cond, name="cond")
687 if cond.dtype != dtypes.bool:
688 raise TypeError("%s.dtype=%s which is not %s" %
689 (cond, cond.dtype, dtypes.bool))
690 cond_value_static = tensor_util.constant_value(cond)
691 if cond_value_static is not None:
692 return true_vector if cond_value_static else false_vector
693 true_vector = ops.convert_to_tensor(true_vector, name="true_vector")
694 false_vector = ops.convert_to_tensor(false_vector, name="false_vector")
695 if true_vector.dtype != false_vector.dtype:
696 raise TypeError(
697 "%s.dtype=%s does not match %s.dtype=%s" %
698 (true_vector, true_vector.dtype, false_vector, false_vector.dtype))
699 n = array_ops.shape(true_vector)[0]
700 return array_ops.slice(
701 array_ops.concat([true_vector, false_vector], 0),
702 [array_ops.where_v2(cond, 0, n)], [array_ops.where(cond, n, -1)])
705def prefer_static_broadcast_shape(shape1,
706 shape2,
707 name="prefer_static_broadcast_shape"):
708 """Convenience function which statically broadcasts shape when possible.
710 Args:
711 shape1: `1-D` integer `Tensor`. Already converted to tensor!
712 shape2: `1-D` integer `Tensor`. Already converted to tensor!
713 name: A string name to prepend to created ops.
715 Returns:
716 The broadcast shape, either as `TensorShape` (if broadcast can be done
717 statically), or as a `Tensor`.
718 """
719 with ops.name_scope(name, values=[shape1, shape2]):
721 def make_shape_tensor(x):
722 return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32)
724 def get_tensor_shape(s):
725 if isinstance(s, tensor_shape.TensorShape):
726 return s
727 s_ = tensor_util.constant_value(make_shape_tensor(s))
728 if s_ is not None:
729 return tensor_shape.TensorShape(s_)
730 return None
732 def get_shape_tensor(s):
733 if not isinstance(s, tensor_shape.TensorShape):
734 return make_shape_tensor(s)
735 if s.is_fully_defined():
736 return make_shape_tensor(s.as_list())
737 raise ValueError("Cannot broadcast from partially "
738 "defined `TensorShape`.")
740 shape1_ = get_tensor_shape(shape1)
741 shape2_ = get_tensor_shape(shape2)
742 if shape1_ is not None and shape2_ is not None:
743 return array_ops.broadcast_static_shape(shape1_, shape2_)
745 shape1_ = get_shape_tensor(shape1)
746 shape2_ = get_shape_tensor(shape2)
747 return array_ops.broadcast_dynamic_shape(shape1_, shape2_)
750def prefer_static_rank(x):
751 """Return static rank of tensor `x` if available, else `tf.rank(x)`.
753 Args:
754 x: `Tensor` (already converted).
756 Returns:
757 Numpy array (if static rank is obtainable), else `Tensor`.
758 """
759 return prefer_static_value(array_ops.rank(x))
762def prefer_static_shape(x):
763 """Return static shape of tensor `x` if available, else `tf.shape(x)`.
765 Args:
766 x: `Tensor` (already converted).
768 Returns:
769 Numpy array (if static shape is obtainable), else `Tensor`.
770 """
771 return prefer_static_value(array_ops.shape(x))
774def prefer_static_value(x):
775 """Return static value of tensor `x` if available, else `x`.
777 Args:
778 x: `Tensor` (already converted).
780 Returns:
781 Numpy array (if static value is obtainable), else `Tensor`.
782 """
783 static_x = tensor_util.constant_value(x)
784 if static_x is not None:
785 return static_x
786 return x
789def gen_new_seed(seed, salt):
790 """Generate a new seed, from the given seed and salt."""
791 if seed is None:
792 return None
793 string = (str(seed) + salt).encode("utf-8")
794 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
797def fill_triangular(x, upper=False, name=None):
798 """Creates a (batch of) triangular matrix from a vector of inputs.
800 Created matrix can be lower- or upper-triangular. (It is more efficient to
801 create the matrix as upper or lower, rather than transpose.)
803 Triangular matrix elements are filled in a clockwise spiral. See example,
804 below.
806 If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is
807 `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
808 `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
810 Example:
812 ```python
813 fill_triangular([1, 2, 3, 4, 5, 6])
814 # ==> [[4, 0, 0],
815 # [6, 5, 0],
816 # [3, 2, 1]]
818 fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
819 # ==> [[1, 2, 3],
820 # [0, 5, 6],
821 # [0, 0, 4]]
822 ```
824 For comparison, a pure numpy version of this function can be found in
825 `util_test.py`, function `_fill_triangular`.
827 Args:
828 x: `Tensor` representing lower (or upper) triangular elements.
829 upper: Python `bool` representing whether output matrix should be upper
830 triangular (`True`) or lower triangular (`False`, default).
831 name: Python `str`. The name to give this op.
833 Returns:
834 tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
836 Raises:
837 ValueError: if `x` cannot be mapped to a triangular matrix.
838 """
840 with ops.name_scope(name, "fill_triangular", values=[x]):
841 x = ops.convert_to_tensor(x, name="x")
842 if tensor_shape.dimension_value(
843 x.shape.with_rank_at_least(1)[-1]) is not None:
844 # Formula derived by solving for n: m = n(n+1)/2.
845 m = np.int32(x.shape.dims[-1].value)
846 n = np.sqrt(0.25 + 2. * m) - 0.5
847 if n != np.floor(n):
848 raise ValueError("Input right-most shape ({}) does not "
849 "correspond to a triangular matrix.".format(m))
850 n = np.int32(n)
851 static_final_shape = x.shape[:-1].concatenate([n, n])
852 else:
853 m = array_ops.shape(x)[-1]
854 # For derivation, see above. Casting automatically lops off the 0.5, so we
855 # omit it. We don't validate n is an integer because this has
856 # graph-execution cost; an error will be thrown from the reshape, below.
857 n = math_ops.cast(
858 math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)),
859 dtype=dtypes.int32)
860 static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate(
861 [None, None])
862 # We now concatenate the "tail" of `x` to `x` (and reverse one of them).
863 #
864 # We do this based on the insight that the input `x` provides `ceil(n/2)`
865 # rows of an `n x n` matrix, some of which will get zeroed out being on the
866 # wrong side of the diagonal. The first row will not get zeroed out at all,
867 # and we need `floor(n/2)` more rows, so the first is what we omit from
868 # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)`
869 # rows provided by a reversed tail, it is exactly the other set of elements
870 # of the reversed tail which will be zeroed out for being on the wrong side
871 # of the diagonal further up/down the matrix. And, in doing-so, we've filled
872 # the triangular matrix in a clock-wise spiral pattern. Neat!
873 #
874 # Try it out in numpy:
875 # n = 3
876 # x = np.arange(n * (n + 1) / 2)
877 # m = x.shape[0]
878 # n = np.int32(np.sqrt(.25 + 2 * m) - .5)
879 # x_tail = x[(m - (n**2 - m)):]
880 # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower
881 # # ==> array([[3, 4, 5],
882 # [5, 4, 3],
883 # [2, 1, 0]])
884 # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper
885 # # ==> array([[0, 1, 2],
886 # [3, 4, 5],
887 # [5, 4, 3]])
888 #
889 # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
890 # correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
891 # Furthermore observe that:
892 # m - (n**2 - m)
893 # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
894 # = 2 (n**2 / 2 + n / 2) - n**2
895 # = n**2 + n - n**2
896 # = n
897 ndims = prefer_static_rank(x)
898 if upper:
899 x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
900 else:
901 x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])]
902 new_shape = (
903 static_final_shape.as_list() if static_final_shape.is_fully_defined()
904 else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0))
905 x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape)
906 x = array_ops.matrix_band_part(
907 x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0))
908 x.set_shape(static_final_shape)
909 return x
912def fill_triangular_inverse(x, upper=False, name=None):
913 """Creates a vector from a (batch of) triangular matrix.
915 The vector is created from the lower-triangular or upper-triangular portion
916 depending on the value of the parameter `upper`.
918 If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
919 `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
921 Example:
923 ```python
924 fill_triangular_inverse(
925 [[4, 0, 0],
926 [6, 5, 0],
927 [3, 2, 1]])
929 # ==> [1, 2, 3, 4, 5, 6]
931 fill_triangular_inverse(
932 [[1, 2, 3],
933 [0, 5, 6],
934 [0, 0, 4]], upper=True)
936 # ==> [1, 2, 3, 4, 5, 6]
937 ```
939 Args:
940 x: `Tensor` representing lower (or upper) triangular elements.
941 upper: Python `bool` representing whether output matrix should be upper
942 triangular (`True`) or lower triangular (`False`, default).
943 name: Python `str`. The name to give this op.
945 Returns:
946 flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
947 (or upper) triangular elements from `x`.
948 """
950 with ops.name_scope(name, "fill_triangular_inverse", values=[x]):
951 x = ops.convert_to_tensor(x, name="x")
952 if tensor_shape.dimension_value(
953 x.shape.with_rank_at_least(2)[-1]) is not None:
954 n = np.int32(x.shape.dims[-1].value)
955 m = np.int32((n * (n + 1)) // 2)
956 static_final_shape = x.shape[:-2].concatenate([m])
957 else:
958 n = array_ops.shape(x)[-1]
959 m = (n * (n + 1)) // 2
960 static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate(
961 [None])
962 ndims = prefer_static_rank(x)
963 if upper:
964 initial_elements = x[..., 0, :]
965 triangular_portion = x[..., 1:, :]
966 else:
967 initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2])
968 triangular_portion = x[..., :-1, :]
969 rotated_triangular_portion = array_ops.reverse(
970 array_ops.reverse(triangular_portion, axis=[ndims - 1]),
971 axis=[ndims - 2])
972 consolidated_matrix = triangular_portion + rotated_triangular_portion
973 end_sequence = array_ops.reshape(
974 consolidated_matrix,
975 array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0))
976 y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
977 y.set_shape(static_final_shape)
978 return y
981def tridiag(below=None, diag=None, above=None, name=None):
982 """Creates a matrix with values set above, below, and on the diagonal.
984 Example:
986 ```python
987 tridiag(below=[1., 2., 3.],
988 diag=[4., 5., 6., 7.],
989 above=[8., 9., 10.])
990 # ==> array([[ 4., 8., 0., 0.],
991 # [ 1., 5., 9., 0.],
992 # [ 0., 2., 6., 10.],
993 # [ 0., 0., 3., 7.]], dtype=float32)
994 ```
996 Warning: This Op is intended for convenience, not efficiency.
998 Args:
999 below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below
1000 diagonal part. `None` is logically equivalent to `below = 0`.
1001 diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal
1002 part. `None` is logically equivalent to `diag = 0`.
1003 above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above
1004 diagonal part. `None` is logically equivalent to `above = 0`.
1005 name: Python `str`. The name to give this op.
1007 Returns:
1008 tridiag: `Tensor` with values set above, below and on the diagonal.
1010 Raises:
1011 ValueError: if all inputs are `None`.
1012 """
1014 def _pad(x):
1015 """Prepends and appends a zero to every vector in a batch of vectors."""
1016 shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0)
1017 z = array_ops.zeros(shape, dtype=x.dtype)
1018 return array_ops.concat([z, x, z], axis=-1)
1020 def _add(*x):
1021 """Adds list of Tensors, ignoring `None`."""
1022 s = None
1023 for y in x:
1024 if y is None:
1025 continue
1026 elif s is None:
1027 s = y
1028 else:
1029 s += y
1030 if s is None:
1031 raise ValueError("Must specify at least one of `below`, `diag`, `above`.")
1032 return s
1034 with ops.name_scope(name, "tridiag", [below, diag, above]):
1035 if below is not None:
1036 below = ops.convert_to_tensor(below, name="below")
1037 below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:]
1038 if diag is not None:
1039 diag = ops.convert_to_tensor(diag, name="diag")
1040 diag = array_ops.matrix_diag(diag)
1041 if above is not None:
1042 above = ops.convert_to_tensor(above, name="above")
1043 above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1]
1044 # TODO(jvdillon): Consider using scatter_nd instead of creating three full
1045 # matrices.
1046 return _add(below, diag, above)
1049def reduce_weighted_logsumexp(logx,
1050 w=None,
1051 axis=None,
1052 keep_dims=False,
1053 return_sign=False,
1054 name=None):
1055 """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.
1057 If all weights `w` are known to be positive, it is more efficient to directly
1058 use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.math.log(w))` is
1059 more
1060 efficient than `du.reduce_weighted_logsumexp(logx, w)`.
1062 Reduces `input_tensor` along the dimensions given in `axis`.
1063 Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
1064 entry in `axis`. If `keep_dims` is true, the reduced dimensions
1065 are retained with length 1.
1067 If `axis` has no entries, all dimensions are reduced, and a
1068 tensor with a single element is returned.
1070 This function is more numerically stable than log(sum(w * exp(input))). It
1071 avoids overflows caused by taking the exp of large inputs and underflows
1072 caused by taking the log of small inputs.
1074 For example:
1076 ```python
1077 x = tf.constant([[0., 0, 0],
1078 [0, 0, 0]])
1080 w = tf.constant([[-1., 1, 1],
1081 [1, 1, 1]])
1083 du.reduce_weighted_logsumexp(x, w)
1084 # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)
1086 du.reduce_weighted_logsumexp(x, w, axis=0)
1087 # ==> [log(-1+1), log(1+1), log(1+1)]
1089 du.reduce_weighted_logsumexp(x, w, axis=1)
1090 # ==> [log(-1+1+1), log(1+1+1)]
1092 du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
1093 # ==> [[log(-1+1+1)], [log(1+1+1)]]
1095 du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
1096 # ==> log(-1+5)
1097 ```
1099 Args:
1100 logx: The tensor to reduce. Should have numeric type.
1101 w: The weight tensor. Should have numeric type identical to `logx`.
1102 axis: The dimensions to reduce. If `None` (the default), reduces all
1103 dimensions. Must be in the range `[-rank(input_tensor),
1104 rank(input_tensor))`.
1105 keep_dims: If true, retains reduced dimensions with length 1.
1106 return_sign: If `True`, returns the sign of the result.
1107 name: A name for the operation (optional).
1109 Returns:
1110 lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
1111 sign: (Optional) The sign of `sum(weight * exp(x))`.
1112 """
1113 with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]):
1114 logx = ops.convert_to_tensor(logx, name="logx")
1115 if w is None:
1116 lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
1117 if return_sign:
1118 sgn = array_ops.ones_like(lswe)
1119 return lswe, sgn
1120 return lswe
1121 w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w")
1122 log_absw_x = logx + math_ops.log(math_ops.abs(w))
1123 max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True)
1124 # If the largest element is `-inf` or `inf` then we don't bother subtracting
1125 # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
1126 # this is ok follows from the fact that we're actually free to subtract any
1127 # value we like, so long as we add it back after taking the `log(sum(...))`.
1128 max_log_absw_x = array_ops.where_v2(
1129 math_ops.is_inf(max_log_absw_x), array_ops.zeros_like(max_log_absw_x),
1130 max_log_absw_x)
1131 wx_over_max_absw_x = (
1132 math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x))
1133 sum_wx_over_max_absw_x = math_ops.reduce_sum(
1134 wx_over_max_absw_x, axis=axis, keepdims=keep_dims)
1135 if not keep_dims:
1136 max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis)
1137 sgn = math_ops.sign(sum_wx_over_max_absw_x)
1138 lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x)
1139 if return_sign:
1140 return lswe, sgn
1141 return lswe
1144# TODO(jvdillon): Merge this test back into:
1145# tensorflow/python/ops/softplus_op_test.py
1146# once TF core is accepting new ops.
1147def softplus_inverse(x, name=None):
1148 """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
1150 Mathematically this op is equivalent to:
1152 ```none
1153 softplus_inverse = log(exp(x) - 1.)
1154 ```
1156 Args:
1157 x: `Tensor`. Non-negative (not enforced), floating-point.
1158 name: A name for the operation (optional).
1160 Returns:
1161 `Tensor`. Has the same type/shape as input `x`.
1162 """
1163 with ops.name_scope(name, "softplus_inverse", values=[x]):
1164 x = ops.convert_to_tensor(x, name="x")
1165 # We begin by deriving a more numerically stable softplus_inverse:
1166 # x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
1167 # ==> exp{x} = 1 + exp{y} (1)
1168 # ==> y = Log[exp{x} - 1] (2)
1169 # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
1170 # = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
1171 # = Log[1 - exp{-x}] + x (3)
1172 # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
1173 # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
1174 # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0.
1175 #
1176 # In addition to the numerically stable derivation above, we clamp
1177 # small/large values to be congruent with the logic in:
1178 # tensorflow/core/kernels/softplus_op.h
1179 #
1180 # Finally, we set the input to one whenever the input is too large or too
1181 # small. This ensures that no unchosen codepath is +/- inf. This is
1182 # necessary to ensure the gradient doesn't get NaNs. Recall that the
1183 # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
1184 # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
1185 # to overwrite `x` with ones only when we will never actually use this
1186 # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
1187 threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2.
1188 is_too_small = math_ops.less(x, np.exp(threshold))
1189 is_too_large = math_ops.greater(x, -threshold)
1190 too_small_value = math_ops.log(x)
1191 too_large_value = x
1192 # This `where` will ultimately be a NOP because we won't select this
1193 # codepath whenever we used the surrogate `ones_like`.
1194 x = array_ops.where_v2(
1195 math_ops.logical_or(is_too_small, is_too_large), array_ops.ones_like(x),
1196 x)
1197 y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x))
1198 return array_ops.where_v2(
1199 is_too_small, too_small_value,
1200 array_ops.where_v2(is_too_large, too_large_value, y))
1203# TODO(b/35290280): Add unit-tests.
1204def dimension_size(x, axis):
1205 """Returns the size of a specific dimension."""
1206 # Since tf.gather isn't "constant-in, constant-out", we must first check the
1207 # static shape or fallback to dynamic shape.
1208 s = tensor_shape.dimension_value(
1209 x.shape.with_rank_at_least(np.abs(axis))[axis])
1210 if s is not None:
1211 return s
1212 return array_ops.shape(x)[axis]
1215def process_quadrature_grid_and_probs(quadrature_grid_and_probs,
1216 dtype,
1217 validate_args,
1218 name=None):
1219 """Validates quadrature grid, probs or computes them as necessary.
1221 Args:
1222 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
1223 representing the sample points and the corresponding (possibly
1224 normalized) weight. When `None`, defaults to:
1225 `np.polynomial.hermite.hermgauss(deg=8)`.
1226 dtype: The expected `dtype` of `grid` and `probs`.
1227 validate_args: Python `bool`, default `False`. When `True` distribution
1228 parameters are checked for validity despite possibly degrading runtime
1229 performance. When `False` invalid inputs may silently render incorrect
1230 outputs.
1231 name: Python `str` name prefixed to Ops created by this class.
1233 Returns:
1234 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
1235 representing the sample points and the corresponding (possibly
1236 normalized) weight.
1238 Raises:
1239 ValueError: if `quadrature_grid_and_probs is not None` and
1240 `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
1241 """
1242 with ops.name_scope(name, "process_quadrature_grid_and_probs",
1243 [quadrature_grid_and_probs]):
1244 if quadrature_grid_and_probs is None:
1245 grid, probs = np.polynomial.hermite.hermgauss(deg=8)
1246 grid = grid.astype(dtype.as_numpy_dtype)
1247 probs = probs.astype(dtype.as_numpy_dtype)
1248 probs /= np.linalg.norm(probs, ord=1, keepdims=True)
1249 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
1250 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
1251 return grid, probs
1253 grid, probs = tuple(quadrature_grid_and_probs)
1254 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
1255 probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype)
1256 probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs")
1258 def _static_event_size(x):
1259 """Returns the static size of a specific dimension or `None`."""
1260 return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1])
1262 m, n = _static_event_size(probs), _static_event_size(grid)
1263 if m is not None and n is not None:
1264 if m != n:
1265 raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
1266 "same-length zero-th-dimension `Tensor`s "
1267 "(saw lengths {}, {})".format(m, n))
1268 elif validate_args:
1269 assertions = [
1270 check_ops.assert_equal(
1271 dimension_size(probs, axis=-1),
1272 dimension_size(grid, axis=-1),
1273 message=("`quadrature_grid_and_probs` must be a `tuple` of "
1274 "same-length zero-th-dimension `Tensor`s")),
1275 ]
1276 with ops.control_dependencies(assertions):
1277 grid = array_ops.identity(grid)
1278 probs = array_ops.identity(probs)
1279 return grid, probs
1282def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
1283 """Pads `value` to the front and/or back of a `Tensor` dim, `count` times.
1285 Args:
1286 x: `Tensor` input.
1287 axis: Scalar `int`-like `Tensor` representing the single dimension to pad.
1288 (Negative indexing is supported.)
1289 front: Python `bool`; if `True` the beginning of the `axis` dimension is
1290 padded with `value`, `count` times. If `False` no front padding is made.
1291 back: Python `bool`; if `True` the end of the `axis` dimension is padded
1292 with `value`, `count` times. If `False` no end padding is made.
1293 value: Scalar `int`-like `Tensor` representing the actual value added to the
1294 front and/or back of the `axis` dimension of `x`.
1295 count: Scalar `int`-like `Tensor` representing number of elements added to
1296 the front and/or back of the `axis` dimension of `x`. E.g., if `front =
1297 back = True` then `2 * count` elements are added.
1298 name: Python `str` name prefixed to Ops created by this function.
1300 Returns:
1301 pad: The padded version of input `x`.
1303 Raises:
1304 ValueError: if both `front` and `back` are `False`.
1305 TypeError: if `count` is not `int`-like.
1306 """
1307 with ops.name_scope(name, "pad", [x, value, count]):
1308 x = ops.convert_to_tensor(x, name="x")
1309 value = ops.convert_to_tensor(value, dtype=x.dtype, name="value")
1310 count = ops.convert_to_tensor(count, name="count")
1311 if not count.dtype.is_integer:
1312 raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format(
1313 count.dtype.name))
1314 if not front and not back:
1315 raise ValueError("At least one of `front`, `back` must be `True`.")
1316 ndims = (
1317 x.shape.ndims if x.shape.ndims is not None else array_ops.rank(
1318 x, name="ndims"))
1319 axis = ops.convert_to_tensor(axis, name="axis")
1320 axis_ = tensor_util.constant_value(axis)
1321 if axis_ is not None:
1322 axis = axis_
1323 if axis < 0:
1324 axis = ndims + axis
1325 count_ = tensor_util.constant_value(count)
1326 if axis_ >= 0 or x.shape.ndims is not None:
1327 head = x.shape[:axis]
1328 middle = tensor_shape.TensorShape(None if count_ is None else (
1329 tensor_shape.dimension_at_index(x.shape, axis) + count_ *
1330 (front + back)))
1331 tail = x.shape[axis + 1:]
1332 final_shape = head.concatenate(middle.concatenate(tail))
1333 else:
1334 final_shape = None
1335 else:
1336 axis = array_ops.where_v2(axis < 0, ndims + axis, axis)
1337 final_shape = None
1338 x = array_ops.pad(
1339 x,
1340 paddings=array_ops.one_hot(
1341 indices=array_ops_stack.stack(
1342 [axis if front else -1, axis if back else -1]),
1343 depth=ndims,
1344 axis=0,
1345 on_value=count,
1346 dtype=dtypes.int32),
1347 constant_values=value)
1348 if final_shape is not None:
1349 x.set_shape(final_shape)
1350 return x
1353def parent_frame_arguments():
1354 """Returns parent frame arguments.
1356 When called inside a function, returns a dictionary with the caller's function
1357 arguments. These are positional arguments and keyword arguments (**kwargs),
1358 while variable arguments (*varargs) are excluded.
1360 When called at global scope, this will return an empty dictionary, since there
1361 are no arguments.
1363 WARNING: If caller function argument names are overloaded before invoking
1364 this method, then values will reflect the overloaded value. For this reason,
1365 we recommend calling `parent_frame_arguments` at the beginning of the
1366 function.
1367 """
1368 # All arguments and the names used for *varargs, and **kwargs
1369 arg_names, variable_arg_name, keyword_arg_name, local_vars = (
1370 tf_inspect._inspect.getargvalues( # pylint: disable=protected-access
1371 # Get the first frame of the caller of this method.
1372 tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access
1374 # Remove the *varargs, and flatten the **kwargs. Both are
1375 # nested lists.
1376 local_vars.pop(variable_arg_name, {})
1377 keyword_args = local_vars.pop(keyword_arg_name, {})
1379 final_args = {}
1380 # Copy over arguments and their values. In general, local_vars
1381 # may contain more than just the arguments, since this method
1382 # can be called anywhere in a function.
1383 for arg_name in arg_names:
1384 final_args[arg_name] = local_vars.pop(arg_name)
1385 final_args.update(keyword_args)
1387 return final_args
1390class AppendDocstring:
1391 """Helper class to promote private subclass docstring to public counterpart.
1393 Example:
1395 ```python
1396 class TransformedDistribution(Distribution):
1397 @distribution_util.AppendDocstring(
1398 additional_note="A special note!",
1399 kwargs_dict={"foo": "An extra arg."})
1400 def _prob(self, y, foo=None):
1401 pass
1402 ```
1404 In this case, the `AppendDocstring` decorator appends the `additional_note` to
1405 the docstring of `prob` (not `_prob`) and adds a new `kwargs`
1406 section with each dictionary item as a bullet-point.
1408 For a more detailed example, see `TransformedDistribution`.
1409 """
1411 def __init__(self, additional_note="", kwargs_dict=None):
1412 """Initializes the AppendDocstring object.
1414 Args:
1415 additional_note: Python string added as additional docstring to public
1416 version of function.
1417 kwargs_dict: Python string/string dictionary representing specific kwargs
1418 expanded from the **kwargs input.
1420 Raises:
1421 ValueError: if kwargs_dict.key contains whitespace.
1422 ValueError: if kwargs_dict.value contains newlines.
1423 """
1424 self._additional_note = additional_note
1425 if kwargs_dict:
1426 bullets = []
1427 for key in sorted(kwargs_dict.keys()):
1428 value = kwargs_dict[key]
1429 if any(x.isspace() for x in key):
1430 raise ValueError("Parameter name \"%s\" contains whitespace." % key)
1431 value = value.lstrip()
1432 if "\n" in value:
1433 raise ValueError(
1434 "Parameter description for \"%s\" contains newlines." % key)
1435 bullets.append("* `%s`: %s" % (key, value))
1436 self._additional_note += ("\n\n##### `kwargs`:\n\n" + "\n".join(bullets))
1438 def __call__(self, fn):
1440 @functools.wraps(fn)
1441 def _fn(*args, **kwargs):
1442 return fn(*args, **kwargs)
1444 if _fn.__doc__ is None:
1445 _fn.__doc__ = self._additional_note
1446 else:
1447 _fn.__doc__ += "\n%s" % self._additional_note
1448 return _fn