/src/sentencepiece/third_party/absl/random/bernoulli_distribution.h
Line | Count | Source |
1 | | // Copyright 2017 The Abseil Authors. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // https://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #ifndef ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_ |
16 | | #define ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_ |
17 | | |
18 | | #include <cassert> |
19 | | #include <cstdint> |
20 | | #include <istream> |
21 | | #include <ostream> |
22 | | |
23 | | #include "absl/base/config.h" |
24 | | #include "absl/base/optimization.h" |
25 | | #include "absl/random/internal/fast_uniform_bits.h" |
26 | | #include "absl/random/internal/iostream_state_saver.h" |
27 | | |
28 | | namespace absl { |
29 | | ABSL_NAMESPACE_BEGIN |
30 | | |
31 | | // absl::bernoulli_distribution is a drop in replacement for |
32 | | // std::bernoulli_distribution. It guarantees that (given a perfect |
33 | | // UniformRandomBitGenerator) the acceptance probability is *exactly* equal to |
34 | | // the given double. |
35 | | // |
36 | | // The implementation assumes that double is IEEE754 |
37 | | class bernoulli_distribution { |
38 | | public: |
39 | | using result_type = bool; |
40 | | |
41 | | class param_type { |
42 | | public: |
43 | | using distribution_type = bernoulli_distribution; |
44 | | |
45 | 0 | explicit param_type(double p = 0.5) : prob_(p) { |
46 | 0 | assert(p >= 0.0 && p <= 1.0); |
47 | 0 | } |
48 | | |
49 | 0 | double p() const { return prob_; } |
50 | | |
51 | 0 | friend bool operator==(const param_type& p1, const param_type& p2) { |
52 | 0 | return p1.p() == p2.p(); |
53 | 0 | } |
54 | 0 | friend bool operator!=(const param_type& p1, const param_type& p2) { |
55 | 0 | return p1.p() != p2.p(); |
56 | 0 | } |
57 | | |
58 | | private: |
59 | | double prob_; |
60 | | }; |
61 | | |
62 | 0 | bernoulli_distribution() : bernoulli_distribution(0.5) {} |
63 | | |
64 | 0 | explicit bernoulli_distribution(double p) : param_(p) {} |
65 | | |
66 | 0 | explicit bernoulli_distribution(param_type p) : param_(p) {} |
67 | | |
68 | | // no-op |
69 | 0 | void reset() {} |
70 | | |
71 | | template <typename URBG> |
72 | 0 | bool operator()(URBG& g) { // NOLINT(runtime/references) |
73 | 0 | return Generate(param_.p(), g); |
74 | 0 | } |
75 | | |
76 | | template <typename URBG> |
77 | | bool operator()(URBG& g, // NOLINT(runtime/references) |
78 | | const param_type& param) { |
79 | | return Generate(param.p(), g); |
80 | | } |
81 | | |
82 | 0 | param_type param() const { return param_; } |
83 | 0 | void param(const param_type& param) { param_ = param; } |
84 | | |
85 | 0 | double p() const { return param_.p(); } |
86 | | |
87 | 0 | result_type(min)() const { return false; } |
88 | 0 | result_type(max)() const { return true; } |
89 | | |
90 | | friend bool operator==(const bernoulli_distribution& d1, |
91 | 0 | const bernoulli_distribution& d2) { |
92 | 0 | return d1.param_ == d2.param_; |
93 | 0 | } |
94 | | |
95 | | friend bool operator!=(const bernoulli_distribution& d1, |
96 | 0 | const bernoulli_distribution& d2) { |
97 | 0 | return d1.param_ != d2.param_; |
98 | 0 | } |
99 | | |
100 | | private: |
101 | | static constexpr uint64_t kP32 = static_cast<uint64_t>(1) << 32; |
102 | | |
103 | | template <typename URBG> |
104 | | static bool Generate(double p, URBG& g); // NOLINT(runtime/references) |
105 | | |
106 | | param_type param_; |
107 | | }; |
108 | | |
109 | | template <typename CharT, typename Traits> |
110 | | std::basic_ostream<CharT, Traits>& operator<<( |
111 | | std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references) |
112 | | const bernoulli_distribution& x) { |
113 | | auto saver = random_internal::make_ostream_state_saver(os); |
114 | | os.precision(random_internal::stream_precision_helper<double>::kPrecision); |
115 | | os << x.p(); |
116 | | return os; |
117 | | } |
118 | | |
119 | | template <typename CharT, typename Traits> |
120 | | std::basic_istream<CharT, Traits>& operator>>( |
121 | | std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references) |
122 | | bernoulli_distribution& x) { // NOLINT(runtime/references) |
123 | | auto saver = random_internal::make_istream_state_saver(is); |
124 | | auto p = random_internal::read_floating_point<double>(is); |
125 | | if (!is.fail()) { |
126 | | x.param(bernoulli_distribution::param_type(p)); |
127 | | } |
128 | | return is; |
129 | | } |
130 | | |
131 | | template <typename URBG> |
132 | | bool bernoulli_distribution::Generate(double p, |
133 | 0 | URBG& g) { // NOLINT(runtime/references) |
134 | 0 | random_internal::FastUniformBits<uint32_t> fast_u32; |
135 | |
|
136 | 0 | while (true) { |
137 | | // There are two aspects of the definition of `c` below that are worth |
138 | | // commenting on. First, because `p` is in the range [0, 1], `c` is in the |
139 | | // range [0, 2^32] which does not fit in a uint32_t and therefore requires |
140 | | // 64 bits. |
141 | | // |
142 | | // Second, `c` is constructed by first casting explicitly to a signed |
143 | | // integer and then casting explicitly to an unsigned integer of the same |
144 | | // size. This is done because the hardware conversion instructions produce |
145 | | // signed integers from double; if taken as a uint64_t the conversion would |
146 | | // be wrong for doubles greater than 2^63 (not relevant in this use-case). |
147 | | // If converted directly to an unsigned integer, the compiler would end up |
148 | | // emitting code to handle such large values that are not relevant due to |
149 | | // the known bounds on `c`. To avoid these extra instructions this |
150 | | // implementation converts first to the signed type and then convert to |
151 | | // unsigned (which is a no-op). |
152 | 0 | const uint64_t c = static_cast<uint64_t>(static_cast<int64_t>(p * kP32)); |
153 | 0 | const uint32_t v = fast_u32(g); |
154 | | // FAST PATH: this path fails with probability 1/2^32. Note that simply |
155 | | // returning v <= c would approximate P very well (up to an absolute error |
156 | | // of 1/2^32); the slow path (taken in that range of possible error, in the |
157 | | // case of equality) eliminates the remaining error. |
158 | 0 | if (ABSL_PREDICT_TRUE(v != c)) return v < c; |
159 | | |
160 | | // It is guaranteed that `q` is strictly less than 1, because if `q` were |
161 | | // greater than or equal to 1, the same would be true for `p`. Certainly `p` |
162 | | // cannot be greater than 1, and if `p == 1`, then the fast path would |
163 | | // necessary have been taken already. |
164 | 0 | const double q = static_cast<double>(c) / kP32; |
165 | | |
166 | | // The probability of acceptance on the fast path is `q` and so the |
167 | | // probability of acceptance here should be `p - q`. |
168 | | // |
169 | | // Note that `q` is obtained from `p` via some shifts and conversions, the |
170 | | // upshot of which is that `q` is simply `p` with some of the |
171 | | // least-significant bits of its mantissa set to zero. This means that the |
172 | | // difference `p - q` will not have any rounding errors. To see why, pretend |
173 | | // that double has 10 bits of resolution and q is obtained from `p` in such |
174 | | // a way that the 4 least-significant bits of its mantissa are set to zero. |
175 | | // For example: |
176 | | // p = 1.1100111011 * 2^-1 |
177 | | // q = 1.1100110000 * 2^-1 |
178 | | // p - q = 1.011 * 2^-8 |
179 | | // The difference `p - q` has exactly the nonzero mantissa bits that were |
180 | | // "lost" in `q` producing a number which is certainly representable in a |
181 | | // double. |
182 | 0 | const double left = p - q; |
183 | | |
184 | | // By construction, the probability of being on this slow path is 1/2^32, so |
185 | | // P(accept in slow path) = P(accept| in slow path) * P(slow path), |
186 | | // which means the probability of acceptance here is `1 / (left * kP32)`: |
187 | 0 | const double here = left * kP32; |
188 | | |
189 | | // The simplest way to compute the result of this trial is to repeat the |
190 | | // whole algorithm with the new probability. This terminates because even |
191 | | // given arbitrarily unfriendly "random" bits, each iteration either |
192 | | // multiplies a tiny probability by 2^32 (if c == 0) or strips off some |
193 | | // number of nonzero mantissa bits. That process is bounded. |
194 | 0 | if (here == 0) return false; |
195 | 0 | p = here; |
196 | 0 | } |
197 | 0 | } |
198 | | |
199 | | ABSL_NAMESPACE_END |
200 | | } // namespace absl |
201 | | |
202 | | #endif // ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_ |