Coverage Report

Created: 2024-10-01 06:54

/src/Simd/src/Simd/SimdAvx512bwSynetSoftmax.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* Simd Library (http://ermig1979.github.io/Simd).
3
*
4
* Copyright (c) 2011-2024 Yermalayeu Ihar.
5
*
6
* Permission is hereby granted, free of charge, to any person obtaining a copy
7
* of this software and associated documentation files (the "Software"), to deal
8
* in the Software without restriction, including without limitation the rights
9
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
* copies of the Software, and to permit persons to whom the Software is
11
* furnished to do so, subject to the following conditions:
12
*
13
* The above copyright notice and this permission notice shall be included in
14
* all copies or substantial portions of the Software.
15
*
16
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
* SOFTWARE.
23
*/
24
#include "Simd/SimdMemory.h"
25
#include "Simd/SimdStore.h"
26
#include "Simd/SimdBase.h"
27
#include "Simd/SimdSse41.h"
28
#include "Simd/SimdAvx2.h"
29
#include "Simd/SimdAvx512bw.h"
30
#include "Simd/SimdSynet.h"
31
#include "Simd/SimdExtract.h"
32
#include "Simd/SimdExp.h"
33
#include "Simd/SimdPow.h"
34
#include "Simd/SimdInterleave.h"
35
#include "Simd/SimdDeinterleave.h"
36
37
namespace Simd
38
{
39
#if defined(SIMD_AVX512BW_ENABLE) && defined(SIMD_SYNET_ENABLE)     
40
    namespace Avx512bw
41
    {
42
        void SynetSoftmaxLayerForward21(const float* src, size_t outer, float* dst)
43
0
        {
44
0
            Exp exp;
45
0
            size_t aligned = Simd::AlignLo(outer, F), tail = outer - aligned;
46
0
            for (size_t o = 0; o < aligned; o += F)
47
0
            {
48
0
                __m512 s0 = _mm512_loadu_ps(src + 0);
49
0
                __m512 s1 = _mm512_loadu_ps(src + F);
50
0
                __m512 ss0 = _mm512_shuffle_ps(s0, s1, 0x88);
51
0
                __m512 ss1 = _mm512_shuffle_ps(s0, s1, 0xDD);
52
0
                __m512 max = _mm512_max_ps(ss0, ss1);
53
0
                __m512 exp0 = exp.Exponent(_mm512_sub_ps(ss0, max));
54
0
                __m512 exp1 = exp.Exponent(_mm512_sub_ps(ss1, max));
55
0
                __m512 sum = _mm512_add_ps(exp0, exp1);
56
0
                __m512 d0 = _mm512_div_ps(exp0, sum);
57
0
                __m512 d1 = _mm512_div_ps(exp1, sum);
58
0
                _mm512_storeu_ps(dst + 0, _mm512_unpacklo_ps(d0, d1));
59
0
                _mm512_storeu_ps(dst + F, _mm512_unpackhi_ps(d0, d1));
60
0
                src += DF;
61
0
                dst += DF;
62
0
            }
63
0
            if (tail)
64
0
            {
65
0
                __mmask16 mask0 = TailMask16(tail * 2 - 0 * F);
66
0
                __mmask16 mask1 = TailMask16(tail * 2 - 1 * F);
67
0
                __m512 s0 = _mm512_maskz_loadu_ps(mask0, src + 0 * F);
68
0
                __m512 s1 = _mm512_maskz_loadu_ps(mask1, src + 1 * F);
69
0
                __m512 ss0 = _mm512_shuffle_ps(s0, s1, 0x88);
70
0
                __m512 ss1 = _mm512_shuffle_ps(s0, s1, 0xDD);
71
0
                __m512 max = _mm512_max_ps(ss0, ss1);
72
0
                __m512 exp0 = exp.Exponent(_mm512_sub_ps(ss0, max));
73
0
                __m512 exp1 = exp.Exponent(_mm512_sub_ps(ss1, max));
74
0
                __m512 sum = _mm512_add_ps(exp0, exp1);
75
0
                __m512 d0 = _mm512_div_ps(exp0, sum);
76
0
                __m512 d1 = _mm512_div_ps(exp1, sum);
77
0
                _mm512_mask_storeu_ps(dst + 0 * F, mask0, _mm512_unpacklo_ps(d0, d1));
78
0
                _mm512_mask_storeu_ps(dst + 1 * F, mask1, _mm512_unpackhi_ps(d0, d1));
79
0
            }
80
0
        }
81
82
        SIMD_INLINE void SynetSoftmaxLayerForward31(const Exp& exp, __m512 buf[3])
83
0
        {
84
0
            __m512 max = _mm512_max_ps(buf[0], _mm512_max_ps(buf[1], buf[2]));
85
0
            buf[0] = exp.Exponent(_mm512_sub_ps(buf[0], max));
86
0
            buf[1] = exp.Exponent(_mm512_sub_ps(buf[1], max));
87
0
            buf[2] = exp.Exponent(_mm512_sub_ps(buf[2], max));
88
0
            __m512 sum = _mm512_add_ps(buf[0], _mm512_add_ps(buf[1], buf[2]));
89
0
            buf[0] = _mm512_div_ps(buf[0], sum);
90
0
            buf[1] = _mm512_div_ps(buf[1], sum);
91
0
            buf[2] = _mm512_div_ps(buf[2], sum);
92
0
        }
93
94
        void SynetSoftmaxLayerForward31(const float* src, size_t outer, float* dst)
95
0
        {
96
0
            static const __m512i idx = _mm512_setr_epi32(0x00, 0x03, 0x06, 0x09, 0x0C, 0x0F, 0x12, 0x15, 0x18, 0x1B, 0x1E, 0x21, 0x24, 0x27, 0x2A, 0x2D);
97
0
            Exp exp;
98
0
            __m512 buf[3];
99
0
            size_t aligned = Simd::AlignLo(outer, F), tail = outer - aligned;
100
0
            for (size_t o = 0; o < aligned; o += F)
101
0
            {
102
0
                buf[0] = _mm512_i32gather_ps(idx, src + 0, 4);
103
0
                buf[1] = _mm512_i32gather_ps(idx, src + 1, 4);
104
0
                buf[2] = _mm512_i32gather_ps(idx, src + 2, 4);
105
0
                SynetSoftmaxLayerForward31(exp, buf);
106
0
                _mm512_i32scatter_ps(dst + 0, idx, buf[0], 4);
107
0
                _mm512_i32scatter_ps(dst + 1, idx, buf[1], 4);
108
0
                _mm512_i32scatter_ps(dst + 2, idx, buf[2], 4);
109
0
                src += 3 * F;
110
0
                dst += 3 * F;
111
0
            }
112
0
            if (tail)
113
0
            {
114
0
                __mmask16 mask = TailMask16(tail);
115
0
                buf[0] = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, idx, src + 0, 4);
116
0
                buf[1] = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, idx, src + 1, 4);
117
0
                buf[2] = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, idx, src + 2, 4);
118
0
                SynetSoftmaxLayerForward31(exp, buf);
119
0
                _mm512_mask_i32scatter_ps(dst + 0, mask, idx, buf[0], 4);
120
0
                _mm512_mask_i32scatter_ps(dst + 1, mask, idx, buf[1], 4);
121
0
                _mm512_mask_i32scatter_ps(dst + 2, mask, idx, buf[2], 4);
122
0
            }
123
0
        }
