/src/xnnpack/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x8c8-minmax-avx2.c
Line | Count | Source |
1 | | // clang-format off |
2 | | // Auto-generated file. Do not edit! |
3 | | // Template: src/qs8-gemm/MRx8c8-avx2.c.in |
4 | | // Generator: tools/xngen |
5 | | // |
6 | | // Copyright 2020 Google LLC |
7 | | // |
8 | | // This source code is licensed under the BSD-style license found in the |
9 | | // LICENSE file in the root directory of this source tree. |
10 | | |
11 | | #include <assert.h> |
12 | | #include <stddef.h> |
13 | | #include <stdint.h> |
14 | | |
15 | | #include <immintrin.h> |
16 | | |
17 | | #include "src/xnnpack/common.h" |
18 | | #include "src/xnnpack/gemm.h" |
19 | | #include "src/xnnpack/intrinsics-polyfill.h" |
20 | | #include "src/xnnpack/math.h" |
21 | | #include "src/xnnpack/microparams.h" |
22 | | |
23 | | |
24 | | void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x8c8__avx2( |
25 | | size_t mr, |
26 | | size_t nc, |
27 | | size_t kc, |
28 | | const int8_t* restrict a, |
29 | | size_t a_stride, |
30 | | const void* restrict w, |
31 | | float* restrict c, |
32 | | size_t cm_stride, |
33 | | size_t cn_stride, |
34 | | const struct xnn_f32_qb4w_minmax_params* restrict params, |
35 | | const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS |
36 | 0 | { |
37 | 0 | assert(mr != 0); |
38 | 0 | assert(mr <= 1); |
39 | 0 | assert(nc != 0); |
40 | 0 | assert(kc != 0); |
41 | 0 | assert(kc % sizeof(int8_t) == 0); |
42 | 0 | assert(a != NULL); |
43 | 0 | assert(w != NULL); |
44 | 0 | assert(c != NULL); |
45 | | |
46 | 0 | kc = round_up_po2(kc, 8 * sizeof(int8_t)); |
47 | 0 | size_t bl = params->scalar.blocksize; |
48 | 0 | assert(bl <= round_up_po2(kc, 16)); |
49 | 0 | assert(bl != 0); |
50 | 0 | assert(bl % 32 == 0); |
51 | 0 | const int8_t* a0 = a; |
52 | 0 | float* c0 = c; |
53 | |
|
54 | 0 | const __m128i vmask = _mm_set1_epi8(0xF0); |
55 | 0 | XNN_FORCE_REALIZATION(vmask); |
56 | 0 | const __m256 vmin = _mm256_set1_ps(params->scalar.min); |
57 | 0 | const __m256 vmax = _mm256_set1_ps(params->scalar.max); |
58 | 0 | XNN_FORCE_REALIZATION(vmin); |
59 | 0 | XNN_FORCE_REALIZATION(vmax); |
60 | |
|
61 | 0 | do { |
62 | 0 | const __m128 vinit0 = _mm_load_ss(&((const float*) w)[0]); |
63 | 0 | const __m128 vinit1 = _mm_load_ss(&((const float*) w)[1]); |
64 | 0 | const __m256 vinit01 = _mm256_insertf128_ps(_mm256_castps128_ps256(vinit0), vinit1, 1); |
65 | 0 | const __m128 vinit2 = _mm_load_ss(&((const float*) w)[2]); |
66 | 0 | const __m128 vinit3 = _mm_load_ss(&((const float*) w)[3]); |
67 | 0 | const __m256 vinit23 = _mm256_insertf128_ps(_mm256_castps128_ps256(vinit2), vinit3, 1); |
68 | 0 | const __m128 vinit4 = _mm_load_ss(&((const float*) w)[4]); |
69 | 0 | const __m128 vinit5 = _mm_load_ss(&((const float*) w)[5]); |
70 | 0 | const __m256 vinit45 = _mm256_insertf128_ps(_mm256_castps128_ps256(vinit4), vinit5, 1); |
71 | 0 | const __m128 vinit6 = _mm_load_ss(&((const float*) w)[6]); |
72 | 0 | const __m128 vinit7 = _mm_load_ss(&((const float*) w)[7]); |
73 | 0 | const __m256 vinit67 = _mm256_insertf128_ps(_mm256_castps128_ps256(vinit6), vinit7, 1); |
74 | 0 | const __m256 vinput_zero_point0 = _mm256_set1_ps((float) quantization_params[0].zero_point); |
75 | 0 | __m256 vout0x01 = _mm256_mul_ps(vinit01, vinput_zero_point0); |
76 | 0 | __m256 vout0x23 = _mm256_mul_ps(vinit23, vinput_zero_point0); |
77 | 0 | __m256 vout0x45 = _mm256_mul_ps(vinit45, vinput_zero_point0); |
78 | 0 | __m256 vout0x67 = _mm256_mul_ps(vinit67, vinput_zero_point0); |
79 | 0 | w = (const int32_t*) w + 8; |
80 | |
|
81 | 0 | for (size_t kb=0; kb < kc; kb += bl) { |
82 | 0 | __m256i vacc0x01 = _mm256_setzero_si256(); |
83 | 0 | __m256i vacc0x23 = _mm256_setzero_si256(); |
84 | 0 | __m256i vacc0x45 = _mm256_setzero_si256(); |
85 | 0 | __m256i vacc0x67 = _mm256_setzero_si256(); |
86 | |
|
87 | 0 | size_t k = bl; |
88 | 0 | while (k >= 16 * sizeof(int8_t)) { |
89 | 0 | __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0)); |
90 | 0 | __m256i vxa0 = _mm256_cvtepi8_epi16(va0); |
91 | 0 | a0 += 8; |
92 | |
|
93 | 0 | __m128i vb01 = _mm_load_si128((const __m128i*) w); |
94 | 0 | __m128i vbs01 = _mm_slli_epi32(vb01, 4); |
95 | 0 | __m128i vbm01 = _mm_and_si128(vbs01, vmask); |
96 | 0 | __m256i vxb01 = _mm256_cvtepi8_epi16(vbm01); |
97 | |
|
98 | 0 | vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01)); |
99 | 0 | __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16)); |
100 | 0 | __m128i vbs23 = _mm_slli_epi32(vb23, 4); |
101 | 0 | __m128i vbm23 = _mm_and_si128(vbs23, vmask); |
102 | 0 | __m256i vxb23 = _mm256_cvtepi8_epi16(vbm23); |
103 | |
|
104 | 0 | vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23)); |
105 | 0 | __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32)); |
106 | 0 | __m128i vbs45 = _mm_slli_epi32(vb45, 4); |
107 | 0 | __m128i vbm45 = _mm_and_si128(vbs45, vmask); |
108 | 0 | __m256i vxb45 = _mm256_cvtepi8_epi16(vbm45); |
109 | |
|
110 | 0 | vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45)); |
111 | 0 | __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48)); |
112 | 0 | __m128i vbs67 = _mm_slli_epi32(vb67, 4); |
113 | 0 | __m128i vbm67 = _mm_and_si128(vbs67, vmask); |
114 | 0 | __m256i vxb67 = _mm256_cvtepi8_epi16(vbm67); |
115 | |
|
116 | 0 | vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67)); |
117 | |
|
118 | 0 | va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0)); |
119 | 0 | vxa0 = _mm256_cvtepi8_epi16(va0); |
120 | 0 | a0 += 8; |
121 | |
|
122 | 0 | vbm01 = _mm_and_si128(vb01, vmask); |
123 | 0 | vxb01 = _mm256_cvtepi8_epi16(vbm01); |
124 | |
|
125 | 0 | vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01)); |
126 | 0 | vbm23 = _mm_and_si128(vb23, vmask); |
127 | 0 | vxb23 = _mm256_cvtepi8_epi16(vbm23); |
128 | |
|
129 | 0 | vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23)); |
130 | 0 | vbm45 = _mm_and_si128(vb45, vmask); |
131 | 0 | vxb45 = _mm256_cvtepi8_epi16(vbm45); |
132 | |
|
133 | 0 | vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45)); |
134 | 0 | vbm67 = _mm_and_si128(vb67, vmask); |
135 | 0 | vxb67 = _mm256_cvtepi8_epi16(vbm67); |
136 | |
|
137 | 0 | vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67)); |
138 | |
|
139 | 0 | w = (const int8_t*) w + 64; |
140 | 0 | k -= 16 * sizeof(int8_t); |
141 | 0 | } |
142 | |
|
143 | 0 | while (k >= 8 * sizeof(int8_t)) { |
144 | 0 | const __m128i va0 = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a0)); |
145 | 0 | const __m256i vxa0 = _mm256_cvtepi8_epi16(va0); |
146 | 0 | a0 += 8; |
147 | |
|
148 | 0 | const __m128i vb01 = _mm_load_si128((const __m128i*) w); |
149 | 0 | const __m128i vbs01 = _mm_slli_epi32(vb01, 4); |
150 | 0 | const __m128i vbm01 = _mm_and_si128(vbs01, vmask); |
151 | 0 | const __m256i vxb01 = _mm256_cvtepi8_epi16(vbm01); |
152 | |
|
153 | 0 | vacc0x01 = _mm256_add_epi32(vacc0x01, _mm256_madd_epi16(vxa0, vxb01)); |
154 | 0 | const __m128i vb23 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 16)); |
155 | 0 | const __m128i vbs23 = _mm_slli_epi32(vb23, 4); |
156 | 0 | const __m128i vbm23 = _mm_and_si128(vbs23, vmask); |
157 | 0 | const __m256i vxb23 = _mm256_cvtepi8_epi16(vbm23); |
158 | |
|
159 | 0 | vacc0x23 = _mm256_add_epi32(vacc0x23, _mm256_madd_epi16(vxa0, vxb23)); |
160 | 0 | const __m128i vb45 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 32)); |
161 | 0 | const __m128i vbs45 = _mm_slli_epi32(vb45, 4); |
162 | 0 | const __m128i vbm45 = _mm_and_si128(vbs45, vmask); |
163 | 0 | const __m256i vxb45 = _mm256_cvtepi8_epi16(vbm45); |
164 | |
|
165 | 0 | vacc0x45 = _mm256_add_epi32(vacc0x45, _mm256_madd_epi16(vxa0, vxb45)); |
166 | 0 | const __m128i vb67 = _mm_load_si128((const __m128i*) ((const int8_t*) w + 48)); |
167 | 0 | const __m128i vbs67 = _mm_slli_epi32(vb67, 4); |
168 | 0 | const __m128i vbm67 = _mm_and_si128(vbs67, vmask); |
169 | 0 | const __m256i vxb67 = _mm256_cvtepi8_epi16(vbm67); |
170 | |
|
171 | 0 | vacc0x67 = _mm256_add_epi32(vacc0x67, _mm256_madd_epi16(vxa0, vxb67)); |
172 | |
|
173 | 0 | w = (const int8_t*) w + 64; |
174 | 0 | k -= 8 * sizeof(int8_t); |
175 | 0 | } |
176 | |
|
177 | 0 | const __m128 vfilter_output_scale0 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[0] << 16)); |
178 | 0 | const __m128 vfilter_output_scale1 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[1] << 16)); |
179 | 0 | const __m256 vfilter_output_scale01 = _mm256_insertf128_ps( |
180 | 0 | _mm256_castps128_ps256(vfilter_output_scale0), vfilter_output_scale1, 1); |
181 | 0 | vout0x01 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc0x01), vfilter_output_scale01, vout0x01); |
182 | 0 | const __m128 vfilter_output_scale2 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[2] << 16)); |
183 | 0 | const __m128 vfilter_output_scale3 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[3] << 16)); |
184 | 0 | const __m256 vfilter_output_scale23 = _mm256_insertf128_ps( |
185 | 0 | _mm256_castps128_ps256(vfilter_output_scale2), vfilter_output_scale3, 1); |
186 | 0 | vout0x23 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc0x23), vfilter_output_scale23, vout0x23); |
187 | 0 | const __m128 vfilter_output_scale4 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[4] << 16)); |
188 | 0 | const __m128 vfilter_output_scale5 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[5] << 16)); |
189 | 0 | const __m256 vfilter_output_scale45 = _mm256_insertf128_ps( |
190 | 0 | _mm256_castps128_ps256(vfilter_output_scale4), vfilter_output_scale5, 1); |
191 | 0 | vout0x45 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc0x45), vfilter_output_scale45, vout0x45); |
192 | 0 | const __m128 vfilter_output_scale6 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[6] << 16)); |
193 | 0 | const __m128 vfilter_output_scale7 = _mm_castsi128_ps(_mm_set1_epi32((uint32_t) ((const uint16_t*) w)[7] << 16)); |
194 | 0 | const __m256 vfilter_output_scale67 = _mm256_insertf128_ps( |
195 | 0 | _mm256_castps128_ps256(vfilter_output_scale6), vfilter_output_scale7, 1); |
196 | 0 | vout0x67 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(vacc0x67), vfilter_output_scale67, vout0x67); |
197 | |
|
198 | 0 | w = (const uint16_t*) w + 8; |
199 | 0 | } |
200 | |
|
201 | 0 | const __m256 vout0x0213 = _mm256_hadd_ps(vout0x01, vout0x23); |
202 | 0 | const __m256 vout0x4657 = _mm256_hadd_ps(vout0x45, vout0x67); |
203 | |
|
204 | 0 | const __m256 vout0x02461357 = _mm256_hadd_ps(vout0x0213, vout0x4657); |
205 | |
|
206 | 0 | const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); |
207 | 0 | __m256 vout0x01234567 = _mm256_permutevar8x32_ps(vout0x02461357, vpermute_mask); |
208 | |
|
209 | 0 | const __m256 vinput_scale0 = _mm256_broadcast_ss(&quantization_params[0].inv_scale); |
210 | |
|
211 | 0 | const __m256 vbias01234567 = _mm256_loadu_ps((const float*) w); |
212 | 0 | w = (const float*) w + 8; |
213 | 0 | vout0x01234567 = _mm256_fmadd_ps(vout0x01234567, vinput_scale0, vbias01234567); |
214 | |
|
215 | 0 | vout0x01234567 = _mm256_max_ps(vout0x01234567, vmin); |
216 | |
|
217 | 0 | vout0x01234567 = _mm256_min_ps(vout0x01234567, vmax); |
218 | |
|
219 | 0 | if XNN_LIKELY(nc >= 8) { |
220 | 0 | _mm256_storeu_ps(c0, vout0x01234567); |
221 | 0 | c0 = (float*) ((uintptr_t) c0 + cn_stride); |
222 | |
|
223 | 0 | a0 = (const int8_t*) ((uintptr_t) a0 - kc); |
224 | |
|
225 | 0 | nc -= 8; |
226 | 0 | } else { |
227 | 0 | __m128 vout0x0123 = _mm256_castps256_ps128(vout0x01234567); |
228 | 0 | if (nc & 4) { |
229 | 0 | _mm_storeu_ps(c0, vout0x0123); |
230 | |
|
231 | 0 | vout0x0123 = _mm256_extractf128_ps(vout0x01234567, 1); |
232 | |
|
233 | 0 | c0 += 4; |
234 | 0 | } |
235 | 0 | if (nc & 2) { |
236 | 0 | _mm_storel_pi((__m64*) c0, vout0x0123); |
237 | |
|
238 | 0 | vout0x0123 = _mm_movehl_ps(vout0x0123, vout0x0123); |
239 | |
|
240 | 0 | c0 += 2; |
241 | 0 | } |
242 | 0 | if (nc & 1) { |
243 | 0 | _mm_store_ss(c0, vout0x0123); |
244 | 0 | } |
245 | 0 | nc = 0; |
246 | 0 | } |
247 | 0 | } while (nc != 0); |
248 | 0 | } |