Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/categorical.py: 34%

105 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The Categorical distribution class.""" 

16 

17from tensorflow.python.framework import constant_op 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_shape 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import math_ops 

23from tensorflow.python.ops import nn_ops 

24from tensorflow.python.ops import random_ops 

25from tensorflow.python.ops.distributions import distribution 

26from tensorflow.python.ops.distributions import kullback_leibler 

27from tensorflow.python.ops.distributions import util as distribution_util 

28from tensorflow.python.util import deprecation 

29from tensorflow.python.util.tf_export import tf_export 

30 

31 

32def _broadcast_cat_event_and_params(event, params, base_dtype): 

33 """Broadcasts the event or distribution parameters.""" 

34 if event.dtype.is_integer: 

35 pass 

36 elif event.dtype.is_floating: 

37 # When `validate_args=True` we've already ensured int/float casting 

38 # is closed. 

39 event = math_ops.cast(event, dtype=dtypes.int32) 

40 else: 

41 raise TypeError("`value` should have integer `dtype` or " 

42 "`self.dtype` ({})".format(base_dtype)) 

43 shape_known_statically = ( 

44 params.shape.ndims is not None and 

45 params.shape[:-1].is_fully_defined() and 

46 event.shape.is_fully_defined()) 

47 if not shape_known_statically or params.shape[:-1] != event.shape: 

48 params *= array_ops.ones_like(event[..., array_ops.newaxis], 

49 dtype=params.dtype) 

50 params_shape = array_ops.shape(params)[:-1] 

51 event *= array_ops.ones(params_shape, dtype=event.dtype) 

52 if params.shape.ndims is not None: 

53 event.set_shape(tensor_shape.TensorShape(params.shape[:-1])) 

54 

55 return event, params 

56 

57 

58@tf_export(v1=["distributions.Categorical"]) 

59class Categorical(distribution.Distribution): 

60 """Categorical distribution. 

61 

62 The Categorical distribution is parameterized by either probabilities or 

63 log-probabilities of a set of `K` classes. It is defined over the integers 

64 `{0, 1, ..., K}`. 

65 

66 The Categorical distribution is closely related to the `OneHotCategorical` and 

67 `Multinomial` distributions. The Categorical distribution can be intuited as 

68 generating samples according to `argmax{ OneHotCategorical(probs) }` itself 

69 being identical to `argmax{ Multinomial(probs, total_count=1) }`. 

70 

71 #### Mathematical Details 

72 

73 The probability mass function (pmf) is, 

74 

75 ```none 

76 pmf(k; pi) = prod_j pi_j**[k == j] 

77 ``` 

78 

79 #### Pitfalls 

80 

81 The number of classes, `K`, must not exceed: 

82 - the largest integer representable by `self.dtype`, i.e., 

83 `2**(mantissa_bits+1)` (IEEE 754), 

84 - the maximum `Tensor` index, i.e., `2**31-1`. 

85 

86 In other words, 

87 

88 ```python 

89 K <= min(2**31-1, { 

90 tf.float16: 2**11, 

91 tf.float32: 2**24, 

92 tf.float64: 2**53 }[param.dtype]) 

93 ``` 

94 

95 Note: This condition is validated only when `self.validate_args = True`. 

96 

97 #### Examples 

98 

99 Creates a 3-class distribution with the 2nd class being most likely. 

100 

101 ```python 

102 dist = Categorical(probs=[0.1, 0.5, 0.4]) 

103 n = 1e4 

104 empirical_prob = tf.cast( 

105 tf.histogram_fixed_width( 

106 dist.sample(int(n)), 

107 [0., 2], 

108 nbins=3), 

109 dtype=tf.float32) / n 

110 # ==> array([ 0.1005, 0.5037, 0.3958], dtype=float32) 

111 ``` 

112 

113 Creates a 3-class distribution with the 2nd class being most likely. 

