Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/dirichlet.py: 51%

90 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The Dirichlet distribution class.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.framework import ops 

20from tensorflow.python.ops import array_ops 

21from tensorflow.python.ops import check_ops 

22from tensorflow.python.ops import control_flow_ops 

23from tensorflow.python.ops import math_ops 

24from tensorflow.python.ops import random_ops 

25from tensorflow.python.ops import special_math_ops 

26from tensorflow.python.ops.distributions import distribution 

27from tensorflow.python.ops.distributions import kullback_leibler 

28from tensorflow.python.ops.distributions import util as distribution_util 

29from tensorflow.python.util import deprecation 

30from tensorflow.python.util.tf_export import tf_export 

31 

32 

33__all__ = [ 

34 "Dirichlet", 

35] 

36 

37 

38_dirichlet_sample_note = """Note: `value` must be a non-negative tensor with 

39dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e., 

40`tf.reduce_sum(value, -1) = 1`. It must have a shape compatible with 

41`self.batch_shape() + self.event_shape()`.""" 

42 

43 

44@tf_export(v1=["distributions.Dirichlet"]) 

45class Dirichlet(distribution.Distribution): 

46 """Dirichlet distribution. 

47 

48 The Dirichlet distribution is defined over the 

49 [`(k-1)`-simplex](https://en.wikipedia.org/wiki/Simplex) using a positive, 

50 length-`k` vector `concentration` (`k > 1`). The Dirichlet is identically the 

51 Beta distribution when `k = 2`. 

52 

53 #### Mathematical Details 

54 

55 The Dirichlet is a distribution over the open `(k-1)`-simplex, i.e., 

56 

57 ```none 

58 S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }. 

59 ``` 

60 

61 The probability density function (pdf) is, 

62 

63 ```none 

64 pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z 

65 Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j) 

66 ``` 

67 

68 where: 

69 

70 * `x in S^{k-1}`, i.e., the `(k-1)`-simplex, 

71 * `concentration = alpha = [alpha_0, ..., alpha_{k-1}]`, `alpha_j > 0`, 

72 * `Z` is the normalization constant aka the [multivariate beta function]( 

73 https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function), 

74 and, 

75 * `Gamma` is the [gamma function]( 

76 https://en.wikipedia.org/wiki/Gamma_function). 

77 

78 The `concentration` represents mean total counts of class occurrence, i.e., 

79 

80 ```none 

81 concentration = alpha = mean * total_concentration 

82 ``` 

83 

84 where `mean` in `S^{k-1}` and `total_concentration` is a positive real number 

85 representing a mean total count. 

86 

87 Distribution parameters are automatically broadcast in all functions; see 

88 examples for details. 

89 

90 Warning: Some components of the samples can be zero due to finite precision. 

91 This happens more often when some of the concentrations are very small. 

92 Make sure to round the samples to `np.finfo(dtype).tiny` before computing the 

93 density. 

94 

95 Samples of this distribution are reparameterized (pathwise differentiable). 

96 The derivatives are computed using the approach described in 

97 (Figurnov et al., 2018). 

98 

99 #### Examples 

100 

101 ```python 

102 import tensorflow_probability as tfp 

103 tfd = tfp.distributions 

104 

105 # Create a single trivariate Dirichlet, with the 3rd class being three times 

106 # more frequent than the first. I.e., batch_shape=[], event_shape=[3]. 

107 alpha = [1., 2, 3] 

108 dist = tfd.Dirichlet(alpha) 

109 

110 dist.sample([4, 5]) # shape: [4, 5, 3] 

111 

112 # x has one sample, one batch, three classes: 

113 x = [.2, .3, .5] # shape: [3] 

114 dist.prob(x) # shape: [] 

115 

116 # x has two samples from one batch: 

117 x = [[.1, .4, .5], 

118 [.2, .3, .5]] 

119 dist.prob(x) # shape: [2] 

120 

121 # alpha will be broadcast to shape [5, 7, 3] to match x. 

122 x = [[...]] # shape: [5, 7, 3] 

123 dist.prob(x) # shape: [5, 7] 

124 ``` 

125 

126 ```python 

127 # Create batch_shape=[2], event_shape=[3]: 

128 alpha = [[1., 2, 3], 

129 [4, 5, 6]] # shape: [2, 3] 

130 dist = tfd.Dirichlet(alpha) 

131 

132 dist.sample([4, 5]) # shape: [4, 5, 2, 3] 

133 

134 x = [.2, .3, .5] 

135 # x will be broadcast as [[.2, .3, .5], 

136 # [.2, .3, .5]], 

137 # thus matching batch_shape [2, 3]. 

138 dist.prob(x) # shape: [2] 

139 ``` 

140 

141 Compute the gradients of samples w.r.t. the parameters: 

142 

143 ```python 

144 alpha = tf.constant([1.0, 2.0, 3.0]) 

145 dist = tfd.Dirichlet(alpha) 

146 samples = dist.sample(5) # Shape [5, 3] 

147 loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function 

148 # Unbiased stochastic gradients of the loss function 

149 grads = tf.gradients(loss, alpha) 

150 ``` 

151 

152 References: 

153 Implicit Reparameterization Gradients: 

154 [Figurnov et al., 2018] 

155 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients) 

156 ([pdf] 

157 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf)) 

158 """ 

