Coverage Report

Created: 2026-02-14 06:31

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/xnnpack/src/f32-qc8w-gemm/gen/f32-qc8w-gemm-5x16-minmax-avx-broadcast.c
Line
Count
Source
1
// clang-format off
2
// Auto-generated file. Do not edit!
3
//   Template: src/f32-gemm/avx-broadcast.c.in
4
//   Generator: tools/xngen
5
//
6
// Copyright 2019 Google LLC
7
//
8
// This source code is licensed under the BSD-style license found in the
9
// LICENSE file in the root directory of this source tree.
10
11
#include <assert.h>
12
#include <stddef.h>
13
#include <stdint.h>
14
15
#include <immintrin.h>
16
17
#include "src/xnnpack/common.h"
18
#include "src/xnnpack/gemm.h"
19
#include "src/xnnpack/unaligned.h"
20
#include "src/xnnpack/microparams.h"
21
22
23
void xnn_f32_qc8w_gemm_minmax_ukernel_5x16__avx_broadcast(
24
    size_t mr,
25
    size_t nc,
26
    size_t kc,
27
    const float* restrict a,
28
    size_t a_stride,
29
    const void* restrict w,
30
    float* restrict c,
31
    size_t cm_stride,
32
    size_t cn_stride,
33
    const struct xnn_f32_minmax_params* restrict params)
34
0
{
35
0
  assert(mr != 0);
36
0
  assert(mr <= 5);
37
0
  assert(nc != 0);
38
0
  assert(kc != 0);
39
0
  assert(kc % sizeof(float) == 0);
40
0
  assert(a != NULL);
41
0
  assert(w != NULL);
42
0
  assert(c != NULL);
43
44
0
  const float* a0 = a;
45
0
  float* c0 = c;
46
0
  const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
47
0
  float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
48
0
  if XNN_UNPREDICTABLE(mr < 2) {
49
0
    a1 = a0;
50
0
    c1 = c0;
51
0
  }
52
0
  const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
53
0
  float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
54
0
  if XNN_UNPREDICTABLE(mr <= 2) {
55
0
    a2 = a1;
56
0
    c2 = c1;
57
0
  }
58
0
  const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
59
0
  float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
60
0
  if XNN_UNPREDICTABLE(mr < 4) {
61
0
    a3 = a2;
62
0
    c3 = c2;
63
0
  }
64
0
  const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
65
0
  float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
66
0
  if XNN_UNPREDICTABLE(mr <= 4) {
67
0
    a4 = a3;
68
0
    c4 = c3;
69
0
  }
70
71
0
  const __m256 vmin = _mm256_set1_ps(params->scalar.min);
72
0
  const __m256 vmax = _mm256_set1_ps(params->scalar.max);
73
0
  XNN_FORCE_REALIZATION(vmin);
74
0
  XNN_FORCE_REALIZATION(vmax);
75
76
0
  do {
77
0
    __m256 vacc0x01234567 = _mm256_loadu_ps((const float*) w + 0);
78
0
    __m256 vacc0x89ABCDEF = _mm256_loadu_ps((const float*) w + 8);
79
0
    __m256 vacc1x01234567 = vacc0x01234567;
80
0
    __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
81
0
    __m256 vacc2x01234567 = vacc0x01234567;
82
0
    __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
83
0
    __m256 vacc3x01234567 = vacc0x01234567;
84
0
    __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
85
0
    __m256 vacc4x01234567 = vacc0x01234567;
86
0
    __m256 vacc4x89ABCDEF = vacc0x89ABCDEF;
87
0
    w = (const float*) w + 16;
88
89
0
    size_t k = kc;
90
0
    do {
91
0
      const __m256 va0 = _mm256_broadcast_ss(a0);
92
0
      a0 += 1;
93
0
      const __m256 va1 = _mm256_broadcast_ss(a1);
94
0
      a1 += 1;
95
0
      const __m256 va2 = _mm256_broadcast_ss(a2);
96
0
      a2 += 1;
97
0
      const __m256 va3 = _mm256_broadcast_ss(a3);
98
0
      a3 += 1;
99
0
      const __m256 va4 = _mm256_broadcast_ss(a4);
100
0
      a4 += 1;
101
102
0
      const __m128i vbi0123 = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_u32((const int8_t*) w)));
103
0
      const __m128i vbi4567 = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_u32((const int8_t*) w + 4)));
104
0
      const __m128i vbi89AB = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_u32((const int8_t*) w + 8)));
105
0
      const __m128i vbiCDEF = _mm_cvtepi8_epi32(_mm_cvtsi32_si128((int) unaligned_load_u32((const int8_t*) w + 12)));
106
0
      const __m256i vbi01234567 = _mm256_castps_si256(_mm256_insertf128_ps(_mm256_castsi256_ps(_mm256_castsi128_si256(vbi0123)), _mm_castsi128_ps(vbi4567), 1));
107
0
      const __m256i vbi89ABCDEF = _mm256_castps_si256(_mm256_insertf128_ps(_mm256_castsi256_ps(_mm256_castsi128_si256(vbi89AB)), _mm_castsi128_ps(vbiCDEF), 1));
