Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/multinomial.py: 46%

85 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The Multinomial distribution class.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.ops import array_ops 

20from tensorflow.python.ops import check_ops 

21from tensorflow.python.ops import control_flow_ops 

22from tensorflow.python.ops import map_fn 

23from tensorflow.python.ops import math_ops 

24from tensorflow.python.ops import nn_ops 

25from tensorflow.python.ops import random_ops 

26from tensorflow.python.ops.distributions import distribution 

27from tensorflow.python.ops.distributions import util as distribution_util 

28from tensorflow.python.util import deprecation 

29from tensorflow.python.util.tf_export import tf_export 

30 

31 

32__all__ = [ 

33 "Multinomial", 

34] 

35 

36 

37_multinomial_sample_note = """For each batch of counts, `value = [n_0, ... 

38,n_{k-1}]`, `P[value]` is the probability that after sampling `self.total_count` 

39draws from this Multinomial distribution, the number of draws falling in class 

40`j` is `n_j`. Since this definition is [exchangeable]( 

41https://en.wikipedia.org/wiki/Exchangeable_random_variables); different 

42sequences have the same counts so the probability includes a combinatorial 

43coefficient. 

44 

45Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no 

46fractional components, and such that 

47`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable 

48with `self.probs` and `self.total_count`.""" 

49 

50 

51@tf_export(v1=["distributions.Multinomial"]) 

52class Multinomial(distribution.Distribution): 

53 """Multinomial distribution. 

54 

55 This Multinomial distribution is parameterized by `probs`, a (batch of) 

56 length-`K` `prob` (probability) vectors (`K > 1`) such that 

57 `tf.reduce_sum(probs, -1) = 1`, and a `total_count` number of trials, i.e., 

58 the number of trials per draw from the Multinomial. It is defined over a 

59 (batch of) length-`K` vector `counts` such that 

60 `tf.reduce_sum(counts, -1) = total_count`. The Multinomial is identically the 

61 Binomial distribution when `K = 2`. 

62 

63 #### Mathematical Details 

64 

65 The Multinomial is a distribution over `K`-class counts, i.e., a length-`K` 

66 vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`. 

67 

68 The probability mass function (pmf) is, 

69 

70 ```none 

71 pmf(n; pi, N) = prod_j (pi_j)**n_j / Z 

72 Z = (prod_j n_j!) / N! 

73 ``` 

74 

75 where: 

76 * `probs = pi = [pi_0, ..., pi_{K-1}]`, `pi_j > 0`, `sum_j pi_j = 1`, 

77 * `total_count = N`, `N` a positive integer, 

78 * `Z` is the normalization constant, and, 

79 * `N!` denotes `N` factorial. 

80 

81 Distribution parameters are automatically broadcast in all functions; see 

82 examples for details. 

83 

84 #### Pitfalls 

85 

86 The number of classes, `K`, must not exceed: 

87 - the largest integer representable by `self.dtype`, i.e., 

88 `2**(mantissa_bits+1)` (IEE754), 

89 - the maximum `Tensor` index, i.e., `2**31-1`. 

90 

91 In other words, 

92 

93 ```python 

94 K <= min(2**31-1, { 

95 tf.float16: 2**11, 

96 tf.float32: 2**24, 

97 tf.float64: 2**53 }[param.dtype]) 

98 ``` 

99 

100 Note: This condition is validated only when `self.validate_args = True`. 

101 

102 #### Examples 

103 

104 Create a 3-class distribution, with the 3rd class is most likely to be drawn, 

105 using logits. 

106 

107 ```python 

108 logits = [-50., -43, 0] 

109 dist = Multinomial(total_count=4., logits=logits) 

110 ``` 

111 

112 Create a 3-class distribution, with the 3rd class is most likely to be drawn. 

113 

114 ```python 

115 p = [.2, .3, .5] 

116 dist = Multinomial(total_count=4., probs=p) 

117 ``` 

118 

119 The distribution functions can be evaluated on counts. 

120 

121 ```python 

122 # counts same shape as p. 

123 counts = [1., 0, 3] 

124 dist.prob(counts) # Shape [] 

125 

126 # p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts. 

127 counts = [[1., 2, 1], [2, 2, 0]] 

128 dist.prob(counts) # Shape [2] 

129 

130 # p will be broadcast to shape [5, 7, 3] to match counts. 

131 counts = [[...]] # Shape [5, 7, 3] 

132 dist.prob(counts) # Shape [5, 7] 

133 ``` 

134 

135 Create a 2-batch of 3-class distributions. 

136 

137 ```python 

138 p = [[.1, .2, .7], [.3, .3, .4]] # Shape [2, 3] 

139 dist = Multinomial(total_count=[4., 5], probs=p) 

140 

141 counts = [[2., 1, 1], [3, 1, 1]] 

142 dist.prob(counts) # Shape [2] 

143 

144 dist.sample(5) # Shape [5, 2, 3] 

145 ``` 

146 """ 