159 

160 @deprecation.deprecated( 

161 "2019-01-01", 

162 "The TensorFlow Distributions library has moved to " 

163 "TensorFlow Probability " 

164 "(https://github.com/tensorflow/probability). You " 

165 "should update all references to use `tfp.distributions` " 

166 "instead of `tf.distributions`.", 

167 warn_once=True) 

168 def __init__(self, 

169 concentration, 

170 validate_args=False, 

171 allow_nan_stats=True, 

172 name="Dirichlet"): 

173 """Initialize a batch of Dirichlet distributions. 

174 

175 Args: 

176 concentration: Positive floating-point `Tensor` indicating mean number 

177 of class occurrences; aka "alpha". Implies `self.dtype`, and 

178 `self.batch_shape`, `self.event_shape`, i.e., if 

179 `concentration.shape = [N1, N2, ..., Nm, k]` then 

180 `batch_shape = [N1, N2, ..., Nm]` and 

181 `event_shape = [k]`. 

182 validate_args: Python `bool`, default `False`. When `True` distribution 

183 parameters are checked for validity despite possibly degrading runtime 

184 performance. When `False` invalid inputs may silently render incorrect 

185 outputs. 

186 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 

187 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 

188 result is undefined. When `False`, an exception is raised if one or 

189 more of the statistic's batch members are undefined. 

190 name: Python `str` name prefixed to Ops created by this class. 

191 """ 

192 parameters = dict(locals()) 

193 with ops.name_scope(name, values=[concentration]) as name: 

194 self._concentration = self._maybe_assert_valid_concentration( 

195 ops.convert_to_tensor(concentration, name="concentration"), 

196 validate_args) 

197 self._total_concentration = math_ops.reduce_sum(self._concentration, -1) 

198 super(Dirichlet, self).__init__( 

199 dtype=self._concentration.dtype, 

200 validate_args=validate_args, 

201 allow_nan_stats=allow_nan_stats, 

202 reparameterization_type=distribution.FULLY_REPARAMETERIZED, 

203 parameters=parameters, 

204 graph_parents=[self._concentration, 

205 self._total_concentration], 

206 name=name) 

207 

208 @property 

209 def concentration(self): 

210 """Concentration parameter; expected counts for that coordinate.""" 

211 return self._concentration 

212 

213 @property 

214 def total_concentration(self): 

215 """Sum of last dim of concentration parameter.""" 

216 return self._total_concentration 

217 

218 def _batch_shape_tensor(self): 

219 return array_ops.shape(self.total_concentration) 

220 

221 def _batch_shape(self): 

222 return self.total_concentration.get_shape() 

