Coverage Report

Created: 2026-03-14 06:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/half-1.8.3/src/bfloat.rs
Line
Count
Source
1
#[cfg(feature = "bytemuck")]
2
use bytemuck::{Pod, Zeroable};
3
use core::{
4
    cmp::Ordering,
5
    fmt::{
6
        Binary, Debug, Display, Error, Formatter, LowerExp, LowerHex, Octal, UpperExp, UpperHex,
7
    },
8
    iter::{Product, Sum},
9
    num::{FpCategory, ParseFloatError},
10
    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign},
11
    str::FromStr,
12
};
13
#[cfg(feature = "serde")]
14
use serde::{Deserialize, Serialize};
15
#[cfg(feature = "zerocopy")]
16
use zerocopy::{AsBytes, FromBytes};
17
18
pub(crate) mod convert;
19
20
/// A 16-bit floating point type implementing the [`bfloat16`] format.
21
///
22
/// The [`bfloat16`] floating point format is a truncated 16-bit version of the IEEE 754 standard
23
/// `binary32`, a.k.a [`f32`]. [`bf16`] has approximately the same dynamic range as [`f32`] by
24
/// having a lower precision than [`f16`][crate::f16]. While [`f16`][crate::f16] has a precision of
25
/// 11 bits, [`bf16`] has a precision of only 8 bits.
26
///
27
/// Like [`f16`][crate::f16], [`bf16`] does not offer arithmetic operations as it is intended for
28
/// compact storage rather than calculations. Operations should be performed with [`f32`] or
29
/// higher-precision types and converted to/from [`bf16`] as necessary.
30
///
31
/// [`bfloat16`]: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
32
#[allow(non_camel_case_types)]
33
#[derive(Clone, Copy, Default)]
34
#[repr(transparent)]
35
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36
#[cfg_attr(feature = "bytemuck", derive(Zeroable, Pod))]
37
#[cfg_attr(feature = "zerocopy", derive(AsBytes, FromBytes))]
38
pub struct bf16(u16);
39
40
impl bf16 {
41
    /// Constructs a [`bf16`] value from the raw bits.
42
    #[inline]
43
0
    pub const fn from_bits(bits: u16) -> bf16 {
44
0
        bf16(bits)
45
0
    }
46
47
    /// Constructs a [`bf16`] value from a 32-bit floating point value.
48
    ///
49
    /// If the 32-bit value is too large to fit, ±∞ will result. NaN values are preserved.
50
    /// Subnormal values that are too tiny to be represented will result in ±0. All other values
51
    /// are truncated and rounded to the nearest representable value.
52
    #[inline]
53
0
    pub fn from_f32(value: f32) -> bf16 {
54
0
        bf16(convert::f32_to_bf16(value))
55
0
    }
56
57
    /// Constructs a [`bf16`] value from a 64-bit floating point value.
58
    ///
59
    /// If the 64-bit value is to large to fit, ±∞ will result. NaN values are preserved.
60
    /// 64-bit subnormal values are too tiny to be represented and result in ±0. Exponents that
61
    /// underflow the minimum exponent will result in subnormals or ±0. All other values are
62
    /// truncated and rounded to the nearest representable value.
63
    #[inline]
64
0
    pub fn from_f64(value: f64) -> bf16 {
65
0
        bf16(convert::f64_to_bf16(value))
66
0
    }
67
68
    /// Converts a [`bf16`] into the underlying bit representation.
69
    #[inline]
70
0
    pub const fn to_bits(self) -> u16 {
71
0
        self.0
72
0
    }
73
74
    /// Returns the memory representation of the underlying bit representation as a byte array in
75
    /// little-endian byte order.
76
    ///
77
    /// # Examples
78
    ///
79
    /// ```rust
80
    /// # use half::prelude::*;
81
    /// let bytes = bf16::from_f32(12.5).to_le_bytes();
82
    /// assert_eq!(bytes, [0x48, 0x41]);
83
    /// ```
84
    #[inline]
85
0
    pub const fn to_le_bytes(self) -> [u8; 2] {
86
0
        self.0.to_le_bytes()
87
0
    }
88
89
    /// Returns the memory representation of the underlying bit representation as a byte array in
90
    /// big-endian (network) byte order.
91
    ///
92
    /// # Examples
93
    ///
94
    /// ```rust
95
    /// # use half::prelude::*;
96
    /// let bytes = bf16::from_f32(12.5).to_be_bytes();
97
    /// assert_eq!(bytes, [0x41, 0x48]);
98
    /// ```
99
    #[inline]
100
0
    pub const fn to_be_bytes(self) -> [u8; 2] {
101
0
        self.0.to_be_bytes()
102
0
    }
103
104
    /// Returns the memory representation of the underlying bit representation as a byte array in
105
    /// native byte order.
106
    ///
107
    /// As the target platform's native endianness is used, portable code should use
108
    /// [`to_be_bytes`][bf16::to_be_bytes] or [`to_le_bytes`][bf16::to_le_bytes], as appropriate,
109
    /// instead.
110
    ///
111
    /// # Examples
112
    ///
113
    /// ```rust
114
    /// # use half::prelude::*;
115
    /// let bytes = bf16::from_f32(12.5).to_ne_bytes();
116
    /// assert_eq!(bytes, if cfg!(target_endian = "big") {
117
    ///     [0x41, 0x48]
118
    /// } else {
119
    ///     [0x48, 0x41]
120
    /// });
121
    /// ```
122
    #[inline]
123
0
    pub const fn to_ne_bytes(self) -> [u8; 2] {
124
0
        self.0.to_ne_bytes()
125
0
    }
126
127
    /// Creates a floating point value from its representation as a byte array in little endian.
128
    ///
129
    /// # Examples
130
    ///
131
    /// ```rust
132
    /// # use half::prelude::*;
133
    /// let value = bf16::from_le_bytes([0x48, 0x41]);
134
    /// assert_eq!(value, bf16::from_f32(12.5));
135
    /// ```
136
    #[inline]
137
0
    pub const fn from_le_bytes(bytes: [u8; 2]) -> bf16 {
138
0
        bf16::from_bits(u16::from_le_bytes(bytes))
139
0
    }
140
141
    /// Creates a floating point value from its representation as a byte array in big endian.
142
    ///
143
    /// # Examples
144
    ///
145
    /// ```rust
146
    /// # use half::prelude::*;
147
    /// let value = bf16::from_be_bytes([0x41, 0x48]);
148
    /// assert_eq!(value, bf16::from_f32(12.5));
149
    /// ```
150
    #[inline]
151
0
    pub const fn from_be_bytes(bytes: [u8; 2]) -> bf16 {
152
0
        bf16::from_bits(u16::from_be_bytes(bytes))
153
0
    }
154
155
    /// Creates a floating point value from its representation as a byte array in native endian.
156
    ///
157
    /// As the target platform's native endianness is used, portable code likely wants to use
158
    /// [`from_be_bytes`][bf16::from_be_bytes] or [`from_le_bytes`][bf16::from_le_bytes], as
159
    /// appropriate instead.
160
    ///
161
    /// # Examples
162
    ///
163
    /// ```rust
164
    /// # use half::prelude::*;
165
    /// let value = bf16::from_ne_bytes(if cfg!(target_endian = "big") {
166
    ///     [0x41, 0x48]
167
    /// } else {
168
    ///     [0x48, 0x41]
169
    /// });
170
    /// assert_eq!(value, bf16::from_f32(12.5));
171
    /// ```
172
    #[inline]
173
0
    pub const fn from_ne_bytes(bytes: [u8; 2]) -> bf16 {
174
0
        bf16::from_bits(u16::from_ne_bytes(bytes))
175
0
    }
176
177
    /// Converts a [`bf16`] value into an [`f32`] value.
178
    ///
179
    /// This conversion is lossless as all values can be represented exactly in [`f32`].
180
    #[inline]
181
0
    pub fn to_f32(self) -> f32 {
182
0
        convert::bf16_to_f32(self.0)
183
0
    }
184
185
    /// Converts a [`bf16`] value into an [`f64`] value.
186
    ///
187
    /// This conversion is lossless as all values can be represented exactly in [`f64`].
188
    #[inline]
189
0
    pub fn to_f64(self) -> f64 {
190
0
        convert::bf16_to_f64(self.0)
191
0
    }
192
193
    /// Returns `true` if this value is NaN and `false` otherwise.
194
    ///
195
    /// # Examples
196
    ///
197
    /// ```rust
198
    /// # use half::prelude::*;
199
    ///
200
    /// let nan = bf16::NAN;
201
    /// let f = bf16::from_f32(7.0_f32);
202
    ///
203
    /// assert!(nan.is_nan());
204
    /// assert!(!f.is_nan());
205
    /// ```
206
    #[inline]
207
0
    pub const fn is_nan(self) -> bool {
208
0
        self.0 & 0x7FFFu16 > 0x7F80u16
209
0
    }
210
211
    /// Returns `true` if this value is ±∞ and `false` otherwise.
212
    ///
213
    /// # Examples
214
    ///
215
    /// ```rust
216
    /// # use half::prelude::*;
217
    ///
218
    /// let f = bf16::from_f32(7.0f32);
219
    /// let inf = bf16::INFINITY;
220
    /// let neg_inf = bf16::NEG_INFINITY;
221
    /// let nan = bf16::NAN;
222
    ///
223
    /// assert!(!f.is_infinite());
224
    /// assert!(!nan.is_infinite());
225
    ///
226
    /// assert!(inf.is_infinite());
227
    /// assert!(neg_inf.is_infinite());
228
    /// ```
229
    #[inline]
230
0
    pub const fn is_infinite(self) -> bool {
231
0
        self.0 & 0x7FFFu16 == 0x7F80u16
232
0
    }
233
234
    /// Returns `true` if this number is neither infinite nor NaN.
235
    ///
236
    /// # Examples
237
    ///
238
    /// ```rust
239
    /// # use half::prelude::*;
240
    ///
241
    /// let f = bf16::from_f32(7.0f32);
242
    /// let inf = bf16::INFINITY;
243
    /// let neg_inf = bf16::NEG_INFINITY;
244
    /// let nan = bf16::NAN;
245
    ///
246
    /// assert!(f.is_finite());
247
    ///
248
    /// assert!(!nan.is_finite());
249
    /// assert!(!inf.is_finite());
250
    /// assert!(!neg_inf.is_finite());
251
    /// ```
252
    #[inline]
253
0
    pub const fn is_finite(self) -> bool {
254
0
        self.0 & 0x7F80u16 != 0x7F80u16
255
0
    }
256
257
    /// Returns `true` if the number is neither zero, infinite, subnormal, or NaN.
258
    ///
259
    /// # Examples
260
    ///
261
    /// ```rust
262
    /// # use half::prelude::*;
263
    ///
264
    /// let min = bf16::MIN_POSITIVE;
265
    /// let max = bf16::MAX;
266
    /// let lower_than_min = bf16::from_f32(1.0e-39_f32);
267
    /// let zero = bf16::from_f32(0.0_f32);
268
    ///
269
    /// assert!(min.is_normal());
270
    /// assert!(max.is_normal());
271
    ///
272
    /// assert!(!zero.is_normal());
273
    /// assert!(!bf16::NAN.is_normal());
274
    /// assert!(!bf16::INFINITY.is_normal());
275
    /// // Values between 0 and `min` are subnormal.
276
    /// assert!(!lower_than_min.is_normal());
277
    /// ```
278
    #[inline]
279
0
    pub const fn is_normal(self) -> bool {
280
0
        let exp = self.0 & 0x7F80u16;
281
0
        exp != 0x7F80u16 && exp != 0
282
0
    }
283
284
    /// Returns the floating point category of the number.
285
    ///
286
    /// If only one property is going to be tested, it is generally faster to use the specific
287
    /// predicate instead.
288
    ///
289
    /// # Examples
290
    ///
291
    /// ```rust
292
    /// use std::num::FpCategory;
293
    /// # use half::prelude::*;
294
    ///
295
    /// let num = bf16::from_f32(12.4_f32);
296
    /// let inf = bf16::INFINITY;
297
    ///
298
    /// assert_eq!(num.classify(), FpCategory::Normal);
299
    /// assert_eq!(inf.classify(), FpCategory::Infinite);
300
    /// ```
301
0
    pub const fn classify(self) -> FpCategory {
302
0
        let exp = self.0 & 0x7F80u16;
303
0
        let man = self.0 & 0x007Fu16;
304
0
        match (exp, man) {
305
0
            (0, 0) => FpCategory::Zero,
306
0
            (0, _) => FpCategory::Subnormal,
307
0
            (0x7F80u16, 0) => FpCategory::Infinite,
308
0
            (0x7F80u16, _) => FpCategory::Nan,
309
0
            _ => FpCategory::Normal,
310
        }
311
0
    }
312
313
    /// Returns a number that represents the sign of `self`.
314
    ///
315
    /// * 1.0 if the number is positive, +0.0 or [`INFINITY`][bf16::INFINITY]
316
    /// * −1.0 if the number is negative, −0.0` or [`NEG_INFINITY`][bf16::NEG_INFINITY]
317
    /// * [`NAN`][bf16::NAN] if the number is NaN
318
    ///
319
    /// # Examples
320
    ///
321
    /// ```rust
322
    /// # use half::prelude::*;
323
    ///
324
    /// let f = bf16::from_f32(3.5_f32);
325
    ///
326
    /// assert_eq!(f.signum(), bf16::from_f32(1.0));
327
    /// assert_eq!(bf16::NEG_INFINITY.signum(), bf16::from_f32(-1.0));
328
    ///
329
    /// assert!(bf16::NAN.signum().is_nan());
330
    /// ```
331
0
    pub const fn signum(self) -> bf16 {
332
0
        if self.is_nan() {
333
0
            self
334
0
        } else if self.0 & 0x8000u16 != 0 {
335
0
            Self::NEG_ONE
336
        } else {
337
0
            Self::ONE
338
        }
339
0
    }
340
341
    /// Returns `true` if and only if `self` has a positive sign, including +0.0, NaNs with a
342
    /// positive sign bit and +∞.
343
    ///
344
    /// # Examples
345
    ///
346
    /// ```rust
347
    /// # use half::prelude::*;
348
    ///
349
    /// let nan = bf16::NAN;
350
    /// let f = bf16::from_f32(7.0_f32);
351
    /// let g = bf16::from_f32(-7.0_f32);
352
    ///
353
    /// assert!(f.is_sign_positive());
354
    /// assert!(!g.is_sign_positive());
355
    /// // NaN can be either positive or negative
356
    /// assert!(nan.is_sign_positive() != nan.is_sign_negative());
357
    /// ```
358
    #[inline]
359
0
    pub const fn is_sign_positive(self) -> bool {
360
0
        self.0 & 0x8000u16 == 0
361
0
    }
362
363
    /// Returns `true` if and only if `self` has a negative sign, including −0.0, NaNs with a
364
    /// negative sign bit and −∞.
365
    ///
366
    /// # Examples
367
    ///
368
    /// ```rust
369
    /// # use half::prelude::*;
370
    ///
371
    /// let nan = bf16::NAN;
372
    /// let f = bf16::from_f32(7.0f32);
373
    /// let g = bf16::from_f32(-7.0f32);
374
    ///
375
    /// assert!(!f.is_sign_negative());
376
    /// assert!(g.is_sign_negative());
377
    /// // NaN can be either positive or negative
378
    /// assert!(nan.is_sign_positive() != nan.is_sign_negative());
379
    /// ```
380
    #[inline]
381
0
    pub const fn is_sign_negative(self) -> bool {
382
0
        self.0 & 0x8000u16 != 0
383
0
    }
384
385
    /// Returns a number composed of the magnitude of `self` and the sign of `sign`.
386
    ///
387
    /// Equal to `self` if the sign of `self` and `sign` are the same, otherwise equal to `-self`.
388
    /// If `self` is NaN, then NaN with the sign of `sign` is returned.
389
    ///
390
    /// # Examples
391
    ///
392
    /// ```
393
    /// # use half::prelude::*;
394
    /// let f = bf16::from_f32(3.5);
395
    ///
396
    /// assert_eq!(f.copysign(bf16::from_f32(0.42)), bf16::from_f32(3.5));
397
    /// assert_eq!(f.copysign(bf16::from_f32(-0.42)), bf16::from_f32(-3.5));
398
    /// assert_eq!((-f).copysign(bf16::from_f32(0.42)), bf16::from_f32(3.5));
399
    /// assert_eq!((-f).copysign(bf16::from_f32(-0.42)), bf16::from_f32(-3.5));
400
    ///
401
    /// assert!(bf16::NAN.copysign(bf16::from_f32(1.0)).is_nan());
402
    /// ```
403
    #[inline]
404
0
    pub const fn copysign(self, sign: bf16) -> bf16 {
405
0
        bf16((sign.0 & 0x8000u16) | (self.0 & 0x7FFFu16))
406
0
    }
407
408
    /// Returns the maximum of the two numbers.
409
    ///
410
    /// If one of the arguments is NaN, then the other argument is returned.
411
    ///
412
    /// # Examples
413
    ///
414
    /// ```
415
    /// # use half::prelude::*;
416
    /// let x = bf16::from_f32(1.0);
417
    /// let y = bf16::from_f32(2.0);
418
    ///
419
    /// assert_eq!(x.max(y), y);
420
    /// ```
421
    #[inline]
422
0
    pub fn max(self, other: bf16) -> bf16 {
423
0
        if other > self && !other.is_nan() {
424
0
            other
425
        } else {
426
0
            self
427
        }
428
0
    }
429
430
    /// Returns the minimum of the two numbers.
431
    ///
432
    /// If one of the arguments is NaN, then the other argument is returned.
433
    ///
434
    /// # Examples
435
    ///
436
    /// ```
437
    /// # use half::prelude::*;
438
    /// let x = bf16::from_f32(1.0);
439
    /// let y = bf16::from_f32(2.0);
440
    ///
441
    /// assert_eq!(x.min(y), x);
442
    /// ```
443
    #[inline]
444
0
    pub fn min(self, other: bf16) -> bf16 {
445
0
        if other < self && !other.is_nan() {
446
0
            other
447
        } else {
448
0
            self
449
        }
450
0
    }
451
452
    /// Restrict a value to a certain interval unless it is NaN.
453
    ///
454
    /// Returns `max` if `self` is greater than `max`, and `min` if `self` is less than `min`.
455
    /// Otherwise this returns `self`.
456
    ///
457
    /// Note that this function returns NaN if the initial value was NaN as well.
458
    ///
459
    /// # Panics
460
    /// Panics if `min > max`, `min` is NaN, or `max` is NaN.
461
    ///
462
    /// # Examples
463
    ///
464
    /// ```
465
    /// # use half::prelude::*;
466
    /// assert!(bf16::from_f32(-3.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(-2.0));
467
    /// assert!(bf16::from_f32(0.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(0.0));
468
    /// assert!(bf16::from_f32(2.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(1.0));
469
    /// assert!(bf16::NAN.clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)).is_nan());
470
    /// ```
471
    #[inline]
472
0
    pub fn clamp(self, min: bf16, max: bf16) -> bf16 {
473
0
        assert!(min <= max);
474
0
        let mut x = self;
475
0
        if x < min {
476
0
            x = min;
477
0
        }
478
0
        if x > max {
479
0
            x = max;
480
0
        }
481
0
        x
482
0
    }
483
484
    /// Approximate number of [`bf16`] significant digits in base 10
485
    pub const DIGITS: u32 = 2;
486
    /// [`bf16`]
487
    /// [machine epsilon](https://en.wikipedia.org/wiki/Machine_epsilon) value
488
    ///
489
    /// This is the difference between 1.0 and the next largest representable number.
490
    pub const EPSILON: bf16 = bf16(0x3C00u16);
491
    /// [`bf16`] positive Infinity (+∞)
492
    pub const INFINITY: bf16 = bf16(0x7F80u16);
493
    /// Number of [`bf16`] significant digits in base 2
494
    pub const MANTISSA_DIGITS: u32 = 8;
495
    /// Largest finite [`bf16`] value
496
    pub const MAX: bf16 = bf16(0x7F7F);
497
    /// Maximum possible [`bf16`] power of 10 exponent
498
    pub const MAX_10_EXP: i32 = 38;
499
    /// Maximum possible [`bf16`] power of 2 exponent
500
    pub const MAX_EXP: i32 = 128;
501
    /// Smallest finite [`bf16`] value
502
    pub const MIN: bf16 = bf16(0xFF7F);
503
    /// Minimum possible normal [`bf16`] power of 10 exponent
504
    pub const MIN_10_EXP: i32 = -37;
505
    /// One greater than the minimum possible normal [`bf16`] power of 2 exponent
506
    pub const MIN_EXP: i32 = -125;
507
    /// Smallest positive normal [`bf16`] value
508
    pub const MIN_POSITIVE: bf16 = bf16(0x0080u16);
509
    /// [`bf16`] Not a Number (NaN)
510
    pub const NAN: bf16 = bf16(0x7FC0u16);
511
    /// [`bf16`] negative infinity (-∞).
512
    pub const NEG_INFINITY: bf16 = bf16(0xFF80u16);
513
    /// The radix or base of the internal representation of [`bf16`]
514
    pub const RADIX: u32 = 2;
515
516
    /// Minimum positive subnormal [`bf16`] value
517
    pub const MIN_POSITIVE_SUBNORMAL: bf16 = bf16(0x0001u16);
518
    /// Maximum subnormal [`bf16`] value
519
    pub const MAX_SUBNORMAL: bf16 = bf16(0x007Fu16);
520
521
    /// [`bf16`] 1
522
    pub const ONE: bf16 = bf16(0x3F80u16);
523
    /// [`bf16`] 0
524
    pub const ZERO: bf16 = bf16(0x0000u16);
525
    /// [`bf16`] -0
526
    pub const NEG_ZERO: bf16 = bf16(0x8000u16);
527
    /// [`bf16`] -1
528
    pub const NEG_ONE: bf16 = bf16(0xBF80u16);
529
530
    /// [`bf16`] Euler's number (ℯ)
531
    pub const E: bf16 = bf16(0x402Eu16);
532
    /// [`bf16`] Archimedes' constant (π)
533
    pub const PI: bf16 = bf16(0x4049u16);
534
    /// [`bf16`] 1/π
535
    pub const FRAC_1_PI: bf16 = bf16(0x3EA3u16);
536
    /// [`bf16`] 1/√2
537
    pub const FRAC_1_SQRT_2: bf16 = bf16(0x3F35u16);
538
    /// [`bf16`] 2/π
539
    pub const FRAC_2_PI: bf16 = bf16(0x3F23u16);
540
    /// [`bf16`] 2/√π
541
    pub const FRAC_2_SQRT_PI: bf16 = bf16(0x3F90u16);
542
    /// [`bf16`] π/2
543
    pub const FRAC_PI_2: bf16 = bf16(0x3FC9u16);
544
    /// [`bf16`] π/3
545
    pub const FRAC_PI_3: bf16 = bf16(0x3F86u16);
546
    /// [`bf16`] π/4
547
    pub const FRAC_PI_4: bf16 = bf16(0x3F49u16);
548
    /// [`bf16`] π/6
549
    pub const FRAC_PI_6: bf16 = bf16(0x3F06u16);
550
    /// [`bf16`] π/8
551
    pub const FRAC_PI_8: bf16 = bf16(0x3EC9u16);
552
    /// [`bf16`] 𝗅𝗇 10
553
    pub const LN_10: bf16 = bf16(0x4013u16);
554
    /// [`bf16`] 𝗅𝗇 2
555
    pub const LN_2: bf16 = bf16(0x3F31u16);
556
    /// [`bf16`] 𝗅𝗈𝗀₁₀ℯ
557
    pub const LOG10_E: bf16 = bf16(0x3EDEu16);
558
    /// [`bf16`] 𝗅𝗈𝗀₁₀2
559
    pub const LOG10_2: bf16 = bf16(0x3E9Au16);
560
    /// [`bf16`] 𝗅𝗈𝗀₂ℯ
561
    pub const LOG2_E: bf16 = bf16(0x3FB9u16);
562
    /// [`bf16`] 𝗅𝗈𝗀₂10
563
    pub const LOG2_10: bf16 = bf16(0x4055u16);
564
    /// [`bf16`] √2
565
    pub const SQRT_2: bf16 = bf16(0x3FB5u16);
566
}
567
568
impl From<bf16> for f32 {
569
    #[inline]
570
0
    fn from(x: bf16) -> f32 {
571
0
        x.to_f32()
572
0
    }
573
}
574
575
impl From<bf16> for f64 {
576
    #[inline]
577
0
    fn from(x: bf16) -> f64 {
578
0
        x.to_f64()
579
0
    }
580
}
581
582
impl From<i8> for bf16 {
583
    #[inline]
584
0
    fn from(x: i8) -> bf16 {
585
        // Convert to f32, then to bf16
586
0
        bf16::from_f32(f32::from(x))
587
0
    }
588
}
589
590
impl From<u8> for bf16 {
591
    #[inline]
592
0
    fn from(x: u8) -> bf16 {
593
        // Convert to f32, then to f16
594
0
        bf16::from_f32(f32::from(x))
595
0
    }
596
}
597
598
impl PartialEq for bf16 {
599
0
    fn eq(&self, other: &bf16) -> bool {
600
0
        if self.is_nan() || other.is_nan() {
601
0
            false
602
        } else {
603
0
            (self.0 == other.0) || ((self.0 | other.0) & 0x7FFFu16 == 0)
604
        }
605
0
    }
606
}
607
608
impl PartialOrd for bf16 {
609
0
    fn partial_cmp(&self, other: &bf16) -> Option<Ordering> {
610
0
        if self.is_nan() || other.is_nan() {
611
0
            None
612
        } else {
613
0
            let neg = self.0 & 0x8000u16 != 0;
614
0
            let other_neg = other.0 & 0x8000u16 != 0;
615
0
            match (neg, other_neg) {
616
0
                (false, false) => Some(self.0.cmp(&other.0)),
617
                (false, true) => {
618
0
                    if (self.0 | other.0) & 0x7FFFu16 == 0 {
619
0
                        Some(Ordering::Equal)
620
                    } else {
621
0
                        Some(Ordering::Greater)
622
                    }
623
                }
624
                (true, false) => {
625
0
                    if (self.0 | other.0) & 0x7FFFu16 == 0 {
626
0
                        Some(Ordering::Equal)
627
                    } else {
628
0
                        Some(Ordering::Less)
629
                    }
630
                }
631
0
                (true, true) => Some(other.0.cmp(&self.0)),
632
            }
633
        }
634
0
    }
635
636
0
    fn lt(&self, other: &bf16) -> bool {
637
0
        if self.is_nan() || other.is_nan() {
638
0
            false
639
        } else {
640
0
            let neg = self.0 & 0x8000u16 != 0;
641
0
            let other_neg = other.0 & 0x8000u16 != 0;
642
0
            match (neg, other_neg) {
643
0
                (false, false) => self.0 < other.0,
644
0
                (false, true) => false,
645
0
                (true, false) => (self.0 | other.0) & 0x7FFFu16 != 0,
646
0
                (true, true) => self.0 > other.0,
647
            }
648
        }
649
0
    }
650
651
0
    fn le(&self, other: &bf16) -> bool {
652
0
        if self.is_nan() || other.is_nan() {
653
0
            false
654
        } else {
655
0
            let neg = self.0 & 0x8000u16 != 0;
656
0
            let other_neg = other.0 & 0x8000u16 != 0;
657
0
            match (neg, other_neg) {
658
0
                (false, false) => self.0 <= other.0,
659
0
                (false, true) => (self.0 | other.0) & 0x7FFFu16 == 0,
660
0
                (true, false) => true,
661
0
                (true, true) => self.0 >= other.0,
662
            }
663
        }
664
0
    }
665
666
0
    fn gt(&self, other: &bf16) -> bool {
667
0
        if self.is_nan() || other.is_nan() {
668
0
            false
669
        } else {
670
0
            let neg = self.0 & 0x8000u16 != 0;
671
0
            let other_neg = other.0 & 0x8000u16 != 0;
672
0
            match (neg, other_neg) {
673
0
                (false, false) => self.0 > other.0,
674
0
                (false, true) => (self.0 | other.0) & 0x7FFFu16 != 0,
675
0
                (true, false) => false,
676
0
                (true, true) => self.0 < other.0,
677
            }
678
        }
679
0
    }
680
681
0
    fn ge(&self, other: &bf16) -> bool {
682
0
        if self.is_nan() || other.is_nan() {
683
0
            false
684
        } else {
685
0
            let neg = self.0 & 0x8000u16 != 0;
686
0
            let other_neg = other.0 & 0x8000u16 != 0;
687
0
            match (neg, other_neg) {
688
0
                (false, false) => self.0 >= other.0,
689
0
                (false, true) => true,
690
0
                (true, false) => (self.0 | other.0) & 0x7FFFu16 == 0,
691
0
                (true, true) => self.0 <= other.0,
692
            }
693
        }
694
0
    }
695
}
696
697
impl FromStr for bf16 {
698
    type Err = ParseFloatError;
699
0
    fn from_str(src: &str) -> Result<bf16, ParseFloatError> {
700
0
        f32::from_str(src).map(bf16::from_f32)
701
0
    }
702
}
703
704
impl Debug for bf16 {
705
0
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
706
0
        write!(f, "{:?}", self.to_f32())
707
0
    }
