Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/gamma.py: 56%

90 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The Gamma distribution class.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.framework import constant_op 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import tensor_shape 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import check_ops 

25from tensorflow.python.ops import control_flow_ops 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.ops import nn 

28from tensorflow.python.ops import random_ops 

29from tensorflow.python.ops.distributions import distribution 

30from tensorflow.python.ops.distributions import kullback_leibler 

31from tensorflow.python.ops.distributions import util as distribution_util 

32from tensorflow.python.util import deprecation 

33from tensorflow.python.util.tf_export import tf_export 

34 

35 

36__all__ = [ 

37 "Gamma", 

38 "GammaWithSoftplusConcentrationRate", 

39] 

40 

41 

42@tf_export(v1=["distributions.Gamma"]) 

43class Gamma(distribution.Distribution): 

44 """Gamma distribution. 

45 

46 The Gamma distribution is defined over positive real numbers using 

47 parameters `concentration` (aka "alpha") and `rate` (aka "beta"). 

48 

49 #### Mathematical Details 

50 

51 The probability density function (pdf) is, 

52 

53 ```none 

54 pdf(x; alpha, beta, x > 0) = x**(alpha - 1) exp(-x beta) / Z 

55 Z = Gamma(alpha) beta**(-alpha) 

56 ``` 

57 

58 where: 

59 

60 * `concentration = alpha`, `alpha > 0`, 

61 * `rate = beta`, `beta > 0`, 

62 * `Z` is the normalizing constant, and, 

63 * `Gamma` is the [gamma function]( 

64 https://en.wikipedia.org/wiki/Gamma_function). 

65 

66 The cumulative density function (cdf) is, 

67 

68 ```none 

69 cdf(x; alpha, beta, x > 0) = GammaInc(alpha, beta x) / Gamma(alpha) 

70 ``` 

71 

72 where `GammaInc` is the [lower incomplete Gamma function]( 

73 https://en.wikipedia.org/wiki/Incomplete_gamma_function). 

74 

75 The parameters can be intuited via their relationship to mean and stddev, 

76 

77 ```none 

78 concentration = alpha = (mean / stddev)**2 

79 rate = beta = mean / stddev**2 = concentration / mean 

80 ``` 

81 

82 Distribution parameters are automatically broadcast in all functions; see 

83 examples for details. 

84 

85 Warning: The samples of this distribution are always non-negative. However, 

86 the samples that are smaller than `np.finfo(dtype).tiny` are rounded 

87 to this value, so it appears more often than it should. 

88 This should only be noticeable when the `concentration` is very small, or the 

89 `rate` is very large. See note in `tf.random.gamma` docstring. 

90 

91 Samples of this distribution are reparameterized (pathwise differentiable). 

92 The derivatives are computed using the approach described in 

93 (Figurnov et al., 2018). 

94 

95 #### Examples 

96 

97 ```python 

98 import tensorflow_probability as tfp 

99 tfd = tfp.distributions 

100 

101 dist = tfd.Gamma(concentration=3.0, rate=2.0) 

102 dist2 = tfd.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) 

103 ``` 

104 

105 Compute the gradients of samples w.r.t. the parameters: 

106 

107 ```python 

108 concentration = tf.constant(3.0) 

109 rate = tf.constant(2.0) 

110 dist = tfd.Gamma(concentration, rate) 

111 samples = dist.sample(5) # Shape [5] 

112 loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function 

113 # Unbiased stochastic gradients of the loss function 

114 grads = tf.gradients(loss, [concentration, rate]) 

115 ``` 

116 

117 References: 

118 Implicit Reparameterization Gradients: 

119 [Figurnov et al., 2018] 

120 (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients) 

121 ([pdf](http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf)) 

122 """ 

123 

124 @deprecation.deprecated( 

125 "2019-01-01", 

126 "The TensorFlow Distributions library has moved to " 

127 "TensorFlow Probability " 

128 "(https://github.com/tensorflow/probability). You " 

129 "should update all references to use `tfp.distributions` " 

130 "instead of `tf.distributions`.", 

131 warn_once=True) 

