/src/xnnpack/src/qs8-qc4w-gemm/gen/qs8-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c
Line | Count | Source |
1 | | // clang-format off |
2 | | // Auto-generated file. Do not edit! |
3 | | // Template: src/qs8-gemm/MRx8c8-avxvnni.c.in |
4 | | // Generator: tools/xngen |
5 | | // |
6 | | // Copyright 2024 Google LLC |
7 | | // |
8 | | // This source code is licensed under the BSD-style license found in the |
9 | | // LICENSE file in the root directory of this source tree. |
10 | | |
11 | | #include <assert.h> |
12 | | #include <stddef.h> |
13 | | #include <stdint.h> |
14 | | |
15 | | #include <immintrin.h> |
16 | | |
17 | | #include "src/xnnpack/common.h" |
18 | | #include "src/xnnpack/gemm.h" |
19 | | #include "src/xnnpack/intrinsics-polyfill.h" |
20 | | #include "src/xnnpack/math.h" |
21 | | #include "src/xnnpack/microparams.h" |
22 | | #include "src/xnnpack/prefetch.h" |
23 | | #include "src/xnnpack/unaligned.h" |
24 | | |
25 | | |
26 | | void xnn_qs8_qc4w_gemm_minmax_fp32_ukernel_7x8c8__avx2_madd_prfm( |
27 | | size_t mr, |
28 | | size_t nc, |
29 | | size_t kc, |
30 | | const int8_t* restrict a, |
31 | | size_t a_stride, |
32 | | const void* restrict w, |
33 | | int8_t* restrict c, |
34 | | size_t cm_stride, |
35 | | size_t cn_stride, |
36 | | const union xnn_qs8_qc8w_conv_minmax_params* restrict params) XNN_OOB_READS |
37 | 0 | { |
38 | 0 | assert(mr != 0); |
39 | 0 | assert(mr <= 7); |
40 | 0 | assert(nc != 0); |
41 | 0 | assert(kc != 0); |
42 | 0 | assert(kc % sizeof(int8_t) == 0); |
43 | 0 | assert(a != NULL); |
44 | 0 | assert(w != NULL); |
45 | 0 | assert(c != NULL); |
46 | | |
47 | 0 | kc = round_up_po2(kc, 8 * sizeof(int8_t)); |
48 | 0 | const int8_t* a0 = a; |
49 | 0 | int8_t* c0 = c; |
50 | 0 | const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); |
51 | 0 | int8_t* c1 = (int8_t*) ((uintptr_t) c0 + cm_stride); |
52 | 0 | if XNN_UNPREDICTABLE(mr < 2) { |
53 | 0 | a1 = a0; |
54 | 0 | c1 = c0; |
55 | 0 | } |
56 | 0 | const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); |
57 | 0 | int8_t* c2 = (int8_t*) ((uintptr_t) c1 + cm_stride); |
58 | 0 | if XNN_UNPREDICTABLE(mr <= 2) { |
59 | 0 | a2 = a1; |
60 | 0 | c2 = c1; |
61 | 0 | } |
62 | 0 | const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride); |
63 | 0 | int8_t* c3 = (int8_t*) ((uintptr_t) c2 + cm_stride); |
64 | 0 | if XNN_UNPREDICTABLE(mr < 4) { |
65 | 0 | a3 = a2; |
66 | 0 | c3 = c2; |
67 | 0 | } |
68 | 0 | const int8_t* a4 = (const int8_t*) ((uintptr_t) a3 + a_stride); |
69 | 0 | int8_t* c4 = (int8_t*) ((uintptr_t) c3 + cm_stride); |
70 | 0 | if XNN_UNPREDICTABLE(mr <= 4) { |
71 | 0 | a4 = a3; |
72 | 0 | c4 = c3; |
73 | 0 | } |
74 | 0 | const int8_t* a5 = (const int8_t*) ((uintptr_t) a4 + a_stride); |
75 | 0 | int8_t* c5 = (int8_t*) ((uintptr_t) c4 + cm_stride); |
76 | 0 | if XNN_UNPREDICTABLE(mr < 6) { |
77 | 0 | a5 = a4; |
78 | 0 | c5 = c4; |
79 | 0 | } |
80 | 0 | const int8_t* a6 = (const int8_t*) ((uintptr_t) a5 + a_stride); |
81 | 0 | int8_t* c6 = (int8_t*) ((uintptr_t) c5 + cm_stride); |
82 | 0 | if XNN_UNPREDICTABLE(mr <= 6) { |
83 | 0 | a6 = a5; |
84 | 0 | c6 = c5; |
85 | 0 | } |
86 | |
|
87 | 0 | const __m256i vsign_mask = _mm256_set1_epi8(0x80); |
88 | 0 | XNN_FORCE_REALIZATION(vsign_mask); |
89 | 0 | const __m256 voutput_max_less_zero_point = _mm256_set1_ps((int32_t) params->fp32_scalar.output_max - (int32_t) params->fp32_scalar.output_zero_point); |
90 | 0 | const __m256i voutput_zero_point = _mm256_set1_epi32(params->fp32_scalar.output_zero_point); |
91 | 0 | const __m128i voutput_min = _mm_set1_epi8(params->fp32_scalar.output_min); |
92 | | // XNN_FORCE_REALIZATION(voutput_max_less_zero_point); |
93 | | // XNN_FORCE_REALIZATION(voutput_zero_point); |
94 | | // XNN_FORCE_REALIZATION(voutput_min); |
95 | 0 | const __m256i vmask = _mm256_set1_epi8(0x0F); |
96 | 0 | XNN_FORCE_REALIZATION(vmask); |
97 | 0 | do { |
98 | 0 | __m256i vacc0x0123 = _mm256_cvtepu32_epi64(_mm_load_si128((const __m128i*) w)); |
99 | 0 | __m256i vacc0x4567 = _mm256_cvtepu32_epi64(_mm_load_si128((const __m128i*) ((const int32_t*) w + 4))); |
100 | 0 | __m256i vacc1x0123 = vacc0x0123; |
101 | 0 | __m256i vacc1x4567 = vacc0x4567; |
102 | 0 | __m256i vacc2x0123 = vacc0x0123; |
103 | 0 | __m256i vacc2x4567 = vacc0x4567; |
104 | 0 | __m256i vacc3x0123 = vacc0x0123; |
105 | 0 | __m256i vacc3x4567 = vacc0x4567; |
106 | 0 | __m256i vacc4x0123 = vacc0x0123; |
107 | 0 | __m256i vacc4x4567 = vacc0x4567; |
108 | 0 | __m256i vacc5x0123 = vacc0x0123; |
109 | 0 | __m256i vacc5x4567 = vacc0x4567; |
110 | 0 | __m256i vacc6x0123 = vacc0x0123; |
111 | 0 | __m256i vacc6x4567 = vacc0x4567; |
112 | 0 | w = (const int32_t*) w + 8; |
113 | |
|
114 | 0 | size_t k = kc; |
115 | 0 | while (k >= 16 * sizeof(int8_t)) { |
116 | 0 | const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); |
117 | 0 | const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); |
118 | 0 | a0 += 16; |
119 | 0 | const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); |
120 | 0 | const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); |
121 | 0 | a1 += 16; |
122 | 0 | const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); |
123 | 0 | const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); |
124 | 0 | a2 += 16; |
125 | 0 | const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); |
126 | 0 | const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); |
127 | 0 | a3 += 16; |
128 | 0 | const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); |
129 | 0 | const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); |
130 | 0 | a4 += 16; |
131 | 0 | const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); |
132 | 0 | const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); |
133 | 0 | a5 += 16; |
134 | 0 | const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); |
135 | 0 | const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); |
136 | 0 | a6 += 16; |
137 | |
|
138 | 0 | const __m256i vbb01234567x01234567 = _mm256_load_si256(w); |
139 | 0 | const __m256i vbb89ABCDEFx01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); |
140 | 0 | const __m256i vbs01234567x4567 = _mm256_srli_epi32(vbb01234567x01234567, 4); |
141 | 0 | const __m256i vbs89ABCDEFx4567 = _mm256_srli_epi32(vbb89ABCDEFx01234567, 4); |
142 | 0 | const __m256i vb01234567x0123 = _mm256_and_si256(vbb01234567x01234567, vmask); |
143 | 0 | const __m256i vb89ABCDEFx0123 = _mm256_and_si256(vbb89ABCDEFx01234567, vmask); |
144 | 0 | const __m256i vb01234567x4567 = _mm256_and_si256(vbs01234567x4567, vmask); |
145 | 0 | const __m256i vb89ABCDEFx4567 = _mm256_and_si256(vbs89ABCDEFx4567, vmask); |
146 | |
|
147 | 0 | vacc0x0123 = _mm256_dpbusd_epi32_madd(vacc0x0123, va0x01234567, vb01234567x0123); |
148 | 0 | vacc0x4567 = _mm256_dpbusd_epi32_madd(vacc0x4567, va0x01234567, vb89ABCDEFx0123); |
149 | 0 | vacc1x0123 = _mm256_dpbusd_epi32_madd(vacc1x0123, va1x01234567, vb01234567x0123); |
150 | 0 | vacc1x4567 = _mm256_dpbusd_epi32_madd(vacc1x4567, va1x01234567, vb89ABCDEFx0123); |
151 | 0 | vacc2x0123 = _mm256_dpbusd_epi32_madd(vacc2x0123, va2x01234567, vb01234567x0123); |
152 | 0 | vacc2x4567 = _mm256_dpbusd_epi32_madd(vacc2x4567, va2x01234567, vb89ABCDEFx0123); |
153 | 0 | vacc3x0123 = _mm256_dpbusd_epi32_madd(vacc3x0123, va3x01234567, vb01234567x0123); |
154 | 0 | vacc3x4567 = _mm256_dpbusd_epi32_madd(vacc3x4567, va3x01234567, vb89ABCDEFx0123); |
155 | 0 | vacc4x0123 = _mm256_dpbusd_epi32_madd(vacc4x0123, va4x01234567, vb01234567x0123); |
156 | 0 | vacc4x4567 = _mm256_dpbusd_epi32_madd(vacc4x4567, va4x01234567, vb89ABCDEFx0123); |
157 | 0 | vacc5x0123 = _mm256_dpbusd_epi32_madd(vacc5x0123, va5x01234567, vb01234567x0123); |
158 | 0 | vacc5x4567 = _mm256_dpbusd_epi32_madd(vacc5x4567, va5x01234567, vb89ABCDEFx0123); |
159 | 0 | vacc6x0123 = _mm256_dpbusd_epi32_madd(vacc6x0123, va6x01234567, vb01234567x0123); |
160 | 0 | vacc6x4567 = _mm256_dpbusd_epi32_madd(vacc6x4567, va6x01234567, vb89ABCDEFx0123); |
161 | 0 | xnn_prefetch_to_l1((const int8_t*) w + 960); |
162 | 0 | vacc0x0123 = _mm256_dpbusd_epi32_madd(vacc0x0123, va0x89ABCDEF, vb01234567x4567); |
163 | 0 | vacc0x4567 = _mm256_dpbusd_epi32_madd(vacc0x4567, va0x89ABCDEF, vb89ABCDEFx4567); |
164 | 0 | vacc1x0123 = _mm256_dpbusd_epi32_madd(vacc1x0123, va1x89ABCDEF, vb01234567x4567); |
165 | 0 | vacc1x4567 = _mm256_dpbusd_epi32_madd(vacc1x4567, va1x89ABCDEF, vb89ABCDEFx4567); |
166 | 0 | vacc2x0123 = _mm256_dpbusd_epi32_madd(vacc2x0123, va2x89ABCDEF, vb01234567x4567); |
167 | 0 | vacc2x4567 = _mm256_dpbusd_epi32_madd(vacc2x4567, va2x89ABCDEF, vb89ABCDEFx4567); |
168 | 0 | vacc3x0123 = _mm256_dpbusd_epi32_madd(vacc3x0123, va3x89ABCDEF, vb01234567x4567); |
169 | 0 | vacc3x4567 = _mm256_dpbusd_epi32_madd(vacc3x4567, va3x89ABCDEF, vb89ABCDEFx4567); |
170 | 0 | vacc4x0123 = _mm256_dpbusd_epi32_madd(vacc4x0123, va4x89ABCDEF, vb01234567x4567); |
171 | 0 | vacc4x4567 = _mm256_dpbusd_epi32_madd(vacc4x4567, va4x89ABCDEF, vb89ABCDEFx4567); |
172 | 0 | vacc5x0123 = _mm256_dpbusd_epi32_madd(vacc5x0123, va5x89ABCDEF, vb01234567x4567); |
173 | 0 | vacc5x4567 = _mm256_dpbusd_epi32_madd(vacc5x4567, va5x89ABCDEF, vb89ABCDEFx4567); |
174 | 0 | vacc6x0123 = _mm256_dpbusd_epi32_madd(vacc6x0123, va6x89ABCDEF, vb01234567x4567); |
175 | 0 | vacc6x4567 = _mm256_dpbusd_epi32_madd(vacc6x4567, va6x89ABCDEF, vb89ABCDEFx4567); |
176 | |
|
177 | 0 | w = (const int8_t*) w + 64; |
178 | 0 | k -= 16 * sizeof(int8_t); |
179 | 0 | } |
180 | |
|
181 | 0 | if (k != 0) { |
182 | 0 | const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); |
183 | 0 | a0 += 8; |
184 | 0 | const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); |
185 | 0 | a1 += 8; |
186 | 0 | const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); |
187 | 0 | a2 += 8; |
188 | 0 | const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); |
189 | 0 | a3 += 8; |
190 | 0 | const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); |
191 | 0 | a4 += 8; |
192 | 0 | const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); |
193 | 0 | a5 += 8; |
194 | 0 | const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); |
195 | 0 | a6 += 8; |
196 | |
|
197 | 0 | const __m256i vbb01234567x01234567 = _mm256_load_si256(w); |
198 | 0 | const __m256i vbb89ABCDEFx01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); |
199 | 0 | const __m256i vb01234567x0123 = _mm256_and_si256(vbb01234567x01234567, vmask); |
200 | 0 | const __m256i vb89ABCDEFx0123 = _mm256_and_si256(vbb89ABCDEFx01234567, vmask); |
201 | |
|
202 | 0 | vacc0x0123 = _mm256_dpbusd_epi32_madd(vacc0x0123, va0x01234567, vb01234567x0123); |
203 | 0 | vacc0x4567 = _mm256_dpbusd_epi32_madd(vacc0x4567, va0x01234567, vb89ABCDEFx0123); |
204 | 0 | vacc1x0123 = _mm256_dpbusd_epi32_madd(vacc1x0123, va1x01234567, vb01234567x0123); |
205 | 0 | vacc1x4567 = _mm256_dpbusd_epi32_madd(vacc1x4567, va1x01234567, vb89ABCDEFx0123); |
206 | 0 | vacc2x0123 = _mm256_dpbusd_epi32_madd(vacc2x0123, va2x01234567, vb01234567x0123); |
207 | 0 | vacc2x4567 = _mm256_dpbusd_epi32_madd(vacc2x4567, va2x01234567, vb89ABCDEFx0123); |
208 | 0 | vacc3x0123 = _mm256_dpbusd_epi32_madd(vacc3x0123, va3x01234567, vb01234567x0123); |
209 | 0 | vacc3x4567 = _mm256_dpbusd_epi32_madd(vacc3x4567, va3x01234567, vb89ABCDEFx0123); |
210 | 0 | vacc4x0123 = _mm256_dpbusd_epi32_madd(vacc4x0123, va4x01234567, vb01234567x0123); |
211 | 0 | vacc4x4567 = _mm256_dpbusd_epi32_madd(vacc4x4567, va4x01234567, vb89ABCDEFx0123); |
212 | 0 | vacc5x0123 = _mm256_dpbusd_epi32_madd(vacc5x0123, va5x01234567, vb01234567x0123); |
213 | 0 | vacc5x4567 = _mm256_dpbusd_epi32_madd(vacc5x4567, va5x01234567, vb89ABCDEFx0123); |
214 | 0 | vacc6x0123 = _mm256_dpbusd_epi32_madd(vacc6x0123, va6x01234567, vb01234567x0123); |
215 | 0 | vacc6x4567 = _mm256_dpbusd_epi32_madd(vacc6x4567, va6x01234567, vb89ABCDEFx0123); |
216 | 0 | xnn_prefetch_to_l1((const int8_t*) w + 960); |
217 | |
|
218 | 0 | w = (const int8_t*) w + 64; |
219 | 0 | k -= 8 * sizeof(int8_t); |
220 | 0 | } |
221 | | |
222 | | // Add adjacent pairs |
223 | 0 | const __m256i vsum0x02134657 = _mm256_hadd_epi32(vacc0x0123, vacc0x4567); |
224 | 0 | __m256i vacc0x01234567 = _mm256_permute4x64_epi64(vsum0x02134657, _MM_SHUFFLE(3, 1, 2, 0)); |
225 | 0 | const __m256i vsum1x02134657 = _mm256_hadd_epi32(vacc1x0123, vacc1x4567); |
226 | 0 | __m256i vacc1x01234567 = _mm256_permute4x64_epi64(vsum1x02134657, _MM_SHUFFLE(3, 1, 2, 0)); |
227 | 0 | const __m256i vsum2x02134657 = _mm256_hadd_epi32(vacc2x0123, vacc2x4567); |
228 | 0 | __m256i vacc2x01234567 = _mm256_permute4x64_epi64(vsum2x02134657, _MM_SHUFFLE(3, 1, 2, 0)); |
229 | 0 | const __m256i vsum3x02134657 = _mm256_hadd_epi32(vacc3x0123, vacc3x4567); |
230 | 0 | __m256i vacc3x01234567 = _mm256_permute4x64_epi64(vsum3x02134657, _MM_SHUFFLE(3, 1, 2, 0)); |
231 | 0 | const __m256i vsum4x02134657 = _mm256_hadd_epi32(vacc4x0123, vacc4x4567); |
232 | 0 | __m256i vacc4x01234567 = _mm256_permute4x64_epi64(vsum4x02134657, _MM_SHUFFLE(3, 1, 2, 0)); |
233 | 0 | const __m256i vsum5x02134657 = _mm256_hadd_epi32(vacc5x0123, vacc5x4567); |
234 | 0 | __m256i vacc5x01234567 = _mm256_permute4x64_epi64(vsum5x02134657, _MM_SHUFFLE(3, 1, 2, 0)); |
235 | 0 | const __m256i vsum6x02134657 = _mm256_hadd_epi32(vacc6x0123, vacc6x4567); |
236 | 0 | __m256i vacc6x01234567 = _mm256_permute4x64_epi64(vsum6x02134657, _MM_SHUFFLE(3, 1, 2, 0)); |
237 | |
|
238 | 0 | __m256 vout0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567); |
239 | 0 | __m256 vout1x01234567 = _mm256_cvtepi32_ps(vacc1x01234567); |
240 | 0 | __m256 vout2x01234567 = _mm256_cvtepi32_ps(vacc2x01234567); |
241 | 0 | __m256 vout3x01234567 = _mm256_cvtepi32_ps(vacc3x01234567); |
242 | 0 | __m256 vout4x01234567 = _mm256_cvtepi32_ps(vacc4x01234567); |
243 | 0 | __m256 vout5x01234567 = _mm256_cvtepi32_ps(vacc5x01234567); |
244 | 0 | __m256 vout6x01234567 = _mm256_cvtepi32_ps(vacc6x01234567); |
245 | |
|
246 | 0 | const __m256 vscale01234567 = _mm256_load_ps(w); |
247 | 0 | w = (const float*) w + 8; |
248 | 0 | vout0x01234567 = _mm256_mul_ps(vout0x01234567, vscale01234567); |
249 | 0 | vout1x01234567 = _mm256_mul_ps(vout1x01234567, vscale01234567); |
250 | 0 | vout2x01234567 = _mm256_mul_ps(vout2x01234567, vscale01234567); |
251 | 0 | vout3x01234567 = _mm256_mul_ps(vout3x01234567, vscale01234567); |
252 | 0 | vout4x01234567 = _mm256_mul_ps(vout4x01234567, vscale01234567); |
253 | 0 | vout5x01234567 = _mm256_mul_ps(vout5x01234567, vscale01234567); |
254 | 0 | vout6x01234567 = _mm256_mul_ps(vout6x01234567, vscale01234567); |
255 | |
|
256 | 0 | vout0x01234567 = _mm256_min_ps(vout0x01234567, voutput_max_less_zero_point); |
257 | 0 | vout1x01234567 = _mm256_min_ps(vout1x01234567, voutput_max_less_zero_point); |
258 | 0 | vout2x01234567 = _mm256_min_ps(vout2x01234567, voutput_max_less_zero_point); |
259 | 0 | vout3x01234567 = _mm256_min_ps(vout3x01234567, voutput_max_less_zero_point); |
260 | 0 | vout4x01234567 = _mm256_min_ps(vout4x01234567, voutput_max_less_zero_point); |
261 | 0 | vout5x01234567 = _mm256_min_ps(vout5x01234567, voutput_max_less_zero_point); |
262 | 0 | vout6x01234567 = _mm256_min_ps(vout6x01234567, voutput_max_less_zero_point); |
263 | |
|
264 | 0 | vacc0x01234567 = _mm256_cvtps_epi32(vout0x01234567); |
265 | 0 | vacc1x01234567 = _mm256_cvtps_epi32(vout1x01234567); |
266 | 0 | vacc2x01234567 = _mm256_cvtps_epi32(vout2x01234567); |
267 | 0 | vacc3x01234567 = _mm256_cvtps_epi32(vout3x01234567); |
268 | 0 | vacc4x01234567 = _mm256_cvtps_epi32(vout4x01234567); |
269 | 0 | vacc5x01234567 = _mm256_cvtps_epi32(vout5x01234567); |
270 | 0 | vacc6x01234567 = _mm256_cvtps_epi32(vout6x01234567); |
271 | |
|
272 | 0 | vacc0x01234567 = _mm256_add_epi32(vacc0x01234567, voutput_zero_point); |
273 | 0 | vacc1x01234567 = _mm256_add_epi32(vacc1x01234567, voutput_zero_point); |
274 | 0 | vacc2x01234567 = _mm256_add_epi32(vacc2x01234567, voutput_zero_point); |
275 | 0 | vacc3x01234567 = _mm256_add_epi32(vacc3x01234567, voutput_zero_point); |
276 | 0 | vacc4x01234567 = _mm256_add_epi32(vacc4x01234567, voutput_zero_point); |
277 | 0 | vacc5x01234567 = _mm256_add_epi32(vacc5x01234567, voutput_zero_point); |
278 | 0 | vacc6x01234567 = _mm256_add_epi32(vacc6x01234567, voutput_zero_point); |
279 | |
|
280 | 0 | vacc0x01234567 = _mm256_packs_epi32(vacc0x01234567, _mm256_castsi128_si256(_mm256_extracti128_si256(vacc0x01234567, 1))); |
281 | 0 | __m128i voutb0x01234567 = _mm256_castsi256_si128(_mm256_packs_epi16(vacc0x01234567, vacc0x01234567)); |
282 | 0 | vacc1x01234567 = _mm256_packs_epi32(vacc1x01234567, _mm256_castsi128_si256(_mm256_extracti128_si256(vacc1x01234567, 1))); |
283 | 0 | __m128i voutb1x01234567 = _mm256_castsi256_si128(_mm256_packs_epi16(vacc1x01234567, vacc1x01234567)); |
284 | 0 | vacc2x01234567 = _mm256_packs_epi32(vacc2x01234567, _mm256_castsi128_si256(_mm256_extracti128_si256(vacc2x01234567, 1))); |
285 | 0 | __m128i voutb2x01234567 = _mm256_castsi256_si128(_mm256_packs_epi16(vacc2x01234567, vacc2x01234567)); |
286 | 0 | vacc3x01234567 = _mm256_packs_epi32(vacc3x01234567, _mm256_castsi128_si256(_mm256_extracti128_si256(vacc3x01234567, 1))); |
287 | 0 | __m128i voutb3x01234567 = _mm256_castsi256_si128(_mm256_packs_epi16(vacc3x01234567, vacc3x01234567)); |
288 | 0 | vacc4x01234567 = _mm256_packs_epi32(vacc4x01234567, _mm256_castsi128_si256(_mm256_extracti128_si256(vacc4x01234567, 1))); |
289 | 0 | __m128i voutb4x01234567 = _mm256_castsi256_si128(_mm256_packs_epi16(vacc4x01234567, vacc4x01234567)); |
290 | 0 | vacc5x01234567 = _mm256_packs_epi32(vacc5x01234567, _mm256_castsi128_si256(_mm256_extracti128_si256(vacc5x01234567, 1))); |
291 | 0 | __m128i voutb5x01234567 = _mm256_castsi256_si128(_mm256_packs_epi16(vacc5x01234567, vacc5x01234567)); |
292 | 0 | vacc6x01234567 = _mm256_packs_epi32(vacc6x01234567, _mm256_castsi128_si256(_mm256_extracti128_si256(vacc6x01234567, 1))); |
293 | 0 | __m128i voutb6x01234567 = _mm256_castsi256_si128(_mm256_packs_epi16(vacc6x01234567, vacc6x01234567)); |
294 | |
|
295 | 0 | voutb0x01234567 = _mm_max_epi8(voutb0x01234567, voutput_min); |
296 | 0 | voutb1x01234567 = _mm_max_epi8(voutb1x01234567, voutput_min); |
297 | 0 | voutb2x01234567 = _mm_max_epi8(voutb2x01234567, voutput_min); |
298 | 0 | voutb3x01234567 = _mm_max_epi8(voutb3x01234567, voutput_min); |
299 | 0 | voutb4x01234567 = _mm_max_epi8(voutb4x01234567, voutput_min); |
300 | 0 | voutb5x01234567 = _mm_max_epi8(voutb5x01234567, voutput_min); |
301 | 0 | voutb6x01234567 = _mm_max_epi8(voutb6x01234567, voutput_min); |
302 | |
|
303 | 0 | if (nc >= 8) { |
304 | 0 | _mm_storel_epi64((__m128i*) c0, voutb0x01234567); |
305 | 0 | c0 = (int8_t*) ((uintptr_t) c0 + cn_stride); |
306 | 0 | a0 = (const int8_t*) ((uintptr_t) a0 - kc); |
307 | 0 | _mm_storel_epi64((__m128i*) c1, voutb1x01234567); |
308 | 0 | c1 = (int8_t*) ((uintptr_t) c1 + cn_stride); |
309 | 0 | a1 = (const int8_t*) ((uintptr_t) a1 - kc); |
310 | 0 | _mm_storel_epi64((__m128i*) c2, voutb2x01234567); |
311 | 0 | c2 = (int8_t*) ((uintptr_t) c2 + cn_stride); |
312 | 0 | a2 = (const int8_t*) ((uintptr_t) a2 - kc); |
313 | 0 | _mm_storel_epi64((__m128i*) c3, voutb3x01234567); |
314 | 0 | c3 = (int8_t*) ((uintptr_t) c3 + cn_stride); |
315 | 0 | a3 = (const int8_t*) ((uintptr_t) a3 - kc); |
316 | 0 | _mm_storel_epi64((__m128i*) c4, voutb4x01234567); |
317 | 0 | c4 = (int8_t*) ((uintptr_t) c4 + cn_stride); |
318 | 0 | a4 = (const int8_t*) ((uintptr_t) a4 - kc); |
319 | 0 | _mm_storel_epi64((__m128i*) c5, voutb5x01234567); |
320 | 0 | c5 = (int8_t*) ((uintptr_t) c5 + cn_stride); |
321 | 0 | a5 = (const int8_t*) ((uintptr_t) a5 - kc); |
322 | 0 | _mm_storel_epi64((__m128i*) c6, voutb6x01234567); |
323 | 0 | c6 = (int8_t*) ((uintptr_t) c6 + cn_stride); |
324 | 0 | a6 = (const int8_t*) ((uintptr_t) a6 - kc); |
325 | |
|
326 | 0 | nc -= 8; |
327 | 0 | } else { |
328 | 0 | if (nc & 4) { |
329 | 0 | _mm_storeu_si32(c0, voutb0x01234567); |
330 | 0 | c0 += 4; |
331 | 0 | _mm_storeu_si32(c1, voutb1x01234567); |
332 | 0 | c1 += 4; |
333 | 0 | _mm_storeu_si32(c2, voutb2x01234567); |
334 | 0 | c2 += 4; |
335 | 0 | _mm_storeu_si32(c3, voutb3x01234567); |
336 | 0 | c3 += 4; |
337 | 0 | _mm_storeu_si32(c4, voutb4x01234567); |
338 | 0 | c4 += 4; |
339 | 0 | _mm_storeu_si32(c5, voutb5x01234567); |
340 | 0 | c5 += 4; |
341 | 0 | _mm_storeu_si32(c6, voutb6x01234567); |
342 | 0 | c6 += 4; |
343 | 0 | voutb0x01234567 = _mm_srli_epi64(voutb0x01234567, 32); |
344 | 0 | voutb1x01234567 = _mm_srli_epi64(voutb1x01234567, 32); |
345 | 0 | voutb2x01234567 = _mm_srli_epi64(voutb2x01234567, 32); |
346 | 0 | voutb3x01234567 = _mm_srli_epi64(voutb3x01234567, 32); |
347 | 0 | voutb4x01234567 = _mm_srli_epi64(voutb4x01234567, 32); |
348 | 0 | voutb5x01234567 = _mm_srli_epi64(voutb5x01234567, 32); |
349 | 0 | voutb6x01234567 = _mm_srli_epi64(voutb6x01234567, 32); |
350 | 0 | } |
351 | 0 | if (nc & 2) { |
352 | 0 | unaligned_store_u16(c0, (uint16_t) _mm_extract_epi16(voutb0x01234567, 0)); |
353 | 0 | c0 += 2; |
354 | 0 | unaligned_store_u16(c1, (uint16_t) _mm_extract_epi16(voutb1x01234567, 0)); |
355 | 0 | c1 += 2; |
356 | 0 | unaligned_store_u16(c2, (uint16_t) _mm_extract_epi16(voutb2x01234567, 0)); |
357 | 0 | c2 += 2; |
358 | 0 | unaligned_store_u16(c3, (uint16_t) _mm_extract_epi16(voutb3x01234567, 0)); |
359 | 0 | c3 += 2; |
360 | 0 | unaligned_store_u16(c4, (uint16_t) _mm_extract_epi16(voutb4x01234567, 0)); |
361 | 0 | c4 += 2; |
362 | 0 | unaligned_store_u16(c5, (uint16_t) _mm_extract_epi16(voutb5x01234567, 0)); |
363 | 0 | c5 += 2; |
364 | 0 | unaligned_store_u16(c6, (uint16_t) _mm_extract_epi16(voutb6x01234567, 0)); |
365 | 0 | c6 += 2; |
366 | 0 | voutb0x01234567 = _mm_srli_epi32(voutb0x01234567, 16); |
367 | 0 | voutb1x01234567 = _mm_srli_epi32(voutb1x01234567, 16); |
368 | 0 | voutb2x01234567 = _mm_srli_epi32(voutb2x01234567, 16); |
369 | 0 | voutb3x01234567 = _mm_srli_epi32(voutb3x01234567, 16); |
370 | 0 | voutb4x01234567 = _mm_srli_epi32(voutb4x01234567, 16); |
371 | 0 | voutb5x01234567 = _mm_srli_epi32(voutb5x01234567, 16); |
372 | 0 | voutb6x01234567 = _mm_srli_epi32(voutb6x01234567, 16); |
373 | 0 | } |
374 | 0 | if (nc & 1) { |
375 | 0 | *c0 = (int8_t) _mm_extract_epi8(voutb0x01234567, 0); |
376 | 0 | *c1 = (int8_t) _mm_extract_epi8(voutb1x01234567, 0); |
377 | 0 | *c2 = (int8_t) _mm_extract_epi8(voutb2x01234567, 0); |
378 | 0 | *c3 = (int8_t) _mm_extract_epi8(voutb3x01234567, 0); |
379 | 0 | *c4 = (int8_t) _mm_extract_epi8(voutb4x01234567, 0); |
380 | 0 | *c5 = (int8_t) _mm_extract_epi8(voutb5x01234567, 0); |
381 | 0 | *c6 = (int8_t) _mm_extract_epi8(voutb6x01234567, 0); |
382 | 0 | } |
383 | 0 | nc = 0; |
384 | 0 | } |
385 | 0 | } while (nc != 0); |
386 | 0 | } |