708
}
709
710
impl Display for bf16 {
711
0
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
712
0
        write!(f, "{}", self.to_f32())
713
0
    }
714
}
715
716
impl LowerExp for bf16 {
717
0
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
718
0
        write!(f, "{:e}", self.to_f32())
719
0
    }
720
}
721
722
impl UpperExp for bf16 {
723
0
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
724
0
        write!(f, "{:E}", self.to_f32())
725
0
    }
726
}
727
728
impl Binary for bf16 {
729
0
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
730
0
        write!(f, "{:b}", self.0)
731
0
    }
732
}
733
734
impl Octal for bf16 {
735
0
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
736
0
        write!(f, "{:o}", self.0)
737
0
    }
738
}
739
740
impl LowerHex for bf16 {
741
0
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
742
0
        write!(f, "{:x}", self.0)
743
0
    }
744
}
745
746
impl UpperHex for bf16 {
747
0
    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
748
0
        write!(f, "{:X}", self.0)
749
0
    }
750
}
751
752
impl Neg for bf16 {
753
    type Output = Self;
754
755
0
    fn neg(self) -> Self::Output {
756
0
        Self(self.0 ^ 0x8000)
757
0
    }
758
}
759
760
impl Add for bf16 {
761
    type Output = Self;
762
763
0
    fn add(self, rhs: Self) -> Self::Output {
764
0
        Self::from_f32(Self::to_f32(self) + Self::to_f32(rhs))
765
0
    }
766
}
767
768
impl Add<&bf16> for bf16 {
769
    type Output = <bf16 as Add<bf16>>::Output;
770
771
    #[inline]
772
0
    fn add(self, rhs: &bf16) -> Self::Output {
773
0
        self.add(*rhs)
774
0
    }
775
}
776
777
impl Add<&bf16> for &bf16 {
778
    type Output = <bf16 as Add<bf16>>::Output;
779
780
    #[inline]
781
0
    fn add(self, rhs: &bf16) -> Self::Output {
782
0
        (*self).add(*rhs)
783
0
    }
784
}
785
786
impl Add<bf16> for &bf16 {
787
    type Output = <bf16 as Add<bf16>>::Output;
788
789
    #[inline]
790
0
    fn add(self, rhs: bf16) -> Self::Output {
791
0
        (*self).add(rhs)
792
0
    }
793
}
794
795
impl AddAssign for bf16 {
796
    #[inline]
797
0
    fn add_assign(&mut self, rhs: Self) {
798
0
        *self = (*self).add(rhs);
799
0
    }
800
}
801
802
impl AddAssign<&bf16> for bf16 {
803
    #[inline]
804
0
    fn add_assign(&mut self, rhs: &bf16) {
805
0
        *self = (*self).add(rhs);
806
0
    }
807
}
808
809
impl Sub for bf16 {
810
    type Output = Self;
811
812
0
    fn sub(self, rhs: Self) -> Self::Output {
813
0
        Self::from_f32(Self::to_f32(self) - Self::to_f32(rhs))
814
0
    }
815
}
816
817
impl Sub<&bf16> for bf16 {
818
    type Output = <bf16 as Sub<bf16>>::Output;
819
820
    #[inline]
821
0
    fn sub(self, rhs: &bf16) -> Self::Output {
822
0
        self.sub(*rhs)
823
0
    }
824
}
825
826
impl Sub<&bf16> for &bf16 {
827
    type Output = <bf16 as Sub<bf16>>::Output;
828
829
    #[inline]
830
0
    fn sub(self, rhs: &bf16) -> Self::Output {
831
0
        (*self).sub(*rhs)
832
0
    }
833
}
834
835
impl Sub<bf16> for &bf16 {
836
    type Output = <bf16 as Sub<bf16>>::Output;
837
838
    #[inline]
839
0
    fn sub(self, rhs: bf16) -> Self::Output {
840
0
        (*self).sub(rhs)
841
0
    }
842
}
843
844
impl SubAssign for bf16 {
845
    #[inline]
846
0
    fn sub_assign(&mut self, rhs: Self) {
847
0
        *self = (*self).sub(rhs);
848
0
    }
849
}
850
851
impl SubAssign<&bf16> for bf16 {
852
    #[inline]
853
0
    fn sub_assign(&mut self, rhs: &bf16) {
854
0
        *self = (*self).sub(rhs);
855
0
    }
856
}
857
858
impl Mul for bf16 {
859
    type Output = Self;
860
861
0
    fn mul(self, rhs: Self) -> Self::Output {
862
0
        Self::from_f32(Self::to_f32(self) * Self::to_f32(rhs))
863
0
    }
864
}
865
866
impl Mul<&bf16> for bf16 {
867
    type Output = <bf16 as Mul<bf16>>::Output;
868
869
    #[inline]
870
0
    fn mul(self, rhs: &bf16) -> Self::Output {
871
0
        self.mul(*rhs)
872
0
    }
873
}
874
875
impl Mul<&bf16> for &bf16 {
876
    type Output = <bf16 as Mul<bf16>>::Output;
877
878
    #[inline]
879
0
    fn mul(self, rhs: &bf16) -> Self::Output {
880
0
        (*self).mul(*rhs)
881
0
    }
882
}
883
884
impl Mul<bf16> for &bf16 {
885
    type Output = <bf16 as Mul<bf16>>::Output;
886
887
    #[inline]
888
0
    fn mul(self, rhs: bf16) -> Self::Output {
889
0
        (*self).mul(rhs)
890
0
    }
891
}
892
893
impl MulAssign for bf16 {
894
    #[inline]
895
0
    fn mul_assign(&mut self, rhs: Self) {
896
0
        *self = (*self).mul(rhs);
897
0
    }
898
}
899
900
impl MulAssign<&bf16> for bf16 {
901
    #[inline]
902
0
    fn mul_assign(&mut self, rhs: &bf16) {
903
0
        *self = (*self).mul(rhs);
904
0
    }
905
}
906
907
impl Div for bf16 {
908
    type Output = Self;
909
910
0
    fn div(self, rhs: Self) -> Self::Output {
911
0
        Self::from_f32(Self::to_f32(self) / Self::to_f32(rhs))
912
0
    }
913
}
914
915
impl Div<&bf16> for bf16 {
916
    type Output = <bf16 as Div<bf16>>::Output;
917
918
    #[inline]
919
0
    fn div(self, rhs: &bf16) -> Self::Output {
920
0
        self.div(*rhs)
921
0
    }
922
}
923
924
impl Div<&bf16> for &bf16 {
925
    type Output = <bf16 as Div<bf16>>::Output;
926
927
    #[inline]
928
0
    fn div(self, rhs: &bf16) -> Self::Output {
929
0
        (*self).div(*rhs)
930
0
    }
931
}
932
933
impl Div<bf16> for &bf16 {
934
    type Output = <bf16 as Div<bf16>>::Output;
935
936
    #[inline]
937
0
    fn div(self, rhs: bf16) -> Self::Output {
938
0
        (*self).div(rhs)
939
0
    }
940
}
941
942
impl DivAssign for bf16 {
943
    #[inline]
944
0
    fn div_assign(&mut self, rhs: Self) {
945
0
        *self = (*self).div(rhs);
946
0
    }
947
}
948
949
impl DivAssign<&bf16> for bf16 {
950
    #[inline]
951
0
    fn div_assign(&mut self, rhs: &bf16) {
952
0
        *self = (*self).div(rhs);
953
0
    }
954
}
955
956
impl Rem for bf16 {
957
    type Output = Self;
958
959
0
    fn rem(self, rhs: Self) -> Self::Output {
960
0
        Self::from_f32(Self::to_f32(self) % Self::to_f32(rhs))
961
0
    }
962
}
963
964
impl Rem<&bf16> for bf16 {
965
    type Output = <bf16 as Rem<bf16>>::Output;
966
967
    #[inline]
968
0
    fn rem(self, rhs: &bf16) -> Self::Output {
969
0
        self.rem(*rhs)
970
0
    }
971
}
972
973
impl Rem<&bf16> for &bf16 {
974
    type Output = <bf16 as Rem<bf16>>::Output;
975
976
    #[inline]
977
0
    fn rem(self, rhs: &bf16) -> Self::Output {
978
0
        (*self).rem(*rhs)
979
0
    }
980
}
981
982
impl Rem<bf16> for &bf16 {
983
    type Output = <bf16 as Rem<bf16>>::Output;
984
985
    #[inline]
986
0
    fn rem(self, rhs: bf16) -> Self::Output {
987
0
        (*self).rem(rhs)
988
0
    }
989
}
990
991
impl RemAssign for bf16 {
992
    #[inline]
993
0
    fn rem_assign(&mut self, rhs: Self) {
994
0
        *self = (*self).rem(rhs);
995
0
    }
996
}
997
998
impl RemAssign<&bf16> for bf16 {
999
    #[inline]
1000
0
    fn rem_assign(&mut self, rhs: &bf16) {
1001
0
        *self = (*self).rem(rhs);
1002
0
    }
1003
}
1004
1005
impl Product for bf16 {
1006
    #[inline]
1007
0
    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
1008
0
        bf16::from_f32(iter.map(|f| f.to_f32()).product())
1009
0
    }