132 def __init__(self, 

133 concentration, 

134 rate, 

135 validate_args=False, 

136 allow_nan_stats=True, 

137 name="Gamma"): 

138 """Construct Gamma with `concentration` and `rate` parameters. 

139 

140 The parameters `concentration` and `rate` must be shaped in a way that 

141 supports broadcasting (e.g. `concentration + rate` is a valid operation). 

142 

143 Args: 

144 concentration: Floating point tensor, the concentration params of the 

145 distribution(s). Must contain only positive values. 

146 rate: Floating point tensor, the inverse scale params of the 

147 distribution(s). Must contain only positive values. 

148 validate_args: Python `bool`, default `False`. When `True` distribution 

149 parameters are checked for validity despite possibly degrading runtime 

150 performance. When `False` invalid inputs may silently render incorrect 

151 outputs. 

152 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 

153 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 

154 result is undefined. When `False`, an exception is raised if one or 

155 more of the statistic's batch members are undefined. 

156 name: Python `str` name prefixed to Ops created by this class. 

157 

158 Raises: 

159 TypeError: if `concentration` and `rate` are different dtypes. 

160 """ 

161 parameters = dict(locals()) 

162 with ops.name_scope(name, values=[concentration, rate]) as name: 

163 with ops.control_dependencies([ 

164 check_ops.assert_positive(concentration), 

165 check_ops.assert_positive(rate), 

166 ] if validate_args else []): 

167 self._concentration = array_ops.identity( 

168 concentration, name="concentration") 

169 self._rate = array_ops.identity(rate, name="rate") 

170 check_ops.assert_same_float_dtype( 

171 [self._concentration, self._rate]) 

172 super(Gamma, self).__init__( 

173 dtype=self._concentration.dtype, 

174 validate_args=validate_args, 

175 allow_nan_stats=allow_nan_stats, 

176 reparameterization_type=distribution.FULLY_REPARAMETERIZED, 

177 parameters=parameters, 

178 graph_parents=[self._concentration, 

179 self._rate], 

180 name=name) 

181 

182 @staticmethod 

183 def _param_shapes(sample_shape): 

184 return dict( 

185 zip(("concentration", "rate"), ([ops.convert_to_tensor( 

186 sample_shape, dtype=dtypes.int32)] * 2))) 

187 

188 @property 

189 def concentration(self): 

190 """Concentration parameter.""" 

191 return self._concentration 

192 

193 @property 

194 def rate(self): 

195 """Rate parameter.""" 

196 return self._rate 

197 

198 def _batch_shape_tensor(self): 

199 return array_ops.broadcast_dynamic_shape( 

200 array_ops.shape(self.concentration), 

201 array_ops.shape(self.rate)) 

202 

203 def _batch_shape(self): 

204 return array_ops.broadcast_static_shape( 

205 self.concentration.get_shape(), 

206 self.rate.get_shape()) 

207 

208 def _event_shape_tensor(self): 

209 return constant_op.constant([], dtype=dtypes.int32) 

210 

211 def _event_shape(self): 

212 return tensor_shape.TensorShape([]) 

213 

214 @distribution_util.AppendDocstring( 

215 """Note: See `tf.random.gamma` docstring for sampling details and 

216 caveats.""") 

217 def _sample_n(self, n, seed=None): 

218 return random_ops.random_gamma( 

219 shape=[n], 

220 alpha=self.concentration, 

221 beta=self.rate, 

222 dtype=self.dtype, 

223 seed=seed) 

224 

225 def _log_prob(self, x): 

226 return self._log_unnormalized_prob(x) - self._log_normalization() 

227 

228 def _cdf(self, x): 

229 x = self._maybe_assert_valid_sample(x) 

230 # Note that igamma returns the regularized incomplete gamma function, 

231 # which is what we want for the CDF. 

232 return math_ops.igamma(self.concentration, self.rate * x) 

233 

234 def _log_unnormalized_prob(self, x): 

235 x = self._maybe_assert_valid_sample(x) 

236 return math_ops.xlogy(self.concentration - 1., x) - self.rate * x 

237 

238 def _log_normalization(self): 

239 return (math_ops.lgamma(self.concentration) 

240 - self.concentration * math_ops.log(self.rate)) 