147 

148 @deprecation.deprecated( 

149 "2019-01-01", 

150 "The TensorFlow Distributions library has moved to " 

151 "TensorFlow Probability " 

152 "(https://github.com/tensorflow/probability). You " 

153 "should update all references to use `tfp.distributions` " 

154 "instead of `tf.distributions`.", 

155 warn_once=True) 

156 def __init__(self, 

157 total_count, 

158 logits=None, 

159 probs=None, 

160 validate_args=False, 

161 allow_nan_stats=True, 

162 name="Multinomial"): 

163 """Initialize a batch of Multinomial distributions. 

164 

165 Args: 

166 total_count: Non-negative floating point tensor with shape broadcastable 

167 to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of 

168 `N1 x ... x Nm` different Multinomial distributions. Its components 

169 should be equal to integer values. 

170 logits: Floating point tensor representing unnormalized log-probabilities 

171 of a positive event with shape broadcastable to 

172 `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines 

173 this as a batch of `N1 x ... x Nm` different `K` class Multinomial 

174 distributions. Only one of `logits` or `probs` should be passed in. 

175 probs: Positive floating point tensor with shape broadcastable to 

176 `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines 

177 this as a batch of `N1 x ... x Nm` different `K` class Multinomial 

178 distributions. `probs`'s components in the last portion of its shape 

179 should sum to `1`. Only one of `logits` or `probs` should be passed in. 

180 validate_args: Python `bool`, default `False`. When `True` distribution 

181 parameters are checked for validity despite possibly degrading runtime 

182 performance. When `False` invalid inputs may silently render incorrect 

183 outputs. 

184 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 

185 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 

186 result is undefined. When `False`, an exception is raised if one or 

187 more of the statistic's batch members are undefined. 

188 name: Python `str` name prefixed to Ops created by this class. 

189 """ 

190 parameters = dict(locals()) 

191 with ops.name_scope(name, values=[total_count, logits, probs]) as name: 

192 self._total_count = ops.convert_to_tensor(total_count, name="total_count") 

193 if validate_args: 

194 self._total_count = ( 

195 distribution_util.embed_check_nonnegative_integer_form( 

196 self._total_count)) 

197 self._logits, self._probs = distribution_util.get_logits_and_probs( 

198 logits=logits, 

199 probs=probs, 

200 multidimensional=True, 

201 validate_args=validate_args, 

202 name=name) 

203 self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs 

204 super(Multinomial, self).__init__( 

205 dtype=self._probs.dtype, 

206 reparameterization_type=distribution.NOT_REPARAMETERIZED, 

207 validate_args=validate_args, 

208 allow_nan_stats=allow_nan_stats, 

209 parameters=parameters, 

210 graph_parents=[self._total_count, 

211 self._logits, 

212 self._probs], 

213 name=name) 

214 

215 @property 

216 def total_count(self): 

217 """Number of trials used to construct a sample.""" 

218 return self._total_count 

219 

220 @property 

221 def logits(self): 

222 """Vector of coordinatewise logits.""" 

223 return self._logits 

224 

225 @property 

