Coverage Report

Created: 2026-04-01 07:11

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/xnnpack/src/f32-qc8w-gemm/gen/f32-qc8w-gemm-5x16-minmax-avx2-broadcast.c
Line
Count
Source
1
// clang-format off
2
// Auto-generated file. Do not edit!
3
//   Template: src/f32-gemm/avx-broadcast.c.in
4
//   Generator: tools/xngen
5
//
6
// Copyright 2019 Google LLC
7
//
8
// This source code is licensed under the BSD-style license found in the
9
// LICENSE file in the root directory of this source tree.
10
11
#include <assert.h>
12
#include <stddef.h>
13
#include <stdint.h>
14
15
#include <immintrin.h>
16
17
#include "src/xnnpack/common.h"
18
#include "src/xnnpack/gemm.h"
19
#include "src/xnnpack/microparams.h"
20
21
22
void xnn_f32_qc8w_gemm_minmax_ukernel_5x16__avx2_broadcast(
23
    size_t mr,
24
    size_t nc,
25
    size_t kc,
26
    const float* restrict a,
27
    size_t a_stride,
28
    const void* restrict w,
29
    float* restrict c,
30
    size_t cm_stride,
31
    size_t cn_stride,
32
    const struct xnn_f32_minmax_params* restrict params)
33
0
{
34
0
  assert(mr != 0);
35
0
  assert(mr <= 5);
36
0
  assert(nc != 0);
37
0
  assert(kc != 0);
38
0
  assert(kc % sizeof(float) == 0);
39
0
  assert(a != NULL);
40
0
  assert(w != NULL);
41
0
  assert(c != NULL);
42
43
0
  const float* a0 = a;
44
0
  float* c0 = c;
45
0
  const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
46
0
  float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
47
0
  if XNN_UNPREDICTABLE(mr < 2) {
48
0
    a1 = a0;
49
0
    c1 = c0;
50
0
  }
51
0
  const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
52
0
  float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
53
0
  if XNN_UNPREDICTABLE(mr <= 2) {
54
0
    a2 = a1;
55
0
    c2 = c1;
56
0
  }
57
0
  const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
58
0
  float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
59
0
  if XNN_UNPREDICTABLE(mr < 4) {
60
0
    a3 = a2;
61
0
    c3 = c2;
62
0
  }
63
0
  const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
64
0
  float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
65
0
  if XNN_UNPREDICTABLE(mr <= 4) {
66
0
    a4 = a3;
67
0
    c4 = c3;
68
0
  }
69
70
0
  const __m256 vmin = _mm256_set1_ps(params->scalar.min);
71
0
  const __m256 vmax = _mm256_set1_ps(params->scalar.max);
72
0
  XNN_FORCE_REALIZATION(vmin);
73
0
  XNN_FORCE_REALIZATION(vmax);
74
75
0
  do {
76
0
    __m256 vacc0x01234567 = _mm256_loadu_ps((const float*) w + 0);
77
0
    __m256 vacc0x89ABCDEF = _mm256_loadu_ps((const float*) w + 8);
78
0
    __m256 vacc1x01234567 = vacc0x01234567;
79
0
    __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
80
0
    __m256 vacc2x01234567 = vacc0x01234567;
81
0
    __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
82
0
    __m256 vacc3x01234567 = vacc0x01234567;
83
0
    __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
84
0
    __m256 vacc4x01234567 = vacc0x01234567;
85
0
    __m256 vacc4x89ABCDEF = vacc0x89ABCDEF;
86
0
    w = (const float*) w + 16;
87
88
0
    size_t k = kc;
89
0
    do {
90
0
      const __m256 va0 = _mm256_broadcast_ss(a0);
91
0
      a0 += 1;
92
0
      const __m256 va1 = _mm256_broadcast_ss(a1);
93
0
      a1 += 1;
94
0
      const __m256 va2 = _mm256_broadcast_ss(a2);
95
0
      a2 += 1;
96
0
      const __m256 va3 = _mm256_broadcast_ss(a3);
97
0
      a3 += 1;
98
0
      const __m256 va4 = _mm256_broadcast_ss(a4);
99
0
      a4 += 1;
100
101
0
      const __m256i vbi01234567 = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) w));
102
0
      const __m256i vbi89ABCDEF = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*) ((const int8_t*) w + 8)));
103
0
      w = (const int8_t*) w + 16;
104
0
      const __m256 vb01234567 = _mm256_cvtepi32_ps(vbi01234567);
105
0
      const __m256 vb89ABCDEF = _mm256_cvtepi32_ps(vbi89ABCDEF);