124
125
        SIMD_INLINE void LoadTansp16x16(const float* src, size_t srcStride, size_t cols, float* dst, __m512& max)
126
0
        {
127
0
            __m512 a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, aA, aB, aC, aD, aE, aF, b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, bA, bB, bC, bD, bE, bF;
128
129
0
            __mmask16 srcMask = __mmask16(-1) >> (16 - cols);
130
0
            __m512 def = _mm512_set1_ps(-FLT_MAX);
131
0
            a0 = _mm512_mask_loadu_ps(def, srcMask, src + 0x0 * srcStride);
132
0
            a1 = _mm512_mask_loadu_ps(def, srcMask, src + 0x1 * srcStride);
133
0
            a2 = _mm512_mask_loadu_ps(def, srcMask, src + 0x2 * srcStride);
134
0
            a3 = _mm512_mask_loadu_ps(def, srcMask, src + 0x3 * srcStride);
135
0
            a4 = _mm512_mask_loadu_ps(def, srcMask, src + 0x4 * srcStride);
136
0
            a5 = _mm512_mask_loadu_ps(def, srcMask, src + 0x5 * srcStride);
137
0
            a6 = _mm512_mask_loadu_ps(def, srcMask, src + 0x6 * srcStride);
138
0
            a7 = _mm512_mask_loadu_ps(def, srcMask, src + 0x7 * srcStride);
139
0
            a8 = _mm512_mask_loadu_ps(def, srcMask, src + 0x8 * srcStride);
140
0
            a9 = _mm512_mask_loadu_ps(def, srcMask, src + 0x9 * srcStride);
141
0
            aA = _mm512_mask_loadu_ps(def, srcMask, src + 0xA * srcStride);
142
0
            aB = _mm512_mask_loadu_ps(def, srcMask, src + 0xB * srcStride);
143
0
            aC = _mm512_mask_loadu_ps(def, srcMask, src + 0xC * srcStride);
144
0
            aD = _mm512_mask_loadu_ps(def, srcMask, src + 0xD * srcStride);
145
0
            aE = _mm512_mask_loadu_ps(def, srcMask, src + 0xE * srcStride);
146
0
            aF = _mm512_mask_loadu_ps(def, srcMask, src + 0xF * srcStride);
147
148
0
            b0 = _mm512_unpacklo_ps(a0, a2);
149
0
            b1 = _mm512_unpacklo_ps(a1, a3);
150
0
            b2 = _mm512_unpackhi_ps(a0, a2);
151
0
            b3 = _mm512_unpackhi_ps(a1, a3);
152
0
            b4 = _mm512_unpacklo_ps(a4, a6);
153
0
            b5 = _mm512_unpacklo_ps(a5, a7);
154
0
            b6 = _mm512_unpackhi_ps(a4, a6);
155
0
            b7 = _mm512_unpackhi_ps(a5, a7);
156
0
            b8 = _mm512_unpacklo_ps(a8, aA);
157
0
            b9 = _mm512_unpacklo_ps(a9, aB);
158
0
            bA = _mm512_unpackhi_ps(a8, aA);
159
0
            bB = _mm512_unpackhi_ps(a9, aB);
160
0
            bC = _mm512_unpacklo_ps(aC, aE);
161
0
            bD = _mm512_unpacklo_ps(aD, aF);
162
0
            bE = _mm512_unpackhi_ps(aC, aE);
163
0
            bF = _mm512_unpackhi_ps(aD, aF);
164
165
0
            a0 = _mm512_unpacklo_ps(b0, b1);
166
0
            a1 = _mm512_unpackhi_ps(b0, b1);
167
0
            a2 = _mm512_unpacklo_ps(b2, b3);
168
0
            a3 = _mm512_unpackhi_ps(b2, b3);
169
0
            a4 = _mm512_unpacklo_ps(b4, b5);
170
0
            a5 = _mm512_unpackhi_ps(b4, b5);
171
0
            a6 = _mm512_unpacklo_ps(b6, b7);
172
0
            a7 = _mm512_unpackhi_ps(b6, b7);
173
0
            a8 = _mm512_unpacklo_ps(b8, b9);
174
0
            a9 = _mm512_unpackhi_ps(b8, b9);
175
0
            aA = _mm512_unpacklo_ps(bA, bB);
176
0
            aB = _mm512_unpackhi_ps(bA, bB);
177
0
            aC = _mm512_unpacklo_ps(bC, bD);
178
0
            aD = _mm512_unpackhi_ps(bC, bD);
179
0
            aE = _mm512_unpacklo_ps(bE, bF);
180
0
            aF = _mm512_unpackhi_ps(bE, bF);
181
182
0
            b0 = _mm512_shuffle_f32x4(a0, a4, 0x44);
183
0
            b1 = _mm512_shuffle_f32x4(a1, a5, 0x44);
184
0
            b2 = _mm512_shuffle_f32x4(a2, a6, 0x44);
185
0
            b3 = _mm512_shuffle_f32x4(a3, a7, 0x44);
186
0
            b4 = _mm512_shuffle_f32x4(a0, a4, 0xEE);
187
0
            b5 = _mm512_shuffle_f32x4(a1, a5, 0xEE);
188
0
            b6 = _mm512_shuffle_f32x4(a2, a6, 0xEE);
189
0
            b7 = _mm512_shuffle_f32x4(a3, a7, 0xEE);
190
0
            b8 = _mm512_shuffle_f32x4(a8, aC, 0x44);
191
0
            b9 = _mm512_shuffle_f32x4(a9, aD, 0x44);
192
0
            bA = _mm512_shuffle_f32x4(aA, aE, 0x44);
193
0
            bB = _mm512_shuffle_f32x4(aB, aF, 0x44);
194
0
            bC = _mm512_shuffle_f32x4(a8, aC, 0xEE);
195
0
            bD = _mm512_shuffle_f32x4(a9, aD, 0xEE);
196
0
            bE = _mm512_shuffle_f32x4(aA, aE, 0xEE);
197
0
            bF = _mm512_shuffle_f32x4(aB, aF, 0xEE);
198
199
0
            a0 = _mm512_shuffle_f32x4(b0, b8, 0x88);
200
0
            a1 = _mm512_shuffle_f32x4(b1, b9, 0x88);
201
0
            a2 = _mm512_shuffle_f32x4(b2, bA, 0x88);
202
0
            a3 = _mm512_shuffle_f32x4(b3, bB, 0x88);
203
0
            a4 = _mm512_shuffle_f32x4(b0, b8, 0xDD);
204
0
            a5 = _mm512_shuffle_f32x4(b1, b9, 0xDD);
205
0
            a6 = _mm512_shuffle_f32x4(b2, bA, 0xDD);
206
0
            a7 = _mm512_shuffle_f32x4(b3, bB, 0xDD);
207
0
            a8 = _mm512_shuffle_f32x4(b4, bC, 0x88);
208
0
            a9 = _mm512_shuffle_f32x4(b5, bD, 0x88);
209
0
            aA = _mm512_shuffle_f32x4(b6, bE, 0x88);
210
0
            aB = _mm512_shuffle_f32x4(b7, bF, 0x88);
211
0
            aC = _mm512_shuffle_f32x4(b4, bC, 0xDD);
212
0
            aD = _mm512_shuffle_f32x4(b5, bD, 0xDD);
213
0
            aE = _mm512_shuffle_f32x4(b6, bE, 0xDD);
214
0
            aF = _mm512_shuffle_f32x4(b7, bF, 0xDD);
215
216
0
            max = _mm512_max_ps(max, a0);
217
0
            max = _mm512_max_ps(max, a1);
218
0
            max = _mm512_max_ps(max, a2);
219
0
            max = _mm512_max_ps(max, a3);
220
0
            max = _mm512_max_ps(max, a4);
221
0
            max = _mm512_max_ps(max, a5);
222
0
            max = _mm512_max_ps(max, a6);
223
0
            max = _mm512_max_ps(max, a7);
224
0
            max = _mm512_max_ps(max, a8);
225
0
            max = _mm512_max_ps(max, a9);
226
0
            max = _mm512_max_ps(max, aA);
227
0
            max = _mm512_max_ps(max, aB);
228
0
            max = _mm512_max_ps(max, aC);
229
0
            max = _mm512_max_ps(max, aD);
230
0
            max = _mm512_max_ps(max, aE);
231
0
            max = _mm512_max_ps(max, aF);
232
233
0
            _mm512_storeu_ps(dst + 0x0 * F, a0);
234
0
            _mm512_storeu_ps(dst + 0x1 * F, a1);
235
0
            _mm512_storeu_ps(dst + 0x2 * F, a2);
236
0
            _mm512_storeu_ps(dst + 0x3 * F, a3);
237
0
            _mm512_storeu_ps(dst + 0x4 * F, a4);
238
0
            _mm512_storeu_ps(dst + 0x5 * F, a5);
239
0
            _mm512_storeu_ps(dst + 0x6 * F, a6);
240
0
            _mm512_storeu_ps(dst + 0x7 * F, a7);
241
0
            _mm512_storeu_ps(dst + 0x8 * F, a8);
242
0
            _mm512_storeu_ps(dst + 0x9 * F, a9);
243
0
            _mm512_storeu_ps(dst + 0xA * F, aA);
244
0
            _mm512_storeu_ps(dst + 0xB * F, aB);
245
0
            _mm512_storeu_ps(dst + 0xC * F, aC);
246
0
            _mm512_storeu_ps(dst + 0xD * F, aD);
247
0
            _mm512_storeu_ps(dst + 0xE * F, aE);
248
0
            _mm512_storeu_ps(dst + 0xF * F, aF);
249
0
        }
