/src/Simd/src/Simd/SimdAvx512bwSynetSoftmax.cpp
Line  | Count  | Source (jump to first uncovered line)  | 
1  |  | /*  | 
2  |  | * Simd Library (http://ermig1979.github.io/Simd).  | 
3  |  | *  | 
4  |  | * Copyright (c) 2011-2024 Yermalayeu Ihar.  | 
5  |  | *  | 
6  |  | * Permission is hereby granted, free of charge, to any person obtaining a copy  | 
7  |  | * of this software and associated documentation files (the "Software"), to deal  | 
8  |  | * in the Software without restriction, including without limitation the rights  | 
9  |  | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell  | 
10  |  | * copies of the Software, and to permit persons to whom the Software is  | 
11  |  | * furnished to do so, subject to the following conditions:  | 
12  |  | *  | 
13  |  | * The above copyright notice and this permission notice shall be included in  | 
14  |  | * all copies or substantial portions of the Software.  | 
15  |  | *  | 
16  |  | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR  | 
17  |  | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  | 
18  |  | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE  | 
19  |  | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER  | 
20  |  | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,  | 
21  |  | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE  | 
22  |  | * SOFTWARE.  | 
23  |  | */  | 
24  |  | #include "Simd/SimdMemory.h"  | 
25  |  | #include "Simd/SimdStore.h"  | 
26  |  | #include "Simd/SimdBase.h"  | 
27  |  | #include "Simd/SimdSse41.h"  | 
28  |  | #include "Simd/SimdAvx2.h"  | 
29  |  | #include "Simd/SimdAvx512bw.h"  | 
30  |  | #include "Simd/SimdSynet.h"  | 
31  |  | #include "Simd/SimdExtract.h"  | 
32  |  | #include "Simd/SimdExp.h"  | 
33  |  | #include "Simd/SimdPow.h"  | 
34  |  | #include "Simd/SimdInterleave.h"  | 
35  |  | #include "Simd/SimdDeinterleave.h"  | 
36  |  |  | 
37  |  | namespace Simd  | 
38  |  | { | 
39  |  | #if defined(SIMD_AVX512BW_ENABLE) && defined(SIMD_SYNET_ENABLE)       | 
40  |  |     namespace Avx512bw  | 
41  |  |     { | 
42  |  |         void SynetSoftmaxLayerForward21(const float* src, size_t outer, float* dst)  | 
43  | 0  |         { | 
44  | 0  |             Exp exp;  | 
45  | 0  |             size_t aligned = Simd::AlignLo(outer, F), tail = outer - aligned;  | 
46  | 0  |             for (size_t o = 0; o < aligned; o += F)  | 
47  | 0  |             { | 
48  | 0  |                 __m512 s0 = _mm512_loadu_ps(src + 0);  | 
49  | 0  |                 __m512 s1 = _mm512_loadu_ps(src + F);  | 
50  | 0  |                 __m512 ss0 = _mm512_shuffle_ps(s0, s1, 0x88);  | 
51  | 0  |                 __m512 ss1 = _mm512_shuffle_ps(s0, s1, 0xDD);  | 
52  | 0  |                 __m512 max = _mm512_max_ps(ss0, ss1);  | 
53  | 0  |                 __m512 exp0 = exp.Exponent(_mm512_sub_ps(ss0, max));  | 
54  | 0  |                 __m512 exp1 = exp.Exponent(_mm512_sub_ps(ss1, max));  | 
55  | 0  |                 __m512 sum = _mm512_add_ps(exp0, exp1);  | 
56  | 0  |                 __m512 d0 = _mm512_div_ps(exp0, sum);  | 
57  | 0  |                 __m512 d1 = _mm512_div_ps(exp1, sum);  | 
58  | 0  |                 _mm512_storeu_ps(dst + 0, _mm512_unpacklo_ps(d0, d1));  | 
59  | 0  |                 _mm512_storeu_ps(dst + F, _mm512_unpackhi_ps(d0, d1));  | 
60  | 0  |                 src += DF;  | 
61  | 0  |                 dst += DF;  | 
62  | 0  |             }  | 
63  | 0  |             if (tail)  | 
64  | 0  |             { | 
65  | 0  |                 __mmask16 mask0 = TailMask16(tail * 2 - 0 * F);  | 
66  | 0  |                 __mmask16 mask1 = TailMask16(tail * 2 - 1 * F);  | 
67  | 0  |                 __m512 s0 = _mm512_maskz_loadu_ps(mask0, src + 0 * F);  | 
68  | 0  |                 __m512 s1 = _mm512_maskz_loadu_ps(mask1, src + 1 * F);  | 
69  | 0  |                 __m512 ss0 = _mm512_shuffle_ps(s0, s1, 0x88);  | 
70  | 0  |                 __m512 ss1 = _mm512_shuffle_ps(s0, s1, 0xDD);  | 
71  | 0  |                 __m512 max = _mm512_max_ps(ss0, ss1);  | 
72  | 0  |                 __m512 exp0 = exp.Exponent(_mm512_sub_ps(ss0, max));  | 
73  | 0  |                 __m512 exp1 = exp.Exponent(_mm512_sub_ps(ss1, max));  | 
74  | 0  |                 __m512 sum = _mm512_add_ps(exp0, exp1);  | 
75  | 0  |                 __m512 d0 = _mm512_div_ps(exp0, sum);  | 
76  | 0  |                 __m512 d1 = _mm512_div_ps(exp1, sum);  | 
77  | 0  |                 _mm512_mask_storeu_ps(dst + 0 * F, mask0, _mm512_unpacklo_ps(d0, d1));  | 
78  | 0  |                 _mm512_mask_storeu_ps(dst + 1 * F, mask1, _mm512_unpackhi_ps(d0, d1));  | 
79  | 0  |             }  | 
80  | 0  |         }  | 
81  |  |  | 
82  |  |         SIMD_INLINE void SynetSoftmaxLayerForward31(const Exp& exp, __m512 buf[3])  | 
83  | 0  |         { | 
84  | 0  |             __m512 max = _mm512_max_ps(buf[0], _mm512_max_ps(buf[1], buf[2]));  | 
85  | 0  |             buf[0] = exp.Exponent(_mm512_sub_ps(buf[0], max));  | 
86  | 0  |             buf[1] = exp.Exponent(_mm512_sub_ps(buf[1], max));  | 
87  | 0  |             buf[2] = exp.Exponent(_mm512_sub_ps(buf[2], max));  | 
88  | 0  |             __m512 sum = _mm512_add_ps(buf[0], _mm512_add_ps(buf[1], buf[2]));  | 
89  | 0  |             buf[0] = _mm512_div_ps(buf[0], sum);  | 
90  | 0  |             buf[1] = _mm512_div_ps(buf[1], sum);  | 
91  | 0  |             buf[2] = _mm512_div_ps(buf[2], sum);  | 
92  | 0  |         }  | 
93  |  |  | 
94  |  |         void SynetSoftmaxLayerForward31(const float* src, size_t outer, float* dst)  | 
95  | 0  |         { | 
96  | 0  |             static const __m512i idx = _mm512_setr_epi32(0x00, 0x03, 0x06, 0x09, 0x0C, 0x0F, 0x12, 0x15, 0x18, 0x1B, 0x1E, 0x21, 0x24, 0x27, 0x2A, 0x2D);  | 
97  | 0  |             Exp exp;  | 
98  | 0  |             __m512 buf[3];  | 
99  | 0  |             size_t aligned = Simd::AlignLo(outer, F), tail = outer - aligned;  | 
100  | 0  |             for (size_t o = 0; o < aligned; o += F)  | 
101  | 0  |             { | 
102  | 0  |                 buf[0] = _mm512_i32gather_ps(idx, src + 0, 4);  | 
103  | 0  |                 buf[1] = _mm512_i32gather_ps(idx, src + 1, 4);  | 
104  | 0  |                 buf[2] = _mm512_i32gather_ps(idx, src + 2, 4);  | 
105  | 0  |                 SynetSoftmaxLayerForward31(exp, buf);  | 
106  | 0  |                 _mm512_i32scatter_ps(dst + 0, idx, buf[0], 4);  | 
107  | 0  |                 _mm512_i32scatter_ps(dst + 1, idx, buf[1], 4);  | 
108  | 0  |                 _mm512_i32scatter_ps(dst + 2, idx, buf[2], 4);  | 
109  | 0  |                 src += 3 * F;  | 
110  | 0  |                 dst += 3 * F;  | 
111  | 0  |             }  | 
112  | 0  |             if (tail)  | 
113  | 0  |             { | 
114  | 0  |                 __mmask16 mask = TailMask16(tail);  | 
115  | 0  |                 buf[0] = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, idx, src + 0, 4);  | 
116  | 0  |                 buf[1] = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, idx, src + 1, 4);  | 
117  | 0  |                 buf[2] = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, idx, src + 2, 4);  | 
118  | 0  |                 SynetSoftmaxLayerForward31(exp, buf);  | 
119  | 0  |                 _mm512_mask_i32scatter_ps(dst + 0, mask, idx, buf[0], 4);  | 
120  | 0  |                 _mm512_mask_i32scatter_ps(dst + 1, mask, idx, buf[1], 4);  | 
121  | 0  |                 _mm512_mask_i32scatter_ps(dst + 2, mask, idx, buf[2], 4);  | 
122  | 0  |             }  | 
123  | 0  |         }  | 
124  |  |  | 
125  |  |         SIMD_INLINE void LoadTansp16x16(const float* src, size_t srcStride, size_t cols, float* dst, __m512& max)  | 
126  | 0  |         { | 
127  | 0  |             __m512 a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, aA, aB, aC, aD, aE, aF, b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, bA, bB, bC, bD, bE, bF;  | 
128  |  | 
  | 
129  | 0  |             __mmask16 srcMask = __mmask16(-1) >> (16 - cols);  | 
130  | 0  |             __m512 def = _mm512_set1_ps(-FLT_MAX);  | 
131  | 0  |             a0 = _mm512_mask_loadu_ps(def, srcMask, src + 0x0 * srcStride);  | 
132  | 0  |             a1 = _mm512_mask_loadu_ps(def, srcMask, src + 0x1 * srcStride);  | 
133  | 0  |             a2 = _mm512_mask_loadu_ps(def, srcMask, src + 0x2 * srcStride);  | 
134  | 0  |             a3 = _mm512_mask_loadu_ps(def, srcMask, src + 0x3 * srcStride);  | 
135  | 0  |             a4 = _mm512_mask_loadu_ps(def, srcMask, src + 0x4 * srcStride);  | 
136  | 0  |             a5 = _mm512_mask_loadu_ps(def, srcMask, src + 0x5 * srcStride);  | 
137  | 0  |             a6 = _mm512_mask_loadu_ps(def, srcMask, src + 0x6 * srcStride);  | 
138  | 0  |             a7 = _mm512_mask_loadu_ps(def, srcMask, src + 0x7 * srcStride);  | 
139  | 0  |             a8 = _mm512_mask_loadu_ps(def, srcMask, src + 0x8 * srcStride);  | 
140  | 0  |             a9 = _mm512_mask_loadu_ps(def, srcMask, src + 0x9 * srcStride);  | 
141  | 0  |             aA = _mm512_mask_loadu_ps(def, srcMask, src + 0xA * srcStride);  | 
142  | 0  |             aB = _mm512_mask_loadu_ps(def, srcMask, src + 0xB * srcStride);  | 
143  | 0  |             aC = _mm512_mask_loadu_ps(def, srcMask, src + 0xC * srcStride);  | 
144  | 0  |             aD = _mm512_mask_loadu_ps(def, srcMask, src + 0xD * srcStride);  | 
145  | 0  |             aE = _mm512_mask_loadu_ps(def, srcMask, src + 0xE * srcStride);  | 
146  | 0  |             aF = _mm512_mask_loadu_ps(def, srcMask, src + 0xF * srcStride);  | 
147  |  | 
  | 
148  | 0  |             b0 = _mm512_unpacklo_ps(a0, a2);  | 
149  | 0  |             b1 = _mm512_unpacklo_ps(a1, a3);  | 
150  | 0  |             b2 = _mm512_unpackhi_ps(a0, a2);  | 
151  | 0  |             b3 = _mm512_unpackhi_ps(a1, a3);  | 
152  | 0  |             b4 = _mm512_unpacklo_ps(a4, a6);  | 
153  | 0  |             b5 = _mm512_unpacklo_ps(a5, a7);  | 
154  | 0  |             b6 = _mm512_unpackhi_ps(a4, a6);  | 
155  | 0  |             b7 = _mm512_unpackhi_ps(a5, a7);  | 
156  | 0  |             b8 = _mm512_unpacklo_ps(a8, aA);  | 
157  | 0  |             b9 = _mm512_unpacklo_ps(a9, aB);  | 
158  | 0  |             bA = _mm512_unpackhi_ps(a8, aA);  | 
159  | 0  |             bB = _mm512_unpackhi_ps(a9, aB);  | 
160  | 0  |             bC = _mm512_unpacklo_ps(aC, aE);  | 
161  | 0  |             bD = _mm512_unpacklo_ps(aD, aF);  | 
162  | 0  |             bE = _mm512_unpackhi_ps(aC, aE);  | 
163  | 0  |             bF = _mm512_unpackhi_ps(aD, aF);  | 
164  |  | 
  | 
165  | 0  |             a0 = _mm512_unpacklo_ps(b0, b1);  | 
166  | 0  |             a1 = _mm512_unpackhi_ps(b0, b1);  | 
167  | 0  |             a2 = _mm512_unpacklo_ps(b2, b3);  | 
168  | 0  |             a3 = _mm512_unpackhi_ps(b2, b3);  | 
169  | 0  |             a4 = _mm512_unpacklo_ps(b4, b5);  | 
170  | 0  |             a5 = _mm512_unpackhi_ps(b4, b5);  | 
171  | 0  |             a6 = _mm512_unpacklo_ps(b6, b7);  | 
172  | 0  |             a7 = _mm512_unpackhi_ps(b6, b7);  | 
173  | 0  |             a8 = _mm512_unpacklo_ps(b8, b9);  | 
174  | 0  |             a9 = _mm512_unpackhi_ps(b8, b9);  | 
175  | 0  |             aA = _mm512_unpacklo_ps(bA, bB);  | 
176  | 0  |             aB = _mm512_unpackhi_ps(bA, bB);  | 
177  | 0  |             aC = _mm512_unpacklo_ps(bC, bD);  | 
178  | 0  |             aD = _mm512_unpackhi_ps(bC, bD);  | 
179  | 0  |             aE = _mm512_unpacklo_ps(bE, bF);  | 
180  | 0  |             aF = _mm512_unpackhi_ps(bE, bF);  | 
181  |  | 
  | 
182  | 0  |             b0 = _mm512_shuffle_f32x4(a0, a4, 0x44);  | 
183  | 0  |             b1 = _mm512_shuffle_f32x4(a1, a5, 0x44);  | 
184  | 0  |             b2 = _mm512_shuffle_f32x4(a2, a6, 0x44);  | 
185  | 0  |             b3 = _mm512_shuffle_f32x4(a3, a7, 0x44);  | 
186  | 0  |             b4 = _mm512_shuffle_f32x4(a0, a4, 0xEE);  | 
187  | 0  |             b5 = _mm512_shuffle_f32x4(a1, a5, 0xEE);  | 
188  | 0  |             b6 = _mm512_shuffle_f32x4(a2, a6, 0xEE);  | 
189  | 0  |             b7 = _mm512_shuffle_f32x4(a3, a7, 0xEE);  | 
190  | 0  |             b8 = _mm512_shuffle_f32x4(a8, aC, 0x44);  | 
191  | 0  |             b9 = _mm512_shuffle_f32x4(a9, aD, 0x44);  | 
192  | 0  |             bA = _mm512_shuffle_f32x4(aA, aE, 0x44);  | 
193  | 0  |             bB = _mm512_shuffle_f32x4(aB, aF, 0x44);  | 
194  | 0  |             bC = _mm512_shuffle_f32x4(a8, aC, 0xEE);  | 
195  | 0  |             bD = _mm512_shuffle_f32x4(a9, aD, 0xEE);  | 
196  | 0  |             bE = _mm512_shuffle_f32x4(aA, aE, 0xEE);  | 
197  | 0  |             bF = _mm512_shuffle_f32x4(aB, aF, 0xEE);  | 
198  |  | 
  | 
199  | 0  |             a0 = _mm512_shuffle_f32x4(b0, b8, 0x88);  | 
200  | 0  |             a1 = _mm512_shuffle_f32x4(b1, b9, 0x88);  | 
201  | 0  |             a2 = _mm512_shuffle_f32x4(b2, bA, 0x88);  | 
202  | 0  |             a3 = _mm512_shuffle_f32x4(b3, bB, 0x88);  | 
203  | 0  |             a4 = _mm512_shuffle_f32x4(b0, b8, 0xDD);  | 
204  | 0  |             a5 = _mm512_shuffle_f32x4(b1, b9, 0xDD);  | 
205  | 0  |             a6 = _mm512_shuffle_f32x4(b2, bA, 0xDD);  | 
206  | 0  |             a7 = _mm512_shuffle_f32x4(b3, bB, 0xDD);  | 
207  | 0  |             a8 = _mm512_shuffle_f32x4(b4, bC, 0x88);  | 
208  | 0  |             a9 = _mm512_shuffle_f32x4(b5, bD, 0x88);  | 
209  | 0  |             aA = _mm512_shuffle_f32x4(b6, bE, 0x88);  | 
210  | 0  |             aB = _mm512_shuffle_f32x4(b7, bF, 0x88);  | 
211  | 0  |             aC = _mm512_shuffle_f32x4(b4, bC, 0xDD);  | 
212  | 0  |             aD = _mm512_shuffle_f32x4(b5, bD, 0xDD);  | 
213  | 0  |             aE = _mm512_shuffle_f32x4(b6, bE, 0xDD);  | 
214  | 0  |             aF = _mm512_shuffle_f32x4(b7, bF, 0xDD);  | 
215  |  | 
  | 
216  | 0  |             max = _mm512_max_ps(max, a0);  | 
217  | 0  |             max = _mm512_max_ps(max, a1);  | 
218  | 0  |             max = _mm512_max_ps(max, a2);  | 
219  | 0  |             max = _mm512_max_ps(max, a3);  | 
220  | 0  |             max = _mm512_max_ps(max, a4);  | 
221  | 0  |             max = _mm512_max_ps(max, a5);  | 
222  | 0  |             max = _mm512_max_ps(max, a6);  | 
223  | 0  |             max = _mm512_max_ps(max, a7);  | 
224  | 0  |             max = _mm512_max_ps(max, a8);  | 
225  | 0  |             max = _mm512_max_ps(max, a9);  | 
226  | 0  |             max = _mm512_max_ps(max, aA);  | 
227  | 0  |             max = _mm512_max_ps(max, aB);  | 
228  | 0  |             max = _mm512_max_ps(max, aC);  | 
229  | 0  |             max = _mm512_max_ps(max, aD);  | 
230  | 0  |             max = _mm512_max_ps(max, aE);  | 
231  | 0  |             max = _mm512_max_ps(max, aF);  | 
232  |  | 
  | 
233  | 0  |             _mm512_storeu_ps(dst + 0x0 * F, a0);  | 
234  | 0  |             _mm512_storeu_ps(dst + 0x1 * F, a1);  | 
235  | 0  |             _mm512_storeu_ps(dst + 0x2 * F, a2);  | 
236  | 0  |             _mm512_storeu_ps(dst + 0x3 * F, a3);  | 
237  | 0  |             _mm512_storeu_ps(dst + 0x4 * F, a4);  | 
238  | 0  |             _mm512_storeu_ps(dst + 0x5 * F, a5);  | 
239  | 0  |             _mm512_storeu_ps(dst + 0x6 * F, a6);  | 
240  | 0  |             _mm512_storeu_ps(dst + 0x7 * F, a7);  | 
241  | 0  |             _mm512_storeu_ps(dst + 0x8 * F, a8);  | 
242  | 0  |             _mm512_storeu_ps(dst + 0x9 * F, a9);  | 
243  | 0  |             _mm512_storeu_ps(dst + 0xA * F, aA);  | 
244  | 0  |             _mm512_storeu_ps(dst + 0xB * F, aB);  | 
245  | 0  |             _mm512_storeu_ps(dst + 0xC * F, aC);  | 
246  | 0  |             _mm512_storeu_ps(dst + 0xD * F, aD);  | 
247  | 0  |             _mm512_storeu_ps(dst + 0xE * F, aE);  | 
248  | 0  |             _mm512_storeu_ps(dst + 0xF * F, aF);  | 
249  | 0  |         }  | 
250  |  |  | 
251  |  |         SIMD_INLINE void LoadTansp16x16(const float* src, size_t srcStride, size_t cols, size_t rows, float* dst, __m512& max)  | 
252  | 0  |         { | 
253  | 0  |             __m512 a[16], b[16];  | 
254  |  | 
  | 
255  | 0  |             __mmask16 srcMask = __mmask16(-1) >> (16 - cols);  | 
256  | 0  |             __m512 def = _mm512_set1_ps(-FLT_MAX);  | 
257  | 0  |             for(size_t r = 0; r < rows; ++r)  | 
258  | 0  |                 a[r] = _mm512_mask_loadu_ps(def, srcMask, src + r * srcStride);  | 
259  |  | 
  | 
260  | 0  |             for (size_t r = 0; r < rows; r += 4)  | 
261  | 0  |             { | 
262  | 0  |                 b[r + 0] = _mm512_unpacklo_ps(a[r + 0], a[r + 2]);  | 
263  | 0  |                 b[r + 1] = _mm512_unpacklo_ps(a[r + 1], a[r + 3]);  | 
264  | 0  |                 b[r + 2] = _mm512_unpackhi_ps(a[r + 0], a[r + 2]);  | 
265  | 0  |                 b[r + 3] = _mm512_unpackhi_ps(a[r + 1], a[r + 3]);  | 
266  | 0  |             }  | 
267  |  | 
  | 
268  | 0  |             for (size_t r = 0; r < rows; r += 4)  | 
269  | 0  |             { | 
270  | 0  |                 a[r + 0] = _mm512_unpacklo_ps(b[r + 0], b[r + 1]);  | 
271  | 0  |                 a[r + 1] = _mm512_unpackhi_ps(b[r + 0], b[r + 1]);  | 
272  | 0  |                 a[r + 2] = _mm512_unpacklo_ps(b[r + 2], b[r + 3]);  | 
273  | 0  |                 a[r + 3] = _mm512_unpackhi_ps(b[r + 2], b[r + 3]);  | 
274  | 0  |             }  | 
275  |  | 
  | 
276  | 0  |             for (size_t i = 0; i < 4; i += 1)  | 
277  | 0  |             { | 
278  | 0  |                 b[0x0 + i] = _mm512_shuffle_f32x4(a[0x0 + i], a[0x4 + i], 0x44);  | 
279  | 0  |                 b[0x4 + i] = _mm512_shuffle_f32x4(a[0x0 + i], a[0x4 + i], 0xEE);  | 
280  | 0  |                 b[0x8 + i] = _mm512_shuffle_f32x4(a[0x8 + i], a[0xC + i], 0x44);  | 
281  | 0  |                 b[0xC + i] = _mm512_shuffle_f32x4(a[0x8 + i], a[0xC + i], 0xEE);  | 
282  | 0  |             }  | 
283  |  | 
  | 
284  | 0  |             for (size_t i = 0; i < 4; i += 1)  | 
285  | 0  |             { | 
286  | 0  |                 a[0x0 + i] = _mm512_shuffle_f32x4(b[0x0 + i], b[0x8 + i], 0x88);  | 
287  | 0  |                 a[0x4 + i] = _mm512_shuffle_f32x4(b[0x0 + i], b[0x8 + i], 0xDD);  | 
288  | 0  |                 a[0x8 + i] = _mm512_shuffle_f32x4(b[0x4 + i], b[0xC + i], 0x88);  | 
289  | 0  |                 a[0xC + i] = _mm512_shuffle_f32x4(b[0x4 + i], b[0xC + i], 0xDD);  | 
290  | 0  |             }  | 
291  |  | 
  | 
292  | 0  |             for (size_t c = 0; c < cols; ++c)  | 
293  | 0  |             { | 
294  | 0  |                 max = _mm512_max_ps(max, a[c]);  | 
295  | 0  |                 _mm512_storeu_ps(dst + c * F, a[c]);  | 
296  | 0  |             }  | 
297  | 0  |         }  | 
298  |  |  | 
299  |  |         SIMD_INLINE void StoreTansp16x16(const float* src, size_t cols, __m512 k, float* dst, size_t dstStride)  | 
300  | 0  |         { | 
301  | 0  |             __m512 a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, aA, aB, aC, aD, aE, aF, b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, bA, bB, bC, bD, bE, bF;  | 
302  |  | 
  | 
303  | 0  |             a0 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x0 * F));  | 
304  | 0  |             a1 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x1 * F));  | 
305  | 0  |             a2 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x2 * F));  | 
306  | 0  |             a3 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x3 * F));  | 
307  | 0  |             a4 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x4 * F));  | 
308  | 0  |             a5 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x5 * F));  | 
309  | 0  |             a6 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x6 * F));  | 
310  | 0  |             a7 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x7 * F));  | 
311  | 0  |             a8 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x8 * F));  | 
312  | 0  |             a9 = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0x9 * F));  | 
313  | 0  |             aA = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xA * F));  | 
314  | 0  |             aB = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xB * F));  | 
315  | 0  |             aC = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xC * F));  | 
316  | 0  |             aD = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xD * F));  | 
317  | 0  |             aE = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xE * F));  | 
318  | 0  |             aF = _mm512_mul_ps(k, _mm512_loadu_ps(src + 0xF * F));  | 
319  |  | 
  | 
