/src/llama.cpp/src/models/gemma3n-iswa.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : |
4 | 0 | llm_graph_context(params), |
5 | 0 | model(model), |
6 | 0 | n_embd_head(model.hparams.n_embd_head_k), |
7 | 0 | n_embd_altup(model.hparams.n_embd_altup), |
8 | 0 | n_altup(model.hparams.n_altup), |
9 | 0 | i_altup_act(model.hparams.i_altup_act) { |
10 | 0 | ggml_tensor * cur; |
11 | 0 | ggml_tensor * inpL; |
12 | |
|
13 | 0 | inpL = build_inp_embd(model.tok_embd); |
14 | | |
15 | | // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) |
16 | 0 | inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); |
17 | 0 | cb(inpL, "inp_scaled", -1); |
18 | | |
19 | | // inp_pos - contains the positions |
20 | 0 | ggml_tensor * inp_pos = build_inp_pos(); |
21 | | |
22 | | // TODO: is causal == true correct? might need some changes |
23 | 0 | auto * inp_attn = build_attn_inp_kv_iswa(); |
24 | | |
25 | | // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] |
26 | 0 | ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); |
27 | | |
28 | | // inpL now has only 1 altup, project it to the rest of the altups |
29 | | // these "added" altups will be concat to the last dim of inpL |
30 | 0 | { |
31 | 0 | ggml_tensor * target_magnitude = calc_magnitude(inpL); |
32 | 0 | ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1); |
33 | 0 | ggml_tensor * altup_added = |
34 | 0 | ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1] |
35 | 0 | ggml_tensor * new_magnitude = calc_magnitude(altup_added); |
36 | 0 | altup_added = ggml_div(ctx0, ggml_mul(ctx0, altup_added, target_magnitude), new_magnitude); |
37 | 0 | inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] |
38 | 0 | cb(inpL, "inp_stacked", -1); |
39 | 0 | } |
40 | | // inpL now has shape: [n_embd, n_tokens, n_altup] |
41 | | // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] |
42 | |
|
43 | 0 | for (int il = 0; il < n_layer; ++il) { |
44 | | // this block is made to be closely resemble Gemma3p5DecoderLayer on python code |
45 | 0 | const float freq_base_l = model.get_rope_freq_base(cparams, il); |
46 | 0 | const float freq_scale_l = model.get_rope_freq_scale(cparams, il); |
47 | |
|
48 | 0 | ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup] |
49 | 0 | ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] |
50 | | |
51 | | // predicted value will go through self-attention and laurel |
52 | 0 | ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens] |
53 | 0 | cur = active_prediction; |
54 | 0 | cb(cur, "active_prediction", il); |
55 | | |
56 | | // norm |
57 | 0 | cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); |
58 | 0 | cb(cur, "attn_norm", il); |
59 | | |
60 | | // laurel |
61 | 0 | ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens] |
62 | | |
63 | | // self-attention |
64 | 0 | if (hparams.has_kv(il)) { |
65 | | // compute Q and K and RoPE them |
66 | 0 | ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); |
67 | 0 | cb(Qcur, "Qcur", il); |
68 | |
|
69 | 0 | ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); |
70 | 0 | cb(Kcur, "Kcur", il); |
71 | |
|
72 | 0 | ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); |
73 | 0 | cb(Vcur, "Vcur", il); |
74 | |
|
75 | 0 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
76 | 0 | Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); |
77 | 0 | Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); |
78 | |
|
79 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
80 | 0 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); |
81 | 0 | Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); |
82 | |
|
83 | 0 | cb(Qcur, "Qcur_normed", il); |
84 | 0 | cb(Kcur, "Kcur_normed", il); |
85 | 0 | cb(Vcur, "Vcur_normed", il); |
86 | |
|
87 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
88 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
89 | |
|
90 | 0 | Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
91 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
92 | |
|
93 | 0 | cb(Qcur, "Qcur_pos", il); |
94 | 0 | cb(Kcur, "Kcur_pos", il); |
95 | |
|
96 | 0 | cur = build_attn(inp_attn, model.layers[il].wo, |
97 | 0 | NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, |
98 | 0 | hparams.f_attention_scale, il); |
99 | 0 | } else { |
100 | | // reuse KV cache of earlier layers |
101 | 0 | ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); |
102 | 0 | cb(Qcur, "Qcur", il); |
103 | 0 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
104 | |
|
105 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
106 | 0 | cb(Qcur, "Qcur_normed", il); |
107 | |
|
108 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
109 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
110 | 0 | cb(Qcur, "Qcur_pos", il); |
111 | |
|
112 | 0 | cur = build_attn(inp_attn, |
113 | 0 | model.layers[il].wo, NULL, |
114 | 0 | Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); |
115 | 0 | } |
116 | 0 | cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); |
117 | 0 | cb(cur, "attn_post_norm", il); |
118 | |
|
119 | 0 | cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens] |
120 | 0 | cb(cur, "attn_gated", il); |
121 | |
|
122 | 0 | ggml_tensor * attn_laurel = ggml_scale(ctx0, ggml_add(ctx0, cur, laurel_out), |
123 | 0 | 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens] |
124 | 0 | cb(attn_laurel, "attn_laurel", il); |
125 | |
|
126 | 0 | cur = build_norm(attn_laurel, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); |
127 | 0 | cb(cur, "ffn_norm", il); |
128 | | |
129 | | // feed-forward network |
130 | 0 | { |
131 | 0 | ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur); |
132 | 0 | ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur); |
133 | |
|
134 | 0 | if (il < n_layer_sparsity) { |
135 | | // apply activation sparsity |
136 | 0 | gate_proj = gaussian_topk(gate_proj); |
137 | 0 | } |
138 | 0 | gate_proj = ggml_gelu(ctx0, gate_proj); |
139 | |
|
140 | 0 | cur = ggml_mul(ctx0, up_proj, gate_proj); |
141 | 0 | cur = build_lora_mm(model.layers[il].ffn_down, cur); |
142 | 0 | cb(cur, "ffn_out", il); |
143 | 0 | } |
144 | 0 | cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1); |
145 | 0 | cb(cur, "ffn_post_norm", il); |
146 | |
|
147 | 0 | ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens] |
148 | 0 | cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il); |
149 | |
|
150 | 0 | ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup] |
151 | |
|
152 | 0 | ggml_tensor * first_prediction; // [n_embd, n_tokens] |
153 | 0 | { |
154 | 0 | first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] |
155 | 0 | first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); |
156 | 0 | first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); |
157 | 0 | first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] |
158 | 0 | cb(first_prediction, "first_prediction_gated", il); |
159 | 0 | ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens] |
160 | 0 | first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens] |
161 | 0 | cb(first_prediction, "first_prediction_scaled", il); |
162 | |
|
163 | 0 | first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens] |
164 | 0 | first_prediction = |
165 | 0 | build_norm(first_prediction, model.layers[il].per_layer_post_norm, NULL, LLM_NORM_RMS, il); |
166 | 0 | cb(first_prediction, "first_prediction_out", il); |
167 | 0 | } |
168 | | // equivalent to python code: corrected_predictions[1:] += first_prediction |
169 | 0 | { |
170 | 0 | ggml_tensor * slice_first = view_2d_slice(corrected, 0); |
171 | 0 | ggml_tensor * slice_rest = ggml_view_3d( |
172 | 0 | ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd), |
173 | 0 | ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected)); |
174 | 0 | ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1] |
175 | 0 | corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup] |
176 | 0 | } |
177 | 0 | cur = corrected; // [n_embd, n_tokens, n_altup] |
178 | 0 | cur = build_cvec(cur, il); |
179 | 0 | cb(cur, "l_out", il); |
180 | | |
181 | | // input for next layer |
182 | 0 | inpL = cur; |
183 | 0 | } |
184 | 0 | cur = inpL; // [n_embd, n_tokens, n_altup] |
185 | | |
186 | | // cur now has multiple altup(s), we want to merge them back to 1 altup |
187 | 0 | { |
188 | 0 | ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens] |
189 | | // do a view to skip the first slice (active altup) |
190 | 0 | ggml_tensor * alt_slice = |
191 | 0 | ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd), |
192 | 0 | ggml_row_size(cur->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(cur)); |
193 | 0 | ggml_tensor * altup_unembd = |
194 | 0 | ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1] |
195 | 0 | ggml_tensor * new_magnitude = calc_magnitude(altup_unembd); |
196 | 0 | altup_unembd = ggml_div(ctx0, ggml_mul(ctx0, altup_unembd, target_magnitude), new_magnitude); |
197 | 0 | cb(altup_unembd, "altup_unembd", -1); |
198 | | |
199 | | // equivalent to torch.mean(hidden_states, dim=0) |
200 | 0 | cur = view_2d_slice(cur, 0); // [n_embd, n_tokens] |
201 | 0 | for (int i = 0; i < n_altup - 1; ++i) { |
202 | 0 | cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i)); |
203 | 0 | } |
204 | 0 | cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens] |
205 | 0 | cb(cur, "unembd_merged", -1); |
206 | 0 | } |
207 | | // cur now has shape: [n_embd, n_tokens] |
208 | | |
209 | | // TODO: move this to right after the last KV layer |
210 | 0 | { |
211 | | // skip computing output for unused tokens |
212 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
213 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
214 | 0 | } |
215 | 0 | cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); |
216 | |
|
217 | 0 | cb(cur, "result_norm", -1); |
218 | 0 | res->t_embd = cur; |
219 | |
|
220 | 0 | cur = build_lora_mm(model.output, cur); |
221 | |
|
222 | 0 | { |
223 | | // final logit soft-capping |
224 | 0 | cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); |
225 | 0 | cur = ggml_tanh(ctx0, cur); |
226 | 0 | cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); |
227 | 0 | } |
228 | 0 | cb(cur, "result_output", -1); |
229 | 0 | res->t_logits = cur; |
230 | |
|
231 | 0 | ggml_build_forward_expand(gf, cur); |
232 | 0 | } |
233 | | |
234 | 0 | ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) { |
235 | 0 | return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); |
236 | 0 | } |
237 | | |
238 | | // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim |
239 | 0 | ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { |
240 | 0 | GGML_ASSERT(idx < (int) x->ne[2]); |
241 | 0 | return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), |
242 | 0 | idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); |
243 | 0 | } |
244 | | |
245 | | // equivalent to get_per_layer_inputs() in python code |
246 | | // output shape: [n_embd_altup, n_layer, n_tokens] |
247 | 0 | ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { |
248 | 0 | auto inp = std::make_unique<llm_graph_input_embd>(); |
249 | 0 | ggml_tensor * inp_per_layer; |
250 | 0 | if (ubatch.token) { |
251 | 0 | inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); |
252 | 0 | ggml_set_input(inp->tokens); |
253 | 0 | res->t_tokens = inp->tokens; |
254 | 0 | inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); |
255 | 0 | inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); |
256 | 0 | inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); |
257 | 0 | cb(inp_per_layer, "inp_per_layer_selected", -1); |
258 | 0 | res->add_input(std::move(inp)); |
259 | 0 | } else { |
260 | | // Vision embedding path: use padding token (ID=0) embedding |
261 | 0 | const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer |
262 | | |
263 | | // Extract and dequantize padding token embedding (column 0) |
264 | 0 | ggml_tensor * padding_q = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0); |
265 | 0 | ggml_tensor * padding_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size); |
266 | 0 | inp_per_layer = ggml_cpy(ctx0, padding_q, padding_f32); |
267 | | |
268 | | // Reshape to [n_embd_altup, n_layer, 1] |
269 | 0 | inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1); |
270 | 0 | cb(inp_per_layer, "inp_per_layer_vision", -1); |
271 | 0 | } |
272 | 0 | return inp_per_layer; |
273 | 0 | } |
274 | | |
275 | | // equivalent to project_per_layer_inputs() in python code |
276 | | // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim |
277 | | // output shape: [n_embd_altup, n_tokens, n_layer] |
278 | 0 | ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { |
279 | 0 | const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); |
280 | 0 | const float per_layer_input_scale = 1.0f / sqrtf(2.0f); |
281 | |
|
282 | 0 | ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds); |
283 | 0 | per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale); |
284 | 0 | per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); |
285 | 0 | per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, |
286 | 0 | -1); // [n_embd_altup, n_layer, n_tokens] |
287 | 0 | cb(per_layer_proj, "per_layer_proj", -1); |
288 | |
|
289 | 0 | inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer); |
290 | 0 | inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); |
291 | 0 | cb(inp_per_layer, "inp_per_layer", -1); |
292 | | |
293 | | // permute to shape: [n_embd_altup, n_tokens, n_layer] |
294 | 0 | inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); |
295 | 0 | return inp_per_layer; |
296 | 0 | } |
297 | | |
298 | | // input cur shape: [n_altup, n_tokens] |
299 | | // output shape: [n_altup, n_tokens] |
300 | 0 | ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) { |
301 | 0 | ggml_tensor * tmp = cur; |
302 | 0 | tmp = build_lora_mm(model.layers[il].laurel_l, tmp); |
303 | 0 | tmp = build_lora_mm(model.layers[il].laurel_r, tmp); |
304 | 0 | tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il); |
305 | 0 | tmp = ggml_add(ctx0, tmp, cur); |
306 | 0 | cb(tmp, "laurel_out", il); |
307 | 0 | return tmp; |
308 | 0 | } |
309 | | |
310 | | // input x shape: [n_embd, n_tokens] |
311 | | // output shape: [n_embd, n_tokens] |
312 | 0 | ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) { |
313 | 0 | ggml_tensor * mean = ggml_mean(ctx0, x); |
314 | 0 | ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))), |
315 | 0 | 1.0f / (float) (x->ne[0] - 1))); |
316 | 0 | ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul)); |
317 | 0 | return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x)); |
318 | 0 | } |
319 | | |
320 | | // |
321 | | // altup functions |
322 | | // |
323 | | |
324 | | // equivalent to compute_router_modalities() in python code |
325 | | // input x shape: [n_embd, n_tokens] |
326 | | // output shape: [n_altup, n_tokens] |
327 | 0 | ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) { |
328 | 0 | ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il); |
329 | | |
330 | | // router_input_scale |
331 | 0 | router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float) n_embd); |
332 | |
|
333 | 0 | ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); |
334 | 0 | return ggml_tanh(ctx0, output); // [n_altup, n_tokens] |
335 | 0 | } |
336 | | |
337 | | // input cur shape: [n_embd, n_tokens, n_altup] |
338 | | // output shape: [n_embd, n_tokens, n_altup] |
339 | 0 | ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { |
340 | 0 | ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens] |
341 | 0 | ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] |
342 | 0 | cb(modalities, "modalities", il); |
343 | |
|
344 | 0 | ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities); |
345 | 0 | cb(all_coefs, "all_coefs", il); |
346 | | // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor) |
347 | 0 | all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens); |
348 | | |
349 | | // permute to [n_altup, n_embd, n_tokens] |
350 | 0 | ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); |
351 | 0 | ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens] |
352 | | |
353 | | // final shape must be the same as cur: [n_embd, n_tokens, n_altup] |
354 | 0 | predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3)); |
355 | 0 | predictions = ggml_add(ctx0, predictions, cur); |
356 | 0 | cb(predictions, "predictions", il); |
357 | |
|
358 | 0 | return predictions; |
359 | 0 | } |
360 | | |
361 | | // input predictions shape: [n_embd, n_tokens, n_altup] |
362 | | // input activated shape: [n_embd, n_tokens] |
363 | | // output shape: [n_embd, n_tokens, n_altup] |
364 | 0 | ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { |
365 | 0 | ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] |
366 | 0 | cb(modalities, "modalities", il); |
367 | |
|
368 | 0 | ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); |
369 | 0 | ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] |
370 | 0 | cb(innovation, "innovation", il); |
371 | |
|
372 | 0 | ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] |
373 | 0 | all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0 |
374 | 0 | cb(all_coefs, "all_coefs", il); |
375 | 0 | all_coefs = ggml_transpose(ctx0, all_coefs); // [n_tokens, n_altup] |
376 | 0 | all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] |
377 | |
|
378 | 0 | innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); |
379 | 0 | ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] |
380 | 0 | corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup] |
381 | 0 | cb(corrected, "corrected", il); |
382 | |
|
383 | 0 | return corrected; |
384 | 0 | } |