250
251
        SIMD_INLINE void LoadTansp16x16(const float* src, size_t srcStride, size_t cols, size_t rows, float* dst, __m512& max)
252
0
        {
253
0
            __m512 a[16], b[16];
254
255
0
            __mmask16 srcMask = __mmask16(-1) >> (16 - cols);
256
0
            __m512 def = _mm512_set1_ps(-FLT_MAX);
257
0
            for(size_t r = 0; r < rows; ++r)
258
0
                a[r] = _mm512_mask_loadu_ps(def, srcMask, src + r * srcStride);
259
260
0
            for (size_t r = 0; r < rows; r += 4)
261
0
            {
262
0
                b[r + 0] = _mm512_unpacklo_ps(a[r + 0], a[r + 2]);
263
0
                b[r + 1] = _mm512_unpacklo_ps(a[r + 1], a[r + 3]);
264
0
                b[r + 2] = _mm512_unpackhi_ps(a[r + 0], a[r + 2]);
265
0
                b[r + 3] = _mm512_unpackhi_ps(a[r + 1], a[r + 3]);
266
0
            }
267
268
0
            for (size_t r = 0; r < rows; r += 4)
269
0
            {
270
0
                a[r + 0] = _mm512_unpacklo_ps(b[r + 0], b[r + 1]);
271
0
                a[r + 1] = _mm512_unpackhi_ps(b[r + 0], b[r + 1]);
272
0
                a[r + 2] = _mm512_unpacklo_ps(b[r + 2], b[r + 3]);
273
0
                a[r + 3] = _mm512_unpackhi_ps(b[r + 2], b[r + 3]);
274
0
            }
275
276
0
            for (size_t i = 0; i < 4; i += 1)
277
0
            {
278
0
                b[0x0 + i] = _mm512_shuffle_f32x4(a[0x0 + i], a[0x4 + i], 0x44);
279
0
                b[0x4 + i] = _mm512_shuffle_f32x4(a[0x0 + i], a[0x4 + i], 0xEE);
280
0
                b[0x8 + i] = _mm512_shuffle_f32x4(a[0x8 + i], a[0xC + i], 0x44);
281
0
                b[0xC + i] = _mm512_shuffle_f32x4(a[0x8 + i], a[0xC + i], 0xEE);
282
0
            }
283
284
0
            for (size_t i = 0; i < 4; i += 1)
285
0
            {
286
0
                a[0x0 + i] = _mm512_shuffle_f32x4(b[0x0 + i], b[0x8 + i], 0x88);
287
0
                a[0x4 + i] = _mm512_shuffle_f32x4(b[0x0 + i], b[0x8 + i], 0xDD);
288
0
                a[0x8 + i] = _mm512_shuffle_f32x4(b[0x4 + i], b[0xC + i], 0x88);
289
0
                a[0xC + i] = _mm512_shuffle_f32x4(b[0x4 + i], b[0xC + i], 0xDD);
290
0
            }
291
292
0
            for (size_t c = 0; c < cols; ++c)
293
0
            {
294
0
                max = _mm512_max_ps(max, a[c]);
295
0
                _mm512_storeu_ps(dst + c * F, a[c]);
296
0
            }
297
0
        }
