/src/xnnpack/src/x32-packw/gen/x32-packw-x16-gemm-gio-avx-u8.c
Line | Count | Source |
1 | | // clang-format off |
2 | | // Auto-generated file. Do not edit! |
3 | | // Template: src/x32-packw/gio-avx.c.in |
4 | | // Generator: tools/xngen |
5 | | // |
6 | | // Copyright 2024 Google LLC |
7 | | // |
8 | | // This source code is licensed under the BSD-style license found in the |
9 | | // LICENSE file in the root directory of this source tree. |
10 | | |
11 | | #include <assert.h> |
12 | | #include <stddef.h> |
13 | | #include <stdint.h> |
14 | | |
15 | | #include <immintrin.h> |
16 | | |
17 | | #include "src/xnnpack/common.h" |
18 | | #include "src/xnnpack/packw.h" |
19 | | |
20 | | |
21 | | void xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8( |
22 | | size_t g, |
23 | | size_t nc, |
24 | | size_t kc, |
25 | | size_t nr, |
26 | | size_t kr, |
27 | | size_t sr, |
28 | | size_t k_stride, |
29 | | const uint32_t* weights, |
30 | | const uint32_t* bias, |
31 | | const void* scale, |
32 | | uint32_t* packed_weights, |
33 | | size_t extra_bytes, |
34 | | const void* params) |
35 | 0 | { |
36 | 0 | assert(g != 0); |
37 | 0 | assert(nc != 0); |
38 | 0 | assert(kc != 0); |
39 | 0 | assert(nr == 16); // This kernel is for NR=16 |
40 | 0 | assert(kr == 1); |
41 | 0 | assert(sr == 1); |
42 | 0 | assert(k_stride != 0); |
43 | 0 | assert(weights != NULL); |
44 | 0 | assert(packed_weights != NULL); |
45 | 0 | static const int32_t mask_table[32] = { |
46 | 0 | -1, -1, -1, -1, -1, -1, -1, -1, |
47 | 0 | -1, -1, -1, -1, -1, -1, -1, -1, |
48 | 0 | 0, 0, 0, 0, 0, 0, 0, 0, |
49 | 0 | 0, 0, 0, 0, 0, 0, 0, 0, |
50 | 0 | }; |
51 | |
|
52 | 0 | const __m256 vzero = _mm256_setzero_ps(); |
53 | 0 | const float* b = (const float*) bias; |
54 | 0 | float* packed_w = (float*) packed_weights; |
55 | 0 | do { |
56 | | // NC main loop multiple of 16 |
57 | 0 | const float* w = (const float*) weights; |
58 | 0 | size_t n = nc; |
59 | |
|
60 | 0 | for (; n >= 16; n -= 16) { |
61 | 0 | if XNN_LIKELY(b != NULL) { |
62 | 0 | const __m256 vb0 = _mm256_loadu_ps(b + 0); |
63 | 0 | const __m256 vb1 = _mm256_loadu_ps(b + 8); |
64 | 0 | _mm256_store_ps(packed_w + 0, vb0); |
65 | 0 | _mm256_store_ps(packed_w + 8, vb1); |
66 | 0 | b += 16; |
67 | 0 | } else { |
68 | 0 | _mm256_store_ps(packed_w + 0, vzero); |
69 | 0 | _mm256_store_ps(packed_w + 8, vzero); |
70 | 0 | } |
71 | 0 | packed_w += 16; |
72 | |
|
73 | 0 | size_t k = kc; |
74 | | // KC main loop 8x16 |
75 | 0 | for (; k >= 8; k -= 8) { |
76 | 0 | const __m256 v0_0 = _mm256_loadu_ps(w + 0 + 0 * k_stride); |
77 | 0 | const __m256 v1_0 = _mm256_loadu_ps(w + 8 + 0 * k_stride); |
78 | 0 | const __m256 v0_1 = _mm256_loadu_ps(w + 0 + 1 * k_stride); |
79 | 0 | const __m256 v1_1 = _mm256_loadu_ps(w + 8 + 1 * k_stride); |
80 | 0 | const __m256 v0_2 = _mm256_loadu_ps(w + 0 + 2 * k_stride); |
81 | 0 | const __m256 v1_2 = _mm256_loadu_ps(w + 8 + 2 * k_stride); |
82 | 0 | const __m256 v0_3 = _mm256_loadu_ps(w + 0 + 3 * k_stride); |
83 | 0 | const __m256 v1_3 = _mm256_loadu_ps(w + 8 + 3 * k_stride); |
84 | 0 | const __m256 v0_4 = _mm256_loadu_ps(w + 0 + 4 * k_stride); |
85 | 0 | const __m256 v1_4 = _mm256_loadu_ps(w + 8 + 4 * k_stride); |
86 | 0 | const __m256 v0_5 = _mm256_loadu_ps(w + 0 + 5 * k_stride); |
87 | 0 | const __m256 v1_5 = _mm256_loadu_ps(w + 8 + 5 * k_stride); |
88 | 0 | const __m256 v0_6 = _mm256_loadu_ps(w + 0 + 6 * k_stride); |
89 | 0 | const __m256 v1_6 = _mm256_loadu_ps(w + 8 + 6 * k_stride); |
90 | 0 | const __m256 v0_7 = _mm256_loadu_ps(w + 0 + 7 * k_stride); |
91 | 0 | const __m256 v1_7 = _mm256_loadu_ps(w + 8 + 7 * k_stride); |
92 | 0 | _mm256_store_ps(packed_w + 0, v0_0); |
93 | 0 | _mm256_store_ps(packed_w + 8, v1_0); |
94 | 0 | _mm256_store_ps(packed_w + 16, v0_1); |
95 | 0 | _mm256_store_ps(packed_w + 24, v1_1); |
96 | 0 | _mm256_store_ps(packed_w + 32, v0_2); |
97 | 0 | _mm256_store_ps(packed_w + 40, v1_2); |
98 | 0 | _mm256_store_ps(packed_w + 48, v0_3); |
99 | 0 | _mm256_store_ps(packed_w + 56, v1_3); |
100 | 0 | _mm256_store_ps(packed_w + 64, v0_4); |
101 | 0 | _mm256_store_ps(packed_w + 72, v1_4); |
102 | 0 | _mm256_store_ps(packed_w + 80, v0_5); |
103 | 0 | _mm256_store_ps(packed_w + 88, v1_5); |
104 | 0 | _mm256_store_ps(packed_w + 96, v0_6); |
105 | 0 | _mm256_store_ps(packed_w + 104, v1_6); |
106 | 0 | _mm256_store_ps(packed_w + 112, v0_7); |
107 | 0 | _mm256_store_ps(packed_w + 120, v1_7); |
108 | 0 | w += k_stride * 8; |
109 | 0 | packed_w += 128; |
110 | 0 | } |
111 | | |
112 | | // KC remainder loop |
113 | 0 | for (; k > 0; --k) { |
114 | 0 | const __m256 v0 = _mm256_loadu_ps(w + 0); |
115 | 0 | const __m256 v1 = _mm256_loadu_ps(w + 8); |
116 | 0 | _mm256_store_ps(packed_w + 0, v0); |
117 | 0 | _mm256_store_ps(packed_w + 8, v1); |
118 | 0 | w += k_stride; |
119 | 0 | packed_w += 16; |
120 | 0 | } |
121 | 0 | w = w - kc * k_stride + 16; // Advance to next column of 16 floats |
122 | 0 | } |
123 | | |
124 | | // NC remainder (1..15) |
125 | 0 | if XNN_UNLIKELY(n != 0) { |
126 | 0 | assert(n >= 1); |
127 | 0 | assert(n <= 15); |
128 | 0 | const __m256i vmask0 = _mm256_loadu_si256((const __m256i*) &mask_table[16 - n]); |
129 | 0 | const __m256i vmask1 = _mm256_loadu_si256((const __m256i*) &mask_table[24 - n]); |
130 | |
|
131 | 0 | if XNN_LIKELY(b != NULL) { |
132 | 0 | const __m256 vb0 = _mm256_maskload_ps(b + 0, vmask0); |
133 | 0 | const __m256 vb1 = _mm256_maskload_ps(b + 8, vmask1); |
134 | 0 | _mm256_store_ps(packed_w + 0, vb0); |
135 | 0 | _mm256_store_ps(packed_w + 8, vb1); |
136 | 0 | b += n; |
137 | 0 | } else { |
138 | 0 | _mm256_store_ps(packed_w + 0, vzero); |
139 | 0 | _mm256_store_ps(packed_w + 8, vzero); |
140 | 0 | } |
141 | 0 | packed_w += 16; |
142 | | |
143 | | // KC main loop |
144 | 0 | for (size_t k = kc; k > 0; --k) { |
145 | 0 | const __m256 v0 = _mm256_maskload_ps(w + 0, vmask0); |
146 | 0 | const __m256 v1 = _mm256_maskload_ps(w + 8, vmask1); |
147 | 0 | _mm256_store_ps(packed_w + 0, v0); |
148 | 0 | _mm256_store_ps(packed_w + 8, v1); |
149 | 0 | w += k_stride; |
150 | 0 | packed_w += 16; |
151 | 0 | } |
152 | 0 | } |
153 | 0 | weights += nc * kc; |
154 | 0 | } while (--g != 0); |
155 | 0 | } |