Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/beta.py: 52%

111 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The Beta distribution class.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.framework import constant_op 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import tensor_shape 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import check_ops 

25from tensorflow.python.ops import control_flow_ops 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.ops import nn 

28from tensorflow.python.ops import random_ops 

29from tensorflow.python.ops.distributions import distribution 

30from tensorflow.python.ops.distributions import kullback_leibler 

31from tensorflow.python.ops.distributions import util as distribution_util 

32from tensorflow.python.util import deprecation 

33from tensorflow.python.util.tf_export import tf_export 

34 

35 

36__all__ = [ 

37 "Beta", 

38 "BetaWithSoftplusConcentration", 

39] 

40 

41 

42_beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in 

43`[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" 

44 

45 

46@tf_export(v1=["distributions.Beta"]) 

47class Beta(distribution.Distribution): 

48 """Beta distribution. 

49 

50 The Beta distribution is defined over the `(0, 1)` interval using parameters 

51 `concentration1` (aka "alpha") and `concentration0` (aka "beta"). 

52 

53 #### Mathematical Details 

54 

55 The probability density function (pdf) is, 

56 

57 ```none 

58 pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z 

59 Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta) 

60 ``` 

61 

62 where: 

63 

64 * `concentration1 = alpha`, 

65 * `concentration0 = beta`, 

66 * `Z` is the normalization constant, and, 

67 * `Gamma` is the [gamma function]( 

68 https://en.wikipedia.org/wiki/Gamma_function). 

69 

70 The concentration parameters represent mean total counts of a `1` or a `0`, 

71 i.e., 

72 

73 ```none 

74 concentration1 = alpha = mean * total_concentration 

75 concentration0 = beta = (1. - mean) * total_concentration 

76 ``` 

77 

78 where `mean` in `(0, 1)` and `total_concentration` is a positive real number 

79 representing a mean `total_count = concentration1 + concentration0`. 

80 

81 Distribution parameters are automatically broadcast in all functions; see 

82 examples for details. 

83 

84 Warning: The samples can be zero due to finite precision. 

85 This happens more often when some of the concentrations are very small. 

86 Make sure to round the samples to `np.finfo(dtype).tiny` before computing the 

87 density. 

88 

89 Samples of this distribution are reparameterized (pathwise differentiable). 

90 The derivatives are computed using the approach described in 

91 (Figurnov et al., 2018). 

92 

93 #### Examples 

94 

95 ```python 

96 import tensorflow_probability as tfp 

97 tfd = tfp.distributions 

98 

99 # Create a batch of three Beta distributions. 

100 alpha = [1, 2, 3] 

101 beta = [1, 2, 3] 

102 dist = tfd.Beta(alpha, beta) 

103 

104 dist.sample([4, 5]) # Shape [4, 5, 3] 

105 

106 # `x` has three batch entries, each with two samples. 

107 x = [[.1, .4, .5], 

108 [.2, .3, .5]] 

109 # Calculate the probability of each pair of samples under the corresponding 

110 # distribution in `dist`. 

111 dist.prob(x) # Shape [2, 3] 

112 ``` 

113 

114 ```python 

115 # Create batch_shape=[2, 3] via parameter broadcast: 

116 alpha = [[1.], [2]] # Shape [2, 1] 

117 beta = [3., 4, 5] # Shape [3] 

118 dist = tfd.Beta(alpha, beta) 

119 

120 # alpha broadcast as: [[1., 1, 1,], 

121 # [2, 2, 2]] 

122 # beta broadcast as: [[3., 4, 5], 

123 # [3, 4, 5]] 

124 # batch_Shape [2, 3] 

125 dist.sample([4, 5]) # Shape [4, 5, 2, 3] 

126 

127 x = [.2, .3, .5] 

128 # x will be broadcast as [[.2, .3, .5], 

129 # [.2, .3, .5]], 

130 # thus matching batch_shape [2, 3]. 

131 dist.prob(x) # Shape [2, 3] 

132 ``` 

133 

134 Compute the gradients of samples w.r.t. the parameters: 

135 

136 ```python 

137 alpha = tf.constant(1.0) 

138 beta = tf.constant(2.0) 

139 dist = tfd.Beta(alpha, beta) 

140 samples = dist.sample(5) # Shape [5] 

141 loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function 

142 # Unbiased stochastic gradients of the loss function 

143 grads = tf.gradients(loss, [alpha, beta]) 

144 ``` 

145 

146 References: 

147 Implicit Reparameterization Gradients: 

148 [Figurnov et al., 2018] 

149 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients) 

150 ([pdf] 

151 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf)) 

152 """ 

153 

154 @deprecation.deprecated( 

155 "2019-01-01", 

156 "The TensorFlow Distributions library has moved to " 

157 "TensorFlow Probability " 

158 "(https://github.com/tensorflow/probability). You " 

159 "should update all references to use `tfp.distributions` " 

160 "instead of `tf.distributions`.", 

161 warn_once=True) 

162 def __init__(self, 

163 concentration1=None, 

164 concentration0=None, 

165 validate_args=False, 

166 allow_nan_stats=True, 

167 name="Beta"): 

168 """Initialize a batch of Beta distributions. 

169 

170 Args: 

171 concentration1: Positive floating-point `Tensor` indicating mean 

172 number of successes; aka "alpha". Implies `self.dtype` and 

173 `self.batch_shape`, i.e., 

174 `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. 

