Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/uniform.py: 53%
66 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Uniform distribution class."""
17import math
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import random_ops
27from tensorflow.python.ops.distributions import distribution
28from tensorflow.python.util import deprecation
29from tensorflow.python.util.tf_export import tf_export
32@tf_export(v1=["distributions.Uniform"])
33class Uniform(distribution.Distribution):
34 """Uniform distribution with `low` and `high` parameters.
36 #### Mathematical Details
38 The probability density function (pdf) is,
40 ```none
41 pdf(x; a, b) = I[a <= x < b] / Z
42 Z = b - a
43 ```
45 where
47 - `low = a`,
48 - `high = b`,
49 - `Z` is the normalizing constant, and
50 - `I[predicate]` is the [indicator function](
51 https://en.wikipedia.org/wiki/Indicator_function) for `predicate`.
53 The parameters `low` and `high` must be shaped in a way that supports
54 broadcasting (e.g., `high - low` is a valid operation).
56 #### Examples
58 ```python
59 # Without broadcasting:
60 u1 = Uniform(low=3.0, high=4.0) # a single uniform distribution [3, 4]
61 u2 = Uniform(low=[1.0, 2.0],
62 high=[3.0, 4.0]) # 2 distributions [1, 3], [2, 4]
63 u3 = Uniform(low=[[1.0, 2.0],
64 [3.0, 4.0]],
65 high=[[1.5, 2.5],
66 [3.5, 4.5]]) # 4 distributions
67 ```
69 ```python
70 # With broadcasting:
71 u1 = Uniform(low=3.0, high=[5.0, 6.0, 7.0]) # 3 distributions
72 ```
74 """
76 @deprecation.deprecated(
77 "2019-01-01",
78 "The TensorFlow Distributions library has moved to "
79 "TensorFlow Probability "
80 "(https://github.com/tensorflow/probability). You "
81 "should update all references to use `tfp.distributions` "
82 "instead of `tf.distributions`.",
83 warn_once=True)
84 def __init__(self,
85 low=0.,
86 high=1.,
87 validate_args=False,
88 allow_nan_stats=True,
89 name="Uniform"):
90 """Initialize a batch of Uniform distributions.
92 Args:
93 low: Floating point tensor, lower boundary of the output interval. Must
94 have `low < high`.
95 high: Floating point tensor, upper boundary of the output interval. Must
96 have `low < high`.
97 validate_args: Python `bool`, default `False`. When `True` distribution
98 parameters are checked for validity despite possibly degrading runtime
99 performance. When `False` invalid inputs may silently render incorrect
100 outputs.
101 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
102 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
103 result is undefined. When `False`, an exception is raised if one or
104 more of the statistic's batch members are undefined.
105 name: Python `str` name prefixed to Ops created by this class.
107 Raises:
108 InvalidArgumentError: if `low >= high` and `validate_args=False`.
109 """
110 parameters = dict(locals())
111 with ops.name_scope(name, values=[low, high]) as name:
112 with ops.control_dependencies([
113 check_ops.assert_less(
114 low, high, message="uniform not defined when low >= high.")
115 ] if validate_args else []):
116 self._low = array_ops.identity(low, name="low")
117 self._high = array_ops.identity(high, name="high")
118 check_ops.assert_same_float_dtype([self._low, self._high])
119 super(Uniform, self).__init__(
120 dtype=self._low.dtype,
121 reparameterization_type=distribution.FULLY_REPARAMETERIZED,
122 validate_args=validate_args,
123 allow_nan_stats=allow_nan_stats,
124 parameters=parameters,
125 graph_parents=[self._low,
126 self._high],
127 name=name)
129 @staticmethod
130 def _param_shapes(sample_shape):
131 return dict(
132 zip(("low", "high"),
133 ([ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)))
135 @property
136 def low(self):
137 """Lower boundary of the output interval."""
138 return self._low
140 @property
141 def high(self):
142 """Upper boundary of the output interval."""
143 return self._high
145 def range(self, name="range"):
146 """`high - low`."""
147 with self._name_scope(name):
148 return self.high - self.low
150 def _batch_shape_tensor(self):
151 return array_ops.broadcast_dynamic_shape(
152 array_ops.shape(self.low),
153 array_ops.shape(self.high))
155 def _batch_shape(self):
156 return array_ops.broadcast_static_shape(
157 self.low.get_shape(),
158 self.high.get_shape())
160 def _event_shape_tensor(self):
161 return constant_op.constant([], dtype=dtypes.int32)
163 def _event_shape(self):
164 return tensor_shape.TensorShape([])
166 def _sample_n(self, n, seed=None):
167 shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
168 samples = random_ops.random_uniform(shape=shape,
169 dtype=self.dtype,
170 seed=seed)
171 return self.low + self.range() * samples
173 def _prob(self, x):
174 broadcasted_x = x * array_ops.ones(
175 self.batch_shape_tensor(), dtype=x.dtype)
176 return array_ops.where_v2(
177 math_ops.is_nan(broadcasted_x), broadcasted_x,
178 array_ops.where_v2(
179 math_ops.logical_or(broadcasted_x < self.low,
180 broadcasted_x >= self.high),
181 array_ops.zeros_like(broadcasted_x),
182 array_ops.ones_like(broadcasted_x) / self.range()))
184 def _cdf(self, x):
185 broadcast_shape = array_ops.broadcast_dynamic_shape(
186 array_ops.shape(x), self.batch_shape_tensor())
187 zeros = array_ops.zeros(broadcast_shape, dtype=self.dtype)
188 ones = array_ops.ones(broadcast_shape, dtype=self.dtype)
189 broadcasted_x = x * ones
190 result_if_not_big = array_ops.where_v2(
191 x < self.low, zeros, (broadcasted_x - self.low) / self.range())
192 return array_ops.where_v2(x >= self.high, ones, result_if_not_big)
194 def _entropy(self):
195 return math_ops.log(self.range())
197 def _mean(self):
198 return (self.low + self.high) / 2.
200 def _variance(self):
201 return math_ops.square(self.range()) / 12.
203 def _stddev(self):
204 return self.range() / math.sqrt(12.)