Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/dirichlet.py: 51%
90 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Dirichlet distribution class."""
17import numpy as np
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import check_ops
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import random_ops
25from tensorflow.python.ops import special_math_ops
26from tensorflow.python.ops.distributions import distribution
27from tensorflow.python.ops.distributions import kullback_leibler
28from tensorflow.python.ops.distributions import util as distribution_util
29from tensorflow.python.util import deprecation
30from tensorflow.python.util.tf_export import tf_export
33__all__ = [
34 "Dirichlet",
35]
38_dirichlet_sample_note = """Note: `value` must be a non-negative tensor with
39dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e.,
40`tf.reduce_sum(value, -1) = 1`. It must have a shape compatible with
41`self.batch_shape() + self.event_shape()`."""
44@tf_export(v1=["distributions.Dirichlet"])
45class Dirichlet(distribution.Distribution):
46 """Dirichlet distribution.
48 The Dirichlet distribution is defined over the
49 [`(k-1)`-simplex](https://en.wikipedia.org/wiki/Simplex) using a positive,
50 length-`k` vector `concentration` (`k > 1`). The Dirichlet is identically the
51 Beta distribution when `k = 2`.
53 #### Mathematical Details
55 The Dirichlet is a distribution over the open `(k-1)`-simplex, i.e.,
57 ```none
58 S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }.
59 ```
61 The probability density function (pdf) is,
63 ```none
64 pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z
65 Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j)
66 ```
68 where:
70 * `x in S^{k-1}`, i.e., the `(k-1)`-simplex,
71 * `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`,
72 * `Z` is the normalization constant aka the [multivariate beta function](
73 https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
74 and,
75 * `Gamma` is the [gamma function](
76 https://en.wikipedia.org/wiki/Gamma_function).
78 The `concentration` represents mean total counts of class occurrence, i.e.,
80 ```none
81 concentration = alpha = mean * total_concentration
82 ```
84 where `mean` in `S^{k-1}` and `total_concentration` is a positive real number
85 representing a mean total count.
87 Distribution parameters are automatically broadcast in all functions; see
88 examples for details.
90 Warning: Some components of the samples can be zero due to finite precision.
91 This happens more often when some of the concentrations are very small.
92 Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
93 density.
95 Samples of this distribution are reparameterized (pathwise differentiable).
96 The derivatives are computed using the approach described in
97 (Figurnov et al., 2018).
99 #### Examples
101 ```python
102 import tensorflow_probability as tfp
103 tfd = tfp.distributions
105 # Create a single trivariate Dirichlet, with the 3rd class being three times
106 # more frequent than the first. I.e., batch_shape=[], event_shape=[3].
107 alpha = [1., 2, 3]
108 dist = tfd.Dirichlet(alpha)
110 dist.sample([4, 5]) # shape: [4, 5, 3]
112 # x has one sample, one batch, three classes:
113 x = [.2, .3, .5] # shape: [3]
114 dist.prob(x) # shape: []
116 # x has two samples from one batch:
117 x = [[.1, .4, .5],
118 [.2, .3, .5]]
119 dist.prob(x) # shape: [2]
121 # alpha will be broadcast to shape [5, 7, 3] to match x.
122 x = [[...]] # shape: [5, 7, 3]
123 dist.prob(x) # shape: [5, 7]
124 ```
126 ```python
127 # Create batch_shape=[2], event_shape=[3]:
128 alpha = [[1., 2, 3],
129 [4, 5, 6]] # shape: [2, 3]
130 dist = tfd.Dirichlet(alpha)
132 dist.sample([4, 5]) # shape: [4, 5, 2, 3]
134 x = [.2, .3, .5]
135 # x will be broadcast as [[.2, .3, .5],
136 # [.2, .3, .5]],
137 # thus matching batch_shape [2, 3].
138 dist.prob(x) # shape: [2]
139 ```
141 Compute the gradients of samples w.r.t. the parameters:
143 ```python
144 alpha = tf.constant([1.0, 2.0, 3.0])
145 dist = tfd.Dirichlet(alpha)
146 samples = dist.sample(5) # Shape [5, 3]
147 loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
148 # Unbiased stochastic gradients of the loss function
149 grads = tf.gradients(loss, alpha)
150 ```
152 References:
153 Implicit Reparameterization Gradients:
154 [Figurnov et al., 2018]
155 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients)
156 ([pdf]
157 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf))
158 """
160 @deprecation.deprecated(
161 "2019-01-01",
162 "The TensorFlow Distributions library has moved to "
163 "TensorFlow Probability "
164 "(https://github.com/tensorflow/probability). You "
165 "should update all references to use `tfp.distributions` "
166 "instead of `tf.distributions`.",
167 warn_once=True)
168 def __init__(self,
169 concentration,
170 validate_args=False,
171 allow_nan_stats=True,
172 name="Dirichlet"):
173 """Initialize a batch of Dirichlet distributions.
175 Args:
176 concentration: Positive floating-point `Tensor` indicating mean number
177 of class occurrences; aka "alpha". Implies `self.dtype`, and
178 `self.batch_shape`, `self.event_shape`, i.e., if
179 `concentration.shape = [N1, N2, ..., Nm, k]` then
180 `batch_shape = [N1, N2, ..., Nm]` and
181 `event_shape = [k]`.
182 validate_args: Python `bool`, default `False`. When `True` distribution
183 parameters are checked for validity despite possibly degrading runtime
184 performance. When `False` invalid inputs may silently render incorrect
185 outputs.
186 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
187 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
188 result is undefined. When `False`, an exception is raised if one or
189 more of the statistic's batch members are undefined.
190 name: Python `str` name prefixed to Ops created by this class.
191 """
192 parameters = dict(locals())
193 with ops.name_scope(name, values=[concentration]) as name:
194 self._concentration = self._maybe_assert_valid_concentration(
195 ops.convert_to_tensor(concentration, name="concentration"),
196 validate_args)
197 self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
198 super(Dirichlet, self).__init__(
199 dtype=self._concentration.dtype,
200 validate_args=validate_args,
201 allow_nan_stats=allow_nan_stats,
202 reparameterization_type=distribution.FULLY_REPARAMETERIZED,
203 parameters=parameters,
204 graph_parents=[self._concentration,
205 self._total_concentration],
206 name=name)
208 @property
209 def concentration(self):
210 """Concentration parameter; expected counts for that coordinate."""
211 return self._concentration
213 @property
214 def total_concentration(self):
215 """Sum of last dim of concentration parameter."""
216 return self._total_concentration
218 def _batch_shape_tensor(self):
219 return array_ops.shape(self.total_concentration)
221 def _batch_shape(self):
222 return self.total_concentration.get_shape()
224 def _event_shape_tensor(self):
225 return array_ops.shape(self.concentration)[-1:]
227 def _event_shape(self):
228 return self.concentration.get_shape().with_rank_at_least(1)[-1:]
230 def _sample_n(self, n, seed=None):
231 gamma_sample = random_ops.random_gamma(
232 shape=[n],
233 alpha=self.concentration,
234 dtype=self.dtype,
235 seed=seed)
236 return gamma_sample / math_ops.reduce_sum(gamma_sample, -1, keepdims=True)
238 @distribution_util.AppendDocstring(_dirichlet_sample_note)
239 def _log_prob(self, x):
240 return self._log_unnormalized_prob(x) - self._log_normalization()
242 @distribution_util.AppendDocstring(_dirichlet_sample_note)
243 def _prob(self, x):
244 return math_ops.exp(self._log_prob(x))
246 def _log_unnormalized_prob(self, x):
247 x = self._maybe_assert_valid_sample(x)
248 return math_ops.reduce_sum(math_ops.xlogy(self.concentration - 1., x), -1)
250 def _log_normalization(self):
251 return special_math_ops.lbeta(self.concentration)
253 def _entropy(self):
254 k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
255 return (
256 self._log_normalization()
257 + ((self.total_concentration - k)
258 * math_ops.digamma(self.total_concentration))
259 - math_ops.reduce_sum(
260 (self.concentration - 1.) * math_ops.digamma(self.concentration),
261 axis=-1))
263 def _mean(self):
264 return self.concentration / self.total_concentration[..., array_ops.newaxis]
266 def _covariance(self):
267 x = self._variance_scale_term() * self._mean()
268 # pylint: disable=invalid-unary-operand-type
269 return array_ops.matrix_set_diag(
270 -math_ops.matmul(
271 x[..., array_ops.newaxis],
272 x[..., array_ops.newaxis, :]), # outer prod
273 self._variance())
275 def _variance(self):
276 scale = self._variance_scale_term()
277 x = scale * self._mean()
278 return x * (scale - x)
280 def _variance_scale_term(self):
281 """Helper to `_covariance` and `_variance` which computes a shared scale."""
282 return math_ops.rsqrt(1. + self.total_concentration[..., array_ops.newaxis])
284 @distribution_util.AppendDocstring(
285 """Note: The mode is undefined when any `concentration <= 1`. If
286 `self.allow_nan_stats` is `True`, `NaN` is used for undefined modes. If
287 `self.allow_nan_stats` is `False` an exception is raised when one or more
288 modes are undefined.""")
289 def _mode(self):
290 k = math_ops.cast(self.event_shape_tensor()[0], self.dtype)
291 mode = (self.concentration - 1.) / (
292 self.total_concentration[..., array_ops.newaxis] - k)
293 if self.allow_nan_stats:
294 nan = array_ops.fill(
295 array_ops.shape(mode),
296 np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
297 name="nan")
298 return array_ops.where_v2(
299 math_ops.reduce_all(self.concentration > 1., axis=-1), mode, nan)
300 return control_flow_ops.with_dependencies([
301 check_ops.assert_less(
302 array_ops.ones([], self.dtype),
303 self.concentration,
304 message="Mode undefined when any concentration <= 1"),
305 ], mode)
307 def _maybe_assert_valid_concentration(self, concentration, validate_args):
308 """Checks the validity of the concentration parameter."""
309 if not validate_args:
310 return concentration
311 return control_flow_ops.with_dependencies([
312 check_ops.assert_positive(
313 concentration,
314 message="Concentration parameter must be positive."),
315 check_ops.assert_rank_at_least(
316 concentration, 1,
317 message="Concentration parameter must have >=1 dimensions."),
318 check_ops.assert_less(
319 1, array_ops.shape(concentration)[-1],
320 message="Concentration parameter must have event_size >= 2."),
321 ], concentration)
323 def _maybe_assert_valid_sample(self, x):
324 """Checks the validity of a sample."""
325 if not self.validate_args:
326 return x
327 return control_flow_ops.with_dependencies([
328 check_ops.assert_positive(x, message="samples must be positive"),
329 check_ops.assert_near(
330 array_ops.ones([], dtype=self.dtype),
331 math_ops.reduce_sum(x, -1),
332 message="sample last-dimension must sum to `1`"),
333 ], x)
336@kullback_leibler.RegisterKL(Dirichlet, Dirichlet)
337def _kl_dirichlet_dirichlet(d1, d2, name=None):
338 """Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet.
340 Args:
341 d1: instance of a Dirichlet distribution object.
342 d2: instance of a Dirichlet distribution object.
343 name: (optional) Name to use for created operations.
344 default is "kl_dirichlet_dirichlet".
346 Returns:
347 Batchwise KL(d1 || d2)
348 """
349 with ops.name_scope(name, "kl_dirichlet_dirichlet", values=[
350 d1.concentration, d2.concentration]):
351 # The KL between Dirichlet distributions can be derived as follows. We have
352 #
353 # Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)]
354 #
355 # where B(a) is the multivariate Beta function:
356 #
357 # B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n])
358 #
359 # The KL is
360 #
361 # KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))}
362 #
363 # so we'll need to know the log density of the Dirichlet. This is
364 #
365 # log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a)
366 #
367 # The only term that matters for the expectations is the log(x[i]). To
368 # compute the expectation of this term over the Dirichlet density, we can
369 # use the following facts about the Dirichlet in exponential family form:
370 # 1. log(x[i]) is a sufficient statistic
371 # 2. expected sufficient statistics (of any exp family distribution) are
372 # equal to derivatives of the log normalizer with respect to
373 # corresponding natural parameters: E{T[i](x)} = dA/d(eta[i])
374 #
375 # To proceed, we can rewrite the Dirichlet density in exponential family
376 # form as follows:
377 #
378 # Dir(x; a) = exp{eta(a) . T(x) - A(a)}
379 #
380 # where '.' is the dot product of vectors eta and T, and A is a scalar:
381 #
382 # eta[i](a) = a[i] - 1
383 # T[i](x) = log(x[i])
384 # A(a) = log B(a)
385 #
386 # Now, we can use fact (2) above to write
387 #
388 # E_Dir(x; a)[log(x[i])]
389 # = dA(a) / da[i]
390 # = d/da[i] log B(a)
391 # = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j])
392 # = digamma(a[i])) - digamma(sum_j a[j])
393 #
394 # Putting it all together, we have
395 #
396 # KL[Dir(x; a) || Dir(x; b)]
397 # = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)}
398 # = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b))
399 # = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b)
400 # = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))]
401 # - lbeta(a) + lbeta(b))
403 digamma_sum_d1 = math_ops.digamma(
404 math_ops.reduce_sum(d1.concentration, axis=-1, keepdims=True))
405 digamma_diff = math_ops.digamma(d1.concentration) - digamma_sum_d1
406 concentration_diff = d1.concentration - d2.concentration
408 return (math_ops.reduce_sum(concentration_diff * digamma_diff, axis=-1) -
409 special_math_ops.lbeta(d1.concentration) +
410 special_math_ops.lbeta(d2.concentration))