106
107
0
      vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567, vacc0x01234567);
108
0
      vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567, vacc1x01234567);
109
0
      vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567, vacc2x01234567);
110
0
      vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567, vacc3x01234567);
111
0
      vacc4x01234567 = _mm256_fmadd_ps(va4, vb01234567, vacc4x01234567);
112
0
      vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF);
113
0
      vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEF, vacc1x89ABCDEF);
114
0
      vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEF, vacc2x89ABCDEF);
115
0
      vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEF, vacc3x89ABCDEF);
116
0
      vacc4x89ABCDEF = _mm256_fmadd_ps(va4, vb89ABCDEF, vacc4x89ABCDEF);
117
118
0
      k -= sizeof(float);
119
0
    } while (k != 0);
120
121
0
    const __m256 vscale01234567 = _mm256_loadu_ps((const float*) w + 0);
122
0
    vacc0x01234567 = _mm256_mul_ps(vacc0x01234567, vscale01234567);
123
0
    vacc1x01234567 = _mm256_mul_ps(vacc1x01234567, vscale01234567);
124
0
    vacc2x01234567 = _mm256_mul_ps(vacc2x01234567, vscale01234567);
125
0
    vacc3x01234567 = _mm256_mul_ps(vacc3x01234567, vscale01234567);
126
0
    vacc4x01234567 = _mm256_mul_ps(vacc4x01234567, vscale01234567);
127
0
    const __m256 vscale89ABCDEF = _mm256_loadu_ps((const float*) w + 8);
128
0
    vacc0x89ABCDEF = _mm256_mul_ps(vacc0x89ABCDEF, vscale89ABCDEF);
129
0
    vacc1x89ABCDEF = _mm256_mul_ps(vacc1x89ABCDEF, vscale89ABCDEF);
130
0
    vacc2x89ABCDEF = _mm256_mul_ps(vacc2x89ABCDEF, vscale89ABCDEF);
131
0
    vacc3x89ABCDEF = _mm256_mul_ps(vacc3x89ABCDEF, vscale89ABCDEF);
132
0
    vacc4x89ABCDEF = _mm256_mul_ps(vacc4x89ABCDEF, vscale89ABCDEF);
133
0
    w = (const float*) w + 16;
134
0
    vacc0x01234567 = _mm256_max_ps(vmin, vacc0x01234567);
135
0
    vacc1x01234567 = _mm256_max_ps(vmin, vacc1x01234567);
136
0
    vacc2x01234567 = _mm256_max_ps(vmin, vacc2x01234567);
137
0
    vacc3x01234567 = _mm256_max_ps(vmin, vacc3x01234567);
138
0
    vacc4x01234567 = _mm256_max_ps(vmin, vacc4x01234567);
139
0
    vacc0x89ABCDEF = _mm256_max_ps(vmin, vacc0x89ABCDEF);
140
0
    vacc1x89ABCDEF = _mm256_max_ps(vmin, vacc1x89ABCDEF);
141
0
    vacc2x89ABCDEF = _mm256_max_ps(vmin, vacc2x89ABCDEF);
142
0
    vacc3x89ABCDEF = _mm256_max_ps(vmin, vacc3x89ABCDEF);
143
0
    vacc4x89ABCDEF = _mm256_max_ps(vmin, vacc4x89ABCDEF);
144
145
0
    vacc0x01234567 = _mm256_min_ps(vmax, vacc0x01234567);
146
0
    vacc1x01234567 = _mm256_min_ps(vmax, vacc1x01234567);
147
0
    vacc2x01234567 = _mm256_min_ps(vmax, vacc2x01234567);
148
0
    vacc3x01234567 = _mm256_min_ps(vmax, vacc3x01234567);
149
0
    vacc4x01234567 = _mm256_min_ps(vmax, vacc4x01234567);
150
0
    vacc0x89ABCDEF = _mm256_min_ps(vmax, vacc0x89ABCDEF);
151
0
    vacc1x89ABCDEF = _mm256_min_ps(vmax, vacc1x89ABCDEF);
152
0
    vacc2x89ABCDEF = _mm256_min_ps(vmax, vacc2x89ABCDEF);
153
0
    vacc3x89ABCDEF = _mm256_min_ps(vmax, vacc3x89ABCDEF);
154
0
    vacc4x89ABCDEF = _mm256_min_ps(vmax, vacc4x89ABCDEF);
155
156
0
    if XNN_LIKELY(nc >= 16) {
157
0
      _mm256_storeu_ps(c0, vacc0x01234567);
158
0
      _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
159
0
      c0 = (float*) ((uintptr_t) c0 + cn_stride);
160
0
      _mm256_storeu_ps(c1, vacc1x01234567);
161
0
      _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF);
