Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/dirichlet_multinomial.py: 49%

84 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The DirichletMultinomial distribution class.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.ops import array_ops 

20from tensorflow.python.ops import check_ops 

21from tensorflow.python.ops import control_flow_ops 

22from tensorflow.python.ops import math_ops 

23from tensorflow.python.ops import random_ops 

24from tensorflow.python.ops import special_math_ops 

25from tensorflow.python.ops.distributions import distribution 

26from tensorflow.python.ops.distributions import util as distribution_util 

27from tensorflow.python.util import deprecation 

28from tensorflow.python.util.tf_export import tf_export 

29 

30 

31__all__ = [ 

32 "DirichletMultinomial", 

33] 

34 

35 

36_dirichlet_multinomial_sample_note = """For each batch of counts, 

37`value = [n_0, ..., n_{K-1}]`, `P[value]` is the probability that after 

38sampling `self.total_count` draws from this Dirichlet-Multinomial distribution, 

39the number of draws falling in class `j` is `n_j`. Since this definition is 

40[exchangeable](https://en.wikipedia.org/wiki/Exchangeable_random_variables); 

41different sequences have the same counts so the probability includes a 

42combinatorial coefficient. 

43 

44Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no 

45fractional components, and such that 

46`tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable 

47with `self.concentration` and `self.total_count`.""" 

48 

49 

50@tf_export(v1=["distributions.DirichletMultinomial"]) 

51class DirichletMultinomial(distribution.Distribution): 

52 """Dirichlet-Multinomial compound distribution. 

53 

54 The Dirichlet-Multinomial distribution is parameterized by a (batch of) 

55 length-`K` `concentration` vectors (`K > 1`) and a `total_count` number of 

56 trials, i.e., the number of trials per draw from the DirichletMultinomial. It 

57 is defined over a (batch of) length-`K` vector `counts` such that 

58 `tf.reduce_sum(counts, -1) = total_count`. The Dirichlet-Multinomial is 

59 identically the Beta-Binomial distribution when `K = 2`. 

60 

61 #### Mathematical Details 

62 

63 The Dirichlet-Multinomial is a distribution over `K`-class counts, i.e., a 

64 length-`K` vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`. 

65 

66 The probability mass function (pmf) is, 

67 

68 ```none 

69 pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z 

70 Z = Beta(alpha) / N! 

71 ``` 

72 

73 where: 

74 

75 * `concentration = alpha = [alpha_0, ..., alpha_{K-1}]`, `alpha_j > 0`, 

76 * `total_count = N`, `N` a positive integer, 

77 * `N!` is `N` factorial, and, 

78 * `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the 

79 [multivariate beta function]( 

80 https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function), 

81 and, 

82 * `Gamma` is the [gamma function]( 

83 https://en.wikipedia.org/wiki/Gamma_function). 

84 

85 Dirichlet-Multinomial is a [compound distribution]( 

86 https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., its 

87 samples are generated as follows. 

88 

89 1. Choose class probabilities: 

90 `probs = [p_0,...,p_{K-1}] ~ Dir(concentration)` 

91 2. Draw integers: 

92 `counts = [n_0,...,n_{K-1}] ~ Multinomial(total_count, probs)` 

93 

94 The last `concentration` dimension parametrizes a single Dirichlet-Multinomial 

95 distribution. When calling distribution functions (e.g., `dist.prob(counts)`), 

96 `concentration`, `total_count` and `counts` are broadcast to the same shape. 

97 The last dimension of `counts` corresponds single Dirichlet-Multinomial 

98 distributions. 

99 

100 Distribution parameters are automatically broadcast in all functions; see 

101 examples for details. 

102 

103 #### Pitfalls 

104 

105 The number of classes, `K`, must not exceed: 

106 - the largest integer representable by `self.dtype`, i.e., 

107 `2**(mantissa_bits+1)` (IEE754), 

108 - the maximum `Tensor` index, i.e., `2**31-1`. 

109 

110 In other words, 

111 

112 ```python 

113 K <= min(2**31-1, { 

114 tf.float16: 2**11, 

115 tf.float32: 2**24, 

116 tf.float64: 2**53 }[param.dtype]) 

117 ``` 

118 

119 Note: This condition is validated only when `self.validate_args = True`. 

120 

121 #### Examples 

122 

123 ```python 

124 alpha = [1., 2., 3.] 

125 n = 2. 

126 dist = DirichletMultinomial(n, alpha) 

127 ``` 

128 

129 Creates a 3-class distribution, with the 3rd class is most likely to be 

130 drawn. 

131 The distribution functions can be evaluated on counts. 

132 

133 ```python 

134 # counts same shape as alpha. 

135 counts = [0., 0., 2.] 

136 dist.prob(counts) # Shape [] 

137 

138 # alpha will be broadcast to [[1., 2., 3.], [1., 2., 3.]] to match counts. 

139 counts = [[1., 1., 0.], [1., 0., 1.]] 

140 dist.prob(counts) # Shape [2] 

141 

142 # alpha will be broadcast to shape [5, 7, 3] to match counts. 

143 counts = [[...]] # Shape [5, 7, 3] 

144 dist.prob(counts) # Shape [5, 7] 

145 ``` 

146 

147 Creates a 2-batch of 3-class distributions. 

148 

149 ```python 

150 alpha = [[1., 2., 3.], [4., 5., 6.]] # Shape [2, 3] 

151 n = [3., 3.] 

152 dist = DirichletMultinomial(n, alpha) 

153 

154 # counts will be broadcast to [[2., 1., 0.], [2., 1., 0.]] to match alpha. 

155 counts = [2., 1., 0.] 

156 dist.prob(counts) # Shape [2] 

157 ``` 

158 

159 """ 