175 concentration0: Positive floating-point `Tensor` indicating mean 

176 number of failures; aka "beta". Otherwise has same semantics as 

177 `concentration1`. 

178 validate_args: Python `bool`, default `False`. When `True` distribution 

179 parameters are checked for validity despite possibly degrading runtime 

180 performance. When `False` invalid inputs may silently render incorrect 

181 outputs. 

182 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 

183 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 

184 result is undefined. When `False`, an exception is raised if one or 

185 more of the statistic's batch members are undefined. 

186 name: Python `str` name prefixed to Ops created by this class. 

187 """ 

188 parameters = dict(locals()) 

189 with ops.name_scope(name, values=[concentration1, concentration0]) as name: 

190 self._concentration1 = self._maybe_assert_valid_concentration( 

191 ops.convert_to_tensor(concentration1, name="concentration1"), 

192 validate_args) 

193 self._concentration0 = self._maybe_assert_valid_concentration( 

194 ops.convert_to_tensor(concentration0, name="concentration0"), 

195 validate_args) 

196 check_ops.assert_same_float_dtype([ 

197 self._concentration1, self._concentration0]) 

198 self._total_concentration = self._concentration1 + self._concentration0 

199 super(Beta, self).__init__( 

200 dtype=self._total_concentration.dtype, 

201 validate_args=validate_args, 

202 allow_nan_stats=allow_nan_stats, 

203 reparameterization_type=distribution.FULLY_REPARAMETERIZED, 

204 parameters=parameters, 

205 graph_parents=[self._concentration1, 

206 self._concentration0, 

207 self._total_concentration], 

208 name=name) 

209 

210 @staticmethod 

211 def _param_shapes(sample_shape): 

212 return dict(zip( 

213 ["concentration1", "concentration0"], 

214 [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)) 

215 

216 @property 

217 def concentration1(self): 

218 """Concentration parameter associated with a `1` outcome.""" 

219 return self._concentration1 

220 

221 @property 

222 def concentration0(self): 

223 """Concentration parameter associated with a `0` outcome.""" 

224 return self._concentration0 

225 

226 @property 

227 def total_concentration(self): 

228 """Sum of concentration parameters.""" 

229 return self._total_concentration 

230 

231 def _batch_shape_tensor(self): 

232 return array_ops.shape(self.total_concentration) 

233 

234 def _batch_shape(self): 

235 return self.total_concentration.get_shape() 

236 

237 def _event_shape_tensor(self): 

238 return constant_op.constant([], dtype=dtypes.int32) 

239 

240 def _event_shape(self): 

241 return tensor_shape.TensorShape([]) 

242 

243 def _sample_n(self, n, seed=None): 

244 expanded_concentration1 = array_ops.ones_like( 

245 self.total_concentration, dtype=self.dtype) * self.concentration1 

246 expanded_concentration0 = array_ops.ones_like( 

247 self.total_concentration, dtype=self.dtype) * self.concentration0 

248 gamma1_sample = random_ops.random_gamma( 

249 shape=[n], 

250 alpha=expanded_concentration1, 

251 dtype=self.dtype, 

252 seed=seed) 

253 gamma2_sample = random_ops.random_gamma( 

254 shape=[n], 

255 alpha=expanded_concentration0, 

256 dtype=self.dtype, 

257 seed=distribution_util.gen_new_seed(seed, "beta")) 

258 beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) 

259 return beta_sample 

260 

261 @distribution_util.AppendDocstring(_beta_sample_note) 

262 def _log_prob(self, x): 

263 return self._log_unnormalized_prob(x) - self._log_normalization() 

264 

265 @distribution_util.AppendDocstring(_beta_sample_note) 

266 def _prob(self, x): 

267 return math_ops.exp(self._log_prob(x)) 

268 

269 @distribution_util.AppendDocstring(_beta_sample_note) 

270 def _log_cdf(self, x): 

271 return math_ops.log(self._cdf(x)) 

272 

273 @distribution_util.AppendDocstring(_beta_sample_note) 

274 def _cdf(self, x): 

275 return math_ops.betainc(self.concentration1, self.concentration0, x) 

276 

277 def _log_unnormalized_prob(self, x): 

278 x = self._maybe_assert_valid_sample(x) 

279 return (math_ops.xlogy(self.concentration1 - 1., x) + 

280 (self.concentration0 - 1.) * math_ops.log1p(-x)) # pylint: disable=invalid-unary-operand-type 

281 

282 def _log_normalization(self): 

283 return (math_ops.lgamma(self.concentration1) 

284 + math_ops.lgamma(self.concentration0) 

285 - math_ops.lgamma(self.total_concentration)) 

286 

287 def _entropy(self): 

288 return ( 

289 self._log_normalization() 

290 - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1) 

291 - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0) 

292 + ((self.total_concentration - 2.) * 

293 math_ops.digamma(self.total_concentration))) 

294 

295 def _mean(self): 

296 return self._concentration1 / self._total_concentration 

297 

298 def _variance(self): 

299 return self._mean() * (1. - self._mean()) / (1. + self.total_concentration) 

300 

301 @distribution_util.AppendDocstring( 

302 """Note: The mode is undefined when `concentration1 <= 1` or 