162
0
      c1 = (float*) ((uintptr_t) c1 + cn_stride);
163
0
      _mm256_storeu_ps(c2, vacc2x01234567);
164
0
      _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF);
165
0
      c2 = (float*) ((uintptr_t) c2 + cn_stride);
166
0
      _mm256_storeu_ps(c3, vacc3x01234567);
167
0
      _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF);
168
0
      c3 = (float*) ((uintptr_t) c3 + cn_stride);
169
0
      _mm256_storeu_ps(c4, vacc4x01234567);
170
0
      _mm256_storeu_ps(c4 + 8, vacc4x89ABCDEF);
171
0
      c4 = (float*) ((uintptr_t) c4 + cn_stride);
172
173
0
      a0 = (const float*) ((uintptr_t) a0 - kc);
174
0
      a1 = (const float*) ((uintptr_t) a1 - kc);
175
0
      a2 = (const float*) ((uintptr_t) a2 - kc);
176
0
      a3 = (const float*) ((uintptr_t) a3 - kc);
177
0
      a4 = (const float*) ((uintptr_t) a4 - kc);
178
179
0
      nc -= 16;
180
0
    } else {
181
0
      if (nc & 8) {
182
0
        _mm256_storeu_ps(c0, vacc0x01234567);
183
0
        _mm256_storeu_ps(c1, vacc1x01234567);
184
0
        _mm256_storeu_ps(c2, vacc2x01234567);
185
0
        _mm256_storeu_ps(c3, vacc3x01234567);
186
0
        _mm256_storeu_ps(c4, vacc4x01234567);
187
188
0
        vacc0x01234567 = vacc0x89ABCDEF;
189
0
        vacc1x01234567 = vacc1x89ABCDEF;
190
0
        vacc2x01234567 = vacc2x89ABCDEF;
191
0
        vacc3x01234567 = vacc3x89ABCDEF;
192
0
        vacc4x01234567 = vacc4x89ABCDEF;
193
194
0
        c0 += 8;
195
0
        c1 += 8;
196
0
        c2 += 8;
197
0
        c3 += 8;
198
0
        c4 += 8;
199
0
      }
200
0
      __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
201
0
      __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567);
202
0
      __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567);
203
0
      __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567);
204
0
      __m128 vacc4x0123 = _mm256_castps256_ps128(vacc4x01234567);
205
0
      if (nc & 4) {
206
0
        _mm_storeu_ps(c0, vacc0x0123);
207
0
        _mm_storeu_ps(c1, vacc1x0123);
208
0
        _mm_storeu_ps(c2, vacc2x0123);
209
0
        _mm_storeu_ps(c3, vacc3x0123);
210
0
        _mm_storeu_ps(c4, vacc4x0123);
211
212
0
        vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
213
0
        vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1);
214
0
        vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1);
215
0
        vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1);
216
0
        vacc4x0123 = _mm256_extractf128_ps(vacc4x01234567, 1);
217
218
0
        c0 += 4;
219
0
        c1 += 4;
220
0
        c2 += 4;
221
0
        c3 += 4;
222
0
        c4 += 4;
223
0
      }
224
0
      if (nc & 2) {
225
0
        _mm_storel_pi((__m64*) c0, vacc0x0123);
226
0
        _mm_storel_pi((__m64*) c1, vacc1x0123);
227
0
        _mm_storel_pi((__m64*) c2, vacc2x0123);
228
0
        _mm_storel_pi((__m64*) c3, vacc3x0123);
229
0
        _mm_storel_pi((__m64*) c4, vacc4x0123);
230
231
0
        vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
232
0
        vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
233
0
        vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
234
0
        vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
235
0
        vacc4x0123 = _mm_movehl_ps(vacc4x0123, vacc4x0123);
236
237
0
        c0 += 2;
238
0
        c1 += 2;
239
0
        c2 += 2;
240
0
        c3 += 2;
241
0
        c4 += 2;
242
0
      }
243
0
      if (nc & 1) {
244
0
        _mm_store_ss(c0, vacc0x0123);
245
0
        _mm_store_ss(c1, vacc1x0123);
246
0
        _mm_store_ss(c2, vacc2x0123);
247
0
        _mm_store_ss(c3, vacc3x0123);
248
0
        _mm_store_ss(c4, vacc4x0123);
249
0
      }
250
251
0
      nc = 0;
252
0
    }
253
0
  } while (nc != 0);
254
0
}