/src/llama.cpp/src/models/gemma3n-iswa.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | |
4 | | |
5 | | llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : |
6 | 0 | llm_graph_context(params), |
7 | 0 | model(model), |
8 | 0 | n_embd_head(model.hparams.n_embd_head_k), |
9 | 0 | n_embd_altup(model.hparams.n_embd_altup), |
10 | 0 | n_altup(model.hparams.n_altup), |
11 | 0 | i_altup_act(model.hparams.i_altup_act) { |
12 | 0 | ggml_tensor * cur; |
13 | 0 | ggml_tensor * inpL; |
14 | |
|
15 | 0 | inpL = build_inp_embd(model.tok_embd); |
16 | | |
17 | | // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) |
18 | 0 | if (ubatch.token) { |
19 | 0 | inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); |
20 | 0 | cb(inpL, "inp_scaled", -1); |
21 | 0 | } |
22 | | // inp_pos - contains the positions |
23 | 0 | ggml_tensor * inp_pos = build_inp_pos(); |
24 | | |
25 | | // TODO: is causal == true correct? might need some changes |
26 | 0 | auto * inp_attn = build_attn_inp_kv_iswa(); |
27 | | |
28 | | // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] |
29 | 0 | ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); |
30 | | |
31 | | // inpL now has only 1 altup, project it to the rest of the altups |
32 | | // these "added" altups will be concat to the last dim of inpL |
33 | 0 | { |
34 | 0 | ggml_tensor * target_magnitude = calc_magnitude(inpL); |
35 | 0 | ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1); |
36 | 0 | ggml_tensor * altup_added = |
37 | 0 | ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1] |
38 | 0 | ggml_tensor * new_magnitude = calc_magnitude(altup_added); |
39 | 0 | altup_added = ggml_div(ctx0, ggml_mul(ctx0, altup_added, target_magnitude), new_magnitude); |
40 | 0 | inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] |
41 | 0 | cb(inpL, "inp_stacked", -1); |
42 | 0 | } |
43 | | // inpL now has shape: [n_embd, n_tokens, n_altup] |
44 | | // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] |
45 | |
|
46 | 0 | for (int il = 0; il < n_layer; ++il) { |
47 | | // this block is made to be closely resemble Gemma3p5DecoderLayer on python code |
48 | 0 | const float freq_base_l = model.get_rope_freq_base(cparams, il); |
49 | 0 | const float freq_scale_l = model.get_rope_freq_scale(cparams, il); |
50 | |
|
51 | 0 | ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup] |
52 | 0 | ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] |
53 | | |
54 | | // predicted value will go through self-attention and laurel |
55 | 0 | ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens] |
56 | 0 | cur = active_prediction; |
57 | 0 | cb(cur, "active_prediction", il); |
58 | | |
59 | | // norm |
60 | 0 | cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); |
61 | 0 | cb(cur, "attn_norm", il); |
62 | | |
63 | | // laurel |
64 | 0 | ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens] |
65 | | |
66 | | // self-attention |
67 | 0 | if (hparams.has_kv(il)) { |
68 | | // compute Q and K and RoPE them |
69 | 0 | ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); |
70 | 0 | cb(Qcur, "Qcur", il); |
71 | |
|
72 | 0 | ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); |
73 | 0 | cb(Kcur, "Kcur", il); |
74 | |
|
75 | 0 | ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); |
76 | 0 | cb(Vcur, "Vcur", il); |
77 | |
|
78 | 0 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
79 | 0 | Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); |
80 | 0 | Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); |
81 | |
|
82 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
83 | 0 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); |
84 | 0 | Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); |
85 | |
|
86 | 0 | cb(Qcur, "Qcur_normed", il); |
87 | 0 | cb(Kcur, "Kcur_normed", il); |
88 | 0 | cb(Vcur, "Vcur_normed", il); |
89 | |
|
90 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
91 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
92 | |
|
93 | 0 | Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
94 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
95 | |
|
96 | 0 | cb(Qcur, "Qcur_pos", il); |
97 | 0 | cb(Kcur, "Kcur_pos", il); |
98 | |
|
99 | 0 | cur = build_attn(inp_attn, model.layers[il].wo, |
100 | 0 | NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, |
101 | 0 | hparams.f_attention_scale, il); |
102 | 0 | } else { |
103 | | // reuse KV cache of earlier layers |
104 | 0 | ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); |
105 | 0 | cb(Qcur, "Qcur", il); |
106 | 0 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
107 | |
|
108 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
109 | 0 | cb(Qcur, "Qcur_normed", il); |
110 | |
|
111 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
112 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
113 | 0 | cb(Qcur, "Qcur_pos", il); |
114 | |
|
115 | 0 | cur = build_attn(inp_attn, |
116 | 0 | model.layers[il].wo, NULL, |
117 | 0 | Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); |
118 | 0 | } |
119 | 0 | cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); |
120 | 0 | cb(cur, "attn_post_norm", il); |
121 | |
|
122 | 0 | cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens] |
123 | 0 | cb(cur, "attn_gated", il); |
124 | |
|
125 | 0 | ggml_tensor * attn_laurel = ggml_scale(ctx0, ggml_add(ctx0, cur, laurel_out), |
126 | 0 | 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens] |
127 | 0 | cb(attn_laurel, "attn_laurel", il); |
128 | |
|
129 | 0 | cur = build_norm(attn_laurel, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); |
130 | 0 | cb(cur, "ffn_norm", il); |
131 | | |
132 | | // feed-forward network |
133 | 0 | { |
134 | 0 | ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur); |
135 | 0 | ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur); |
136 | |
|
137 | 0 | if (il < n_layer_sparsity) { |
138 | | // apply activation sparsity |
139 | 0 | gate_proj = gaussian_topk(gate_proj); |
140 | 0 | } |
141 | 0 | gate_proj = ggml_gelu(ctx0, gate_proj); |
142 | |
|
143 | 0 | cur = ggml_mul(ctx0, up_proj, gate_proj); |
144 | 0 | cur = build_lora_mm(model.layers[il].ffn_down, cur); |
145 | 0 | cb(cur, "ffn_out", il); |
146 | 0 | } |
147 | 0 | cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1); |
148 | 0 | cb(cur, "ffn_post_norm", il); |
149 | |
|
150 | 0 | ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens] |
151 | 0 | cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il); |
152 | |
|
153 | 0 | ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup] |
154 | |
|
155 | 0 | ggml_tensor * first_prediction; // [n_embd, n_tokens] |
156 | 0 | { |
157 | 0 | first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] |
158 | 0 | first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); |
159 | 0 | first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); |
160 | 0 | first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] |
161 | 0 | cb(first_prediction, "first_prediction_gated", il); |
162 | 0 | ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens] |
163 | 0 | first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens] |
164 | 0 | cb(first_prediction, "first_prediction_scaled", il); |
165 | |
|
166 | 0 | first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens] |
167 | 0 | first_prediction = |
168 | 0 | build_norm(first_prediction, model.layers[il].per_layer_post_norm, NULL, LLM_NORM_RMS, il); |
169 | 0 | cb(first_prediction, "first_prediction_out", il); |
170 | 0 | } |
171 | | // equivalent to python code: corrected_predictions[1:] += first_prediction |
172 | 0 | { |
173 | 0 | ggml_tensor * slice_first = view_2d_slice(corrected, 0); |
174 | 0 | ggml_tensor * slice_rest = ggml_view_3d( |
175 | 0 | ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd), |
176 | 0 | ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected)); |
177 | 0 | ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1] |
178 | 0 | corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup] |
179 | 0 | } |
180 | 0 | cur = corrected; // [n_embd, n_tokens, n_altup] |
181 | 0 | cur = build_cvec(cur, il); |
182 | 0 | cb(cur, "l_out", il); |
183 | | |
184 | | // input for next layer |
185 | 0 | inpL = cur; |
186 | 0 | } |
187 | 0 | cur = inpL; // [n_embd, n_tokens, n_altup] |
188 | | |
189 | | // cur now has multiple altup(s), we want to merge them back to 1 altup |
190 | 0 | { |
191 | 0 | ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens] |
192 | | // do a view to skip the first slice (active altup) |
193 | 0 | ggml_tensor * alt_slice = |
194 | 0 | ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd), |
195 | 0 | ggml_row_size(cur->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(cur)); |
196 | 0 | ggml_tensor * altup_unembd = |
197 | 0 | ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1] |
198 | 0 | ggml_tensor * new_magnitude = calc_magnitude(altup_unembd); |
199 | 0 | altup_unembd = ggml_div(ctx0, ggml_mul(ctx0, altup_unembd, target_magnitude), new_magnitude); |
200 | 0 | cb(altup_unembd, "altup_unembd", -1); |
201 | | |
202 | | // equivalent to torch.mean(hidden_states, dim=0) |
203 | 0 | cur = view_2d_slice(cur, 0); // [n_embd, n_tokens] |
204 | 0 | for (int i = 0; i < n_altup - 1; ++i) { |
205 | 0 | cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i)); |
206 | 0 | } |
207 | 0 | cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens] |
208 | 0 | cb(cur, "unembd_merged", -1); |
209 | 0 | } |
210 | | // cur now has shape: [n_embd, n_tokens] |
211 | | |
212 | | // TODO: move this to right after the last KV layer |
213 | 0 | { |
214 | | // skip computing output for unused tokens |
215 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
216 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
217 | 0 | } |
218 | 0 | cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); |
219 | |
|
220 | 0 | cb(cur, "result_norm", -1); |
221 | 0 | res->t_embd = cur; |
222 | |
|
223 | 0 | cur = build_lora_mm(model.output, cur); |
224 | |
|
225 | 0 | { |
226 | | // final logit soft-capping |
227 | 0 | cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); |
228 | 0 | cur = ggml_tanh(ctx0, cur); |
229 | 0 | cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); |
230 | 0 | } |
231 | 0 | cb(cur, "result_output", -1); |
232 | 0 | res->t_logits = cur; |
233 | |
|
234 | 0 | ggml_build_forward_expand(gf, cur); |
235 | 0 | } |
236 | | |
237 | 0 | ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) { |
238 | 0 | return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); |
239 | 0 | } |
240 | | |
241 | | // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim |
242 | 0 | ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { |
243 | 0 | GGML_ASSERT(idx < (int) x->ne[2]); |
244 | 0 | return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), |
245 | 0 | idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); |
246 | 0 | } |
247 | | |
248 | | // equivalent to get_per_layer_inputs() in python code |
249 | | // output shape: [n_embd_altup, n_layer, n_tokens] |
250 | 0 | ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { |
251 | 0 | auto inp = std::make_unique<llm_graph_input_embd>(); |
252 | 0 | ggml_tensor * inp_per_layer; |
253 | 0 | if (ubatch.token) { |
254 | 0 | inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); |
255 | 0 | ggml_set_input(inp->tokens); |
256 | 0 | res->t_tokens = inp->tokens; |
257 | 0 | inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); |
258 | 0 | inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); |
259 | 0 | inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); |
260 | 0 | cb(inp_per_layer, "inp_per_layer_selected", -1); |
261 | 0 | } else { |
262 | 0 | GGML_ABORT("TODO: support embd input"); |
263 | 0 | } |
264 | 0 | res->add_input(std::move(inp)); |
265 | 0 | return inp_per_layer; |
266 | 0 | } |
267 | | |
268 | | // equivalent to project_per_layer_inputs() in python code |
269 | | // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim |
270 | | // output shape: [n_embd_altup, n_tokens, n_layer] |
271 | 0 | ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { |
272 | 0 | const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); |
273 | 0 | const float per_layer_input_scale = 1.0f / sqrtf(2.0f); |
274 | |
|
275 | 0 | ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds); |
276 | 0 | per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale); |
277 | 0 | per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); |
278 | 0 | per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, |
279 | 0 | -1); // [n_embd_altup, n_layer, n_tokens] |
280 | 0 | cb(per_layer_proj, "per_layer_proj", -1); |
281 | |
|
282 | 0 | inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj); |
283 | 0 | inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); |
284 | 0 | cb(inp_per_layer, "inp_per_layer", -1); |
285 | | |
286 | | // permute to shape: [n_embd_altup, n_tokens, n_layer] |
287 | 0 | inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); |
288 | 0 | return inp_per_layer; |
289 | 0 | } |
290 | | |
291 | | // input cur shape: [n_altup, n_tokens] |
292 | | // output shape: [n_altup, n_tokens] |
293 | 0 | ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) { |
294 | 0 | ggml_tensor * tmp = cur; |
295 | 0 | tmp = build_lora_mm(model.layers[il].laurel_l, tmp); |
296 | 0 | tmp = build_lora_mm(model.layers[il].laurel_r, tmp); |
297 | 0 | tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il); |
298 | 0 | tmp = ggml_add(ctx0, tmp, cur); |
299 | 0 | cb(tmp, "laurel_out", il); |
300 | 0 | return tmp; |
301 | 0 | } |
302 | | |
303 | | // input x shape: [n_embd, n_tokens] |
304 | | // output shape: [n_embd, n_tokens] |
305 | 0 | ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) { |
306 | 0 | ggml_tensor * mean = ggml_mean(ctx0, x); |
307 | 0 | ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))), |
308 | 0 | 1.0f / (float) (x->ne[0] - 1))); |
309 | 0 | ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul)); |
310 | 0 | return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x)); |
311 | 0 | } |
312 | | |
313 | | // |
314 | | // altup functions |
315 | | // |
316 | | |
317 | | // equivalent to compute_router_modalities() in python code |
318 | | // input x shape: [n_embd, n_tokens] |
319 | | // output shape: [n_altup, n_tokens] |
320 | 0 | ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) { |
321 | 0 | ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il); |
322 | | |
323 | | // router_input_scale |
324 | 0 | router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float) n_embd); |
325 | |
|
326 | 0 | ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); |
327 | 0 | return ggml_tanh(ctx0, output); // [n_altup, n_tokens] |
328 | 0 | } |
329 | | |
330 | | // input cur shape: [n_embd, n_tokens, n_altup] |
331 | | // output shape: [n_embd, n_tokens, n_altup] |
332 | 0 | ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { |
333 | 0 | ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens] |
334 | 0 | ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] |
335 | 0 | cb(modalities, "modalities", il); |
336 | |
|
337 | 0 | ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities); |
338 | 0 | cb(all_coefs, "all_coefs", il); |
339 | | // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor) |
340 | 0 | all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens); |
341 | | |
342 | | // permute to [n_altup, n_embd, n_tokens] |
343 | 0 | ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); |
344 | 0 | ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens] |
345 | | |
346 | | // final shape must be the same as cur: [n_embd, n_tokens, n_altup] |
347 | 0 | predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3)); |
348 | 0 | predictions = ggml_add(ctx0, predictions, cur); |
349 | 0 | cb(predictions, "predictions", il); |
350 | |
|
351 | 0 | return predictions; |
352 | 0 | } |
353 | | |
354 | | // input predictions shape: [n_embd, n_tokens, n_altup] |
355 | | // input activated shape: [n_embd, n_tokens] |
356 | | // output shape: [n_embd, n_tokens, n_altup] |
357 | 0 | ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { |
358 | 0 | ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] |
359 | 0 | cb(modalities, "modalities", il); |
360 | |
|
361 | 0 | ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); |
362 | 0 | ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] |
363 | 0 | cb(innovation, "innovation", il); |
364 | |
|
365 | 0 | ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] |
366 | 0 | all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0 |
367 | 0 | cb(all_coefs, "all_coefs", il); |
368 | 0 | all_coefs = ggml_transpose(ctx0, all_coefs); // [n_tokens, n_altup] |
369 | 0 | all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] |
370 | |
|
371 | 0 | innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); |
372 | 0 | ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] |
373 | 0 | corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup] |
374 | 0 | cb(corrected, "corrected", il); |
375 | |
|
376 | 0 | return corrected; |
377 | 0 | } |