/src/xnnpack/src/f32-igemm/gen/f32-igemm-5x8-minmax-avx-broadcast.c
Line | Count | Source |
1 | | // clang-format off |
2 | | // Auto-generated file. Do not edit! |
3 | | // Template: src/f32-igemm/avx-broadcast.c.in |
4 | | // Generator: tools/xngen |
5 | | // |
6 | | // Copyright 2019 Google LLC |
7 | | // |
8 | | // This source code is licensed under the BSD-style license found in the |
9 | | // LICENSE file in the root directory of this source tree. |
10 | | |
11 | | #include <assert.h> |
12 | | #include <stddef.h> |
13 | | #include <stdint.h> |
14 | | |
15 | | #include <immintrin.h> |
16 | | |
17 | | #include "src/xnnpack/common.h" |
18 | | #include "src/xnnpack/microparams.h" |
19 | | #include "src/xnnpack/igemm.h" |
20 | | |
21 | | |
22 | | void xnn_f32_igemm_minmax_ukernel_5x8__avx_broadcast( |
23 | | size_t mr, |
24 | | size_t nc, |
25 | | size_t kc, |
26 | | size_t ks, |
27 | | const float** restrict a, |
28 | | const float* restrict w, |
29 | | float* restrict c, |
30 | | size_t cm_stride, |
31 | | size_t cn_stride, |
32 | | size_t a_offset, |
33 | | const float* zero, |
34 | | const struct xnn_f32_minmax_params* restrict params) |
35 | 0 | { |
36 | 0 | assert(mr != 0); |
37 | 0 | assert(mr <= 5); |
38 | 0 | assert(nc != 0); |
39 | 0 | assert(kc != 0); |
40 | 0 | assert(kc % sizeof(float) == 0); |
41 | 0 | assert(ks != 0); |
42 | 0 | assert(ks % (5 * sizeof(void*)) == 0); |
43 | 0 | assert(a_offset % sizeof(float) == 0); |
44 | 0 | assert(a != NULL); |
45 | 0 | assert(w != NULL); |
46 | 0 | assert(c != NULL); |
47 | | |
48 | 0 | float* c0 = c; |
49 | 0 | float* c1 = (float*) ((uintptr_t) c0 + cm_stride); |
50 | 0 | if XNN_UNPREDICTABLE(mr < 2) { |
51 | 0 | c1 = c0; |
52 | 0 | } |
53 | 0 | float* c2 = (float*) ((uintptr_t) c1 + cm_stride); |
54 | 0 | if XNN_UNPREDICTABLE(mr <= 2) { |
55 | 0 | c2 = c1; |
56 | 0 | } |
57 | 0 | float* c3 = (float*) ((uintptr_t) c2 + cm_stride); |
58 | 0 | if XNN_UNPREDICTABLE(mr < 4) { |
59 | 0 | c3 = c2; |
60 | 0 | } |
61 | 0 | float* c4 = (float*) ((uintptr_t) c3 + cm_stride); |
62 | 0 | if XNN_UNPREDICTABLE(mr <= 4) { |
63 | 0 | c4 = c3; |
64 | 0 | } |
65 | |
|
66 | 0 | const __m256 vmin = _mm256_set1_ps(params->scalar.min); |
67 | 0 | const __m256 vmax = _mm256_set1_ps(params->scalar.max); |
68 | 0 | XNN_FORCE_REALIZATION(vmin); |
69 | 0 | XNN_FORCE_REALIZATION(vmax); |
70 | |
|
71 | 0 | do { |
72 | 0 | __m256 vacc0x01234567 = _mm256_load_ps(w); |
73 | 0 | __m256 vacc1x01234567 = vacc0x01234567; |
74 | 0 | __m256 vacc2x01234567 = vacc0x01234567; |
75 | 0 | __m256 vacc3x01234567 = vacc0x01234567; |
76 | 0 | __m256 vacc4x01234567 = vacc0x01234567; |
77 | 0 | w += 8; |
78 | |
|
79 | 0 | size_t p = ks; |
80 | 0 | do { |
81 | 0 | const float* restrict a0 = a[0]; |
82 | 0 | assert(a0 != NULL); |
83 | 0 | if XNN_UNPREDICTABLE(a0 != zero) { |
84 | 0 | a0 = (const float*) ((uintptr_t) a0 + a_offset); |
85 | 0 | } |
86 | 0 | const float* restrict a1 = a[1]; |
87 | 0 | assert(a1 != NULL); |
88 | 0 | if XNN_UNPREDICTABLE(a1 != zero) { |
89 | 0 | a1 = (const float*) ((uintptr_t) a1 + a_offset); |
90 | 0 | } |
91 | 0 | const float* restrict a2 = a[2]; |
92 | 0 | assert(a2 != NULL); |
93 | 0 | if XNN_UNPREDICTABLE(a2 != zero) { |
94 | 0 | a2 = (const float*) ((uintptr_t) a2 + a_offset); |
95 | 0 | } |
96 | 0 | const float* restrict a3 = a[3]; |
97 | 0 | assert(a3 != NULL); |
98 | 0 | if XNN_UNPREDICTABLE(a3 != zero) { |
99 | 0 | a3 = (const float*) ((uintptr_t) a3 + a_offset); |
100 | 0 | } |
101 | 0 | const float* restrict a4 = a[4]; |
102 | 0 | assert(a4 != NULL); |
103 | 0 | if XNN_UNPREDICTABLE(a4 != zero) { |
104 | 0 | a4 = (const float*) ((uintptr_t) a4 + a_offset); |
105 | 0 | } |
106 | 0 | a += 5; |
107 | |
|
108 | 0 | size_t k = kc; |
109 | 0 | do { |
110 | 0 | const __m256 vb01234567 = _mm256_load_ps(w); |
111 | 0 | w += 8; |
112 | |
|
113 | 0 | const __m256 va0 = _mm256_broadcast_ss(a0); |
114 | 0 | a0 += 1; |
115 | 0 | const __m256 va1 = _mm256_broadcast_ss(a1); |
116 | 0 | a1 += 1; |
117 | 0 | const __m256 va2 = _mm256_broadcast_ss(a2); |
118 | 0 | a2 += 1; |
119 | 0 | const __m256 va3 = _mm256_broadcast_ss(a3); |
120 | 0 | a3 += 1; |
121 | 0 | const __m256 va4 = _mm256_broadcast_ss(a4); |
122 | 0 | a4 += 1; |
123 | |
|
124 | 0 | vacc0x01234567 = _mm256_add_ps(vacc0x01234567, _mm256_mul_ps(va0, vb01234567)); |
125 | 0 | vacc1x01234567 = _mm256_add_ps(vacc1x01234567, _mm256_mul_ps(va1, vb01234567)); |
126 | 0 | vacc2x01234567 = _mm256_add_ps(vacc2x01234567, _mm256_mul_ps(va2, vb01234567)); |
127 | 0 | vacc3x01234567 = _mm256_add_ps(vacc3x01234567, _mm256_mul_ps(va3, vb01234567)); |
128 | 0 | vacc4x01234567 = _mm256_add_ps(vacc4x01234567, _mm256_mul_ps(va4, vb01234567)); |
129 | 0 | k -= sizeof(float); |
130 | 0 | } while (k != 0); |
131 | 0 | p -= 5 * sizeof(void*); |
132 | 0 | } while (p != 0); |
133 | | |
134 | 0 | vacc0x01234567 = _mm256_max_ps(vmin, vacc0x01234567); |
135 | 0 | vacc1x01234567 = _mm256_max_ps(vmin, vacc1x01234567); |
136 | 0 | vacc2x01234567 = _mm256_max_ps(vmin, vacc2x01234567); |
137 | 0 | vacc3x01234567 = _mm256_max_ps(vmin, vacc3x01234567); |
138 | 0 | vacc4x01234567 = _mm256_max_ps(vmin, vacc4x01234567); |
139 | |
|
140 | 0 | vacc0x01234567 = _mm256_min_ps(vmax, vacc0x01234567); |
141 | 0 | vacc1x01234567 = _mm256_min_ps(vmax, vacc1x01234567); |
142 | 0 | vacc2x01234567 = _mm256_min_ps(vmax, vacc2x01234567); |
143 | 0 | vacc3x01234567 = _mm256_min_ps(vmax, vacc3x01234567); |
144 | 0 | vacc4x01234567 = _mm256_min_ps(vmax, vacc4x01234567); |
145 | |
|
146 | 0 | if XNN_LIKELY(nc >= 8) { |
147 | 0 | _mm256_storeu_ps(c4, vacc4x01234567); |
148 | 0 | c4 = (float*) ((uintptr_t) c4 + cn_stride); |
149 | 0 | _mm256_storeu_ps(c3, vacc3x01234567); |
150 | 0 | c3 = (float*) ((uintptr_t) c3 + cn_stride); |
151 | 0 | _mm256_storeu_ps(c2, vacc2x01234567); |
152 | 0 | c2 = (float*) ((uintptr_t) c2 + cn_stride); |
153 | 0 | _mm256_storeu_ps(c1, vacc1x01234567); |
154 | 0 | c1 = (float*) ((uintptr_t) c1 + cn_stride); |
155 | 0 | _mm256_storeu_ps(c0, vacc0x01234567); |
156 | 0 | c0 = (float*) ((uintptr_t) c0 + cn_stride); |
157 | |
|
158 | 0 | a = (const float**restrict) ((uintptr_t) a - ks); |
159 | 0 | nc -= 8; |
160 | 0 | } else { |
161 | 0 | __m128 vacc4x0123 = _mm256_castps256_ps128(vacc4x01234567); |
162 | 0 | __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567); |
163 | 0 | __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567); |
164 | 0 | __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567); |
165 | 0 | __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567); |
166 | 0 | if (nc & 4) { |
167 | 0 | _mm_storeu_ps(c4, vacc4x0123); |
168 | 0 | _mm_storeu_ps(c3, vacc3x0123); |
169 | 0 | _mm_storeu_ps(c2, vacc2x0123); |
170 | 0 | _mm_storeu_ps(c1, vacc1x0123); |
171 | 0 | _mm_storeu_ps(c0, vacc0x0123); |
172 | |
|
173 | 0 | vacc4x0123 = _mm256_extractf128_ps(vacc4x01234567, 1); |
174 | 0 | vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1); |
175 | 0 | vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1); |
176 | 0 | vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1); |
177 | 0 | vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1); |
178 | |
|
179 | 0 | c4 += 4; |
180 | 0 | c3 += 4; |
181 | 0 | c2 += 4; |
182 | 0 | c1 += 4; |
183 | 0 | c0 += 4; |
184 | 0 | } |
185 | 0 | if (nc & 2) { |
186 | 0 | _mm_storel_pi((__m64*) c4, vacc4x0123); |
187 | 0 | _mm_storel_pi((__m64*) c3, vacc3x0123); |
188 | 0 | _mm_storel_pi((__m64*) c2, vacc2x0123); |
189 | 0 | _mm_storel_pi((__m64*) c1, vacc1x0123); |
190 | 0 | _mm_storel_pi((__m64*) c0, vacc0x0123); |
191 | |
|
192 | 0 | vacc4x0123 = _mm_movehl_ps(vacc4x0123, vacc4x0123); |
193 | 0 | vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123); |
194 | 0 | vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123); |
195 | 0 | vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123); |
196 | 0 | vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123); |
197 | |
|
198 | 0 | c4 += 2; |
199 | 0 | c3 += 2; |
200 | 0 | c2 += 2; |
201 | 0 | c1 += 2; |
202 | 0 | c0 += 2; |
203 | 0 | } |
204 | 0 | if (nc & 1) { |
205 | 0 | _mm_store_ss(c4, vacc4x0123); |
206 | 0 | _mm_store_ss(c3, vacc3x0123); |
207 | 0 | _mm_store_ss(c2, vacc2x0123); |
208 | 0 | _mm_store_ss(c1, vacc1x0123); |
209 | 0 | _mm_store_ss(c0, vacc0x0123); |
210 | 0 | } |
211 | |
|
212 | 0 | nc = 0; |
213 | 0 | } |
214 | 0 | } while (nc != 0); |
215 | 0 | } |