298
299
        SIMD_INLINE void StoreTansp16x16(const float* src, size_t cols, __m512 k, float* dst, size_t dstStride)
300
0
        {
301
0
            __m512 a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, aA, aB, aC, aD, aE, aF, b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, bA, bB, bC, bD, bE, bF;
302
303
0
            a0 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x0 * F));
304
0
            a1 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x1 * F));
305
0
            a2 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x2 * F));
306
0
            a3 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x3 * F));
307
0
            a4 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x4 * F));
308
0
            a5 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x5 * F));
309
0
            a6 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x6 * F));
310
0
            a7 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x7 * F));
311
0
            a8 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x8 * F));
312
0
            a9 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x9 * F));
313
0
            aA = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xA * F));
314
0
            aB = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xB * F));
315
0
            aC = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xC * F));
316
0
            aD = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xD * F));
317
0
            aE = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xE * F));
318
0
            aF = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xF * F));
319
320
0
            b0 = _mm512_unpacklo_ps(a0, a2);
321
0
            b1 = _mm512_unpacklo_ps(a1, a3);
322
0
            b2 = _mm512_unpackhi_ps(a0, a2);
323
0
            b3 = _mm512_unpackhi_ps(a1, a3);
324
0
            b4 = _mm512_unpacklo_ps(a4, a6);