241 

242 def _entropy(self): 

243 return (self.concentration 

244 - math_ops.log(self.rate) 

245 + math_ops.lgamma(self.concentration) 

246 + ((1. - self.concentration) * 

247 math_ops.digamma(self.concentration))) 

248 

249 def _mean(self): 

250 return self.concentration / self.rate 

251 

252 def _variance(self): 

253 return self.concentration / math_ops.square(self.rate) 

254 

255 def _stddev(self): 

256 return math_ops.sqrt(self.concentration) / self.rate 

257 

258 @distribution_util.AppendDocstring( 

259 """The mode of a gamma distribution is `(shape - 1) / rate` when 

260 `shape > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is `False`, 

261 an exception will be raised rather than returning `NaN`.""") 

262 def _mode(self): 

263 mode = (self.concentration - 1.) / self.rate 

264 if self.allow_nan_stats: 

265 nan = array_ops.fill( 

266 self.batch_shape_tensor(), 

267 np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), 

268 name="nan") 

269 return array_ops.where_v2(self.concentration > 1., mode, nan) 

270 else: 

271 return control_flow_ops.with_dependencies([ 

272 check_ops.assert_less( 

273 array_ops.ones([], self.dtype), 

274 self.concentration, 

275 message="mode not defined when any concentration <= 1"), 

276 ], mode) 

277 

278 def _maybe_assert_valid_sample(self, x): 

279 check_ops.assert_same_float_dtype(tensors=[x], dtype=self.dtype) 

280 if not self.validate_args: 

281 return x 

282 return control_flow_ops.with_dependencies([ 

283 check_ops.assert_positive(x), 

284 ], x) 

285 

286 

287class GammaWithSoftplusConcentrationRate(Gamma): 

288 """`Gamma` with softplus of `concentration` and `rate`.""" 

289 

290 @deprecation.deprecated( 

291 "2019-01-01", 

292 "Use `tfd.Gamma(tf.nn.softplus(concentration), " 

293 "tf.nn.softplus(rate))` instead.", 

294 warn_once=True) 

295 def __init__(self, 

296 concentration, 

297 rate, 

298 validate_args=False, 

299 allow_nan_stats=True, 

300 name="GammaWithSoftplusConcentrationRate"): 

301 parameters = dict(locals()) 

302 with ops.name_scope(name, values=[concentration, rate]) as name: 

303 super(GammaWithSoftplusConcentrationRate, self).__init__( 

304 concentration=nn.softplus(concentration, 

305 name="softplus_concentration"), 

306 rate=nn.softplus(rate, name="softplus_rate"), 

307 validate_args=validate_args, 

308 allow_nan_stats=allow_nan_stats, 

309 name=name) 

310 self._parameters = parameters 

311 

312 

313@kullback_leibler.RegisterKL(Gamma, Gamma) 

314def _kl_gamma_gamma(g0, g1, name=None): 

315 """Calculate the batched KL divergence KL(g0 || g1) with g0 and g1 Gamma. 

316 

317 Args: 

318 g0: instance of a Gamma distribution object. 

319 g1: instance of a Gamma distribution object. 

320 name: (optional) Name to use for created operations. 

321 Default is "kl_gamma_gamma". 

322 

323 Returns: 

324 kl_gamma_gamma: `Tensor`. The batchwise KL(g0 || g1). 

325 """ 

326 with ops.name_scope(name, "kl_gamma_gamma", values=[ 

327 g0.concentration, g0.rate, g1.concentration, g1.rate]): 

328 # Result from: 

329 # http://www.fil.ion.ucl.ac.uk/~wpenny/publications/densities.ps 

330 # For derivation see: 

331 # http://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions pylint: disable=line-too-long 

332 return (((g0.concentration - g1.concentration) 

333 * math_ops.digamma(g0.concentration)) 

334 + math_ops.lgamma(g1.concentration) 

335 - math_ops.lgamma(g0.concentration) 

336 + g1.concentration * math_ops.log(g0.rate) 

337 - g1.concentration * math_ops.log(g1.rate) 

338 + g0.concentration * (g1.rate / g0.rate - 1.))