320  | 0  |             b0 = _mm512_unpacklo_ps(a0, a2);  | 
321  | 0  |             b1 = _mm512_unpacklo_ps(a1, a3);  | 
322  | 0  |             b2 = _mm512_unpackhi_ps(a0, a2);  | 
323  | 0  |             b3 = _mm512_unpackhi_ps(a1, a3);  | 
324  | 0  |             b4 = _mm512_unpacklo_ps(a4, a6);  | 
325  | 0  |             b5 = _mm512_unpacklo_ps(a5, a7);  | 
326  | 0  |             b6 = _mm512_unpackhi_ps(a4, a6);  | 
327  | 0  |             b7 = _mm512_unpackhi_ps(a5, a7);  | 
328  | 0  |             b8 = _mm512_unpacklo_ps(a8, aA);  | 
329  | 0  |             b9 = _mm512_unpacklo_ps(a9, aB);  | 
330  | 0  |             bA = _mm512_unpackhi_ps(a8, aA);  | 
331  | 0  |             bB = _mm512_unpackhi_ps(a9, aB);  | 
332  | 0  |             bC = _mm512_unpacklo_ps(aC, aE);  | 
333  | 0  |             bD = _mm512_unpacklo_ps(aD, aF);  | 
334  | 0  |             bE = _mm512_unpackhi_ps(aC, aE);  | 
335  | 0  |             bF = _mm512_unpackhi_ps(aD, aF);  | 
336  |  | 
  | 