108
0
      w = (const int8_t*) w + 16;
109
0
      const __m256 vb01234567 = _mm256_cvtepi32_ps(vbi01234567);
110
0
      const __m256 vb89ABCDEF = _mm256_cvtepi32_ps(vbi89ABCDEF);
111
112
0
      vacc0x01234567 = _mm256_add_ps(vacc0x01234567, _mm256_mul_ps(va0, vb01234567));
113
0
      vacc1x01234567 = _mm256_add_ps(vacc1x01234567, _mm256_mul_ps(va1, vb01234567));
114
0
      vacc2x01234567 = _mm256_add_ps(vacc2x01234567, _mm256_mul_ps(va2, vb01234567));
115
0
      vacc3x01234567 = _mm256_add_ps(vacc3x01234567, _mm256_mul_ps(va3, vb01234567));
116
0
      vacc4x01234567 = _mm256_add_ps(vacc4x01234567, _mm256_mul_ps(va4, vb01234567));
117
0
      vacc0x89ABCDEF = _mm256_add_ps(vacc0x89ABCDEF, _mm256_mul_ps(va0, vb89ABCDEF));
118
0
      vacc1x89ABCDEF = _mm256_add_ps(vacc1x89ABCDEF, _mm256_mul_ps(va1, vb89ABCDEF));
119
0
      vacc2x89ABCDEF = _mm256_add_ps(vacc2x89ABCDEF, _mm256_mul_ps(va2, vb89ABCDEF));
120
0
      vacc3x89ABCDEF = _mm256_add_ps(vacc3x89ABCDEF, _mm256_mul_ps(va3, vb89ABCDEF));
121
0
      vacc4x89ABCDEF = _mm256_add_ps(vacc4x89ABCDEF, _mm256_mul_ps(va4, vb89ABCDEF));
122
123
0
      k -= sizeof(float);
124
0
    } while (k != 0);
125
126
0
    const __m256 vscale01234567 = _mm256_loadu_ps((const float*) w + 0);
127
0
    vacc0x01234567 = _mm256_mul_ps(vacc0x01234567, vscale01234567);
128
0
    vacc1x01234567 = _mm256_mul_ps(vacc1x01234567, vscale01234567);
129
0
    vacc2x01234567 = _mm256_mul_ps(vacc2x01234567, vscale01234567);
130
0
    vacc3x01234567 = _mm256_mul_ps(vacc3x01234567, vscale01234567);
131
0
    vacc4x01234567 = _mm256_mul_ps(vacc4x01234567, vscale01234567);
132
0
    const __m256 vscale89ABCDEF = _mm256_loadu_ps((const float*) w + 8);
133
0
    vacc0x89ABCDEF = _mm256_mul_ps(vacc0x89ABCDEF, vscale89ABCDEF);
134
0
    vacc1x89ABCDEF = _mm256_mul_ps(vacc1x89ABCDEF, vscale89ABCDEF);
135
0
    vacc2x89ABCDEF = _mm256_mul_ps(vacc2x89ABCDEF, vscale89ABCDEF);
136
0
    vacc3x89ABCDEF = _mm256_mul_ps(vacc3x89ABCDEF, vscale89ABCDEF);
137
0
    vacc4x89ABCDEF = _mm256_mul_ps(vacc4x89ABCDEF, vscale89ABCDEF);
138
0
    w = (const float*) w + 16;
139
0
    vacc0x01234567 = _mm256_max_ps(vmin, vacc0x01234567);
140
0
    vacc1x01234567 = _mm256_max_ps(vmin, vacc1x01234567);
141
0
    vacc2x01234567 = _mm256_max_ps(vmin, vacc2x01234567);
142
0
    vacc3x01234567 = _mm256_max_ps(vmin, vacc3x01234567);
143
0
    vacc4x01234567 = _mm256_max_ps(vmin, vacc4x01234567);
144
0
    vacc0x89ABCDEF = _mm256_max_ps(vmin, vacc0x89ABCDEF);
145
0
    vacc1x89ABCDEF = _mm256_max_ps(vmin, vacc1x89ABCDEF);
146
0
    vacc2x89ABCDEF = _mm256_max_ps(vmin, vacc2x89ABCDEF);
147
0
    vacc3x89ABCDEF = _mm256_max_ps(vmin, vacc3x89ABCDEF);
148
0
    vacc4x89ABCDEF = _mm256_max_ps(vmin, vacc4x89ABCDEF);
149
150
0
    vacc0x01234567 = _mm256_min_ps(vmax, vacc0x01234567);
151
0
    vacc1x01234567 = _mm256_min_ps(vmax, vacc1x01234567);
152
0
    vacc2x01234567 = _mm256_min_ps(vmax, vacc2x01234567);
153
0
    vacc3x01234567 = _mm256_min_ps(vmax, vacc3x01234567);
154
0
    vacc4x01234567 = _mm256_min_ps(vmax, vacc4x01234567);
155
0
    vacc0x89ABCDEF = _mm256_min_ps(vmax, vacc0x89ABCDEF);
156
0
    vacc1x89ABCDEF = _mm256_min_ps(vmax, vacc1x89ABCDEF);
