Coverage Report

Created: 2025-10-13 07:19

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/xnnpack/src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u8.c
Line
Count
Source
1
// clang-format off
2
// Auto-generated file. Do not edit!
3
//   Template: src/x32-packw/gio-avx.c.in
4
//   Generator: tools/xngen
5
//
6
// Copyright 2024 Google LLC
7
//
8
// This source code is licensed under the BSD-style license found in the
9
// LICENSE file in the root directory of this source tree.
10
11
#include <assert.h>
12
#include <stddef.h>
13
#include <stdint.h>
14
15
#include <immintrin.h>
16
17
#include "src/xnnpack/common.h"
18
#include "src/xnnpack/packw.h"
19
20
21
void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8(
22
  size_t g,
23
  size_t nc,
24
  size_t kc,
25
  size_t nr,
26
  size_t kr,
27
  size_t sr,
28
  size_t k_stride,
29
  const uint32_t* weights,
30
  const uint32_t* bias,
31
  const void* scale,
32
  uint32_t* packed_weights,
33
  size_t extra_bytes,
34
  const void* params)
35
0
{
36
0
  assert(g != 0);
37
0
  assert(nc != 0);
38
0
  assert(kc != 0);
39
0
  assert(nr == 16);   // This kernel is for NR=16
40
0
  assert(kr == 1);
41
0
  assert(sr == 1);
42
0
  assert(k_stride != 0);
43
0
  assert(weights != NULL);
44
0
  assert(packed_weights != NULL);
45
0
  static const int32_t mask_table[32] = {
46
0
    -1, -1, -1, -1, -1, -1, -1, -1,
47
0
    -1, -1, -1, -1, -1, -1, -1, -1,
48
0
    0, 0, 0, 0, 0, 0, 0, 0,
49
0
    0, 0, 0, 0, 0, 0, 0, 0,
50
0
  };
51
52
0
  const __m256 vzero = _mm256_setzero_ps();
53
0
  const float* b = (const float*) bias;
54
0
  float* packed_w = (float*) packed_weights;
55
0
  do {
56
    // NC main loop multiple of 16
57
0
    const float* w = (const float*) weights;
58
0
    size_t n = nc;
59
60
0
    for (; n >= 16; n -= 16) {
61
0
      if XNN_LIKELY(b != NULL) {
62
0
        const __m256 vb0 = _mm256_loadu_ps(b + 0);
63
0
        const __m256 vb1 = _mm256_loadu_ps(b + 8);
64
0
        _mm256_store_ps(packed_w + 0, vb0);
65
0
        _mm256_store_ps(packed_w + 8, vb1);
66
0
        b += 16;
67
0
      } else {
68
0
        _mm256_store_ps(packed_w + 0, vzero);
69
0
        _mm256_store_ps(packed_w + 8, vzero);
70
0
      }
71
0
      packed_w += 16;
72
73
0
      size_t k = kc;
74
      // KC main loop 8x16
75
0
      for (; k >= 8; k -= 8) {
76
0
        const __m256 v0_0 = _mm256_loadu_ps(w + 0 + 0 * k_stride);
77
0
        const __m256 v1_0 = _mm256_loadu_ps(w + 8 + 0 * k_stride);
78
0
        const __m256 v0_1 = _mm256_loadu_ps(w + 0 + 1 * k_stride);
79
0
        const __m256 v1_1 = _mm256_loadu_ps(w + 8 + 1 * k_stride);
80
0
        const __m256 v0_2 = _mm256_loadu_ps(w + 0 + 2 * k_stride);
81
0
        const __m256 v1_2 = _mm256_loadu_ps(w + 8 + 2 * k_stride);
82
0
        const __m256 v0_3 = _mm256_loadu_ps(w + 0 + 3 * k_stride);
83
0
        const __m256 v1_3 = _mm256_loadu_ps(w + 8 + 3 * k_stride);
84
0
        const __m256 v0_4 = _mm256_loadu_ps(w + 0 + 4 * k_stride);
85
0
        const __m256 v1_4 = _mm256_loadu_ps(w + 8 + 4 * k_stride);
86
0
        const __m256 v0_5 = _mm256_loadu_ps(w + 0 + 5 * k_stride);
87
0
        const __m256 v1_5 = _mm256_loadu_ps(w + 8 + 5 * k_stride);
88
0
        const __m256 v0_6 = _mm256_loadu_ps(w + 0 + 6 * k_stride);
89
0
        const __m256 v1_6 = _mm256_loadu_ps(w + 8 + 6 * k_stride);
90
0
        const __m256 v0_7 = _mm256_loadu_ps(w + 0 + 7 * k_stride);
91
0
        const __m256 v1_7 = _mm256_loadu_ps(w + 8 + 7 * k_stride);
92
0
        _mm256_store_ps(packed_w + 0, v0_0);
93
0
        _mm256_store_ps(packed_w + 8, v1_0);
94
0
        _mm256_store_ps(packed_w + 16, v0_1);
95
0
        _mm256_store_ps(packed_w + 24, v1_1);
96
0
        _mm256_store_ps(packed_w + 32, v0_2);
97
0
        _mm256_store_ps(packed_w + 40, v1_2);
98
0
        _mm256_store_ps(packed_w + 48, v0_3);
99
0
        _mm256_store_ps(packed_w + 56, v1_3);
100
0
        _mm256_store_ps(packed_w + 64, v0_4);
101
0
        _mm256_store_ps(packed_w + 72, v1_4);
102
0
        _mm256_store_ps(packed_w + 80, v0_5);
103
0
        _mm256_store_ps(packed_w + 88, v1_5);
104
0
        _mm256_store_ps(packed_w + 96, v0_6);
105
0
        _mm256_store_ps(packed_w + 104, v1_6);
106
0
        _mm256_store_ps(packed_w + 112, v0_7);
107
0
        _mm256_store_ps(packed_w + 120, v1_7);
108
0
        w += k_stride * 8;
109
0
        packed_w += 128;
110
0
      }
111
112
      // KC remainder loop
113
0
      for (; k > 0; --k) {
114
0
        const __m256 v0 = _mm256_loadu_ps(w + 0);
115
0
        const __m256 v1 = _mm256_loadu_ps(w + 8);
116
0
        _mm256_store_ps(packed_w + 0, v0);
117
0
        _mm256_store_ps(packed_w + 8, v1);
118
0
        w += k_stride;
119
0
        packed_w += 16;
120
0
      }
121
0
      w = w - kc * k_stride + 16;  // Advance to next column of 16 floats
122
0
    }
123
124
    // NC remainder (1..15)
125
0
    if XNN_UNLIKELY(n != 0) {
126
0
      assert(n >= 1);
127
0
      assert(n <= 15);
128
0
      const __m256i vmask0 = _mm256_loadu_si256((const __m256i*) &mask_table[16 - n]);
129
0
      const __m256i vmask1 = _mm256_loadu_si256((const __m256i*) &mask_table[24 - n]);
130
131
0
      if XNN_LIKELY(b != NULL) {
132
0
        const __m256 vb0 = _mm256_maskload_ps(b + 0, vmask0);
133
0
        const __m256 vb1 = _mm256_maskload_ps(b + 8, vmask1);
134
0
        _mm256_store_ps(packed_w + 0, vb0);
135
0
        _mm256_store_ps(packed_w + 8, vb1);
136
0
        b += n;
137
0
      } else {
138
0
        _mm256_store_ps(packed_w + 0, vzero);
139
0
        _mm256_store_ps(packed_w + 8, vzero);
140
0
      }
141
0
      packed_w += 16;
142
143
      // KC main loop
144
0
      for (size_t k = kc; k > 0; --k) {
145
0
        const __m256 v0 = _mm256_maskload_ps(w + 0, vmask0);
146
0
        const __m256 v1 = _mm256_maskload_ps(w + 8, vmask1);
147
0
        _mm256_store_ps(packed_w + 0, v0);
148
0
        _mm256_store_ps(packed_w + 8, v1);
149
0
        w += k_stride;
150
0
        packed_w += 16;
151
0
      }
152
0
    }
153
0
    weights += nc * kc;
154
0
  } while (--g != 0);
155
0
}