325
0
            b5 = _mm512_unpacklo_ps(a5, a7);
326
0
            b6 = _mm512_unpackhi_ps(a4, a6);
327
0
            b7 = _mm512_unpackhi_ps(a5, a7);
328
0
            b8 = _mm512_unpacklo_ps(a8, aA);
329
0
            b9 = _mm512_unpacklo_ps(a9, aB);
330
0
            bA = _mm512_unpackhi_ps(a8, aA);
331
0
            bB = _mm512_unpackhi_ps(a9, aB);
332
0
            bC = _mm512_unpacklo_ps(aC, aE);
333
0
            bD = _mm512_unpacklo_ps(aD, aF);
334
0
            bE = _mm512_unpackhi_ps(aC, aE);
335
0
            bF = _mm512_unpackhi_ps(aD, aF);
336
337
0
            a0 = _mm512_unpacklo_ps(b0, b1);
338
0
            a1 = _mm512_unpackhi_ps(b0, b1);
339
0
            a2 = _mm512_unpacklo_ps(b2, b3);
340
0
            a3 = _mm512_unpackhi_ps(b2, b3);
341
0
            a4 = _mm512_unpacklo_ps(b4, b5);
342
0
            a5 = _mm512_unpackhi_ps(b4, b5);
343
0
            a6 = _mm512_unpacklo_ps(b6, b7);
344
0
            a7 = _mm512_unpackhi_ps(b6, b7);
345
0
            a8 = _mm512_unpacklo_ps(b8, b9);
346
0
            a9 = _mm512_unpackhi_ps(b8, b9);
347
0
            aA = _mm512_unpacklo_ps(bA, bB);
348
0
            aB = _mm512_unpackhi_ps(bA, bB);
349
0
            aC = _mm512_unpacklo_ps(bC, bD);
350
0
            aD = _mm512_unpackhi_ps(bC, bD);
351
0
            aE = _mm512_unpacklo_ps(bE, bF);
352
0
            aF = _mm512_unpackhi_ps(bE, bF);
353
354
0
            b0 = _mm512_shuffle_f32x4(a0, a4, 0x44);
355
0
            b1 = _mm512_shuffle_f32x4(a1, a5, 0x44);
356
0
            b2 = _mm512_shuffle_f32x4(a2, a6, 0x44);
