Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/gamma.py: 56%
90 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Gamma distribution class."""
17import numpy as np
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn
28from tensorflow.python.ops import random_ops
29from tensorflow.python.ops.distributions import distribution
30from tensorflow.python.ops.distributions import kullback_leibler
31from tensorflow.python.ops.distributions import util as distribution_util
32from tensorflow.python.util import deprecation
33from tensorflow.python.util.tf_export import tf_export
36__all__ = [
37 "Gamma",
38 "GammaWithSoftplusConcentrationRate",
39]
42@tf_export(v1=["distributions.Gamma"])
43class Gamma(distribution.Distribution):
44 """Gamma distribution.
46 The Gamma distribution is defined over positive real numbers using
47 parameters `concentration` (aka "alpha") and `rate` (aka "beta").
49 #### Mathematical Details
51 The probability density function (pdf) is,
53 ```none
54 pdf(x; alpha, beta, x > 0) = x**(alpha - 1) exp(-x beta) / Z
55 Z = Gamma(alpha) beta**(-alpha)
56 ```
58 where:
60 * `concentration = alpha`, `alpha > 0`,
61 * `rate = beta`, `beta > 0`,
62 * `Z` is the normalizing constant, and,
63 * `Gamma` is the [gamma function](
64 https://en.wikipedia.org/wiki/Gamma_function).
66 The cumulative density function (cdf) is,
68 ```none
69 cdf(x; alpha, beta, x > 0) = GammaInc(alpha, beta x) / Gamma(alpha)
70 ```
72 where `GammaInc` is the [lower incomplete Gamma function](
73 https://en.wikipedia.org/wiki/Incomplete_gamma_function).
75 The parameters can be intuited via their relationship to mean and stddev,
77 ```none
78 concentration = alpha = (mean / stddev)**2
79 rate = beta = mean / stddev**2 = concentration / mean
80 ```
82 Distribution parameters are automatically broadcast in all functions; see
83 examples for details.
85 Warning: The samples of this distribution are always non-negative. However,
86 the samples that are smaller than `np.finfo(dtype).tiny` are rounded
87 to this value, so it appears more often than it should.
88 This should only be noticeable when the `concentration` is very small, or the
89 `rate` is very large. See note in `tf.random.gamma` docstring.
91 Samples of this distribution are reparameterized (pathwise differentiable).
92 The derivatives are computed using the approach described in
93 (Figurnov et al., 2018).
95 #### Examples
97 ```python
98 import tensorflow_probability as tfp
99 tfd = tfp.distributions
101 dist = tfd.Gamma(concentration=3.0, rate=2.0)
102 dist2 = tfd.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
103 ```
105 Compute the gradients of samples w.r.t. the parameters:
107 ```python
108 concentration = tf.constant(3.0)
109 rate = tf.constant(2.0)
110 dist = tfd.Gamma(concentration, rate)
111 samples = dist.sample(5) # Shape [5]
112 loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
113 # Unbiased stochastic gradients of the loss function
114 grads = tf.gradients(loss, [concentration, rate])
115 ```
117 References:
118 Implicit Reparameterization Gradients:
119 [Figurnov et al., 2018]
120 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
121 ([pdf](http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
122 """
124 @deprecation.deprecated(
125 "2019-01-01",
126 "The TensorFlow Distributions library has moved to "
127 "TensorFlow Probability "
128 "(https://github.com/tensorflow/probability). You "
129 "should update all references to use `tfp.distributions` "
130 "instead of `tf.distributions`.",
131 warn_once=True)
132 def __init__(self,
133 concentration,
134 rate,
135 validate_args=False,
136 allow_nan_stats=True,
137 name="Gamma"):
138 """Construct Gamma with `concentration` and `rate` parameters.
140 The parameters `concentration` and `rate` must be shaped in a way that
141 supports broadcasting (e.g. `concentration + rate` is a valid operation).
143 Args:
144 concentration: Floating point tensor, the concentration params of the
145 distribution(s). Must contain only positive values.
146 rate: Floating point tensor, the inverse scale params of the
147 distribution(s). Must contain only positive values.
148 validate_args: Python `bool`, default `False`. When `True` distribution
149 parameters are checked for validity despite possibly degrading runtime
150 performance. When `False` invalid inputs may silently render incorrect
151 outputs.
152 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
153 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
154 result is undefined. When `False`, an exception is raised if one or
155 more of the statistic's batch members are undefined.
156 name: Python `str` name prefixed to Ops created by this class.
158 Raises:
159 TypeError: if `concentration` and `rate` are different dtypes.
160 """
161 parameters = dict(locals())
162 with ops.name_scope(name, values=[concentration, rate]) as name:
163 with ops.control_dependencies([
164 check_ops.assert_positive(concentration),
165 check_ops.assert_positive(rate),
166 ] if validate_args else []):
167 self._concentration = array_ops.identity(
168 concentration, name="concentration")
169 self._rate = array_ops.identity(rate, name="rate")
170 check_ops.assert_same_float_dtype(
171 [self._concentration, self._rate])
172 super(Gamma, self).__init__(
173 dtype=self._concentration.dtype,
174 validate_args=validate_args,
175 allow_nan_stats=allow_nan_stats,
176 reparameterization_type=distribution.FULLY_REPARAMETERIZED,
177 parameters=parameters,
178 graph_parents=[self._concentration,
179 self._rate],
180 name=name)
182 @staticmethod
183 def _param_shapes(sample_shape):
184 return dict(
185 zip(("concentration", "rate"), ([ops.convert_to_tensor(
186 sample_shape, dtype=dtypes.int32)] * 2)))
188 @property
189 def concentration(self):
190 """Concentration parameter."""
191 return self._concentration
193 @property
194 def rate(self):
195 """Rate parameter."""
196 return self._rate
198 def _batch_shape_tensor(self):
199 return array_ops.broadcast_dynamic_shape(
200 array_ops.shape(self.concentration),
201 array_ops.shape(self.rate))
203 def _batch_shape(self):
204 return array_ops.broadcast_static_shape(
205 self.concentration.get_shape(),
206 self.rate.get_shape())
208 def _event_shape_tensor(self):
209 return constant_op.constant([], dtype=dtypes.int32)
211 def _event_shape(self):
212 return tensor_shape.TensorShape([])
214 @distribution_util.AppendDocstring(
215 """Note: See `tf.random.gamma` docstring for sampling details and
216 caveats.""")
217 def _sample_n(self, n, seed=None):
218 return random_ops.random_gamma(
219 shape=[n],
220 alpha=self.concentration,
221 beta=self.rate,
222 dtype=self.dtype,
223 seed=seed)
225 def _log_prob(self, x):
226 return self._log_unnormalized_prob(x) - self._log_normalization()
228 def _cdf(self, x):
229 x = self._maybe_assert_valid_sample(x)
230 # Note that igamma returns the regularized incomplete gamma function,
231 # which is what we want for the CDF.
232 return math_ops.igamma(self.concentration, self.rate * x)
234 def _log_unnormalized_prob(self, x):
235 x = self._maybe_assert_valid_sample(x)
236 return math_ops.xlogy(self.concentration - 1., x) - self.rate * x
238 def _log_normalization(self):
239 return (math_ops.lgamma(self.concentration)
240 - self.concentration * math_ops.log(self.rate))
242 def _entropy(self):
243 return (self.concentration
244 - math_ops.log(self.rate)
245 + math_ops.lgamma(self.concentration)
246 + ((1. - self.concentration) *
247 math_ops.digamma(self.concentration)))
249 def _mean(self):
250 return self.concentration / self.rate
252 def _variance(self):
253 return self.concentration / math_ops.square(self.rate)
255 def _stddev(self):
256 return math_ops.sqrt(self.concentration) / self.rate
258 @distribution_util.AppendDocstring(
259 """The mode of a gamma distribution is `(shape - 1) / rate` when
260 `shape > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is `False`,
261 an exception will be raised rather than returning `NaN`.""")
262 def _mode(self):
263 mode = (self.concentration - 1.) / self.rate
264 if self.allow_nan_stats:
265 nan = array_ops.fill(
266 self.batch_shape_tensor(),
267 np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
268 name="nan")
269 return array_ops.where_v2(self.concentration > 1., mode, nan)
270 else:
271 return control_flow_ops.with_dependencies([
272 check_ops.assert_less(
273 array_ops.ones([], self.dtype),
274 self.concentration,
275 message="mode not defined when any concentration <= 1"),
276 ], mode)
278 def _maybe_assert_valid_sample(self, x):
279 check_ops.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
280 if not self.validate_args:
281 return x
282 return control_flow_ops.with_dependencies([
283 check_ops.assert_positive(x),
284 ], x)
287class GammaWithSoftplusConcentrationRate(Gamma):
288 """`Gamma` with softplus of `concentration` and `rate`."""
290 @deprecation.deprecated(
291 "2019-01-01",
292 "Use `tfd.Gamma(tf.nn.softplus(concentration), "
293 "tf.nn.softplus(rate))` instead.",
294 warn_once=True)
295 def __init__(self,
296 concentration,
297 rate,
298 validate_args=False,
299 allow_nan_stats=True,
300 name="GammaWithSoftplusConcentrationRate"):
301 parameters = dict(locals())
302 with ops.name_scope(name, values=[concentration, rate]) as name:
303 super(GammaWithSoftplusConcentrationRate, self).__init__(
304 concentration=nn.softplus(concentration,
305 name="softplus_concentration"),
306 rate=nn.softplus(rate, name="softplus_rate"),
307 validate_args=validate_args,
308 allow_nan_stats=allow_nan_stats,
309 name=name)
310 self._parameters = parameters
313@kullback_leibler.RegisterKL(Gamma, Gamma)
314def _kl_gamma_gamma(g0, g1, name=None):
315 """Calculate the batched KL divergence KL(g0 || g1) with g0 and g1 Gamma.
317 Args:
318 g0: instance of a Gamma distribution object.
319 g1: instance of a Gamma distribution object.
320 name: (optional) Name to use for created operations.
321 Default is "kl_gamma_gamma".
323 Returns:
324 kl_gamma_gamma: `Tensor`. The batchwise KL(g0 || g1).
325 """
326 with ops.name_scope(name, "kl_gamma_gamma", values=[
327 g0.concentration, g0.rate, g1.concentration, g1.rate]):
328 # Result from:
329 # http://www.fil.ion.ucl.ac.uk/~wpenny/publications/densities.ps
330 # For derivation see:
331 # http://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions pylint: disable=line-too-long
332 return (((g0.concentration - g1.concentration)
333 * math_ops.digamma(g0.concentration))
334 + math_ops.lgamma(g1.concentration)
335 - math_ops.lgamma(g0.concentration)
336 + g1.concentration * math_ops.log(g0.rate)
337 - g1.concentration * math_ops.log(g1.rate)
338 + g0.concentration * (g1.rate / g0.rate - 1.))