/src/sentencepiece/third_party/absl/random/gaussian_distribution.h
Line | Count | Source |
1 | | // Copyright 2017 The Abseil Authors. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // https://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #ifndef ABSL_RANDOM_GAUSSIAN_DISTRIBUTION_H_ |
16 | | #define ABSL_RANDOM_GAUSSIAN_DISTRIBUTION_H_ |
17 | | |
18 | | // absl::gaussian_distribution implements the Ziggurat algorithm |
19 | | // for generating random gaussian numbers. |
20 | | // |
21 | | // Implementation based on "The Ziggurat Method for Generating Random Variables" |
22 | | // by George Marsaglia and Wai Wan Tsang: http://www.jstatsoft.org/v05/i08/ |
23 | | // |
24 | | |
25 | | #include <cmath> |
26 | | #include <cstdint> |
27 | | #include <istream> |
28 | | #include <limits> |
29 | | #include <ostream> |
30 | | #include <type_traits> |
31 | | |
32 | | #include "absl/base/config.h" |
33 | | #include "absl/random/internal/fast_uniform_bits.h" |
34 | | #include "absl/random/internal/generate_real.h" |
35 | | #include "absl/random/internal/iostream_state_saver.h" |
36 | | |
37 | | namespace absl { |
38 | | ABSL_NAMESPACE_BEGIN |
39 | | namespace random_internal { |
40 | | |
41 | | // absl::gaussian_distribution_base implements the underlying ziggurat algorithm |
42 | | // using the ziggurat tables generated by the gaussian_distribution_gentables |
43 | | // binary. |
44 | | // |
45 | | // The specific algorithm has some of the improvements suggested by the |
46 | | // 2005 paper, "An Improved Ziggurat Method to Generate Normal Random Samples", |
47 | | // Jurgen A Doornik. (https://www.doornik.com/research/ziggurat.pdf) |
48 | | class ABSL_DLL gaussian_distribution_base { |
49 | | public: |
50 | | template <typename URBG> |
51 | | inline double zignor(URBG& g); // NOLINT(runtime/references) |
52 | | |
53 | | private: |
54 | | friend class TableGenerator; |
55 | | |
56 | | template <typename URBG> |
57 | | inline double zignor_fallback(URBG& g, // NOLINT(runtime/references) |
58 | | bool neg); |
59 | | |
60 | | // Constants used for the gaussian distribution. |
61 | | static constexpr double kR = 3.442619855899; // Start of the tail. |
62 | | static constexpr double kRInv = 0.29047645161474317; // ~= (1.0 / kR) . |
63 | | static constexpr double kV = 9.91256303526217e-3; |
64 | | static constexpr uint64_t kMask = 0x07f; |
65 | | |
66 | | // The ziggurat tables store the pdf(f) and inverse-pdf(x) for equal-area |
67 | | // points on one-half of the normal distribution, where the pdf function, |
68 | | // pdf = e ^ (-1/2 *x^2), assumes that the mean = 0 & stddev = 1. |
69 | | // |
70 | | // These tables are just over 2kb in size; larger tables might improve the |
71 | | // distributions, but also lead to more cache pollution. |
72 | | // |
73 | | // x = {3.71308, 3.44261, 3.22308, ..., 0} |
74 | | // f = {0.00101, 0.00266, 0.00554, ..., 1} |
75 | | struct Tables { |
76 | | double x[kMask + 2]; |
77 | | double f[kMask + 2]; |
78 | | }; |
79 | | static const Tables zg_; |
80 | | random_internal::FastUniformBits<uint64_t> fast_u64_; |
81 | | }; |
82 | | |
83 | | } // namespace random_internal |
84 | | |
85 | | // absl::gaussian_distribution: |
86 | | // Generates a number conforming to a Gaussian distribution. |
87 | | template <typename RealType = double> |
88 | | class gaussian_distribution : random_internal::gaussian_distribution_base { |
89 | | public: |
90 | | using result_type = RealType; |
91 | | |
92 | | class param_type { |
93 | | public: |
94 | | using distribution_type = gaussian_distribution; |
95 | | |
96 | | explicit param_type(result_type mean = 0, result_type stddev = 1) |
97 | 0 | : mean_(mean), stddev_(stddev) {} |
98 | | |
99 | | // Returns the mean distribution parameter. The mean specifies the location |
100 | | // of the peak. The default value is 0.0. |
101 | 0 | result_type mean() const { return mean_; } |
102 | | |
103 | | // Returns the deviation distribution parameter. The default value is 1.0. |
104 | 0 | result_type stddev() const { return stddev_; } |
105 | | |
106 | | friend bool operator==(const param_type& a, const param_type& b) { |
107 | | return a.mean_ == b.mean_ && a.stddev_ == b.stddev_; |
108 | | } |
109 | | |
110 | | friend bool operator!=(const param_type& a, const param_type& b) { |
111 | | return !(a == b); |
112 | | } |
113 | | |
114 | | private: |
115 | | result_type mean_; |
116 | | result_type stddev_; |
117 | | |
118 | | static_assert( |
119 | | std::is_floating_point<RealType>::value, |
120 | | "Class-template absl::gaussian_distribution<> must be parameterized " |
121 | | "using a floating-point type."); |
122 | | }; |
123 | | |
124 | | gaussian_distribution() : gaussian_distribution(0) {} |
125 | | |
126 | | explicit gaussian_distribution(result_type mean, result_type stddev = 1) |
127 | 0 | : param_(mean, stddev) {} |
128 | | |
129 | | explicit gaussian_distribution(const param_type& p) : param_(p) {} |
130 | | |
131 | | void reset() {} |
132 | | |
133 | | // Generating functions |
134 | | template <typename URBG> |
135 | 0 | result_type operator()(URBG& g) { // NOLINT(runtime/references) |
136 | 0 | return (*this)(g, param_); |
137 | 0 | } |
138 | | |
139 | | template <typename URBG> |
140 | | result_type operator()(URBG& g, // NOLINT(runtime/references) |
141 | | const param_type& p); |
142 | | |
143 | | param_type param() const { return param_; } |
144 | | void param(const param_type& p) { param_ = p; } |
145 | | |
146 | | result_type(min)() const { |
147 | | return -std::numeric_limits<result_type>::infinity(); |
148 | | } |
149 | | result_type(max)() const { |
150 | | return std::numeric_limits<result_type>::infinity(); |
151 | | } |
152 | | |
153 | | result_type mean() const { return param_.mean(); } |
154 | | result_type stddev() const { return param_.stddev(); } |
155 | | |
156 | | friend bool operator==(const gaussian_distribution& a, |
157 | | const gaussian_distribution& b) { |
158 | | return a.param_ == b.param_; |
159 | | } |
160 | | friend bool operator!=(const gaussian_distribution& a, |
161 | | const gaussian_distribution& b) { |
162 | | return a.param_ != b.param_; |
163 | | } |
164 | | |
165 | | private: |
166 | | param_type param_; |
167 | | }; |
168 | | |
169 | | // -------------------------------------------------------------------------- |
170 | | // Implementation details only below |
171 | | // -------------------------------------------------------------------------- |
172 | | |
173 | | template <typename RealType> |
174 | | template <typename URBG> |
175 | | typename gaussian_distribution<RealType>::result_type |
176 | | gaussian_distribution<RealType>::operator()( |
177 | | URBG& g, // NOLINT(runtime/references) |
178 | 0 | const param_type& p) { |
179 | 0 | return p.mean() + p.stddev() * static_cast<result_type>(zignor(g)); |
180 | 0 | } |
181 | | |
182 | | template <typename CharT, typename Traits, typename RealType> |
183 | | std::basic_ostream<CharT, Traits>& operator<<( |
184 | | std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) |
185 | | const gaussian_distribution<RealType>& x) { |
186 | | auto saver = random_internal::make_ostream_state_saver(os); |
187 | | os.precision(random_internal::stream_precision_helper<RealType>::kPrecision); |
188 | | os << x.mean() << os.fill() << x.stddev(); |
189 | | return os; |
190 | | } |
191 | | |
192 | | template <typename CharT, typename Traits, typename RealType> |
193 | | std::basic_istream<CharT, Traits>& operator>>( |
194 | | std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) |
195 | | gaussian_distribution<RealType>& x) { // NOLINT(runtime/references) |
196 | | using result_type = typename gaussian_distribution<RealType>::result_type; |
197 | | using param_type = typename gaussian_distribution<RealType>::param_type; |
198 | | |
199 | | auto saver = random_internal::make_istream_state_saver(is); |
200 | | auto mean = random_internal::read_floating_point<result_type>(is); |
201 | | if (is.fail()) return is; |
202 | | auto stddev = random_internal::read_floating_point<result_type>(is); |
203 | | if (!is.fail()) { |
204 | | x.param(param_type(mean, stddev)); |
205 | | } |
206 | | return is; |
207 | | } |
208 | | |
209 | | namespace random_internal { |
210 | | |
211 | | template <typename URBG> |
212 | 0 | inline double gaussian_distribution_base::zignor_fallback(URBG& g, bool neg) { |
213 | 0 | using random_internal::GeneratePositiveTag; |
214 | 0 | using random_internal::GenerateRealFromBits; |
215 | | |
216 | | // This fallback path happens approximately 0.05% of the time. |
217 | 0 | double x, y; |
218 | 0 | do { |
219 | | // kRInv = 1/r, U(0, 1) |
220 | 0 | x = kRInv * |
221 | 0 | std::log(GenerateRealFromBits<double, GeneratePositiveTag, false>( |
222 | 0 | fast_u64_(g))); |
223 | 0 | y = -std::log( |
224 | 0 | GenerateRealFromBits<double, GeneratePositiveTag, false>(fast_u64_(g))); |
225 | 0 | } while ((y + y) < (x * x)); |
226 | 0 | return neg ? (x - kR) : (kR - x); |
227 | 0 | } |
228 | | |
229 | | template <typename URBG> |
230 | | inline double gaussian_distribution_base::zignor( |
231 | 0 | URBG& g) { // NOLINT(runtime/references) |
232 | 0 | using random_internal::GeneratePositiveTag; |
233 | 0 | using random_internal::GenerateRealFromBits; |
234 | 0 | using random_internal::GenerateSignedTag; |
235 | |
|
236 | 0 | while (true) { |
237 | | // We use a single uint64_t to generate both a double and a strip. |
238 | | // These bits are unused when the generated double is > 1/2^5. |
239 | | // This may introduce some bias from the duplicated low bits of small |
240 | | // values (those smaller than 1/2^5, which all end up on the left tail). |
241 | 0 | uint64_t bits = fast_u64_(g); |
242 | 0 | int i = static_cast<int>(bits & kMask); // pick a random strip |
243 | 0 | double j = GenerateRealFromBits<double, GenerateSignedTag, false>( |
244 | 0 | bits); // U(-1, 1) |
245 | 0 | const double x = j * zg_.x[i]; |
246 | | |
247 | | // Rectangular box. Handles >97% of all cases. |
248 | | // For any given box, this handles between 75% and 99% of values. |
249 | | // Equivalent to U(01) < (x[i+1] / x[i]), and when i == 0, ~93.5% |
250 | 0 | if (std::abs(x) < zg_.x[i + 1]) { |
251 | 0 | return x; |
252 | 0 | } |
253 | | |
254 | | // i == 0: Base box. Sample using a ratio of uniforms. |
255 | 0 | if (i == 0) { |
256 | | // This path happens about 0.05% of the time. |
257 | 0 | return zignor_fallback(g, j < 0); |
258 | 0 | } |
259 | | |
260 | | // i > 0: Wedge samples using precomputed values. |
261 | 0 | double v = GenerateRealFromBits<double, GeneratePositiveTag, false>( |
262 | 0 | fast_u64_(g)); // U(0, 1) |
263 | 0 | if ((zg_.f[i + 1] + v * (zg_.f[i] - zg_.f[i + 1])) < |
264 | 0 | std::exp(-0.5 * x * x)) { |
265 | 0 | return x; |
266 | 0 | } |
267 | | |
268 | | // The wedge was missed; reject the value and try again. |
269 | 0 | } |
270 | 0 | } |
271 | | |
272 | | } // namespace random_internal |
273 | | ABSL_NAMESPACE_END |
274 | | } // namespace absl |
275 | | |
276 | | #endif // ABSL_RANDOM_GAUSSIAN_DISTRIBUTION_H_ |