Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/distributions/uniform.py: 53%

66 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The Uniform distribution class.""" 

16 

17import math 

18 

19from tensorflow.python.framework import constant_op 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import tensor_shape 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import check_ops 

25from tensorflow.python.ops import math_ops 

26from tensorflow.python.ops import random_ops 

27from tensorflow.python.ops.distributions import distribution 

28from tensorflow.python.util import deprecation 

29from tensorflow.python.util.tf_export import tf_export 

30 

31 

32@tf_export(v1=["distributions.Uniform"]) 

33class Uniform(distribution.Distribution): 

34 """Uniform distribution with `low` and `high` parameters. 

35 

36 #### Mathematical Details 

37 

38 The probability density function (pdf) is, 

39 

40 ```none 

41 pdf(x; a, b) = I[a <= x < b] / Z 

42 Z = b - a 

43 ``` 

44 

45 where 

46 

47 - `low = a`, 

48 - `high = b`, 

49 - `Z` is the normalizing constant, and 

50 - `I[predicate]` is the [indicator function]( 

51 https://en.wikipedia.org/wiki/Indicator_function) for `predicate`. 

52 

53 The parameters `low` and `high` must be shaped in a way that supports 

54 broadcasting (e.g., `high - low` is a valid operation). 

55 

56 #### Examples 

57 

58 ```python 

59 # Without broadcasting: 

60 u1 = Uniform(low=3.0, high=4.0) # a single uniform distribution [3, 4] 

61 u2 = Uniform(low=[1.0, 2.0], 

62 high=[3.0, 4.0]) # 2 distributions [1, 3], [2, 4] 

63 u3 = Uniform(low=[[1.0, 2.0], 

64 [3.0, 4.0]], 

65 high=[[1.5, 2.5], 

66 [3.5, 4.5]]) # 4 distributions 

67 ``` 

68 

69 ```python 

70 # With broadcasting: 

71 u1 = Uniform(low=3.0, high=[5.0, 6.0, 7.0]) # 3 distributions 

72 ``` 

73 

74 """ 

75 

76 @deprecation.deprecated( 

77 "2019-01-01", 

78 "The TensorFlow Distributions library has moved to " 

79 "TensorFlow Probability " 

80 "(https://github.com/tensorflow/probability). You " 

81 "should update all references to use `tfp.distributions` " 

82 "instead of `tf.distributions`.", 

83 warn_once=True) 

84 def __init__(self, 

85 low=0., 

86 high=1., 

87 validate_args=False, 

88 allow_nan_stats=True, 

89 name="Uniform"): 

90 """Initialize a batch of Uniform distributions. 

91 

92 Args: 

93 low: Floating point tensor, lower boundary of the output interval. Must 

94 have `low < high`. 

95 high: Floating point tensor, upper boundary of the output interval. Must 

96 have `low < high`. 

97 validate_args: Python `bool`, default `False`. When `True` distribution 

98 parameters are checked for validity despite possibly degrading runtime 

99 performance. When `False` invalid inputs may silently render incorrect 

100 outputs. 

101 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 

102 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 

103 result is undefined. When `False`, an exception is raised if one or 

104 more of the statistic's batch members are undefined. 

105 name: Python `str` name prefixed to Ops created by this class. 

106 

107 Raises: 

108 InvalidArgumentError: if `low >= high` and `validate_args=False`. 

109 """ 

110 parameters = dict(locals()) 

111 with ops.name_scope(name, values=[low, high]) as name: 

112 with ops.control_dependencies([ 

113 check_ops.assert_less( 

114 low, high, message="uniform not defined when low >= high.") 

115 ] if validate_args else []): 

116 self._low = array_ops.identity(low, name="low") 

117 self._high = array_ops.identity(high, name="high") 

118 check_ops.assert_same_float_dtype([self._low, self._high]) 

119 super(Uniform, self).__init__( 

120 dtype=self._low.dtype, 

121 reparameterization_type=distribution.FULLY_REPARAMETERIZED, 

122 validate_args=validate_args, 

123 allow_nan_stats=allow_nan_stats, 

124 parameters=parameters, 

125 graph_parents=[self._low, 

126 self._high], 

127 name=name) 

128 

129 @staticmethod 

130 def _param_shapes(sample_shape): 

131 return dict( 

132 zip(("low", "high"), 

133 ([ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))) 

134 

135 @property 

136 def low(self): 

137 """Lower boundary of the output interval.""" 

138 return self._low 

139 

140 @property 

141 def high(self): 

142 """Upper boundary of the output interval.""" 

143 return self._high 

144 

145 def range(self, name="range"): 

146 """`high - low`.""" 

147 with self._name_scope(name): 

148 return self.high - self.low 

149 

150 def _batch_shape_tensor(self): 

151 return array_ops.broadcast_dynamic_shape( 

152 array_ops.shape(self.low), 

153 array_ops.shape(self.high)) 

154 

155 def _batch_shape(self): 

156 return array_ops.broadcast_static_shape( 

157 self.low.get_shape(), 

158 self.high.get_shape()) 

159 

160 def _event_shape_tensor(self): 

161 return constant_op.constant([], dtype=dtypes.int32) 

162 

163 def _event_shape(self): 

164 return tensor_shape.TensorShape([]) 

165 

166 def _sample_n(self, n, seed=None): 

167 shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) 

168 samples = random_ops.random_uniform(shape=shape, 

169 dtype=self.dtype, 

170 seed=seed) 

171 return self.low + self.range() * samples 

172 

173 def _prob(self, x): 

174 broadcasted_x = x * array_ops.ones( 

175 self.batch_shape_tensor(), dtype=x.dtype) 

176 return array_ops.where_v2( 

177 math_ops.is_nan(broadcasted_x), broadcasted_x, 

178 array_ops.where_v2( 

179 math_ops.logical_or(broadcasted_x < self.low, 

180 broadcasted_x >= self.high), 

181 array_ops.zeros_like(broadcasted_x), 

182 array_ops.ones_like(broadcasted_x) / self.range())) 

183 

184 def _cdf(self, x): 

185 broadcast_shape = array_ops.broadcast_dynamic_shape( 

186 array_ops.shape(x), self.batch_shape_tensor()) 

187 zeros = array_ops.zeros(broadcast_shape, dtype=self.dtype) 

188 ones = array_ops.ones(broadcast_shape, dtype=self.dtype) 

189 broadcasted_x = x * ones 

190 result_if_not_big = array_ops.where_v2( 

191 x < self.low, zeros, (broadcasted_x - self.low) / self.range()) 

192 return array_ops.where_v2(x >= self.high, ones, result_if_not_big) 

193 

194 def _entropy(self): 

195 return math_ops.log(self.range()) 

196 

197 def _mean(self): 

198 return (self.low + self.high) / 2. 

199 

200 def _variance(self): 

201 return math_ops.square(self.range()) / 12. 

202 

203 def _stddev(self): 

204 return self.range() / math.sqrt(12.)