223 

224 def _event_shape_tensor(self): 

225 return array_ops.shape(self.concentration)[-1:] 

226 

227 def _event_shape(self): 

228 return self.concentration.get_shape().with_rank_at_least(1)[-1:] 

229 

230 def _sample_n(self, n, seed=None): 

231 gamma_sample = random_ops.random_gamma( 

232 shape=[n], 

233 alpha=self.concentration, 

234 dtype=self.dtype, 

235 seed=seed) 

236 return gamma_sample / math_ops.reduce_sum(gamma_sample, -1, keepdims=True) 

237 

238 @distribution_util.AppendDocstring(_dirichlet_sample_note) 

239 def _log_prob(self, x): 

240 return self._log_unnormalized_prob(x) - self._log_normalization() 

241 

242 @distribution_util.AppendDocstring(_dirichlet_sample_note) 

243 def _prob(self, x): 

244 return math_ops.exp(self._log_prob(x)) 

245 

246 def _log_unnormalized_prob(self, x): 

247 x = self._maybe_assert_valid_sample(x) 

248 return math_ops.reduce_sum(math_ops.xlogy(self.concentration - 1., x), -1) 

249 

250 def _log_normalization(self): 

251 return special_math_ops.lbeta(self.concentration) 

252 

253 def _entropy(self): 

254 k = math_ops.cast(self.event_shape_tensor()[0], self.dtype) 

255 return ( 

256 self._log_normalization() 

257 + ((self.total_concentration - k) 

258 * math_ops.digamma(self.total_concentration)) 

259 - math_ops.reduce_sum( 

260 (self.concentration - 1.) * math_ops.digamma(self.concentration), 

261 axis=-1)) 

262 

263 def _mean(self): 

264 return self.concentration / self.total_concentration[..., array_ops.newaxis] 

265 

266 def _covariance(self): 

267 x = self._variance_scale_term() * self._mean() 

268 # pylint: disable=invalid-unary-operand-type 

269 return array_ops.matrix_set_diag( 

270 -math_ops.matmul( 

271 x[..., array_ops.newaxis], 

272 x[..., array_ops.newaxis, :]), # outer prod 

273 self._variance()) 

274 

275 def _variance(self): 

276 scale = self._variance_scale_term() 

277 x = scale * self._mean() 

278 return x * (scale - x) 

279 

280 def _variance_scale_term(self): 

281 """Helper to `_covariance` and `_variance` which computes a shared scale.""" 

282 return math_ops.rsqrt(1. + self.total_concentration[..., array_ops.newaxis]) 

283 

284 @distribution_util.AppendDocstring( 

285 """Note: The mode is undefined when any `concentration <= 1`. If 

286 `self.allow_nan_stats` is `True`, `NaN` is used for undefined modes. If 

287 `self.allow_nan_stats` is `False` an exception is raised when one or more 

288 modes are undefined.""") 

289 def _mode(self): 

290 k = math_ops.cast(self.event_shape_tensor()[0], self.dtype) 

291 mode = (self.concentration - 1.) / ( 

292 self.total_concentration[..., array_ops.newaxis] - k) 

293 if self.allow_nan_stats: 

294 nan = array_ops.fill( 

295 array_ops.shape(mode), 

296 np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), 

297 name="nan") 

298 return array_ops.where_v2( 

299 math_ops.reduce_all(self.concentration > 1., axis=-1), mode, nan) 

300 return control_flow_ops.with_dependencies([ 

301 check_ops.assert_less( 

302 array_ops.ones([], self.dtype), 

303 self.concentration, 

304 message="Mode undefined when any concentration <= 1"), 

305 ], mode) 

306 

307 def _maybe_assert_valid_concentration(self, concentration, validate_args): 

308 """Checks the validity of the concentration parameter.""" 

309 if not validate_args: 

310 return concentration 