357
0
            b3 = _mm512_shuffle_f32x4(a3, a7, 0x44);
358
0
            b4 = _mm512_shuffle_f32x4(a0, a4, 0xEE);
359
0
            b5 = _mm512_shuffle_f32x4(a1, a5, 0xEE);
360
0
            b6 = _mm512_shuffle_f32x4(a2, a6, 0xEE);
361
0
            b7 = _mm512_shuffle_f32x4(a3, a7, 0xEE);
362
0
            b8 = _mm512_shuffle_f32x4(a8, aC, 0x44);
363
0
            b9 = _mm512_shuffle_f32x4(a9, aD, 0x44);
364
0
            bA = _mm512_shuffle_f32x4(aA, aE, 0x44);
365
0
            bB = _mm512_shuffle_f32x4(aB, aF, 0x44);
366
0
            bC = _mm512_shuffle_f32x4(a8, aC, 0xEE);
367
0
            bD = _mm512_shuffle_f32x4(a9, aD, 0xEE);
368
0
            bE = _mm512_shuffle_f32x4(aA, aE, 0xEE);
369
0
            bF = _mm512_shuffle_f32x4(aB, aF, 0xEE);
370
371
0
            a0 = _mm512_shuffle_f32x4(b0, b8, 0x88);
372
0
            a1 = _mm512_shuffle_f32x4(b1, b9, 0x88);
373
0
            a2 = _mm512_shuffle_f32x4(b2, bA, 0x88);
374
0
            a3 = _mm512_shuffle_f32x4(b3, bB, 0x88);
375
0
            a4 = _mm512_shuffle_f32x4(b0, b8, 0xDD);
376
0
            a5 = _mm512_shuffle_f32x4(b1, b9, 0xDD);
377
0
            a6 = _mm512_shuffle_f32x4(b2, bA, 0xDD);
378
0
            a7 = _mm512_shuffle_f32x4(b3, bB, 0xDD);
379
0
            a8 = _mm512_shuffle_f32x4(b4, bC, 0x88);
380
0
            a9 = _mm512_shuffle_f32x4(b5, bD, 0x88);
381
0
            aA = _mm512_shuffle_f32x4(b6, bE, 0x88);
382
0
            aB = _mm512_shuffle_f32x4(b7, bF, 0x88);
383
0
            aC = _mm512_shuffle_f32x4(b4, bC, 0xDD);
384
0
            aD = _mm512_shuffle_f32x4(b5, bD, 0xDD);
385
0
            aE = _mm512_shuffle_f32x4(b6, bE, 0xDD);
386
0
            aF = _mm512_shuffle_f32x4(b7, bF, 0xDD);
387
388
0
            __mmask16 dstMask = __mmask16(-1) >> (16 - cols);
389
0
            _mm512_mask_storeu_ps(dst + 0x0 * dstStride, dstMask, a0);
390
0
            _mm512_mask_storeu_ps(dst + 0x1 * dstStride, dstMask, a1);
391
0
            _mm512_mask_storeu_ps(dst + 0x2 * dstStride, dstMask, a2);
392
0
            _mm512_mask_storeu_ps(dst + 0x3 * dstStride, dstMask, a3);
393
0
            _mm512_mask_storeu_ps(dst + 0x4 * dstStride, dstMask, a4);
394
0
            _mm512_mask_storeu_ps(dst + 0x5 * dstStride, dstMask, a5);
395
0
            _mm512_mask_storeu_ps(dst + 0x6 * dstStride, dstMask, a6);
396
0
            _mm512_mask_storeu_ps(dst + 0x7 * dstStride, dstMask, a7);
397
0
            _mm512_mask_storeu_ps(dst + 0x8 * dstStride, dstMask, a8);
398
0
            _mm512_mask_storeu_ps(dst + 0x9 * dstStride, dstMask, a9);
399
0
            _mm512_mask_storeu_ps(dst + 0xA * dstStride, dstMask, aA);
400
0
            _mm512_mask_storeu_ps(dst + 0xB * dstStride, dstMask, aB);
401
0
            _mm512_mask_storeu_ps(dst + 0xC * dstStride, dstMask, aC);
402
0
            _mm512_mask_storeu_ps(dst + 0xD * dstStride, dstMask, aD);
403
0
            _mm512_mask_storeu_ps(dst + 0xE * dstStride, dstMask, aE);
404
0
            _mm512_mask_storeu_ps(dst + 0xF * dstStride, dstMask, aF);
405
0
        }
406
407
        SIMD_INLINE void StoreTansp16x16(const float* src, size_t cols, size_t rows, __m512 k, float* dst, size_t dstStride)
