/src/llama.cpp/src/models/gemma4-iswa.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim |
4 | 0 | static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { |
5 | 0 | GGML_ASSERT(idx < (int) x->ne[2]); |
6 | 0 | return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), |
7 | 0 | idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); |
8 | 0 | } |
9 | | |
10 | | llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) : |
11 | 0 | llm_graph_context(params), |
12 | 0 | model(model), |
13 | 0 | n_embd_per_layer(model.hparams.n_embd_per_layer) { |
14 | 0 | ggml_tensor * cur; |
15 | 0 | ggml_tensor * inpL; |
16 | |
|
17 | 0 | inpL = build_inp_embd(model.tok_embd); |
18 | | |
19 | | // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) |
20 | 0 | inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); |
21 | 0 | cb(inpL, "inp_scaled", -1); |
22 | | |
23 | | // inp_pos - contains the positions |
24 | 0 | ggml_tensor * inp_pos = build_inp_pos(); |
25 | | |
26 | | // TODO: is causal == true correct? might need some changes |
27 | 0 | auto * inp_attn = build_attn_inp_kv_iswa(); |
28 | |
|
29 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
30 | |
|
31 | 0 | ggml_tensor * inp_per_layer = nullptr; |
32 | 0 | if (model.per_layer_tok_embd) { |
33 | 0 | inp_per_layer = build_inp_per_layer(); |
34 | 0 | ggml_build_forward_expand(gf, inp_per_layer); |
35 | | |
36 | | // inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer] |
37 | 0 | inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer); |
38 | 0 | } |
39 | |
|
40 | 0 | for (int il = 0; il < n_layer; ++il) { |
41 | 0 | const int64_t n_embd_head = hparams.n_embd_head_k(il); |
42 | 0 | GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il)); |
43 | |
|
44 | 0 | const int64_t n_head = hparams.n_head(il); |
45 | 0 | const int64_t n_head_kv = hparams.n_head_kv(il); |
46 | |
|
47 | 0 | const float freq_base_l = model.get_rope_freq_base(cparams, il); |
48 | 0 | const float freq_scale_l = model.get_rope_freq_scale(cparams, il); |
49 | 0 | const int n_rot_l = hparams.n_rot(il); |
50 | | |
51 | | // norm |
52 | 0 | cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); |
53 | 0 | cb(cur, "attn_norm", il); |
54 | |
|
55 | 0 | ggml_tensor * freq_factors = nullptr; |
56 | 0 | if (!hparams.is_swa(il)) { |
57 | | // full_attention layers use rope_freqs for proportional rope |
58 | 0 | freq_factors = model.layers[il].rope_freqs; |
59 | 0 | } |
60 | | |
61 | | // Q projection (shared for both non-KV and KV layers) |
62 | | // this is to mirror Gemma4Attention in pytorch code |
63 | 0 | ggml_tensor * Qcur; |
64 | 0 | { |
65 | 0 | Qcur = build_lora_mm(model.layers[il].wq, cur); |
66 | 0 | cb(Qcur, "Qcur", il); |
67 | |
|
68 | 0 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
69 | |
|
70 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); |
71 | 0 | cb(Qcur, "Qcur_normed", il); |
72 | |
|
73 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
74 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
75 | 0 | cb(Qcur, "Qcur_pos", il); |
76 | 0 | } |
77 | | |
78 | | // self-attention |
79 | 0 | if (hparams.has_kv(il)) { |
80 | 0 | ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); |
81 | 0 | cb(Kcur, "Kcur", il); |
82 | |
|
83 | 0 | ggml_tensor * Vcur = model.layers[il].wv |
84 | 0 | ? build_lora_mm(model.layers[il].wv, cur) |
85 | 0 | : Kcur; // if v_proj is not present, use Kcur as Vcur |
86 | 0 | cb(Vcur, "Vcur", il); |
87 | |
|
88 | 0 | Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); |
89 | 0 | Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); |
90 | |
|
91 | 0 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); |
92 | 0 | Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); |
93 | |
|
94 | 0 | cb(Kcur, "Kcur_normed", il); |
95 | 0 | cb(Vcur, "Vcur_normed", il); |
96 | |
|
97 | 0 | Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
98 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
99 | |
|
100 | 0 | cb(Kcur, "Kcur_pos", il); |
101 | |
|
102 | 0 | cur = build_attn(inp_attn, model.layers[il].wo, |
103 | 0 | nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, |
104 | 0 | hparams.f_attention_scale, il); |
105 | 0 | } else { |
106 | | // reuse KV cache of earlier layers |
107 | 0 | cur = build_attn(inp_attn, |
108 | 0 | model.layers[il].wo, nullptr, |
109 | 0 | Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); |
110 | 0 | } |
111 | | |
112 | | // TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing |
113 | 0 | if (il == n_layer - 1 && inp_out_ids) { |
114 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
115 | 0 | inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); |
116 | 0 | } |
117 | 0 | cur = build_norm(cur, |
118 | 0 | model.layers[il].attn_post_norm, nullptr, |
119 | 0 | LLM_NORM_RMS, il); |
120 | 0 | cb(cur, "attn_post_norm", il); |
121 | |
|
122 | 0 | ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); |
123 | 0 | cb(attn_out, "attn_out", il); |
124 | | |
125 | | // feed-forward network |
126 | 0 | const bool is_moe_layer = model.layers[il].ffn_gate_inp != nullptr; |
127 | 0 | if (is_moe_layer) { |
128 | | // MLP (shared exp) |
129 | 0 | ggml_tensor * cur_mlp = build_norm(attn_out, |
130 | 0 | model.layers[il].ffn_norm, nullptr, |
131 | 0 | LLM_NORM_RMS, il); |
132 | 0 | cb(cur_mlp, "ffn_norm_1", il); |
133 | |
|
134 | 0 | cur_mlp = build_ffn(cur_mlp, |
135 | 0 | model.layers[il].ffn_up, nullptr, nullptr, |
136 | 0 | model.layers[il].ffn_gate, nullptr, nullptr, |
137 | 0 | model.layers[il].ffn_down, nullptr, nullptr, |
138 | 0 | nullptr, |
139 | 0 | LLM_FFN_GELU, LLM_FFN_PAR, il); |
140 | 0 | cur_mlp = build_norm(cur_mlp, |
141 | 0 | model.layers[il].ffn_post_norm_1, nullptr, |
142 | 0 | LLM_NORM_RMS, il); |
143 | 0 | cb(cur_mlp, "ffn_mlp", il); |
144 | | |
145 | | // Expert FFN |
146 | 0 | ggml_tensor * cur_moe = build_norm(attn_out, |
147 | 0 | model.layers[il].ffn_pre_norm_2, nullptr, |
148 | 0 | LLM_NORM_RMS, il); |
149 | 0 | cb(cur_moe, "ffn_norm_2", il); |
150 | | |
151 | | // custom MoE logits calculation (router operates on attn_out, not cur) |
152 | 0 | ggml_tensor * tmp = ggml_rms_norm(ctx0, attn_out, hparams.f_norm_rms_eps); |
153 | 0 | tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) n_embd)); |
154 | 0 | tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_gate_inp_s); |
155 | 0 | ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, tmp); // [n_expert, n_tokens] |
156 | 0 | cb(logits, "ffn_moe_logits", il); |
157 | |
|
158 | 0 | cur_moe = build_moe_ffn(cur_moe, |
159 | 0 | nullptr, // gate_inp |
160 | 0 | nullptr, // up_exps |
161 | 0 | nullptr, // gate_exps |
162 | 0 | model.layers[il].ffn_down_exps, |
163 | 0 | nullptr, // exp_probs_b (not used for gemma4) |
164 | 0 | n_expert, n_expert_used, |
165 | 0 | LLM_FFN_GELU, true, |
166 | 0 | 1.0f, |
167 | 0 | LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, |
168 | 0 | il, logits, |
169 | 0 | model.layers[il].ffn_gate_up_exps, |
170 | 0 | nullptr, // up_exps_s |
171 | 0 | nullptr, // gate_exps_s |
172 | 0 | model.layers[il].ffn_down_exps_s); |
173 | 0 | cur_moe = build_norm(cur_moe, |
174 | 0 | model.layers[il].ffn_post_norm_2, nullptr, |
175 | 0 | LLM_NORM_RMS, il); |
176 | 0 | cb(cur_moe, "ffn_moe", il); |
177 | |
|
178 | 0 | cur = ggml_add(ctx0, cur_mlp, cur_moe); |
179 | 0 | cb(cur, "ffn_moe_combined", il); |
180 | 0 | } else { |
181 | 0 | cur = build_norm(attn_out, |
182 | 0 | model.layers[il].ffn_norm, nullptr, |
183 | 0 | LLM_NORM_RMS, il); |
184 | 0 | cb(cur, "ffn_norm", il); |
185 | |
|
186 | 0 | cur = build_ffn(cur, |
187 | 0 | model.layers[il].ffn_up, nullptr, nullptr, |
188 | 0 | model.layers[il].ffn_gate, nullptr, nullptr, |
189 | 0 | model.layers[il].ffn_down, nullptr, nullptr, |
190 | 0 | nullptr, |
191 | 0 | LLM_FFN_GELU, LLM_FFN_PAR, il); |
192 | 0 | cb(cur, "ffn_out", il); |
193 | 0 | } |
194 | 0 | cur = build_norm(cur, |
195 | 0 | model.layers[il].ffn_post_norm, nullptr, |
196 | 0 | LLM_NORM_RMS, -1); |
197 | 0 | cb(cur, "ffn_post_norm", il); |
198 | | |
199 | | // residual connection |
200 | 0 | cur = ggml_add(ctx0, cur, attn_out); |
201 | | |
202 | | // per-layer embedding |
203 | 0 | if (inp_per_layer) { |
204 | 0 | ggml_tensor * pe_in = cur; |
205 | 0 | cb(cur, "pe_in", il); |
206 | |
|
207 | 0 | cur = build_lora_mm(model.layers[il].per_layer_inp_gate, cur); // [n_embd_per_layer, n_tokens] |
208 | 0 | cur = ggml_gelu(ctx0, cur); |
209 | |
|
210 | 0 | ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens] |
211 | | |
212 | | // TODO @ngxson : improve this |
213 | 0 | if (il == n_layer - 1 && inp_out_ids) { |
214 | 0 | inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids); |
215 | 0 | } |
216 | |
|
217 | 0 | cur = ggml_mul(ctx0, cur, inp_this_layer); |
218 | 0 | cur = build_lora_mm(model.layers[il].per_layer_proj, cur); // [n_embd, n_tokens] |
219 | 0 | cur = build_norm(cur, model.layers[il].per_layer_post_norm, nullptr, LLM_NORM_RMS, il); |
220 | 0 | cb(cur, "per_layer_embd_out", il); |
221 | | |
222 | | // residual connection |
223 | 0 | cur = ggml_add(ctx0, pe_in, cur); |
224 | 0 | } |
225 | | |
226 | | // layer_scalar |
227 | 0 | if (model.layers[il].out_scale) { |
228 | 0 | cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); |
229 | 0 | cb(cur, "out_scaled", il); |
230 | 0 | } |
231 | |
|
232 | 0 | cur = build_cvec(cur, il); |
233 | 0 | cb(cur, "l_out", il); |
234 | | |
235 | | // input for next layer |
236 | 0 | inpL = cur; |
237 | 0 | } |
238 | 0 | cur = inpL; |
239 | |
|
240 | 0 | cur = build_norm(cur, |
241 | 0 | model.output_norm, nullptr, |
242 | 0 | LLM_NORM_RMS, -1); |
243 | |
|
244 | 0 | cb(cur, "result_norm", -1); |
245 | 0 | res->t_embd = cur; |
246 | | |
247 | | // lm_head |
248 | 0 | cur = build_lora_mm(model.output, cur); |
249 | |
|
250 | 0 | if (hparams.f_final_logit_softcapping) { |
251 | 0 | cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); |
252 | 0 | cur = ggml_tanh(ctx0, cur); |
253 | 0 | cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); |
254 | 0 | } |
255 | |
|
256 | 0 | cb(cur, "result_output", -1); |
257 | 0 | res->t_logits = cur; |
258 | |
|
259 | 0 | ggml_build_forward_expand(gf, cur); |
260 | 0 | } |
261 | | |
262 | | // equivalent to get_per_layer_inputs() in python code |
263 | | // output shape: [n_embd_per_layer, n_layer, n_tokens] |
264 | 0 | ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() { |
265 | 0 | auto inp = std::make_unique<llm_graph_input_embd>(n_embd); |
266 | |
|
267 | 0 | ggml_tensor * inp_per_layer; |
268 | 0 | float tok_embd_scale = sqrtf((float) n_embd_per_layer); |
269 | 0 | if (ubatch.token) { |
270 | 0 | inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); |
271 | 0 | ggml_set_input(inp->tokens); |
272 | 0 | res->t_inp_tokens = inp->tokens; |
273 | |
|
274 | 0 | inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens); |
275 | 0 | inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens); |
276 | 0 | inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale); |
277 | 0 | cb(inp_per_layer, "inp_per_layer_selected", -1); |
278 | |
|
279 | 0 | res->add_input(std::move(inp)); |
280 | 0 | } else { |
281 | | // Multimodal embedding path: use padding token (ID=0) embedding |
282 | | // TODO: verify if this is the correct behavior in transformers implementation |
283 | 0 | const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_per_layer * n_layer |
284 | | |
285 | | // Extract and dequantize padding token embedding (row 0) |
286 | 0 | ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0); |
287 | 0 | inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32); |
288 | 0 | inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale); |
289 | | |
290 | | // Reshape to [n_embd_per_layer, n_layer, 1] |
291 | 0 | inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, 1); |
292 | 0 | cb(inp_per_layer, "inp_per_layer_multimodal", -1); |
293 | 0 | } |
294 | 0 | return inp_per_layer; |
295 | 0 | } |
296 | | |
297 | | // equivalent to project_per_layer_inputs() in python code |
298 | | // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim |
299 | | // inp_batch shape: [n_embd, n_tokens] |
300 | | // inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer) |
301 | | // output shape: [n_embd_per_layer, n_tokens, n_layer] |
302 | 0 | ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { |
303 | 0 | const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); |
304 | 0 | const float per_layer_input_scale = 1.0f / sqrtf(2.0f); |
305 | | |
306 | | // note: this matrix multiplication will be performed in the input layer (i.e. on the CPU) |
307 | 0 | ggml_tensor * per_layer_proj; |
308 | 0 | per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch); |
309 | 0 | per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale); |
310 | 0 | per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens); |
311 | |
|
312 | 0 | per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS, -1); |
313 | 0 | cb(per_layer_proj, "per_layer_proj", -1); |
314 | |
|
315 | 0 | inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer); |
316 | 0 | inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); |
317 | 0 | cb(inp_per_layer, "inp_per_layer", -1); |
318 | | |
319 | | // permute to shape: [n_embd_per_layer, n_tokens, n_layer] |
320 | 0 | inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); |
321 | 0 | return inp_per_layer; |
322 | 0 | } |