/src/xnnpack/src/x32-packw/gen/x32-packw-x16-gemm-gio-avx512f-u8.c
Line | Count | Source |
1 | | // clang-format off |
2 | | // Auto-generated file. Do not edit! |
3 | | // Template: src/x32-packw/gio-avx512.c.in |
4 | | // Generator: tools/xngen |
5 | | // |
6 | | // Copyright 2024 Google LLC |
7 | | // |
8 | | // This source code is licensed under the BSD-style license found in the |
9 | | // LICENSE file in the root directory of this source tree. |
10 | | |
11 | | |
12 | | #include <assert.h> |
13 | | #include <stddef.h> |
14 | | #include <stdint.h> |
15 | | |
16 | | #include <immintrin.h> |
17 | | |
18 | | #include "src/xnnpack/common.h" |
19 | | #include "src/xnnpack/intrinsics-polyfill.h" |
20 | | #include "src/xnnpack/packw.h" |
21 | | |
22 | | |
23 | | // Pack pre-transposed weights (GIO) for use by f32-gemm |
24 | | void xnn_x32_packw_gemm_gio_ukernel_x16__avx512f_u8( |
25 | | size_t g, // Batch size (outer loop). usually 1 |
26 | | size_t nc, // Number of columns and typically large |
27 | | size_t kc, // Number of rows and typically small |
28 | | size_t nr, // Matches gemm and is a multiple of vector sizes |
29 | | size_t kr, // unused - must be 1 |
30 | | size_t sr, // unused - must be 1 |
31 | | size_t k_stride, // Elements per row (typically same as nc) |
32 | | const uint32_t* weights, // Weights to pack. unaligned, unpadded |
33 | | const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL |
34 | | const void* scale, // unused |
35 | | uint32_t* packed_weights, // packed weights output buffer - aligned, padded |
36 | | size_t extra_bytes, // number of extra bytes between weights. aligned |
37 | | const void* params) // unused |
38 | 0 | { |
39 | 0 | assert(g != 0); |
40 | 0 | assert(nc != 0); |
41 | 0 | assert(kc != 0); |
42 | 0 | assert(nr == 16); // This kernel is for NR=16 |
43 | 0 | assert(kr == 1); |
44 | 0 | assert(sr == 1); |
45 | 0 | assert(k_stride != 0); |
46 | 0 | assert(weights != NULL); |
47 | 0 | assert(packed_weights != NULL); |
48 | | |
49 | 0 | const __m512 vzero = _mm512_setzero_ps(); |
50 | 0 | const float* b = (const float*) bias; |
51 | 0 | float* packed_w = (float*) packed_weights; |
52 | 0 | do { |
53 | | // NC main loop multiple of 16 |
54 | 0 | const float* w = (const float*) weights; |
55 | 0 | size_t n = nc; |
56 | |
|
57 | 0 | for (; n >= 16; n -= 16) { |
58 | 0 | if XNN_LIKELY(b != NULL) { |
59 | 0 | const __m512 vb0 = _mm512_loadu_ps(b + 0); |
60 | 0 | _mm512_store_ps(packed_w + 0, vb0); |
61 | 0 | b += 16; |
62 | 0 | } else { |
63 | 0 | _mm512_store_ps(packed_w + 0, vzero); |
64 | 0 | } |
65 | 0 | packed_w += 16; |
66 | |
|
67 | 0 | size_t k = kc; |
68 | | // KC main loop 8x16 |
69 | 0 | for (; k >= 8; k -= 8) { |
70 | 0 | const __m512 v0_0 = _mm512_loadu_ps(w + 0 + 0 * k_stride); |
71 | 0 | const __m512 v0_1 = _mm512_loadu_ps(w + 0 + 1 * k_stride); |
72 | 0 | const __m512 v0_2 = _mm512_loadu_ps(w + 0 + 2 * k_stride); |
73 | 0 | const __m512 v0_3 = _mm512_loadu_ps(w + 0 + 3 * k_stride); |
74 | 0 | const __m512 v0_4 = _mm512_loadu_ps(w + 0 + 4 * k_stride); |
75 | 0 | const __m512 v0_5 = _mm512_loadu_ps(w + 0 + 5 * k_stride); |
76 | 0 | const __m512 v0_6 = _mm512_loadu_ps(w + 0 + 6 * k_stride); |
77 | 0 | const __m512 v0_7 = _mm512_loadu_ps(w + 0 + 7 * k_stride); |
78 | 0 | _mm512_store_ps(packed_w + 0, v0_0); |
79 | 0 | _mm512_store_ps(packed_w + 16, v0_1); |
80 | 0 | _mm512_store_ps(packed_w + 32, v0_2); |
81 | 0 | _mm512_store_ps(packed_w + 48, v0_3); |
82 | 0 | _mm512_store_ps(packed_w + 64, v0_4); |
83 | 0 | _mm512_store_ps(packed_w + 80, v0_5); |
84 | 0 | _mm512_store_ps(packed_w + 96, v0_6); |
85 | 0 | _mm512_store_ps(packed_w + 112, v0_7); |
86 | 0 | w += k_stride * 8; |
87 | 0 | packed_w += 128; |
88 | 0 | } |
89 | | |
90 | | // KC remainder loop |
91 | 0 | for (; k > 0; --k) { |
92 | 0 | const __m512 v0 = _mm512_loadu_ps(w + 0); |
93 | 0 | _mm512_store_ps(packed_w + 0, v0); |
94 | 0 | w += k_stride; |
95 | 0 | packed_w += 16; |
96 | 0 | } |
97 | 0 | w = w - kc * k_stride + 16; // Advance to next column of 16 floats |
98 | 0 | } |
99 | | |
100 | | // NC remainder (1..15) |
101 | 0 | if XNN_UNLIKELY(n != 0) { |
102 | 0 | assert(n >= 1); |
103 | 0 | assert(n <= 15); |
104 | | |
105 | | // Prepare mask for valid 32-bit elements (depends on n). |
106 | 0 | const __mmask16 vmask0 = _cvtu32_mask16((uint32_t) (((UINT64_C(1) << n) - 1) >> 0)); |
107 | |
|
108 | 0 | if XNN_LIKELY(b != NULL) { |
109 | 0 | const __m512 vb0 = _mm512_maskz_loadu_ps(vmask0, b + 0); |
110 | 0 | _mm512_store_ps(packed_w + 0, vb0); |
111 | 0 | b += n; |
112 | 0 | } else { |
113 | 0 | _mm512_store_ps(packed_w + 0, vzero); |
114 | 0 | } |
115 | 0 | packed_w += 16; |
116 | | |
117 | | // KC main loop |
118 | 0 | for (size_t k = kc; k > 0; --k) { |
119 | 0 | const __m512 v0 = _mm512_maskz_loadu_ps(vmask0, w + 0); |
120 | 0 | _mm512_store_ps(packed_w + 0, v0); |
121 | 0 | w += k_stride; |
122 | 0 | packed_w += 16; |
123 | 0 | } |
124 | 0 | } |
125 | 0 | weights += nc * kc; |
126 | 0 | } while (--g != 0); |
127 | 0 | } |