/src/llama.cpp/ggml/src/ggml-cpu/simd-gemm.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | // Computes C[M x N] += A[M x K] * B[K x N] |
4 | | |
5 | | #include "simd-mappings.h" |
6 | | |
7 | | // TODO: add support for sizeless vector types |
8 | | #if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic) |
9 | | |
10 | | // TODO: untested on avx512 |
11 | | // These are in units of GGML_F32_EPR |
12 | | #if defined(__AVX512F__) || defined (__ARM_NEON__) |
13 | | static constexpr int GEMM_RM = 4; |
14 | | static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32 |
15 | | #elif defined(__AVX2__) || defined(__AVX__) |
16 | | static constexpr int GEMM_RM = 6; |
17 | | static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16 |
18 | | #else |
19 | | static constexpr int GEMM_RM = 2; |
20 | | static constexpr int GEMM_RN = 2; |
21 | | #endif |
22 | | |
23 | | template <int RM, int RN> |
24 | | static inline void simd_gemm_ukernel( |
25 | | float * GGML_RESTRICT C, |
26 | | const float * GGML_RESTRICT A, |
27 | | const float * GGML_RESTRICT B, |
28 | | int K, int N) |
29 | 0 | { |
30 | 0 | static constexpr int KN = GGML_F32_EPR; |
31 | |
|
32 | 0 | GGML_F32_VEC acc[RM][RN]; |
33 | 0 | for (int64_t i = 0; i < RM; i++) { |
34 | 0 | for (int r = 0; r < RN; r++) { |
35 | 0 | acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN); |
36 | 0 | } |
37 | 0 | } |
38 | |
|
39 | 0 | for (int64_t kk = 0; kk < K; kk++) { |
40 | 0 | GGML_F32_VEC Bv[RN]; |
41 | 0 | for (int r = 0; r < RN; r++) { |
42 | 0 | Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN); |
43 | 0 | } |
44 | 0 | for (int64_t i = 0; i < RM; i++) { |
45 | 0 | GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]); |
46 | 0 | for (int r = 0; r < RN; r++) { |
47 | 0 | acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p); |
48 | 0 | } |
49 | 0 | } |
50 | 0 | } |
51 | |
|
52 | 0 | for (int64_t i = 0; i < RM; i++) { |
53 | 0 | for (int r = 0; r < RN; r++) { |
54 | 0 | GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]); |
55 | 0 | } |
56 | 0 | } |
57 | 0 | } Unexecuted instantiation: ops.cpp:void simd_gemm_ukernel<6, 2>(float*, float const*, float const*, int, int) Unexecuted instantiation: ops.cpp:void simd_gemm_ukernel<6, 1>(float*, float const*, float const*, int, int) Unexecuted instantiation: ops.cpp:void simd_gemm_ukernel<1, 2>(float*, float const*, float const*, int, int) Unexecuted instantiation: ops.cpp:void simd_gemm_ukernel<1, 1>(float*, float const*, float const*, int, int) |
58 | | |
59 | | // C[M x N] += A[M x K] * B[K x N] |
60 | | static void simd_gemm( |
61 | | float * GGML_RESTRICT C, |
62 | | const float * GGML_RESTRICT A, |
63 | | const float * GGML_RESTRICT B, |
64 | | int M, int K, int N) |
65 | 0 | { |
66 | 0 | static constexpr int KN = GGML_F32_EPR; |
67 | |
|
68 | 0 | int64_t ii = 0; |
69 | 0 | for (; ii + GEMM_RM <= M; ii += GEMM_RM) { |
70 | 0 | int64_t jj = 0; |
71 | 0 | for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { |
72 | 0 | simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N); |
73 | 0 | } |
74 | 0 | for (; jj + KN <= N; jj += KN) { |
75 | 0 | simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N); |
76 | 0 | } |
77 | 0 | for (; jj < N; jj++) { |
78 | 0 | for (int64_t i = 0; i < GEMM_RM; i++) { |
79 | 0 | float a = C[i * N + jj]; |
80 | 0 | for (int64_t kk = 0; kk < K; kk++) { |
81 | 0 | a += A[i + kk] * B[kk * N + jj]; |
82 | 0 | } |
83 | 0 | C[i * N + jj] = a; |
84 | 0 | } |
85 | 0 | } |
86 | |
|
87 | 0 | A += GEMM_RM * K; |
88 | 0 | C += GEMM_RM * N; |
89 | 0 | } |
90 | | |
91 | | // Tail rows: one at a time |
92 | 0 | for (; ii < M; ii++) { |
93 | 0 | int64_t jj = 0; |
94 | 0 | for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { |
95 | 0 | simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N); |
96 | 0 | } |
97 | 0 | for (; jj + KN <= N; jj += KN) { |
98 | 0 | simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N); |
99 | 0 | } |
100 | 0 | for (; jj < N; jj++) { |
101 | 0 | float a = C[jj]; |
102 | 0 | for (int64_t kk = 0; kk < K; kk++) { |
103 | 0 | a += A[kk] * B[kk * N + jj]; |
104 | 0 | } |
105 | 0 | C[jj] = a; |
106 | 0 | } |
107 | |
|
108 | 0 | A += K; |
109 | 0 | C += N; |
110 | 0 | } |
111 | 0 | } |
112 | | #elif defined(GGML_SIMD) && defined(__riscv_v_intrinsic) |
113 | | // RM accumulators + 1 B vector = RM + 1 <= 8 => RM <= 7 |
114 | | // Microkernel: C[RM x vl] += A[RM x K] * B[K x N] |
115 | | template <int RM> |
116 | | static inline void rvv_simd_gemm_ukernel( |
117 | | float * GGML_RESTRICT C, |
118 | | const float * GGML_RESTRICT A, |
119 | | const float * GGML_RESTRICT B, |
120 | | int K, int N, size_t vl) |
121 | | { |
122 | | static_assert(RM >= 1 && RM <= 7, "RM must be 1..7 for LMUL=4"); |
123 | | |
124 | | vfloat32m4_t acc_0 = __riscv_vle32_v_f32m4(C + 0 * N, vl); |
125 | | vfloat32m4_t acc_1, acc_2, acc_3, acc_4, acc_5, acc_6; |
126 | | if constexpr (RM > 1) acc_1 = __riscv_vle32_v_f32m4(C + 1 * N, vl); |
127 | | if constexpr (RM > 2) acc_2 = __riscv_vle32_v_f32m4(C + 2 * N, vl); |
128 | | if constexpr (RM > 3) acc_3 = __riscv_vle32_v_f32m4(C + 3 * N, vl); |
129 | | if constexpr (RM > 4) acc_4 = __riscv_vle32_v_f32m4(C + 4 * N, vl); |
130 | | if constexpr (RM > 5) acc_5 = __riscv_vle32_v_f32m4(C + 5 * N, vl); |
131 | | if constexpr (RM > 6) acc_6 = __riscv_vle32_v_f32m4(C + 6 * N, vl); |
132 | | |
133 | | for (int kk = 0; kk < K; kk++) { |
134 | | vfloat32m4_t b_0 = __riscv_vle32_v_f32m4(B + kk * N, vl); |
135 | | |
136 | | acc_0 = __riscv_vfmacc_vf_f32m4(acc_0, A[0 * K + kk], b_0, vl); |
137 | | if constexpr (RM > 1) acc_1 = __riscv_vfmacc_vf_f32m4(acc_1, A[1 * K + kk], b_0, vl); |
138 | | if constexpr (RM > 2) acc_2 = __riscv_vfmacc_vf_f32m4(acc_2, A[2 * K + kk], b_0, vl); |
139 | | if constexpr (RM > 3) acc_3 = __riscv_vfmacc_vf_f32m4(acc_3, A[3 * K + kk], b_0, vl); |
140 | | if constexpr (RM > 4) acc_4 = __riscv_vfmacc_vf_f32m4(acc_4, A[4 * K + kk], b_0, vl); |
141 | | if constexpr (RM > 5) acc_5 = __riscv_vfmacc_vf_f32m4(acc_5, A[5 * K + kk], b_0, vl); |
142 | | if constexpr (RM > 6) acc_6 = __riscv_vfmacc_vf_f32m4(acc_6, A[6 * K + kk], b_0, vl); |
143 | | } |
144 | | |
145 | | __riscv_vse32_v_f32m4(C + 0 * N, acc_0, vl); |
146 | | if constexpr (RM > 1) __riscv_vse32_v_f32m4(C + 1 * N, acc_1, vl); |
147 | | if constexpr (RM > 2) __riscv_vse32_v_f32m4(C + 2 * N, acc_2, vl); |
148 | | if constexpr (RM > 3) __riscv_vse32_v_f32m4(C + 3 * N, acc_3, vl); |
149 | | if constexpr (RM > 4) __riscv_vse32_v_f32m4(C + 4 * N, acc_4, vl); |
150 | | if constexpr (RM > 5) __riscv_vse32_v_f32m4(C + 5 * N, acc_5, vl); |
151 | | if constexpr (RM > 6) __riscv_vse32_v_f32m4(C + 6 * N, acc_6, vl); |
152 | | } |
153 | | |
154 | | template <int RM> |
155 | | static inline void rvv_simd_gemm_dispatch_tail( |
156 | | float * GGML_RESTRICT C, |
157 | | const float * GGML_RESTRICT A, |
158 | | const float * GGML_RESTRICT B, |
159 | | int K, int N, int KN, int remaining_rows) |
160 | | { |
161 | | if constexpr (RM > 0) { |
162 | | if (remaining_rows == RM) { |
163 | | int64_t jj = 0; |
164 | | for (; jj + KN <= N; jj += KN) { |
165 | | rvv_simd_gemm_ukernel<RM>(C + jj, A, B + jj, K, N, KN); |
166 | | } |
167 | | if (jj < N) { |
168 | | rvv_simd_gemm_ukernel<RM>(C + jj, A, B + jj, K, N, N - jj); |
169 | | } |
170 | | } else { |
171 | | rvv_simd_gemm_dispatch_tail<RM - 1>(C, A, B, K, N, KN, remaining_rows); |
172 | | } |
173 | | } |
174 | | } |
175 | | |
176 | | static constexpr int GEMM_RM = 7; |
177 | | |
178 | | // C[M x N] += A[M x K] * B[K x N] |
179 | | static void simd_gemm( |
180 | | float * GGML_RESTRICT C, |
181 | | const float * GGML_RESTRICT A, |
182 | | const float * GGML_RESTRICT B, |
183 | | int M, int K, int N) |
184 | | { |
185 | | const int KN = (int)__riscv_vlenb(); |
186 | | int64_t ii = 0; |
187 | | for (; ii + GEMM_RM <= M; ii += GEMM_RM) { |
188 | | int64_t jj = 0; |
189 | | for (; jj + KN <= N; jj += KN) { |
190 | | rvv_simd_gemm_ukernel<GEMM_RM>(C + jj, A, B + jj, K, N, KN); |
191 | | } |
192 | | if (jj < N) { |
193 | | rvv_simd_gemm_ukernel<GEMM_RM>(C + jj, A, B + jj, K, N, N - jj); |
194 | | } |
195 | | A += GEMM_RM * K; |
196 | | C += GEMM_RM * N; |
197 | | } |
198 | | |
199 | | int remaining_rows = M - ii; |
200 | | rvv_simd_gemm_dispatch_tail<GEMM_RM - 1>(C, A, B, K, N, KN, remaining_rows); |
201 | | } |
202 | | |
203 | | #if defined(__GNUC__) && !defined(__clang__) |
204 | | #pragma GCC diagnostic pop |
205 | | #endif |
206 | | |
207 | | #else // scalar path |
208 | | |
209 | | static void simd_gemm( |
210 | | float * GGML_RESTRICT C, |
211 | | const float * GGML_RESTRICT A, |
212 | | const float * GGML_RESTRICT B, |
213 | | int M, int K, int N) |
214 | | { |
215 | | for (int64_t i = 0; i < M; i++) { |
216 | | for (int64_t j = 0; j < N; j++) { |
217 | | float sum = C[i * N + j]; |
218 | | for (int64_t kk = 0; kk < K; kk++) { |
219 | | sum += A[i * K + kk] * B[kk * N + j]; |
220 | | } |
221 | | C[i * N + j] = sum; |
222 | | } |
223 | | } |
224 | | } |
225 | | |
226 | | #endif // GGML_SIMD |