/src/llama.cpp/src/models/afmoe.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | 0 | llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { |
4 | 0 | const int64_t n_embd_head = hparams.n_embd_head_v; |
5 | 0 | GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); |
6 | |
|
7 | 0 | ggml_tensor * cur; |
8 | 0 | ggml_tensor * inpL; |
9 | |
|
10 | 0 | inpL = build_inp_embd(model.tok_embd); |
11 | | |
12 | | // MuP scaling: embeddings * sqrt(hidden_size) |
13 | | // mup_enabled = true, hidden_size = 1024, scale = 32.0 |
14 | 0 | inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd))); |
15 | 0 | cb(inpL, "inp_embd_scaled", -1); |
16 | | |
17 | | // inp_pos - contains the positions |
18 | 0 | ggml_tensor * inp_pos = build_inp_pos(); |
19 | 0 | auto * inp_attn = build_attn_inp_kv_iswa(); |
20 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
21 | |
|
22 | 0 | const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); |
23 | |
|
24 | 0 | for (int il = 0; il < n_layer; ++il) { |
25 | 0 | ggml_tensor * inpSA = inpL; |
26 | | |
27 | | // dual attention normalization (pre) |
28 | 0 | cur = build_norm(inpL, |
29 | 0 | model.layers[il].attn_norm, NULL, |
30 | 0 | LLM_NORM_RMS, il); |
31 | 0 | cb(cur, "attn_norm", il); |
32 | | |
33 | | // self-attention |
34 | 0 | { |
35 | 0 | ggml_tensor * attn_inp = cur; // save input for gate computation |
36 | |
|
37 | 0 | ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); |
38 | 0 | cb(Qcur, "Qcur", il); |
39 | |
|
40 | 0 | ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); |
41 | 0 | cb(Kcur, "Kcur", il); |
42 | |
|
43 | 0 | ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); |
44 | 0 | cb(Vcur, "Vcur", il); |
45 | | |
46 | | // compute gate from input |
47 | 0 | ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp); |
48 | 0 | cb(gate, "attn_gate_proj", il); |
49 | |
|
50 | 0 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
51 | 0 | Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); |
52 | | |
53 | | // Q/K normalization |
54 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
55 | 0 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); |
56 | 0 | cb(Qcur, "Qcur_normed", il); |
57 | 0 | cb(Kcur, "Kcur_normed", il); |
58 | | |
59 | | // RoPE only for sliding_attention layers |
60 | 0 | const bool use_rope = hparams.n_no_rope_layer_step > 0 && |
61 | 0 | ((il + 1) % hparams.n_no_rope_layer_step) != 0; |
62 | 0 | if (use_rope) { |
63 | 0 | Qcur = ggml_rope_ext( |
64 | 0 | ctx0, Qcur, inp_pos, nullptr, |
65 | 0 | n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
66 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
67 | 0 | cb(Qcur, "Qcur_rope", il); |
68 | |
|
69 | 0 | Kcur = ggml_rope_ext( |
70 | 0 | ctx0, Kcur, inp_pos, nullptr, |
71 | 0 | n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
72 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
73 | 0 | cb(Kcur, "Kcur_rope", il); |
74 | 0 | } |
75 | |
|
76 | 0 | Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); |
77 | |
|
78 | 0 | cur = build_attn(inp_attn, |
79 | 0 | NULL, NULL, // wo will be applied after gating |
80 | 0 | Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); |
81 | 0 | cb(cur, "attn_out", il); |
82 | | |
83 | | // attention gating: attn_out * sigmoid(gate) BEFORE o_proj |
84 | 0 | gate = ggml_sigmoid(ctx0, gate); |
85 | 0 | cb(gate, "attn_gate_sig", il); |
86 | 0 | cur = ggml_mul(ctx0, cur, gate); |
87 | 0 | cb(cur, "attn_gated", il); |
88 | | |
89 | | // now apply output projection |
90 | 0 | cur = build_lora_mm(model.layers[il].wo, cur); |
91 | 0 | cb(cur, "attn_o_proj", il); |
92 | 0 | } |
93 | | |
94 | | // dual attention normalization (post) |
95 | 0 | cur = build_norm(cur, |
96 | 0 | model.layers[il].attn_post_norm, NULL, |
97 | 0 | LLM_NORM_RMS, il); |
98 | 0 | cb(cur, "attn_post_norm", il); |
99 | |
|
100 | 0 | if (il == n_layer - 1 && inp_out_ids) { |
101 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
102 | 0 | inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); |
103 | 0 | } |
104 | |
|
105 | 0 | ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); |
106 | 0 | cb(ffn_inp, "ffn_inp", il); |
107 | | |
108 | | // dual ffn normalization (pre) |
109 | 0 | cur = build_norm(ffn_inp, |
110 | 0 | model.layers[il].ffn_norm, NULL, |
111 | 0 | LLM_NORM_RMS, il); |
112 | 0 | cb(cur, "ffn_norm", il); |
113 | | |
114 | | // MoE or dense FFN |
115 | 0 | if ((uint32_t)il >= hparams.n_layer_dense_lead) { |
116 | | // MoE layer with sigmoid routing, normalization, and scaling |
117 | 0 | ggml_tensor * moe_out = build_moe_ffn(cur, |
118 | 0 | model.layers[il].ffn_gate_inp, |
119 | 0 | model.layers[il].ffn_up_exps, |
120 | 0 | model.layers[il].ffn_gate_exps, |
121 | 0 | model.layers[il].ffn_down_exps, |
122 | 0 | model.layers[il].ffn_exp_probs_b, |
123 | 0 | n_expert, n_expert_used, |
124 | 0 | LLM_FFN_SILU, |
125 | 0 | hparams.expert_weights_norm, // norm_w (route_norm=True) |
126 | 0 | hparams.expert_weights_scale, // scale_w |
127 | 0 | hparams.expert_weights_scale, // w_scale (route_scale=2.826) |
128 | 0 | (llama_expert_gating_func_type) hparams.expert_gating_func, |
129 | 0 | il); |
130 | 0 | cb(moe_out, "ffn_moe_out", il); |
131 | | |
132 | | // shared expert |
133 | 0 | if (hparams.n_expert_shared > 0) { |
134 | 0 | ggml_tensor * ffn_shexp = build_ffn(cur, |
135 | 0 | model.layers[il].ffn_up_shexp, NULL, NULL, |
136 | 0 | model.layers[il].ffn_gate_shexp, NULL, NULL, |
137 | 0 | model.layers[il].ffn_down_shexp, NULL, NULL, |
138 | 0 | NULL, |
139 | 0 | LLM_FFN_SILU, LLM_FFN_PAR, il); |
140 | 0 | cb(ffn_shexp, "ffn_shexp", il); |
141 | |
|
142 | 0 | cur = ggml_add(ctx0, moe_out, ffn_shexp); |
143 | 0 | cb(cur, "ffn_out", il); |
144 | 0 | } else { |
145 | 0 | cur = moe_out; |
146 | 0 | } |
147 | 0 | } else { |
148 | | // dense layer |
149 | 0 | cur = build_ffn(cur, |
150 | 0 | model.layers[il].ffn_up, NULL, NULL, |
151 | 0 | model.layers[il].ffn_gate, NULL, NULL, |
152 | 0 | model.layers[il].ffn_down, NULL, NULL, |
153 | 0 | NULL, |
154 | 0 | LLM_FFN_SILU, LLM_FFN_PAR, il); |
155 | 0 | cb(cur, "ffn_out", il); |
156 | 0 | } |
157 | | |
158 | | // dual ffn normalization (post) |
159 | 0 | cur = build_norm(cur, |
160 | 0 | model.layers[il].ffn_post_norm, NULL, |
161 | 0 | LLM_NORM_RMS, il); |
162 | 0 | cb(cur, "ffn_post_norm", il); |
163 | |
|
164 | 0 | cur = ggml_add(ctx0, cur, ffn_inp); |
165 | 0 | cur = build_cvec(cur, il); |
166 | 0 | cb(cur, "l_out", il); |
167 | | |
168 | | // input for next layer |
169 | 0 | inpL = cur; |
170 | 0 | } |
171 | |
|
172 | 0 | cur = inpL; |
173 | |
|
174 | 0 | cur = build_norm(cur, |
175 | 0 | model.output_norm, NULL, |
176 | 0 | LLM_NORM_RMS, -1); |
177 | 0 | cb(cur, "result_norm", -1); |
178 | |
|
179 | 0 | res->t_embd = cur; |
180 | | |
181 | | // lm_head |
182 | 0 | cur = build_lora_mm(model.output, cur); |
183 | 0 | cb(cur, "result_output", -1); |
184 | 0 | res->t_logits = cur; |
185 | |
|
186 | 0 | ggml_build_forward_expand(gf, cur); |
187 | 0 | } |