Coverage Report

Created: 2026-01-15 07:10

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/xnnpack/src/x32-packw/gen/x32-packw-x16-gemm-gio-avx512f-u8.c
Line
Count
Source
1
// clang-format off
2
// Auto-generated file. Do not edit!
3
//   Template: src/x32-packw/gio-avx512.c.in
4
//   Generator: tools/xngen
5
//
6
// Copyright 2024 Google LLC
7
//
8
// This source code is licensed under the BSD-style license found in the
9
// LICENSE file in the root directory of this source tree.
10
11
12
#include <assert.h>
13
#include <stddef.h>
14
#include <stdint.h>
15
16
#include <immintrin.h>
17
18
#include "src/xnnpack/common.h"
19
#include "src/xnnpack/intrinsics-polyfill.h"
20
#include "src/xnnpack/packw.h"
21
22
23
// Pack pre-transposed weights (GIO) for use by f32-gemm
24
void xnn_x32_packw_gemm_gio_ukernel_x16__avx512f_u8(
25
  size_t g,                  // Batch size (outer loop).  usually 1
26
  size_t nc,                 // Number of columns and typically large
27
  size_t kc,                 // Number of rows and typically small
28
  size_t nr,                 // Matches gemm and is a multiple of vector sizes
29
  size_t kr,                 // unused - must be 1
30
  size_t sr,                 // unused - must be 1
31
  size_t k_stride,           // Elements per row (typically same as nc)
32
  const uint32_t* weights,   // Weights to pack. unaligned, unpadded
33
  const uint32_t* bias,      // Bias to pack. unaligned, unpadded, can be NULL
34
  const void* scale,         // unused
35
  uint32_t* packed_weights,  // packed weights output buffer - aligned, padded
36
  size_t extra_bytes,        // number of extra bytes between weights. aligned
37
  const void* params)        // unused
38
0
{
39
0
  assert(g != 0);
40
0
  assert(nc != 0);
41
0
  assert(kc != 0);
42
0
  assert(nr == 16);   // This kernel is for NR=16
43
0
  assert(kr == 1);
44
0
  assert(sr == 1);
45
0
  assert(k_stride != 0);
46
0
  assert(weights != NULL);
47
0
  assert(packed_weights != NULL);
48
49
0
  const __m512 vzero = _mm512_setzero_ps();
50
0
  const float* b = (const float*) bias;
51
0
  float* packed_w = (float*) packed_weights;
52
0
  do {
53
    // NC main loop multiple of 16
54
0
    const float* w = (const float*) weights;
55
0
    size_t n = nc;
56
57
0
    for (; n >= 16; n -= 16) {
58
0
      if XNN_LIKELY(b != NULL) {
59
0
        const __m512 vb0 = _mm512_loadu_ps(b + 0);
60
0
        _mm512_store_ps(packed_w + 0, vb0);
61
0
        b += 16;
62
0
      } else {
63
0
        _mm512_store_ps(packed_w + 0, vzero);
64
0
      }
65
0
      packed_w += 16;
66
67
0
      size_t k = kc;
68
      // KC main loop 8x16
69
0
      for (; k >= 8; k -= 8) {
70
0
        const __m512 v0_0 = _mm512_loadu_ps(w + 0 + 0 * k_stride);
71
0
        const __m512 v0_1 = _mm512_loadu_ps(w + 0 + 1 * k_stride);
72
0
        const __m512 v0_2 = _mm512_loadu_ps(w + 0 + 2 * k_stride);
73
0
        const __m512 v0_3 = _mm512_loadu_ps(w + 0 + 3 * k_stride);
74
0
        const __m512 v0_4 = _mm512_loadu_ps(w + 0 + 4 * k_stride);
75
0
        const __m512 v0_5 = _mm512_loadu_ps(w + 0 + 5 * k_stride);
76
0
        const __m512 v0_6 = _mm512_loadu_ps(w + 0 + 6 * k_stride);
77
0
        const __m512 v0_7 = _mm512_loadu_ps(w + 0 + 7 * k_stride);
78
0
        _mm512_store_ps(packed_w + 0, v0_0);
79
0
        _mm512_store_ps(packed_w + 16, v0_1);
80
0
        _mm512_store_ps(packed_w + 32, v0_2);
81
0
        _mm512_store_ps(packed_w + 48, v0_3);
82
0
        _mm512_store_ps(packed_w + 64, v0_4);
83
0
        _mm512_store_ps(packed_w + 80, v0_5);
84
0
        _mm512_store_ps(packed_w + 96, v0_6);
85
0
        _mm512_store_ps(packed_w + 112, v0_7);
86
0
        w += k_stride * 8;
87
0
        packed_w += 128;
88
0
      }
89
90
      // KC remainder loop
91
0
      for (; k > 0; --k) {
92
0
        const __m512 v0 = _mm512_loadu_ps(w + 0);
93
0
        _mm512_store_ps(packed_w + 0, v0);
94
0
        w += k_stride;
95
0
        packed_w += 16;
96
0
      }
97
0
      w = w - kc * k_stride + 16;  // Advance to next column of 16 floats
98
0
    }
99
100
    // NC remainder (1..15)
101
0
    if XNN_UNLIKELY(n != 0) {
102
0
      assert(n >= 1);
103
0
      assert(n <= 15);
104
105
      // Prepare mask for valid 32-bit elements (depends on n).
106
0
      const __mmask16 vmask0 = _cvtu32_mask16((uint32_t) (((UINT64_C(1) << n) - 1) >> 0));
107
108
0
      if XNN_LIKELY(b != NULL) {
109
0
        const __m512 vb0 = _mm512_maskz_loadu_ps(vmask0, b + 0);
110
0
        _mm512_store_ps(packed_w + 0, vb0);
111
0
        b += n;
112
0
      } else {
113
0
        _mm512_store_ps(packed_w + 0, vzero);
114
0
      }
115
0
      packed_w += 16;
116
117
      // KC main loop
118
0
      for (size_t k = kc; k > 0; --k) {
119
0
        const __m512 v0 = _mm512_maskz_loadu_ps(vmask0, w + 0);
120
0
        _mm512_store_ps(packed_w + 0, v0);
121
0
        w += k_stride;
122
0
        packed_w += 16;
123
0
      }
124
0
    }
125
0
    weights += nc * kc;
126
0
  } while (--g != 0);
127
0
}