Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/distribution.py: 39%
399 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base classes for probability distributions."""
17import abc
18import contextlib
19import types
21import numpy as np
23from tensorflow.python.eager import context
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.distributions import kullback_leibler
31from tensorflow.python.ops.distributions import util
32from tensorflow.python.util import deprecation
33from tensorflow.python.util import tf_inspect
34from tensorflow.python.util.tf_export import tf_export
37__all__ = [
38 "ReparameterizationType",
39 "FULLY_REPARAMETERIZED",
40 "NOT_REPARAMETERIZED",
41 "Distribution",
42]
44_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
45 "batch_shape",
46 "batch_shape_tensor",
47 "cdf",
48 "covariance",
49 "cross_entropy",
50 "entropy",
51 "event_shape",
52 "event_shape_tensor",
53 "kl_divergence",
54 "log_cdf",
55 "log_prob",
56 "log_survival_function",
57 "mean",
58 "mode",
59 "prob",
60 "sample",
61 "stddev",
62 "survival_function",
63 "variance",
64]
67class _BaseDistribution(metaclass=abc.ABCMeta):
68 """Abstract base class needed for resolving subclass hierarchy."""
69 pass
72def _copy_fn(fn):
73 """Create a deep copy of fn.
75 Args:
76 fn: a callable
78 Returns:
79 A `FunctionType`: a deep copy of fn.
81 Raises:
82 TypeError: if `fn` is not a callable.
83 """
84 if not callable(fn):
85 raise TypeError("fn is not callable: %s" % fn)
86 # The blessed way to copy a function. copy.deepcopy fails to create a
87 # non-reference copy. Since:
88 # types.FunctionType == type(lambda: None),
89 # and the docstring for the function type states:
90 #
91 # function(code, globals[, name[, argdefs[, closure]]])
92 #
93 # Create a function object from a code object and a dictionary.
94 # ...
95 #
96 # Here we can use this to create a new function with the old function's
97 # code, globals, closure, etc.
98 return types.FunctionType(
99 code=fn.__code__, globals=fn.__globals__,
100 name=fn.__name__, argdefs=fn.__defaults__,
101 closure=fn.__closure__)
104def _update_docstring(old_str, append_str):
105 """Update old_str by inserting append_str just before the "Args:" section."""
106 old_str = old_str or ""
107 old_str_lines = old_str.split("\n")
109 # Step 0: Prepend spaces to all lines of append_str. This is
110 # necessary for correct markdown generation.
111 append_str = "\n".join(" %s" % line for line in append_str.split("\n"))
113 # Step 1: Find mention of "Args":
114 has_args_ix = [
115 ix for ix, line in enumerate(old_str_lines)
116 if line.strip().lower() == "args:"]
117 if has_args_ix:
118 final_args_ix = has_args_ix[-1]
119 return ("\n".join(old_str_lines[:final_args_ix])
120 + "\n\n" + append_str + "\n\n"
121 + "\n".join(old_str_lines[final_args_ix:]))
122 else:
123 return old_str + "\n\n" + append_str
126def _convert_to_tensor(value, name=None, preferred_dtype=None):
127 """Converts to tensor avoiding an eager bug that loses float precision."""
128 # TODO(b/116672045): Remove this function.
129 if (context.executing_eagerly() and preferred_dtype is not None and
130 (preferred_dtype.is_integer or preferred_dtype.is_bool)):
131 v = ops.convert_to_tensor(value, name=name)
132 if v.dtype.is_floating:
133 return v
134 return ops.convert_to_tensor(
135 value, name=name, preferred_dtype=preferred_dtype)
138class _DistributionMeta(abc.ABCMeta):
140 def __new__(mcs, classname, baseclasses, attrs):
141 """Control the creation of subclasses of the Distribution class.
143 The main purpose of this method is to properly propagate docstrings
144 from private Distribution methods, like `_log_prob`, into their
145 public wrappers as inherited by the Distribution base class
146 (e.g. `log_prob`).
148 Args:
149 classname: The name of the subclass being created.
150 baseclasses: A tuple of parent classes.
151 attrs: A dict mapping new attributes to their values.
153 Returns:
154 The class object.
156 Raises:
157 TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
158 the new class is derived via multiple inheritance and the first
159 parent class is not a subclass of `BaseDistribution`.
160 AttributeError: If `Distribution` does not implement e.g. `log_prob`.
161 ValueError: If a `Distribution` public method lacks a docstring.
162 """
163 if not baseclasses: # Nothing to be done for Distribution
164 raise TypeError("Expected non-empty baseclass. Does Distribution "
165 "not subclass _BaseDistribution?")
166 which_base = [
167 base for base in baseclasses
168 if base == _BaseDistribution or issubclass(base, Distribution)]
169 base = which_base[0]
170 if base == _BaseDistribution: # Nothing to be done for Distribution
171 return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
172 if not issubclass(base, Distribution):
173 raise TypeError("First parent class declared for %s must be "
174 "Distribution, but saw '%s'" % (classname, base.__name__))
175 for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS:
176 special_attr = "_%s" % attr
177 class_attr_value = attrs.get(attr, None)
178 if attr in attrs:
179 # The method is being overridden, do not update its docstring
180 continue
181 base_attr_value = getattr(base, attr, None)
182 if not base_attr_value:
183 raise AttributeError(
184 "Internal error: expected base class '%s' to implement method '%s'"
185 % (base.__name__, attr))
186 class_special_attr_value = attrs.get(special_attr, None)
187 if class_special_attr_value is None:
188 # No _special method available, no need to update the docstring.
189 continue
190 class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
191 if not class_special_attr_docstring:
192 # No docstring to append.
193 continue
194 class_attr_value = _copy_fn(base_attr_value)
195 class_attr_docstring = tf_inspect.getdoc(base_attr_value)
196 if class_attr_docstring is None:
197 raise ValueError(
198 "Expected base class fn to contain a docstring: %s.%s"
199 % (base.__name__, attr))
200 class_attr_value.__doc__ = _update_docstring(
201 class_attr_value.__doc__,
202 ("Additional documentation from `%s`:\n\n%s"
203 % (classname, class_special_attr_docstring)))
204 attrs[attr] = class_attr_value
206 return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
209@tf_export(v1=["distributions.ReparameterizationType"])
210class ReparameterizationType:
211 """Instances of this class represent how sampling is reparameterized.
213 Two static instances exist in the distributions library, signifying
214 one of two possible properties for samples from a distribution:
216 `FULLY_REPARAMETERIZED`: Samples from the distribution are fully
217 reparameterized, and straight-through gradients are supported.
219 `NOT_REPARAMETERIZED`: Samples from the distribution are not fully
220 reparameterized, and straight-through gradients are either partially
221 unsupported or are not supported at all. In this case, for purposes of
222 e.g. RL or variational inference, it is generally safest to wrap the
223 sample results in a `stop_gradients` call and use policy
224 gradients / surrogate loss instead.
225 """
227 @deprecation.deprecated(
228 "2019-01-01",
229 "The TensorFlow Distributions library has moved to "
230 "TensorFlow Probability "
231 "(https://github.com/tensorflow/probability). You "
232 "should update all references to use `tfp.distributions` "
233 "instead of `tf.distributions`.",
234 warn_once=True)
235 def __init__(self, rep_type):
236 self._rep_type = rep_type
238 def __repr__(self):
239 return "<Reparameterization Type: %s>" % self._rep_type
241 def __eq__(self, other):
242 """Determine if this `ReparameterizationType` is equal to another.
244 Since ReparameterizationType instances are constant static global
245 instances, equality checks if two instances' id() values are equal.
247 Args:
248 other: Object to compare against.
250 Returns:
251 `self is other`.
252 """
253 return self is other
256# Fully reparameterized distribution: samples from a fully
257# reparameterized distribution support straight-through gradients with
258# respect to all parameters.
259FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED")
260tf_export(v1=["distributions.FULLY_REPARAMETERIZED"]).export_constant(
261 __name__, "FULLY_REPARAMETERIZED")
264# Not reparameterized distribution: samples from a non-
265# reparameterized distribution do not support straight-through gradients for
266# at least some of the parameters.
267NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED")
268tf_export(v1=["distributions.NOT_REPARAMETERIZED"]).export_constant(
269 __name__, "NOT_REPARAMETERIZED")
272@tf_export(v1=["distributions.Distribution"])
273class Distribution(_BaseDistribution, metaclass=_DistributionMeta):
274 """A generic probability distribution base class.
276 `Distribution` is a base class for constructing and organizing properties
277 (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian).
279 #### Subclassing
281 Subclasses are expected to implement a leading-underscore version of the
282 same-named function. The argument signature should be identical except for
283 the omission of `name="..."`. For example, to enable `log_prob(value,
284 name="log_prob")` a subclass should implement `_log_prob(value)`.
286 Subclasses can append to public-level docstrings by providing
287 docstrings for their method specializations. For example:
289 ```python
290 @util.AppendDocstring("Some other details.")
291 def _log_prob(self, value):
292 ...
293 ```
295 would add the string "Some other details." to the `log_prob` function
296 docstring. This is implemented as a simple decorator to avoid python
297 linter complaining about missing Args/Returns/Raises sections in the
298 partial docstrings.
300 #### Broadcasting, batching, and shapes
302 All distributions support batches of independent distributions of that type.
303 The batch shape is determined by broadcasting together the parameters.
305 The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and
306 `log_prob` reflect this broadcasting, as does the return value of `sample` and
307 `sample_n`.
309 `sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is
310 the shape of the `Tensor` returned from `sample_n`, `n` is the number of
311 samples, `batch_shape` defines how many independent distributions there are,
312 and `event_shape` defines the shape of samples from each of those independent
313 distributions. Samples are independent along the `batch_shape` dimensions, but
314 not necessarily so along the `event_shape` dimensions (depending on the
315 particulars of the underlying distribution).
317 Using the `Uniform` distribution as an example:
319 ```python
320 minval = 3.0
321 maxval = [[4.0, 6.0],
322 [10.0, 12.0]]
324 # Broadcasting:
325 # This instance represents 4 Uniform distributions. Each has a lower bound at
326 # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape.
327 u = Uniform(minval, maxval)
329 # `event_shape` is `TensorShape([])`.
330 event_shape = u.event_shape
331 # `event_shape_t` is a `Tensor` which will evaluate to [].
332 event_shape_t = u.event_shape_tensor()
334 # Sampling returns a sample per distribution. `samples` has shape
335 # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5,
336 # batch_shape=[2, 2], and event_shape=[].
337 samples = u.sample_n(5)
339 # The broadcasting holds across methods. Here we use `cdf` as an example. The
340 # same holds for `log_cdf` and the likelihood functions.
342 # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the
343 # shape of the `Uniform` instance.
344 cum_prob_broadcast = u.cdf(4.0)
346 # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting
347 # occurred.
348 cum_prob_per_dist = u.cdf([[4.0, 5.0],
349 [6.0, 7.0]])
351 # INVALID as the `value` argument is not broadcastable to the distribution's
352 # shape.
353 cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
354 ```
356 #### Shapes
358 There are three important concepts associated with TensorFlow Distributions
359 shapes:
360 - Event shape describes the shape of a single draw from the distribution;
361 it may be dependent across dimensions. For scalar distributions, the event
362 shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is
363 `[5]`.
364 - Batch shape describes independent, not identically distributed draws, aka a
365 "collection" or "bunch" of distributions.
366 - Sample shape describes independent, identically distributed draws of batches
367 from the distribution family.
369 The event shape and the batch shape are properties of a Distribution object,
370 whereas the sample shape is associated with a specific call to `sample` or
371 `log_prob`.
373 For detailed usage examples of TensorFlow Distributions shapes, see
374 [this tutorial](
375 https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb)
377 #### Parameter values leading to undefined statistics or distributions.
379 Some distributions do not have well-defined statistics for all initialization
380 parameter values. For example, the beta distribution is parameterized by
381 positive real numbers `concentration1` and `concentration0`, and does not have
382 well-defined mode if `concentration1 < 1` or `concentration0 < 1`.
384 The user is given the option of raising an exception or returning `NaN`.
386 ```python
387 a = tf.exp(tf.matmul(logits, weights_a))
388 b = tf.exp(tf.matmul(logits, weights_b))
390 # Will raise exception if ANY batch member has a < 1 or b < 1.
391 dist = distributions.beta(a, b, allow_nan_stats=False)
392 mode = dist.mode().eval()
394 # Will return NaN for batch members with either a < 1 or b < 1.
395 dist = distributions.beta(a, b, allow_nan_stats=True) # Default behavior
396 mode = dist.mode().eval()
397 ```
399 In all cases, an exception is raised if *invalid* parameters are passed, e.g.
401 ```python
402 # Will raise an exception if any Op is run.
403 negative_a = -1.0 * a # beta distribution by definition has a > 0.
404 dist = distributions.beta(negative_a, b, allow_nan_stats=True)
405 dist.mean().eval()
406 ```
408 """
410 @deprecation.deprecated(
411 "2019-01-01",
412 "The TensorFlow Distributions library has moved to "
413 "TensorFlow Probability "
414 "(https://github.com/tensorflow/probability). You "
415 "should update all references to use `tfp.distributions` "
416 "instead of `tf.distributions`.",
417 warn_once=True)
418 def __init__(self,
419 dtype,
420 reparameterization_type,
421 validate_args,
422 allow_nan_stats,
423 parameters=None,
424 graph_parents=None,
425 name=None):
426 """Constructs the `Distribution`.
428 **This is a private method for subclass use.**
430 Args:
431 dtype: The type of the event samples. `None` implies no type-enforcement.
432 reparameterization_type: Instance of `ReparameterizationType`.
433 If `distributions.FULLY_REPARAMETERIZED`, this
434 `Distribution` can be reparameterized in terms of some standard
435 distribution with a function whose Jacobian is constant for the support
436 of the standard distribution. If `distributions.NOT_REPARAMETERIZED`,
437 then no such reparameterization is available.
438 validate_args: Python `bool`, default `False`. When `True` distribution
439 parameters are checked for validity despite possibly degrading runtime
440 performance. When `False` invalid inputs may silently render incorrect
441 outputs.
442 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
443 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
444 result is undefined. When `False`, an exception is raised if one or
445 more of the statistic's batch members are undefined.
446 parameters: Python `dict` of parameters used to instantiate this
447 `Distribution`.
448 graph_parents: Python `list` of graph prerequisites of this
449 `Distribution`.
450 name: Python `str` name prefixed to Ops created by this class. Default:
451 subclass name.
453 Raises:
454 ValueError: if any member of graph_parents is `None` or not a `Tensor`.
455 """
456 graph_parents = [] if graph_parents is None else graph_parents
457 for i, t in enumerate(graph_parents):
458 if t is None or not tensor_util.is_tf_type(t):
459 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
460 if not name or name[-1] != "/": # `name` is not a name scope
461 non_unique_name = name or type(self).__name__
462 with ops.name_scope(non_unique_name) as name:
463 pass
464 self._dtype = dtype
465 self._reparameterization_type = reparameterization_type
466 self._allow_nan_stats = allow_nan_stats
467 self._validate_args = validate_args
468 self._parameters = parameters or {}
469 self._graph_parents = graph_parents
470 self._name = name
472 @property
473 def _parameters(self):
474 return self._parameter_dict
476 @_parameters.setter
477 def _parameters(self, value):
478 """Intercept assignments to self._parameters to avoid reference cycles.
480 Parameters are often created using locals(), so we need to clean out any
481 references to `self` before assigning it to an attribute.
483 Args:
484 value: A dictionary of parameters to assign to the `_parameters` property.
485 """
486 if "self" in value:
487 del value["self"]
488 self._parameter_dict = value
490 @classmethod
491 def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
492 """Shapes of parameters given the desired shape of a call to `sample()`.
494 This is a class method that describes what key/value arguments are required
495 to instantiate the given `Distribution` so that a particular shape is
496 returned for that instance's call to `sample()`.
498 Subclasses should override class method `_param_shapes`.
500 Args:
501 sample_shape: `Tensor` or python list/tuple. Desired shape of a call to
502 `sample()`.
503 name: name to prepend ops with.
505 Returns:
506 `dict` of parameter name to `Tensor` shapes.
507 """
508 with ops.name_scope(name, values=[sample_shape]):
509 return cls._param_shapes(sample_shape)
511 @classmethod
512 def param_static_shapes(cls, sample_shape):
513 """param_shapes with static (i.e. `TensorShape`) shapes.
515 This is a class method that describes what key/value arguments are required
516 to instantiate the given `Distribution` so that a particular shape is
517 returned for that instance's call to `sample()`. Assumes that the sample's
518 shape is known statically.
520 Subclasses should override class method `_param_shapes` to return
521 constant-valued tensors when constant values are fed.
523 Args:
524 sample_shape: `TensorShape` or python list/tuple. Desired shape of a call
525 to `sample()`.
527 Returns:
528 `dict` of parameter name to `TensorShape`.
530 Raises:
531 ValueError: if `sample_shape` is a `TensorShape` and is not fully defined.
532 """
533 if isinstance(sample_shape, tensor_shape.TensorShape):
534 if not sample_shape.is_fully_defined():
535 raise ValueError("TensorShape sample_shape must be fully defined")
536 sample_shape = sample_shape.as_list()
538 params = cls.param_shapes(sample_shape)
540 static_params = {}
541 for name, shape in params.items():
542 static_shape = tensor_util.constant_value(shape)
543 if static_shape is None:
544 raise ValueError(
545 "sample_shape must be a fully-defined TensorShape or list/tuple")
546 static_params[name] = tensor_shape.TensorShape(static_shape)
548 return static_params
550 @staticmethod
551 def _param_shapes(sample_shape):
552 raise NotImplementedError("_param_shapes not implemented")
554 @property
555 def name(self):
556 """Name prepended to all ops created by this `Distribution`."""
557 return self._name
559 @property
560 def dtype(self):
561 """The `DType` of `Tensor`s handled by this `Distribution`."""
562 return self._dtype
564 @property
565 def parameters(self):
566 """Dictionary of parameters used to instantiate this `Distribution`."""
567 # Remove "self", "__class__", or other special variables. These can appear
568 # if the subclass used:
569 # `parameters = dict(locals())`.
570 return {k: v for k, v in self._parameters.items()
571 if not k.startswith("__") and k != "self"}
573 @property
574 def reparameterization_type(self):
575 """Describes how samples from the distribution are reparameterized.
577 Currently this is one of the static instances
578 `distributions.FULLY_REPARAMETERIZED`
579 or `distributions.NOT_REPARAMETERIZED`.
581 Returns:
582 An instance of `ReparameterizationType`.
583 """
584 return self._reparameterization_type
586 @property
587 def allow_nan_stats(self):
588 """Python `bool` describing behavior when a stat is undefined.
590 Stats return +/- infinity when it makes sense. E.g., the variance of a
591 Cauchy distribution is infinity. However, sometimes the statistic is
592 undefined, e.g., if a distribution's pdf does not achieve a maximum within
593 the support of the distribution, the mode is undefined. If the mean is
594 undefined, then by definition the variance is undefined. E.g. the mean for
595 Student's T for df = 1 is undefined (no clear way to say it is either + or -
596 infinity), so the variance = E[(X - mean)**2] is also undefined.
598 Returns:
599 allow_nan_stats: Python `bool`.
600 """
601 return self._allow_nan_stats
603 @property
604 def validate_args(self):
605 """Python `bool` indicating possibly expensive checks are enabled."""
606 return self._validate_args
608 def copy(self, **override_parameters_kwargs):
609 """Creates a deep copy of the distribution.
611 Note: the copy distribution may continue to depend on the original
612 initialization arguments.
614 Args:
615 **override_parameters_kwargs: String/value dictionary of initialization
616 arguments to override with new values.
618 Returns:
619 distribution: A new instance of `type(self)` initialized from the union
620 of self.parameters and override_parameters_kwargs, i.e.,
621 `dict(self.parameters, **override_parameters_kwargs)`.
622 """
623 parameters = dict(self.parameters, **override_parameters_kwargs)
624 return type(self)(**parameters)
626 def _batch_shape_tensor(self):
627 raise NotImplementedError(
628 "batch_shape_tensor is not implemented: {}".format(type(self).__name__))
630 def batch_shape_tensor(self, name="batch_shape_tensor"):
631 """Shape of a single sample from a single event index as a 1-D `Tensor`.
633 The batch dimensions are indexes into independent, non-identical
634 parameterizations of this distribution.
636 Args:
637 name: name to give to the op
639 Returns:
640 batch_shape: `Tensor`.
641 """
642 with self._name_scope(name):
643 if self.batch_shape.is_fully_defined():
644 return ops.convert_to_tensor(self.batch_shape.as_list(),
645 dtype=dtypes.int32,
646 name="batch_shape")
647 return self._batch_shape_tensor()
649 def _batch_shape(self):
650 return tensor_shape.TensorShape(None)
652 @property
653 def batch_shape(self):
654 """Shape of a single sample from a single event index as a `TensorShape`.
656 May be partially defined or unknown.
658 The batch dimensions are indexes into independent, non-identical
659 parameterizations of this distribution.
661 Returns:
662 batch_shape: `TensorShape`, possibly unknown.
663 """
664 return tensor_shape.as_shape(self._batch_shape())
666 def _event_shape_tensor(self):
667 raise NotImplementedError(
668 "event_shape_tensor is not implemented: {}".format(type(self).__name__))
670 def event_shape_tensor(self, name="event_shape_tensor"):
671 """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
673 Args:
674 name: name to give to the op
676 Returns:
677 event_shape: `Tensor`.
678 """
679 with self._name_scope(name):
680 if self.event_shape.is_fully_defined():
681 return ops.convert_to_tensor(self.event_shape.as_list(),
682 dtype=dtypes.int32,
683 name="event_shape")
684 return self._event_shape_tensor()
686 def _event_shape(self):
687 return tensor_shape.TensorShape(None)
689 @property
690 def event_shape(self):
691 """Shape of a single sample from a single batch as a `TensorShape`.
693 May be partially defined or unknown.
695 Returns:
696 event_shape: `TensorShape`, possibly unknown.
697 """
698 return tensor_shape.as_shape(self._event_shape())
700 def is_scalar_event(self, name="is_scalar_event"):
701 """Indicates that `event_shape == []`.
703 Args:
704 name: Python `str` prepended to names of ops created by this function.
706 Returns:
707 is_scalar_event: `bool` scalar `Tensor`.
708 """
709 with self._name_scope(name):
710 return ops.convert_to_tensor(
711 self._is_scalar_helper(self.event_shape, self.event_shape_tensor),
712 name="is_scalar_event")
714 def is_scalar_batch(self, name="is_scalar_batch"):
715 """Indicates that `batch_shape == []`.
717 Args:
718 name: Python `str` prepended to names of ops created by this function.
720 Returns:
721 is_scalar_batch: `bool` scalar `Tensor`.
722 """
723 with self._name_scope(name):
724 return ops.convert_to_tensor(
725 self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor),
726 name="is_scalar_batch")
728 def _sample_n(self, n, seed=None):
729 raise NotImplementedError("sample_n is not implemented: {}".format(
730 type(self).__name__))
732 def _call_sample_n(self, sample_shape, seed, name, **kwargs):
733 with self._name_scope(name, values=[sample_shape]):
734 sample_shape = ops.convert_to_tensor(
735 sample_shape, dtype=dtypes.int32, name="sample_shape")
736 sample_shape, n = self._expand_sample_shape_to_vector(
737 sample_shape, "sample_shape")
738 samples = self._sample_n(n, seed, **kwargs)
739 batch_event_shape = array_ops.shape(samples)[1:]
740 final_shape = array_ops.concat([sample_shape, batch_event_shape], 0)
741 samples = array_ops.reshape(samples, final_shape)
742 samples = self._set_sample_static_shape(samples, sample_shape)
743 return samples
745 def sample(self, sample_shape=(), seed=None, name="sample"):
746 """Generate samples of the specified shape.
748 Note that a call to `sample()` without arguments will generate a single
749 sample.
751 Args:
752 sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
753 seed: Python integer seed for RNG
754 name: name to give to the op.
756 Returns:
757 samples: a `Tensor` with prepended dimensions `sample_shape`.
758 """
759 return self._call_sample_n(sample_shape, seed, name)
761 def _log_prob(self, value):
762 raise NotImplementedError("log_prob is not implemented: {}".format(
763 type(self).__name__))
765 def _call_log_prob(self, value, name, **kwargs):
766 with self._name_scope(name, values=[value]):
767 value = _convert_to_tensor(
768 value, name="value", preferred_dtype=self.dtype)
769 try:
770 return self._log_prob(value, **kwargs)
771 except NotImplementedError as original_exception:
772 try:
773 return math_ops.log(self._prob(value, **kwargs))
774 except NotImplementedError:
775 raise original_exception
777 def log_prob(self, value, name="log_prob"):
778 """Log probability density/mass function.
780 Args:
781 value: `float` or `double` `Tensor`.
782 name: Python `str` prepended to names of ops created by this function.
784 Returns:
785 log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
786 values of type `self.dtype`.
787 """
788 return self._call_log_prob(value, name)
790 def _prob(self, value):
791 raise NotImplementedError("prob is not implemented: {}".format(
792 type(self).__name__))
794 def _call_prob(self, value, name, **kwargs):
795 with self._name_scope(name, values=[value]):
796 value = _convert_to_tensor(
797 value, name="value", preferred_dtype=self.dtype)
798 try:
799 return self._prob(value, **kwargs)
800 except NotImplementedError as original_exception:
801 try:
802 return math_ops.exp(self._log_prob(value, **kwargs))
803 except NotImplementedError:
804 raise original_exception
806 def prob(self, value, name="prob"):
807 """Probability density/mass function.
809 Args:
810 value: `float` or `double` `Tensor`.
811 name: Python `str` prepended to names of ops created by this function.
813 Returns:
814 prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
815 values of type `self.dtype`.
816 """
817 return self._call_prob(value, name)
819 def _log_cdf(self, value):
820 raise NotImplementedError("log_cdf is not implemented: {}".format(
821 type(self).__name__))
823 def _call_log_cdf(self, value, name, **kwargs):
824 with self._name_scope(name, values=[value]):
825 value = _convert_to_tensor(
826 value, name="value", preferred_dtype=self.dtype)
827 try:
828 return self._log_cdf(value, **kwargs)
829 except NotImplementedError as original_exception:
830 try:
831 return math_ops.log(self._cdf(value, **kwargs))
832 except NotImplementedError:
833 raise original_exception
835 def log_cdf(self, value, name="log_cdf"):
836 """Log cumulative distribution function.
838 Given random variable `X`, the cumulative distribution function `cdf` is:
840 ```none
841 log_cdf(x) := Log[ P[X <= x] ]
842 ```
844 Often, a numerical approximation can be used for `log_cdf(x)` that yields
845 a more accurate answer than simply taking the logarithm of the `cdf` when
846 `x << -1`.
848 Args:
849 value: `float` or `double` `Tensor`.
850 name: Python `str` prepended to names of ops created by this function.
852 Returns:
853 logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
854 values of type `self.dtype`.
855 """
856 return self._call_log_cdf(value, name)
858 def _cdf(self, value):
859 raise NotImplementedError("cdf is not implemented: {}".format(
860 type(self).__name__))
862 def _call_cdf(self, value, name, **kwargs):
863 with self._name_scope(name, values=[value]):
864 value = _convert_to_tensor(
865 value, name="value", preferred_dtype=self.dtype)
866 try:
867 return self._cdf(value, **kwargs)
868 except NotImplementedError as original_exception:
869 try:
870 return math_ops.exp(self._log_cdf(value, **kwargs))
871 except NotImplementedError:
872 raise original_exception
874 def cdf(self, value, name="cdf"):
875 """Cumulative distribution function.
877 Given random variable `X`, the cumulative distribution function `cdf` is:
879 ```none
880 cdf(x) := P[X <= x]
881 ```
883 Args:
884 value: `float` or `double` `Tensor`.
885 name: Python `str` prepended to names of ops created by this function.
887 Returns:
888 cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
889 values of type `self.dtype`.
890 """
891 return self._call_cdf(value, name)
893 def _log_survival_function(self, value):
894 raise NotImplementedError(
895 "log_survival_function is not implemented: {}".format(
896 type(self).__name__))
898 def _call_log_survival_function(self, value, name, **kwargs):
899 with self._name_scope(name, values=[value]):
900 value = _convert_to_tensor(
901 value, name="value", preferred_dtype=self.dtype)
902 try:
903 return self._log_survival_function(value, **kwargs)
904 except NotImplementedError as original_exception:
905 try:
906 return math_ops.log1p(-self.cdf(value, **kwargs))
907 except NotImplementedError:
908 raise original_exception
910 def log_survival_function(self, value, name="log_survival_function"):
911 """Log survival function.
913 Given random variable `X`, the survival function is defined:
915 ```none
916 log_survival_function(x) = Log[ P[X > x] ]
917 = Log[ 1 - P[X <= x] ]
918 = Log[ 1 - cdf(x) ]
919 ```
921 Typically, different numerical approximations can be used for the log
922 survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.
924 Args:
925 value: `float` or `double` `Tensor`.
926 name: Python `str` prepended to names of ops created by this function.
928 Returns:
929 `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
930 `self.dtype`.
931 """
932 return self._call_log_survival_function(value, name)
934 def _survival_function(self, value):
935 raise NotImplementedError("survival_function is not implemented: {}".format(
936 type(self).__name__))
938 def _call_survival_function(self, value, name, **kwargs):
939 with self._name_scope(name, values=[value]):
940 value = _convert_to_tensor(
941 value, name="value", preferred_dtype=self.dtype)
942 try:
943 return self._survival_function(value, **kwargs)
944 except NotImplementedError as original_exception:
945 try:
946 return 1. - self.cdf(value, **kwargs)
947 except NotImplementedError:
948 raise original_exception
950 def survival_function(self, value, name="survival_function"):
951 """Survival function.
953 Given random variable `X`, the survival function is defined:
955 ```none
956 survival_function(x) = P[X > x]
957 = 1 - P[X <= x]
958 = 1 - cdf(x).
959 ```
961 Args:
962 value: `float` or `double` `Tensor`.
963 name: Python `str` prepended to names of ops created by this function.
965 Returns:
966 `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
967 `self.dtype`.
968 """
969 return self._call_survival_function(value, name)
971 def _entropy(self):
972 raise NotImplementedError("entropy is not implemented: {}".format(
973 type(self).__name__))
975 def entropy(self, name="entropy"):
976 """Shannon entropy in nats."""
977 with self._name_scope(name):
978 return self._entropy()
980 def _mean(self):
981 raise NotImplementedError("mean is not implemented: {}".format(
982 type(self).__name__))
984 def mean(self, name="mean"):
985 """Mean."""
986 with self._name_scope(name):
987 return self._mean()
989 def _quantile(self, value):
990 raise NotImplementedError("quantile is not implemented: {}".format(
991 type(self).__name__))
993 def _call_quantile(self, value, name, **kwargs):
994 with self._name_scope(name, values=[value]):
995 value = _convert_to_tensor(
996 value, name="value", preferred_dtype=self.dtype)
997 return self._quantile(value, **kwargs)
999 def quantile(self, value, name="quantile"):
1000 """Quantile function. Aka "inverse cdf" or "percent point function".
1002 Given random variable `X` and `p in [0, 1]`, the `quantile` is:
1004 ```none
1005 quantile(p) := x such that P[X <= x] == p
1006 ```
1008 Args:
1009 value: `float` or `double` `Tensor`.
1010 name: Python `str` prepended to names of ops created by this function.
1012 Returns:
1013 quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
1014 values of type `self.dtype`.
1015 """
1016 return self._call_quantile(value, name)
1018 def _variance(self):
1019 raise NotImplementedError("variance is not implemented: {}".format(
1020 type(self).__name__))
1022 def variance(self, name="variance"):
1023 """Variance.
1025 Variance is defined as,
1027 ```none
1028 Var = E[(X - E[X])**2]
1029 ```
1031 where `X` is the random variable associated with this distribution, `E`
1032 denotes expectation, and `Var.shape = batch_shape + event_shape`.
1034 Args:
1035 name: Python `str` prepended to names of ops created by this function.
1037 Returns:
1038 variance: Floating-point `Tensor` with shape identical to
1039 `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
1040 """
1041 with self._name_scope(name):
1042 try:
1043 return self._variance()
1044 except NotImplementedError as original_exception:
1045 try:
1046 return math_ops.square(self._stddev())
1047 except NotImplementedError:
1048 raise original_exception
1050 def _stddev(self):
1051 raise NotImplementedError("stddev is not implemented: {}".format(
1052 type(self).__name__))
1054 def stddev(self, name="stddev"):
1055 """Standard deviation.
1057 Standard deviation is defined as,
1059 ```none
1060 stddev = E[(X - E[X])**2]**0.5
1061 ```
1063 where `X` is the random variable associated with this distribution, `E`
1064 denotes expectation, and `stddev.shape = batch_shape + event_shape`.
1066 Args:
1067 name: Python `str` prepended to names of ops created by this function.
1069 Returns:
1070 stddev: Floating-point `Tensor` with shape identical to
1071 `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
1072 """
1074 with self._name_scope(name):
1075 try:
1076 return self._stddev()
1077 except NotImplementedError as original_exception:
1078 try:
1079 return math_ops.sqrt(self._variance())
1080 except NotImplementedError:
1081 raise original_exception
1083 def _covariance(self):
1084 raise NotImplementedError("covariance is not implemented: {}".format(
1085 type(self).__name__))
1087 def covariance(self, name="covariance"):
1088 """Covariance.
1090 Covariance is (possibly) defined only for non-scalar-event distributions.
1092 For example, for a length-`k`, vector-valued distribution, it is calculated
1093 as,
1095 ```none
1096 Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]
1097 ```
1099 where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E`
1100 denotes expectation.
1102 Alternatively, for non-vector, multivariate distributions (e.g.,
1103 matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices
1104 under some vectorization of the events, i.e.,
1106 ```none
1107 Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]
1108 ```
1110 where `Cov` is a (batch of) `k' x k'` matrices,
1111 `0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function
1112 mapping indices of this distribution's event dimensions to indices of a
1113 length-`k'` vector.
1115 Args:
1116 name: Python `str` prepended to names of ops created by this function.
1118 Returns:
1119 covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']`
1120 where the first `n` dimensions are batch coordinates and
1121 `k' = reduce_prod(self.event_shape)`.
1122 """
1123 with self._name_scope(name):
1124 return self._covariance()
1126 def _mode(self):
1127 raise NotImplementedError("mode is not implemented: {}".format(
1128 type(self).__name__))
1130 def mode(self, name="mode"):
1131 """Mode."""
1132 with self._name_scope(name):
1133 return self._mode()
1135 def _cross_entropy(self, other):
1136 return kullback_leibler.cross_entropy(
1137 self, other, allow_nan_stats=self.allow_nan_stats)
1139 def cross_entropy(self, other, name="cross_entropy"):
1140 """Computes the (Shannon) cross entropy.
1142 Denote this distribution (`self`) by `P` and the `other` distribution by
1143 `Q`. Assuming `P, Q` are absolutely continuous with respect to
1144 one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon)
1145 cross entropy is defined as:
1147 ```none
1148 H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
1149 ```
1151 where `F` denotes the support of the random variable `X ~ P`.
1153 Args:
1154 other: `tfp.distributions.Distribution` instance.
1155 name: Python `str` prepended to names of ops created by this function.
1157 Returns:
1158 cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
1159 representing `n` different calculations of (Shanon) cross entropy.
1160 """
1161 with self._name_scope(name):
1162 return self._cross_entropy(other)
1164 def _kl_divergence(self, other):
1165 return kullback_leibler.kl_divergence(
1166 self, other, allow_nan_stats=self.allow_nan_stats)
1168 def kl_divergence(self, other, name="kl_divergence"):
1169 """Computes the Kullback--Leibler divergence.
1171 Denote this distribution (`self`) by `p` and the `other` distribution by
1172 `q`. Assuming `p, q` are absolutely continuous with respect to reference
1173 measure `r`, the KL divergence is defined as:
1175 ```none
1176 KL[p, q] = E_p[log(p(X)/q(X))]
1177 = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
1178 = H[p, q] - H[p]
1179 ```
1181 where `F` denotes the support of the random variable `X ~ p`, `H[., .]`
1182 denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
1184 Args:
1185 other: `tfp.distributions.Distribution` instance.
1186 name: Python `str` prepended to names of ops created by this function.
1188 Returns:
1189 kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
1190 representing `n` different calculations of the Kullback-Leibler
1191 divergence.
1192 """
1193 with self._name_scope(name):
1194 return self._kl_divergence(other)
1196 def __str__(self):
1197 return ("tfp.distributions.{type_name}("
1198 "\"{self_name}\""
1199 "{maybe_batch_shape}"
1200 "{maybe_event_shape}"
1201 ", dtype={dtype})".format(
1202 type_name=type(self).__name__,
1203 self_name=self.name,
1204 maybe_batch_shape=(", batch_shape={}".format(self.batch_shape)
1205 if self.batch_shape.ndims is not None
1206 else ""),
1207 maybe_event_shape=(", event_shape={}".format(self.event_shape)
1208 if self.event_shape.ndims is not None
1209 else ""),
1210 dtype=self.dtype.name))
1212 def __repr__(self):
1213 return ("<tfp.distributions.{type_name} "
1214 "'{self_name}'"
1215 " batch_shape={batch_shape}"
1216 " event_shape={event_shape}"
1217 " dtype={dtype}>".format(
1218 type_name=type(self).__name__,
1219 self_name=self.name,
1220 batch_shape=self.batch_shape,
1221 event_shape=self.event_shape,
1222 dtype=self.dtype.name))
1224 @contextlib.contextmanager
1225 def _name_scope(self, name=None, values=None):
1226 """Helper function to standardize op scope."""
1227 with ops.name_scope(self.name):
1228 with ops.name_scope(name, values=(
1229 ([] if values is None else values) + self._graph_parents)) as scope:
1230 yield scope
1232 def _expand_sample_shape_to_vector(self, x, name):
1233 """Helper to `sample` which ensures input is 1D."""
1234 x_static_val = tensor_util.constant_value(x)
1235 if x_static_val is None:
1236 prod = math_ops.reduce_prod(x)
1237 else:
1238 prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype())
1240 ndims = x.get_shape().ndims # != sample_ndims
1241 if ndims is None:
1242 # Maybe expand_dims.
1243 ndims = array_ops.rank(x)
1244 expanded_shape = util.pick_vector(
1245 math_ops.equal(ndims, 0),
1246 np.array([1], dtype=np.int32), array_ops.shape(x))
1247 x = array_ops.reshape(x, expanded_shape)
1248 elif ndims == 0:
1249 # Definitely expand_dims.
1250 if x_static_val is not None:
1251 x = ops.convert_to_tensor(
1252 np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()),
1253 name=name)
1254 else:
1255 x = array_ops.reshape(x, [1])
1256 elif ndims != 1:
1257 raise ValueError("Input is neither scalar nor vector.")
1259 return x, prod
1261 def _set_sample_static_shape(self, x, sample_shape):
1262 """Helper to `sample`; sets static shape info."""
1263 # Set shape hints.
1264 sample_shape = tensor_shape.TensorShape(
1265 tensor_util.constant_value(sample_shape))
1267 ndims = x.get_shape().ndims
1268 sample_ndims = sample_shape.ndims
1269 batch_ndims = self.batch_shape.ndims
1270 event_ndims = self.event_shape.ndims
1272 # Infer rank(x).
1273 if (ndims is None and
1274 sample_ndims is not None and
1275 batch_ndims is not None and
1276 event_ndims is not None):
1277 ndims = sample_ndims + batch_ndims + event_ndims
1278 x.set_shape([None] * ndims)
1280 # Infer sample shape.
1281 if ndims is not None and sample_ndims is not None:
1282 shape = sample_shape.concatenate([None]*(ndims - sample_ndims))
1283 x.set_shape(x.get_shape().merge_with(shape))
1285 # Infer event shape.
1286 if ndims is not None and event_ndims is not None:
1287 shape = tensor_shape.TensorShape(
1288 [None]*(ndims - event_ndims)).concatenate(self.event_shape)
1289 x.set_shape(x.get_shape().merge_with(shape))
1291 # Infer batch shape.
1292 if batch_ndims is not None:
1293 if ndims is not None:
1294 if sample_ndims is None and event_ndims is not None:
1295 sample_ndims = ndims - batch_ndims - event_ndims
1296 elif event_ndims is None and sample_ndims is not None:
1297 event_ndims = ndims - batch_ndims - sample_ndims
1298 if sample_ndims is not None and event_ndims is not None:
1299 shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate(
1300 self.batch_shape).concatenate([None]*event_ndims)
1301 x.set_shape(x.get_shape().merge_with(shape))
1303 return x
1305 def _is_scalar_helper(self, static_shape, dynamic_shape_fn):
1306 """Implementation for `is_scalar_batch` and `is_scalar_event`."""
1307 if static_shape.ndims is not None:
1308 return static_shape.ndims == 0
1309 shape = dynamic_shape_fn()
1310 if (shape.get_shape().ndims is not None and
1311 shape.get_shape().dims[0].value is not None):
1312 # If the static_shape is correctly written then we should never execute
1313 # this branch. We keep it just in case there's some unimagined corner
1314 # case.
1315 return shape.get_shape().as_list() == [0]
1316 return math_ops.equal(array_ops.shape(shape)[0], 0)