160 

161 # TODO(b/27419586) Change docstring for dtype of concentration once int 

162 # allowed. 

163 @deprecation.deprecated( 

164 "2019-01-01", 

165 "The TensorFlow Distributions library has moved to " 

166 "TensorFlow Probability " 

167 "(https://github.com/tensorflow/probability). You " 

168 "should update all references to use `tfp.distributions` " 

169 "instead of `tf.distributions`.", 

170 warn_once=True) 

171 def __init__(self, 

172 total_count, 

173 concentration, 

174 validate_args=False, 

175 allow_nan_stats=True, 

176 name="DirichletMultinomial"): 

177 """Initialize a batch of DirichletMultinomial distributions. 

178 

179 Args: 

180 total_count: Non-negative floating point tensor, whose dtype is the same 

181 as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with 

182 `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different 

183 Dirichlet multinomial distributions. Its components should be equal to 

184 integer values. 

185 concentration: Positive floating point tensor, whose dtype is the 

186 same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`. 

187 Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet 

188 multinomial distributions. 

189 validate_args: Python `bool`, default `False`. When `True` distribution 

190 parameters are checked for validity despite possibly degrading runtime 

191 performance. When `False` invalid inputs may silently render incorrect 

192 outputs. 

193 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 

194 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 

195 result is undefined. When `False`, an exception is raised if one or 

196 more of the statistic's batch members are undefined. 

197 name: Python `str` name prefixed to Ops created by this class. 

198 """ 

199 parameters = dict(locals()) 

200 with ops.name_scope(name, values=[total_count, concentration]) as name: 

201 # Broadcasting works because: 

202 # * The broadcasting convention is to prepend dimensions of size [1], and 

203 # we use the last dimension for the distribution, whereas 

204 # the batch dimensions are the leading dimensions, which forces the 

205 # distribution dimension to be defined explicitly (i.e. it cannot be 

206 # created automatically by prepending). This forces enough explicitness. 

207 # * All calls involving `counts` eventually require a broadcast between 

208 # `counts` and concentration. 

209 self._total_count = ops.convert_to_tensor(total_count, name="total_count") 

210 if validate_args: 

211 self._total_count = ( 

212 distribution_util.embed_check_nonnegative_integer_form( 

213 self._total_count)) 

214 self._concentration = self._maybe_assert_valid_concentration( 

215 ops.convert_to_tensor(concentration, 

216 name="concentration"), 

217 validate_args) 

218 self._total_concentration = math_ops.reduce_sum(self._concentration, -1) 

219 super(DirichletMultinomial, self).__init__( 

220 dtype=self._concentration.dtype, 

221 validate_args=validate_args, 

222 allow_nan_stats=allow_nan_stats, 

223 reparameterization_type=distribution.NOT_REPARAMETERIZED, 

224 parameters=parameters, 

225 graph_parents=[self._total_count, 

226 self._concentration], 

227 name=name) 

228 

229 @property 

230 def total_count(self): 

231 """Number of trials used to construct a sample.""" 

232 return self._total_count 

233 

234 @property 

235 def concentration(self): 

236 """Concentration parameter; expected prior counts for that coordinate.""" 