408
0
        {
409
0
            __m512 a[16], b[16];
410
411
0
            for (size_t c = 0; c < cols; ++c)
412
0
                a[c] = _mm512_mul_ps(k, _mm512_loadu_ps(src + c * F));
413
414
0
            for (size_t i = 0; i < 4; i += 1)
415
0
            {
416
0
                b[0x0 + i] = _mm512_shuffle_f32x4(a[0x0 + i], a[0x4 + i], 0x44);
417
0
                b[0x4 + i] = _mm512_shuffle_f32x4(a[0x0 + i], a[0x4 + i], 0xEE);
418
0
                b[0x8 + i] = _mm512_shuffle_f32x4(a[0x8 + i], a[0xC + i], 0x44);
419
0
                b[0xC + i] = _mm512_shuffle_f32x4(a[0x8 + i], a[0xC + i], 0xEE);
420
0
            }
421
422
0
            for (size_t i = 0; i < 4; i += 1)
423
0
            {
424
0
                a[0x0 + i] = _mm512_shuffle_f32x4(b[0x0 + i], b[0x8 + i], 0x88);
425
0
                a[0x4 + i] = _mm512_shuffle_f32x4(b[0x0 + i], b[0x8 + i], 0xDD);
426
0
                a[0x8 + i] = _mm512_shuffle_f32x4(b[0x4 + i], b[0xC + i], 0x88);
427
0
                a[0xC + i] = _mm512_shuffle_f32x4(b[0x4 + i], b[0xC + i], 0xDD);
428
0
            }
429
430
0
            for (size_t r = 0; r < rows; r += 4)
431
0
            {
432
0
                b[r + 0] = _mm512_unpacklo_ps(a[r + 0], a[r + 2]);
433
0
                b[r + 1] = _mm512_unpacklo_ps(a[r + 1], a[r + 3]);
434
0
                b[r + 2] = _mm512_unpackhi_ps(a[r + 0], a[r + 2]);
435
0
                b[r + 3] = _mm512_unpackhi_ps(a[r + 1], a[r + 3]);
436
0
            }
437
438
0
            for (size_t r = 0; r < rows; r += 4)
439
0
            {
440
0
                a[r + 0] = _mm512_unpacklo_ps(b[r + 0], b[r + 1]);
441
0
                a[r + 1] = _mm512_unpackhi_ps(b[r + 0], b[r + 1]);
442
0
                a[r + 2] = _mm512_unpacklo_ps(b[r + 2], b[r + 3]);
443
0
                a[r + 3] = _mm512_unpackhi_ps(b[r + 2], b[r + 3]);
444
0
            }
445
446
0
            __mmask16 dstMask = __mmask16(-1) >> (16 - cols);
447
0
            for (size_t r = 0; r < rows; ++r)
448
0
                _mm512_mask_storeu_ps(dst + r * dstStride, dstMask, a[r]);
449
0
        }
450
451
        void SynetSoftmaxLayerForwardX1(const float* src, size_t outer, size_t count, float* dst)
452
0
        {
453
0
            size_t o = 0, c = 0, outerF = AlignLo(outer, F), countF = AlignLo(count, F);
454
0
            Array32f buf(AlignHi(count, F) * F);
455
0
            Exp exp;
456
0
            for (; o < outerF; o += F)
457
0
            {
458
0
                __m512 _max = _mm512_set1_ps(-FLT_MAX);
459
0
                for (c = 0; c < countF; c += F)
460
0
                    LoadTansp16x16(src + c, count, F, buf.data + c * F, _max);
461
0
                if (c < count)
462
0
                    LoadTansp16x16(src + c, count, count - c, buf.data + c * F, _max);
463
0
                __m512 _sum = _mm512_setzero_ps();
464
0
                for (c = 0; c < count; ++c)
465
0
                {
466
0
                    __m512 _exp = exp.Exponent(_mm512_sub_ps(_mm512_loadu_ps(buf.data + c * F), _max));
467
0
                    _sum = _mm512_add_ps(_sum, _exp);
468
0
                    _mm512_storeu_ps(buf.data + c * F, _exp);
469
0
                }
470
0
                __m512 _k = _mm512_div_ps(_mm512_set1_ps(1.0f), _sum);
471
0
                for (c = 0; c < countF; c += F)
472
0
                    StoreTansp16x16(buf.data + c * F, F, _k, dst + c, count);
473
0
                if (c < count)
474
0
                    StoreTansp16x16(buf.data + c * F, count - c, _k, dst + c, count);
475
0
                src += count * F;
476
0
                dst += count * F;
477
0
            }
478
0
            if (o < outer)
479
0
            {
480
0
                buf.Clear();
481
0
                __m512 _max = _mm512_set1_ps(-FLT_MAX);
482
0
                for (c = 0; c < countF; c += F)
483
0
                    LoadTansp16x16(src + c, count, F, outer - o, buf.data + c * F, _max);
484
0
                if (c < count)
485
0
                    LoadTansp16x16(src + c, count, count - c, outer - o, buf.data + c * F, _max);
486
0
                __m512 _sum = _mm512_setzero_ps();
487
0
                for (c = 0; c < count; ++c)
488
0
                {
489
0
                    __m512 _exp = exp.Exponent(_mm512_sub_ps(_mm512_loadu_ps(buf.data + c * F), _max));
490
0
                    _sum = _mm512_add_ps(_sum, _exp);
491
0
                    _mm512_storeu_ps(buf.data + c * F, _exp);
492
0
                }
493
0
                __m512 _k = _mm512_div_ps(_mm512_set1_ps(1.0f), _sum);
494
0
                for (c = 0; c < countF; c += F)
495
0
                    StoreTansp16x16(buf.data + c * F, F, outer - o, _k, dst + c, count);
496
0
                if (c < count)
497
0
                    StoreTansp16x16(buf.data + c * F, count - c, outer - o, _k, dst + c, count);
498
0
            }
499
0
        }
