/rust/registry/src/index.crates.io-6f17d22bba15001f/half-2.4.1/src/binary16/arch.rs
Line | Count | Source (jump to first uncovered line) |
1 | | #![allow(dead_code, unused_imports)] |
2 | | use crate::leading_zeros::leading_zeros_u16; |
3 | | use core::mem; |
4 | | |
5 | | #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
6 | | mod x86; |
7 | | |
8 | | #[cfg(target_arch = "aarch64")] |
9 | | mod aarch64; |
10 | | |
11 | | macro_rules! convert_fn { |
12 | | (if x86_feature("f16c") { $f16c:expr } |
13 | | else if aarch64_feature("fp16") { $aarch64:expr } |
14 | | else { $fallback:expr }) => { |
15 | | cfg_if::cfg_if! { |
16 | | // Use intrinsics directly when a compile target or using no_std |
17 | | if #[cfg(all( |
18 | | any(target_arch = "x86", target_arch = "x86_64"), |
19 | | target_feature = "f16c" |
20 | | ))] { |
21 | | $f16c |
22 | | } |
23 | | else if #[cfg(all( |
24 | | target_arch = "aarch64", |
25 | | target_feature = "fp16" |
26 | | ))] { |
27 | | $aarch64 |
28 | | |
29 | | } |
30 | | |
31 | | // Use CPU feature detection if using std |
32 | | else if #[cfg(all( |
33 | | feature = "std", |
34 | | any(target_arch = "x86", target_arch = "x86_64") |
35 | | ))] { |
36 | | use std::arch::is_x86_feature_detected; |
37 | | if is_x86_feature_detected!("f16c") { |
38 | | $f16c |
39 | | } else { |
40 | | $fallback |
41 | | } |
42 | | } |
43 | | else if #[cfg(all( |
44 | | feature = "std", |
45 | | target_arch = "aarch64", |
46 | | ))] { |
47 | | use std::arch::is_aarch64_feature_detected; |
48 | | if is_aarch64_feature_detected!("fp16") { |
49 | | $aarch64 |
50 | | } else { |
51 | | $fallback |
52 | | } |
53 | | } |
54 | | |
55 | | // Fallback to software |
56 | | else { |
57 | | $fallback |
58 | | } |
59 | | } |
60 | | }; |
61 | | } |
62 | | |
63 | | #[inline] |
64 | 0 | pub(crate) fn f32_to_f16(f: f32) -> u16 { |
65 | 0 | convert_fn! { |
66 | 0 | if x86_feature("f16c") { |
67 | 0 | unsafe { x86::f32_to_f16_x86_f16c(f) } |
68 | 0 | } else if aarch64_feature("fp16") { |
69 | 0 | unsafe { aarch64::f32_to_f16_fp16(f) } |
70 | 0 | } else { |
71 | 0 | f32_to_f16_fallback(f) |
72 | 0 | } |
73 | 0 | } |
74 | 0 | } |
75 | | |
76 | | #[inline] |
77 | 0 | pub(crate) fn f64_to_f16(f: f64) -> u16 { |
78 | 0 | convert_fn! { |
79 | 0 | if x86_feature("f16c") { |
80 | 0 | unsafe { x86::f32_to_f16_x86_f16c(f as f32) } |
81 | 0 | } else if aarch64_feature("fp16") { |
82 | 0 | unsafe { aarch64::f64_to_f16_fp16(f) } |
83 | 0 | } else { |
84 | 0 | f64_to_f16_fallback(f) |
85 | 0 | } |
86 | 0 | } |
87 | 0 | } Unexecuted instantiation: half::binary16::arch::f64_to_f16 Unexecuted instantiation: half::binary16::arch::f64_to_f16 |
88 | | |
89 | | #[inline] |
90 | 0 | pub(crate) fn f16_to_f32(i: u16) -> f32 { |
91 | 0 | convert_fn! { |
92 | 0 | if x86_feature("f16c") { |
93 | 0 | unsafe { x86::f16_to_f32_x86_f16c(i) } |
94 | 0 | } else if aarch64_feature("fp16") { |
95 | 0 | unsafe { aarch64::f16_to_f32_fp16(i) } |
96 | 0 | } else { |
97 | 0 | f16_to_f32_fallback(i) |
98 | 0 | } |
99 | 0 | } |
100 | 0 | } |
101 | | |
102 | | #[inline] |
103 | 0 | pub(crate) fn f16_to_f64(i: u16) -> f64 { |
104 | 0 | convert_fn! { |
105 | 0 | if x86_feature("f16c") { |
106 | 0 | unsafe { x86::f16_to_f32_x86_f16c(i) as f64 } |
107 | 0 | } else if aarch64_feature("fp16") { |
108 | 0 | unsafe { aarch64::f16_to_f64_fp16(i) } |
109 | 0 | } else { |
110 | 0 | f16_to_f64_fallback(i) |
111 | 0 | } |
112 | 0 | } |
113 | 0 | } Unexecuted instantiation: half::binary16::arch::f16_to_f64 Unexecuted instantiation: half::binary16::arch::f16_to_f64 |
114 | | |
115 | | #[inline] |
116 | 0 | pub(crate) fn f32x4_to_f16x4(f: &[f32; 4]) -> [u16; 4] { |
117 | 0 | convert_fn! { |
118 | 0 | if x86_feature("f16c") { |
119 | 0 | unsafe { x86::f32x4_to_f16x4_x86_f16c(f) } |
120 | 0 | } else if aarch64_feature("fp16") { |
121 | 0 | unsafe { aarch64::f32x4_to_f16x4_fp16(f) } |
122 | 0 | } else { |
123 | 0 | f32x4_to_f16x4_fallback(f) |
124 | 0 | } |
125 | 0 | } |
126 | 0 | } |
127 | | |
128 | | #[inline] |
129 | 0 | pub(crate) fn f16x4_to_f32x4(i: &[u16; 4]) -> [f32; 4] { |
130 | 0 | convert_fn! { |
131 | 0 | if x86_feature("f16c") { |
132 | 0 | unsafe { x86::f16x4_to_f32x4_x86_f16c(i) } |
133 | 0 | } else if aarch64_feature("fp16") { |
134 | 0 | unsafe { aarch64::f16x4_to_f32x4_fp16(i) } |
135 | 0 | } else { |
136 | 0 | f16x4_to_f32x4_fallback(i) |
137 | 0 | } |
138 | 0 | } |
139 | 0 | } |
140 | | |
141 | | #[inline] |
142 | 0 | pub(crate) fn f64x4_to_f16x4(f: &[f64; 4]) -> [u16; 4] { |
143 | 0 | convert_fn! { |
144 | 0 | if x86_feature("f16c") { |
145 | 0 | unsafe { x86::f64x4_to_f16x4_x86_f16c(f) } |
146 | 0 | } else if aarch64_feature("fp16") { |
147 | 0 | unsafe { aarch64::f64x4_to_f16x4_fp16(f) } |
148 | 0 | } else { |
149 | 0 | f64x4_to_f16x4_fallback(f) |
150 | 0 | } |
151 | 0 | } |
152 | 0 | } |
153 | | |
154 | | #[inline] |
155 | 0 | pub(crate) fn f16x4_to_f64x4(i: &[u16; 4]) -> [f64; 4] { |
156 | 0 | convert_fn! { |
157 | 0 | if x86_feature("f16c") { |
158 | 0 | unsafe { x86::f16x4_to_f64x4_x86_f16c(i) } |
159 | 0 | } else if aarch64_feature("fp16") { |
160 | 0 | unsafe { aarch64::f16x4_to_f64x4_fp16(i) } |
161 | 0 | } else { |
162 | 0 | f16x4_to_f64x4_fallback(i) |
163 | 0 | } |
164 | 0 | } |
165 | 0 | } |
166 | | |
167 | | #[inline] |
168 | 0 | pub(crate) fn f32x8_to_f16x8(f: &[f32; 8]) -> [u16; 8] { |
169 | 0 | convert_fn! { |
170 | 0 | if x86_feature("f16c") { |
171 | 0 | unsafe { x86::f32x8_to_f16x8_x86_f16c(f) } |
172 | 0 | } else if aarch64_feature("fp16") { |
173 | 0 | { |
174 | 0 | let mut result = [0u16; 8]; |
175 | 0 | convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(), |
176 | 0 | aarch64::f32x4_to_f16x4_fp16); |
177 | 0 | result |
178 | 0 | } |
179 | 0 | } else { |
180 | 0 | f32x8_to_f16x8_fallback(f) |
181 | 0 | } |
182 | 0 | } |
183 | 0 | } |
184 | | |
185 | | #[inline] |
186 | 0 | pub(crate) fn f16x8_to_f32x8(i: &[u16; 8]) -> [f32; 8] { |
187 | 0 | convert_fn! { |
188 | 0 | if x86_feature("f16c") { |
189 | 0 | unsafe { x86::f16x8_to_f32x8_x86_f16c(i) } |
190 | 0 | } else if aarch64_feature("fp16") { |
191 | 0 | { |
192 | 0 | let mut result = [0f32; 8]; |
193 | 0 | convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(), |
194 | 0 | aarch64::f16x4_to_f32x4_fp16); |
195 | 0 | result |
196 | 0 | } |
197 | 0 | } else { |
198 | 0 | f16x8_to_f32x8_fallback(i) |
199 | 0 | } |
200 | 0 | } |
201 | 0 | } |
202 | | |
203 | | #[inline] |
204 | 0 | pub(crate) fn f64x8_to_f16x8(f: &[f64; 8]) -> [u16; 8] { |
205 | 0 | convert_fn! { |
206 | 0 | if x86_feature("f16c") { |
207 | 0 | unsafe { x86::f64x8_to_f16x8_x86_f16c(f) } |
208 | 0 | } else if aarch64_feature("fp16") { |
209 | 0 | { |
210 | 0 | let mut result = [0u16; 8]; |
211 | 0 | convert_chunked_slice_4(f.as_slice(), result.as_mut_slice(), |
212 | 0 | aarch64::f64x4_to_f16x4_fp16); |
213 | 0 | result |
214 | 0 | } |
215 | 0 | } else { |
216 | 0 | f64x8_to_f16x8_fallback(f) |
217 | 0 | } |
218 | 0 | } |
219 | 0 | } |
220 | | |
221 | | #[inline] |
222 | 0 | pub(crate) fn f16x8_to_f64x8(i: &[u16; 8]) -> [f64; 8] { |
223 | 0 | convert_fn! { |
224 | 0 | if x86_feature("f16c") { |
225 | 0 | unsafe { x86::f16x8_to_f64x8_x86_f16c(i) } |
226 | 0 | } else if aarch64_feature("fp16") { |
227 | 0 | { |
228 | 0 | let mut result = [0f64; 8]; |
229 | 0 | convert_chunked_slice_4(i.as_slice(), result.as_mut_slice(), |
230 | 0 | aarch64::f16x4_to_f64x4_fp16); |
231 | 0 | result |
232 | 0 | } |
233 | 0 | } else { |
234 | 0 | f16x8_to_f64x8_fallback(i) |
235 | 0 | } |
236 | 0 | } |
237 | 0 | } |
238 | | |
239 | | #[inline] |
240 | 0 | pub(crate) fn f32_to_f16_slice(src: &[f32], dst: &mut [u16]) { |
241 | 0 | convert_fn! { |
242 | 0 | if x86_feature("f16c") { |
243 | 0 | convert_chunked_slice_8(src, dst, x86::f32x8_to_f16x8_x86_f16c, |
244 | 0 | x86::f32x4_to_f16x4_x86_f16c) |
245 | 0 | } else if aarch64_feature("fp16") { |
246 | 0 | convert_chunked_slice_4(src, dst, aarch64::f32x4_to_f16x4_fp16) |
247 | 0 | } else { |
248 | 0 | slice_fallback(src, dst, f32_to_f16_fallback) |
249 | 0 | } |
250 | 0 | } |
251 | 0 | } |
252 | | |
253 | | #[inline] |
254 | 0 | pub(crate) fn f16_to_f32_slice(src: &[u16], dst: &mut [f32]) { |
255 | 0 | convert_fn! { |
256 | 0 | if x86_feature("f16c") { |
257 | 0 | convert_chunked_slice_8(src, dst, x86::f16x8_to_f32x8_x86_f16c, |
258 | 0 | x86::f16x4_to_f32x4_x86_f16c) |
259 | 0 | } else if aarch64_feature("fp16") { |
260 | 0 | convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f32x4_fp16) |
261 | 0 | } else { |
262 | 0 | slice_fallback(src, dst, f16_to_f32_fallback) |
263 | 0 | } |
264 | 0 | } |
265 | 0 | } |
266 | | |
267 | | #[inline] |
268 | 0 | pub(crate) fn f64_to_f16_slice(src: &[f64], dst: &mut [u16]) { |
269 | 0 | convert_fn! { |
270 | 0 | if x86_feature("f16c") { |
271 | 0 | convert_chunked_slice_8(src, dst, x86::f64x8_to_f16x8_x86_f16c, |
272 | 0 | x86::f64x4_to_f16x4_x86_f16c) |
273 | 0 | } else if aarch64_feature("fp16") { |
274 | 0 | convert_chunked_slice_4(src, dst, aarch64::f64x4_to_f16x4_fp16) |
275 | 0 | } else { |
276 | 0 | slice_fallback(src, dst, f64_to_f16_fallback) |
277 | 0 | } |
278 | 0 | } |
279 | 0 | } |
280 | | |
281 | | #[inline] |
282 | 0 | pub(crate) fn f16_to_f64_slice(src: &[u16], dst: &mut [f64]) { |
283 | 0 | convert_fn! { |
284 | 0 | if x86_feature("f16c") { |
285 | 0 | convert_chunked_slice_8(src, dst, x86::f16x8_to_f64x8_x86_f16c, |
286 | 0 | x86::f16x4_to_f64x4_x86_f16c) |
287 | 0 | } else if aarch64_feature("fp16") { |
288 | 0 | convert_chunked_slice_4(src, dst, aarch64::f16x4_to_f64x4_fp16) |
289 | 0 | } else { |
290 | 0 | slice_fallback(src, dst, f16_to_f64_fallback) |
291 | 0 | } |
292 | 0 | } |
293 | 0 | } |
294 | | |
295 | | macro_rules! math_fn { |
296 | | (if aarch64_feature("fp16") { $aarch64:expr } |
297 | | else { $fallback:expr }) => { |
298 | | cfg_if::cfg_if! { |
299 | | // Use intrinsics directly when a compile target or using no_std |
300 | | if #[cfg(all( |
301 | | target_arch = "aarch64", |
302 | | target_feature = "fp16" |
303 | | ))] { |
304 | | $aarch64 |
305 | | } |
306 | | |
307 | | // Use CPU feature detection if using std |
308 | | else if #[cfg(all( |
309 | | feature = "std", |
310 | | target_arch = "aarch64", |
311 | | not(target_feature = "fp16") |
312 | | ))] { |
313 | | use std::arch::is_aarch64_feature_detected; |
314 | | if is_aarch64_feature_detected!("fp16") { |
315 | | $aarch64 |
316 | | } else { |
317 | | $fallback |
318 | | } |
319 | | } |
320 | | |
321 | | // Fallback to software |
322 | | else { |
323 | | $fallback |
324 | | } |
325 | | } |
326 | | }; |
327 | | } |
328 | | |
329 | | #[inline] |
330 | 0 | pub(crate) fn add_f16(a: u16, b: u16) -> u16 { |
331 | 0 | math_fn! { |
332 | 0 | if aarch64_feature("fp16") { |
333 | 0 | unsafe { aarch64::add_f16_fp16(a, b) } |
334 | 0 | } else { |
335 | 0 | add_f16_fallback(a, b) |
336 | 0 | } |
337 | 0 | } |
338 | 0 | } |
339 | | |
340 | | #[inline] |
341 | 0 | pub(crate) fn subtract_f16(a: u16, b: u16) -> u16 { |
342 | 0 | math_fn! { |
343 | 0 | if aarch64_feature("fp16") { |
344 | 0 | unsafe { aarch64::subtract_f16_fp16(a, b) } |
345 | 0 | } else { |
346 | 0 | subtract_f16_fallback(a, b) |
347 | 0 | } |
348 | 0 | } |
349 | 0 | } |
350 | | |
351 | | #[inline] |
352 | 0 | pub(crate) fn multiply_f16(a: u16, b: u16) -> u16 { |
353 | 0 | math_fn! { |
354 | 0 | if aarch64_feature("fp16") { |
355 | 0 | unsafe { aarch64::multiply_f16_fp16(a, b) } |
356 | 0 | } else { |
357 | 0 | multiply_f16_fallback(a, b) |
358 | 0 | } |
359 | 0 | } |
360 | 0 | } |
361 | | |
362 | | #[inline] |
363 | 0 | pub(crate) fn divide_f16(a: u16, b: u16) -> u16 { |
364 | 0 | math_fn! { |
365 | 0 | if aarch64_feature("fp16") { |
366 | 0 | unsafe { aarch64::divide_f16_fp16(a, b) } |
367 | 0 | } else { |
368 | 0 | divide_f16_fallback(a, b) |
369 | 0 | } |
370 | 0 | } |
371 | 0 | } |
372 | | |
373 | | #[inline] |
374 | 0 | pub(crate) fn remainder_f16(a: u16, b: u16) -> u16 { |
375 | 0 | remainder_f16_fallback(a, b) |
376 | 0 | } |
377 | | |
378 | | #[inline] |
379 | 0 | pub(crate) fn product_f16<I: Iterator<Item = u16>>(iter: I) -> u16 { |
380 | 0 | math_fn! { |
381 | 0 | if aarch64_feature("fp16") { |
382 | 0 | iter.fold(0, |acc, x| unsafe { aarch64::multiply_f16_fp16(acc, x) }) |
383 | 0 | } else { |
384 | 0 | product_f16_fallback(iter) |
385 | 0 | } |
386 | 0 | } |
387 | 0 | } |
388 | | |
389 | | #[inline] |
390 | 0 | pub(crate) fn sum_f16<I: Iterator<Item = u16>>(iter: I) -> u16 { |
391 | 0 | math_fn! { |
392 | 0 | if aarch64_feature("fp16") { |
393 | 0 | iter.fold(0, |acc, x| unsafe { aarch64::add_f16_fp16(acc, x) }) |
394 | 0 | } else { |
395 | 0 | sum_f16_fallback(iter) |
396 | 0 | } |
397 | 0 | } |
398 | 0 | } |
399 | | |
400 | | /// Chunks sliced into x8 or x4 arrays |
401 | | #[inline] |
402 | 0 | fn convert_chunked_slice_8<S: Copy + Default, D: Copy>( |
403 | 0 | src: &[S], |
404 | 0 | dst: &mut [D], |
405 | 0 | fn8: unsafe fn(&[S; 8]) -> [D; 8], |
406 | 0 | fn4: unsafe fn(&[S; 4]) -> [D; 4], |
407 | 0 | ) { |
408 | 0 | assert_eq!(src.len(), dst.len()); |
409 | | |
410 | | // TODO: Can be further optimized with array_chunks when it becomes stabilized |
411 | | |
412 | 0 | let src_chunks = src.chunks_exact(8); |
413 | 0 | let mut dst_chunks = dst.chunks_exact_mut(8); |
414 | 0 | let src_remainder = src_chunks.remainder(); |
415 | 0 | for (s, d) in src_chunks.zip(&mut dst_chunks) { |
416 | 0 | let chunk: &[S; 8] = s.try_into().unwrap(); |
417 | 0 | d.copy_from_slice(unsafe { &fn8(chunk) }); |
418 | 0 | } |
419 | | |
420 | | // Process remainder |
421 | 0 | if src_remainder.len() > 4 { |
422 | 0 | let mut buf: [S; 8] = Default::default(); |
423 | 0 | buf[..src_remainder.len()].copy_from_slice(src_remainder); |
424 | 0 | let vec = unsafe { fn8(&buf) }; |
425 | 0 | let dst_remainder = dst_chunks.into_remainder(); |
426 | 0 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]); |
427 | 0 | } else if !src_remainder.is_empty() { |
428 | 0 | let mut buf: [S; 4] = Default::default(); |
429 | 0 | buf[..src_remainder.len()].copy_from_slice(src_remainder); |
430 | 0 | let vec = unsafe { fn4(&buf) }; |
431 | 0 | let dst_remainder = dst_chunks.into_remainder(); |
432 | 0 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]); |
433 | 0 | } |
434 | 0 | } |
435 | | |
436 | | /// Chunks sliced into x4 arrays |
437 | | #[inline] |
438 | 0 | fn convert_chunked_slice_4<S: Copy + Default, D: Copy>( |
439 | 0 | src: &[S], |
440 | 0 | dst: &mut [D], |
441 | 0 | f: unsafe fn(&[S; 4]) -> [D; 4], |
442 | 0 | ) { |
443 | 0 | assert_eq!(src.len(), dst.len()); |
444 | | |
445 | | // TODO: Can be further optimized with array_chunks when it becomes stabilized |
446 | | |
447 | 0 | let src_chunks = src.chunks_exact(4); |
448 | 0 | let mut dst_chunks = dst.chunks_exact_mut(4); |
449 | 0 | let src_remainder = src_chunks.remainder(); |
450 | 0 | for (s, d) in src_chunks.zip(&mut dst_chunks) { |
451 | 0 | let chunk: &[S; 4] = s.try_into().unwrap(); |
452 | 0 | d.copy_from_slice(unsafe { &f(chunk) }); |
453 | 0 | } |
454 | | |
455 | | // Process remainder |
456 | 0 | if !src_remainder.is_empty() { |
457 | 0 | let mut buf: [S; 4] = Default::default(); |
458 | 0 | buf[..src_remainder.len()].copy_from_slice(src_remainder); |
459 | 0 | let vec = unsafe { f(&buf) }; |
460 | 0 | let dst_remainder = dst_chunks.into_remainder(); |
461 | 0 | dst_remainder.copy_from_slice(&vec[..dst_remainder.len()]); |
462 | 0 | } |
463 | 0 | } |
464 | | |
465 | | /////////////// Fallbacks //////////////// |
466 | | |
467 | | // In the below functions, round to nearest, with ties to even. |
468 | | // Let us call the most significant bit that will be shifted out the round_bit. |
469 | | // |
470 | | // Round up if either |
471 | | // a) Removed part > tie. |
472 | | // (mantissa & round_bit) != 0 && (mantissa & (round_bit - 1)) != 0 |
473 | | // b) Removed part == tie, and retained part is odd. |
474 | | // (mantissa & round_bit) != 0 && (mantissa & (2 * round_bit)) != 0 |
475 | | // (If removed part == tie and retained part is even, do not round up.) |
476 | | // These two conditions can be combined into one: |
477 | | // (mantissa & round_bit) != 0 && (mantissa & ((round_bit - 1) | (2 * round_bit))) != 0 |
478 | | // which can be simplified into |
479 | | // (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0 |
480 | | |
481 | | #[inline] |
482 | 0 | pub(crate) const fn f32_to_f16_fallback(value: f32) -> u16 { |
483 | 0 | // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized |
484 | 0 | // Convert to raw bytes |
485 | 0 | let x: u32 = unsafe { mem::transmute::<f32, u32>(value) }; |
486 | 0 |
|
487 | 0 | // Extract IEEE754 components |
488 | 0 | let sign = x & 0x8000_0000u32; |
489 | 0 | let exp = x & 0x7F80_0000u32; |
490 | 0 | let man = x & 0x007F_FFFFu32; |
491 | 0 |
|
492 | 0 | // Check for all exponent bits being set, which is Infinity or NaN |
493 | 0 | if exp == 0x7F80_0000u32 { |
494 | | // Set mantissa MSB for NaN (and also keep shifted mantissa bits) |
495 | 0 | let nan_bit = if man == 0 { 0 } else { 0x0200u32 }; |
496 | 0 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16; |
497 | 0 | } |
498 | 0 |
|
499 | 0 | // The number is normalized, start assembling half precision version |
500 | 0 | let half_sign = sign >> 16; |
501 | 0 | // Unbias the exponent, then bias for half precision |
502 | 0 | let unbiased_exp = ((exp >> 23) as i32) - 127; |
503 | 0 | let half_exp = unbiased_exp + 15; |
504 | 0 |
|
505 | 0 | // Check for exponent overflow, return +infinity |
506 | 0 | if half_exp >= 0x1F { |
507 | 0 | return (half_sign | 0x7C00u32) as u16; |
508 | 0 | } |
509 | 0 |
|
510 | 0 | // Check for underflow |
511 | 0 | if half_exp <= 0 { |
512 | | // Check mantissa for what we can do |
513 | 0 | if 14 - half_exp > 24 { |
514 | | // No rounding possibility, so this is a full underflow, return signed zero |
515 | 0 | return half_sign as u16; |
516 | 0 | } |
517 | 0 | // Don't forget about hidden leading mantissa bit when assembling mantissa |
518 | 0 | let man = man | 0x0080_0000u32; |
519 | 0 | let mut half_man = man >> (14 - half_exp); |
520 | 0 | // Check for rounding (see comment above functions) |
521 | 0 | let round_bit = 1 << (13 - half_exp); |
522 | 0 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
523 | 0 | half_man += 1; |
524 | 0 | } |
525 | | // No exponent for subnormals |
526 | 0 | return (half_sign | half_man) as u16; |
527 | 0 | } |
528 | 0 |
|
529 | 0 | // Rebias the exponent |
530 | 0 | let half_exp = (half_exp as u32) << 10; |
531 | 0 | let half_man = man >> 13; |
532 | 0 | // Check for rounding (see comment above functions) |
533 | 0 | let round_bit = 0x0000_1000u32; |
534 | 0 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
535 | | // Round it |
536 | 0 | ((half_sign | half_exp | half_man) + 1) as u16 |
537 | | } else { |
538 | 0 | (half_sign | half_exp | half_man) as u16 |
539 | | } |
540 | 0 | } |
541 | | |
542 | | #[inline] |
543 | 0 | pub(crate) const fn f64_to_f16_fallback(value: f64) -> u16 { |
544 | 0 | // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always |
545 | 0 | // be lost on half-precision. |
546 | 0 | // TODO: Replace mem::transmute with to_bits() once to_bits is const-stabilized |
547 | 0 | let val: u64 = unsafe { mem::transmute::<f64, u64>(value) }; |
548 | 0 | let x = (val >> 32) as u32; |
549 | 0 |
|
550 | 0 | // Extract IEEE754 components |
551 | 0 | let sign = x & 0x8000_0000u32; |
552 | 0 | let exp = x & 0x7FF0_0000u32; |
553 | 0 | let man = x & 0x000F_FFFFu32; |
554 | 0 |
|
555 | 0 | // Check for all exponent bits being set, which is Infinity or NaN |
556 | 0 | if exp == 0x7FF0_0000u32 { |
557 | | // Set mantissa MSB for NaN (and also keep shifted mantissa bits). |
558 | | // We also have to check the last 32 bits. |
559 | 0 | let nan_bit = if man == 0 && (val as u32 == 0) { |
560 | 0 | 0 |
561 | | } else { |
562 | 0 | 0x0200u32 |
563 | | }; |
564 | 0 | return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 10)) as u16; |
565 | 0 | } |
566 | 0 |
|
567 | 0 | // The number is normalized, start assembling half precision version |
568 | 0 | let half_sign = sign >> 16; |
569 | 0 | // Unbias the exponent, then bias for half precision |
570 | 0 | let unbiased_exp = ((exp >> 20) as i64) - 1023; |
571 | 0 | let half_exp = unbiased_exp + 15; |
572 | 0 |
|
573 | 0 | // Check for exponent overflow, return +infinity |
574 | 0 | if half_exp >= 0x1F { |
575 | 0 | return (half_sign | 0x7C00u32) as u16; |
576 | 0 | } |
577 | 0 |
|
578 | 0 | // Check for underflow |
579 | 0 | if half_exp <= 0 { |
580 | | // Check mantissa for what we can do |
581 | 0 | if 10 - half_exp > 21 { |
582 | | // No rounding possibility, so this is a full underflow, return signed zero |
583 | 0 | return half_sign as u16; |
584 | 0 | } |
585 | 0 | // Don't forget about hidden leading mantissa bit when assembling mantissa |
586 | 0 | let man = man | 0x0010_0000u32; |
587 | 0 | let mut half_man = man >> (11 - half_exp); |
588 | 0 | // Check for rounding (see comment above functions) |
589 | 0 | let round_bit = 1 << (10 - half_exp); |
590 | 0 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
591 | 0 | half_man += 1; |
592 | 0 | } |
593 | | // No exponent for subnormals |
594 | 0 | return (half_sign | half_man) as u16; |
595 | 0 | } |
596 | 0 |
|
597 | 0 | // Rebias the exponent |
598 | 0 | let half_exp = (half_exp as u32) << 10; |
599 | 0 | let half_man = man >> 10; |
600 | 0 | // Check for rounding (see comment above functions) |
601 | 0 | let round_bit = 0x0000_0200u32; |
602 | 0 | if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { |
603 | | // Round it |
604 | 0 | ((half_sign | half_exp | half_man) + 1) as u16 |
605 | | } else { |
606 | 0 | (half_sign | half_exp | half_man) as u16 |
607 | | } |
608 | 0 | } Unexecuted instantiation: half::binary16::arch::f64_to_f16_fallback Unexecuted instantiation: half::binary16::arch::f64_to_f16_fallback |
609 | | |
610 | | #[inline] |
611 | 0 | pub(crate) const fn f16_to_f32_fallback(i: u16) -> f32 { |
612 | 0 | // Check for signed zero |
613 | 0 | // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized |
614 | 0 | if i & 0x7FFFu16 == 0 { |
615 | 0 | return unsafe { mem::transmute::<u32, f32>((i as u32) << 16) }; |
616 | 0 | } |
617 | 0 |
|
618 | 0 | let half_sign = (i & 0x8000u16) as u32; |
619 | 0 | let half_exp = (i & 0x7C00u16) as u32; |
620 | 0 | let half_man = (i & 0x03FFu16) as u32; |
621 | 0 |
|
622 | 0 | // Check for an infinity or NaN when all exponent bits set |
623 | 0 | if half_exp == 0x7C00u32 { |
624 | | // Check for signed infinity if mantissa is zero |
625 | 0 | if half_man == 0 { |
626 | 0 | return unsafe { mem::transmute::<u32, f32>((half_sign << 16) | 0x7F80_0000u32) }; |
627 | | } else { |
628 | | // NaN, keep current mantissa but also set most significiant mantissa bit |
629 | | return unsafe { |
630 | 0 | mem::transmute::<u32, f32>((half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13)) |
631 | | }; |
632 | | } |
633 | 0 | } |
634 | 0 |
|
635 | 0 | // Calculate single-precision components with adjusted exponent |
636 | 0 | let sign = half_sign << 16; |
637 | 0 | // Unbias exponent |
638 | 0 | let unbiased_exp = ((half_exp as i32) >> 10) - 15; |
639 | 0 |
|
640 | 0 | // Check for subnormals, which will be normalized by adjusting exponent |
641 | 0 | if half_exp == 0 { |
642 | | // Calculate how much to adjust the exponent by |
643 | 0 | let e = leading_zeros_u16(half_man as u16) - 6; |
644 | 0 |
|
645 | 0 | // Rebias and adjust exponent |
646 | 0 | let exp = (127 - 15 - e) << 23; |
647 | 0 | let man = (half_man << (14 + e)) & 0x7F_FF_FFu32; |
648 | 0 | return unsafe { mem::transmute::<u32, f32>(sign | exp | man) }; |
649 | 0 | } |
650 | 0 |
|
651 | 0 | // Rebias exponent for a normalized normal |
652 | 0 | let exp = ((unbiased_exp + 127) as u32) << 23; |
653 | 0 | let man = (half_man & 0x03FFu32) << 13; |
654 | 0 | unsafe { mem::transmute::<u32, f32>(sign | exp | man) } |
655 | 0 | } |
656 | | |
657 | | #[inline] |
658 | 0 | pub(crate) const fn f16_to_f64_fallback(i: u16) -> f64 { |
659 | 0 | // Check for signed zero |
660 | 0 | // TODO: Replace mem::transmute with from_bits() once from_bits is const-stabilized |
661 | 0 | if i & 0x7FFFu16 == 0 { |
662 | 0 | return unsafe { mem::transmute::<u64, f64>((i as u64) << 48) }; |
663 | 0 | } |
664 | 0 |
|
665 | 0 | let half_sign = (i & 0x8000u16) as u64; |
666 | 0 | let half_exp = (i & 0x7C00u16) as u64; |
667 | 0 | let half_man = (i & 0x03FFu16) as u64; |
668 | 0 |
|
669 | 0 | // Check for an infinity or NaN when all exponent bits set |
670 | 0 | if half_exp == 0x7C00u64 { |
671 | | // Check for signed infinity if mantissa is zero |
672 | 0 | if half_man == 0 { |
673 | | return unsafe { |
674 | 0 | mem::transmute::<u64, f64>((half_sign << 48) | 0x7FF0_0000_0000_0000u64) |
675 | | }; |
676 | | } else { |
677 | | // NaN, keep current mantissa but also set most significiant mantissa bit |
678 | | return unsafe { |
679 | 0 | mem::transmute::<u64, f64>( |
680 | 0 | (half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 42), |
681 | 0 | ) |
682 | | }; |
683 | | } |
684 | 0 | } |
685 | 0 |
|
686 | 0 | // Calculate double-precision components with adjusted exponent |
687 | 0 | let sign = half_sign << 48; |
688 | 0 | // Unbias exponent |
689 | 0 | let unbiased_exp = ((half_exp as i64) >> 10) - 15; |
690 | 0 |
|
691 | 0 | // Check for subnormals, which will be normalized by adjusting exponent |
692 | 0 | if half_exp == 0 { |
693 | | // Calculate how much to adjust the exponent by |
694 | 0 | let e = leading_zeros_u16(half_man as u16) - 6; |
695 | 0 |
|
696 | 0 | // Rebias and adjust exponent |
697 | 0 | let exp = ((1023 - 15 - e) as u64) << 52; |
698 | 0 | let man = (half_man << (43 + e)) & 0xF_FFFF_FFFF_FFFFu64; |
699 | 0 | return unsafe { mem::transmute::<u64, f64>(sign | exp | man) }; |
700 | 0 | } |
701 | 0 |
|
702 | 0 | // Rebias exponent for a normalized normal |
703 | 0 | let exp = ((unbiased_exp + 1023) as u64) << 52; |
704 | 0 | let man = (half_man & 0x03FFu64) << 42; |
705 | 0 | unsafe { mem::transmute::<u64, f64>(sign | exp | man) } |
706 | 0 | } Unexecuted instantiation: half::binary16::arch::f16_to_f64_fallback Unexecuted instantiation: half::binary16::arch::f16_to_f64_fallback |
707 | | |
708 | | #[inline] |
709 | 0 | fn f16x4_to_f32x4_fallback(v: &[u16; 4]) -> [f32; 4] { |
710 | 0 | [ |
711 | 0 | f16_to_f32_fallback(v[0]), |
712 | 0 | f16_to_f32_fallback(v[1]), |
713 | 0 | f16_to_f32_fallback(v[2]), |
714 | 0 | f16_to_f32_fallback(v[3]), |
715 | 0 | ] |
716 | 0 | } |
717 | | |
718 | | #[inline] |
719 | 0 | fn f32x4_to_f16x4_fallback(v: &[f32; 4]) -> [u16; 4] { |
720 | 0 | [ |
721 | 0 | f32_to_f16_fallback(v[0]), |
722 | 0 | f32_to_f16_fallback(v[1]), |
723 | 0 | f32_to_f16_fallback(v[2]), |
724 | 0 | f32_to_f16_fallback(v[3]), |
725 | 0 | ] |
726 | 0 | } |
727 | | |
728 | | #[inline] |
729 | 0 | fn f16x4_to_f64x4_fallback(v: &[u16; 4]) -> [f64; 4] { |
730 | 0 | [ |
731 | 0 | f16_to_f64_fallback(v[0]), |
732 | 0 | f16_to_f64_fallback(v[1]), |
733 | 0 | f16_to_f64_fallback(v[2]), |
734 | 0 | f16_to_f64_fallback(v[3]), |
735 | 0 | ] |
736 | 0 | } |
737 | | |
738 | | #[inline] |
739 | 0 | fn f64x4_to_f16x4_fallback(v: &[f64; 4]) -> [u16; 4] { |
740 | 0 | [ |
741 | 0 | f64_to_f16_fallback(v[0]), |
742 | 0 | f64_to_f16_fallback(v[1]), |
743 | 0 | f64_to_f16_fallback(v[2]), |
744 | 0 | f64_to_f16_fallback(v[3]), |
745 | 0 | ] |
746 | 0 | } |
747 | | |
748 | | #[inline] |
749 | 0 | fn f16x8_to_f32x8_fallback(v: &[u16; 8]) -> [f32; 8] { |
750 | 0 | [ |
751 | 0 | f16_to_f32_fallback(v[0]), |
752 | 0 | f16_to_f32_fallback(v[1]), |
753 | 0 | f16_to_f32_fallback(v[2]), |
754 | 0 | f16_to_f32_fallback(v[3]), |
755 | 0 | f16_to_f32_fallback(v[4]), |
756 | 0 | f16_to_f32_fallback(v[5]), |
757 | 0 | f16_to_f32_fallback(v[6]), |
758 | 0 | f16_to_f32_fallback(v[7]), |
759 | 0 | ] |
760 | 0 | } |
761 | | |
762 | | #[inline] |
763 | 0 | fn f32x8_to_f16x8_fallback(v: &[f32; 8]) -> [u16; 8] { |
764 | 0 | [ |
765 | 0 | f32_to_f16_fallback(v[0]), |
766 | 0 | f32_to_f16_fallback(v[1]), |
767 | 0 | f32_to_f16_fallback(v[2]), |
768 | 0 | f32_to_f16_fallback(v[3]), |
769 | 0 | f32_to_f16_fallback(v[4]), |
770 | 0 | f32_to_f16_fallback(v[5]), |
771 | 0 | f32_to_f16_fallback(v[6]), |
772 | 0 | f32_to_f16_fallback(v[7]), |
773 | 0 | ] |
774 | 0 | } |
775 | | |
776 | | #[inline] |
777 | 0 | fn f16x8_to_f64x8_fallback(v: &[u16; 8]) -> [f64; 8] { |
778 | 0 | [ |
779 | 0 | f16_to_f64_fallback(v[0]), |
780 | 0 | f16_to_f64_fallback(v[1]), |
781 | 0 | f16_to_f64_fallback(v[2]), |
782 | 0 | f16_to_f64_fallback(v[3]), |
783 | 0 | f16_to_f64_fallback(v[4]), |
784 | 0 | f16_to_f64_fallback(v[5]), |
785 | 0 | f16_to_f64_fallback(v[6]), |
786 | 0 | f16_to_f64_fallback(v[7]), |
787 | 0 | ] |
788 | 0 | } |
789 | | |
790 | | #[inline] |
791 | 0 | fn f64x8_to_f16x8_fallback(v: &[f64; 8]) -> [u16; 8] { |
792 | 0 | [ |
793 | 0 | f64_to_f16_fallback(v[0]), |
794 | 0 | f64_to_f16_fallback(v[1]), |
795 | 0 | f64_to_f16_fallback(v[2]), |
796 | 0 | f64_to_f16_fallback(v[3]), |
797 | 0 | f64_to_f16_fallback(v[4]), |
798 | 0 | f64_to_f16_fallback(v[5]), |
799 | 0 | f64_to_f16_fallback(v[6]), |
800 | 0 | f64_to_f16_fallback(v[7]), |
801 | 0 | ] |
802 | 0 | } |
803 | | |
804 | | #[inline] |
805 | 0 | fn slice_fallback<S: Copy, D>(src: &[S], dst: &mut [D], f: fn(S) -> D) { |
806 | 0 | assert_eq!(src.len(), dst.len()); |
807 | 0 | for (s, d) in src.iter().copied().zip(dst.iter_mut()) { |
808 | 0 | *d = f(s); |
809 | 0 | } |
810 | 0 | } |
811 | | |
812 | | #[inline] |
813 | 0 | fn add_f16_fallback(a: u16, b: u16) -> u16 { |
814 | 0 | f32_to_f16(f16_to_f32(a) + f16_to_f32(b)) |
815 | 0 | } |
816 | | |
817 | | #[inline] |
818 | 0 | fn subtract_f16_fallback(a: u16, b: u16) -> u16 { |
819 | 0 | f32_to_f16(f16_to_f32(a) - f16_to_f32(b)) |
820 | 0 | } |
821 | | |
822 | | #[inline] |
823 | 0 | fn multiply_f16_fallback(a: u16, b: u16) -> u16 { |
824 | 0 | f32_to_f16(f16_to_f32(a) * f16_to_f32(b)) |
825 | 0 | } |
826 | | |
827 | | #[inline] |
828 | 0 | fn divide_f16_fallback(a: u16, b: u16) -> u16 { |
829 | 0 | f32_to_f16(f16_to_f32(a) / f16_to_f32(b)) |
830 | 0 | } |
831 | | |
832 | | #[inline] |
833 | 0 | fn remainder_f16_fallback(a: u16, b: u16) -> u16 { |
834 | 0 | f32_to_f16(f16_to_f32(a) % f16_to_f32(b)) |
835 | 0 | } |
836 | | |
837 | | #[inline] |
838 | 0 | fn product_f16_fallback<I: Iterator<Item = u16>>(iter: I) -> u16 { |
839 | 0 | f32_to_f16(iter.map(f16_to_f32).product()) |
840 | 0 | } |
841 | | |
842 | | #[inline] |
843 | 0 | fn sum_f16_fallback<I: Iterator<Item = u16>>(iter: I) -> u16 { |
844 | 0 | f32_to_f16(iter.map(f16_to_f32).sum()) |
845 | 0 | } |
846 | | |
847 | | // TODO SIMD arithmetic |