/src/llama.cpp/src/models/rwkv7-base.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : |
4 | 0 | llm_graph_context(params), |
5 | 0 | model(model) {} |
6 | | |
7 | | ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * layer, |
8 | | ggml_tensor * cur, |
9 | | ggml_tensor * x_prev, |
10 | 0 | llm_arch arch) const { |
11 | 0 | ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); |
12 | 0 | switch (arch) { |
13 | 0 | case LLM_ARCH_RWKV7: |
14 | 0 | { |
15 | 0 | ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); |
16 | |
|
17 | 0 | ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk))); |
18 | |
|
19 | 0 | cur = build_lora_mm(layer->channel_mix_value, k); |
20 | 0 | } |
21 | 0 | break; |
22 | 0 | default: |
23 | 0 | GGML_ABORT("fatal error"); |
24 | 0 | } |
25 | 0 | return cur; |
26 | 0 | } |
27 | | |
28 | | ggml_tensor * llm_build_rwkv7_base::build_rwkv7_time_mix(llm_graph_input_rs * inp, |
29 | | ggml_tensor * cur, |
30 | | ggml_tensor * x_prev, |
31 | | ggml_tensor *& first_layer_value, |
32 | | const llama_ubatch & ubatch, |
33 | 0 | int il) const { |
34 | 0 | const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); |
35 | |
|
36 | 0 | const auto n_tokens = ubatch.n_tokens; |
37 | 0 | const auto n_seqs = ubatch.n_seqs; |
38 | 0 | const auto n_embd = hparams.n_embd; |
39 | 0 | const auto head_size = hparams.wkv_head_size; |
40 | 0 | const auto head_count = n_embd / head_size; |
41 | 0 | const auto n_seq_tokens = ubatch.n_seq_tokens; |
42 | |
|
43 | 0 | const auto kv_head = mctx_cur->get_head(); |
44 | |
|
45 | 0 | const auto & layer = model.layers[il]; |
46 | |
|
47 | 0 | bool has_gating = layer.time_mix_g1 && layer.time_mix_g2; |
48 | |
|
49 | 0 | ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); |
50 | 0 | ggml_tensor * dummy = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5); |
51 | 0 | sx = ggml_repeat(ctx0, sx, dummy); |
52 | |
|
53 | 0 | ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_fused), cur); |
54 | |
|
55 | 0 | ggml_tensor * xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); |
56 | 0 | ggml_tensor * xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); |
57 | 0 | ggml_tensor * xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); |
58 | 0 | ggml_tensor * xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); |
59 | 0 | ggml_tensor * xa = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); |
60 | 0 | ggml_tensor * xg = |
61 | 0 | has_gating ? ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float)) : |
62 | 0 | nullptr; |
63 | |
|
64 | 0 | ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr); |
65 | 0 | ggml_tensor * w = ggml_add( |
66 | 0 | ctx0, ggml_mul_mat(ctx0, layer.time_mix_w2, ggml_tanh(ctx0, ggml_mul_mat(ctx0, layer.time_mix_w1, xw))), |
67 | 0 | layer.time_mix_w0); |
68 | 0 | w = ggml_exp(ctx0, ggml_scale(ctx0, ggml_sigmoid(ctx0, w), -0.606531)); |
69 | |
|
70 | 0 | ggml_tensor * k = build_lora_mm(layer.time_mix_key, xk); |
71 | 0 | ggml_tensor * v = build_lora_mm(layer.time_mix_value, xv); |
72 | 0 | if (first_layer_value == nullptr) { |
73 | 0 | first_layer_value = v; |
74 | 0 | } else { |
75 | | // Add the first layer value as a residual connection. |
76 | 0 | v = ggml_add(ctx0, v, |
77 | 0 | ggml_mul(ctx0, ggml_sub(ctx0, first_layer_value, v), |
78 | 0 | ggml_sigmoid(ctx0, ggml_add(ctx0, |
79 | 0 | ggml_mul_mat(ctx0, layer.time_mix_v2, |
80 | 0 | ggml_mul_mat(ctx0, layer.time_mix_v1, xv)), |
81 | 0 | layer.time_mix_v0)))); |
82 | 0 | } |
83 | 0 | ggml_tensor * g = nullptr; |
84 | 0 | if (layer.time_mix_g1 && layer.time_mix_g2) { |
85 | 0 | g = ggml_mul_mat(ctx0, layer.time_mix_g2, ggml_sigmoid(ctx0, ggml_mul_mat(ctx0, layer.time_mix_g1, xg))); |
86 | 0 | } |
87 | 0 | ggml_tensor * a = ggml_sigmoid( |
88 | 0 | ctx0, ggml_add(ctx0, ggml_mul_mat(ctx0, layer.time_mix_a2, ggml_mul_mat(ctx0, layer.time_mix_a1, xa)), |
89 | 0 | layer.time_mix_a0)); |
90 | |
|
91 | 0 | ggml_tensor * kk = ggml_reshape_3d(ctx0, ggml_mul(ctx0, k, layer.time_mix_k_k), head_size, head_count, n_tokens); |
92 | 0 | kk = ggml_l2_norm(ctx0, kk, 1e-12); |
93 | |
|
94 | 0 | ggml_tensor * ka = ggml_mul(ctx0, k, layer.time_mix_k_a); |
95 | 0 | k = ggml_add(ctx0, k, ggml_sub(ctx0, ggml_mul(ctx0, a, ka), ka)); |
96 | |
|
97 | 0 | r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens); |
98 | 0 | w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens); |
99 | 0 | k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens); |
100 | 0 | v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); |
101 | 0 | a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); |
102 | |
|
103 | 0 | ggml_tensor * wkv_state = build_rs(inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); |
104 | |
|
105 | 0 | ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); |
106 | 0 | cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); |
107 | 0 | wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); |
108 | |
|
109 | 0 | ggml_build_forward_expand( |
110 | 0 | gf, ggml_cpy(ctx0, wkv_state, |
111 | 0 | ggml_view_1d(ctx0, mctx_cur->get_s_l(il), hparams.n_embd_s() * n_seqs, |
112 | 0 | hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))))); |
113 | |
|
114 | 0 | if (layer.time_mix_ln && layer.time_mix_ln_b) { |
115 | | // group norm with head_count groups |
116 | 0 | cur = ggml_reshape_3d(ctx0, cur, n_embd / head_count, head_count, n_tokens); |
117 | 0 | cur = ggml_norm(ctx0, cur, 64e-5f); |
118 | | |
119 | | // Convert back to regular vectors. |
120 | 0 | cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); |
121 | 0 | cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b); |
122 | 0 | } else { |
123 | 0 | cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); |
124 | 0 | } |
125 | 0 | ggml_tensor * rk = ggml_sum_rows( |
126 | 0 | ctx0, ggml_mul(ctx0, ggml_mul(ctx0, k, r), ggml_reshape_2d(ctx0, layer.time_mix_r_k, head_size, head_count))); |
127 | 0 | cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, ggml_mul(ctx0, v, rk), n_embd, n_tokens)); |
128 | |
|
129 | 0 | if (has_gating) { |
130 | 0 | cur = ggml_mul(ctx0, cur, g); |
131 | 0 | } |
132 | 0 | cur = build_lora_mm(layer.time_mix_output, cur); |
133 | |
|
134 | 0 | return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs); |
135 | 0 | } |