Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/bernoulli.py: 52%
67 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Bernoulli distribution class."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import tensor_shape
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops import nn
23from tensorflow.python.ops import random_ops
24from tensorflow.python.ops.distributions import distribution
25from tensorflow.python.ops.distributions import kullback_leibler
26from tensorflow.python.ops.distributions import util as distribution_util
27from tensorflow.python.util import deprecation
28from tensorflow.python.util.tf_export import tf_export
31@tf_export(v1=["distributions.Bernoulli"])
32class Bernoulli(distribution.Distribution):
33 """Bernoulli distribution.
35 The Bernoulli distribution with `probs` parameter, i.e., the probability of a
36 `1` outcome (vs a `0` outcome).
37 """
39 @deprecation.deprecated(
40 "2019-01-01",
41 "The TensorFlow Distributions library has moved to "
42 "TensorFlow Probability "
43 "(https://github.com/tensorflow/probability). You "
44 "should update all references to use `tfp.distributions` "
45 "instead of `tf.distributions`.",
46 warn_once=True)
47 def __init__(self,
48 logits=None,
49 probs=None,
50 dtype=dtypes.int32,
51 validate_args=False,
52 allow_nan_stats=True,
53 name="Bernoulli"):
54 """Construct Bernoulli distributions.
56 Args:
57 logits: An N-D `Tensor` representing the log-odds of a `1` event. Each
58 entry in the `Tensor` parametrizes an independent Bernoulli distribution
59 where the probability of an event is sigmoid(logits). Only one of
60 `logits` or `probs` should be passed in.
61 probs: An N-D `Tensor` representing the probability of a `1`
62 event. Each entry in the `Tensor` parameterizes an independent
63 Bernoulli distribution. Only one of `logits` or `probs` should be passed
64 in.
65 dtype: The type of the event samples. Default: `int32`.
66 validate_args: Python `bool`, default `False`. When `True` distribution
67 parameters are checked for validity despite possibly degrading runtime
68 performance. When `False` invalid inputs may silently render incorrect
69 outputs.
70 allow_nan_stats: Python `bool`, default `True`. When `True`,
71 statistics (e.g., mean, mode, variance) use the value "`NaN`" to
72 indicate the result is undefined. When `False`, an exception is raised
73 if one or more of the statistic's batch members are undefined.
74 name: Python `str` name prefixed to Ops created by this class.
76 Raises:
77 ValueError: If p and logits are passed, or if neither are passed.
78 """
79 parameters = dict(locals())
80 with ops.name_scope(name) as name:
81 self._logits, self._probs = distribution_util.get_logits_and_probs(
82 logits=logits,
83 probs=probs,
84 validate_args=validate_args,
85 name=name)
86 super(Bernoulli, self).__init__(
87 dtype=dtype,
88 reparameterization_type=distribution.NOT_REPARAMETERIZED,
89 validate_args=validate_args,
90 allow_nan_stats=allow_nan_stats,
91 parameters=parameters,
92 graph_parents=[self._logits, self._probs],
93 name=name)
95 @staticmethod
96 def _param_shapes(sample_shape):
97 return {"logits": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
99 @property
100 def logits(self):
101 """Log-odds of a `1` outcome (vs `0`)."""
102 return self._logits
104 @property
105 def probs(self):
106 """Probability of a `1` outcome (vs `0`)."""
107 return self._probs
109 def _batch_shape_tensor(self):
110 return array_ops.shape(self._logits)
112 def _batch_shape(self):
113 return self._logits.get_shape()
115 def _event_shape_tensor(self):
116 return array_ops.constant([], dtype=dtypes.int32)
118 def _event_shape(self):
119 return tensor_shape.TensorShape([])
121 def _sample_n(self, n, seed=None):
122 new_shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
123 uniform = random_ops.random_uniform(
124 new_shape, seed=seed, dtype=self.probs.dtype)
125 sample = math_ops.less(uniform, self.probs)
126 return math_ops.cast(sample, self.dtype)
128 def _log_prob(self, event):
129 if self.validate_args:
130 event = distribution_util.embed_check_integer_casting_closed(
131 event, target_dtype=dtypes.bool)
133 # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
134 # inconsistent behavior for logits = inf/-inf.
135 event = math_ops.cast(event, self.logits.dtype)
136 logits = self.logits
137 # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
138 # so we do this here.
140 def _broadcast(logits, event):
141 return (array_ops.ones_like(event) * logits,
142 array_ops.ones_like(logits) * event)
144 if not (event.get_shape().is_fully_defined() and
145 logits.get_shape().is_fully_defined() and
146 event.get_shape() == logits.get_shape()):
147 logits, event = _broadcast(logits, event)
148 return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
150 def _entropy(self):
151 return (-self.logits * (math_ops.sigmoid(self.logits) - 1) + # pylint: disable=invalid-unary-operand-type
152 nn.softplus(-self.logits)) # pylint: disable=invalid-unary-operand-type
154 def _mean(self):
155 return array_ops.identity(self.probs)
157 def _variance(self):
158 return self._mean() * (1. - self.probs)
160 def _mode(self):
161 """Returns `1` if `prob > 0.5` and `0` otherwise."""
162 return math_ops.cast(self.probs > 0.5, self.dtype)
165@kullback_leibler.RegisterKL(Bernoulli, Bernoulli)
166def _kl_bernoulli_bernoulli(a, b, name=None):
167 """Calculate the batched KL divergence KL(a || b) with a and b Bernoulli.
169 Args:
170 a: instance of a Bernoulli distribution object.
171 b: instance of a Bernoulli distribution object.
172 name: (optional) Name to use for created operations.
173 default is "kl_bernoulli_bernoulli".
175 Returns:
176 Batchwise KL(a || b)
177 """
178 with ops.name_scope(name, "kl_bernoulli_bernoulli",
179 values=[a.logits, b.logits]):
180 delta_probs0 = nn.softplus(-b.logits) - nn.softplus(-a.logits)
181 delta_probs1 = nn.softplus(b.logits) - nn.softplus(a.logits)
182 return (math_ops.sigmoid(a.logits) * delta_probs0
183 + math_ops.sigmoid(-a.logits) * delta_probs1)