/src/xnnpack/src/qd8-f16-qb4w-gemm/gen/qd8-f16-qb4w-gemm-3x8c8-minmax-avx2.c
Line | Count | Source |
1 | | // clang-format off |
2 | | // Auto-generated file. Do not edit! |
3 | | // Template: src/qs8-gemm/MRx8c8-avx2.c.in |
4 | | // Generator: tools/xngen |
5 | | // |
6 | | // Copyright 2020 Google LLC |
7 | | // |
8 | | // This source code is licensed under the BSD-style license found in the |
9 | | // LICENSE file in the root directory of this source tree. |
10 | | |
11 | | #include <assert.h> |
12 | | #include <stddef.h> |
13 | | #include <stdint.h> |
14 | | |
15 | | #include <immintrin.h> |
16 | | |
17 | | #include "src/xnnpack/common.h" |
18 | | #include "src/xnnpack/gemm.h" |
19 | | #include "src/xnnpack/intrinsics-polyfill.h" |
20 | | #include "src/xnnpack/math.h" |
21 | | #include "src/xnnpack/microparams.h" |
22 | | |
23 | | |
24 | | void xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x8c8__avx2( |
25 | | size_t mr, |
26 | | size_t nc, |
27 | | size_t kc, |
28 | | const int8_t* restrict a, |
29 | | size_t a_stride, |
30 | | const void* restrict w, |
31 | | xnn_float16* restrict c, |
32 | | size_t cm_stride, |
33 | | size_t cn_stride, |
34 | | const struct xnn_f16_qb4w_minmax_params* restrict params, |
35 | | const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS |
36 | 0 | { |
37 | 0 | assert(mr != 0); |
38 | 0 | assert(mr <= 3); |
39 | 0 | assert(nc != 0); |
40 | 0 | assert(kc != 0); |
41 | 0 | assert(kc % sizeof(int8_t) == 0); |
42 | 0 | assert(a != NULL); |
43 | 0 | assert(w != NULL); |
44 | 0 | assert(c != NULL); |
45 | | |
46 | 0 | kc = round_up_po2(kc, 8 * sizeof(int8_t)); |
47 | 0 | size_t bl = params->scalar.blocksize; |
48 | 0 | assert(bl <= round_up_po2(kc, 16)); |
49 | 0 | assert(bl != 0); |
50 | 0 | assert(bl % 32 == 0); |
51 | 0 | const int8_t* a0 = a; |
52 | 0 | uint16_t* c0 = (uint16_t*) c; |
53 | 0 | const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); |
54 | 0 | uint16_t* c1 = (uint16_t*) ((uintptr_t) c0 + cm_stride); |
55 | 0 | if XNN_UNPREDICTABLE(mr < 2) { |
56 | 0 | a1 = a0; |
57 | 0 | c1 = c0; |
58 | 0 | } |
59 | 0 | const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); |
60 | 0 | uint16_t* c2 = (uint16_t*) ((uintptr_t) c1 + cm_stride); |
61 | 0 | if XNN_UNPREDICTABLE(mr <= 2) { |
62 | 0 | a2 = a1; |
63 | 0 | c2 = c1; |
64 | 0 | } |
65 | |
|
66 | 0 | const __m128i vmask = _mm_set1_epi8(0xF0); |
67 | 0 | XNN_FORCE_REALIZATION(vmask); |
68 | 0 | const __m256 vmin = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); |
69 | 0 | const __m256 vmax = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); |
70 | 0 | XNN_FORCE_REALIZATION(vmin); |
71 | 0 | XNN_FORCE_REALIZATION(vmax); |
72 | |
|
73 | 0 | do { |
74 | 0 | const __m128 vinit0 = _mm_load_ss(&((const float*) w)[0]); |
75 | 0 | const __m128 vinit1 = _mm_load_ss(&((const float*) w)[1]); |
76 | 0 | const __m256 vinit01 = _mm256_insertf128_ps(_mm256_castps128_ps256(vinit0), vinit1, 1); |
77 | 0 | const __m128 vinit2 = _mm_load_ss(&((const float*) w)[2]); |
78 | 0 | const __m128 vinit3 = _mm_load_ss(&((const float*) w)[3]); |
79 | 0 | const __m256 vinit23 = _mm256_insertf128_ps(_mm256_castps128_ps256(vinit2), vinit3, 1); |
80 | 0 | const __m128 vinit4 = _mm_load_ss(&((const float*) w)[4]); |
81 | 0 | const __m128 vinit5 = _mm_load_ss(&((const float*) w)[5]); |
82 | 0 | const __m256 vinit45 = _mm256_insertf128_ps(_mm256_castps128_ps256(vinit4), vinit5, 1); |
83 | 0 | const __m128 vinit6 = _mm_load_ss(&((const float*) w)[6]); |
84 | 0 | const __m128 vinit7 = _mm_load_ss(&((const float*) w)[7]); |
85 | 0 | const __m256 vinit67 = _mm256_insertf128_ps(_mm256_castps128_ps256(vinit6), vinit7, 1); |
86 | 0 | const __m256 vinput_zero_point0 = _mm256_set1_ps((float) quantization_params[0].zero_point); |
87 | 0 | __m256 vout0x01 = _mm256_mul_ps(vinit01, vinput_zero_point0); |
88 | 0 | __m256 vout0x23 = _mm256_mul_ps(vinit23, vinput_zero_point0); |
89 | 0 | __m256 vout0x45 = _mm256_mul_ps(vinit45, vinput_zero_point0); |
90 | 0 | __m256 vout0x67 = _mm256_mul_ps(vinit67, vinput_zero_point0); |
91 | 0 | const __m256 vinput_zero_point1 = _mm256_set1_ps((float) quantization_params[1].zero_point); |
92 | 0 | __m256 vout1x01 = _mm256_mul_ps(vinit01, vinput_zero_point1); |
93 | 0 | __m256 vout1x23 = _mm256_mul_ps(vinit23, vinput_zero_point1); |
94 | 0 | __m256 vout1x45 = _mm256_mul_ps(vinit45, vinput_zero_point1); |
95 | 0 | __m256 vout1x67 = _mm256_mul_ps(vinit67, vinput_zero_point1); |
96 | 0 | const __m256 vinput_zero_point2 = _mm256_set1_ps((float) quantization_params[2].zero_point); |
97 | 0 | __m256 vout2x01 = _mm256_mul_ps(vinit01, vinput_zero_point2); |
98 | 0 | __m256 vout2x23 = _mm256_mul_ps(vinit23, vinput_zero_point2); |
99 | 0 | __m256 vout2x45 = _mm256_mul_ps(vinit45, vinput_zero_point2); |
100 | 0 | __m256 vout2x67 = _mm256_mul_ps(vinit67, vinput_zero_point2); |
101 | 0 | w = (const int32_t*) w + 8; |
102 | |
|
103 | 0 | for (size_t kb=0; kb < kc; kb += bl) { |
104 | 0 | __m256i vacc0x01 = _mm256_setzero_si256(); |
105 | 0 | __m256i vacc0x23 = _mm256_setzero_si256(); |
106 | 0 | __m256i vacc0x45 = _mm256_setzero_si256(); |
107 | 0 | __m256i vacc0x67 = _mm256_setzero_si256(); |
108 | 0 | __m256i vacc1x01 = _mm256_setzero_si256(); |
109 | 0 | __m256i vacc1x23 = _mm256_setzero_si256(); |
110 | 0 | __m256i vacc1x45 = _mm256_setzero_si256(); |
111 | 0 | __m256i vacc1x67 = _mm256_setzero_si256(); |
112 | 0 | __m256i vacc2x01 = _mm256_setzero_si256(); |
113 | 0 | __m256i vacc2x23 = _mm256_setzero_si256(); |
114 | 0 | __m256i vacc2x45 = _mm256_setzero_si256(); |
115 | 0 | __m256i vacc2x67 = _mm256_setzero_si256(); |
116 | |
|
117 | 0 | size_t k = bl; |
118 | 0 | while (k >= 16 * sizeof(int8_t)) { |
119 | 0 | __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0)); |
120 | 0 | __m256i vxa0 = _mm256_cvtepi8_epi16(va0); |
121 | 0 | a0 += 8; |
122 | 0 | __m128i va1 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a1)); |
123 | 0 | __m256i vxa1 = _mm256_cvtepi8_epi16(va1); |
124 | 0 | a1 += 8; |
125 | 0 | __m128i va2 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a2)); |
126 | 0 | __m256i vxa2 = _mm256_cvtepi8_epi16(va2); |
127 | 0 | a2 += 8; |
128 | |
|
129 | 0 | __m128i vb01 = _mm_load_si128((const __m128i*) w); |
130 | 0 | __m128i vbs01 = _mm_slli_epi32(vb01, 4); |
131 | 0 | __m128i vbm01 = _mm_and_si128(vbs01, vmask); |
132 | 0 | __m256i vxb01 = _mm256_cvtepi8_epi16(vbm01); |
133 | |
|
134 | 0 | vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01)); |
135 | 0 | vacc1x01 = _mm256_add_epi32(vacc1x01, _mm256_madd_epi16(vxa1, vxb01)); |
136 | 0 | vacc2x01 = _mm256_add_epi32(vacc2x01, _mm256_madd_epi16(vxa2, vxb01)); |
137 | 0 | __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16)); |
138 | 0 | __m128i vbs23 = _mm_slli_epi32(vb23, 4); |
139 | 0 | __m128i vbm23 = _mm_and_si128(vbs23, vmask); |
140 | 0 | __m256i vxb23 = _mm256_cvtepi8_epi16(vbm23); |
141 | |
|
142 | 0 | vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23)); |
143 | 0 | vacc1x23 = _mm256_add_epi32(vacc1x23, _mm256_madd_epi16(vxa1, vxb23)); |
144 | 0 | vacc2x23 = _mm256_add_epi32(vacc2x23, _mm256_madd_epi16(vxa2, vxb23)); |
145 | 0 | __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32)); |
146 | 0 | __m128i vbs45 = _mm_slli_epi32(vb45, 4); |
147 | 0 | __m128i vbm45 = _mm_and_si128(vbs45, vmask); |
148 | 0 | __m256i vxb45 = _mm256_cvtepi8_epi16(vbm45); |
149 | |
|
150 | 0 | vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45)); |
151 | 0 | vacc1x45 = _mm256_add_epi32(vacc1x45, _mm256_madd_epi16(vxa1, vxb45)); |
152 | 0 | vacc2x45 = _mm256_add_epi32(vacc2x45, _mm256_madd_epi16(vxa2, vxb45)); |
153 | 0 | __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48)); |
154 | 0 | __m128i vbs67 = _mm_slli_epi32(vb67, 4); |
155 | 0 | __m128i vbm67 = _mm_and_si128(vbs67, vmask); |
156 | 0 | __m256i vxb67 = _mm256_cvtepi8_epi16(vbm67); |
157 | |
|
158 | 0 | vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67)); |
159 | 0 | vacc1x67 = _mm256_add_epi32(vacc1x67, _mm256_madd_epi16(vxa1, vxb67)); |
160 | 0 | vacc2x67 = _mm256_add_epi32(vacc2x67, _mm256_madd_epi16(vxa2, vxb67)); |
161 | |
|
162 | 0 | va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0)); |
163 | 0 | vxa0 = _mm256_cvtepi8_epi16(va0); |
164 | 0 | a0 += 8; |
165 | 0 | va1 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a1)); |
166 | 0 | vxa1 = _mm256_cvtepi8_epi16(va1); |
167 | 0 | a1 += 8; |
168 | 0 | va2 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a2)); |
169 | 0 | vxa2 = _mm256_cvtepi8_epi16(va2); |
170 | 0 | a2 += 8; |
171 | |
|
172 | 0 | vbm01 = _mm_and_si128(vb01, vmask); |
173 | 0 | vxb01 = _mm256_cvtepi8_epi16(vbm01); |
174 | |
|
175 | 0 | vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01)); |
176 | 0 | vacc1x01 = _mm256_add_epi32(vacc1x01, _mm256_madd_epi16(vxa1, vxb01)); |
177 | 0 | vacc2x01 = _mm256_add_epi32(vacc2x01, _mm256_madd_epi16(vxa2, vxb01)); |
178 | 0 | vbm23 = _mm_and_si128(vb23, vmask); |
179 | 0 | vxb23 = _mm256_cvtepi8_epi16(vbm23); |
180 | |
|
181 | 0 | vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23)); |
182 | 0 | vacc1x23 = _mm256_add_epi32(vacc1x23, _mm256_madd_epi16(vxa1, vxb23)); |
183 | 0 | vacc2x23 = _mm256_add_epi32(vacc2x23, _mm256_madd_epi16(vxa2, vxb23)); |
184 | 0 | vbm45 = _mm_and_si128(vb45, vmask); |
185 | 0 | vxb45 = _mm256_cvtepi8_epi16(vbm45); |
186 | |
|
187 | 0 | vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45)); |
188 | 0 | vacc1x45 = _mm256_add_epi32(vacc1x45, _mm256_madd_epi16(vxa1, vxb45)); |
189 | 0 | vacc2x45 = _mm256_add_epi32(vacc2x45, _mm256_madd_epi16(vxa2, vxb45)); |
190 | 0 | vbm67 = _mm_and_si128(vb67, vmask); |
191 | 0 | vxb67 = _mm256_cvtepi8_epi16(vbm67); |
192 | |
|
193 | 0 | vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67)); |
194 | 0 | vacc1x67 = _mm256_add_epi32(vacc1x67, _mm256_madd_epi16(vxa1, vxb67)); |
195 | 0 | vacc2x67 = _mm256_add_epi32(vacc2x67, _mm256_madd_epi16(vxa2, vxb67)); |
196 | |
|
197 | 0 | w = (const int8_t*) w + 64; |
198 | 0 | k -= 16 * sizeof(int8_t); |
199 | 0 | } |
200 | |
|
201 | 0 | while (k >= 8 * sizeof(int8_t)) { |
202 | 0 | const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0)); |
203 | 0 | const __m256i vxa0 = _mm256_cvtepi8_epi16(va0); |
204 | 0 | a0 += 8; |
205 | 0 | const __m128i va1 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a1)); |
206 | 0 | const __m256i vxa1 = _mm256_cvtepi8_epi16(va1); |
207 | 0 | a1 += 8; |
208 | 0 | const __m128i va2 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a2)); |
209 | 0 | const __m256i vxa2 = _mm256_cvtepi8_epi16(va2); |
210 | 0 | a2 += 8; |
211 | |
|
212 | 0 | const __m128i vb01 = _mm_load_si128((const __m128i*) w); |
213 | 0 | const __m128i vbs01 = _mm_slli_epi32(vb01, 4); |
214 | 0 | const __m128i vbm01 = _mm_and_si128(vbs01, vmask); |
215 | 0 | const __m256i vxb01 = _mm256_cvtepi8_epi16(vbm01); |
216 | |
|
217 | 0 | vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01)); |
218 | 0 | vacc1x01 = _mm256_add_epi32(vacc1x01, _mm256_madd_epi16(vxa1, vxb01)); |
219 | 0 | vacc2x01 = _mm256_add_epi32(vacc2x01, _mm256_madd_epi16(vxa2, vxb01)); |
220 | 0 | const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16)); |
221 | 0 | const __m128i vbs23 = _mm_slli_epi32(vb23, 4); |
222 | 0 | const __m128i vbm23 = _mm_and_si128(vbs23, vmask); |
223 | 0 | const __m256i vxb23 = _mm256_cvtepi8_epi16(vbm23); |
224 | |
|
225 | 0 | vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23)); |
226 | 0 | vacc1x23 = _mm256_add_epi32(vacc1x23, _mm256_madd_epi16(vxa1, vxb23)); |
227 | 0 | vacc2x23 = _mm256_add_epi32(vacc2x23, _mm256_madd_epi16(vxa2, vxb23)); |
228 | 0 | const __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32)); |
229 | 0 | const __m128i vbs45 = _mm_slli_epi32(vb45, 4); |
230 | 0 | const __m128i vbm45 = _mm_and_si128(vbs45, vmask); |
231 | 0 | const __m256i vxb45 = _mm256_cvtepi8_epi16(vbm45); |
232 | |
|
233 | 0 | vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45)); |
234 | 0 | vacc1x45 = _mm256_add_epi32(vacc1x45, _mm256_madd_epi16(vxa1, vxb45)); |
235 | 0 | vacc2x45 = _mm256_add_epi32(vacc2x45, _mm256_madd_epi16(vxa2, vxb45)); |
236 | 0 | const __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48)); |
237 | 0 | const __m128i vbs67 = _mm_slli_epi32(vb67, 4); |
238 | 0 | const __m128i vbm67 = _mm_and_si128(vbs67, vmask); |
239 | 0 | const __m256i vxb67 = _mm256_cvtepi8_epi16(vbm67); |
240 | |
|
241 | 0 | vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67)); |
242 | 0 | vacc1x67 = _mm256_add_epi32(vacc1x67, _mm256_madd_epi16(vxa1, vxb67)); |
243 | 0 | vacc2x67 = _mm256_add_epi32(vacc2x67, _mm256_madd_epi16(vxa2, vxb67)); |
244 | |
|
245 | 0 | w = (const int8_t*) w + 64; |
246 | 0 | k -= 8 * sizeof(int8_t); |
247 | 0 | } |
248 | |
|
249 | 0 | const __m128 vfilter_output_scale0 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[0] << 16)); |
250 | 0 | const __m128 vfilter_output_scale1 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[1] << 16)); |
251 | 0 | const __m256 vfilter_output_scale01 = _mm256_insertf128_ps( |
252 | 0 | _mm256_castps128_ps256(vfilter_output_scale0), vfilter_output_scale1, 1); |
253 | 0 | vout0x01 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc0x01), vfilter_output_scale01, vout0x01); |
254 | 0 | vout1x01 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc1x01), vfilter_output_scale01, vout1x01); |
255 | 0 | vout2x01 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc2x01), vfilter_output_scale01, vout2x01); |
256 | 0 | const __m128 vfilter_output_scale2 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[2] << 16)); |
257 | 0 | const __m128 vfilter_output_scale3 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[3] << 16)); |
258 | 0 | const __m256 vfilter_output_scale23 = _mm256_insertf128_ps( |
259 | 0 | _mm256_castps128_ps256(vfilter_output_scale2), vfilter_output_scale3, 1); |
260 | 0 | vout0x23 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc0x23), vfilter_output_scale23, vout0x23); |
261 | 0 | vout1x23 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc1x23), vfilter_output_scale23, vout1x23); |
262 | 0 | vout2x23 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc2x23), vfilter_output_scale23, vout2x23); |
263 | 0 | const __m128 vfilter_output_scale4 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[4] << 16)); |
264 | 0 | const __m128 vfilter_output_scale5 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[5] << 16)); |
265 | 0 | const __m256 vfilter_output_scale45 = _mm256_insertf128_ps( |
266 | 0 | _mm256_castps128_ps256(vfilter_output_scale4), vfilter_output_scale5, 1); |
267 | 0 | vout0x45 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc0x45), vfilter_output_scale45, vout0x45); |
268 | 0 | vout1x45 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc1x45), vfilter_output_scale45, vout1x45); |
269 | 0 | vout2x45 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc2x45), vfilter_output_scale45, vout2x45); |
270 | 0 | const __m128 vfilter_output_scale6 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[6] << 16)); |
271 | 0 | const __m128 vfilter_output_scale7 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[7] << 16)); |
272 | 0 | const __m256 vfilter_output_scale67 = _mm256_insertf128_ps( |
273 | 0 | _mm256_castps128_ps256(vfilter_output_scale6), vfilter_output_scale7, 1); |
274 | 0 | vout0x67 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc0x67), vfilter_output_scale67, vout0x67); |
275 | 0 | vout1x67 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc1x67), vfilter_output_scale67, vout1x67); |
276 | 0 | vout2x67 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc2x67), vfilter_output_scale67, vout2x67); |
277 | |
|
278 | 0 | w = (const uint16_t*) w + 8; |
279 | 0 | } |
280 | |
|
281 | 0 | const __m256 vout0x0213 = _mm256_hadd_ps(vout0x01, vout0x23); |
282 | 0 | const __m256 vout0x4657 = _mm256_hadd_ps(vout0x45, vout0x67); |
283 | 0 | const __m256 vout1x0213 = _mm256_hadd_ps(vout1x01, vout1x23); |
284 | 0 | const __m256 vout1x4657 = _mm256_hadd_ps(vout1x45, vout1x67); |
285 | 0 | const __m256 vout2x0213 = _mm256_hadd_ps(vout2x01, vout2x23); |
286 | 0 | const __m256 vout2x4657 = _mm256_hadd_ps(vout2x45, vout2x67); |
287 | |
|
288 | 0 | const __m256 vout0x02461357 = _mm256_hadd_ps(vout0x0213, vout0x4657); |
289 | 0 | const __m256 vout1x02461357 = _mm256_hadd_ps(vout1x0213, vout1x4657); |
290 | 0 | const __m256 vout2x02461357 = _mm256_hadd_ps(vout2x0213, vout2x4657); |
291 | |
|
292 | 0 | const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); |
293 | 0 | __m256 vout0x01234567 = _mm256_permutevar8x32_ps(vout0x02461357, vpermute_mask); |
294 | 0 | __m256 vout1x01234567 = _mm256_permutevar8x32_ps(vout1x02461357, vpermute_mask); |
295 | 0 | __m256 vout2x01234567 = _mm256_permutevar8x32_ps(vout2x02461357, vpermute_mask); |
296 | |
|
297 | 0 | const __m256 vinput_scale0 = _mm256_broadcast_ss(&quantization_params[0].inv_scale); |
298 | 0 | const __m256 vinput_scale1 = _mm256_broadcast_ss(&quantization_params[1].inv_scale); |
299 | 0 | const __m256 vinput_scale2 = _mm256_broadcast_ss(&quantization_params[2].inv_scale); |
300 | |
|
301 | 0 | const __m256 vbias01234567 = _mm256_loadu_ps((const float*) w); |
302 | 0 | w = (const float*) w + 8; |
303 | 0 | vout0x01234567 = _mm256_fmadd_ps(vout0x01234567, vinput_scale0, vbias01234567); |
304 | 0 | vout1x01234567 = _mm256_fmadd_ps(vout1x01234567, vinput_scale1, vbias01234567); |
305 | 0 | vout2x01234567 = _mm256_fmadd_ps(vout2x01234567, vinput_scale2, vbias01234567); |
306 | |
|
307 | 0 | vout0x01234567 = _mm256_max_ps(vout0x01234567, vmin); |
308 | 0 | vout1x01234567 = _mm256_max_ps(vout1x01234567, vmin); |
309 | 0 | vout2x01234567 = _mm256_max_ps(vout2x01234567, vmin); |
310 | |
|
311 | 0 | vout0x01234567 = _mm256_min_ps(vout0x01234567, vmax); |
312 | 0 | vout1x01234567 = _mm256_min_ps(vout1x01234567, vmax); |
313 | 0 | vout2x01234567 = _mm256_min_ps(vout2x01234567, vmax); |
314 | 0 | __m128i vfp16out0x01234567 = _mm256_cvtps_ph(vout0x01234567, _MM_FROUND_TO_NEAREST_INT); |
315 | 0 | __m128i vfp16out1x01234567 = _mm256_cvtps_ph(vout1x01234567, _MM_FROUND_TO_NEAREST_INT); |
316 | 0 | __m128i vfp16out2x01234567 = _mm256_cvtps_ph(vout2x01234567, _MM_FROUND_TO_NEAREST_INT); |
317 | 0 | if XNN_LIKELY(nc >= 8) { |
318 | 0 | _mm_storeu_si128((__m128i*) c0, vfp16out0x01234567); |
319 | 0 | c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); |
320 | 0 | _mm_storeu_si128((__m128i*) c1, vfp16out1x01234567); |
321 | 0 | c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); |
322 | 0 | _mm_storeu_si128((__m128i*) c2, vfp16out2x01234567); |
323 | 0 | c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); |
324 | |
|
325 | 0 | a0 = (const int8_t*) ((uintptr_t) a0 - kc); |
326 | 0 | a1 = (const int8_t*) ((uintptr_t) a1 - kc); |
327 | 0 | a2 = (const int8_t*) ((uintptr_t) a2 - kc); |
328 | |
|
329 | 0 | nc -= 8; |
330 | 0 | } else { |
331 | 0 | if (nc & 4) { |
332 | 0 | _mm_storel_epi64((__m128i*) c0, vfp16out0x01234567); |
333 | 0 | _mm_storel_epi64((__m128i*) c1, vfp16out1x01234567); |
334 | 0 | _mm_storel_epi64((__m128i*) c2, vfp16out2x01234567); |
335 | |
|
336 | 0 | vfp16out0x01234567 = _mm_unpackhi_epi64(vfp16out0x01234567, vfp16out0x01234567); |
337 | 0 | vfp16out1x01234567 = _mm_unpackhi_epi64(vfp16out1x01234567, vfp16out1x01234567); |
338 | 0 | vfp16out2x01234567 = _mm_unpackhi_epi64(vfp16out2x01234567, vfp16out2x01234567); |
339 | |
|
340 | 0 | c0 += 4; |
341 | 0 | c1 += 4; |
342 | 0 | c2 += 4; |
343 | 0 | } |
344 | 0 | if (nc & 2) { |
345 | 0 | _mm_storeu_si32(c0, vfp16out0x01234567); |
346 | 0 | _mm_storeu_si32(c1, vfp16out1x01234567); |
347 | 0 | _mm_storeu_si32(c2, vfp16out2x01234567); |
348 | |
|
349 | 0 | vfp16out0x01234567 = _mm_srli_epi64(vfp16out0x01234567, 32); |
350 | 0 | vfp16out1x01234567 = _mm_srli_epi64(vfp16out1x01234567, 32); |
351 | 0 | vfp16out2x01234567 = _mm_srli_epi64(vfp16out2x01234567, 32); |
352 | |
|
353 | 0 | c0 += 2; |
354 | 0 | c1 += 2; |
355 | 0 | c2 += 2; |
356 | 0 | } |
357 | 0 | if (nc & 1) { |
358 | 0 | *c0 = (uint16_t) _mm_extract_epi16(vfp16out0x01234567, 0); |
359 | 0 | *c1 = (uint16_t) _mm_extract_epi16(vfp16out1x01234567, 0); |
360 | 0 | *c2 = (uint16_t) _mm_extract_epi16(vfp16out2x01234567, 0); |
361 | 0 | } |
362 | 0 | nc = 0; |
363 | 0 | } |
364 | 0 | } while (nc != 0); |
365 | 0 | } |