337  | 0  |             a0 = _mm512_unpacklo_ps(b0, b1);  | 
338  | 0  |             a1 = _mm512_unpackhi_ps(b0, b1);  | 
339  | 0  |             a2 = _mm512_unpacklo_ps(b2, b3);  | 
340  | 0  |             a3 = _mm512_unpackhi_ps(b2, b3);  | 
341  | 0  |             a4 = _mm512_unpacklo_ps(b4, b5);  | 
342  | 0  |             a5 = _mm512_unpackhi_ps(b4, b5);  | 
343  | 0  |             a6 = _mm512_unpacklo_ps(b6, b7);  | 
344  | 0  |             a7 = _mm512_unpackhi_ps(b6, b7);  | 
345  | 0  |             a8 = _mm512_unpacklo_ps(b8, b9);  | 
346  | 0  |             a9 = _mm512_unpackhi_ps(b8, b9);  | 
347  | 0  |             aA = _mm512_unpacklo_ps(bA, bB);  | 
348  | 0  |             aB = _mm512_unpackhi_ps(bA, bB);  | 
349  | 0  |             aC = _mm512_unpacklo_ps(bC, bD);  | 
350  | 0  |             aD = _mm512_unpackhi_ps(bC, bD);  | 
351  | 0  |             aE = _mm512_unpacklo_ps(bE, bF);  | 
352  | 0  |             aF = _mm512_unpackhi_ps(bE, bF);  | 
353  |  | 
  | 
