/src/xnnpack/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnni.c
Line | Count | Source |
1 | | // clang-format off |
2 | | // Auto-generated file. Do not edit! |
3 | | // Template: src/qs8-gemm/MRx8c8-avxvnni.c.in |
4 | | // Generator: tools/xngen |
5 | | // |
6 | | // Copyright 2024 Google LLC |
7 | | // |
8 | | // This source code is licensed under the BSD-style license found in the |
9 | | // LICENSE file in the root directory of this source tree. |
10 | | |
11 | | #include <assert.h> |
12 | | #include <stddef.h> |
13 | | #include <stdint.h> |
14 | | |
15 | | #include <immintrin.h> |
16 | | |
17 | | #include "src/xnnpack/common.h" |
18 | | #include "src/xnnpack/gemm.h" |
19 | | #include "src/xnnpack/intrinsics-polyfill.h" |
20 | | #include "src/xnnpack/math.h" |
21 | | #include "src/xnnpack/microparams.h" |
22 | | #include "src/xnnpack/unaligned.h" |
23 | | |
24 | | |
25 | | void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni( |
26 | | size_t mr, |
27 | | size_t nc, |
28 | | size_t kc, |
29 | | const int8_t* restrict a, |
30 | | size_t a_stride, |
31 | | const void* restrict w, |
32 | | xnn_float16* restrict c, |
33 | | size_t cm_stride, |
34 | | size_t cn_stride, |
35 | | const struct xnn_f16_qc4w_minmax_params* restrict params, |
36 | | const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS |
37 | 0 | { |
38 | 0 | assert(mr != 0); |
39 | 0 | assert(mr <= 1); |
40 | 0 | assert(nc != 0); |
41 | 0 | assert(kc != 0); |
42 | 0 | assert(kc % sizeof(int8_t) == 0); |
43 | 0 | assert(a != NULL); |
44 | 0 | assert(w != NULL); |
45 | 0 | assert(c != NULL); |
46 | | |
47 | 0 | kc = round_up_po2(kc, 8 * sizeof(int8_t)); |
48 | 0 | const int8_t* a0 = a; |
49 | 0 | uint16_t* c0 = (uint16_t*) c; |
50 | |
|
51 | 0 | const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); |
52 | 0 | const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); |
53 | 0 | const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); |
54 | | // XNN_FORCE_REALIZATION(voutput_min); |
55 | | // XNN_FORCE_REALIZATION(voutput_max); |
56 | 0 | const __m256i vmask = _mm256_set1_epi8(0xF0); |
57 | 0 | XNN_FORCE_REALIZATION(vmask); |
58 | 0 | do { |
59 | 0 | const __m256i vksum01234567 = _mm256_load_si256(w); |
60 | 0 | __m256i vsum0x01234567 = _mm256_mullo_epi32(vksum01234567, vinput_zero_point0); |
61 | 0 | __m256i vacc0x0123 = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(vsum0x01234567, 0)); |
62 | 0 | __m256i vacc0x4567 = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(vsum0x01234567, 1)); |
63 | 0 | __m256i vacc1x0x0123 = _mm256_setzero_si256(); |
64 | 0 | __m256i vacc1x0x4567 = _mm256_setzero_si256(); |
65 | 0 | w = (const int32_t*) w + 8; |
66 | |
|
67 | 0 | size_t k = kc; |
68 | 0 | while (k >= 16 * sizeof(int8_t)) { |
69 | 0 | const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); |
70 | 0 | const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); |
71 | 0 | a0 += 16; |
72 | |
|
73 | 0 | const __m256i vbb01234567x01234567 = _mm256_load_si256(w); |
74 | 0 | const __m256i vbb89ABCDEFx01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); |
75 | 0 | const __m256i vbs01234567x0123 = _mm256_slli_epi32(vbb01234567x01234567, 4); |
76 | 0 | const __m256i vbs89ABCDEFx0123 = _mm256_slli_epi32(vbb89ABCDEFx01234567, 4); |
77 | 0 | const __m256i vb01234567x4567 = _mm256_and_si256(vbb01234567x01234567, vmask); |
78 | 0 | const __m256i vb89ABCDEFx4567 = _mm256_and_si256(vbb89ABCDEFx01234567, vmask); |
79 | 0 | const __m256i vb01234567x0123 = _mm256_and_si256(vbs01234567x0123, vmask); |
80 | 0 | const __m256i vb89ABCDEFx0123 = _mm256_and_si256(vbs89ABCDEFx0123, vmask); |
81 | |
|
82 | 0 | vacc0x0123 = _mm256_dpbusd_epi32(vacc0x0123, va0x01234567, vb01234567x0123); |
83 | 0 | vacc0x4567 = _mm256_dpbusd_epi32(vacc0x4567, va0x01234567, vb89ABCDEFx0123); |
84 | 0 | vacc1x0x0123 = _mm256_dpbusd_epi32(vacc1x0x0123, va0x89ABCDEF, vb01234567x4567); |
85 | 0 | vacc1x0x4567 = _mm256_dpbusd_epi32(vacc1x0x4567, va0x89ABCDEF, vb89ABCDEFx4567); |
86 | |
|
87 | 0 | w = (const int8_t*) w + 64; |
88 | 0 | k -= 16 * sizeof(int8_t); |
89 | 0 | } |
90 | |
|
91 | 0 | if (k != 0) { |
92 | 0 | const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); |
93 | 0 | a0 += 8; |
94 | |
|
95 | 0 | const __m256i vbb01234567x01234567 = _mm256_load_si256(w); |
96 | 0 | const __m256i vbb89ABCDEFx01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); |
97 | 0 | const __m256i vb01234567x0123 = _mm256_slli_epi32(vbb01234567x01234567, 4); |
98 | 0 | const __m256i vb89ABCDEFx0123 = _mm256_slli_epi32(vbb89ABCDEFx01234567, 4); |
99 | |
|
100 | 0 | vacc0x0123 = _mm256_dpbusd_epi32(vacc0x0123, va0x01234567, vb01234567x0123); |
101 | 0 | vacc0x4567 = _mm256_dpbusd_epi32(vacc0x4567, va0x01234567, vb89ABCDEFx0123); |
102 | |
|
103 | 0 | w = (const int8_t*) w + 64; |
104 | 0 | k -= 8 * sizeof(int8_t); |
105 | 0 | } |
106 | 0 | vacc0x0123 = _mm256_add_epi32(vacc0x0123, vacc1x0x0123); |
107 | 0 | vacc0x4567 = _mm256_add_epi32(vacc0x4567, vacc1x0x4567); |
108 | | |
109 | | // Add adjacent pairs |
110 | 0 | const __m256i vsum0x02134657 = _mm256_hadd_epi32(vacc0x0123, vacc0x4567); |
111 | 0 | __m256i vacc0x01234567 = _mm256_permute4x64_epi64(vsum0x02134657, _MM_SHUFFLE(3, 1, 2, 0)); |
112 | |
|
113 | 0 | vacc0x01234567 = _mm256_srai_epi32(vacc0x01234567, 4); |
114 | 0 | __m256 vout0x01234567 = _mm256_cvtepi32_ps(vacc0x01234567); |
115 | |
|
116 | 0 | vout0x01234567 = _mm256_mul_ps(vout0x01234567, _mm256_set1_ps(quantization_params[0].inv_scale)); |
117 | |
|
118 | 0 | const __m256 vfilter_output_scale01234567 = _mm256_load_ps((const float*) w); |
119 | 0 | const __m256 vbias01234567 = _mm256_load_ps((const float*) w + 8); |
120 | 0 | w = (const float*) w + 16; |
121 | |
|
122 | 0 | vout0x01234567 = _mm256_fmadd_ps(vout0x01234567, vfilter_output_scale01234567, vbias01234567); |
123 | |
|
124 | 0 | vout0x01234567 = _mm256_max_ps(vout0x01234567, voutput_min); |
125 | |
|
126 | 0 | vout0x01234567 = _mm256_min_ps(vout0x01234567, voutput_max); |
127 | |
|
128 | 0 | __m128i vfp16out0x01234567 = _mm256_cvtps_ph(vout0x01234567, _MM_FROUND_TO_NEAREST_INT); |
129 | 0 | if XNN_LIKELY(nc >= 8) { |
130 | 0 | _mm_storeu_si128((__m128i*) c0, vfp16out0x01234567); |
131 | 0 | a0 = (const int8_t*) ((uintptr_t) a0 - kc); |
132 | 0 | c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); |
133 | 0 | nc -= 8; |
134 | 0 | } else { |
135 | | // Prepare mask for valid 16-bit elements (depends on nc). |
136 | 0 | const __mmask8 vmask = _cvtu32_mask8((UINT32_C(1) << nc) - 1); |
137 | 0 | _mm_mask_storeu_epi16(c0, vmask, vfp16out0x01234567); |
138 | 0 | nc = 0; |
139 | 0 | } |
140 | 0 | } while (nc != 0); |
141 | 0 | } |