Coverage Report

Created: 2026-06-13 06:24

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/ggml/src/ggml-cpu/simd-gemm.h
Line
Count
Source
1
#pragma once
2
3
// Computes C[M x N] += A[M x K] * B[K x N]
4
5
#include "simd-mappings.h"
6
7
// TODO: add support for sizeless vector types
8
#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
9
10
// TODO: untested on avx512
11
// These are in units of GGML_F32_EPR
12
#if defined(__AVX512F__) || defined (__ARM_NEON__)
13
    static constexpr int GEMM_RM = 4;
14
    static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
15
#elif defined(__AVX2__) || defined(__AVX__)
16
    static constexpr int GEMM_RM = 6;
17
    static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
18
#else
19
    static constexpr int GEMM_RM = 2;
20
    static constexpr int GEMM_RN = 2;
21
#endif
22
23
template <int RM, int RN>
24
static inline void simd_gemm_ukernel(
25
    float       * GGML_RESTRICT C,
26
    const float * GGML_RESTRICT A,
27
    const float * GGML_RESTRICT B,
28
    int K, int N)
29
0
{
30
0
    static constexpr int KN = GGML_F32_EPR;
31
32
0
    GGML_F32_VEC acc[RM][RN];
33
0
    for (int64_t i = 0; i < RM; i++) {
34
0
        for (int r = 0; r < RN; r++) {
35
0
            acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
36
0
        }
37
0
    }
38
39
0
    for (int64_t kk = 0; kk < K; kk++) {
40
0
        GGML_F32_VEC Bv[RN];
41
0
        for (int r = 0; r < RN; r++) {
42
0
            Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
43
0
        }
44
0
        for (int64_t i = 0; i < RM; i++) {
45
0
            GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
46
0
            for (int r = 0; r < RN; r++) {
47
0
                acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
48
0
            }
49
0
        }
50
0
    }
51
52
0
    for (int64_t i = 0; i < RM; i++) {
53
0
        for (int r = 0; r < RN; r++) {
54
0
            GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
55
0
        }
56
0
    }
57
0
}
Unexecuted instantiation: ops.cpp:void simd_gemm_ukernel<6, 2>(float*, float const*, float const*, int, int)
Unexecuted instantiation: ops.cpp:void simd_gemm_ukernel<6, 1>(float*, float const*, float const*, int, int)
Unexecuted instantiation: ops.cpp:void simd_gemm_ukernel<1, 2>(float*, float const*, float const*, int, int)
Unexecuted instantiation: ops.cpp:void simd_gemm_ukernel<1, 1>(float*, float const*, float const*, int, int)
58
59
// C[M x N] += A[M x K] * B[K x N]
60
static void simd_gemm(
61
    float       * GGML_RESTRICT C,
62
    const float * GGML_RESTRICT A,
63
    const float * GGML_RESTRICT B,
64
    int M, int K, int N)
65
0
{
66
0
    static constexpr int KN = GGML_F32_EPR;
67
68
0
    int64_t ii = 0;
69
0
    for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
70
0
        int64_t jj = 0;
71
0
        for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
72
0
            simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);
73
0
        }
74
0
        for (; jj + KN <= N; jj += KN) {
75
0
            simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);
76
0
        }
77
0
        for (; jj < N; jj++) {
78
0
            for (int64_t i = 0; i < GEMM_RM; i++) {
79
0
                float a = C[i * N + jj];
80
0
                for (int64_t kk = 0; kk < K; kk++) {
81
0
                    a += A[i + kk] * B[kk * N + jj];
82
0
                }
83
0
                C[i * N + jj] = a;
84
0
            }
85
0
        }
86
87
0
        A += GEMM_RM * K;
88
0
        C += GEMM_RM * N;
89
0
    }
90
91
    // Tail rows: one at a time
92
0
    for (; ii < M; ii++) {
93
0
        int64_t jj = 0;
94
0
        for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
95
0
            simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
96
0
        }
97
0
        for (; jj + KN <= N; jj += KN) {
98
0
            simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
99
0
        }
100
0
        for (; jj < N; jj++) {
101
0
            float a = C[jj];
102
0
            for (int64_t kk = 0; kk < K; kk++) {
103
0
                a += A[kk] * B[kk * N + jj];
104
0
            }
105
0
            C[jj] = a;
106
0
        }
107
108
0
        A += K;
109
0
        C += N;
110
0
    }
111
0
}
112
#elif defined(GGML_SIMD) && defined(__riscv_v_intrinsic)
113
// RM accumulators + 1 B vector = RM + 1 <= 8  =>  RM <= 7
114
// Microkernel: C[RM x vl] += A[RM x K] * B[K x N]
115
template <int RM>
116
static inline void rvv_simd_gemm_ukernel(
117
    float       * GGML_RESTRICT C,
118
    const float * GGML_RESTRICT A,
119
    const float * GGML_RESTRICT B,
120
    int K, int N, size_t vl)
