/src/llama.cpp/src/models/gemma3n.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | 0 | void llama_model_gemma3n::load_arch_hparams(llama_model_loader & ml) { |
4 | 0 | uint32_t swa_period = 5; |
5 | 0 | ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); |
6 | 0 | hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; |
7 | 0 | hparams.set_swa_pattern(swa_period); |
8 | |
|
9 | 0 | hparams.n_layer_kv_from_start = 20; |
10 | 0 | hparams.f_attention_scale = 1.0f; |
11 | |
|
12 | 0 | ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); |
13 | 0 | ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); |
14 | 0 | ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); |
15 | |
|
16 | 0 | switch (hparams.n_layer()) { |
17 | 0 | case 30: type = LLM_TYPE_E2B; break; |
18 | 0 | case 35: type = LLM_TYPE_E4B; break; |
19 | 0 | default: type = LLM_TYPE_UNKNOWN; |
20 | 0 | } |
21 | 0 | } |
22 | | |
23 | 0 | void llama_model_gemma3n::load_arch_tensors(llama_model_loader &) { |
24 | 0 | LLAMA_LOAD_LOCALS; |
25 | |
|
26 | 0 | const int64_t n_altup = hparams.n_altup; |
27 | 0 | const int64_t laurel_rank = hparams.laurel_rank; |
28 | 0 | const int64_t n_embd_altup = hparams.n_embd_altup; |
29 | |
|
30 | 0 | output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); |
31 | | // if output is NULL, init from the input tok embed |
32 | 0 | if (output == NULL) { |
33 | 0 | output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); |
34 | 0 | } |
35 | |
|
36 | 0 | tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); |
37 | |
|
38 | 0 | altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); |
39 | 0 | altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); |
40 | |
|
41 | 0 | per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); |
42 | 0 | per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0); |
43 | 0 | per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0); |
44 | |
|
45 | 0 | output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); |
46 | |
|
47 | 0 | for (int i = 0; i < n_layer; ++i) { |
48 | 0 | auto & layer = layers[i]; |
49 | |
|
50 | 0 | layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); |
51 | |
|
52 | 0 | create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); |
53 | 0 | layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); |
54 | |
|
55 | 0 | layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); |
56 | 0 | layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); |
57 | 0 | layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); |
58 | |
|
59 | 0 | layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); |
60 | 0 | layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); |
61 | 0 | layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); |
62 | 0 | layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); |
63 | 0 | layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); |
64 | | |
65 | | // altup & laurel |
66 | 0 | layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); |
67 | 0 | layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); |
68 | 0 | layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); |
69 | 0 | layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); |
70 | 0 | layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); |
71 | 0 | layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); |
72 | 0 | layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); |
73 | 0 | layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); |
74 | 0 | layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); |
75 | 0 | layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); |
76 | 0 | layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); |
77 | 0 | } |
78 | 0 | } |
79 | | |
80 | 0 | std::unique_ptr<llm_graph_context> llama_model_gemma3n::build_arch_graph(const llm_graph_params & params) const { |
81 | 0 | return std::make_unique<graph>(*this, params); |
82 | 0 | } |
83 | | |
84 | | // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim |
85 | 0 | static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { |
86 | 0 | GGML_ASSERT(idx < (int) x->ne[2]); |
87 | 0 | return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), |
88 | 0 | idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); |
89 | 0 | } |
90 | | |
91 | | llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_params & params) : |
92 | 0 | llm_graph_context(params), |
93 | 0 | model(model), |
94 | 0 | n_embd_head(model.hparams.n_embd_head_k()), |
95 | 0 | n_embd_altup(model.hparams.n_embd_altup), |
96 | 0 | n_altup(model.hparams.n_altup), |
97 | 0 | i_altup_act(model.hparams.i_altup_act) { |
98 | 0 | ggml_tensor * cur; |
99 | 0 | ggml_tensor * inpL; |
100 | |
|
101 | 0 | inpL = build_inp_embd(model.tok_embd); |
102 | | |
103 | | // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) |
104 | 0 | inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); |
105 | 0 | cb(inpL, "inp_scaled", -1); |
106 | | |
107 | | // inp_pos - contains the positions |
108 | 0 | ggml_tensor * inp_pos = build_inp_pos(); |
109 | | |
110 | | // TODO: is causal == true correct? might need some changes |
111 | 0 | auto * inp_attn = build_attn_inp_kv_iswa(); |
112 | |
|
113 | 0 | ggml_tensor * inp_per_layer = build_inp_per_layer(); |
114 | 0 | ggml_build_forward_expand(gf, inp_per_layer); |
115 | | |
116 | | // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] |
117 | 0 | inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer); |
118 | | |
119 | | // inpL now has only 1 altup, project it to the rest of the altups |
120 | | // these "added" altups will be concat to the last dim of inpL |
121 | 0 | { |
122 | 0 | ggml_tensor * target_magnitude = calc_magnitude(inpL); |
123 | 0 | ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1); |
124 | 0 | ggml_tensor * altup_added = |
125 | 0 | ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1] |
126 | 0 | ggml_tensor * new_magnitude = calc_magnitude(altup_added); |
127 | 0 | altup_added = ggml_div(ctx0, ggml_mul(ctx0, altup_added, target_magnitude), new_magnitude); |
128 | 0 | inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] |
129 | 0 | cb(inpL, "inp_stacked", -1); |
130 | 0 | } |
131 | | // inpL now has shape: [n_embd, n_tokens, n_altup] |
132 | |
|
133 | 0 | for (int il = 0; il < n_layer; ++il) { |
134 | | // this block is made to be closely resemble Gemma3p5DecoderLayer on python code |
135 | 0 | const float freq_base_l = model.get_rope_freq_base(cparams, il); |
136 | 0 | const float freq_scale_l = model.get_rope_freq_scale(cparams, il); |
137 | |
|
138 | 0 | ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup] |
139 | 0 | ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] |
140 | | |
141 | | // predicted value will go through self-attention and laurel |
142 | 0 | ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); // [n_embd, n_tokens] |
143 | 0 | cur = active_prediction; |
144 | 0 | cb(cur, "active_prediction", il); |
145 | | |
146 | | // norm |
147 | 0 | cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); |
148 | 0 | cb(cur, "attn_norm", il); |
149 | | |
150 | | // laurel |
151 | 0 | ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens] |
152 | | |
153 | | // self-attention |
154 | 0 | if (hparams.has_kv(il)) { |
155 | 0 | auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); |
156 | |
|
157 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
158 | 0 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); |
159 | 0 | Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); |
160 | |
|
161 | 0 | cb(Qcur, "Qcur_normed", il); |
162 | 0 | cb(Kcur, "Kcur_normed", il); |
163 | 0 | cb(Vcur, "Vcur_normed", il); |
164 | |
|
165 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
166 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
167 | |
|
168 | 0 | Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
169 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
170 | |
|
171 | 0 | cb(Qcur, "Qcur_pos", il); |
172 | 0 | cb(Kcur, "Kcur_pos", il); |
173 | |
|
174 | 0 | cur = build_attn(inp_attn, model.layers[il].wo, |
175 | 0 | NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, |
176 | 0 | hparams.f_attention_scale, il); |
177 | 0 | } else { |
178 | | // reuse KV cache of earlier layers |
179 | 0 | ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); |
180 | 0 | cb(Qcur, "Qcur", il); |
181 | 0 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
182 | |
|
183 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
184 | 0 | cb(Qcur, "Qcur_normed", il); |
185 | |
|
186 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, |
187 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
188 | 0 | cb(Qcur, "Qcur_pos", il); |
189 | |
|
190 | 0 | cur = build_attn(inp_attn, |
191 | 0 | model.layers[il].wo, NULL, model.layers[il].wo_s, |
192 | 0 | Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); |
193 | 0 | } |
194 | 0 | cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); |
195 | 0 | cb(cur, "attn_post_norm", il); |
196 | |
|
197 | 0 | cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens] |
198 | 0 | cb(cur, "attn_gated", il); |
199 | |
|
200 | 0 | ggml_tensor * attn_laurel = ggml_scale(ctx0, ggml_add(ctx0, cur, laurel_out), |
201 | 0 | 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens] |
202 | 0 | cb(attn_laurel, "attn_laurel", il); |
203 | |
|
204 | 0 | cur = build_norm(attn_laurel, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); |
205 | 0 | cb(cur, "ffn_norm", il); |
206 | | |
207 | | // feed-forward network |
208 | 0 | { |
209 | 0 | ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur); |
210 | 0 | ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur); |
211 | |
|
212 | 0 | if (il < n_layer_sparsity) { |
213 | | // apply activation sparsity |
214 | 0 | gate_proj = gaussian_topk(gate_proj); |
215 | 0 | } |
216 | 0 | gate_proj = ggml_gelu(ctx0, gate_proj); |
217 | |
|
218 | 0 | cur = ggml_mul(ctx0, up_proj, gate_proj); |
219 | 0 | cur = build_lora_mm(model.layers[il].ffn_down, cur); |
220 | 0 | cb(cur, "ffn_out", il); |
221 | 0 | } |
222 | 0 | cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1); |
223 | 0 | cb(cur, "ffn_post_norm", il); |
224 | |
|
225 | 0 | ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens] |
226 | 0 | cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il); |
227 | |
|
228 | 0 | ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup] |
229 | |
|
230 | 0 | ggml_tensor * first_prediction; // [n_embd, n_tokens] |
231 | 0 | { |
232 | 0 | first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); // [n_embd, n_tokens] |
233 | 0 | first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); |
234 | 0 | first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); |
235 | 0 | first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] |
236 | 0 | cb(first_prediction, "first_prediction_gated", il); |
237 | |
|
238 | 0 | ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens] |
239 | 0 | first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens] |
240 | 0 | cb(first_prediction, "first_prediction_scaled", il); |
241 | |
|
242 | 0 | first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens] |
243 | 0 | first_prediction = |
244 | 0 | build_norm(first_prediction, model.layers[il].per_layer_post_norm, NULL, LLM_NORM_RMS, il); |
245 | 0 | cb(first_prediction, "first_prediction_out", il); |
246 | 0 | } |
247 | | // equivalent to python code: corrected_predictions[1:] += first_prediction |
248 | 0 | { |
249 | 0 | ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0); |
250 | 0 | ggml_tensor * slice_rest = ggml_view_3d( |
251 | 0 | ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd), |
252 | 0 | ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected)); |
253 | 0 | ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1] |
254 | 0 | corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup] |
255 | 0 | } |
256 | 0 | cur = corrected; // [n_embd, n_tokens, n_altup] |
257 | 0 | cur = build_cvec(cur, il); |
258 | 0 | cb(cur, "l_out", il); |
259 | | |
260 | | // input for next layer |
261 | 0 | inpL = cur; |
262 | 0 | } |
263 | 0 | cur = inpL; // [n_embd, n_tokens, n_altup] |
264 | | |
265 | | // cur now has multiple altup(s), we want to merge them back to 1 altup |
266 | 0 | { |
267 | 0 | ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); // [n_embd, n_tokens] |
268 | | // do a view to skip the first slice (active altup) |
269 | 0 | ggml_tensor * alt_slice = |
270 | 0 | ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd), |
271 | 0 | ggml_row_size(cur->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(cur)); |
272 | 0 | ggml_tensor * altup_unembd = |
273 | 0 | ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1] |
274 | 0 | ggml_tensor * new_magnitude = calc_magnitude(altup_unembd); |
275 | 0 | altup_unembd = ggml_div(ctx0, ggml_mul(ctx0, altup_unembd, target_magnitude), new_magnitude); |
276 | 0 | cb(altup_unembd, "altup_unembd", -1); |
277 | | |
278 | | // equivalent to torch.mean(hidden_states, dim=0) |
279 | 0 | cur = ggml_view_2d_slice(ctx0, cur, 0); // [n_embd, n_tokens] |
280 | 0 | for (int i = 0; i < n_altup - 1; ++i) { |
281 | 0 | cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i)); |
282 | 0 | } |
283 | 0 | cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens] |
284 | 0 | cb(cur, "unembd_merged", -1); |
285 | 0 | } |
286 | | // cur now has shape: [n_embd, n_tokens] |
287 | | |
288 | | // TODO: move this to right after the last KV layer |
289 | 0 | { |
290 | | // skip computing output for unused tokens |
291 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
292 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
293 | 0 | } |
294 | 0 | cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); |
295 | |
|
296 | 0 | cb(cur, "result_norm", -1); |
297 | 0 | res->t_embd = cur; |
298 | |
|
299 | 0 | cur = build_lora_mm(model.output, cur, model.output_s); |
300 | |
|
301 | 0 | { |
302 | | // final logit soft-capping |
303 | 0 | cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); |
304 | 0 | cur = ggml_tanh(ctx0, cur); |
305 | 0 | cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); |
306 | 0 | } |
307 | 0 | cb(cur, "result_output", -1); |
308 | 0 | res->t_logits = cur; |
309 | |
|
310 | 0 | ggml_build_forward_expand(gf, cur); |
311 | 0 | } |
312 | | |
313 | 0 | ggml_tensor * llama_model_gemma3n::graph::calc_magnitude(ggml_tensor * x) { |
314 | 0 | return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); |
315 | 0 | } |
316 | | |
317 | | // equivalent to get_per_layer_inputs() in python code |
318 | | // output shape: [n_embd_altup, n_layer, n_tokens] |
319 | 0 | ggml_tensor * llama_model_gemma3n::graph::build_inp_per_layer() { |
320 | 0 | auto inp = std::make_unique<llm_graph_input_embd>(n_embd); |
321 | 0 | ggml_tensor * inp_per_layer; |
322 | 0 | float tok_embd_scale = sqrtf((float) n_embd_altup); |
323 | 0 | if (ubatch.token) { |
324 | 0 | inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); |
325 | 0 | ggml_set_input(inp->tokens); |
326 | 0 | res->t_inp_tokens = inp->tokens; |
327 | 0 | inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens); |
328 | 0 | inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); |
329 | 0 | inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale); |
330 | 0 | cb(inp_per_layer, "inp_per_layer_selected", -1); |
331 | 0 | res->add_input(std::move(inp)); |
332 | 0 | } else { |
333 | | // Multimodal embedding path: use padding token (ID=0) embedding |
334 | | // TODO: verify if this is the correct behavior in transformers implementation |
335 | 0 | const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_altup * n_layer |
336 | | |
337 | | // Extract and dequantize padding token embedding (row 0) |
338 | 0 | ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0); |
339 | 0 | inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32); |
340 | 0 | inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale); |
341 | | |
342 | | // Reshape to [n_embd_altup, n_layer, 1] |
343 | 0 | inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1); |
344 | 0 | cb(inp_per_layer, "inp_per_layer_multimodal", -1); |
345 | 0 | } |
346 | 0 | return inp_per_layer; |
347 | 0 | } |
348 | | |
349 | | // equivalent to project_per_layer_inputs() in python code |
350 | | // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim |
351 | | // output shape: [n_embd_altup, n_tokens, n_layer] |
352 | 0 | ggml_tensor * llama_model_gemma3n::graph::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { |
353 | 0 | const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); |
354 | 0 | const float per_layer_input_scale = 1.0f / sqrtf(2.0f); |
355 | |
|
356 | 0 | ggml_tensor * per_layer_proj; |
357 | 0 | per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch); |
358 | 0 | per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale); |
359 | 0 | per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); |
360 | |
|
361 | 0 | per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1); |
362 | 0 | cb(per_layer_proj, "per_layer_proj", -1); |
363 | |
|
364 | 0 | inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer); |
365 | 0 | inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); |
366 | 0 | cb(inp_per_layer, "inp_per_layer", -1); |
367 | | |
368 | | // permute to shape: [n_embd_altup, n_tokens, n_layer] |
369 | 0 | inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); |
370 | 0 | return inp_per_layer; |
371 | 0 | } |
372 | | |
373 | | // input cur shape: [n_altup, n_tokens] |
374 | | // output shape: [n_altup, n_tokens] |
375 | 0 | ggml_tensor * llama_model_gemma3n::graph::laurel(ggml_tensor * cur, int il) { |
376 | 0 | ggml_tensor * tmp = cur; |
377 | 0 | tmp = build_lora_mm(model.layers[il].laurel_l, tmp); |
378 | 0 | tmp = build_lora_mm(model.layers[il].laurel_r, tmp); |
379 | 0 | tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il); |
380 | 0 | tmp = ggml_add(ctx0, tmp, cur); |
381 | 0 | cb(tmp, "laurel_out", il); |
382 | 0 | return tmp; |
383 | 0 | } |
384 | | |
385 | | // input x shape: [n_embd, n_tokens] |
386 | | // output shape: [n_embd, n_tokens] |
387 | 0 | ggml_tensor * llama_model_gemma3n::graph::gaussian_topk(ggml_tensor * x) { |
388 | 0 | ggml_tensor * mean = ggml_mean(ctx0, x); |
389 | 0 | ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))), |
390 | 0 | 1.0f / (float) (x->ne[0] - 1))); |
391 | 0 | ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul)); |
392 | 0 | return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x)); |
393 | 0 | } |
394 | | |
395 | | // |
396 | | // altup functions |
397 | | // |
398 | | |
399 | | // equivalent to compute_router_modalities() in python code |
400 | | // input x shape: [n_embd, n_tokens] |
401 | | // output shape: [n_altup, n_tokens] |
402 | 0 | ggml_tensor * llama_model_gemma3n::graph::altup_compute_router_modalities(ggml_tensor * x, int il) { |
403 | 0 | ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il); |
404 | | |
405 | | // router_input_scale |
406 | 0 | router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float) n_embd); |
407 | |
|
408 | 0 | ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); |
409 | 0 | return ggml_tanh(ctx0, output); // [n_altup, n_tokens] |
410 | 0 | } |
411 | | |
412 | | // input cur shape: [n_embd, n_tokens, n_altup] |
413 | | // output shape: [n_embd, n_tokens, n_altup] |
414 | 0 | ggml_tensor * llama_model_gemma3n::graph::altup_predict(ggml_tensor * cur, int il) { |
415 | 0 | ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens] |
416 | 0 | ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] |
417 | 0 | cb(modalities, "modalities", il); |
418 | |
|
419 | 0 | ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities); |
420 | 0 | cb(all_coefs, "all_coefs", il); |
421 | | // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor) |
422 | 0 | all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens); |
423 | | |
424 | | // permute to [n_altup, n_embd, n_tokens] |
425 | 0 | ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); |
426 | 0 | ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens] |
427 | | |
428 | | // final shape must be the same as cur: [n_embd, n_tokens, n_altup] |
429 | 0 | predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3)); |
430 | 0 | predictions = ggml_add(ctx0, predictions, cur); |
431 | 0 | cb(predictions, "predictions", il); |
432 | |
|
433 | 0 | return predictions; |
434 | 0 | } |
435 | | |
436 | | // input predictions shape: [n_embd, n_tokens, n_altup] |
437 | | // input activated shape: [n_embd, n_tokens] |
438 | | // output shape: [n_embd, n_tokens, n_altup] |
439 | 0 | ggml_tensor * llama_model_gemma3n::graph::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { |
440 | 0 | ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] |
441 | 0 | cb(modalities, "modalities", il); |
442 | |
|
443 | 0 | ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); |
444 | 0 | ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] |
445 | 0 | cb(innovation, "innovation", il); |
446 | |
|
447 | 0 | ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] |
448 | 0 | all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0 |
449 | 0 | cb(all_coefs, "all_coefs", il); |
450 | 0 | all_coefs = ggml_transpose(ctx0, all_coefs); // [n_tokens, n_altup] |
451 | 0 | all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] |
452 | |
|
453 | 0 | innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); |
454 | 0 | ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] |
455 | 0 | corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup] |
456 | 0 | cb(corrected, "corrected", il); |
457 | |
|
458 | 0 | return corrected; |
459 | 0 | } |