354  | 0  |             b0 = _mm512_shuffle_f32x4(a0, a4, 0x44);  | 
355  | 0  |             b1 = _mm512_shuffle_f32x4(a1, a5, 0x44);  | 
356  | 0  |             b2 = _mm512_shuffle_f32x4(a2, a6, 0x44);  | 
357  | 0  |             b3 = _mm512_shuffle_f32x4(a3, a7, 0x44);  | 
358  | 0  |             b4 = _mm512_shuffle_f32x4(a0, a4, 0xEE);  | 
359  | 0  |             b5 = _mm512_shuffle_f32x4(a1, a5, 0xEE);  | 
360  | 0  |             b6 = _mm512_shuffle_f32x4(a2, a6, 0xEE);  | 
361  | 0  |             b7 = _mm512_shuffle_f32x4(a3, a7, 0xEE);  | 
362  | 0  |             b8 = _mm512_shuffle_f32x4(a8, aC, 0x44);  | 
363  | 0  |             b9 = _mm512_shuffle_f32x4(a9, aD, 0x44);  | 
364  | 0  |             bA = _mm512_shuffle_f32x4(aA, aE, 0x44);  | 
365  | 0  |             bB = _mm512_shuffle_f32x4(aB, aF, 0x44);  | 
366  | 0  |             bC = _mm512_shuffle_f32x4(a8, aC, 0xEE);  | 
367  | 0  |             bD = _mm512_shuffle_f32x4(a9, aD, 0xEE);  | 
368  | 0  |             bE = _mm512_shuffle_f32x4(aA, aE, 0xEE);  | 
369  | 0  |             bF = _mm512_shuffle_f32x4(aB, aF, 0xEE);  | 
370  |  | 
  | 