237 return self._concentration 

238 

239 @property 

240 def total_concentration(self): 

241 """Sum of last dim of concentration parameter.""" 

242 return self._total_concentration 

243 

244 def _batch_shape_tensor(self): 

245 return array_ops.shape(self.total_concentration) 

246 

247 def _batch_shape(self): 

248 return self.total_concentration.get_shape() 

249 

250 def _event_shape_tensor(self): 

251 return array_ops.shape(self.concentration)[-1:] 

252 

253 def _event_shape(self): 

254 # Event shape depends only on total_concentration, not "n". 

255 return self.concentration.get_shape().with_rank_at_least(1)[-1:] 

256 

257 def _sample_n(self, n, seed=None): 

258 n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) 

259 k = self.event_shape_tensor()[0] 

260 unnormalized_logits = array_ops.reshape( 

261 math_ops.log(random_ops.random_gamma( 

262 shape=[n], 

263 alpha=self.concentration, 

264 dtype=self.dtype, 

265 seed=seed)), 

266 shape=[-1, k]) 

267 draws = random_ops.multinomial( 

268 logits=unnormalized_logits, 

269 num_samples=n_draws, 

270 seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial")) 

271 x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2) 

272 final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) 

273 x = array_ops.reshape(x, final_shape) 

274 return math_ops.cast(x, self.dtype) 

275 

276 @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note) 

277 def _log_prob(self, counts): 

278 counts = self._maybe_assert_valid_sample(counts) 

279 ordered_prob = ( 

280 special_math_ops.lbeta(self.concentration + counts) 

281 - special_math_ops.lbeta(self.concentration)) 

282 return ordered_prob + distribution_util.log_combinations( 

283 self.total_count, counts) 

284 

285 @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note) 

286 def _prob(self, counts): 

287 return math_ops.exp(self._log_prob(counts)) 

288 

289 def _mean(self): 

290 return self.total_count * (self.concentration / 

291 self.total_concentration[..., array_ops.newaxis]) 

292 

293 @distribution_util.AppendDocstring( 

294 """The covariance for each batch member is defined as the following: 

295 

296 ```none 

297 Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) * 

298 (n + alpha_0) / (1 + alpha_0) 

299 ``` 

300 

301 where `concentration = alpha` and 

302 `total_concentration = alpha_0 = sum_j alpha_j`. 

303 

304 The covariance between elements in a batch is defined as: 

305 

306 ```none 

307 Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 * 

308 (n + alpha_0) / (1 + alpha_0) 

309 ``` 

310 """) 

311 def _covariance(self): 

312 x = self._variance_scale_term() * self._mean() 

313 # pylint: disable=invalid-unary-operand-type 

314 return array_ops.matrix_set_diag( 

315 -math_ops.matmul( 

316 x[..., array_ops.newaxis], 

317 x[..., array_ops.newaxis, :]), # outer prod 

318 self._variance()) 

319 

320 def _variance(self): 

321 scale = self._variance_scale_term() 

322 x = scale * self._mean() 

323 return x * (self.total_count * scale - x) 

324 

325 def _variance_scale_term(self): 

326 """Helper to `_covariance` and `_variance` which computes a shared scale.""" 

327 # We must take care to expand back the last dim whenever we use the 

328 # total_concentration. 

329 c0 = self.total_concentration[..., array_ops.newaxis] 

330 return math_ops.sqrt((1. + c0 / self.total_count) / (1. + c0)) 

331 

332 def _maybe_assert_valid_concentration(self, concentration, validate_args): 

333 """Checks the validity of the concentration parameter.""" 

334 if not validate_args: 

335 return concentration 

336 concentration = distribution_util.embed_check_categorical_event_shape( 

337 concentration) 

338 return control_flow_ops.with_dependencies([ 

339 check_ops.assert_positive( 

340 concentration, 

341 message="Concentration parameter must be positive."), 

342 ], concentration) 

343 

344 def _maybe_assert_valid_sample(self, counts): 

345 """Check counts for proper shape, values, then return tensor version.""" 

346 if not self.validate_args: 

347 return counts 

348 counts = distribution_util.embed_check_nonnegative_integer_form(counts) 

349 return control_flow_ops.with_dependencies([ 

350 check_ops.assert_equal( 

351 self.total_count, math_ops.reduce_sum(counts, -1), 

352 message="counts last-dimension must sum to `self.total_count`"), 

353 ], counts)