1010
}
1011
1012
impl<'a> Product<&'a bf16> for bf16 {
1013
    #[inline]
1014
0
    fn product<I: Iterator<Item = &'a bf16>>(iter: I) -> Self {
1015
0
        bf16::from_f32(iter.map(|f| f.to_f32()).product())
1016
0
    }
1017
}
1018
1019
impl Sum for bf16 {
1020
    #[inline]
1021
0
    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
1022
0
        bf16::from_f32(iter.map(|f| f.to_f32()).sum())
1023
0
    }
1024
}
1025
1026
impl<'a> Sum<&'a bf16> for bf16 {
1027
    #[inline]
1028
0
    fn sum<I: Iterator<Item = &'a bf16>>(iter: I) -> Self {
1029
0
        bf16::from_f32(iter.map(|f| f.to_f32()).product())
1030
0
    }
1031
}
1032
1033
#[allow(
1034
    clippy::cognitive_complexity,
1035
    clippy::float_cmp,
1036
    clippy::neg_cmp_op_on_partial_ord
1037
)]
1038
#[cfg(test)]
1039
mod test {
1040
    use super::*;
1041
    use core::cmp::Ordering;
1042
    #[cfg(feature = "num-traits")]
1043
    use num_traits::{AsPrimitive, FromPrimitive, ToPrimitive};
1044
    use quickcheck_macros::quickcheck;
1045
1046
    #[cfg(feature = "num-traits")]
1047
    #[test]
1048
    fn as_primitive() {
1049
        let two = bf16::from_f32(2.0);
1050
        assert_eq!(<i32 as AsPrimitive<bf16>>::as_(2), two);
1051
        assert_eq!(<bf16 as AsPrimitive<i32>>::as_(two), 2);
1052
1053
        assert_eq!(<f32 as AsPrimitive<bf16>>::as_(2.0), two);
1054
        assert_eq!(<bf16 as AsPrimitive<f32>>::as_(two), 2.0);
1055
1056
        assert_eq!(<f64 as AsPrimitive<bf16>>::as_(2.0), two);
1057
        assert_eq!(<bf16 as AsPrimitive<f64>>::as_(two), 2.0);
1058
    }
1059
1060
    #[cfg(feature = "num-traits")]
1061
    #[test]
1062
    fn to_primitive() {
1063
        let two = bf16::from_f32(2.0);
1064
        assert_eq!(ToPrimitive::to_i32(&two).unwrap(), 2i32);
1065
        assert_eq!(ToPrimitive::to_f32(&two).unwrap(), 2.0f32);
1066
        assert_eq!(ToPrimitive::to_f64(&two).unwrap(), 2.0f64);
1067
    }
1068
1069
    #[cfg(feature = "num-traits")]
1070
    #[test]
1071
    fn from_primitive() {
1072
        let two = bf16::from_f32(2.0);
1073
        assert_eq!(<bf16 as FromPrimitive>::from_i32(2).unwrap(), two);
1074
        assert_eq!(<bf16 as FromPrimitive>::from_f32(2.0).unwrap(), two);
1075
        assert_eq!(<bf16 as FromPrimitive>::from_f64(2.0).unwrap(), two);
1076
    }
1077
1078
    #[test]
1079
    fn test_bf16_consts_from_f32() {
1080
        let one = bf16::from_f32(1.0);
1081
        let zero = bf16::from_f32(0.0);
1082
        let neg_zero = bf16::from_f32(-0.0);
1083
        let neg_one = bf16::from_f32(-1.0);
1084
        let inf = bf16::from_f32(core::f32::INFINITY);
1085
        let neg_inf = bf16::from_f32(core::f32::NEG_INFINITY);
1086
        let nan = bf16::from_f32(core::f32::NAN);
1087
1088
        assert_eq!(bf16::ONE, one);
1089
        assert_eq!(bf16::ZERO, zero);
1090
        assert!(zero.is_sign_positive());
1091
        assert_eq!(bf16::NEG_ZERO, neg_zero);
1092
        assert!(neg_zero.is_sign_negative());
1093
        assert_eq!(bf16::NEG_ONE, neg_one);
1094
        assert!(neg_one.is_sign_negative());
1095
        assert_eq!(bf16::INFINITY, inf);
1096
        assert_eq!(bf16::NEG_INFINITY, neg_inf);
1097
        assert!(nan.is_nan());
1098
        assert!(bf16::NAN.is_nan());
1099
1100
        let e = bf16::from_f32(core::f32::consts::E);
1101
        let pi = bf16::from_f32(core::f32::consts::PI);
1102
        let frac_1_pi = bf16::from_f32(core::f32::consts::FRAC_1_PI);
1103
        let frac_1_sqrt_2 = bf16::from_f32(core::f32::consts::FRAC_1_SQRT_2);
1104
        let frac_2_pi = bf16::from_f32(core::f32::consts::FRAC_2_PI);
1105
        let frac_2_sqrt_pi = bf16::from_f32(core::f32::consts::FRAC_2_SQRT_PI);
1106
        let frac_pi_2 = bf16::from_f32(core::f32::consts::FRAC_PI_2);
1107
        let frac_pi_3 = bf16::from_f32(core::f32::consts::FRAC_PI_3);
1108
        let frac_pi_4 = bf16::from_f32(core::f32::consts::FRAC_PI_4);
1109
        let frac_pi_6 = bf16::from_f32(core::f32::consts::FRAC_PI_6);
1110
        let frac_pi_8 = bf16::from_f32(core::f32::consts::FRAC_PI_8);
1111
        let ln_10 = bf16::from_f32(core::f32::consts::LN_10);
1112
        let ln_2 = bf16::from_f32(core::f32::consts::LN_2);
1113
        let log10_e = bf16::from_f32(core::f32::consts::LOG10_E);
1114
        // core::f32::consts::LOG10_2 requires rustc 1.43.0
1115
        let log10_2 = bf16::from_f32(2f32.log10());
1116
        let log2_e = bf16::from_f32(core::f32::consts::LOG2_E);
1117
        // core::f32::consts::LOG2_10 requires rustc 1.43.0
1118
        let log2_10 = bf16::from_f32(10f32.log2());
1119
        let sqrt_2 = bf16::from_f32(core::f32::consts::SQRT_2);
1120
1121
        assert_eq!(bf16::E, e);
1122
        assert_eq!(bf16::PI, pi);
1123
        assert_eq!(bf16::FRAC_1_PI, frac_1_pi);
1124
        assert_eq!(bf16::FRAC_1_SQRT_2, frac_1_sqrt_2);
1125
        assert_eq!(bf16::FRAC_2_PI, frac_2_pi);
1126
        assert_eq!(bf16::FRAC_2_SQRT_PI, frac_2_sqrt_pi);
1127
        assert_eq!(bf16::FRAC_PI_2, frac_pi_2);
1128
        assert_eq!(bf16::FRAC_PI_3, frac_pi_3);
1129
        assert_eq!(bf16::FRAC_PI_4, frac_pi_4);
1130
        assert_eq!(bf16::FRAC_PI_6, frac_pi_6);
1131
        assert_eq!(bf16::FRAC_PI_8, frac_pi_8);
1132
        assert_eq!(bf16::LN_10, ln_10);
1133
        assert_eq!(bf16::LN_2, ln_2);
1134
        assert_eq!(bf16::LOG10_E, log10_e);
1135
        assert_eq!(bf16::LOG10_2, log10_2);
1136
        assert_eq!(bf16::LOG2_E, log2_e);
1137
        assert_eq!(bf16::LOG2_10, log2_10);
1138
        assert_eq!(bf16::SQRT_2, sqrt_2);
1139
    }
1140
1141
    #[test]
1142
    fn test_bf16_consts_from_f64() {
1143
        let one = bf16::from_f64(1.0);
1144
        let zero = bf16::from_f64(0.0);
1145
        let neg_zero = bf16::from_f64(-0.0);
1146
        let inf = bf16::from_f64(core::f64::INFINITY);
1147
        let neg_inf = bf16::from_f64(core::f64::NEG_INFINITY);
1148
        let nan = bf16::from_f64(core::f64::NAN);
1149
1150
        assert_eq!(bf16::ONE, one);
1151
        assert_eq!(bf16::ZERO, zero);
1152
        assert_eq!(bf16::NEG_ZERO, neg_zero);
1153
        assert_eq!(bf16::INFINITY, inf);
1154
        assert_eq!(bf16::NEG_INFINITY, neg_inf);
1155
        assert!(nan.is_nan());
1156
        assert!(bf16::NAN.is_nan());
1157
1158
        let e = bf16::from_f64(core::f64::consts::E);
1159
        let pi = bf16::from_f64(core::f64::consts::PI);
1160
        let frac_1_pi = bf16::from_f64(core::f64::consts::FRAC_1_PI);
1161
        let frac_1_sqrt_2 = bf16::from_f64(core::f64::consts::FRAC_1_SQRT_2);
1162
        let frac_2_pi = bf16::from_f64(core::f64::consts::FRAC_2_PI);
1163
        let frac_2_sqrt_pi = bf16::from_f64(core::f64::consts::FRAC_2_SQRT_PI);
1164
        let frac_pi_2 = bf16::from_f64(core::f64::consts::FRAC_PI_2);
1165
        let frac_pi_3 = bf16::from_f64(core::f64::consts::FRAC_PI_3);
1166
        let frac_pi_4 = bf16::from_f64(core::f64::consts::FRAC_PI_4);
1167
        let frac_pi_6 = bf16::from_f64(core::f64::consts::FRAC_PI_6);
1168
        let frac_pi_8 = bf16::from_f64(core::f64::consts::FRAC_PI_8);
1169
        let ln_10 = bf16::from_f64(core::f64::consts::LN_10);
1170
        let ln_2 = bf16::from_f64(core::f64::consts::LN_2);
1171
        let log10_e = bf16::from_f64(core::f64::consts::LOG10_E);
1172
        // core::f64::consts::LOG10_2 requires rustc 1.43.0
1173
        let log10_2 = bf16::from_f64(2f64.log10());
1174
        let log2_e = bf16::from_f64(core::f64::consts::LOG2_E);
1175
        // core::f64::consts::LOG2_10 requires rustc 1.43.0
1176
        let log2_10 = bf16::from_f64(10f64.log2());
1177
        let sqrt_2 = bf16::from_f64(core::f64::consts::SQRT_2);
1178
1179
        assert_eq!(bf16::E, e);
1180
        assert_eq!(bf16::PI, pi);
1181
        assert_eq!(bf16::FRAC_1_PI, frac_1_pi);
1182
        assert_eq!(bf16::FRAC_1_SQRT_2, frac_1_sqrt_2);
1183
        assert_eq!(bf16::FRAC_2_PI, frac_2_pi);
1184
        assert_eq!(bf16::FRAC_2_SQRT_PI, frac_2_sqrt_pi);
1185
        assert_eq!(bf16::FRAC_PI_2, frac_pi_2);
1186
        assert_eq!(bf16::FRAC_PI_3, frac_pi_3);
1187
        assert_eq!(bf16::FRAC_PI_4, frac_pi_4);
1188
        assert_eq!(bf16::FRAC_PI_6, frac_pi_6);
1189
        assert_eq!(bf16::FRAC_PI_8, frac_pi_8);
1190
        assert_eq!(bf16::LN_10, ln_10);
1191
        assert_eq!(bf16::LN_2, ln_2);
1192
        assert_eq!(bf16::LOG10_E, log10_e);
1193
        assert_eq!(bf16::LOG10_2, log10_2);
1194
        assert_eq!(bf16::LOG2_E, log2_e);
1195
        assert_eq!(bf16::LOG2_10, log2_10);
1196
        assert_eq!(bf16::SQRT_2, sqrt_2);
1197
    }
1198
1199
    #[test]
1200
    fn test_nan_conversion_to_smaller() {
1201
        let nan64 = f64::from_bits(0x7FF0_0000_0000_0001u64);
1202
        let neg_nan64 = f64::from_bits(0xFFF0_0000_0000_0001u64);
1203
        let nan32 = f32::from_bits(0x7F80_0001u32);
1204
        let neg_nan32 = f32::from_bits(0xFF80_0001u32);
1205
        let nan32_from_64 = nan64 as f32;
1206
        let neg_nan32_from_64 = neg_nan64 as f32;
1207
        let nan16_from_64 = bf16::from_f64(nan64);
1208
        let neg_nan16_from_64 = bf16::from_f64(neg_nan64);
1209
        let nan16_from_32 = bf16::from_f32(nan32);
1210
        let neg_nan16_from_32 = bf16::from_f32(neg_nan32);
1211
1212
        assert!(nan64.is_nan() && nan64.is_sign_positive());
1213
        assert!(neg_nan64.is_nan() && neg_nan64.is_sign_negative());
1214
        assert!(nan32.is_nan() && nan32.is_sign_positive());
1215
        assert!(neg_nan32.is_nan() && neg_nan32.is_sign_negative());
1216
        assert!(nan32_from_64.is_nan() && nan32_from_64.is_sign_positive());
1217
        assert!(neg_nan32_from_64.is_nan() && neg_nan32_from_64.is_sign_negative());
1218
        assert!(nan16_from_64.is_nan() && nan16_from_64.is_sign_positive());
1219
        assert!(neg_nan16_from_64.is_nan() && neg_nan16_from_64.is_sign_negative());
1220
        assert!(nan16_from_32.is_nan() && nan16_from_32.is_sign_positive());
1221
        assert!(neg_nan16_from_32.is_nan() && neg_nan16_from_32.is_sign_negative());
1222
    }
1223
1224
    #[test]
1225
    fn test_nan_conversion_to_larger() {
1226
        let nan16 = bf16::from_bits(0x7F81u16);
1227
        let neg_nan16 = bf16::from_bits(0xFF81u16);
1228
        let nan32 = f32::from_bits(0x7F80_0001u32);
1229
        let neg_nan32 = f32::from_bits(0xFF80_0001u32);
1230
        let nan32_from_16 = f32::from(nan16);
1231
        let neg_nan32_from_16 = f32::from(neg_nan16);
1232
        let nan64_from_16 = f64::from(nan16);
1233
        let neg_nan64_from_16 = f64::from(neg_nan16);
1234
        let nan64_from_32 = f64::from(nan32);
1235
        let neg_nan64_from_32 = f64::from(neg_nan32);
1236
1237
        assert!(nan16.is_nan() && nan16.is_sign_positive());
1238
        assert!(neg_nan16.is_nan() && neg_nan16.is_sign_negative());
1239
        assert!(nan32.is_nan() && nan32.is_sign_positive());
1240
        assert!(neg_nan32.is_nan() && neg_nan32.is_sign_negative());
1241
        assert!(nan32_from_16.is_nan() && nan32_from_16.is_sign_positive());
1242
        assert!(neg_nan32_from_16.is_nan() && neg_nan32_from_16.is_sign_negative());
1243
        assert!(nan64_from_16.is_nan() && nan64_from_16.is_sign_positive());
1244
        assert!(neg_nan64_from_16.is_nan() && neg_nan64_from_16.is_sign_negative());
1245
        assert!(nan64_from_32.is_nan() && nan64_from_32.is_sign_positive());
1246
        assert!(neg_nan64_from_32.is_nan() && neg_nan64_from_32.is_sign_negative());
1247
    }
1248
1249
    #[test]
1250
    fn test_bf16_to_f32() {
1251
        let f = bf16::from_f32(7.0);
1252
        assert_eq!(f.to_f32(), 7.0f32);
1253
1254
        // 7.1 is NOT exactly representable in 16-bit, it's rounded
1255
        let f = bf16::from_f32(7.1);
1256
        let diff = (f.to_f32() - 7.1f32).abs();
1257
        // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1
1258
        assert!(diff <= 4.0 * bf16::EPSILON.to_f32());
1259
1260
        let tiny32 = f32::from_bits(0x0001_0000u32);
1261
        assert_eq!(bf16::from_bits(0x0001).to_f32(), tiny32);
1262
        assert_eq!(bf16::from_bits(0x0005).to_f32(), 5.0 * tiny32);
1263
1264
        assert_eq!(bf16::from_bits(0x0001), bf16::from_f32(tiny32));
1265
        assert_eq!(bf16::from_bits(0x0005), bf16::from_f32(5.0 * tiny32));
1266
    }
1267
1268
    #[test]
1269
    fn test_bf16_to_f64() {
1270
        let f = bf16::from_f64(7.0);
1271
        assert_eq!(f.to_f64(), 7.0f64);
1272
1273
        // 7.1 is NOT exactly representable in 16-bit, it's rounded
1274
        let f = bf16::from_f64(7.1);
1275
        let diff = (f.to_f64() - 7.1f64).abs();
1276
        // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1
1277
        assert!(diff <= 4.0 * bf16::EPSILON.to_f64());
1278
1279
        let tiny64 = 2.0f64.powi(-133);
1280
        assert_eq!(bf16::from_bits(0x0001).to_f64(), tiny64);
1281
        assert_eq!(bf16::from_bits(0x0005).to_f64(), 5.0 * tiny64);
1282
1283
        assert_eq!(bf16::from_bits(0x0001), bf16::from_f64(tiny64));
1284
        assert_eq!(bf16::from_bits(0x0005), bf16::from_f64(5.0 * tiny64));
1285
    }
1286
1287
    #[test]
1288
    fn test_comparisons() {
1289
        let zero = bf16::from_f64(0.0);
1290
        let one = bf16::from_f64(1.0);
1291
        let neg_zero = bf16::from_f64(-0.0);
1292
        let neg_one = bf16::from_f64(-1.0);
1293
1294
        assert_eq!(zero.partial_cmp(&neg_zero), Some(Ordering::Equal));
1295
        assert_eq!(neg_zero.partial_cmp(&zero), Some(Ordering::Equal));
1296
        assert!(zero == neg_zero);
1297
        assert!(neg_zero == zero);
1298
        assert!(!(zero != neg_zero));
1299
        assert!(!(neg_zero != zero));
1300
        assert!(!(zero < neg_zero));
1301
        assert!(!(neg_zero < zero));
1302
        assert!(zero <= neg_zero);
1303
        assert!(neg_zero <= zero);
1304
        assert!(!(zero > neg_zero));
1305
        assert!(!(neg_zero > zero));
1306
        assert!(zero >= neg_zero);
1307
        assert!(neg_zero >= zero);
1308
1309
        assert_eq!(one.partial_cmp(&neg_zero), Some(Ordering::Greater));
1310
        assert_eq!(neg_zero.partial_cmp(&one), Some(Ordering::Less));
1311
        assert!(!(one == neg_zero));
1312
        assert!(!(neg_zero == one));
1313
        assert!(one != neg_zero);
1314
        assert!(neg_zero != one);
1315
        assert!(!(one < neg_zero));
1316
        assert!(neg_zero < one);
1317
        assert!(!(one <= neg_zero));
1318
        assert!(neg_zero <= one);
1319
        assert!(one > neg_zero);
1320
        assert!(!(neg_zero > one));
1321
        assert!(one >= neg_zero);
1322
        assert!(!(neg_zero >= one));
1323
1324
        assert_eq!(one.partial_cmp(&neg_one), Some(Ordering::Greater));
1325
        assert_eq!(neg_one.partial_cmp(&one), Some(Ordering::Less));
1326
        assert!(!(one == neg_one));
1327
        assert!(!(neg_one == one));
1328
        assert!(one != neg_one);
1329
        assert!(neg_one != one);
1330
        assert!(!(one < neg_one));
1331
        assert!(neg_one < one);
1332
        assert!(!(one <= neg_one));
1333
        assert!(neg_one <= one);
1334
        assert!(one > neg_one);
1335
        assert!(!(neg_one > one));
1336
        assert!(one >= neg_one);
1337
        assert!(!(neg_one >= one));
1338
    }
1339
1340
    #[test]
1341
    #[allow(clippy::erasing_op, clippy::identity_op)]
1342
    fn round_to_even_f32() {
1343
        // smallest positive subnormal = 0b0.0000_001 * 2^-126 = 2^-133
1344
        let min_sub = bf16::from_bits(1);
1345
        let min_sub_f = (-133f32).exp2();
1346
        assert_eq!(bf16::from_f32(min_sub_f).to_bits(), min_sub.to_bits());
1347
        assert_eq!(f32::from(min_sub).to_bits(), min_sub_f.to_bits());
1348
1349
        // 0.0000000_011111 rounded to 0.0000000 (< tie, no rounding)
1350
        // 0.0000000_100000 rounded to 0.0000000 (tie and even, remains at even)
1351
        // 0.0000000_100001 rounded to 0.0000001 (> tie, rounds up)
1352
        assert_eq!(
1353
            bf16::from_f32(min_sub_f * 0.49).to_bits(),
1354
            min_sub.to_bits() * 0
1355
        );
1356
        assert_eq!(
1357
            bf16::from_f32(min_sub_f * 0.50).to_bits(),
1358
            min_sub.to_bits() * 0
1359
        );
1360
        assert_eq!(
1361
            bf16::from_f32(min_sub_f * 0.51).to_bits(),
1362
            min_sub.to_bits() * 1
1363
        );
1364
1365
        // 0.0000001_011111 rounded to 0.0000001 (< tie, no rounding)
1366
        // 0.0000001_100000 rounded to 0.0000010 (tie and odd, rounds up to even)
1367
        // 0.0000001_100001 rounded to 0.0000010 (> tie, rounds up)
1368
        assert_eq!(
1369
            bf16::from_f32(min_sub_f * 1.49).to_bits(),
1370
            min_sub.to_bits() * 1
1371
        );
1372
        assert_eq!(
1373
            bf16::from_f32(min_sub_f * 1.50).to_bits(),
1374
            min_sub.to_bits() * 2
1375
        );
1376
        assert_eq!(
1377
            bf16::from_f32(min_sub_f * 1.51).to_bits(),
1378
            min_sub.to_bits() * 2
1379
        );
1380
1381
        // 0.0000010_011111 rounded to 0.0000010 (< tie, no rounding)
1382
        // 0.0000010_100000 rounded to 0.0000010 (tie and even, remains at even)
1383
        // 0.0000010_100001 rounded to 0.0000011 (> tie, rounds up)
1384
        assert_eq!(
1385
            bf16::from_f32(min_sub_f * 2.49).to_bits(),
1386
            min_sub.to_bits() * 2
1387
        );
1388
        assert_eq!(
1389
            bf16::from_f32(min_sub_f * 2.50).to_bits(),
1390
            min_sub.to_bits() * 2
1391
        );
1392
        assert_eq!(
1393
            bf16::from_f32(min_sub_f * 2.51).to_bits(),
1394
            min_sub.to_bits() * 3
1395
        );
1396
1397
        assert_eq!(
1398
            bf16::from_f32(250.49f32).to_bits(),
1399
            bf16::from_f32(250.0).to_bits()
1400
        );
1401
        assert_eq!(
1402
            bf16::from_f32(250.50f32).to_bits(),
1403
            bf16::from_f32(250.0).to_bits()
1404
        );
1405
        assert_eq!(
1406
            bf16::from_f32(250.51f32).to_bits(),
1407
            bf16::from_f32(251.0).to_bits()
1408
        );
1409
        assert_eq!(
1410
            bf16::from_f32(251.49f32).to_bits(),
1411
            bf16::from_f32(251.0).to_bits()
1412
        );
1413
        assert_eq!(
1414
            bf16::from_f32(251.50f32).to_bits(),
1415
            bf16::from_f32(252.0).to_bits()
1416
        );
1417
        assert_eq!(
1418
            bf16::from_f32(251.51f32).to_bits(),
1419
            bf16::from_f32(252.0).to_bits()
1420
        );
1421
        assert_eq!(
1422
            bf16::from_f32(252.49f32).to_bits(),
1423
            bf16::from_f32(252.0).to_bits()
1424
        );
1425
        assert_eq!(
1426
            bf16::from_f32(252.50f32).to_bits(),
1427
            bf16::from_f32(252.0).to_bits()
1428
        );
1429
        assert_eq!(
1430
            bf16::from_f32(252.51f32).to_bits(),
1431
            bf16::from_f32(253.0).to_bits()
1432
        );
1433
    }
1434
1435
    #[test]
1436
    #[allow(clippy::erasing_op, clippy::identity_op)]
1437
    fn round_to_even_f64() {
1438
        // smallest positive subnormal = 0b0.0000_001 * 2^-126 = 2^-133
1439
        let min_sub = bf16::from_bits(1);
1440
        let min_sub_f = (-133f64).exp2();
1441
        assert_eq!(bf16::from_f64(min_sub_f).to_bits(), min_sub.to_bits());
1442
        assert_eq!(f64::from(min_sub).to_bits(), min_sub_f.to_bits());
1443
1444
        // 0.0000000_011111 rounded to 0.0000000 (< tie, no rounding)
1445
        // 0.0000000_100000 rounded to 0.0000000 (tie and even, remains at even)
1446
        // 0.0000000_100001 rounded to 0.0000001 (> tie, rounds up)
1447
        assert_eq!(
1448
            bf16::from_f64(min_sub_f * 0.49).to_bits(),
1449
            min_sub.to_bits() * 0
1450
        );
1451
        assert_eq!(
1452
            bf16::from_f64(min_sub_f * 0.50).to_bits(),
1453
            min_sub.to_bits() * 0
1454
        );
1455
        assert_eq!(
1456
            bf16::from_f64(min_sub_f * 0.51).to_bits(),
1457
            min_sub.to_bits() * 1
1458
        );
1459
1460
        // 0.0000001_011111 rounded to 0.0000001 (< tie, no rounding)
1461
        // 0.0000001_100000 rounded to 0.0000010 (tie and odd, rounds up to even)
1462
        // 0.0000001_100001 rounded to 0.0000010 (> tie, rounds up)
1463
        assert_eq!(
1464
            bf16::from_f64(min_sub_f * 1.49).to_bits(),
1465
            min_sub.to_bits() * 1
1466
        );
1467
        assert_eq!(
1468
            bf16::from_f64(min_sub_f * 1.50).to_bits(),
1469
            min_sub.to_bits() * 2
1470
        );
1471
        assert_eq!(
1472
            bf16::from_f64(min_sub_f * 1.51).to_bits(),
1473
            min_sub.to_bits() * 2
1474
        );
1475
1476
        // 0.0000010_011111 rounded to 0.0000010 (< tie, no rounding)
1477
        // 0.0000010_100000 rounded to 0.0000010 (tie and even, remains at even)
1478
        // 0.0000010_100001 rounded to 0.0000011 (> tie, rounds up)
1479
        assert_eq!(
1480
            bf16::from_f64(min_sub_f * 2.49).to_bits(),
1481
            min_sub.to_bits() * 2
1482
        );
1483
        assert_eq!(
1484
            bf16::from_f64(min_sub_f * 2.50).to_bits(),
1485
            min_sub.to_bits() * 2
1486
        );
1487
        assert_eq!(
1488
            bf16::from_f64(min_sub_f * 2.51).to_bits(),
1489
            min_sub.to_bits() * 3
1490
        );
1491
1492
        assert_eq!(
1493
            bf16::from_f64(250.49f64).to_bits(),
1494
            bf16::from_f64(250.0).to_bits()
1495
        );
1496
        assert_eq!(
1497
            bf16::from_f64(250.50f64).to_bits(),
1498
            bf16::from_f64(250.0).to_bits()
1499
        );
1500
        assert_eq!(
1501
            bf16::from_f64(250.51f64).to_bits(),
1502
            bf16::from_f64(251.0).to_bits()
1503
        );
1504
        assert_eq!(
1505
            bf16::from_f64(251.49f64).to_bits(),
1506
            bf16::from_f64(251.0).to_bits()
1507
        );
1508
        assert_eq!(
1509
            bf16::from_f64(251.50f64).to_bits(),
1510
            bf16::from_f64(252.0).to_bits()
1511
        );
1512
        assert_eq!(
1513
            bf16::from_f64(251.51f64).to_bits(),
1514
            bf16::from_f64(252.0).to_bits()
1515
        );
1516
        assert_eq!(
1517
            bf16::from_f64(252.49f64).to_bits(),
1518
            bf16::from_f64(252.0).to_bits()
1519
        );
1520
        assert_eq!(
1521
            bf16::from_f64(252.50f64).to_bits(),
1522
            bf16::from_f64(252.0).to_bits()
1523
        );
1524
        assert_eq!(
1525
            bf16::from_f64(252.51f64).to_bits(),
1526
            bf16::from_f64(253.0).to_bits()
1527
        );
1528
    }
1529
1530
    impl quickcheck::Arbitrary for bf16 {
1531
        fn arbitrary(g: &mut quickcheck::Gen) -> Self {
1532
            bf16(u16::arbitrary(g))
1533
        }
1534
    }
1535
1536
    #[quickcheck]
1537
    fn qc_roundtrip_bf16_f32_is_identity(f: bf16) -> bool {
1538
        let roundtrip = bf16::from_f32(f.to_f32());
1539
        if f.is_nan() {
1540
            roundtrip.is_nan() && f.is_sign_negative() == roundtrip.is_sign_negative()
1541
        } else {
1542
            f.0 == roundtrip.0
1543
        }
1544
    }
1545
1546
    #[quickcheck]
1547
    fn qc_roundtrip_bf16_f64_is_identity(f: bf16) -> bool {
1548
        let roundtrip = bf16::from_f64(f.to_f64());
1549
        if f.is_nan() {
1550
            roundtrip.is_nan() && f.is_sign_negative() == roundtrip.is_sign_negative()
1551
        } else {
1552
            f.0 == roundtrip.0
1553
        }
1554
    }
1555
}