/src/xnnpack/src/f32-qc8w-gemm/gen/f32-qc8w-gemm-5x16-minmax-avx-broadcast.c
Line | Count | Source |
1 | | // clang-format off |
2 | | // Auto-generated file. Do not edit! |
3 | | // Template: src/f32-gemm/avx-broadcast.c.in |
4 | | // Generator: tools/xngen |
5 | | // |
6 | | // Copyright 2019 Google LLC |
7 | | // |
8 | | // This source code is licensed under the BSD-style license found in the |
9 | | // LICENSE file in the root directory of this source tree. |
10 | | |
11 | | #include <assert.h> |
12 | | #include <stddef.h> |
13 | | #include <stdint.h> |
14 | | |
15 | | #include <immintrin.h> |
16 | | |
17 | | #include "src/xnnpack/common.h" |
18 | | #include "src/xnnpack/gemm.h" |
19 | | #include "src/xnnpack/unaligned.h" |
20 | | #include "src/xnnpack/microparams.h" |
21 | | |
22 | | |
23 | | void xnn_f32_qc8w_gemm_minmax_ukernel_5x16__avx_broadcast( |
24 | | size_t mr, |
25 | | size_t nc, |
26 | | size_t kc, |
27 | | const float* restrict a, |
28 | | size_t a_stride, |
29 | | const void* restrict w, |
30 | | float* restrict c, |
31 | | size_t cm_stride, |
32 | | size_t cn_stride, |
33 | | const struct xnn_f32_minmax_params* restrict params) |
34 | 0 | { |
35 | 0 | assert(mr != 0); |
36 | 0 | assert(mr <= 5); |
37 | 0 | assert(nc != 0); |
38 | 0 | assert(kc != 0); |
39 | 0 | assert(kc % sizeof(float) == 0); |
40 | 0 | assert(a != NULL); |
41 | 0 | assert(w != NULL); |
42 | 0 | assert(c != NULL); |
43 | | |
44 | 0 | const float* a0 = a; |
45 | 0 | float* c0 = c; |
46 | 0 | const float* a1 = (const float*) ((uintptr_t) a0 + a_stride); |
47 | 0 | float* c1 = (float*) ((uintptr_t) c0 + cm_stride); |
48 | 0 | if XNN_UNPREDICTABLE(mr < 2) { |
49 | 0 | a1 = a0; |
50 | 0 | c1 = c0; |
51 | 0 | } |
52 | 0 | const float* a2 = (const float*) ((uintptr_t) a1 + a_stride); |
53 | 0 | float* c2 = (float*) ((uintptr_t) c1 + cm_stride); |
54 | 0 | if XNN_UNPREDICTABLE(mr <= 2) { |
55 | 0 | a2 = a1; |
56 | 0 | c2 = c1; |
57 | 0 | } |
58 | 0 | const float* a3 = (const float*) ((uintptr_t) a2 + a_stride); |
59 | 0 | float* c3 = (float*) ((uintptr_t) c2 + cm_stride); |
60 | 0 | if XNN_UNPREDICTABLE(mr < 4) { |
61 | 0 | a3 = a2; |
62 | 0 | c3 = c2; |
63 | 0 | } |
64 | 0 | const float* a4 = (const float*) ((uintptr_t) a3 + a_stride); |
65 | 0 | float* c4 = (float*) ((uintptr_t) c3 + cm_stride); |
66 | 0 | if XNN_UNPREDICTABLE(mr <= 4) { |
67 | 0 | a4 = a3; |
68 | 0 | c4 = c3; |
69 | 0 | } |
70 | |
|
71 | 0 | const __m256 vmin = _mm256_set1_ps(params->scalar.min); |
72 | 0 | const __m256 vmax = _mm256_set1_ps(params->scalar.max); |
73 | 0 | XNN_FORCE_REALIZATION(vmin); |
74 | 0 | XNN_FORCE_REALIZATION(vmax); |
75 | |
|
76 | 0 | do { |
77 | 0 | __m256 vacc0x01234567 = _mm256_loadu_ps((const float*) w + 0); |
78 | 0 | __m256 vacc0x89ABCDEF = _mm256_loadu_ps((const float*) w + 8); |
79 | 0 | __m256 vacc1x01234567 = vacc0x01234567; |
80 | 0 | __m256 vacc1x89ABCDEF = vacc0x89ABCDEF; |
81 | 0 | __m256 vacc2x01234567 = vacc0x01234567; |
82 | 0 | __m256 vacc2x89ABCDEF = vacc0x89ABCDEF; |
83 | 0 | __m256 vacc3x01234567 = vacc0x01234567; |
84 | 0 | __m256 vacc3x89ABCDEF = vacc0x89ABCDEF; |
85 | 0 | __m256 vacc4x01234567 = vacc0x01234567; |
86 | 0 | __m256 vacc4x89ABCDEF = vacc0x89ABCDEF; |
87 | 0 | w = (const float*) w + 16; |
88 | |
|
89 | 0 | size_t k = kc; |
90 | 0 | do { |
91 | 0 | const __m256 va0 = _mm256_broadcast_ss(a0); |
92 | 0 | a0 += 1; |
93 | 0 | const __m256 va1 = _mm256_broadcast_ss(a1); |
94 | 0 | a1 += 1; |
95 | 0 | const __m256 va2 = _mm256_broadcast_ss(a2); |
96 | 0 | a2 += 1; |
97 | 0 | const __m256 va3 = _mm256_broadcast_ss(a3); |
98 | 0 | a3 += 1; |
99 | 0 | const __m256 va4 = _mm256_broadcast_ss(a4); |
100 | 0 | a4 += 1; |
101 | |
|
102 | 0 | const __m128i vbi0123 = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_u32((const int8_t*) w))); |
103 | 0 | const __m128i vbi4567 = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_u32((const int8_t*) w + 4))); |
104 | 0 | const __m128i vbi89AB = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_u32((const int8_t*) w + 8))); |
105 | 0 | const __m128i vbiCDEF = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_u32((const int8_t*) w + 12))); |
106 | 0 | const __m256i vbi01234567 = _mm256_castps_si256(_mm256_insertf128_ps(_mm256_castsi256_ps(_mm256_castsi128_si256(vbi0123)), _mm_castsi128_ps(vbi4567), 1)); |
107 | 0 | const __m256i vbi89ABCDEF = _mm256_castps_si256(_mm256_insertf128_ps(_mm256_castsi256_ps(_mm256_castsi128_si256(vbi89AB)), _mm_castsi128_ps(vbiCDEF), 1)); |
108 | 0 | w = (const int8_t*) w + 16; |
109 | 0 | const __m256 vb01234567 = _mm256_cvtepi32_ps(vbi01234567); |
110 | 0 | const __m256 vb89ABCDEF = _mm256_cvtepi32_ps(vbi89ABCDEF); |
111 | |
|
112 | 0 | vacc0x01234567 = _mm256_add_ps(vacc0x01234567, _mm256_mul_ps(va0, vb01234567)); |
113 | 0 | vacc1x01234567 = _mm256_add_ps(vacc1x01234567, _mm256_mul_ps(va1, vb01234567)); |
114 | 0 | vacc2x01234567 = _mm256_add_ps(vacc2x01234567, _mm256_mul_ps(va2, vb01234567)); |
115 | 0 | vacc3x01234567 = _mm256_add_ps(vacc3x01234567, _mm256_mul_ps(va3, vb01234567)); |
116 | 0 | vacc4x01234567 = _mm256_add_ps(vacc4x01234567, _mm256_mul_ps(va4, vb01234567)); |
117 | 0 | vacc0x89ABCDEF = _mm256_add_ps(vacc0x89ABCDEF, _mm256_mul_ps(va0, vb89ABCDEF)); |
118 | 0 | vacc1x89ABCDEF = _mm256_add_ps(vacc1x89ABCDEF, _mm256_mul_ps(va1, vb89ABCDEF)); |
119 | 0 | vacc2x89ABCDEF = _mm256_add_ps(vacc2x89ABCDEF, _mm256_mul_ps(va2, vb89ABCDEF)); |
120 | 0 | vacc3x89ABCDEF = _mm256_add_ps(vacc3x89ABCDEF, _mm256_mul_ps(va3, vb89ABCDEF)); |
121 | 0 | vacc4x89ABCDEF = _mm256_add_ps(vacc4x89ABCDEF, _mm256_mul_ps(va4, vb89ABCDEF)); |
122 | |
|
123 | 0 | k -= sizeof(float); |
124 | 0 | } while (k != 0); |
125 | |
|
126 | 0 | const __m256 vscale01234567 = _mm256_loadu_ps((const float*) w + 0); |
127 | 0 | vacc0x01234567 = _mm256_mul_ps(vacc0x01234567, vscale01234567); |
128 | 0 | vacc1x01234567 = _mm256_mul_ps(vacc1x01234567, vscale01234567); |
129 | 0 | vacc2x01234567 = _mm256_mul_ps(vacc2x01234567, vscale01234567); |
130 | 0 | vacc3x01234567 = _mm256_mul_ps(vacc3x01234567, vscale01234567); |
131 | 0 | vacc4x01234567 = _mm256_mul_ps(vacc4x01234567, vscale01234567); |
132 | 0 | const __m256 vscale89ABCDEF = _mm256_loadu_ps((const float*) w + 8); |
133 | 0 | vacc0x89ABCDEF = _mm256_mul_ps(vacc0x89ABCDEF, vscale89ABCDEF); |
134 | 0 | vacc1x89ABCDEF = _mm256_mul_ps(vacc1x89ABCDEF, vscale89ABCDEF); |
135 | 0 | vacc2x89ABCDEF = _mm256_mul_ps(vacc2x89ABCDEF, vscale89ABCDEF); |
136 | 0 | vacc3x89ABCDEF = _mm256_mul_ps(vacc3x89ABCDEF, vscale89ABCDEF); |
137 | 0 | vacc4x89ABCDEF = _mm256_mul_ps(vacc4x89ABCDEF, vscale89ABCDEF); |
138 | 0 | w = (const float*) w + 16; |
139 | 0 | vacc0x01234567 = _mm256_max_ps(vmin, vacc0x01234567); |
140 | 0 | vacc1x01234567 = _mm256_max_ps(vmin, vacc1x01234567); |
141 | 0 | vacc2x01234567 = _mm256_max_ps(vmin, vacc2x01234567); |
142 | 0 | vacc3x01234567 = _mm256_max_ps(vmin, vacc3x01234567); |
143 | 0 | vacc4x01234567 = _mm256_max_ps(vmin, vacc4x01234567); |
144 | 0 | vacc0x89ABCDEF = _mm256_max_ps(vmin, vacc0x89ABCDEF); |
145 | 0 | vacc1x89ABCDEF = _mm256_max_ps(vmin, vacc1x89ABCDEF); |
146 | 0 | vacc2x89ABCDEF = _mm256_max_ps(vmin, vacc2x89ABCDEF); |
147 | 0 | vacc3x89ABCDEF = _mm256_max_ps(vmin, vacc3x89ABCDEF); |
148 | 0 | vacc4x89ABCDEF = _mm256_max_ps(vmin, vacc4x89ABCDEF); |
149 | |
|
150 | 0 | vacc0x01234567 = _mm256_min_ps(vmax, vacc0x01234567); |
151 | 0 | vacc1x01234567 = _mm256_min_ps(vmax, vacc1x01234567); |
152 | 0 | vacc2x01234567 = _mm256_min_ps(vmax, vacc2x01234567); |
153 | 0 | vacc3x01234567 = _mm256_min_ps(vmax, vacc3x01234567); |
154 | 0 | vacc4x01234567 = _mm256_min_ps(vmax, vacc4x01234567); |
155 | 0 | vacc0x89ABCDEF = _mm256_min_ps(vmax, vacc0x89ABCDEF); |
156 | 0 | vacc1x89ABCDEF = _mm256_min_ps(vmax, vacc1x89ABCDEF); |
157 | 0 | vacc2x89ABCDEF = _mm256_min_ps(vmax, vacc2x89ABCDEF); |
158 | 0 | vacc3x89ABCDEF = _mm256_min_ps(vmax, vacc3x89ABCDEF); |
159 | 0 | vacc4x89ABCDEF = _mm256_min_ps(vmax, vacc4x89ABCDEF); |
160 | |
|
161 | 0 | if XNN_LIKELY(nc >= 16) { |
162 | 0 | _mm256_storeu_ps(c0, vacc0x01234567); |
163 | 0 | _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF); |
164 | 0 | c0 = (float*) ((uintptr_t) c0 + cn_stride); |
165 | 0 | _mm256_storeu_ps(c1, vacc1x01234567); |
166 | 0 | _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF); |
167 | 0 | c1 = (float*) ((uintptr_t) c1 + cn_stride); |
168 | 0 | _mm256_storeu_ps(c2, vacc2x01234567); |
169 | 0 | _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF); |
170 | 0 | c2 = (float*) ((uintptr_t) c2 + cn_stride); |
171 | 0 | _mm256_storeu_ps(c3, vacc3x01234567); |
172 | 0 | _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF); |
173 | 0 | c3 = (float*) ((uintptr_t) c3 + cn_stride); |
174 | 0 | _mm256_storeu_ps(c4, vacc4x01234567); |
175 | 0 | _mm256_storeu_ps(c4 + 8, vacc4x89ABCDEF); |
176 | 0 | c4 = (float*) ((uintptr_t) c4 + cn_stride); |
177 | |
|
178 | 0 | a0 = (const float*) ((uintptr_t) a0 - kc); |
179 | 0 | a1 = (const float*) ((uintptr_t) a1 - kc); |
180 | 0 | a2 = (const float*) ((uintptr_t) a2 - kc); |
181 | 0 | a3 = (const float*) ((uintptr_t) a3 - kc); |
182 | 0 | a4 = (const float*) ((uintptr_t) a4 - kc); |
183 | |
|
184 | 0 | nc -= 16; |
185 | 0 | } else { |
186 | 0 | if (nc & 8) { |
187 | 0 | _mm256_storeu_ps(c0, vacc0x01234567); |
188 | 0 | _mm256_storeu_ps(c1, vacc1x01234567); |
189 | 0 | _mm256_storeu_ps(c2, vacc2x01234567); |
190 | 0 | _mm256_storeu_ps(c3, vacc3x01234567); |
191 | 0 | _mm256_storeu_ps(c4, vacc4x01234567); |
192 | |
|
193 | 0 | vacc0x01234567 = vacc0x89ABCDEF; |
194 | 0 | vacc1x01234567 = vacc1x89ABCDEF; |
195 | 0 | vacc2x01234567 = vacc2x89ABCDEF; |
196 | 0 | vacc3x01234567 = vacc3x89ABCDEF; |
197 | 0 | vacc4x01234567 = vacc4x89ABCDEF; |
198 | |
|
199 | 0 | c0 += 8; |
200 | 0 | c1 += 8; |
201 | 0 | c2 += 8; |
202 | 0 | c3 += 8; |
203 | 0 | c4 += 8; |
204 | 0 | } |
205 | 0 | __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567); |
206 | 0 | __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567); |
207 | 0 | __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567); |
208 | 0 | __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567); |
209 | 0 | __m128 vacc4x0123 = _mm256_castps256_ps128(vacc4x01234567); |
210 | 0 | if (nc & 4) { |
211 | 0 | _mm_storeu_ps(c0, vacc0x0123); |
212 | 0 | _mm_storeu_ps(c1, vacc1x0123); |
213 | 0 | _mm_storeu_ps(c2, vacc2x0123); |
214 | 0 | _mm_storeu_ps(c3, vacc3x0123); |
215 | 0 | _mm_storeu_ps(c4, vacc4x0123); |
216 | |
|
217 | 0 | vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1); |
218 | 0 | vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1); |
219 | 0 | vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1); |
220 | 0 | vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1); |
221 | 0 | vacc4x0123 = _mm256_extractf128_ps(vacc4x01234567, 1); |
222 | |
|
223 | 0 | c0 += 4; |
224 | 0 | c1 += 4; |
225 | 0 | c2 += 4; |
226 | 0 | c3 += 4; |
227 | 0 | c4 += 4; |
228 | 0 | } |
229 | 0 | if (nc & 2) { |
230 | 0 | _mm_storel_pi((__m64*) c0, vacc0x0123); |
231 | 0 | _mm_storel_pi((__m64*) c1, vacc1x0123); |
232 | 0 | _mm_storel_pi((__m64*) c2, vacc2x0123); |
233 | 0 | _mm_storel_pi((__m64*) c3, vacc3x0123); |
234 | 0 | _mm_storel_pi((__m64*) c4, vacc4x0123); |
235 | |
|
236 | 0 | vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123); |
237 | 0 | vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123); |
238 | 0 | vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123); |
239 | 0 | vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123); |
240 | 0 | vacc4x0123 = _mm_movehl_ps(vacc4x0123, vacc4x0123); |
241 | |
|
242 | 0 | c0 += 2; |
243 | 0 | c1 += 2; |
244 | 0 | c2 += 2; |
245 | 0 | c3 += 2; |
246 | 0 | c4 += 2; |
247 | 0 | } |
248 | 0 | if (nc & 1) { |
249 | 0 | _mm_store_ss(c0, vacc0x0123); |
250 | 0 | _mm_store_ss(c1, vacc1x0123); |
251 | 0 | _mm_store_ss(c2, vacc2x0123); |
252 | 0 | _mm_store_ss(c3, vacc3x0123); |
253 | 0 | _mm_store_ss(c4, vacc4x0123); |
254 | 0 | } |
255 | |
|
256 | 0 | nc = 0; |
257 | 0 | } |
258 | 0 | } while (nc != 0); |
259 | 0 | } |