311 return control_flow_ops.with_dependencies([ 

312 check_ops.assert_positive( 

313 concentration, 

314 message="Concentration parameter must be positive."), 

315 check_ops.assert_rank_at_least( 

316 concentration, 1, 

317 message="Concentration parameter must have >=1 dimensions."), 

318 check_ops.assert_less( 

319 1, array_ops.shape(concentration)[-1], 

320 message="Concentration parameter must have event_size >= 2."), 

321 ], concentration) 

322 

323 def _maybe_assert_valid_sample(self, x): 

324 """Checks the validity of a sample.""" 

325 if not self.validate_args: 

326 return x 

327 return control_flow_ops.with_dependencies([ 

328 check_ops.assert_positive(x, message="samples must be positive"), 

329 check_ops.assert_near( 

330 array_ops.ones([], dtype=self.dtype), 

331 math_ops.reduce_sum(x, -1), 

332 message="sample last-dimension must sum to `1`"), 

333 ], x) 

334 

335 

336@kullback_leibler.RegisterKL(Dirichlet, Dirichlet) 

337def _kl_dirichlet_dirichlet(d1, d2, name=None): 

338 """Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet. 

339 

340 Args: 

341 d1: instance of a Dirichlet distribution object. 

342 d2: instance of a Dirichlet distribution object. 

343 name: (optional) Name to use for created operations. 

344 default is "kl_dirichlet_dirichlet". 

345 

346 Returns: 

347 Batchwise KL(d1 || d2) 

348 """ 

349 with ops.name_scope(name, "kl_dirichlet_dirichlet", values=[ 

350 d1.concentration, d2.concentration]): 

351 # The KL between Dirichlet distributions can be derived as follows. We have 

352 # 

353 # Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)] 

354 # 

355 # where B(a) is the multivariate Beta function: 

356 # 

357 # B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n]) 

358 # 

359 # The KL is 

360 # 

361 # KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))} 

362 # 

363 # so we'll need to know the log density of the Dirichlet. This is 

364 # 

365 # log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a) 

366 # 

367 # The only term that matters for the expectations is the log(x[i]). To 

368 # compute the expectation of this term over the Dirichlet density, we can 

369 # use the following facts about the Dirichlet in exponential family form: 

370 # 1. log(x[i]) is a sufficient statistic 

371 # 2. expected sufficient statistics (of any exp family distribution) are 

372 # equal to derivatives of the log normalizer with respect to 

373 # corresponding natural parameters: E{T[i](x)} = dA/d(eta[i]) 

374 # 

375 # To proceed, we can rewrite the Dirichlet density in exponential family 

376 # form as follows: 

377 # 

378 # Dir(x; a) = exp{eta(a) . T(x) - A(a)} 

379 # 

380 # where '.' is the dot product of vectors eta and T, and A is a scalar: 

381 # 

382 # eta[i](a) = a[i] - 1 

383 # T[i](x) = log(x[i]) 

384 # A(a) = log B(a) 

385 # 

386 # Now, we can use fact (2) above to write 

387 # 

388 # E_Dir(x; a)[log(x[i])] 

389 # = dA(a) / da[i] 

390 # = d/da[i] log B(a) 

391 # = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j]) 

392 # = digamma(a[i])) - digamma(sum_j a[j]) 

393 # 

394 # Putting it all together, we have 

395 # 

396 # KL[Dir(x; a) || Dir(x; b)] 

397 # = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)} 

398 # = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b)) 

399 # = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b) 

400 # = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))] 

401 # - lbeta(a) + lbeta(b)) 

402 

403 digamma_sum_d1 = math_ops.digamma( 

404 math_ops.reduce_sum(d1.concentration, axis=-1, keepdims=True)) 

405 digamma_diff = math_ops.digamma(d1.concentration) - digamma_sum_d1 

406 concentration_diff = d1.concentration - d2.concentration 

407 

408 return (math_ops.reduce_sum(concentration_diff * digamma_diff, axis=-1) - 

409 special_math_ops.lbeta(d1.concentration) + 

410 special_math_ops.lbeta(d2.concentration))