/src/llama.cpp/src/models/rwkv6-base.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | #include "llama-memory-recurrent.h" |
4 | | |
5 | | llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : |
6 | 0 | llm_graph_context(params), |
7 | 0 | model(model) {} |
8 | | |
9 | | ggml_tensor * llm_build_rwkv6_base::build_rwkv6_channel_mix(const llama_layer * layer, |
10 | | ggml_tensor * cur, |
11 | | ggml_tensor * x_prev, |
12 | 0 | llm_arch arch) const { |
13 | 0 | ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); |
14 | 0 | switch (arch) { |
15 | 0 | case LLM_ARCH_RWKV6: |
16 | 0 | { |
17 | 0 | ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); |
18 | 0 | ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur); |
19 | |
|
20 | 0 | ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr)); |
21 | 0 | ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk))); |
22 | 0 | cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k)); |
23 | 0 | } |
24 | 0 | break; |
25 | 0 | default: |
26 | 0 | GGML_ABORT("fatal error"); |
27 | 0 | } |
28 | 0 | return cur; |
29 | 0 | } |
30 | | |
31 | | ggml_tensor * llm_build_rwkv6_base::build_rwkv6_time_mix(llm_graph_input_rs * inp, |
32 | | ggml_tensor * cur, |
33 | | ggml_tensor * x_prev, |
34 | | const llama_ubatch & ubatch, |
35 | 0 | int il) const { |
36 | 0 | const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); |
37 | |
|
38 | 0 | const auto n_tokens = ubatch.n_tokens; |
39 | 0 | const auto n_seqs = ubatch.n_seqs; |
40 | 0 | const auto n_seq_tokens = ubatch.n_seq_tokens; |
41 | 0 | const auto n_embd = hparams.n_embd; |
42 | 0 | const auto head_size = hparams.wkv_head_size; |
43 | 0 | const auto n_head = n_embd / head_size; |
44 | 0 | const auto n_head_kv = hparams.n_head_kv(il); |
45 | |
|
46 | 0 | const auto kv_head = mctx_cur->get_head(); |
47 | |
|
48 | 0 | const auto & layer = model.layers[il]; |
49 | |
|
50 | 0 | bool is_qrwkv = layer.time_mix_first == nullptr; |
51 | |
|
52 | 0 | ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); |
53 | |
|
54 | 0 | sx = ggml_reshape_2d(ctx0, sx, n_embd, n_tokens); |
55 | 0 | cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); |
56 | |
|
57 | 0 | ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur); |
58 | |
|
59 | 0 | xxx = ggml_reshape_4d(ctx0, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xxx)), |
60 | 0 | layer.time_mix_w1->ne[1] / 5, 1, 5, n_tokens); |
61 | |
|
62 | 0 | xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2)); |
63 | |
|
64 | 0 | xxx = ggml_mul_mat( |
65 | 0 | ctx0, ggml_reshape_4d(ctx0, layer.time_mix_w2, layer.time_mix_w2->ne[0], layer.time_mix_w2->ne[1], 1, 5), xxx); |
66 | |
|
67 | 0 | ggml_tensor *xw, *xk, *xv, *xr, *xg; |
68 | 0 | if (layer.time_mix_lerp_fused) { |
69 | | // fusing these weights makes some performance improvement |
70 | 0 | sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens); |
71 | 0 | cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); |
72 | 0 | xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer.time_mix_lerp_fused), sx), cur); |
73 | 0 | xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); |
74 | 0 | xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); |
75 | 0 | xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); |
76 | 0 | xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); |
77 | 0 | xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); |
78 | 0 | } else { |
79 | | // for backward compatibility |
80 | 0 | xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); |
81 | 0 | xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); |
82 | 0 | xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); |
83 | 0 | xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); |
84 | 0 | xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); |
85 | |
|
86 | 0 | xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer.time_mix_lerp_w), sx), cur); |
87 | 0 | xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer.time_mix_lerp_k), sx), cur); |
88 | 0 | xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer.time_mix_lerp_v), sx), cur); |
89 | 0 | xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer.time_mix_lerp_r), sx), cur); |
90 | 0 | xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer.time_mix_lerp_g), sx), cur); |
91 | 0 | } |
92 | 0 | ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr); |
93 | 0 | ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk); |
94 | 0 | ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv); |
95 | 0 | if (layer.time_mix_receptance_b) { |
96 | 0 | r = ggml_add(ctx0, r, layer.time_mix_receptance_b); |
97 | 0 | } |
98 | 0 | if (layer.time_mix_key_b) { |
99 | 0 | k = ggml_add(ctx0, k, layer.time_mix_key_b); |
100 | 0 | } |
101 | 0 | if (layer.time_mix_value_b) { |
102 | 0 | v = ggml_add(ctx0, v, layer.time_mix_value_b); |
103 | 0 | } |
104 | 0 | ggml_tensor * g = build_lora_mm(layer.time_mix_gate, xg); |
105 | 0 | if (is_qrwkv) { |
106 | 0 | g = ggml_sigmoid(ctx0, g); |
107 | 0 | } else { |
108 | 0 | g = ggml_silu(ctx0, g); |
109 | 0 | } |
110 | 0 | if (n_head_kv != 0 && n_head_kv != n_head) { |
111 | 0 | GGML_ASSERT(n_head % n_head_kv == 0); |
112 | 0 | k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens); |
113 | 0 | v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens); |
114 | 0 | ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens); |
115 | 0 | k = ggml_repeat(ctx0, k, tmp); |
116 | 0 | v = ggml_repeat(ctx0, v, tmp); |
117 | 0 | } |
118 | 0 | k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens); |
119 | 0 | v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens); |
120 | 0 | r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens); |
121 | |
|
122 | 0 | ggml_tensor * w = |
123 | 0 | ggml_mul_mat(ctx0, layer.time_mix_decay_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_decay_w1, xw))); |
124 | |
|
125 | 0 | w = ggml_add(ctx0, w, layer.time_mix_decay); |
126 | 0 | w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w))); |
127 | 0 | w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens); |
128 | |
|
129 | 0 | if (is_qrwkv) { |
130 | | // k = k * (1 - w) |
131 | 0 | k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); |
132 | 0 | } |
133 | 0 | ggml_tensor * wkv_state = build_rs(inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); |
134 | |
|
135 | 0 | ggml_tensor * wkv_output; |
136 | 0 | if (is_qrwkv) { |
137 | 0 | wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f)); |
138 | 0 | } else { |
139 | 0 | wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer.time_mix_first, w, wkv_state); |
140 | 0 | } |
141 | 0 | cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); |
142 | 0 | wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); |
143 | |
|
144 | 0 | ggml_build_forward_expand( |
145 | 0 | gf, ggml_cpy(ctx0, wkv_state, |
146 | 0 | ggml_view_1d(ctx0, mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs, |
147 | 0 | hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))))); |
148 | |
|
149 | 0 | if (!is_qrwkv) { |
150 | | // group norm with head_count groups |
151 | 0 | cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens); |
152 | 0 | cur = ggml_norm(ctx0, cur, 64e-5f); |
153 | | |
154 | | // Convert back to regular vectors. |
155 | 0 | cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); |
156 | 0 | cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b); |
157 | 0 | } else { |
158 | 0 | cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); |
159 | 0 | } |
160 | 0 | cur = ggml_mul(ctx0, cur, g); |
161 | 0 | cur = build_lora_mm(layer.time_mix_output, cur); |
162 | |
|
163 | 0 | return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs); |
164 | 0 | } |