226 def probs(self): 

227 """Probability of drawing a `1` in that coordinate.""" 

228 return self._probs 

229 

230 def _batch_shape_tensor(self): 

231 return array_ops.shape(self._mean_val)[:-1] 

232 

233 def _batch_shape(self): 

234 return self._mean_val.get_shape().with_rank_at_least(1)[:-1] 

235 

236 def _event_shape_tensor(self): 

237 return array_ops.shape(self._mean_val)[-1:] 

238 

239 def _event_shape(self): 

240 return self._mean_val.get_shape().with_rank_at_least(1)[-1:] 

241 

242 def _sample_n(self, n, seed=None): 

243 n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) 

244 k = self.event_shape_tensor()[0] 

245 

246 # broadcast the total_count and logits to same shape 

247 n_draws = array_ops.ones_like( 

248 self.logits[..., 0], dtype=n_draws.dtype) * n_draws 

249 logits = array_ops.ones_like( 

250 n_draws[..., array_ops.newaxis], dtype=self.logits.dtype) * self.logits 

251 

252 # flatten the total_count and logits 

253 flat_logits = array_ops.reshape(logits, [-1, k]) # [B1B2...Bm, k] 

254 flat_ndraws = n * array_ops.reshape(n_draws, [-1]) # [B1B2...Bm] 

255 

256 # computes each total_count and logits situation by map_fn 

257 def _sample_single(args): 

258 logits, n_draw = args[0], args[1] # [K], [] 

259 x = random_ops.multinomial(logits[array_ops.newaxis, ...], n_draw, 

260 seed) # [1, n*n_draw] 

261 x = array_ops.reshape(x, shape=[n, -1]) # [n, n_draw] 

262 x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2) # [n, k] 

263 return x 

264 

265 x = map_fn.map_fn( 

266 _sample_single, [flat_logits, flat_ndraws], 

267 dtype=self.dtype) # [B1B2...Bm, n, k] 

268 

269 # reshape the results to proper shape 

270 x = array_ops.transpose(x, perm=[1, 0, 2]) 

271 final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) 

272 x = array_ops.reshape(x, final_shape) # [n, B1, B2,..., Bm, k] 

273 return x 

274 

275 @distribution_util.AppendDocstring(_multinomial_sample_note) 

276 def _log_prob(self, counts): 

277 return self._log_unnormalized_prob(counts) - self._log_normalization(counts) 

278 

279 def _log_unnormalized_prob(self, counts): 

280 counts = self._maybe_assert_valid_sample(counts) 

281 return math_ops.reduce_sum(counts * nn_ops.log_softmax(self.logits), -1) 

282 

283 def _log_normalization(self, counts): 

284 counts = self._maybe_assert_valid_sample(counts) 

285 return -distribution_util.log_combinations(self.total_count, counts) 

286 

287 def _mean(self): 

288 return array_ops.identity(self._mean_val) 

289 

290 def _covariance(self): 

291 p = self.probs * array_ops.ones_like( 

292 self.total_count)[..., array_ops.newaxis] 

293 # pylint: disable=invalid-unary-operand-type 

294 return array_ops.matrix_set_diag( 

295 -math_ops.matmul( 

296 self._mean_val[..., array_ops.newaxis], 

297 p[..., array_ops.newaxis, :]), # outer product 

298 self._variance()) 

299 

300 def _variance(self): 

301 p = self.probs * array_ops.ones_like( 

302 self.total_count)[..., array_ops.newaxis] 

303 return self._mean_val - self._mean_val * p 

304 

305 def _maybe_assert_valid_sample(self, counts): 

306 """Check counts for proper shape, values, then return tensor version.""" 

307 if not self.validate_args: 

308 return counts 

309 counts = distribution_util.embed_check_nonnegative_integer_form(counts) 

310 return control_flow_ops.with_dependencies([ 

311 check_ops.assert_equal( 

312 self.total_count, math_ops.reduce_sum(counts, -1), 

313 message="counts must sum to `self.total_count`"), 

314 ], counts)