Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/multinomial.py: 46%
85 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Multinomial distribution class."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import check_ops
21from tensorflow.python.ops import control_flow_ops
22from tensorflow.python.ops import map_fn
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import nn_ops
25from tensorflow.python.ops import random_ops
26from tensorflow.python.ops.distributions import distribution
27from tensorflow.python.ops.distributions import util as distribution_util
28from tensorflow.python.util import deprecation
29from tensorflow.python.util.tf_export import tf_export
32__all__ = [
33 "Multinomial",
34]
37_multinomial_sample_note = """For each batch of counts, `value = [n_0, ...
38,n_{k-1}]`, `P[value]` is the probability that after sampling `self.total_count`
39draws from this Multinomial distribution, the number of draws falling in class
40`j` is `n_j`. Since this definition is [exchangeable](
41https://en.wikipedia.org/wiki/Exchangeable_random_variables); different
42sequences have the same counts so the probability includes a combinatorial
43coefficient.
45Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no
46fractional components, and such that
47`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable
48with `self.probs` and `self.total_count`."""
51@tf_export(v1=["distributions.Multinomial"])
52class Multinomial(distribution.Distribution):
53 """Multinomial distribution.
55 This Multinomial distribution is parameterized by `probs`, a (batch of)
56 length-`K` `prob` (probability) vectors (`K > 1`) such that
57 `tf.reduce_sum(probs, -1) = 1`, and a `total_count` number of trials, i.e.,
58 the number of trials per draw from the Multinomial. It is defined over a
59 (batch of) length-`K` vector `counts` such that
60 `tf.reduce_sum(counts, -1) = total_count`. The Multinomial is identically the
61 Binomial distribution when `K = 2`.
63 #### Mathematical Details
65 The Multinomial is a distribution over `K`-class counts, i.e., a length-`K`
66 vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`.
68 The probability mass function (pmf) is,
70 ```none
71 pmf(n; pi, N) = prod_j (pi_j)**n_j / Z
72 Z = (prod_j n_j!) / N!
73 ```
75 where:
76 * `probs = pi = [pi_0, ..., pi_{K-1}]`, `pi_j > 0`, `sum_j pi_j = 1`,
77 * `total_count = N`, `N` a positive integer,
78 * `Z` is the normalization constant, and,
79 * `N!` denotes `N` factorial.
81 Distribution parameters are automatically broadcast in all functions; see
82 examples for details.
84 #### Pitfalls
86 The number of classes, `K`, must not exceed:
87 - the largest integer representable by `self.dtype`, i.e.,
88 `2**(mantissa_bits+1)` (IEE754),
89 - the maximum `Tensor` index, i.e., `2**31-1`.
91 In other words,
93 ```python
94 K <= min(2**31-1, {
95 tf.float16: 2**11,
96 tf.float32: 2**24,
97 tf.float64: 2**53 }[param.dtype])
98 ```
100 Note: This condition is validated only when `self.validate_args = True`.
102 #### Examples
104 Create a 3-class distribution, with the 3rd class is most likely to be drawn,
105 using logits.
107 ```python
108 logits = [-50., -43, 0]
109 dist = Multinomial(total_count=4., logits=logits)
110 ```
112 Create a 3-class distribution, with the 3rd class is most likely to be drawn.
114 ```python
115 p = [.2, .3, .5]
116 dist = Multinomial(total_count=4., probs=p)
117 ```
119 The distribution functions can be evaluated on counts.
121 ```python
122 # counts same shape as p.
123 counts = [1., 0, 3]
124 dist.prob(counts) # Shape []
126 # p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
127 counts = [[1., 2, 1], [2, 2, 0]]
128 dist.prob(counts) # Shape [2]
130 # p will be broadcast to shape [5, 7, 3] to match counts.
131 counts = [[...]] # Shape [5, 7, 3]
132 dist.prob(counts) # Shape [5, 7]
133 ```
135 Create a 2-batch of 3-class distributions.
137 ```python
138 p = [[.1, .2, .7], [.3, .3, .4]] # Shape [2, 3]
139 dist = Multinomial(total_count=[4., 5], probs=p)
141 counts = [[2., 1, 1], [3, 1, 1]]
142 dist.prob(counts) # Shape [2]
144 dist.sample(5) # Shape [5, 2, 3]
145 ```
146 """
148 @deprecation.deprecated(
149 "2019-01-01",
150 "The TensorFlow Distributions library has moved to "
151 "TensorFlow Probability "
152 "(https://github.com/tensorflow/probability). You "
153 "should update all references to use `tfp.distributions` "
154 "instead of `tf.distributions`.",
155 warn_once=True)
156 def __init__(self,
157 total_count,
158 logits=None,
159 probs=None,
160 validate_args=False,
161 allow_nan_stats=True,
162 name="Multinomial"):
163 """Initialize a batch of Multinomial distributions.
165 Args:
166 total_count: Non-negative floating point tensor with shape broadcastable
167 to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
168 `N1 x ... x Nm` different Multinomial distributions. Its components
169 should be equal to integer values.
170 logits: Floating point tensor representing unnormalized log-probabilities
171 of a positive event with shape broadcastable to
172 `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines
173 this as a batch of `N1 x ... x Nm` different `K` class Multinomial
174 distributions. Only one of `logits` or `probs` should be passed in.
175 probs: Positive floating point tensor with shape broadcastable to
176 `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines
177 this as a batch of `N1 x ... x Nm` different `K` class Multinomial
178 distributions. `probs`'s components in the last portion of its shape
179 should sum to `1`. Only one of `logits` or `probs` should be passed in.
180 validate_args: Python `bool`, default `False`. When `True` distribution
181 parameters are checked for validity despite possibly degrading runtime
182 performance. When `False` invalid inputs may silently render incorrect
183 outputs.
184 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
185 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
186 result is undefined. When `False`, an exception is raised if one or
187 more of the statistic's batch members are undefined.
188 name: Python `str` name prefixed to Ops created by this class.
189 """
190 parameters = dict(locals())
191 with ops.name_scope(name, values=[total_count, logits, probs]) as name:
192 self._total_count = ops.convert_to_tensor(total_count, name="total_count")
193 if validate_args:
194 self._total_count = (
195 distribution_util.embed_check_nonnegative_integer_form(
196 self._total_count))
197 self._logits, self._probs = distribution_util.get_logits_and_probs(
198 logits=logits,
199 probs=probs,
200 multidimensional=True,
201 validate_args=validate_args,
202 name=name)
203 self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs
204 super(Multinomial, self).__init__(
205 dtype=self._probs.dtype,
206 reparameterization_type=distribution.NOT_REPARAMETERIZED,
207 validate_args=validate_args,
208 allow_nan_stats=allow_nan_stats,
209 parameters=parameters,
210 graph_parents=[self._total_count,
211 self._logits,
212 self._probs],
213 name=name)
215 @property
216 def total_count(self):
217 """Number of trials used to construct a sample."""
218 return self._total_count
220 @property
221 def logits(self):
222 """Vector of coordinatewise logits."""
223 return self._logits
225 @property
226 def probs(self):
227 """Probability of drawing a `1` in that coordinate."""
228 return self._probs
230 def _batch_shape_tensor(self):
231 return array_ops.shape(self._mean_val)[:-1]
233 def _batch_shape(self):
234 return self._mean_val.get_shape().with_rank_at_least(1)[:-1]
236 def _event_shape_tensor(self):
237 return array_ops.shape(self._mean_val)[-1:]
239 def _event_shape(self):
240 return self._mean_val.get_shape().with_rank_at_least(1)[-1:]
242 def _sample_n(self, n, seed=None):
243 n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
244 k = self.event_shape_tensor()[0]
246 # broadcast the total_count and logits to same shape
247 n_draws = array_ops.ones_like(
248 self.logits[..., 0], dtype=n_draws.dtype) * n_draws
249 logits = array_ops.ones_like(
250 n_draws[..., array_ops.newaxis], dtype=self.logits.dtype) * self.logits
252 # flatten the total_count and logits
253 flat_logits = array_ops.reshape(logits, [-1, k]) # [B1B2...Bm, k]
254 flat_ndraws = n * array_ops.reshape(n_draws, [-1]) # [B1B2...Bm]
256 # computes each total_count and logits situation by map_fn
257 def _sample_single(args):
258 logits, n_draw = args[0], args[1] # [K], []
259 x = random_ops.multinomial(logits[array_ops.newaxis, ...], n_draw,
260 seed) # [1, n*n_draw]
261 x = array_ops.reshape(x, shape=[n, -1]) # [n, n_draw]
262 x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2) # [n, k]
263 return x
265 x = map_fn.map_fn(
266 _sample_single, [flat_logits, flat_ndraws],
267 dtype=self.dtype) # [B1B2...Bm, n, k]
269 # reshape the results to proper shape
270 x = array_ops.transpose(x, perm=[1, 0, 2])
271 final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
272 x = array_ops.reshape(x, final_shape) # [n, B1, B2,..., Bm, k]
273 return x
275 @distribution_util.AppendDocstring(_multinomial_sample_note)
276 def _log_prob(self, counts):
277 return self._log_unnormalized_prob(counts) - self._log_normalization(counts)
279 def _log_unnormalized_prob(self, counts):
280 counts = self._maybe_assert_valid_sample(counts)
281 return math_ops.reduce_sum(counts * nn_ops.log_softmax(self.logits), -1)
283 def _log_normalization(self, counts):
284 counts = self._maybe_assert_valid_sample(counts)
285 return -distribution_util.log_combinations(self.total_count, counts)
287 def _mean(self):
288 return array_ops.identity(self._mean_val)
290 def _covariance(self):
291 p = self.probs * array_ops.ones_like(
292 self.total_count)[..., array_ops.newaxis]
293 # pylint: disable=invalid-unary-operand-type
294 return array_ops.matrix_set_diag(
295 -math_ops.matmul(
296 self._mean_val[..., array_ops.newaxis],
297 p[..., array_ops.newaxis, :]), # outer product
298 self._variance())
300 def _variance(self):
301 p = self.probs * array_ops.ones_like(
302 self.total_count)[..., array_ops.newaxis]
303 return self._mean_val - self._mean_val * p
305 def _maybe_assert_valid_sample(self, counts):
306 """Check counts for proper shape, values, then return tensor version."""
307 if not self.validate_args:
308 return counts
309 counts = distribution_util.embed_check_nonnegative_integer_form(counts)
310 return control_flow_ops.with_dependencies([
311 check_ops.assert_equal(
312 self.total_count, math_ops.reduce_sum(counts, -1),
313 message="counts must sum to `self.total_count`"),
314 ], counts)