Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/bernoulli.py: 52%

67 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The Bernoulli distribution class.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.framework import tensor_shape 

20from tensorflow.python.ops import array_ops 

21from tensorflow.python.ops import math_ops 

22from tensorflow.python.ops import nn 

23from tensorflow.python.ops import random_ops 

24from tensorflow.python.ops.distributions import distribution 

25from tensorflow.python.ops.distributions import kullback_leibler 

26from tensorflow.python.ops.distributions import util as distribution_util 

27from tensorflow.python.util import deprecation 

28from tensorflow.python.util.tf_export import tf_export 

29 

30 

31@tf_export(v1=["distributions.Bernoulli"]) 

32class Bernoulli(distribution.Distribution): 

33 """Bernoulli distribution. 

34 

35 The Bernoulli distribution with `probs` parameter, i.e., the probability of a 

36 `1` outcome (vs a `0` outcome). 

37 """ 

38 

39 @deprecation.deprecated( 

40 "2019-01-01", 

41 "The TensorFlow Distributions library has moved to " 

42 "TensorFlow Probability " 

43 "(https://github.com/tensorflow/probability). You " 

44 "should update all references to use `tfp.distributions` " 

45 "instead of `tf.distributions`.", 

46 warn_once=True) 

47 def __init__(self, 

48 logits=None, 

49 probs=None, 

50 dtype=dtypes.int32, 

51 validate_args=False, 

52 allow_nan_stats=True, 

53 name="Bernoulli"): 

54 """Construct Bernoulli distributions. 

55 

56 Args: 

57 logits: An N-D `Tensor` representing the log-odds of a `1` event. Each 

58 entry in the `Tensor` parametrizes an independent Bernoulli distribution 

59 where the probability of an event is sigmoid(logits). Only one of 

60 `logits` or `probs` should be passed in. 

61 probs: An N-D `Tensor` representing the probability of a `1` 

62 event. Each entry in the `Tensor` parameterizes an independent 

63 Bernoulli distribution. Only one of `logits` or `probs` should be passed 

64 in. 

65 dtype: The type of the event samples. Default: `int32`. 

66 validate_args: Python `bool`, default `False`. When `True` distribution 

67 parameters are checked for validity despite possibly degrading runtime 

68 performance. When `False` invalid inputs may silently render incorrect 

69 outputs. 

70 allow_nan_stats: Python `bool`, default `True`. When `True`, 

71 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 

72 indicate the result is undefined. When `False`, an exception is raised 

73 if one or more of the statistic's batch members are undefined. 

74 name: Python `str` name prefixed to Ops created by this class. 

75 

76 Raises: 

77 ValueError: If p and logits are passed, or if neither are passed. 

78 """ 

79 parameters = dict(locals()) 

80 with ops.name_scope(name) as name: 

81 self._logits, self._probs = distribution_util.get_logits_and_probs( 

82 logits=logits, 

83 probs=probs, 

84 validate_args=validate_args, 

85 name=name) 

86 super(Bernoulli, self).__init__( 

87 dtype=dtype, 

88 reparameterization_type=distribution.NOT_REPARAMETERIZED, 

89 validate_args=validate_args, 

90 allow_nan_stats=allow_nan_stats, 

91 parameters=parameters, 

92 graph_parents=[self._logits, self._probs], 

93 name=name) 

94 

95 @staticmethod 

96 def _param_shapes(sample_shape): 

97 return {"logits": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} 

98 

99 @property 

100 def logits(self): 

101 """Log-odds of a `1` outcome (vs `0`).""" 

102 return self._logits 

103 

104 @property 

105 def probs(self): 

106 """Probability of a `1` outcome (vs `0`).""" 

107 return self._probs 

108 

109 def _batch_shape_tensor(self): 

110 return array_ops.shape(self._logits) 

111 

112 def _batch_shape(self): 

113 return self._logits.get_shape() 

114 

115 def _event_shape_tensor(self): 

116 return array_ops.constant([], dtype=dtypes.int32) 

117 

118 def _event_shape(self): 

119 return tensor_shape.TensorShape([]) 

120 

121 def _sample_n(self, n, seed=None): 

122 new_shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) 

123 uniform = random_ops.random_uniform( 

124 new_shape, seed=seed, dtype=self.probs.dtype) 

125 sample = math_ops.less(uniform, self.probs) 

126 return math_ops.cast(sample, self.dtype) 

127 

128 def _log_prob(self, event): 

129 if self.validate_args: 

130 event = distribution_util.embed_check_integer_casting_closed( 

131 event, target_dtype=dtypes.bool) 

132 

133 # TODO(jaana): The current sigmoid_cross_entropy_with_logits has 

134 # inconsistent behavior for logits = inf/-inf. 

135 event = math_ops.cast(event, self.logits.dtype) 

136 logits = self.logits 

137 # sigmoid_cross_entropy_with_logits doesn't broadcast shape, 

138 # so we do this here. 

139 

140 def _broadcast(logits, event): 

141 return (array_ops.ones_like(event) * logits, 

142 array_ops.ones_like(logits) * event) 

143 

144 if not (event.get_shape().is_fully_defined() and 

145 logits.get_shape().is_fully_defined() and 

146 event.get_shape() == logits.get_shape()): 

147 logits, event = _broadcast(logits, event) 

148 return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits) 

149 

150 def _entropy(self): 

151 return (-self.logits * (math_ops.sigmoid(self.logits) - 1) + # pylint: disable=invalid-unary-operand-type 

152 nn.softplus(-self.logits)) # pylint: disable=invalid-unary-operand-type 

153 

154 def _mean(self): 

155 return array_ops.identity(self.probs) 

156 

157 def _variance(self): 

158 return self._mean() * (1. - self.probs) 

159 

160 def _mode(self): 

161 """Returns `1` if `prob > 0.5` and `0` otherwise.""" 

162 return math_ops.cast(self.probs > 0.5, self.dtype) 

163 

164 

165@kullback_leibler.RegisterKL(Bernoulli, Bernoulli) 

166def _kl_bernoulli_bernoulli(a, b, name=None): 

167 """Calculate the batched KL divergence KL(a || b) with a and b Bernoulli. 

168 

169 Args: 

170 a: instance of a Bernoulli distribution object. 

171 b: instance of a Bernoulli distribution object. 

172 name: (optional) Name to use for created operations. 

173 default is "kl_bernoulli_bernoulli". 

174 

175 Returns: 

176 Batchwise KL(a || b) 

177 """ 

178 with ops.name_scope(name, "kl_bernoulli_bernoulli", 

179 values=[a.logits, b.logits]): 

180 delta_probs0 = nn.softplus(-b.logits) - nn.softplus(-a.logits) 

181 delta_probs1 = nn.softplus(b.logits) - nn.softplus(a.logits) 

182 return (math_ops.sigmoid(a.logits) * delta_probs0 

183 + math_ops.sigmoid(-a.logits) * delta_probs1)