Coverage Report

Created: 2025-09-27 07:04

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/xnnpack/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x4c8-minmax-sse41-ld128.c
Line
Count
Source
1
// clang-format off
2
// Auto-generated file. Do not edit!
3
//   Template: src/qs8-gemm/MRx4c8-sse.c.in
4
//   Generator: tools/xngen
5
//
6
// Copyright 2020 Google LLC
7
//
8
// This source code is licensed under the BSD-style license found in the
9
// LICENSE file in the root directory of this source tree.
10
11
#include <assert.h>
12
#include <stddef.h>
13
#include <stdint.h>
14
15
#include <smmintrin.h>
16
17
#include "src/xnnpack/common.h"
18
#include "src/xnnpack/gemm.h"
19
#include "src/xnnpack/math.h"
20
#include "src/xnnpack/microparams.h"
21
22
23
void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld128(
24
    size_t mr,
25
    size_t nc,
26
    size_t kc,
27
    const int8_t* restrict a,
28
    size_t a_stride,
29
    const void* restrict w,
30
    float* restrict c,
31
    size_t cm_stride,
32
    size_t cn_stride,
33
    const struct xnn_f32_qb4w_minmax_params* restrict params,
34
    const struct xnn_qd8_quantization_params* restrict quantization_params) XNN_OOB_READS
35
0
{
36
0
  assert(mr != 0);
37
0
  assert(mr <= 1);
38
0
  assert(nc != 0);
39
0
  assert(kc != 0);
40
0
  assert(kc % sizeof(int8_t) == 0);
41
0
  assert(a != NULL);
42
0
  assert(w != NULL);
43
0
  assert(c != NULL);
44
45
0
  size_t bl = params->scalar.blocksize;
46
0
  assert(bl <= round_up_po2(kc, 2));
47
0
  assert(bl != 0);
48
0
  assert(bl % 32 == 0);
49
0
  kc = round_up_po2(kc, 8 * sizeof(int8_t));
50
0
  const int8_t* a0 = a;
51
0
  float* c0 = c;
52
53
0
  const __m128 vmin = _mm_set1_ps(params->scalar.min);
54
0
  const __m128 vmax = _mm_set1_ps(params->scalar.max);
55
0
  XNN_FORCE_REALIZATION(vmin);
56
0
  XNN_FORCE_REALIZATION(vmax);
57
58
0
  const __m128i vmask = _mm_set1_epi8(0xF0);
59
0
  XNN_FORCE_REALIZATION(vmask);
60
61
0
  do {
62
0
    const __m128 vksum = _mm_loadu_ps((const float*) w);
63
0
    __m128i vinput_zero_point0 = _mm_cvtsi32_si128(*((const int*) &quantization_params[0].zero_point));
64
0
    vinput_zero_point0 = _mm_shuffle_epi32(vinput_zero_point0, _MM_SHUFFLE(0, 0, 0, 0));
65
66
0
    __m128 vinput_zero_point0_float = _mm_cvtepi32_ps(vinput_zero_point0);
67
0
    __m128 vout0x0123 = _mm_mul_ps(vksum, vinput_zero_point0_float);
68
0
    w = (const int32_t*) w + 4;
69
70
0
    for (size_t kb=0; kb < kc; kb += bl) {
71
0
      __m128i vacc0x0 = _mm_setzero_si128();
72
0
      __m128i vacc0x1 = _mm_setzero_si128();
73
0
      __m128i vacc0x2 = _mm_setzero_si128();
74
0
      __m128i vacc0x3 = _mm_setzero_si128();
75
0
      size_t k = bl;
76
77
0
      while (k >= 16 * sizeof(int8_t)) {
78
0
        const __m128i va0c0 = _mm_loadl_epi64((const __m128i*) a0);
79
0
        const __m128i vxa0c0 = _mm_cvtepi8_epi16(va0c0);
80
0
        a0 += 8;
81
82
0
        const __m128i vb01c01 = _mm_loadu_si128((const __m128i*) w);
83
0
        const __m128i vbs01c0 = _mm_slli_epi32(vb01c01, 4);
84
0
        const __m128i vb01c0 = _mm_and_si128(vbs01c0, vmask);
85
0
        const __m128i vsb01c0 = _mm_cmpgt_epi8(_mm_setzero_si128(), vb01c0);
86
0
        const __m128i vxb0c0 = _mm_unpacklo_epi8(vb01c0, vsb01c0);
87
0
        const __m128i vxb1c0 = _mm_unpackhi_epi8(vb01c0, vsb01c0);
88
89
0
        vacc0x0 = _mm_add_epi32(vacc0x0, _mm_madd_epi16(vxa0c0, vxb0c0));
90
0
        vacc0x1 = _mm_add_epi32(vacc0x1, _mm_madd_epi16(vxa0c0, vxb1c0));
91
0
        const __m128i vb23c01 = _mm_loadu_si128((const __m128i*) ((const int8_t*) w + 16));
92
0
        const __m128i vbs23c0 = _mm_slli_epi32(vb23c01, 4);
93
0
        const __m128i vb23c0 = _mm_and_si128(vbs23c0, vmask);
94
0
        const __m128i vsb23c0 = _mm_cmpgt_epi8(_mm_setzero_si128(), vb23c0);
95
0
        const __m128i vxb2c0 = _mm_unpacklo_epi8(vb23c0, vsb23c0);
96
0
        const __m128i vxb3c0 = _mm_unpackhi_epi8(vb23c0, vsb23c0);
97
98
0
        vacc0x2 = _mm_add_epi32(vacc0x2, _mm_madd_epi16(vxa0c0, vxb2c0));
99
0
        vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0c0, vxb3c0));
100
101
0
        const __m128i va0c1 = _mm_loadl_epi64((const __m128i*) a0);
102
0
        const __m128i vxa0c1 = _mm_cvtepi8_epi16(va0c1);
103
0
        a0 += 8;
104
105
0
        const __m128i vb01c1 = _mm_and_si128(vb01c01, vmask);
106
0
        const __m128i vsb01c1 = _mm_cmpgt_epi8(_mm_setzero_si128(), vb01c1);
107
0
        const __m128i vxb0c1 = _mm_unpacklo_epi8(vb01c1, vsb01c1);
108
0
        const __m128i vxb1c1 = _mm_unpackhi_epi8(vb01c1, vsb01c1);
109
110
0
        vacc0x0 = _mm_add_epi32(vacc0x0, _mm_madd_epi16(vxa0c1, vxb0c1));
111
0
        vacc0x1 = _mm_add_epi32(vacc0x1, _mm_madd_epi16(vxa0c1, vxb1c1));
112
0
        const __m128i vb23c1 = _mm_and_si128(vb23c01, vmask);
113
0
        const __m128i vsb23c1 = _mm_cmpgt_epi8(_mm_setzero_si128(), vb23c1);
114
0
        const __m128i vxb2c1 = _mm_unpacklo_epi8(vb23c1, vsb23c1);
115
0
        const __m128i vxb3c1 = _mm_unpackhi_epi8(vb23c1, vsb23c1);
116
117
0
        vacc0x2 = _mm_add_epi32(vacc0x2, _mm_madd_epi16(vxa0c1, vxb2c1));
118
0
        vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0c1, vxb3c1));
