/src/xnnpack/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x4c8-minmax-sse41-ld128.c
Line | Count | Source |
1 | | // clang-format off |
2 | | // Auto-generated file. Do not edit! |
3 | | // Template: src/qs8-gemm/MRx4c8-sse.c.in |
4 | | // Generator: tools/xngen |
5 | | // |
6 | | // Copyright 2020 Google LLC |
7 | | // |
8 | | // This source code is licensed under the BSD-style license found in the |
9 | | // LICENSE file in the root directory of this source tree. |
10 | | |
11 | | #include <assert.h> |
12 | | #include <stddef.h> |
13 | | #include <stdint.h> |
14 | | |
15 | | #include <smmintrin.h> |
16 | | |
17 | | #include "src/xnnpack/common.h" |
18 | | #include "src/xnnpack/gemm.h" |
19 | | #include "src/xnnpack/math.h" |
20 | | #include "src/xnnpack/microparams.h" |
21 | | |
22 | | |
23 | | void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld128( |
24 | | size_t mr, |
25 | | size_t nc, |
26 | | size_t kc, |
27 | | const int8_t* restrict a, |
28 | | size_t a_stride, |
29 | | const void* restrict w, |
30 | | float* restrict c, |
31 | | size_t cm_stride, |
32 | | size_t cn_stride, |
33 | | const struct xnn_f32_qb4w_minmax_params* restrict params, |
34 | | const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS |
35 | 0 | { |
36 | 0 | assert(mr != 0); |
37 | 0 | assert(mr <= 1); |
38 | 0 | assert(nc != 0); |
39 | 0 | assert(kc != 0); |
40 | 0 | assert(kc % sizeof(int8_t) == 0); |
41 | 0 | assert(a != NULL); |
42 | 0 | assert(w != NULL); |
43 | 0 | assert(c != NULL); |
44 | | |
45 | 0 | size_t bl = params->scalar.blocksize; |
46 | 0 | assert(bl <= round_up_po2(kc, 2)); |
47 | 0 | assert(bl != 0); |
48 | 0 | assert(bl % 32 == 0); |
49 | 0 | kc = round_up_po2(kc, 8 * sizeof(int8_t)); |
50 | 0 | const int8_t* a0 = a; |
51 | 0 | float* c0 = c; |
52 | |
|
53 | 0 | const __m128 vmin = _mm_set1_ps(params->scalar.min); |
54 | 0 | const __m128 vmax = _mm_set1_ps(params->scalar.max); |
55 | 0 | XNN_FORCE_REALIZATION(vmin); |
56 | 0 | XNN_FORCE_REALIZATION(vmax); |
57 | |
|
58 | 0 | const __m128i vmask = _mm_set1_epi8(0xF0); |
59 | 0 | XNN_FORCE_REALIZATION(vmask); |
60 | |
|
61 | 0 | do { |
62 | 0 | const __m128 vksum = _mm_loadu_ps((const float*) w); |
63 | 0 | __m128i vinput_zero_point0 = _mm_cvtsi32_si128(*((const int*) &quantization_params[0].zero_point)); |
64 | 0 | vinput_zero_point0 = _mm_shuffle_epi32(vinput_zero_point0, _MM_SHUFFLE(0, 0, 0, 0)); |
65 | |
|
66 | 0 | __m128 vinput_zero_point0_float = _mm_cvtepi32_ps(vinput_zero_point0); |
67 | 0 | __m128 vout0x0123 = _mm_mul_ps(vksum, vinput_zero_point0_float); |
68 | 0 | w = (const int32_t*) w + 4; |
69 | |
|
70 | 0 | for (size_t kb=0; kb < kc; kb += bl) { |
71 | 0 | __m128i vacc0x0 = _mm_setzero_si128(); |
72 | 0 | __m128i vacc0x1 = _mm_setzero_si128(); |
73 | 0 | __m128i vacc0x2 = _mm_setzero_si128(); |
74 | 0 | __m128i vacc0x3 = _mm_setzero_si128(); |
75 | 0 | size_t k = bl; |
76 | |
|
77 | 0 | while (k >= 16 * sizeof(int8_t)) { |
78 | 0 | const __m128i va0c0 = _mm_loadl_epi64((const __m128i*) a0); |
79 | 0 | const __m128i vxa0c0 = _mm_cvtepi8_epi16(va0c0); |
80 | 0 | a0 += 8; |
81 | |
|
82 | 0 | const __m128i vb01c01 = _mm_loadu_si128((const __m128i*) w); |
83 | 0 | const __m128i vbs01c0 = _mm_slli_epi32(vb01c01, 4); |
84 | 0 | const __m128i vb01c0 = _mm_and_si128(vbs01c0, vmask); |
85 | 0 | const __m128i vsb01c0 = _mm_cmpgt_epi8(_mm_setzero_si128(), vb01c0); |
86 | 0 | const __m128i vxb0c0 = _mm_unpacklo_epi8(vb01c0, vsb01c0); |
87 | 0 | const __m128i vxb1c0 = _mm_unpackhi_epi8(vb01c0, vsb01c0); |
88 | |
|
89 | 0 | vacc0x0 = _mm_add_epi32(vacc0x0, _mm_madd_epi16(vxa0c0, vxb0c0)); |
90 | 0 | vacc0x1 = _mm_add_epi32(vacc0x1, _mm_madd_epi16(vxa0c0, vxb1c0)); |
91 | 0 | const __m128i vb23c01 = _mm_loadu_si128((const __m128i*) ((const int8_t*) w + 16)); |
92 | 0 | const __m128i vbs23c0 = _mm_slli_epi32(vb23c01, 4); |
93 | 0 | const __m128i vb23c0 = _mm_and_si128(vbs23c0, vmask); |
94 | 0 | const __m128i vsb23c0 = _mm_cmpgt_epi8(_mm_setzero_si128(), vb23c0); |
95 | 0 | const __m128i vxb2c0 = _mm_unpacklo_epi8(vb23c0, vsb23c0); |
96 | 0 | const __m128i vxb3c0 = _mm_unpackhi_epi8(vb23c0, vsb23c0); |
97 | |
|
98 | 0 | vacc0x2 = _mm_add_epi32(vacc0x2, _mm_madd_epi16(vxa0c0, vxb2c0)); |
99 | 0 | vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0c0, vxb3c0)); |
100 | |
|
101 | 0 | const __m128i va0c1 = _mm_loadl_epi64((const __m128i*) a0); |
102 | 0 | const __m128i vxa0c1 = _mm_cvtepi8_epi16(va0c1); |
103 | 0 | a0 += 8; |
104 | |
|
105 | 0 | const __m128i vb01c1 = _mm_and_si128(vb01c01, vmask); |
106 | 0 | const __m128i vsb01c1 = _mm_cmpgt_epi8(_mm_setzero_si128(), vb01c1); |
107 | 0 | const __m128i vxb0c1 = _mm_unpacklo_epi8(vb01c1, vsb01c1); |
108 | 0 | const __m128i vxb1c1 = _mm_unpackhi_epi8(vb01c1, vsb01c1); |
109 | |
|
110 | 0 | vacc0x0 = _mm_add_epi32(vacc0x0, _mm_madd_epi16(vxa0c1, vxb0c1)); |
111 | 0 | vacc0x1 = _mm_add_epi32(vacc0x1, _mm_madd_epi16(vxa0c1, vxb1c1)); |
112 | 0 | const __m128i vb23c1 = _mm_and_si128(vb23c01, vmask); |
113 | 0 | const __m128i vsb23c1 = _mm_cmpgt_epi8(_mm_setzero_si128(), vb23c1); |
114 | 0 | const __m128i vxb2c1 = _mm_unpacklo_epi8(vb23c1, vsb23c1); |
115 | 0 | const __m128i vxb3c1 = _mm_unpackhi_epi8(vb23c1, vsb23c1); |
116 | |
|
117 | 0 | vacc0x2 = _mm_add_epi32(vacc0x2, _mm_madd_epi16(vxa0c1, vxb2c1)); |
118 | 0 | vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0c1, vxb3c1)); |
119 | | |
120 | |
|
121 | 0 | w = (const int8_t*) w + 32; |
122 | 0 | k -= 16 * sizeof(int8_t); |
123 | 0 | } |
124 | |
|
125 | 0 | while (k >= 8 * sizeof(int8_t)) { |
126 | 0 | const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0); |
127 | 0 | const __m128i vxa0 = _mm_cvtepi8_epi16(va0); |
128 | 0 | a0 += 8; |
129 | |
|
130 | 0 | __m128i vb01 = _mm_loadu_si128((const __m128i*) w); |
131 | 0 | vb01 = _mm_slli_epi32(vb01, 4); |
132 | 0 | vb01 = _mm_and_si128(vb01, vmask); |
133 | |
|
134 | 0 | const __m128i vxbm1 = _mm_unpackhi_epi8(vb01, vb01); |
135 | 0 | const __m128i vxb0 = _mm_cvtepi8_epi16(vb01); |
136 | 0 | const __m128i vxb1 = _mm_srai_epi16(vxbm1, 8); |
137 | |
|
138 | 0 | vacc0x0 = _mm_add_epi32(vacc0x0, _mm_madd_epi16(vxa0, vxb0)); |
139 | 0 | vacc0x1 = _mm_add_epi32(vacc0x1, _mm_madd_epi16(vxa0, vxb1)); |
140 | 0 | __m128i vb23 = _mm_loadu_si128((const __m128i*) ((const int8_t*) w + 16)); |
141 | 0 | vb23 = _mm_slli_epi32(vb23, 4); |
142 | 0 | vb23 = _mm_and_si128(vb23, vmask); |
143 | |
|
144 | 0 | const __m128i vxbm3 = _mm_unpackhi_epi8(vb23, vb23); |
145 | 0 | const __m128i vxb2 = _mm_cvtepi8_epi16(vb23); |
146 | 0 | const __m128i vxb3 = _mm_srai_epi16(vxbm3, 8); |
147 | |
|
148 | 0 | vacc0x2 = _mm_add_epi32(vacc0x2, _mm_madd_epi16(vxa0, vxb2)); |
149 | 0 | vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3)); |
150 | |
|
151 | 0 | w = (const int8_t*) w + 32; |
152 | 0 | k -= 8 * sizeof(int8_t); |
153 | 0 | } |
154 | | // accumulate float |
155 | 0 | const __m128 vfilter_output_scale0123 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*) w)), 16)); |
156 | 0 | w = (const uint16_t*) w + 4; |
157 | |
|
158 | 0 | const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1); |
159 | 0 | const __m128i vacc0x23 = _mm_hadd_epi32(vacc0x2, vacc0x3); |
160 | |
|
161 | 0 | __m128i vacc0x0123 = _mm_hadd_epi32(vacc0x01, vacc0x23); |
162 | |
|
163 | 0 | vout0x0123 = _mm_add_ps(vout0x0123, _mm_mul_ps(_mm_cvtepi32_ps(vacc0x0123), vfilter_output_scale0123)); |
164 | 0 | } |
165 | |
|
166 | 0 | const __m128 vinput_scale0 = _mm_load1_ps(&quantization_params[0].inv_scale); |
167 | |
|
168 | 0 | vout0x0123 = _mm_mul_ps(vout0x0123, vinput_scale0); |
169 | | |
170 | |
|
171 | 0 | const __m128 vbias0123 = _mm_loadu_ps((const float*) w); |
172 | 0 | w = (const float*) w + 4; |
173 | 0 | vout0x0123 = _mm_add_ps(vout0x0123, vbias0123); |
174 | |
|
175 | 0 | vout0x0123 = _mm_max_ps(vout0x0123, vmin); |
176 | |
|
177 | 0 | vout0x0123 = _mm_min_ps(vout0x0123, vmax); |
178 | |
|
179 | 0 | if XNN_LIKELY(nc >= 4) { |
180 | 0 | _mm_storeu_ps(c0, vout0x0123); |
181 | |
|
182 | 0 | a0 = (const int8_t*) ((uintptr_t) a0 - kc); |
183 | |
|
184 | 0 | c0 = (float*) ((uintptr_t) c0 + cn_stride); |
185 | |
|
186 | 0 | nc -= 4; |
187 | 0 | } else { |
188 | 0 | if (nc & 2) { |
189 | 0 | _mm_storel_pi((__m64*) c0, vout0x0123); |
190 | 0 | vout0x0123 = _mm_unpackhi_ps(vout0x0123, vout0x0123); |
191 | 0 | c0 += 2; |
192 | 0 | } |
193 | 0 | if (nc & 1) { |
194 | 0 | _mm_store_ss(c0, vout0x0123); |
195 | 0 | } |
196 | 0 | nc = 0; |
197 | 0 | } |
198 | 0 | } while (nc != 0); |
199 | 0 | } |