157
0
    vacc2x89ABCDEF = _mm256_min_ps(vmax, vacc2x89ABCDEF);
158
0
    vacc3x89ABCDEF = _mm256_min_ps(vmax, vacc3x89ABCDEF);
159
0
    vacc4x89ABCDEF = _mm256_min_ps(vmax, vacc4x89ABCDEF);
160
161
0
    if XNN_LIKELY(nc >= 16) {
162
0
      _mm256_storeu_ps(c0, vacc0x01234567);
163
0
      _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
164
0
      c0 = (float*) ((uintptr_t) c0 + cn_stride);
165
0
      _mm256_storeu_ps(c1, vacc1x01234567);
166
0
      _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF);
167
0
      c1 = (float*) ((uintptr_t) c1 + cn_stride);
168
0
      _mm256_storeu_ps(c2, vacc2x01234567);
169
0
      _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF);
170
0
      c2 = (float*) ((uintptr_t) c2 + cn_stride);
171
0
      _mm256_storeu_ps(c3, vacc3x01234567);
172
0
      _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF);
173
0
      c3 = (float*) ((uintptr_t) c3 + cn_stride);
174
0
      _mm256_storeu_ps(c4, vacc4x01234567);
175
0
      _mm256_storeu_ps(c4 + 8, vacc4x89ABCDEF);
176
0
      c4 = (float*) ((uintptr_t) c4 + cn_stride);
177
178
0
      a0 = (const float*) ((uintptr_t) a0 - kc);
179
0
      a1 = (const float*) ((uintptr_t) a1 - kc);
180
0
      a2 = (const float*) ((uintptr_t) a2 - kc);
181
0
      a3 = (const float*) ((uintptr_t) a3 - kc);
182
0
      a4 = (const float*) ((uintptr_t) a4 - kc);
183
184
0
      nc -= 16;
185
0
    } else {
186
0
      if (nc & 8) {
187
0
        _mm256_storeu_ps(c0, vacc0x01234567);
188
0
        _mm256_storeu_ps(c1, vacc1x01234567);
189
0
        _mm256_storeu_ps(c2, vacc2x01234567);
190
0
        _mm256_storeu_ps(c3, vacc3x01234567);
191
0
        _mm256_storeu_ps(c4, vacc4x01234567);
192
193
0
        vacc0x01234567 = vacc0x89ABCDEF;
194
0
        vacc1x01234567 = vacc1x89ABCDEF;
195
0
        vacc2x01234567 = vacc2x89ABCDEF;
196
0
        vacc3x01234567 = vacc3x89ABCDEF;
197
0
        vacc4x01234567 = vacc4x89ABCDEF;
198
199
0
        c0 += 8;
200
0
        c1 += 8;
201
0
        c2 += 8;
202
0
        c3 += 8;
203
0
        c4 += 8;
204
0
      }
205
0
      __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
206
0
      __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567);
207
0
      __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567);
208
0
      __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567);
209
0
      __m128 vacc4x0123 = _mm256_castps256_ps128(vacc4x01234567);
210
0
      if (nc & 4) {
211
0
        _mm_storeu_ps(c0, vacc0x0123);
212
0
        _mm_storeu_ps(c1, vacc1x0123);
213
0
        _mm_storeu_ps(c2, vacc2x0123);
214
0
        _mm_storeu_ps(c3, vacc3x0123);
215
0
        _mm_storeu_ps(c4, vacc4x0123);
216
217
0
        vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
218
0
        vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1);
219
0
        vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1);
220
0
        vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1);
221
0
        vacc4x0123 = _mm256_extractf128_ps(vacc4x01234567, 1);
222
223
0
        c0 += 4;
224
0
        c1 += 4;
225
0
        c2 += 4;
226
0
        c3 += 4;
227
0
        c4 += 4;
228
0
      }
229
0
      if (nc & 2) {
230
0
        _mm_storel_pi((__m64*) c0, vacc0x0123);
231
0
        _mm_storel_pi((__m64*) c1, vacc1x0123);
232
0
        _mm_storel_pi((__m64*) c2, vacc2x0123);
233
0
        _mm_storel_pi((__m64*) c3, vacc3x0123);
234
0
        _mm_storel_pi((__m64*) c4, vacc4x0123);
235
236
0
        vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
237
0
        vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
238
0
        vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
239
0
        vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
240
0
        vacc4x0123 = _mm_movehl_ps(vacc4x0123, vacc4x0123);
241
242
0
        c0 += 2;
243
0
        c1 += 2;
244
0
        c2 += 2;
245
0
        c3 += 2;
246
0
        c4 += 2;
247
0
      }
248
0
      if (nc & 1) {
249
0
        _mm_store_ss(c0, vacc0x0123);
250
0
        _mm_store_ss(c1, vacc1x0123);
251
0
        _mm_store_ss(c2, vacc2x0123);
252
0
        _mm_store_ss(c3, vacc3x0123);
253
0
        _mm_store_ss(c4, vacc4x0123);
254
0
      }
255
256
0
      nc = 0;
257
0
    }
258
0
  } while (nc != 0);
259
0
}