Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/dirichlet_multinomial.py: 49%
84 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The DirichletMultinomial distribution class."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import check_ops
21from tensorflow.python.ops import control_flow_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops import random_ops
24from tensorflow.python.ops import special_math_ops
25from tensorflow.python.ops.distributions import distribution
26from tensorflow.python.ops.distributions import util as distribution_util
27from tensorflow.python.util import deprecation
28from tensorflow.python.util.tf_export import tf_export
31__all__ = [
32 "DirichletMultinomial",
33]
36_dirichlet_multinomial_sample_note = """For each batch of counts,
37`value = [n_0, ..., n_{K-1}]`, `P[value]` is the probability that after
38sampling `self.total_count` draws from this Dirichlet-Multinomial distribution,
39the number of draws falling in class `j` is `n_j`. Since this definition is
40[exchangeable](https://en.wikipedia.org/wiki/Exchangeable_random_variables);
41different sequences have the same counts so the probability includes a
42combinatorial coefficient.
44Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no
45fractional components, and such that
46`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable
47with `self.concentration` and `self.total_count`."""
50@tf_export(v1=["distributions.DirichletMultinomial"])
51class DirichletMultinomial(distribution.Distribution):
52 """Dirichlet-Multinomial compound distribution.
54 The Dirichlet-Multinomial distribution is parameterized by a (batch of)
55 length-`K` `concentration` vectors (`K > 1`) and a `total_count` number of
56 trials, i.e., the number of trials per draw from the DirichletMultinomial. It
57 is defined over a (batch of) length-`K` vector `counts` such that
58 `tf.reduce_sum(counts, -1) = total_count`. The Dirichlet-Multinomial is
59 identically the Beta-Binomial distribution when `K = 2`.
61 #### Mathematical Details
63 The Dirichlet-Multinomial is a distribution over `K`-class counts, i.e., a
64 length-`K` vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`.
66 The probability mass function (pmf) is,
68 ```none
69 pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z
70 Z = Beta(alpha) / N!
71 ```
73 where:
75 * `concentration = alpha = [alpha_0, ..., alpha_{K-1}]`, `alpha_j > 0`,
76 * `total_count = N`, `N` a positive integer,
77 * `N!` is `N` factorial, and,
78 * `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the
79 [multivariate beta function](
80 https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
81 and,
82 * `Gamma` is the [gamma function](
83 https://en.wikipedia.org/wiki/Gamma_function).
85 Dirichlet-Multinomial is a [compound distribution](
86 https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., its
87 samples are generated as follows.
89 1. Choose class probabilities:
90 `probs = [p_0,...,p_{K-1}] ~ Dir(concentration)`
91 2. Draw integers:
92 `counts = [n_0,...,n_{K-1}] ~ Multinomial(total_count, probs)`
94 The last `concentration` dimension parametrizes a single Dirichlet-Multinomial
95 distribution. When calling distribution functions (e.g., `dist.prob(counts)`),
96 `concentration`, `total_count` and `counts` are broadcast to the same shape.
97 The last dimension of `counts` corresponds single Dirichlet-Multinomial
98 distributions.
100 Distribution parameters are automatically broadcast in all functions; see
101 examples for details.
103 #### Pitfalls
105 The number of classes, `K`, must not exceed:
106 - the largest integer representable by `self.dtype`, i.e.,
107 `2**(mantissa_bits+1)` (IEE754),
108 - the maximum `Tensor` index, i.e., `2**31-1`.
110 In other words,
112 ```python
113 K <= min(2**31-1, {
114 tf.float16: 2**11,
115 tf.float32: 2**24,
116 tf.float64: 2**53 }[param.dtype])
117 ```
119 Note: This condition is validated only when `self.validate_args = True`.
121 #### Examples
123 ```python
124 alpha = [1., 2., 3.]
125 n = 2.
126 dist = DirichletMultinomial(n, alpha)
127 ```
129 Creates a 3-class distribution, with the 3rd class is most likely to be
130 drawn.
131 The distribution functions can be evaluated on counts.
133 ```python
134 # counts same shape as alpha.
135 counts = [0., 0., 2.]
136 dist.prob(counts) # Shape []
138 # alpha will be broadcast to [[1., 2., 3.], [1., 2., 3.]] to match counts.
139 counts = [[1., 1., 0.], [1., 0., 1.]]
140 dist.prob(counts) # Shape [2]
142 # alpha will be broadcast to shape [5, 7, 3] to match counts.
143 counts = [[...]] # Shape [5, 7, 3]
144 dist.prob(counts) # Shape [5, 7]
145 ```
147 Creates a 2-batch of 3-class distributions.
149 ```python
150 alpha = [[1., 2., 3.], [4., 5., 6.]] # Shape [2, 3]
151 n = [3., 3.]
152 dist = DirichletMultinomial(n, alpha)
154 # counts will be broadcast to [[2., 1., 0.], [2., 1., 0.]] to match alpha.
155 counts = [2., 1., 0.]
156 dist.prob(counts) # Shape [2]
157 ```
159 """
161 # TODO(b/27419586) Change docstring for dtype of concentration once int
162 # allowed.
163 @deprecation.deprecated(
164 "2019-01-01",
165 "The TensorFlow Distributions library has moved to "
166 "TensorFlow Probability "
167 "(https://github.com/tensorflow/probability). You "
168 "should update all references to use `tfp.distributions` "
169 "instead of `tf.distributions`.",
170 warn_once=True)
171 def __init__(self,
172 total_count,
173 concentration,
174 validate_args=False,
175 allow_nan_stats=True,
176 name="DirichletMultinomial"):
177 """Initialize a batch of DirichletMultinomial distributions.
179 Args:
180 total_count: Non-negative floating point tensor, whose dtype is the same
181 as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with
182 `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different
183 Dirichlet multinomial distributions. Its components should be equal to
184 integer values.
185 concentration: Positive floating point tensor, whose dtype is the
186 same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`.
187 Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet
188 multinomial distributions.
189 validate_args: Python `bool`, default `False`. When `True` distribution
190 parameters are checked for validity despite possibly degrading runtime
191 performance. When `False` invalid inputs may silently render incorrect
192 outputs.
193 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
194 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
195 result is undefined. When `False`, an exception is raised if one or
196 more of the statistic's batch members are undefined.
197 name: Python `str` name prefixed to Ops created by this class.
198 """
199 parameters = dict(locals())
200 with ops.name_scope(name, values=[total_count, concentration]) as name:
201 # Broadcasting works because:
202 # * The broadcasting convention is to prepend dimensions of size [1], and
203 # we use the last dimension for the distribution, whereas
204 # the batch dimensions are the leading dimensions, which forces the
205 # distribution dimension to be defined explicitly (i.e. it cannot be
206 # created automatically by prepending). This forces enough explicitness.
207 # * All calls involving `counts` eventually require a broadcast between
208 # `counts` and concentration.
209 self._total_count = ops.convert_to_tensor(total_count, name="total_count")
210 if validate_args:
211 self._total_count = (
212 distribution_util.embed_check_nonnegative_integer_form(
213 self._total_count))
214 self._concentration = self._maybe_assert_valid_concentration(
215 ops.convert_to_tensor(concentration,
216 name="concentration"),
217 validate_args)
218 self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
219 super(DirichletMultinomial, self).__init__(
220 dtype=self._concentration.dtype,
221 validate_args=validate_args,
222 allow_nan_stats=allow_nan_stats,
223 reparameterization_type=distribution.NOT_REPARAMETERIZED,
224 parameters=parameters,
225 graph_parents=[self._total_count,
226 self._concentration],
227 name=name)
229 @property
230 def total_count(self):
231 """Number of trials used to construct a sample."""
232 return self._total_count
234 @property
235 def concentration(self):
236 """Concentration parameter; expected prior counts for that coordinate."""
237 return self._concentration
239 @property
240 def total_concentration(self):
241 """Sum of last dim of concentration parameter."""
242 return self._total_concentration
244 def _batch_shape_tensor(self):
245 return array_ops.shape(self.total_concentration)
247 def _batch_shape(self):
248 return self.total_concentration.get_shape()
250 def _event_shape_tensor(self):
251 return array_ops.shape(self.concentration)[-1:]
253 def _event_shape(self):
254 # Event shape depends only on total_concentration, not "n".
255 return self.concentration.get_shape().with_rank_at_least(1)[-1:]
257 def _sample_n(self, n, seed=None):
258 n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
259 k = self.event_shape_tensor()[0]
260 unnormalized_logits = array_ops.reshape(
261 math_ops.log(random_ops.random_gamma(
262 shape=[n],
263 alpha=self.concentration,
264 dtype=self.dtype,
265 seed=seed)),
266 shape=[-1, k])
267 draws = random_ops.multinomial(
268 logits=unnormalized_logits,
269 num_samples=n_draws,
270 seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
271 x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2)
272 final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
273 x = array_ops.reshape(x, final_shape)
274 return math_ops.cast(x, self.dtype)
276 @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note)
277 def _log_prob(self, counts):
278 counts = self._maybe_assert_valid_sample(counts)
279 ordered_prob = (
280 special_math_ops.lbeta(self.concentration + counts)
281 - special_math_ops.lbeta(self.concentration))
282 return ordered_prob + distribution_util.log_combinations(
283 self.total_count, counts)
285 @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note)
286 def _prob(self, counts):
287 return math_ops.exp(self._log_prob(counts))
289 def _mean(self):
290 return self.total_count * (self.concentration /
291 self.total_concentration[..., array_ops.newaxis])
293 @distribution_util.AppendDocstring(
294 """The covariance for each batch member is defined as the following:
296 ```none
297 Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
298 (n + alpha_0) / (1 + alpha_0)
299 ```
301 where `concentration = alpha` and
302 `total_concentration = alpha_0 = sum_j alpha_j`.
304 The covariance between elements in a batch is defined as:
306 ```none
307 Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 *
308 (n + alpha_0) / (1 + alpha_0)
309 ```
310 """)
311 def _covariance(self):
312 x = self._variance_scale_term() * self._mean()
313 # pylint: disable=invalid-unary-operand-type
314 return array_ops.matrix_set_diag(
315 -math_ops.matmul(
316 x[..., array_ops.newaxis],
317 x[..., array_ops.newaxis, :]), # outer prod
318 self._variance())
320 def _variance(self):
321 scale = self._variance_scale_term()
322 x = scale * self._mean()
323 return x * (self.total_count * scale - x)
325 def _variance_scale_term(self):
326 """Helper to `_covariance` and `_variance` which computes a shared scale."""
327 # We must take care to expand back the last dim whenever we use the
328 # total_concentration.
329 c0 = self.total_concentration[..., array_ops.newaxis]
330 return math_ops.sqrt((1. + c0 / self.total_count) / (1. + c0))
332 def _maybe_assert_valid_concentration(self, concentration, validate_args):
333 """Checks the validity of the concentration parameter."""
334 if not validate_args:
335 return concentration
336 concentration = distribution_util.embed_check_categorical_event_shape(
337 concentration)
338 return control_flow_ops.with_dependencies([
339 check_ops.assert_positive(
340 concentration,
341 message="Concentration parameter must be positive."),
342 ], concentration)
344 def _maybe_assert_valid_sample(self, counts):
345 """Check counts for proper shape, values, then return tensor version."""
346 if not self.validate_args:
347 return counts
348 counts = distribution_util.embed_check_nonnegative_integer_form(counts)
349 return control_flow_ops.with_dependencies([
350 check_ops.assert_equal(
351 self.total_count, math_ops.reduce_sum(counts, -1),
352 message="counts last-dimension must sum to `self.total_count`"),
353 ], counts)