/src/xnnpack/src/f32-qc8w-gemm/gen/f32-qc8w-gemm-5x16-minmax-avx2-broadcast.c
Line | Count | Source |
1 | | // clang-format off |
2 | | // Auto-generated file. Do not edit! |
3 | | // Template: src/f32-gemm/avx-broadcast.c.in |
4 | | // Generator: tools/xngen |
5 | | // |
6 | | // Copyright 2019 Google LLC |
7 | | // |
8 | | // This source code is licensed under the BSD-style license found in the |
9 | | // LICENSE file in the root directory of this source tree. |
10 | | |
11 | | #include <assert.h> |
12 | | #include <stddef.h> |
13 | | #include <stdint.h> |
14 | | |
15 | | #include <immintrin.h> |
16 | | |
17 | | #include "src/xnnpack/common.h" |
18 | | #include "src/xnnpack/gemm.h" |
19 | | #include "src/xnnpack/microparams.h" |
20 | | |
21 | | |
22 | | void xnn_f32_qc8w_gemm_minmax_ukernel_5x16__avx2_broadcast( |
23 | | size_t mr, |
24 | | size_t nc, |
25 | | size_t kc, |
26 | | const float* restrict a, |
27 | | size_t a_stride, |
28 | | const void* restrict w, |
29 | | float* restrict c, |
30 | | size_t cm_stride, |
31 | | size_t cn_stride, |
32 | | const struct xnn_f32_minmax_params* restrict params) |
33 | 0 | { |
34 | 0 | assert(mr != 0); |
35 | 0 | assert(mr <= 5); |
36 | 0 | assert(nc != 0); |
37 | 0 | assert(kc != 0); |
38 | 0 | assert(kc % sizeof(float) == 0); |
39 | 0 | assert(a != NULL); |
40 | 0 | assert(w != NULL); |
41 | 0 | assert(c != NULL); |
42 | | |
43 | 0 | const float* a0 = a; |
44 | 0 | float* c0 = c; |
45 | 0 | const float* a1 = (const float*) ((uintptr_t) a0 + a_stride); |
46 | 0 | float* c1 = (float*) ((uintptr_t) c0 + cm_stride); |
47 | 0 | if XNN_UNPREDICTABLE(mr < 2) { |
48 | 0 | a1 = a0; |
49 | 0 | c1 = c0; |
50 | 0 | } |
51 | 0 | const float* a2 = (const float*) ((uintptr_t) a1 + a_stride); |
52 | 0 | float* c2 = (float*) ((uintptr_t) c1 + cm_stride); |
53 | 0 | if XNN_UNPREDICTABLE(mr <= 2) { |
54 | 0 | a2 = a1; |
55 | 0 | c2 = c1; |
56 | 0 | } |
57 | 0 | const float* a3 = (const float*) ((uintptr_t) a2 + a_stride); |
58 | 0 | float* c3 = (float*) ((uintptr_t) c2 + cm_stride); |
59 | 0 | if XNN_UNPREDICTABLE(mr < 4) { |
60 | 0 | a3 = a2; |
61 | 0 | c3 = c2; |
62 | 0 | } |
63 | 0 | const float* a4 = (const float*) ((uintptr_t) a3 + a_stride); |
64 | 0 | float* c4 = (float*) ((uintptr_t) c3 + cm_stride); |
65 | 0 | if XNN_UNPREDICTABLE(mr <= 4) { |
66 | 0 | a4 = a3; |
67 | 0 | c4 = c3; |
68 | 0 | } |
69 | |
|
70 | 0 | const __m256 vmin = _mm256_set1_ps(params->scalar.min); |
71 | 0 | const __m256 vmax = _mm256_set1_ps(params->scalar.max); |
72 | 0 | XNN_FORCE_REALIZATION(vmin); |
73 | 0 | XNN_FORCE_REALIZATION(vmax); |
74 | |
|
75 | 0 | do { |
76 | 0 | __m256 vacc0x01234567 = _mm256_loadu_ps((const float*) w + 0); |
77 | 0 | __m256 vacc0x89ABCDEF = _mm256_loadu_ps((const float*) w + 8); |
78 | 0 | __m256 vacc1x01234567 = vacc0x01234567; |
79 | 0 | __m256 vacc1x89ABCDEF = vacc0x89ABCDEF; |
80 | 0 | __m256 vacc2x01234567 = vacc0x01234567; |
81 | 0 | __m256 vacc2x89ABCDEF = vacc0x89ABCDEF; |
82 | 0 | __m256 vacc3x01234567 = vacc0x01234567; |
83 | 0 | __m256 vacc3x89ABCDEF = vacc0x89ABCDEF; |
84 | 0 | __m256 vacc4x01234567 = vacc0x01234567; |
85 | 0 | __m256 vacc4x89ABCDEF = vacc0x89ABCDEF; |
86 | 0 | w = (const float*) w + 16; |
87 | |
|
88 | 0 | size_t k = kc; |
89 | 0 | do { |
90 | 0 | const __m256 va0 = _mm256_broadcast_ss(a0); |
91 | 0 | a0 += 1; |
92 | 0 | const __m256 va1 = _mm256_broadcast_ss(a1); |
93 | 0 | a1 += 1; |
94 | 0 | const __m256 va2 = _mm256_broadcast_ss(a2); |
95 | 0 | a2 += 1; |
96 | 0 | const __m256 va3 = _mm256_broadcast_ss(a3); |
97 | 0 | a3 += 1; |
98 | 0 | const __m256 va4 = _mm256_broadcast_ss(a4); |
99 | 0 | a4 += 1; |
100 | |
|
101 | 0 | const __m256i vbi01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) w)); |
102 | 0 | const __m256i vbi89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((const int8_t*) w + 8))); |
103 | 0 | w = (const int8_t*) w + 16; |
104 | 0 | const __m256 vb01234567 = _mm256_cvtepi32_ps(vbi01234567); |
105 | 0 | const __m256 vb89ABCDEF = _mm256_cvtepi32_ps(vbi89ABCDEF); |
106 | |
|
107 | 0 | vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567, vacc0x01234567); |
108 | 0 | vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567, vacc1x01234567); |
109 | 0 | vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567, vacc2x01234567); |
110 | 0 | vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567, vacc3x01234567); |
111 | 0 | vacc4x01234567 = _mm256_fmadd_ps(va4, vb01234567, vacc4x01234567); |
112 | 0 | vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF); |
113 | 0 | vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEF, vacc1x89ABCDEF); |
114 | 0 | vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEF, vacc2x89ABCDEF); |
115 | 0 | vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEF, vacc3x89ABCDEF); |
116 | 0 | vacc4x89ABCDEF = _mm256_fmadd_ps(va4, vb89ABCDEF, vacc4x89ABCDEF); |
117 | |
|
118 | 0 | k -= sizeof(float); |
119 | 0 | } while (k != 0); |
120 | |
|
121 | 0 | const __m256 vscale01234567 = _mm256_loadu_ps((const float*) w + 0); |
122 | 0 | vacc0x01234567 = _mm256_mul_ps(vacc0x01234567, vscale01234567); |
123 | 0 | vacc1x01234567 = _mm256_mul_ps(vacc1x01234567, vscale01234567); |
124 | 0 | vacc2x01234567 = _mm256_mul_ps(vacc2x01234567, vscale01234567); |
125 | 0 | vacc3x01234567 = _mm256_mul_ps(vacc3x01234567, vscale01234567); |
126 | 0 | vacc4x01234567 = _mm256_mul_ps(vacc4x01234567, vscale01234567); |
127 | 0 | const __m256 vscale89ABCDEF = _mm256_loadu_ps((const float*) w + 8); |
128 | 0 | vacc0x89ABCDEF = _mm256_mul_ps(vacc0x89ABCDEF, vscale89ABCDEF); |
129 | 0 | vacc1x89ABCDEF = _mm256_mul_ps(vacc1x89ABCDEF, vscale89ABCDEF); |
130 | 0 | vacc2x89ABCDEF = _mm256_mul_ps(vacc2x89ABCDEF, vscale89ABCDEF); |
131 | 0 | vacc3x89ABCDEF = _mm256_mul_ps(vacc3x89ABCDEF, vscale89ABCDEF); |
132 | 0 | vacc4x89ABCDEF = _mm256_mul_ps(vacc4x89ABCDEF, vscale89ABCDEF); |
133 | 0 | w = (const float*) w + 16; |
134 | 0 | vacc0x01234567 = _mm256_max_ps(vmin, vacc0x01234567); |
135 | 0 | vacc1x01234567 = _mm256_max_ps(vmin, vacc1x01234567); |
136 | 0 | vacc2x01234567 = _mm256_max_ps(vmin, vacc2x01234567); |
137 | 0 | vacc3x01234567 = _mm256_max_ps(vmin, vacc3x01234567); |
138 | 0 | vacc4x01234567 = _mm256_max_ps(vmin, vacc4x01234567); |
139 | 0 | vacc0x89ABCDEF = _mm256_max_ps(vmin, vacc0x89ABCDEF); |
140 | 0 | vacc1x89ABCDEF = _mm256_max_ps(vmin, vacc1x89ABCDEF); |
141 | 0 | vacc2x89ABCDEF = _mm256_max_ps(vmin, vacc2x89ABCDEF); |
142 | 0 | vacc3x89ABCDEF = _mm256_max_ps(vmin, vacc3x89ABCDEF); |
143 | 0 | vacc4x89ABCDEF = _mm256_max_ps(vmin, vacc4x89ABCDEF); |
144 | |
|
145 | 0 | vacc0x01234567 = _mm256_min_ps(vmax, vacc0x01234567); |
146 | 0 | vacc1x01234567 = _mm256_min_ps(vmax, vacc1x01234567); |
147 | 0 | vacc2x01234567 = _mm256_min_ps(vmax, vacc2x01234567); |
148 | 0 | vacc3x01234567 = _mm256_min_ps(vmax, vacc3x01234567); |
149 | 0 | vacc4x01234567 = _mm256_min_ps(vmax, vacc4x01234567); |
150 | 0 | vacc0x89ABCDEF = _mm256_min_ps(vmax, vacc0x89ABCDEF); |
151 | 0 | vacc1x89ABCDEF = _mm256_min_ps(vmax, vacc1x89ABCDEF); |
152 | 0 | vacc2x89ABCDEF = _mm256_min_ps(vmax, vacc2x89ABCDEF); |
153 | 0 | vacc3x89ABCDEF = _mm256_min_ps(vmax, vacc3x89ABCDEF); |
154 | 0 | vacc4x89ABCDEF = _mm256_min_ps(vmax, vacc4x89ABCDEF); |
155 | |
|
156 | 0 | if XNN_LIKELY(nc >= 16) { |
157 | 0 | _mm256_storeu_ps(c0, vacc0x01234567); |
158 | 0 | _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF); |
159 | 0 | c0 = (float*) ((uintptr_t) c0 + cn_stride); |
160 | 0 | _mm256_storeu_ps(c1, vacc1x01234567); |
161 | 0 | _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF); |
162 | 0 | c1 = (float*) ((uintptr_t) c1 + cn_stride); |
163 | 0 | _mm256_storeu_ps(c2, vacc2x01234567); |
164 | 0 | _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF); |
165 | 0 | c2 = (float*) ((uintptr_t) c2 + cn_stride); |
166 | 0 | _mm256_storeu_ps(c3, vacc3x01234567); |
167 | 0 | _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF); |
168 | 0 | c3 = (float*) ((uintptr_t) c3 + cn_stride); |
169 | 0 | _mm256_storeu_ps(c4, vacc4x01234567); |
170 | 0 | _mm256_storeu_ps(c4 + 8, vacc4x89ABCDEF); |
171 | 0 | c4 = (float*) ((uintptr_t) c4 + cn_stride); |
172 | |
|
173 | 0 | a0 = (const float*) ((uintptr_t) a0 - kc); |
174 | 0 | a1 = (const float*) ((uintptr_t) a1 - kc); |
175 | 0 | a2 = (const float*) ((uintptr_t) a2 - kc); |
176 | 0 | a3 = (const float*) ((uintptr_t) a3 - kc); |
177 | 0 | a4 = (const float*) ((uintptr_t) a4 - kc); |
178 | |
|
179 | 0 | nc -= 16; |
180 | 0 | } else { |
181 | 0 | if (nc & 8) { |
182 | 0 | _mm256_storeu_ps(c0, vacc0x01234567); |
183 | 0 | _mm256_storeu_ps(c1, vacc1x01234567); |
184 | 0 | _mm256_storeu_ps(c2, vacc2x01234567); |
185 | 0 | _mm256_storeu_ps(c3, vacc3x01234567); |
186 | 0 | _mm256_storeu_ps(c4, vacc4x01234567); |
187 | |
|
188 | 0 | vacc0x01234567 = vacc0x89ABCDEF; |
189 | 0 | vacc1x01234567 = vacc1x89ABCDEF; |
190 | 0 | vacc2x01234567 = vacc2x89ABCDEF; |
191 | 0 | vacc3x01234567 = vacc3x89ABCDEF; |
192 | 0 | vacc4x01234567 = vacc4x89ABCDEF; |
193 | |
|
194 | 0 | c0 += 8; |
195 | 0 | c1 += 8; |
196 | 0 | c2 += 8; |
197 | 0 | c3 += 8; |
198 | 0 | c4 += 8; |
199 | 0 | } |
200 | 0 | __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567); |
201 | 0 | __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567); |
202 | 0 | __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567); |
203 | 0 | __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567); |
204 | 0 | __m128 vacc4x0123 = _mm256_castps256_ps128(vacc4x01234567); |
205 | 0 | if (nc & 4) { |
206 | 0 | _mm_storeu_ps(c0, vacc0x0123); |
207 | 0 | _mm_storeu_ps(c1, vacc1x0123); |
208 | 0 | _mm_storeu_ps(c2, vacc2x0123); |
209 | 0 | _mm_storeu_ps(c3, vacc3x0123); |
210 | 0 | _mm_storeu_ps(c4, vacc4x0123); |
211 | |
|
212 | 0 | vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1); |
213 | 0 | vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1); |
214 | 0 | vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1); |
215 | 0 | vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1); |
216 | 0 | vacc4x0123 = _mm256_extractf128_ps(vacc4x01234567, 1); |
217 | |
|
218 | 0 | c0 += 4; |
219 | 0 | c1 += 4; |
220 | 0 | c2 += 4; |
221 | 0 | c3 += 4; |
222 | 0 | c4 += 4; |
223 | 0 | } |
224 | 0 | if (nc & 2) { |
225 | 0 | _mm_storel_pi((__m64*) c0, vacc0x0123); |
226 | 0 | _mm_storel_pi((__m64*) c1, vacc1x0123); |
227 | 0 | _mm_storel_pi((__m64*) c2, vacc2x0123); |
228 | 0 | _mm_storel_pi((__m64*) c3, vacc3x0123); |
229 | 0 | _mm_storel_pi((__m64*) c4, vacc4x0123); |
230 | |
|
231 | 0 | vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123); |
232 | 0 | vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123); |
233 | 0 | vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123); |
234 | 0 | vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123); |
235 | 0 | vacc4x0123 = _mm_movehl_ps(vacc4x0123, vacc4x0123); |
236 | |
|
237 | 0 | c0 += 2; |
238 | 0 | c1 += 2; |
239 | 0 | c2 += 2; |
240 | 0 | c3 += 2; |
241 | 0 | c4 += 2; |
242 | 0 | } |
243 | 0 | if (nc & 1) { |
244 | 0 | _mm_store_ss(c0, vacc0x0123); |
245 | 0 | _mm_store_ss(c1, vacc1x0123); |
246 | 0 | _mm_store_ss(c2, vacc2x0123); |
247 | 0 | _mm_store_ss(c3, vacc3x0123); |
248 | 0 | _mm_store_ss(c4, vacc4x0123); |
249 | 0 | } |
250 | |
|
251 | 0 | nc = 0; |
252 | 0 | } |
253 | 0 | } while (nc != 0); |
254 | 0 | } |