119
120
121
0
        w = (const int8_t*) w + 32;
122
0
        k -= 16 * sizeof(int8_t);
123
0
      }
124
125
0
      while (k >= 8 * sizeof(int8_t)) {
126
0
        const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
127
0
        const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
128
0
        a0 += 8;
129
130
0
        __m128i vb01 = _mm_loadu_si128((const __m128i*) w);
131
0
        vb01 = _mm_slli_epi32(vb01, 4);
132
0
        vb01 = _mm_and_si128(vb01, vmask);
133
134
0
        const __m128i vxbm1 = _mm_unpackhi_epi8(vb01, vb01);
135
0
        const __m128i vxb0 = _mm_cvtepi8_epi16(vb01);
136
0
        const __m128i vxb1 = _mm_srai_epi16(vxbm1, 8);
137
138
0
        vacc0x0 = _mm_add_epi32(vacc0x0, _mm_madd_epi16(vxa0, vxb0));
139
0
        vacc0x1 = _mm_add_epi32(vacc0x1, _mm_madd_epi16(vxa0, vxb1));
140
0
        __m128i vb23 = _mm_loadu_si128((const __m128i*) ((const int8_t*) w + 16));
141
0
        vb23 = _mm_slli_epi32(vb23, 4);
142
0
        vb23 = _mm_and_si128(vb23, vmask);
143
144
0
        const __m128i vxbm3 = _mm_unpackhi_epi8(vb23, vb23);
145
0
        const __m128i vxb2 = _mm_cvtepi8_epi16(vb23);
146
0
        const __m128i vxb3 = _mm_srai_epi16(vxbm3, 8);
147
148
0
        vacc0x2 = _mm_add_epi32(vacc0x2, _mm_madd_epi16(vxa0, vxb2));
149
0
        vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3));
150
151
0
        w = (const int8_t*) w + 32;
152
0
        k -= 8 * sizeof(int8_t);
153
0
      }
154
      // accumulate float
155
0
      const __m128 vfilter_output_scale0123 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*) w)), 16));
156
0
      w = (const uint16_t*) w + 4;
157
158
0
      const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
159
0
      const __m128i vacc0x23 = _mm_hadd_epi32(vacc0x2, vacc0x3);
160
161
0
        __m128i vacc0x0123 = _mm_hadd_epi32(vacc0x01, vacc0x23);
162
163
0
      vout0x0123 = _mm_add_ps(vout0x0123, _mm_mul_ps(_mm_cvtepi32_ps(vacc0x0123), vfilter_output_scale0123));
164
0
    }
165
166
0
    const __m128 vinput_scale0 = _mm_load1_ps(&quantization_params[0].inv_scale);
167
168
0
    vout0x0123 = _mm_mul_ps(vout0x0123, vinput_scale0);
169
170
171
0
    const __m128 vbias0123 = _mm_loadu_ps((const float*) w);
172
0
    w = (const float*) w + 4;
173
0
    vout0x0123 = _mm_add_ps(vout0x0123, vbias0123);
174
175
0
    vout0x0123 = _mm_max_ps(vout0x0123, vmin);
176
177
0
    vout0x0123 = _mm_min_ps(vout0x0123, vmax);
178
179
0
    if XNN_LIKELY(nc >= 4) {
180
0
      _mm_storeu_ps(c0, vout0x0123);
181
182
0
      a0 = (const int8_t*) ((uintptr_t) a0 - kc);
183
184
0
      c0 = (float*) ((uintptr_t) c0 + cn_stride);
185
186
0
      nc -= 4;
187
0
    } else {
188
0
      if (nc & 2) {
189
0
        _mm_storel_pi((__m64*) c0, vout0x0123);
190
0
        vout0x0123 = _mm_unpackhi_ps(vout0x0123, vout0x0123);
191
0
        c0 += 2;
192
0
      }
193
0
      if (nc & 1) {
194
0
        _mm_store_ss(c0, vout0x0123);
195
0
      }
196
0
      nc = 0;
197
0
    }
198
0
  } while (nc != 0);
199
0
}