Coverage Report

Created: 2024-10-01 06:54

/src/Simd/src/Simd/SimdFloat16.h
Line
Count
Source (jump to first uncovered line)
1
/*
2
* Simd Library (http://ermig1979.github.io/Simd).
3
*
4
* Copyright (c) 2011-2021 Yermalayeu Ihar.
5
*
6
* Permission is hereby granted, free of charge, to any person obtaining a copy
7
* of this software and associated documentation files (the "Software"), to deal
8
* in the Software without restriction, including without limitation the rights
9
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
* copies of the Software, and to permit persons to whom the Software is
11
* furnished to do so, subject to the following conditions:
12
*
13
* The above copyright notice and this permission notice shall be included in
14
* all copies or substantial portions of the Software.
15
*
16
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
* SOFTWARE.
23
*/
24
#ifndef __SimdFloat16_h__
25
#define __SimdFloat16_h__
26
27
#include "Simd/SimdInit.h"
28
29
namespace Simd
30
{
31
    namespace Base
32
    {
33
        namespace Fp16
34
        {
35
            union Bits
36
            {
37
                float f;
38
                int32_t si;
39
                uint32_t ui;
40
            };
41
42
            const int SHIFT = 13;
43
            const int SHIFT_SIGN = 16;
44
45
            const int32_t INF_N = 0x7F800000; // flt32 infinity
46
            const int32_t MAX_N = 0x477FE000; // max flt16 normal as a flt32
47
            const int32_t MIN_N = 0x38800000; // min flt16 normal as a flt32
48
            const int32_t SIGN_N = 0x80000000; // flt32 sign bit
49
50
            const int32_t INF_C = INF_N >> SHIFT;
51
            const int32_t NAN_N = (INF_C + 1) << SHIFT; // minimum flt16 nan as a flt32
52
            const int32_t MAX_C = MAX_N >> SHIFT;
53
            const int32_t MIN_C = MIN_N >> SHIFT;
54
            const int32_t SIGN_C = SIGN_N >> SHIFT_SIGN; // flt16 sign bit
55
56
            const int32_t MUL_N = 0x52000000; // (1 << 23) / MIN_N
57
            const int32_t MUL_C = 0x33800000; // MIN_N / (1 << (23 - shift))
58
59
            const int32_t SUB_C = 0x003FF; // max flt32 subnormal down shifted
60
            const int32_t NOR_C = 0x00400; // min flt32 normal down shifted
61
62
            const int32_t MAX_D = INF_C - MAX_C - 1;
63
            const int32_t MIN_D = MIN_C - SUB_C - 1;
64
        }
65
66
        SIMD_INLINE uint16_t Float32ToFloat16(float value)
67
0
        {
68
0
            Fp16::Bits v, s;
69
0
            v.f = value;
70
0
            uint32_t sign = v.si & Fp16::SIGN_N;
71
0
            v.si ^= sign;
72
0
            sign >>= Fp16::SHIFT_SIGN; // logical shift
73
0
            s.si = Fp16::MUL_N;
74
0
            s.si = int32_t(s.f * v.f); // correct subnormals
75
0
            v.si ^= (s.si ^ v.si) & -(Fp16::MIN_N > v.si);
76
0
            v.si ^= (Fp16::INF_N ^ v.si) & -((Fp16::INF_N > v.si) & (v.si > Fp16::MAX_N));
77
0
            v.si ^= (Fp16::NAN_N ^ v.si) & -((Fp16::NAN_N > v.si) & (v.si > Fp16::INF_N));
78
0
            v.ui >>= Fp16::SHIFT; // logical shift
79
0
            v.si ^= ((v.si - Fp16::MAX_D) ^ v.si) & -(v.si > Fp16::MAX_C);
80
0
            v.si ^= ((v.si - Fp16::MIN_D) ^ v.si) & -(v.si > Fp16::SUB_C);
81
0
            return v.ui | sign;
82
0
        }
83
84
        SIMD_INLINE float Float16ToFloat32(uint16_t value)
85
0
        {
86
0
            Fp16::Bits v;
87
0
            v.ui = value;
88
0
            int32_t sign = v.si & Fp16::SIGN_C;
89
0
            v.si ^= sign;
90
0
            sign <<= Fp16::SHIFT_SIGN;
91
0
            v.si ^= ((v.si + Fp16::MIN_D) ^ v.si) & -(v.si > Fp16::SUB_C);
92
0
            v.si ^= ((v.si + Fp16::MAX_D) ^ v.si) & -(v.si > Fp16::MAX_C);
93
0
            Fp16::Bits s;
94
0
            s.si = Fp16::MUL_C;
95
0
            s.f *= v.si;
96
0
            int32_t mask = -(Fp16::NOR_C > v.si);
97
0
            v.si <<= Fp16::SHIFT;
98
0
            v.si ^= (s.si ^ v.si) & mask;
99
0
            v.si |= sign;
100
0
            return v.f;
101
0
        }
102
    }
103
104
#ifdef SIMD_SSE41_ENABLE    
105
    namespace Sse41
106
    {
107
        namespace Fp16
108
        {
109
            union Bits
110
            {
111
                __m128 f;
112
                __m128i i;
113
            };
114
115
            const __m128i INF_N = SIMD_MM_SET1_EPI32(Base::Fp16::INF_N);
116
            const __m128i MAX_N = SIMD_MM_SET1_EPI32(Base::Fp16::MAX_N);
117
            const __m128i MIN_N = SIMD_MM_SET1_EPI32(Base::Fp16::MIN_N);
118
            const __m128i SIGN_N = SIMD_MM_SET1_EPI32(Base::Fp16::SIGN_N);
119
120
            const __m128i INF_C = SIMD_MM_SET1_EPI32(Base::Fp16::INF_C);
121
            const __m128i NAN_N = SIMD_MM_SET1_EPI32(Base::Fp16::NAN_N);
122
            const __m128i MAX_C = SIMD_MM_SET1_EPI32(Base::Fp16::MAX_C);
123
            const __m128i MIN_C = SIMD_MM_SET1_EPI32(Base::Fp16::MIN_C);
124
            const __m128i SIGN_C = SIMD_MM_SET1_EPI32(Base::Fp16::SIGN_C);
125
126
            const __m128i MUL_N = SIMD_MM_SET1_EPI32(Base::Fp16::MUL_N);
127
            const __m128i MUL_C = SIMD_MM_SET1_EPI32(Base::Fp16::MUL_C);
128
129
            const __m128i SUB_C = SIMD_MM_SET1_EPI32(Base::Fp16::SUB_C);
130
            const __m128i NOR_C = SIMD_MM_SET1_EPI32(Base::Fp16::NOR_C);
131
132
            const __m128i MAX_D = SIMD_MM_SET1_EPI32(Base::Fp16::MAX_D);
133
            const __m128i MIN_D = SIMD_MM_SET1_EPI32(Base::Fp16::MIN_D);
134
        }
135
136
        SIMD_INLINE __m128i Float32ToFloat16(__m128 value)
137
0
        {
138
0
            Fp16::Bits v, s;
139
0
            v.f = value;
140
0
            __m128i sign = _mm_and_si128(v.i, Fp16::SIGN_N);
141
0
            v.i = _mm_xor_si128(v.i, sign);
142
0
            sign = _mm_srli_epi32(sign, Base::Fp16::SHIFT_SIGN);
143
0
            s.i = Fp16::MUL_N;
144
0
            s.i = _mm_cvtps_epi32(_mm_floor_ps(_mm_mul_ps(s.f, v.f))); 
145
0
            v.i = _mm_xor_si128(v.i, _mm_and_si128(_mm_xor_si128(s.i, v.i), _mm_cmpgt_epi32(Fp16::MIN_N, v.i)));
146
0
            v.i = _mm_xor_si128(v.i, _mm_and_si128(_mm_xor_si128(Fp16::INF_N, v.i), _mm_and_si128(_mm_cmpgt_epi32(Fp16::INF_N, v.i), _mm_cmpgt_epi32(v.i, Fp16::MAX_N))));
147
0
            v.i = _mm_xor_si128(v.i, _mm_and_si128(_mm_xor_si128(Fp16::NAN_N, v.i), _mm_and_si128(_mm_cmpgt_epi32(Fp16::NAN_N, v.i), _mm_cmpgt_epi32(v.i, Fp16::INF_N))));
148
0
            v.i = _mm_srli_epi32(v.i, Base::Fp16::SHIFT); 
149
0
            v.i = _mm_xor_si128(v.i, _mm_and_si128(_mm_xor_si128(_mm_sub_epi32(v.i, Fp16::MAX_D), v.i), _mm_cmpgt_epi32(v.i, Fp16::MAX_C)));
150
0
            v.i = _mm_xor_si128(v.i, _mm_and_si128(_mm_xor_si128(_mm_sub_epi32(v.i, Fp16::MIN_D), v.i), _mm_cmpgt_epi32(v.i, Fp16::SUB_C)));
151
0
            return _mm_or_si128(v.i, sign);
152
0
        }
153
154
        SIMD_INLINE __m128 Float16ToFloat32(__m128i value)
155
0
        {
156
0
            Fp16::Bits v;
157
0
            v.i = value;
158
0
            __m128i sign = _mm_and_si128(v.i, Fp16::SIGN_C);
159
0
            v.i = _mm_xor_si128(v.i, sign);
160
0
            sign = _mm_slli_epi32(sign, Base::Fp16::SHIFT_SIGN);
161
0
            v.i = _mm_xor_si128(v.i, _mm_and_si128(_mm_xor_si128(_mm_add_epi32(v.i, Fp16::MIN_D), v.i), _mm_cmpgt_epi32(v.i, Fp16::SUB_C)));
162
0
            v.i = _mm_xor_si128(v.i, _mm_and_si128(_mm_xor_si128(_mm_add_epi32(v.i, Fp16::MAX_D), v.i), _mm_cmpgt_epi32(v.i, Fp16::MAX_C)));
163
0
            Fp16::Bits s;
164
0
            s.i = Fp16::MUL_C;
165
0
            s.f = _mm_mul_ps(s.f, _mm_cvtepi32_ps(v.i));
166
0
            __m128i mask = _mm_cmpgt_epi32(Fp16::NOR_C, v.i);
167
0
            v.i = _mm_slli_epi32(v.i, Base::Fp16::SHIFT);
168
0
            v.i = _mm_xor_si128(v.i, _mm_and_si128(_mm_xor_si128(s.i, v.i), mask));
169
0
            v.i = _mm_or_si128(v.i, sign);
170
0
            return v.f;
171
0
        }
172
    }
173
#endif   
174
}
175
176
#endif//__SimdFloat16_h__