303 `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN` 

304 is used for undefined modes. If `self.allow_nan_stats` is `False` an 

305 exception is raised when one or more modes are undefined.""") 

306 def _mode(self): 

307 mode = (self.concentration1 - 1.) / (self.total_concentration - 2.) 

308 if self.allow_nan_stats: 

309 nan = array_ops.fill( 

310 self.batch_shape_tensor(), 

311 np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), 

312 name="nan") 

313 is_defined = math_ops.logical_and(self.concentration1 > 1., 

314 self.concentration0 > 1.) 

315 return array_ops.where_v2(is_defined, mode, nan) 

316 return control_flow_ops.with_dependencies([ 

317 check_ops.assert_less( 

318 array_ops.ones([], dtype=self.dtype), 

319 self.concentration1, 

320 message="Mode undefined for concentration1 <= 1."), 

321 check_ops.assert_less( 

322 array_ops.ones([], dtype=self.dtype), 

323 self.concentration0, 

324 message="Mode undefined for concentration0 <= 1.") 

325 ], mode) 

326 

327 def _maybe_assert_valid_concentration(self, concentration, validate_args): 

328 """Checks the validity of a concentration parameter.""" 

329 if not validate_args: 

330 return concentration 

331 return control_flow_ops.with_dependencies([ 

332 check_ops.assert_positive( 

333 concentration, 

334 message="Concentration parameter must be positive."), 

335 ], concentration) 

336 

337 def _maybe_assert_valid_sample(self, x): 

338 """Checks the validity of a sample.""" 

339 if not self.validate_args: 

340 return x 

341 return control_flow_ops.with_dependencies([ 

342 check_ops.assert_positive(x, message="sample must be positive"), 

343 check_ops.assert_less( 

344 x, 

345 array_ops.ones([], self.dtype), 

346 message="sample must be less than `1`."), 

347 ], x) 

348 

349 

350class BetaWithSoftplusConcentration(Beta): 

351 """Beta with softplus transform of `concentration1` and `concentration0`.""" 

352 

353 @deprecation.deprecated( 

354 "2019-01-01", 

355 "Use `tfd.Beta(tf.nn.softplus(concentration1), " 

356 "tf.nn.softplus(concentration2))` instead.", 

357 warn_once=True) 

358 def __init__(self, 

359 concentration1, 

360 concentration0, 

361 validate_args=False, 

362 allow_nan_stats=True, 

363 name="BetaWithSoftplusConcentration"): 

364 parameters = dict(locals()) 

365 with ops.name_scope(name, values=[concentration1, 

366 concentration0]) as name: 

367 super(BetaWithSoftplusConcentration, self).__init__( 

368 concentration1=nn.softplus(concentration1, 

369 name="softplus_concentration1"), 

370 concentration0=nn.softplus(concentration0, 

371 name="softplus_concentration0"), 

372 validate_args=validate_args, 

373 allow_nan_stats=allow_nan_stats, 

374 name=name) 

375 self._parameters = parameters 

376 

377 

378@kullback_leibler.RegisterKL(Beta, Beta) 

379def _kl_beta_beta(d1, d2, name=None): 

380 """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta. 

381 

382 Args: 

383 d1: instance of a Beta distribution object. 

384 d2: instance of a Beta distribution object. 

385 name: (optional) Name to use for created operations. 

386 default is "kl_beta_beta". 

387 

388 Returns: 

389 Batchwise KL(d1 || d2) 

390 """ 

391 def delta(fn, is_property=True): 

392 fn1 = getattr(d1, fn) 

393 fn2 = getattr(d2, fn) 

394 return (fn2 - fn1) if is_property else (fn2() - fn1()) 

395 with ops.name_scope(name, "kl_beta_beta", values=[ 

396 d1.concentration1, 

397 d1.concentration0, 

398 d1.total_concentration, 

399 d2.concentration1, 

400 d2.concentration0, 

401 d2.total_concentration, 

402 ]): 

403 return (delta("_log_normalization", is_property=False) 

404 - math_ops.digamma(d1.concentration1) * delta("concentration1") 

405 - math_ops.digamma(d1.concentration0) * delta("concentration0") 

406 + (math_ops.digamma(d1.total_concentration) 

407 * delta("total_concentration")))