Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/normal.py: 53%
97 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Normal (Gaussian) distribution class."""
17import math
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import nn
27from tensorflow.python.ops import random_ops
28from tensorflow.python.ops.distributions import distribution
29from tensorflow.python.ops.distributions import kullback_leibler
30from tensorflow.python.ops.distributions import special_math
31from tensorflow.python.util import deprecation
32from tensorflow.python.util.tf_export import tf_export
35__all__ = [
36 "Normal",
37 "NormalWithSoftplusScale",
38]
41@tf_export(v1=["distributions.Normal"])
42class Normal(distribution.Distribution):
43 """The Normal distribution with location `loc` and `scale` parameters.
45 #### Mathematical details
47 The probability density function (pdf) is,
49 ```none
50 pdf(x; mu, sigma) = exp(-0.5 (x - mu)**2 / sigma**2) / Z
51 Z = (2 pi sigma**2)**0.5
52 ```
54 where `loc = mu` is the mean, `scale = sigma` is the std. deviation, and, `Z`
55 is the normalization constant.
57 The Normal distribution is a member of the [location-scale family](
58 https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
59 constructed as,
61 ```none
62 X ~ Normal(loc=0, scale=1)
63 Y = loc + scale * X
64 ```
66 #### Examples
68 Examples of initialization of one or a batch of distributions.
70 ```python
71 import tensorflow_probability as tfp
72 tfd = tfp.distributions
74 # Define a single scalar Normal distribution.
75 dist = tfd.Normal(loc=0., scale=3.)
77 # Evaluate the cdf at 1, returning a scalar.
78 dist.cdf(1.)
80 # Define a batch of two scalar valued Normals.
81 # The first has mean 1 and standard deviation 11, the second 2 and 22.
82 dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
84 # Evaluate the pdf of the first distribution on 0, and the second on 1.5,
85 # returning a length two tensor.
86 dist.prob([0, 1.5])
88 # Get 3 samples, returning a 3 x 2 tensor.
89 dist.sample([3])
90 ```
92 Arguments are broadcast when possible.
94 ```python
95 # Define a batch of two scalar valued Normals.
96 # Both have mean 1, but different standard deviations.
97 dist = tfd.Normal(loc=1., scale=[11, 22.])
99 # Evaluate the pdf of both distributions on the same point, 3.0,
100 # returning a length 2 tensor.
101 dist.prob(3.0)
102 ```
104 """
106 @deprecation.deprecated(
107 "2019-01-01",
108 "The TensorFlow Distributions library has moved to "
109 "TensorFlow Probability "
110 "(https://github.com/tensorflow/probability). You "
111 "should update all references to use `tfp.distributions` "
112 "instead of `tf.distributions`.",
113 warn_once=True)
114 def __init__(self,
115 loc,
116 scale,
117 validate_args=False,
118 allow_nan_stats=True,
119 name="Normal"):
120 """Construct Normal distributions with mean and stddev `loc` and `scale`.
122 The parameters `loc` and `scale` must be shaped in a way that supports
123 broadcasting (e.g. `loc + scale` is a valid operation).
125 Args:
126 loc: Floating point tensor; the means of the distribution(s).
127 scale: Floating point tensor; the stddevs of the distribution(s).
128 Must contain only positive values.
129 validate_args: Python `bool`, default `False`. When `True` distribution
130 parameters are checked for validity despite possibly degrading runtime
131 performance. When `False` invalid inputs may silently render incorrect
132 outputs.
133 allow_nan_stats: Python `bool`, default `True`. When `True`,
134 statistics (e.g., mean, mode, variance) use the value "`NaN`" to
135 indicate the result is undefined. When `False`, an exception is raised
136 if one or more of the statistic's batch members are undefined.
137 name: Python `str` name prefixed to Ops created by this class.
139 Raises:
140 TypeError: if `loc` and `scale` have different `dtype`.
141 """
142 parameters = dict(locals())
143 with ops.name_scope(name, values=[loc, scale]) as name:
144 with ops.control_dependencies([check_ops.assert_positive(scale)] if
145 validate_args else []):
146 self._loc = array_ops.identity(loc, name="loc")
147 self._scale = array_ops.identity(scale, name="scale")
148 check_ops.assert_same_float_dtype([self._loc, self._scale])
149 super(Normal, self).__init__(
150 dtype=self._scale.dtype,
151 reparameterization_type=distribution.FULLY_REPARAMETERIZED,
152 validate_args=validate_args,
153 allow_nan_stats=allow_nan_stats,
154 parameters=parameters,
155 graph_parents=[self._loc, self._scale],
156 name=name)
158 @staticmethod
159 def _param_shapes(sample_shape):
160 return dict(
161 zip(("loc", "scale"), ([ops.convert_to_tensor(
162 sample_shape, dtype=dtypes.int32)] * 2)))
164 @property
165 def loc(self):
166 """Distribution parameter for the mean."""
167 return self._loc
169 @property
170 def scale(self):
171 """Distribution parameter for standard deviation."""
172 return self._scale
174 def _batch_shape_tensor(self):
175 return array_ops.broadcast_dynamic_shape(
176 array_ops.shape(self.loc),
177 array_ops.shape(self.scale))
179 def _batch_shape(self):
180 return array_ops.broadcast_static_shape(
181 self.loc.get_shape(),
182 self.scale.get_shape())
184 def _event_shape_tensor(self):
185 return constant_op.constant([], dtype=dtypes.int32)
187 def _event_shape(self):
188 return tensor_shape.TensorShape([])
190 def _sample_n(self, n, seed=None):
191 shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
192 sampled = random_ops.random_normal(
193 shape=shape, mean=0., stddev=1., dtype=self.loc.dtype, seed=seed)
194 return sampled * self.scale + self.loc
196 def _log_prob(self, x):
197 return self._log_unnormalized_prob(x) - self._log_normalization()
199 def _log_cdf(self, x):
200 return special_math.log_ndtr(self._z(x))
202 def _cdf(self, x):
203 return special_math.ndtr(self._z(x))
205 def _log_survival_function(self, x):
206 return special_math.log_ndtr(-self._z(x))
208 def _survival_function(self, x):
209 return special_math.ndtr(-self._z(x))
211 def _log_unnormalized_prob(self, x):
212 return -0.5 * math_ops.square(self._z(x))
214 def _log_normalization(self):
215 return 0.5 * math.log(2. * math.pi) + math_ops.log(self.scale)
217 def _entropy(self):
218 # Use broadcasting rules to calculate the full broadcast scale.
219 scale = self.scale * array_ops.ones_like(self.loc)
220 return 0.5 * math.log(2. * math.pi * math.e) + math_ops.log(scale)
222 def _mean(self):
223 return self.loc * array_ops.ones_like(self.scale)
225 def _quantile(self, p):
226 return self._inv_z(special_math.ndtri(p))
228 def _stddev(self):
229 return self.scale * array_ops.ones_like(self.loc)
231 def _mode(self):
232 return self._mean()
234 def _z(self, x):
235 """Standardize input `x` to a unit normal."""
236 with ops.name_scope("standardize", values=[x]):
237 return (x - self.loc) / self.scale
239 def _inv_z(self, z):
240 """Reconstruct input `x` from a its normalized version."""
241 with ops.name_scope("reconstruct", values=[z]):
242 return z * self.scale + self.loc
245class NormalWithSoftplusScale(Normal):
246 """Normal with softplus applied to `scale`."""
248 @deprecation.deprecated(
249 "2019-01-01",
250 "Use `tfd.Normal(loc, tf.nn.softplus(scale)) "
251 "instead.",
252 warn_once=True)
253 def __init__(self,
254 loc,
255 scale,
256 validate_args=False,
257 allow_nan_stats=True,
258 name="NormalWithSoftplusScale"):
259 parameters = dict(locals())
260 with ops.name_scope(name, values=[scale]) as name:
261 super(NormalWithSoftplusScale, self).__init__(
262 loc=loc,
263 scale=nn.softplus(scale, name="softplus_scale"),
264 validate_args=validate_args,
265 allow_nan_stats=allow_nan_stats,
266 name=name)
267 self._parameters = parameters
270@kullback_leibler.RegisterKL(Normal, Normal)
271def _kl_normal_normal(n_a, n_b, name=None):
272 """Calculate the batched KL divergence KL(n_a || n_b) with n_a and n_b Normal.
274 Args:
275 n_a: instance of a Normal distribution object.
276 n_b: instance of a Normal distribution object.
277 name: (optional) Name to use for created operations.
278 default is "kl_normal_normal".
280 Returns:
281 Batchwise KL(n_a || n_b)
282 """
283 with ops.name_scope(name, "kl_normal_normal", [n_a.loc, n_b.loc]):
284 one = constant_op.constant(1, dtype=n_a.dtype)
285 two = constant_op.constant(2, dtype=n_a.dtype)
286 half = constant_op.constant(0.5, dtype=n_a.dtype)
287 s_a_squared = math_ops.square(n_a.scale)
288 s_b_squared = math_ops.square(n_b.scale)
289 ratio = s_a_squared / s_b_squared
290 return (math_ops.squared_difference(n_a.loc, n_b.loc) / (two * s_b_squared)
291 + half * (ratio - one - math_ops.log(ratio)))