371  | 0  |             a0 = _mm512_shuffle_f32x4(b0, b8, 0x88);  | 
372  | 0  |             a1 = _mm512_shuffle_f32x4(b1, b9, 0x88);  | 
373  | 0  |             a2 = _mm512_shuffle_f32x4(b2, bA, 0x88);  | 
374  | 0  |             a3 = _mm512_shuffle_f32x4(b3, bB, 0x88);  | 
375  | 0  |             a4 = _mm512_shuffle_f32x4(b0, b8, 0xDD);  | 
376  | 0  |             a5 = _mm512_shuffle_f32x4(b1, b9, 0xDD);  | 
377  | 0  |             a6 = _mm512_shuffle_f32x4(b2, bA, 0xDD);  | 
378  | 0  |             a7 = _mm512_shuffle_f32x4(b3, bB, 0xDD);  | 
379  | 0  |             a8 = _mm512_shuffle_f32x4(b4, bC, 0x88);  | 
380  | 0  |             a9 = _mm512_shuffle_f32x4(b5, bD, 0x88);  | 
381  | 0  |             aA = _mm512_shuffle_f32x4(b6, bE, 0x88);  | 
382  | 0  |             aB = _mm512_shuffle_f32x4(b7, bF, 0x88);  | 
383  | 0  |             aC = _mm512_shuffle_f32x4(b4, bC, 0xDD);  | 
384  | 0  |             aD = _mm512_shuffle_f32x4(b5, bD, 0xDD);  | 
385  | 0  |             aE = _mm512_shuffle_f32x4(b6, bE, 0xDD);  | 
386  | 0  |             aF = _mm512_shuffle_f32x4(b7, bF, 0xDD);  | 
387  |  | 
  | 
