/src/llama.cpp/src/models/rwkv6-base.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : |
4 | 0 | llm_graph_context(params), |
5 | 0 | model(model) {} |
6 | | |
7 | | ggml_tensor * llm_build_rwkv6_base::build_rwkv6_channel_mix(const llama_layer * layer, |
8 | | ggml_tensor * cur, |
9 | | ggml_tensor * x_prev, |
10 | 0 | llm_arch arch) const { |
11 | 0 | ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); |
12 | 0 | switch (arch) { |
13 | 0 | case LLM_ARCH_RWKV6: |
14 | 0 | { |
15 | 0 | ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); |
16 | 0 | ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur); |
17 | |
|
18 | 0 | ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr)); |
19 | 0 | ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk))); |
20 | 0 | cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k)); |
21 | 0 | } |
22 | 0 | break; |
23 | 0 | default: |
24 | 0 | GGML_ABORT("fatal error"); |
25 | 0 | } |
26 | 0 | return cur; |
27 | 0 | } |
28 | | |
29 | | ggml_tensor * llm_build_rwkv6_base::build_rwkv6_time_mix(llm_graph_input_rs * inp, |
30 | | ggml_tensor * cur, |
31 | | ggml_tensor * x_prev, |
32 | | const llama_ubatch & ubatch, |
33 | 0 | int il) const { |
34 | 0 | const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); |
35 | |
|
36 | 0 | const auto n_tokens = ubatch.n_tokens; |
37 | 0 | const auto n_seqs = ubatch.n_seqs; |
38 | 0 | const auto n_seq_tokens = ubatch.n_seq_tokens; |
39 | 0 | const auto n_embd = hparams.n_embd; |
40 | 0 | const auto head_size = hparams.wkv_head_size; |
41 | 0 | const auto n_head = n_embd / head_size; |
42 | 0 | const auto n_head_kv = hparams.n_head_kv(il); |
43 | |
|
44 | 0 | const auto kv_head = mctx_cur->get_head(); |
45 | |
|
46 | 0 | const auto & layer = model.layers[il]; |
47 | |
|
48 | 0 | bool is_qrwkv = layer.time_mix_first == nullptr; |
49 | |
|
50 | 0 | ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); |
51 | |
|
52 | 0 | sx = ggml_reshape_2d(ctx0, sx, n_embd, n_tokens); |
53 | 0 | cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); |
54 | |
|
55 | 0 | ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur); |
56 | |
|
57 | 0 | xxx = ggml_reshape_4d(ctx0, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xxx)), |
58 | 0 | layer.time_mix_w1->ne[1] / 5, 1, 5, n_tokens); |
59 | |
|
60 | 0 | xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2)); |
61 | |
|
62 | 0 | xxx = ggml_mul_mat( |
63 | 0 | ctx0, ggml_reshape_4d(ctx0, layer.time_mix_w2, layer.time_mix_w2->ne[0], layer.time_mix_w2->ne[1], 1, 5), xxx); |
64 | |
|
65 | 0 | ggml_tensor *xw, *xk, *xv, *xr, *xg; |
66 | 0 | if (layer.time_mix_lerp_fused) { |
67 | | // fusing these weights makes some performance improvement |
68 | 0 | sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens); |
69 | 0 | cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); |
70 | 0 | xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer.time_mix_lerp_fused), sx), cur); |
71 | 0 | xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); |
72 | 0 | xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); |
73 | 0 | xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); |
74 | 0 | xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); |
75 | 0 | xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); |
76 | 0 | } else { |
77 | | // for backward compatibility |
78 | 0 | xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); |
79 | 0 | xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); |
80 | 0 | xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); |
81 | 0 | xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); |
82 | 0 | xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); |
83 | |
|
84 | 0 | xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer.time_mix_lerp_w), sx), cur); |
85 | 0 | xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer.time_mix_lerp_k), sx), cur); |
86 | 0 | xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer.time_mix_lerp_v), sx), cur); |
87 | 0 | xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer.time_mix_lerp_r), sx), cur); |
88 | 0 | xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer.time_mix_lerp_g), sx), cur); |
89 | 0 | } |
90 | 0 | ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr); |
91 | 0 | ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk); |
92 | 0 | ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv); |
93 | 0 | if (layer.time_mix_receptance_b) { |
94 | 0 | r = ggml_add(ctx0, r, layer.time_mix_receptance_b); |
95 | 0 | } |
96 | 0 | if (layer.time_mix_key_b) { |
97 | 0 | k = ggml_add(ctx0, k, layer.time_mix_key_b); |
98 | 0 | } |
99 | 0 | if (layer.time_mix_value_b) { |
100 | 0 | v = ggml_add(ctx0, v, layer.time_mix_value_b); |
101 | 0 | } |
102 | 0 | ggml_tensor * g = build_lora_mm(layer.time_mix_gate, xg); |
103 | 0 | if (is_qrwkv) { |
104 | 0 | g = ggml_sigmoid(ctx0, g); |
105 | 0 | } else { |
106 | 0 | g = ggml_silu(ctx0, g); |
107 | 0 | } |
108 | 0 | if (n_head_kv != 0 && n_head_kv != n_head) { |
109 | 0 | GGML_ASSERT(n_head % n_head_kv == 0); |
110 | 0 | k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens); |
111 | 0 | v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens); |
112 | 0 | ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens); |
113 | 0 | k = ggml_repeat(ctx0, k, tmp); |
114 | 0 | v = ggml_repeat(ctx0, v, tmp); |
115 | 0 | } |
116 | 0 | k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens); |
117 | 0 | v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens); |
118 | 0 | r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens); |
119 | |
|
120 | 0 | ggml_tensor * w = |
121 | 0 | ggml_mul_mat(ctx0, layer.time_mix_decay_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_decay_w1, xw))); |
122 | |
|
123 | 0 | w = ggml_add(ctx0, w, layer.time_mix_decay); |
124 | 0 | w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w))); |
125 | 0 | w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens); |
126 | |
|
127 | 0 | if (is_qrwkv) { |
128 | | // k = k * (1 - w) |
129 | 0 | k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); |
130 | 0 | } |
131 | 0 | ggml_tensor * wkv_state = build_rs(inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); |
132 | |
|
133 | 0 | ggml_tensor * wkv_output; |
134 | 0 | if (is_qrwkv) { |
135 | 0 | wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f)); |
136 | 0 | } else { |
137 | 0 | wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer.time_mix_first, w, wkv_state); |
138 | 0 | } |
139 | 0 | cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); |
140 | 0 | wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); |
141 | |
|
142 | 0 | ggml_build_forward_expand( |
143 | 0 | gf, ggml_cpy(ctx0, wkv_state, |
144 | 0 | ggml_view_1d(ctx0, mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs, |
145 | 0 | hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))))); |
146 | |
|
147 | 0 | if (!is_qrwkv) { |
148 | | // group norm with head_count groups |
149 | 0 | cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens); |
150 | 0 | cur = ggml_norm(ctx0, cur, 64e-5f); |
151 | | |
152 | | // Convert back to regular vectors. |
153 | 0 | cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); |
154 | 0 | cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b); |
155 | 0 | } else { |
156 | 0 | cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); |
157 | 0 | } |
158 | 0 | cur = ggml_mul(ctx0, cur, g); |
159 | 0 | cur = build_lora_mm(layer.time_mix_output, cur); |
160 | |
|
161 | 0 | return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs); |
162 | 0 | } |