121
{
122
    static_assert(RM >= 1 && RM <= 7, "RM must be 1..7 for LMUL=4");
123
124
    vfloat32m4_t acc_0 = __riscv_vle32_v_f32m4(C + 0 * N, vl);
125
    vfloat32m4_t acc_1, acc_2, acc_3, acc_4, acc_5, acc_6;
126
    if constexpr (RM > 1) acc_1 = __riscv_vle32_v_f32m4(C + 1 * N, vl);
127
    if constexpr (RM > 2) acc_2 = __riscv_vle32_v_f32m4(C + 2 * N, vl);
128
    if constexpr (RM > 3) acc_3 = __riscv_vle32_v_f32m4(C + 3 * N, vl);
129
    if constexpr (RM > 4) acc_4 = __riscv_vle32_v_f32m4(C + 4 * N, vl);
130
    if constexpr (RM > 5) acc_5 = __riscv_vle32_v_f32m4(C + 5 * N, vl);
131
    if constexpr (RM > 6) acc_6 = __riscv_vle32_v_f32m4(C + 6 * N, vl);
132
133
    for (int kk = 0; kk < K; kk++) {
134
        vfloat32m4_t b_0 = __riscv_vle32_v_f32m4(B + kk * N, vl);
135
136
                              acc_0 = __riscv_vfmacc_vf_f32m4(acc_0, A[0 * K + kk], b_0, vl);
137
        if constexpr (RM > 1) acc_1 = __riscv_vfmacc_vf_f32m4(acc_1, A[1 * K + kk], b_0, vl);
138
        if constexpr (RM > 2) acc_2 = __riscv_vfmacc_vf_f32m4(acc_2, A[2 * K + kk], b_0, vl);
139
        if constexpr (RM > 3) acc_3 = __riscv_vfmacc_vf_f32m4(acc_3, A[3 * K + kk], b_0, vl);
140
        if constexpr (RM > 4) acc_4 = __riscv_vfmacc_vf_f32m4(acc_4, A[4 * K + kk], b_0, vl);
141
        if constexpr (RM > 5) acc_5 = __riscv_vfmacc_vf_f32m4(acc_5, A[5 * K + kk], b_0, vl);
142
        if constexpr (RM > 6) acc_6 = __riscv_vfmacc_vf_f32m4(acc_6, A[6 * K + kk], b_0, vl);
143
    }
144
145
                          __riscv_vse32_v_f32m4(C + 0 * N, acc_0, vl);
146
    if constexpr (RM > 1) __riscv_vse32_v_f32m4(C + 1 * N, acc_1, vl);
147
    if constexpr (RM > 2) __riscv_vse32_v_f32m4(C + 2 * N, acc_2, vl);
148
    if constexpr (RM > 3) __riscv_vse32_v_f32m4(C + 3 * N, acc_3, vl);
149
    if constexpr (RM > 4) __riscv_vse32_v_f32m4(C + 4 * N, acc_4, vl);
150
    if constexpr (RM > 5) __riscv_vse32_v_f32m4(C + 5 * N, acc_5, vl);
151
    if constexpr (RM > 6) __riscv_vse32_v_f32m4(C + 6 * N, acc_6, vl);
152
}
153
154
template <int RM>
155
static inline void rvv_simd_gemm_dispatch_tail(
156
    float       * GGML_RESTRICT C,
157
    const float * GGML_RESTRICT A,
158
    const float * GGML_RESTRICT B,
159
    int K, int N, int KN, int remaining_rows)
160
{
161
    if constexpr (RM > 0) {
162
        if (remaining_rows == RM) {
163
            int64_t jj = 0;
164
            for (; jj + KN <= N; jj += KN) {
165
                rvv_simd_gemm_ukernel<RM>(C + jj, A, B + jj, K, N, KN);
166
            }
167
            if (jj < N) {
168
                rvv_simd_gemm_ukernel<RM>(C + jj, A, B + jj, K, N, N - jj);
169
            }
170
        } else {
171
            rvv_simd_gemm_dispatch_tail<RM - 1>(C, A, B, K, N, KN, remaining_rows);
172
        }
173
    }
174
}
175
176
static constexpr int GEMM_RM = 7;
177
178
// C[M x N] += A[M x K] * B[K x N]
179
static void simd_gemm(
180
    float       * GGML_RESTRICT C,
181
    const float * GGML_RESTRICT A,
182
    const float * GGML_RESTRICT B,
183
    int M, int K, int N)
184
{
185
    const int KN = (int)__riscv_vlenb();
186
    int64_t ii = 0;
187
    for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
188
        int64_t jj = 0;
189
        for (; jj + KN <= N; jj += KN) {
190
            rvv_simd_gemm_ukernel<GEMM_RM>(C + jj, A, B + jj, K, N, KN);
191
        }
192
        if (jj < N) {
193
            rvv_simd_gemm_ukernel<GEMM_RM>(C + jj, A, B + jj, K, N, N - jj);
194
        }
195
        A += GEMM_RM * K;
196
        C += GEMM_RM * N;
197
    }
198
199
    int remaining_rows = M - ii;
200
    rvv_simd_gemm_dispatch_tail<GEMM_RM - 1>(C, A, B, K, N, KN, remaining_rows);
201
}
202
203
#if defined(__GNUC__) && !defined(__clang__)
204
#pragma GCC diagnostic pop
205
#endif
206
207
#else // scalar path
208
209
static void simd_gemm(
210
    float       * GGML_RESTRICT C,
211
    const float * GGML_RESTRICT A,
212
    const float * GGML_RESTRICT B,
213
    int M, int K, int N)
214
{
215
    for (int64_t i = 0; i < M; i++) {
216
        for (int64_t j = 0; j < N; j++) {
217
            float sum = C[i * N + j];
218
            for (int64_t kk = 0; kk < K; kk++) {
219
                sum += A[i * K + kk] * B[kk * N + j];
220
            }
221
            C[i * N + j] = sum;
222
        }
223
    }
224
}
225
226
#endif // GGML_SIMD