Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/categorical.py: 34%
105 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Categorical distribution class."""
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops import nn_ops
24from tensorflow.python.ops import random_ops
25from tensorflow.python.ops.distributions import distribution
26from tensorflow.python.ops.distributions import kullback_leibler
27from tensorflow.python.ops.distributions import util as distribution_util
28from tensorflow.python.util import deprecation
29from tensorflow.python.util.tf_export import tf_export
32def _broadcast_cat_event_and_params(event, params, base_dtype):
33 """Broadcasts the event or distribution parameters."""
34 if event.dtype.is_integer:
35 pass
36 elif event.dtype.is_floating:
37 # When `validate_args=True` we've already ensured int/float casting
38 # is closed.
39 event = math_ops.cast(event, dtype=dtypes.int32)
40 else:
41 raise TypeError("`value` should have integer `dtype` or "
42 "`self.dtype` ({})".format(base_dtype))
43 shape_known_statically = (
44 params.shape.ndims is not None and
45 params.shape[:-1].is_fully_defined() and
46 event.shape.is_fully_defined())
47 if not shape_known_statically or params.shape[:-1] != event.shape:
48 params *= array_ops.ones_like(event[..., array_ops.newaxis],
49 dtype=params.dtype)
50 params_shape = array_ops.shape(params)[:-1]
51 event *= array_ops.ones(params_shape, dtype=event.dtype)
52 if params.shape.ndims is not None:
53 event.set_shape(tensor_shape.TensorShape(params.shape[:-1]))
55 return event, params
58@tf_export(v1=["distributions.Categorical"])
59class Categorical(distribution.Distribution):
60 """Categorical distribution.
62 The Categorical distribution is parameterized by either probabilities or
63 log-probabilities of a set of `K` classes. It is defined over the integers
64 `{0, 1, ..., K}`.
66 The Categorical distribution is closely related to the `OneHotCategorical` and
67 `Multinomial` distributions. The Categorical distribution can be intuited as
68 generating samples according to `argmax{ OneHotCategorical(probs) }` itself
69 being identical to `argmax{ Multinomial(probs, total_count=1) }`.
71 #### Mathematical Details
73 The probability mass function (pmf) is,
75 ```none
76 pmf(k; pi) = prod_j pi_j**[k == j]
77 ```
79 #### Pitfalls
81 The number of classes, `K`, must not exceed:
82 - the largest integer representable by `self.dtype`, i.e.,
83 `2**(mantissa_bits+1)` (IEEE 754),
84 - the maximum `Tensor` index, i.e., `2**31-1`.
86 In other words,
88 ```python
89 K <= min(2**31-1, {
90 tf.float16: 2**11,
91 tf.float32: 2**24,
92 tf.float64: 2**53 }[param.dtype])
93 ```
95 Note: This condition is validated only when `self.validate_args = True`.
97 #### Examples
99 Creates a 3-class distribution with the 2nd class being most likely.
101 ```python
102 dist = Categorical(probs=[0.1, 0.5, 0.4])
103 n = 1e4
104 empirical_prob = tf.cast(
105 tf.histogram_fixed_width(
106 dist.sample(int(n)),
107 [0., 2],
108 nbins=3),
109 dtype=tf.float32) / n
110 # ==> array([ 0.1005, 0.5037, 0.3958], dtype=float32)
111 ```
113 Creates a 3-class distribution with the 2nd class being most likely.
114 Parameterized by [logits](https://en.wikipedia.org/wiki/Logit) rather than
115 probabilities.
117 ```python
118 dist = Categorical(logits=np.log([0.1, 0.5, 0.4])
119 n = 1e4
120 empirical_prob = tf.cast(
121 tf.histogram_fixed_width(
122 dist.sample(int(n)),
123 [0., 2],
124 nbins=3),
125 dtype=tf.float32) / n
126 # ==> array([0.1045, 0.5047, 0.3908], dtype=float32)
127 ```
129 Creates a 3-class distribution with the 3rd class being most likely.
130 The distribution functions can be evaluated on counts.
132 ```python
133 # counts is a scalar.
134 p = [0.1, 0.4, 0.5]
135 dist = Categorical(probs=p)
136 dist.prob(0) # Shape []
138 # p will be broadcast to [[0.1, 0.4, 0.5], [0.1, 0.4, 0.5]] to match counts.
139 counts = [1, 0]
140 dist.prob(counts) # Shape [2]
142 # p will be broadcast to shape [3, 5, 7, 3] to match counts.
143 counts = [[...]] # Shape [5, 7, 3]
144 dist.prob(counts) # Shape [5, 7, 3]
145 ```
147 """
149 @deprecation.deprecated(
150 "2019-01-01",
151 "The TensorFlow Distributions library has moved to "
152 "TensorFlow Probability "
153 "(https://github.com/tensorflow/probability). You "
154 "should update all references to use `tfp.distributions` "
155 "instead of `tf.distributions`.",
156 warn_once=True)
157 def __init__(
158 self,
159 logits=None,
160 probs=None,
161 dtype=dtypes.int32,
162 validate_args=False,
163 allow_nan_stats=True,
164 name="Categorical"):
165 """Initialize Categorical distributions using class log-probabilities.
167 Args:
168 logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
169 of a set of Categorical distributions. The first `N - 1` dimensions
170 index into a batch of independent distributions and the last dimension
171 represents a vector of logits for each class. Only one of `logits` or
172 `probs` should be passed in.
173 probs: An N-D `Tensor`, `N >= 1`, representing the probabilities
174 of a set of Categorical distributions. The first `N - 1` dimensions
175 index into a batch of independent distributions and the last dimension
176 represents a vector of probabilities for each class. Only one of
177 `logits` or `probs` should be passed in.
178 dtype: The type of the event samples (default: int32).
179 validate_args: Python `bool`, default `False`. When `True` distribution
180 parameters are checked for validity despite possibly degrading runtime
181 performance. When `False` invalid inputs may silently render incorrect
182 outputs.
183 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
184 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
185 result is undefined. When `False`, an exception is raised if one or
186 more of the statistic's batch members are undefined.
187 name: Python `str` name prefixed to Ops created by this class.
188 """
189 parameters = dict(locals())
190 with ops.name_scope(name, values=[logits, probs]) as name:
191 self._logits, self._probs = distribution_util.get_logits_and_probs(
192 logits=logits,
193 probs=probs,
194 validate_args=validate_args,
195 multidimensional=True,
196 name=name)
198 if validate_args:
199 self._logits = distribution_util.embed_check_categorical_event_shape(
200 self._logits)
202 logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
203 if logits_shape_static.ndims is not None:
204 self._batch_rank = ops.convert_to_tensor(
205 logits_shape_static.ndims - 1,
206 dtype=dtypes.int32,
207 name="batch_rank")
208 else:
209 with ops.name_scope(name="batch_rank"):
210 self._batch_rank = array_ops.rank(self._logits) - 1
212 logits_shape = array_ops.shape(self._logits, name="logits_shape")
213 if tensor_shape.dimension_value(logits_shape_static[-1]) is not None:
214 self._event_size = ops.convert_to_tensor(
215 logits_shape_static.dims[-1].value,
216 dtype=dtypes.int32,
217 name="event_size")
218 else:
219 with ops.name_scope(name="event_size"):
220 self._event_size = logits_shape[self._batch_rank]
222 if logits_shape_static[:-1].is_fully_defined():
223 self._batch_shape_val = constant_op.constant(
224 logits_shape_static[:-1].as_list(),
225 dtype=dtypes.int32,
226 name="batch_shape")
227 else:
228 with ops.name_scope(name="batch_shape"):
229 self._batch_shape_val = logits_shape[:-1]
230 super(Categorical, self).__init__(
231 dtype=dtype,
232 reparameterization_type=distribution.NOT_REPARAMETERIZED,
233 validate_args=validate_args,
234 allow_nan_stats=allow_nan_stats,
235 parameters=parameters,
236 graph_parents=[self._logits,
237 self._probs],
238 name=name)
240 @property
241 def event_size(self):
242 """Scalar `int32` tensor: the number of classes."""
243 return self._event_size
245 @property
246 def logits(self):
247 """Vector of coordinatewise logits."""
248 return self._logits
250 @property
251 def probs(self):
252 """Vector of coordinatewise probabilities."""
253 return self._probs
255 def _batch_shape_tensor(self):
256 return array_ops.identity(self._batch_shape_val)
258 def _batch_shape(self):
259 return self.logits.get_shape()[:-1]
261 def _event_shape_tensor(self):
262 return constant_op.constant([], dtype=dtypes.int32)
264 def _event_shape(self):
265 return tensor_shape.TensorShape([])
267 def _sample_n(self, n, seed=None):
268 if self.logits.get_shape().ndims == 2:
269 logits_2d = self.logits
270 else:
271 logits_2d = array_ops.reshape(self.logits, [-1, self.event_size])
272 sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32
273 draws = random_ops.multinomial(
274 logits_2d, n, seed=seed, output_dtype=sample_dtype)
275 draws = array_ops.reshape(
276 array_ops.transpose(draws),
277 array_ops.concat([[n], self.batch_shape_tensor()], 0))
278 return math_ops.cast(draws, self.dtype)
280 def _cdf(self, k):
281 k = ops.convert_to_tensor(k, name="k")
282 if self.validate_args:
283 k = distribution_util.embed_check_integer_casting_closed(
284 k, target_dtype=dtypes.int32)
286 k, probs = _broadcast_cat_event_and_params(
287 k, self.probs, base_dtype=self.dtype.base_dtype)
289 # batch-flatten everything in order to use `sequence_mask()`.
290 batch_flattened_probs = array_ops.reshape(probs,
291 (-1, self._event_size))
292 batch_flattened_k = array_ops.reshape(k, [-1])
294 to_sum_over = array_ops.where(
295 array_ops.sequence_mask(batch_flattened_k, self._event_size),
296 batch_flattened_probs,
297 array_ops.zeros_like(batch_flattened_probs))
298 batch_flattened_cdf = math_ops.reduce_sum(to_sum_over, axis=-1)
299 # Reshape back to the shape of the argument.
300 return array_ops.reshape(batch_flattened_cdf, array_ops.shape(k))
302 def _log_prob(self, k):
303 k = ops.convert_to_tensor(k, name="k")
304 if self.validate_args:
305 k = distribution_util.embed_check_integer_casting_closed(
306 k, target_dtype=dtypes.int32)
307 k, logits = _broadcast_cat_event_and_params(
308 k, self.logits, base_dtype=self.dtype.base_dtype)
310 # pylint: disable=invalid-unary-operand-type
311 return -nn_ops.sparse_softmax_cross_entropy_with_logits(
312 labels=k,
313 logits=logits)
315 def _entropy(self):
316 return -math_ops.reduce_sum(
317 nn_ops.log_softmax(self.logits) * self.probs, axis=-1)
319 def _mode(self):
320 ret = math_ops.argmax(self.logits, axis=self._batch_rank)
321 ret = math_ops.cast(ret, self.dtype)
322 ret.set_shape(self.batch_shape)
323 return ret
326@kullback_leibler.RegisterKL(Categorical, Categorical)
327def _kl_categorical_categorical(a, b, name=None):
328 """Calculate the batched KL divergence KL(a || b) with a and b Categorical.
330 Args:
331 a: instance of a Categorical distribution object.
332 b: instance of a Categorical distribution object.
333 name: (optional) Name to use for created operations.
334 default is "kl_categorical_categorical".
336 Returns:
337 Batchwise KL(a || b)
338 """
339 with ops.name_scope(name, "kl_categorical_categorical",
340 values=[a.logits, b.logits]):
341 # sum(probs log(probs / (1 - probs)))
342 delta_log_probs1 = (nn_ops.log_softmax(a.logits) -
343 nn_ops.log_softmax(b.logits))
344 return math_ops.reduce_sum(nn_ops.softmax(a.logits) * delta_log_probs1,
345 axis=-1)