Coverage Report

Created: 2026-02-26 07:06

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/models/rwkv6-base.cpp
Line
Count
Source
1
#include "models.h"
2
3
#include "llama-memory-recurrent.h"
4
5
llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) :
6
0
    llm_graph_context(params),
7
0
    model(model) {}
8
9
ggml_tensor * llm_build_rwkv6_base::build_rwkv6_channel_mix(const llama_layer * layer,
10
                                                            ggml_tensor *       cur,
11
                                                            ggml_tensor *       x_prev,
12
0
                                                            llm_arch            arch) const {
13
0
    ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
14
0
    switch (arch) {
15
0
        case LLM_ARCH_RWKV6:
16
0
            {
17
0
                ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
18
0
                ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur);
19
20
0
                ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr));
21
0
                ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk)));
22
0
                cur             = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k));
23
0
            }
24
0
            break;
25
0
        default:
26
0
            GGML_ABORT("fatal error");
27
0
    }
28
0
    return cur;
29
0
}
30
31
ggml_tensor * llm_build_rwkv6_base::build_rwkv6_time_mix(llm_graph_input_rs * inp,
32
                                                         ggml_tensor *        cur,
33
                                                         ggml_tensor *        x_prev,
34
                                                         const llama_ubatch & ubatch,
35
0
                                                         int                  il) const {
36
0
    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
37
38
0
    const auto n_tokens     = ubatch.n_tokens;
39
0
    const auto n_seqs       = ubatch.n_seqs;
40
0
    const auto n_seq_tokens = ubatch.n_seq_tokens;
41
0
    const auto n_embd       = hparams.n_embd;
42
0
    const auto head_size    = hparams.wkv_head_size;
43
0
    const auto n_head       = n_embd / head_size;
44
0
    const auto n_head_kv    = hparams.n_head_kv(il);
45
46
0
    const auto kv_head = mctx_cur->get_head();
47
48
0
    const auto & layer = model.layers[il];
49
50
0
    bool is_qrwkv = layer.time_mix_first == nullptr;
51
52
0
    ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
53
54
0
    sx  = ggml_reshape_2d(ctx0, sx, n_embd, n_tokens);
55
0
    cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
56
57
0
    ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur);
58
59
0
    xxx = ggml_reshape_4d(ctx0, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xxx)),
60
0
                          layer.time_mix_w1->ne[1] / 5, 1, 5, n_tokens);
61
62
0
    xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2));
63
64
0
    xxx = ggml_mul_mat(
65
0
        ctx0, ggml_reshape_4d(ctx0, layer.time_mix_w2, layer.time_mix_w2->ne[0], layer.time_mix_w2->ne[1], 1, 5), xxx);
66
67
0
    ggml_tensor *xw, *xk, *xv, *xr, *xg;
68
0
    if (layer.time_mix_lerp_fused) {
69
        // fusing these weights makes some performance improvement
70
0
        sx  = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens);
71
0
        cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
72
0
        xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer.time_mix_lerp_fused), sx), cur);
73
0
        xw  = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
74
0
        xk  = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
75
0
        xv  = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
76
0
        xr  = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
77
0
        xg  = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
78
0
    } else {
79
        // for backward compatibility
80
0
        xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
81
0
        xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
82
0
        xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
83
0
        xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
84
0
        xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
85
86
0
        xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer.time_mix_lerp_w), sx), cur);
87
0
        xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer.time_mix_lerp_k), sx), cur);
88
0
        xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer.time_mix_lerp_v), sx), cur);
89
0
        xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer.time_mix_lerp_r), sx), cur);
90
0
        xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer.time_mix_lerp_g), sx), cur);
91
0
    }
92
0
    ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr);
93
0
    ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk);
94
0
    ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv);
95
0
    if (layer.time_mix_receptance_b) {
96
0
        r = ggml_add(ctx0, r, layer.time_mix_receptance_b);
97
0
    }
98
0
    if (layer.time_mix_key_b) {
99
0
        k = ggml_add(ctx0, k, layer.time_mix_key_b);
100
0
    }
101
0
    if (layer.time_mix_value_b) {
102
0
        v = ggml_add(ctx0, v, layer.time_mix_value_b);
103
0
    }
104
0
    ggml_tensor * g = build_lora_mm(layer.time_mix_gate, xg);
105
0
    if (is_qrwkv) {
106
0
        g = ggml_sigmoid(ctx0, g);
107
0
    } else {
108
0
        g = ggml_silu(ctx0, g);
109
0
    }
110
0
    if (n_head_kv != 0 && n_head_kv != n_head) {
111
0
        GGML_ASSERT(n_head % n_head_kv == 0);
112
0
        k                 = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens);
113
0
        v                 = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens);
114
0
        ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens);
115
0
        k                 = ggml_repeat(ctx0, k, tmp);
116
0
        v                 = ggml_repeat(ctx0, v, tmp);
117
0
    }
118
0
    k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens);
119
0
    v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens);
120
0
    r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens);
121
122
0
    ggml_tensor * w =
123
0
        ggml_mul_mat(ctx0, layer.time_mix_decay_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_decay_w1, xw)));
124
125
0
    w = ggml_add(ctx0, w, layer.time_mix_decay);
126
0
    w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w)));
127
0
    w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens);
128
129
0
    if (is_qrwkv) {
130
        // k = k * (1 - w)
131
0
        k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
132
0
    }
133
0
    ggml_tensor * wkv_state = build_rs(inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs);
134
135
0
    ggml_tensor * wkv_output;
136
0
    if (is_qrwkv) {
137
0
        wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f));
138
0
    } else {
139
0
        wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer.time_mix_first, w, wkv_state);
140
0
    }
141
0
    cur       = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
142
0
    wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
143
144
0
    ggml_build_forward_expand(
145
0
        gf, ggml_cpy(ctx0, wkv_state,
146
0
                     ggml_view_1d(ctx0, mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs,
147
0
                                  hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il)))));
148
149
0
    if (!is_qrwkv) {
150
        // group norm with head_count groups
151
0
        cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens);
152
0
        cur = ggml_norm(ctx0, cur, 64e-5f);
153
154
        // Convert back to regular vectors.
155
0
        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
156
0
        cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b);
157
0
    } else {
158
0
        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
159
0
    }
160
0
    cur = ggml_mul(ctx0, cur, g);
161
0
    cur = build_lora_mm(layer.time_mix_output, cur);
162
163
0
    return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs);
164
0
}