Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/beta.py: 52%
111 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Beta distribution class."""
17import numpy as np
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn
28from tensorflow.python.ops import random_ops
29from tensorflow.python.ops.distributions import distribution
30from tensorflow.python.ops.distributions import kullback_leibler
31from tensorflow.python.ops.distributions import util as distribution_util
32from tensorflow.python.util import deprecation
33from tensorflow.python.util.tf_export import tf_export
36__all__ = [
37 "Beta",
38 "BetaWithSoftplusConcentration",
39]
42_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in
43`[0, 1].` It must have a shape compatible with `self.batch_shape()`."""
46@tf_export(v1=["distributions.Beta"])
47class Beta(distribution.Distribution):
48 """Beta distribution.
50 The Beta distribution is defined over the `(0, 1)` interval using parameters
51 `concentration1` (aka "alpha") and `concentration0` (aka "beta").
53 #### Mathematical Details
55 The probability density function (pdf) is,
57 ```none
58 pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
59 Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
60 ```
62 where:
64 * `concentration1 = alpha`,
65 * `concentration0 = beta`,
66 * `Z` is the normalization constant, and,
67 * `Gamma` is the [gamma function](
68 https://en.wikipedia.org/wiki/Gamma_function).
70 The concentration parameters represent mean total counts of a `1` or a `0`,
71 i.e.,
73 ```none
74 concentration1 = alpha = mean * total_concentration
75 concentration0 = beta = (1. - mean) * total_concentration
76 ```
78 where `mean` in `(0, 1)` and `total_concentration` is a positive real number
79 representing a mean `total_count = concentration1 + concentration0`.
81 Distribution parameters are automatically broadcast in all functions; see
82 examples for details.
84 Warning: The samples can be zero due to finite precision.
85 This happens more often when some of the concentrations are very small.
86 Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
87 density.
89 Samples of this distribution are reparameterized (pathwise differentiable).
90 The derivatives are computed using the approach described in
91 (Figurnov et al., 2018).
93 #### Examples
95 ```python
96 import tensorflow_probability as tfp
97 tfd = tfp.distributions
99 # Create a batch of three Beta distributions.
100 alpha = [1, 2, 3]
101 beta = [1, 2, 3]
102 dist = tfd.Beta(alpha, beta)
104 dist.sample([4, 5]) # Shape [4, 5, 3]
106 # `x` has three batch entries, each with two samples.
107 x = [[.1, .4, .5],
108 [.2, .3, .5]]
109 # Calculate the probability of each pair of samples under the corresponding
110 # distribution in `dist`.
111 dist.prob(x) # Shape [2, 3]
112 ```
114 ```python
115 # Create batch_shape=[2, 3] via parameter broadcast:
116 alpha = [[1.], [2]] # Shape [2, 1]
117 beta = [3., 4, 5] # Shape [3]
118 dist = tfd.Beta(alpha, beta)
120 # alpha broadcast as: [[1., 1, 1,],
121 # [2, 2, 2]]
122 # beta broadcast as: [[3., 4, 5],
123 # [3, 4, 5]]
124 # batch_Shape [2, 3]
125 dist.sample([4, 5]) # Shape [4, 5, 2, 3]
127 x = [.2, .3, .5]
128 # x will be broadcast as [[.2, .3, .5],
129 # [.2, .3, .5]],
130 # thus matching batch_shape [2, 3].
131 dist.prob(x) # Shape [2, 3]
132 ```
134 Compute the gradients of samples w.r.t. the parameters:
136 ```python
137 alpha = tf.constant(1.0)
138 beta = tf.constant(2.0)
139 dist = tfd.Beta(alpha, beta)
140 samples = dist.sample(5) # Shape [5]
141 loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
142 # Unbiased stochastic gradients of the loss function
143 grads = tf.gradients(loss, [alpha, beta])
144 ```
146 References:
147 Implicit Reparameterization Gradients:
148 [Figurnov et al., 2018]
149 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
150 ([pdf]
151 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
152 """
154 @deprecation.deprecated(
155 "2019-01-01",
156 "The TensorFlow Distributions library has moved to "
157 "TensorFlow Probability "
158 "(https://github.com/tensorflow/probability). You "
159 "should update all references to use `tfp.distributions` "
160 "instead of `tf.distributions`.",
161 warn_once=True)
162 def __init__(self,
163 concentration1=None,
164 concentration0=None,
165 validate_args=False,
166 allow_nan_stats=True,
167 name="Beta"):
168 """Initialize a batch of Beta distributions.
170 Args:
171 concentration1: Positive floating-point `Tensor` indicating mean
172 number of successes; aka "alpha". Implies `self.dtype` and
173 `self.batch_shape`, i.e.,
174 `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
175 concentration0: Positive floating-point `Tensor` indicating mean
176 number of failures; aka "beta". Otherwise has same semantics as
177 `concentration1`.
178 validate_args: Python `bool`, default `False`. When `True` distribution
179 parameters are checked for validity despite possibly degrading runtime
180 performance. When `False` invalid inputs may silently render incorrect
181 outputs.
182 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
183 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
184 result is undefined. When `False`, an exception is raised if one or
185 more of the statistic's batch members are undefined.
186 name: Python `str` name prefixed to Ops created by this class.
187 """
188 parameters = dict(locals())
189 with ops.name_scope(name, values=[concentration1, concentration0]) as name:
190 self._concentration1 = self._maybe_assert_valid_concentration(
191 ops.convert_to_tensor(concentration1, name="concentration1"),
192 validate_args)
193 self._concentration0 = self._maybe_assert_valid_concentration(
194 ops.convert_to_tensor(concentration0, name="concentration0"),
195 validate_args)
196 check_ops.assert_same_float_dtype([
197 self._concentration1, self._concentration0])
198 self._total_concentration = self._concentration1 + self._concentration0
199 super(Beta, self).__init__(
200 dtype=self._total_concentration.dtype,
201 validate_args=validate_args,
202 allow_nan_stats=allow_nan_stats,
203 reparameterization_type=distribution.FULLY_REPARAMETERIZED,
204 parameters=parameters,
205 graph_parents=[self._concentration1,
206 self._concentration0,
207 self._total_concentration],
208 name=name)
210 @staticmethod
211 def _param_shapes(sample_shape):
212 return dict(zip(
213 ["concentration1", "concentration0"],
214 [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))
216 @property
217 def concentration1(self):
218 """Concentration parameter associated with a `1` outcome."""
219 return self._concentration1
221 @property
222 def concentration0(self):
223 """Concentration parameter associated with a `0` outcome."""
224 return self._concentration0
226 @property
227 def total_concentration(self):
228 """Sum of concentration parameters."""
229 return self._total_concentration
231 def _batch_shape_tensor(self):
232 return array_ops.shape(self.total_concentration)
234 def _batch_shape(self):
235 return self.total_concentration.get_shape()
237 def _event_shape_tensor(self):
238 return constant_op.constant([], dtype=dtypes.int32)
240 def _event_shape(self):
241 return tensor_shape.TensorShape([])
243 def _sample_n(self, n, seed=None):
244 expanded_concentration1 = array_ops.ones_like(
245 self.total_concentration, dtype=self.dtype) * self.concentration1
246 expanded_concentration0 = array_ops.ones_like(
247 self.total_concentration, dtype=self.dtype) * self.concentration0
248 gamma1_sample = random_ops.random_gamma(
249 shape=[n],
250 alpha=expanded_concentration1,
251 dtype=self.dtype,
252 seed=seed)
253 gamma2_sample = random_ops.random_gamma(
254 shape=[n],
255 alpha=expanded_concentration0,
256 dtype=self.dtype,
257 seed=distribution_util.gen_new_seed(seed, "beta"))
258 beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
259 return beta_sample
261 @distribution_util.AppendDocstring(_beta_sample_note)
262 def _log_prob(self, x):
263 return self._log_unnormalized_prob(x) - self._log_normalization()
265 @distribution_util.AppendDocstring(_beta_sample_note)
266 def _prob(self, x):
267 return math_ops.exp(self._log_prob(x))
269 @distribution_util.AppendDocstring(_beta_sample_note)
270 def _log_cdf(self, x):
271 return math_ops.log(self._cdf(x))
273 @distribution_util.AppendDocstring(_beta_sample_note)
274 def _cdf(self, x):
275 return math_ops.betainc(self.concentration1, self.concentration0, x)
277 def _log_unnormalized_prob(self, x):
278 x = self._maybe_assert_valid_sample(x)
279 return (math_ops.xlogy(self.concentration1 - 1., x) +
280 (self.concentration0 - 1.) * math_ops.log1p(-x)) # pylint: disable=invalid-unary-operand-type
282 def _log_normalization(self):
283 return (math_ops.lgamma(self.concentration1)
284 + math_ops.lgamma(self.concentration0)
285 - math_ops.lgamma(self.total_concentration))
287 def _entropy(self):
288 return (
289 self._log_normalization()
290 - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1)
291 - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0)
292 + ((self.total_concentration - 2.) *
293 math_ops.digamma(self.total_concentration)))
295 def _mean(self):
296 return self._concentration1 / self._total_concentration
298 def _variance(self):
299 return self._mean() * (1. - self._mean()) / (1. + self.total_concentration)
301 @distribution_util.AppendDocstring(
302 """Note: The mode is undefined when `concentration1 <= 1` or
303 `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN`
304 is used for undefined modes. If `self.allow_nan_stats` is `False` an
305 exception is raised when one or more modes are undefined.""")
306 def _mode(self):
307 mode = (self.concentration1 - 1.) / (self.total_concentration - 2.)
308 if self.allow_nan_stats:
309 nan = array_ops.fill(
310 self.batch_shape_tensor(),
311 np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
312 name="nan")
313 is_defined = math_ops.logical_and(self.concentration1 > 1.,
314 self.concentration0 > 1.)
315 return array_ops.where_v2(is_defined, mode, nan)
316 return control_flow_ops.with_dependencies([
317 check_ops.assert_less(
318 array_ops.ones([], dtype=self.dtype),
319 self.concentration1,
320 message="Mode undefined for concentration1 <= 1."),
321 check_ops.assert_less(
322 array_ops.ones([], dtype=self.dtype),
323 self.concentration0,
324 message="Mode undefined for concentration0 <= 1.")
325 ], mode)
327 def _maybe_assert_valid_concentration(self, concentration, validate_args):
328 """Checks the validity of a concentration parameter."""
329 if not validate_args:
330 return concentration
331 return control_flow_ops.with_dependencies([
332 check_ops.assert_positive(
333 concentration,
334 message="Concentration parameter must be positive."),
335 ], concentration)
337 def _maybe_assert_valid_sample(self, x):
338 """Checks the validity of a sample."""
339 if not self.validate_args:
340 return x
341 return control_flow_ops.with_dependencies([
342 check_ops.assert_positive(x, message="sample must be positive"),
343 check_ops.assert_less(
344 x,
345 array_ops.ones([], self.dtype),
346 message="sample must be less than `1`."),
347 ], x)
350class BetaWithSoftplusConcentration(Beta):
351 """Beta with softplus transform of `concentration1` and `concentration0`."""
353 @deprecation.deprecated(
354 "2019-01-01",
355 "Use `tfd.Beta(tf.nn.softplus(concentration1), "
356 "tf.nn.softplus(concentration2))` instead.",
357 warn_once=True)
358 def __init__(self,
359 concentration1,
360 concentration0,
361 validate_args=False,
362 allow_nan_stats=True,
363 name="BetaWithSoftplusConcentration"):
364 parameters = dict(locals())
365 with ops.name_scope(name, values=[concentration1,
366 concentration0]) as name:
367 super(BetaWithSoftplusConcentration, self).__init__(
368 concentration1=nn.softplus(concentration1,
369 name="softplus_concentration1"),
370 concentration0=nn.softplus(concentration0,
371 name="softplus_concentration0"),
372 validate_args=validate_args,
373 allow_nan_stats=allow_nan_stats,
374 name=name)
375 self._parameters = parameters
378@kullback_leibler.RegisterKL(Beta, Beta)
379def _kl_beta_beta(d1, d2, name=None):
380 """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta.
382 Args:
383 d1: instance of a Beta distribution object.
384 d2: instance of a Beta distribution object.
385 name: (optional) Name to use for created operations.
386 default is "kl_beta_beta".
388 Returns:
389 Batchwise KL(d1 || d2)
390 """
391 def delta(fn, is_property=True):
392 fn1 = getattr(d1, fn)
393 fn2 = getattr(d2, fn)
394 return (fn2 - fn1) if is_property else (fn2() - fn1())
395 with ops.name_scope(name, "kl_beta_beta", values=[
396 d1.concentration1,
397 d1.concentration0,
398 d1.total_concentration,
399 d2.concentration1,
400 d2.concentration0,
401 d2.total_concentration,
402 ]):
403 return (delta("_log_normalization", is_property=False)
404 - math_ops.digamma(d1.concentration1) * delta("concentration1")
405 - math_ops.digamma(d1.concentration0) * delta("concentration0")
406 + (math_ops.digamma(d1.total_concentration)
407 * delta("total_concentration")))