388  | 0  |             __mmask16 dstMask = __mmask16(-1) >> (16 - cols);  | 
389  | 0  |             _mm512_mask_storeu_ps(dst + 0x0 * dstStride, dstMask, a0);  | 
390  | 0  |             _mm512_mask_storeu_ps(dst + 0x1 * dstStride, dstMask, a1);  | 
391  | 0  |             _mm512_mask_storeu_ps(dst + 0x2 * dstStride, dstMask, a2);  | 
392  | 0  |             _mm512_mask_storeu_ps(dst + 0x3 * dstStride, dstMask, a3);  | 
393  | 0  |             _mm512_mask_storeu_ps(dst + 0x4 * dstStride, dstMask, a4);  | 
394  | 0  |             _mm512_mask_storeu_ps(dst + 0x5 * dstStride, dstMask, a5);  | 
395  | 0  |             _mm512_mask_storeu_ps(dst + 0x6 * dstStride, dstMask, a6);  | 
396  | 0  |             _mm512_mask_storeu_ps(dst + 0x7 * dstStride, dstMask, a7);  | 
397  | 0  |             _mm512_mask_storeu_ps(dst + 0x8 * dstStride, dstMask, a8);  | 
398  | 0  |             _mm512_mask_storeu_ps(dst + 0x9 * dstStride, dstMask, a9);  | 
399  | 0  |             _mm512_mask_storeu_ps(dst + 0xA * dstStride, dstMask, aA);  | 
400  | 0  |             _mm512_mask_storeu_ps(dst + 0xB * dstStride, dstMask, aB);  | 
401  | 0  |             _mm512_mask_storeu_ps(dst + 0xC * dstStride, dstMask, aC);  | 
402  | 0  |             _mm512_mask_storeu_ps(dst + 0xD * dstStride, dstMask, aD);  | 
403  | 0  |             _mm512_mask_storeu_ps(dst + 0xE * dstStride, dstMask, aE);  | 
404  | 0  |             _mm512_mask_storeu_ps(dst + 0xF * dstStride, dstMask, aF);  | 
405  | 0  |         }  | 
406  |  |  | 
407  |  |         SIMD_INLINE void StoreTansp16x16(const float* src, size_t cols, size_t rows, __m512 k, float* dst, size_t dstStride)  | 
408  | 0  |         { | 
409  | 0  |             __m512 a[16], b[16];  | 
410  |  | 
  | 
411  | 0  |             for (size_t c = 0; c < cols; ++c)  | 
412  | 0  |                 a[c] = _mm512_mul_ps(k, _mm512_loadu_ps(src + c * F));  | 
413  |  | 
  | 
414  | 0  |             for (size_t i = 0; i < 4; i += 1)  | 
415  | 0  |             { | 
416  | 0  |                 b[0x0 + i] = _mm512_shuffle_f32x4(a[0x0 + i], a[0x4 + i], 0x44);  | 
417  | 0  |                 b[0x4 + i] = _mm512_shuffle_f32x4(a[0x0 + i], a[0x4 + i], 0xEE);  | 
418  | 0  |                 b[0x8 + i] = _mm512_shuffle_f32x4(a[0x8 + i], a[0xC + i], 0x44);  | 
419  | 0  |                 b[0xC + i] = _mm512_shuffle_f32x4(a[0x8 + i], a[0xC + i], 0xEE);  | 
420  | 0  |             }  | 
421  |  | 
  | 
422  | 0  |             for (size_t i = 0; i < 4; i += 1)  | 
423  | 0  |             { | 
424  | 0  |                 a[0x0 + i] = _mm512_shuffle_f32x4(b[0x0 + i], b[0x8 + i], 0x88);  | 
425  | 0  |                 a[0x4 + i] = _mm512_shuffle_f32x4(b[0x0 + i], b[0x8 + i], 0xDD);  | 
426  | 0  |                 a[0x8 + i] = _mm512_shuffle_f32x4(b[0x4 + i], b[0xC + i], 0x88);  | 
427  | 0  |                 a[0xC + i] = _mm512_shuffle_f32x4(b[0x4 + i], b[0xC + i], 0xDD);  | 
428  | 0  |             }  | 
429  |  | 
  | 
430  | 0  |             for (size_t r = 0; r < rows; r += 4)  | 
431  | 0  |             { | 
432  | 0  |                 b[r + 0] = _mm512_unpacklo_ps(a[r + 0], a[r + 2]);  | 
433  | 0  |                 b[r + 1] = _mm512_unpacklo_ps(a[r + 1], a[r + 3]);  | 
434  | 0  |                 b[r + 2] = _mm512_unpackhi_ps(a[r + 0], a[r + 2]);  | 
435  | 0  |                 b[r + 3] = _mm512_unpackhi_ps(a[r + 1], a[r + 3]);  | 
436  | 0  |             }  | 
437  |  | 
  | 
438  | 0  |             for (size_t r = 0; r < rows; r += 4)  | 
439  | 0  |             { | 
440  | 0  |                 a[r + 0] = _mm512_unpacklo_ps(b[r + 0], b[r + 1]);  | 
441  | 0  |                 a[r + 1] = _mm512_unpackhi_ps(b[r + 0], b[r + 1]);  | 
442  | 0  |                 a[r + 2] = _mm512_unpacklo_ps(b[r + 2], b[r + 3]);  | 
443  | 0  |                 a[r + 3] = _mm512_unpackhi_ps(b[r + 2], b[r + 3]);  | 
444  | 0  |             }  | 
445  |  | 
  | 
446  | 0  |             __mmask16 dstMask = __mmask16(-1) >> (16 - cols);  | 
447  | 0  |             for (size_t r = 0; r < rows; ++r)  | 
448  | 0  |                 _mm512_mask_storeu_ps(dst + r * dstStride, dstMask, a[r]);  | 
449  | 0  |         }  | 
450  |  |  | 
451  |  |         void SynetSoftmaxLayerForwardX1(const float* src, size_t outer, size_t count, float* dst)  | 
452  | 0  |         { | 
453  | 0  |             size_t o = 0, c = 0, outerF = AlignLo(outer, F), countF = AlignLo(count, F);  | 
454  | 0  |             Array32f buf(AlignHi(count, F) * F);  | 
455  | 0  |             Exp exp;  | 
456  | 0  |             for (; o < outerF; o += F)  | 
457  | 0  |             { | 
458  | 0  |                 __m512 _max = _mm512_set1_ps(-FLT_MAX);  | 
459  | 0  |                 for (c = 0; c < countF; c += F)  | 
460  | 0  |                     LoadTansp16x16(src + c, count, F, buf.data + c * F, _max);  | 
461  | 0  |                 if (c < count)  | 
462  | 0  |                     LoadTansp16x16(src + c, count, count - c, buf.data + c * F, _max);  | 
463  | 0  |                 __m512 _sum = _mm512_setzero_ps();  | 
464  | 0  |                 for (c = 0; c < count; ++c)  | 
465  | 0  |                 { | 
466  | 0  |                     __m512 _exp = exp.Exponent(_mm512_sub_ps(_mm512_loadu_ps(buf.data + c * F), _max));  | 
467  | 0  |                     _sum = _mm512_add_ps(_sum, _exp);  | 
468  | 0  |                     _mm512_storeu_ps(buf.data + c * F, _exp);  | 
469  | 0  |                 }  | 
470  | 0  |                 __m512 _k = _mm512_div_ps(_mm512_set1_ps(1.0f), _sum);  | 
471  | 0  |                 for (c = 0; c < countF; c += F)  | 
472  | 0  |                     StoreTansp16x16(buf.data + c * F, F, _k, dst + c, count);  | 
473  | 0  |                 if (c < count)  | 
474  | 0  |                     StoreTansp16x16(buf.data + c * F, count - c, _k, dst + c, count);  | 
475  | 0  |                 src += count * F;  | 
476  | 0  |                 dst += count * F;  | 
477  | 0  |             }  | 
478  | 0  |             if (o < outer)  | 
479  | 0  |             { | 
480  | 0  |                 buf.Clear();  | 
481  | 0  |                 __m512 _max = _mm512_set1_ps(-FLT_MAX);  | 
482  | 0  |                 for (c = 0; c < countF; c += F)  | 
483  | 0  |                     LoadTansp16x16(src + c, count, F, outer - o, buf.data + c * F, _max);  | 
484  | 0  |                 if (c < count)  | 
485  | 0  |                     LoadTansp16x16(src + c, count, count - c, outer - o, buf.data + c * F, _max);  | 
486  | 0  |                 __m512 _sum = _mm512_setzero_ps();  | 
487  | 0  |                 for (c = 0; c < count; ++c)  | 
488  | 0  |                 { | 
489  | 0  |                     __m512 _exp = exp.Exponent(_mm512_sub_ps(_mm512_loadu_ps(buf.data + c * F), _max));  | 
490  | 0  |                     _sum = _mm512_add_ps(_sum, _exp);  | 
491  | 0  |                     _mm512_storeu_ps(buf.data + c * F, _exp);  | 
492  | 0  |                 }  | 
493  | 0  |                 __m512 _k = _mm512_div_ps(_mm512_set1_ps(1.0f), _sum);  | 
494  | 0  |                 for (c = 0; c < countF; c += F)  | 
495  | 0  |                     StoreTansp16x16(buf.data + c * F, F, outer - o, _k, dst + c, count);  | 
496  | 0  |                 if (c < count)  | 
497  | 0  |                     StoreTansp16x16(buf.data + c * F, count - c, outer - o, _k, dst + c, count);  | 
498  | 0  |             }  | 
499  | 0  |         }  | 
500  |  |  | 
501  |  |         void SynetSoftmaxLayerForward(const float* src, size_t outer, size_t count, size_t inner, float* dst)  | 
502  | 0  |         { | 
503  | 0  |             if (inner == 1)  | 
504  | 0  |             { | 
505  | 0  |                 if (count == 2)  | 
506  | 0  |                     SynetSoftmaxLayerForward21(src, outer, dst);  | 
507  | 0  |                 else if (count == 3)  | 
508  | 0  |                     SynetSoftmaxLayerForward31(src, outer, dst);  | 
509  | 0  |                 else  | 
510  | 0  |                     SynetSoftmaxLayerForwardX1(src, outer, count, dst);  | 
511  | 0  |             }  | 
512  | 0  |             else  | 
513  | 0  |             { | 
514  | 0  |                 Exp exp;  | 
515  | 0  |                 size_t aligned = Simd::AlignLo(inner, F);  | 
516  | 0  |                 __mmask16 tail = TailMask16(inner - aligned);  | 
517  | 0  |                 Array32f tmp(inner * 2);  | 
518  | 0  |                 const float* s;  | 
519  | 0  |                 float* max = tmp.data, * sum = tmp.data + inner, * d;  | 
520  | 0  |                 for (size_t o = 0; o < outer; ++o)  | 
521  | 0  |                 { | 
522  | 0  |                     memcpy(max, src, inner * sizeof(float));  | 
523  | 0  |                     s = src + inner;  | 
524  | 0  |                     for (size_t c = 1; c < count; ++c)  | 
525  | 0  |                     { | 
526  | 0  |                         size_t i = 0;  | 
527  | 0  |                         for (; i < aligned; i += F)  | 
528  | 0  |                             _mm512_storeu_ps(max + i, _mm512_max_ps(_mm512_loadu_ps(s + i), _mm512_loadu_ps(max + i)));  | 
529  | 0  |                         if (i < inner)  | 
530  | 0  |                             _mm512_mask_storeu_ps(max + i, tail, _mm512_max_ps(_mm512_maskz_loadu_ps(tail, s + i), _mm512_maskz_loadu_ps(tail, max + i)));  | 
531  | 0  |                         s += inner;  | 
532  | 0  |                     }  | 
533  |  | 
  | 
534  | 0  |                     s = src;  | 
535  | 0  |                     d = dst;  | 
536  | 0  |                     memset(sum, 0, inner * sizeof(float));  | 
537  | 0  |                     for (size_t c = 0; c < count; ++c)  | 
538  | 0  |                     { | 
539  | 0  |                         size_t i = 0;  | 
540  | 0  |                         for (; i < aligned; i += F)  | 
541  | 0  |                         { | 
542  | 0  |                             __m512 _d = exp.Exponent(_mm512_sub_ps(_mm512_loadu_ps(s + i), _mm512_loadu_ps(max + i)));  | 
543  | 0  |                             _mm512_storeu_ps(d + i, _d);  | 
544  | 0  |                             _mm512_storeu_ps(sum + i, _mm512_add_ps(_d, _mm512_loadu_ps(sum + i)));  | 
545  | 0  |                         }  | 
546  | 0  |                         if (i < inner)  | 
547  | 0  |                         { | 
548  | 0  |                             __m512 _d = exp.Exponent(_mm512_sub_ps(_mm512_maskz_loadu_ps(tail, s + i), _mm512_maskz_loadu_ps(tail, max + i)));  | 
549  | 0  |                             _mm512_mask_storeu_ps(d + i, tail, _d);  | 
550  | 0  |                             _mm512_mask_storeu_ps(sum + i, tail, _mm512_add_ps(_d, _mm512_maskz_loadu_ps(tail, sum + i)));  | 
551  | 0  |                         }  | 
552  | 0  |                         s += inner;  | 
553  | 0  |                         d += inner;  | 
554  | 0  |                     }  | 
555  |  | 
  | 
556  | 0  |                     d = dst;  | 
557  | 0  |                     for (size_t c = 0; c < count; ++c)  | 
558  | 0  |                     { | 
559  | 0  |                         size_t i = 0;  | 
560  | 0  |                         for (; i < aligned; i += F)  | 
561  | 0  |                             _mm512_storeu_ps(d + i, _mm512_div_ps(_mm512_loadu_ps(d + i), _mm512_loadu_ps(sum + i)));  | 
562  | 0  |                         if (i < inner)  | 
563  | 0  |                             _mm512_mask_storeu_ps(d + i, tail, _mm512_div_ps(_mm512_maskz_loadu_ps(tail, d + i), _mm512_maskz_loadu_ps(tail, sum + i)));  | 
564  | 0  |                         d += inner;  | 
565  | 0  |                     }  | 
566  | 0  |                     src += count * inner;  | 
567  | 0  |                     dst += count * inner;  | 
568  | 0  |                 }  | 
569  | 0  |             }  | 
570  | 0  |         }  | 
571  |  |     }  | 
572  |  | #endif  | 
573  |  | }  |