114 Parameterized by [logits](https://en.wikipedia.org/wiki/Logit) rather than 

115 probabilities. 

116 

117 ```python 

118 dist = Categorical(logits=np.log([0.1, 0.5, 0.4]) 

119 n = 1e4 

120 empirical_prob = tf.cast( 

121 tf.histogram_fixed_width( 

122 dist.sample(int(n)), 

123 [0., 2], 

124 nbins=3), 

125 dtype=tf.float32) / n 

126 # ==> array([0.1045, 0.5047, 0.3908], dtype=float32) 

127 ``` 

128 

129 Creates a 3-class distribution with the 3rd class being most likely. 

130 The distribution functions can be evaluated on counts. 

131 

132 ```python 

133 # counts is a scalar. 

134 p = [0.1, 0.4, 0.5] 

135 dist = Categorical(probs=p) 

136 dist.prob(0) # Shape [] 

137 

138 # p will be broadcast to [[0.1, 0.4, 0.5], [0.1, 0.4, 0.5]] to match counts. 

139 counts = [1, 0] 

140 dist.prob(counts) # Shape [2] 

141 

142 # p will be broadcast to shape [3, 5, 7, 3] to match counts. 

143 counts = [[...]] # Shape [5, 7, 3] 

144 dist.prob(counts) # Shape [5, 7, 3] 

145 ``` 

146 

147 """ 

148 

149 @deprecation.deprecated( 

150 "2019-01-01", 

151 "The TensorFlow Distributions library has moved to " 

152 "TensorFlow Probability " 

153 "(https://github.com/tensorflow/probability). You " 

154 "should update all references to use `tfp.distributions` " 

155 "instead of `tf.distributions`.", 

156 warn_once=True) 

157 def __init__( 

158 self, 

159 logits=None, 

160 probs=None, 

161 dtype=dtypes.int32, 

162 validate_args=False, 

163 allow_nan_stats=True, 

164 name="Categorical"): 

165 """Initialize Categorical distributions using class log-probabilities. 

166 

167 Args: 

168 logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities 

169 of a set of Categorical distributions. The first `N - 1` dimensions 

170 index into a batch of independent distributions and the last dimension 

171 represents a vector of logits for each class. Only one of `logits` or 

172 `probs` should be passed in. 

173 probs: An N-D `Tensor`, `N >= 1`, representing the probabilities 

174 of a set of Categorical distributions. The first `N - 1` dimensions 

175 index into a batch of independent distributions and the last dimension 

176 represents a vector of probabilities for each class. Only one of 

177 `logits` or `probs` should be passed in. 

178 dtype: The type of the event samples (default: int32). 

179 validate_args: Python `bool`, default `False`. When `True` distribution 

180 parameters are checked for validity despite possibly degrading runtime 

181 performance. When `False` invalid inputs may silently render incorrect 

182 outputs. 

183 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 

184 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 

185 result is undefined. When `False`, an exception is raised if one or 

186 more of the statistic's batch members are undefined. 

187 name: Python `str` name prefixed to Ops created by this class. 

188 """ 

189 parameters = dict(locals()) 

190 with ops.name_scope(name, values=[logits, probs]) as name: 

191 self._logits, self._probs = distribution_util.get_logits_and_probs( 

192 logits=logits, 

193 probs=probs, 

194 validate_args=validate_args, 

195 multidimensional=True, 

196 name=name) 

197 

198 if validate_args: 

199 self._logits = distribution_util.embed_check_categorical_event_shape( 

200 self._logits) 

201 

202 logits_shape_static = self._logits.get_shape().with_rank_at_least(1) 

203 if logits_shape_static.ndims is not None: 

204 self._batch_rank = ops.convert_to_tensor( 

205 logits_shape_static.ndims - 1, 

206 dtype=dtypes.int32, 

207 name="batch_rank") 

208 else: 

209 with ops.name_scope(name="batch_rank"): 

210 self._batch_rank = array_ops.rank(self._logits) - 1 

211 

212 logits_shape = array_ops.shape(self._logits, name="logits_shape") 

213 if tensor_shape.dimension_value(logits_shape_static[-1]) is not None: 

214 self._event_size = ops.convert_to_tensor( 

215 logits_shape_static.dims[-1].value, 

216 dtype=dtypes.int32, 

217 name="event_size") 

218 else: 

219 with ops.name_scope(name="event_size"): 

220 self._event_size = logits_shape[self._batch_rank] 

221 

222 if logits_shape_static[:-1].is_fully_defined(): 

223 self._batch_shape_val = constant_op.constant( 

224 logits_shape_static[:-1].as_list(), 

225 dtype=dtypes.int32, 

226 name="batch_shape") 

227 else: 

228 with ops.name_scope(name="batch_shape"): 

229 self._batch_shape_val = logits_shape[:-1] 

230 super(Categorical, self).__init__( 

231 dtype=dtype, 

232 reparameterization_type=distribution.NOT_REPARAMETERIZED, 

233 validate_args=validate_args, 

234 allow_nan_stats=allow_nan_stats, 

235 parameters=parameters, 

236 graph_parents=[self._logits, 

237 self._probs], 

238 name=name) 

239 

240 @property 

241 def event_size(self): 

242 """Scalar `int32` tensor: the number of classes.""" 

243 return self._event_size 

244 

245 @property 

246 def logits(self): 

247 """Vector of coordinatewise logits.""" 

248 return self._logits 

249 

250 @property 

251 def probs(self): 

252 """Vector of coordinatewise probabilities.""" 

253 return self._probs 

254 

255 def _batch_shape_tensor(self): 

256 return array_ops.identity(self._batch_shape_val) 

257 

258 def _batch_shape(self): 

259 return self.logits.get_shape()[:-1] 

260 

261 def _event_shape_tensor(self): 

262 return constant_op.constant([], dtype=dtypes.int32) 

263 

264 def _event_shape(self): 

265 return tensor_shape.TensorShape([]) 

266 

267 def _sample_n(self, n, seed=None): 

268 if self.logits.get_shape().ndims == 2: 

269 logits_2d = self.logits 

270 else: 

271 logits_2d = array_ops.reshape(self.logits, [-1, self.event_size]) 

272 sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32 

273 draws = random_ops.multinomial( 

274 logits_2d, n, seed=seed, output_dtype=sample_dtype) 

275 draws = array_ops.reshape( 

276 array_ops.transpose(draws), 

277 array_ops.concat([[n], self.batch_shape_tensor()], 0)) 

278 return math_ops.cast(draws, self.dtype) 

279 

280 def _cdf(self, k): 

281 k = ops.convert_to_tensor(k, name="k") 

282 if self.validate_args: 

283 k = distribution_util.embed_check_integer_casting_closed( 

284 k, target_dtype=dtypes.int32) 

285 

286 k, probs = _broadcast_cat_event_and_params( 

287 k, self.probs, base_dtype=self.dtype.base_dtype) 

288 

289 # batch-flatten everything in order to use `sequence_mask()`. 

290 batch_flattened_probs = array_ops.reshape(probs, 

291 (-1, self._event_size)) 

292 batch_flattened_k = array_ops.reshape(k, [-1]) 

293 

294 to_sum_over = array_ops.where( 

295 array_ops.sequence_mask(batch_flattened_k, self._event_size), 

296 batch_flattened_probs, 

297 array_ops.zeros_like(batch_flattened_probs)) 

298 batch_flattened_cdf = math_ops.reduce_sum(to_sum_over, axis=-1) 

299 # Reshape back to the shape of the argument. 

300 return array_ops.reshape(batch_flattened_cdf, array_ops.shape(k)) 

301 

302 def _log_prob(self, k): 

303 k = ops.convert_to_tensor(k, name="k") 

304 if self.validate_args: 

305 k = distribution_util.embed_check_integer_casting_closed( 

306 k, target_dtype=dtypes.int32) 

307 k, logits = _broadcast_cat_event_and_params( 

308 k, self.logits, base_dtype=self.dtype.base_dtype) 

309 

310 # pylint: disable=invalid-unary-operand-type 

311 return -nn_ops.sparse_softmax_cross_entropy_with_logits( 

312 labels=k, 

313 logits=logits) 

314 

315 def _entropy(self): 

316 return -math_ops.reduce_sum( 

317 nn_ops.log_softmax(self.logits) * self.probs, axis=-1) 

318 

319 def _mode(self): 

320 ret = math_ops.argmax(self.logits, axis=self._batch_rank) 

321 ret = math_ops.cast(ret, self.dtype) 

322 ret.set_shape(self.batch_shape) 

323 return ret 

324 

325 

326@kullback_leibler.RegisterKL(Categorical, Categorical) 

327def _kl_categorical_categorical(a, b, name=None): 

328 """Calculate the batched KL divergence KL(a || b) with a and b Categorical. 

329 

330 Args: 

331 a: instance of a Categorical distribution object. 

332 b: instance of a Categorical distribution object. 

333 name: (optional) Name to use for created operations. 

334 default is "kl_categorical_categorical". 

335 

336 Returns: 

337 Batchwise KL(a || b) 

338 """ 

339 with ops.name_scope(name, "kl_categorical_categorical", 

340 values=[a.logits, b.logits]): 

341 # sum(probs log(probs / (1 - probs))) 

342 delta_log_probs1 = (nn_ops.log_softmax(a.logits) - 

343 nn_ops.log_softmax(b.logits)) 

344 return math_ops.reduce_sum(nn_ops.softmax(a.logits) * delta_log_probs1, 

345 axis=-1)