500
501
        void SynetSoftmaxLayerForward(const float* src, size_t outer, size_t count, size_t inner, float* dst)
502
0
        {
503
0
            if (inner == 1)
504
0
            {
505
0
                if (count == 2)
506
0
                    SynetSoftmaxLayerForward21(src, outer, dst);
507
0
                else if (count == 3)
508
0
                    SynetSoftmaxLayerForward31(src, outer, dst);
509
0
                else
510
0
                    SynetSoftmaxLayerForwardX1(src, outer, count, dst);
511
0
            }
512
0
            else
513
0
            {
514
0
                Exp exp;
515
0
                size_t aligned = Simd::AlignLo(inner, F);
516
0
                __mmask16 tail = TailMask16(inner - aligned);
517
0
                Array32f tmp(inner * 2);
518
0
                const float* s;
519
0
                float* max = tmp.data, * sum = tmp.data + inner, * d;
520
0
                for (size_t o = 0; o < outer; ++o)
521
0
                {
522
0
                    memcpy(max, src, inner * sizeof(float));
523
0
                    s = src + inner;
524
0
                    for (size_t c = 1; c < count; ++c)
525
0
                    {
526
0
                        size_t i = 0;
527
0
                        for (; i < aligned; i += F)
528
0
                            _mm512_storeu_ps(max + i, _mm512_max_ps(_mm512_loadu_ps(s + i), _mm512_loadu_ps(max + i)));
529
0
                        if (i < inner)
530
0
                            _mm512_mask_storeu_ps(max + i, tail, _mm512_max_ps(_mm512_maskz_loadu_ps(tail, s + i), _mm512_maskz_loadu_ps(tail, max + i)));
531
0
                        s += inner;
532
0
                    }
533
534
0
                    s = src;
535
0
                    d = dst;
536
0
                    memset(sum, 0, inner * sizeof(float));
537
0
                    for (size_t c = 0; c < count; ++c)
538
0
                    {
539
0
                        size_t i = 0;
540
0
                        for (; i < aligned; i += F)
541
0
                        {
542
0
                            __m512 _d = exp.Exponent(_mm512_sub_ps(_mm512_loadu_ps(s + i), _mm512_loadu_ps(max + i)));
543
0
                            _mm512_storeu_ps(d + i, _d);
544
0
                            _mm512_storeu_ps(sum + i, _mm512_add_ps(_d, _mm512_loadu_ps(sum + i)));
545
0
                        }
546
0
                        if (i < inner)
547
0
                        {
548
0
                            __m512 _d = exp.Exponent(_mm512_sub_ps(_mm512_maskz_loadu_ps(tail, s + i), _mm512_maskz_loadu_ps(tail, max + i)));
549
0
                            _mm512_mask_storeu_ps(d + i, tail, _d);
550
0
                            _mm512_mask_storeu_ps(sum + i, tail, _mm512_add_ps(_d, _mm512_maskz_loadu_ps(tail, sum + i)));
551
0
                        }
552
0
                        s += inner;
553
0
                        d += inner;
554
0
                    }
555
556
0
                    d = dst;
557
0
                    for (size_t c = 0; c < count; ++c)
558
0
                    {
559
0
                        size_t i = 0;
560
0
                        for (; i < aligned; i += F)
561
0
                            _mm512_storeu_ps(d + i, _mm512_div_ps(_mm512_loadu_ps(d + i), _mm512_loadu_ps(sum + i)));
562
0
                        if (i < inner)
563
0
                            _mm512_mask_storeu_ps(d + i, tail, _mm512_div_ps(_mm512_maskz_loadu_ps(tail, d + i), _mm512_maskz_loadu_ps(tail, sum + i)));
564
0
                        d += inner;
565
0
                    }
566
0
                    src += count * inner;
567
0
                    dst += count * inner;
568
0
                }
569
0
            }
570
0
        }
571
    }
572
#endif
573
}