Coverage Report

Created: 2026-06-22 06:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/models/gemma3n.cpp
Line
Count
Source
1
#include "models.h"
2
3
0
void llama_model_gemma3n::load_arch_hparams(llama_model_loader & ml) {
4
0
    uint32_t swa_period = 5;
5
0
    ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
6
0
    hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
7
0
    hparams.set_swa_pattern(swa_period);
8
9
0
    hparams.n_layer_kv_from_start = 20;
10
0
    hparams.f_attention_scale     = 1.0f;
11
12
0
    ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA,          hparams.rope_freq_base_train_swa, false);
13
0
    ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
14
0
    ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
15
16
0
    switch (hparams.n_layer()) {
17
0
        case 30: type = LLM_TYPE_E2B; break;
18
0
        case 35: type = LLM_TYPE_E4B; break;
19
0
        default: type = LLM_TYPE_UNKNOWN;
20
0
    }
21
0
}
22
23
0
void llama_model_gemma3n::load_arch_tensors(llama_model_loader &) {
24
0
    LLAMA_LOAD_LOCALS;
25
26
0
    const int64_t n_altup      = hparams.n_altup;
27
0
    const int64_t laurel_rank  = hparams.laurel_rank;
28
0
    const int64_t n_embd_altup = hparams.n_embd_altup;
29
30
0
    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
31
    // if output is NULL, init from the input tok embed
32
0
    if (output == NULL) {
33
0
        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
34
0
    }
35
36
0
    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
37
38
0
    altup_proj        = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ,        "weight"), {n_embd, n_embd, n_altup - 1}, 0);
39
0
    altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
40
41
0
    per_layer_tok_embd   = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
42
0
    per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0);
43
0
    per_layer_proj_norm  = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM,  "weight", 0), {n_embd_altup}, 0);
44
45
0
    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
46
47
0
    for (int i = 0; i < n_layer; ++i) {
48
0
        auto & layer = layers[i];
49
50
0
        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
51
52
0
        create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0);
53
0
        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
54
55
0
        layer.attn_q_norm    = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM,    "weight", i), {n_embd_head_k}, 0);
56
0
        layer.attn_k_norm    = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM,    "weight", i), {n_embd_head_k}, 0);
57
0
        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
58
59
0
        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
60
0
        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
61
0
        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
62
0
        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
63
0
        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
64
65
        // altup & laurel
66
0
        layer.per_layer_inp_gate   = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE,  "weight", i), {n_embd, n_embd_altup}, 0);
67
0
        layer.per_layer_proj       = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ,      "weight", i), {n_embd_altup, n_embd}, 0);
68
0
        layer.per_layer_post_norm  = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
69
0
        layer.altup_correct_coef   = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF,  "weight", i), {n_altup, n_altup}, 0);
70
0
        layer.altup_correct_scale  = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
71
0
        layer.altup_predict_coef   = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF,  "weight", i), {n_altup, n_altup * n_altup}, 0);
72
0
        layer.altup_router         = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER,        "weight", i), {n_embd, n_altup}, 0);
73
0
        layer.altup_router_norm    = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM,   "weight", i), {n_embd}, 0);
74
0
        layer.laurel_l             = create_tensor(tn(LLM_TENSOR_LAUREL_L,            "weight", i), {n_embd, laurel_rank}, 0);
75
0
        layer.laurel_r             = create_tensor(tn(LLM_TENSOR_LAUREL_R,            "weight", i), {laurel_rank, n_embd}, 0);
76
0
        layer.laurel_post_norm     = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM,    "weight", i), {n_embd}, 0);
77
0
    }
78
0
}
79
80
0
std::unique_ptr<llm_graph_context> llama_model_gemma3n::build_arch_graph(const llm_graph_params & params) const {
81
0
    return std::make_unique<graph>(*this, params);
82
0
}
83
84
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
85
0
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
86
0
    GGML_ASSERT(idx < (int) x->ne[2]);
87
0
    return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
88
0
                        idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
89
0
}
90
91
llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_params & params) :
92
0
    llm_graph_context(params),
93
0
    model(model),
94
0
    n_embd_head(model.hparams.n_embd_head_k()),
95
0
    n_embd_altup(model.hparams.n_embd_altup),
96
0
    n_altup(model.hparams.n_altup),
97
0
    i_altup_act(model.hparams.i_altup_act) {
98
0
    ggml_tensor * cur;
99
0
    ggml_tensor * inpL;
100
101
0
    inpL = build_inp_embd(model.tok_embd);
102
103
    // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings)
104
0
    inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
105
0
    cb(inpL, "inp_scaled", -1);
106
107
    // inp_pos - contains the positions
108
0
    ggml_tensor * inp_pos = build_inp_pos();
109
110
    // TODO: is causal == true correct? might need some changes
111
0
    auto * inp_attn = build_attn_inp_kv_iswa();
112
113
0
    ggml_tensor * inp_per_layer = build_inp_per_layer();
114
0
    ggml_build_forward_expand(gf, inp_per_layer);
115
116
    // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
117
0
    inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
118
119
    // inpL now has only 1 altup, project it to the rest of the altups
120
    // these "added" altups will be concat to the last dim of inpL
121
0
    {
122
0
        ggml_tensor * target_magnitude = calc_magnitude(inpL);
123
0
        ggml_tensor * inp_repeated     = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
124
0
        ggml_tensor * altup_added =
125
0
            ggml_mul_mat(ctx0, model.altup_proj, inp_repeated);  // shape: [n_embd, n_tokens, n_altup - 1]
126
0
        ggml_tensor * new_magnitude = calc_magnitude(altup_added);
127
0
        altup_added                 = ggml_div(ctx0, ggml_mul(ctx0, altup_added, target_magnitude), new_magnitude);
128
0
        inpL                        = ggml_concat(ctx0, inpL, altup_added, 2);  // shape: [n_embd, n_tokens, n_altup]
129
0
        cb(inpL, "inp_stacked", -1);
130
0
    }
131
    // inpL now has shape: [n_embd, n_tokens, n_altup]
132
133
0
    for (int il = 0; il < n_layer; ++il) {
134
        // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
135
0
        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
136
0
        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
137
138
0
        ggml_tensor * cur         = inpL;                    // [n_embd, n_tokens, n_altup]
139
0
        ggml_tensor * predictions = altup_predict(cur, il);  // [n_embd, n_tokens, n_altup]
140
141
        // predicted value will go through self-attention and laurel
142
0
        ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act);  // [n_embd, n_tokens]
143
0
        cur = active_prediction;
144
0
        cb(cur, "active_prediction", il);
145
146
        // norm
147
0
        cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
148
0
        cb(cur, "attn_norm", il);
149
150
        // laurel
151
0
        ggml_tensor * laurel_out = laurel(cur, il);  // [n_embd, n_tokens]
152
153
        // self-attention
154
0
        if (hparams.has_kv(il)) {
155
0
            auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il);
156
157
0
            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
158
0
            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
159
0
            Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
160
161
0
            cb(Qcur, "Qcur_normed", il);
162
0
            cb(Kcur, "Kcur_normed", il);
163
0
            cb(Vcur, "Vcur_normed", il);
164
165
0
            Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
166
0
                                 ext_factor, attn_factor, beta_fast, beta_slow);
167
168
0
            Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
169
0
                                 ext_factor, attn_factor, beta_fast, beta_slow);
170
171
0
            cb(Qcur, "Qcur_pos", il);
172
0
            cb(Kcur, "Kcur_pos", il);
173
174
0
            cur = build_attn(inp_attn, model.layers[il].wo,
175
0
                    NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr,
176
0
                    hparams.f_attention_scale, il);
177
0
        } else {
178
            // reuse KV cache of earlier layers
179
0
            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
180
0
            cb(Qcur, "Qcur", il);
181
0
            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
182
183
0
            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
184
0
            cb(Qcur, "Qcur_normed", il);
185
186
0
            Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
187
0
                                 ext_factor, attn_factor, beta_fast, beta_slow);
188
0
            cb(Qcur, "Qcur_pos", il);
189
190
0
            cur = build_attn(inp_attn,
191
0
                    model.layers[il].wo, NULL, model.layers[il].wo_s,
192
0
                    Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
193
0
        }
194
0
        cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
195
0
        cb(cur, "attn_post_norm", il);
196
197
0
        cur = ggml_add(ctx0, cur, active_prediction);  // [n_embd, n_tokens]
198
0
        cb(cur, "attn_gated", il);
199
200
0
        ggml_tensor * attn_laurel = ggml_scale(ctx0, ggml_add(ctx0, cur, laurel_out),
201
0
                                               1.0f / sqrtf(2.0f));  // [n_embd, n_tokens]
202
0
        cb(attn_laurel, "attn_laurel", il);
203
204
0
        cur = build_norm(attn_laurel, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
205
0
        cb(cur, "ffn_norm", il);
206
207
        // feed-forward network
208
0
        {
209
0
            ggml_tensor * up_proj   = build_lora_mm(model.layers[il].ffn_up, cur);
210
0
            ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
211
212
0
            if (il < n_layer_sparsity) {
213
                // apply activation sparsity
214
0
                gate_proj = gaussian_topk(gate_proj);
215
0
            }
216
0
            gate_proj = ggml_gelu(ctx0, gate_proj);
217
218
0
            cur = ggml_mul(ctx0, up_proj, gate_proj);
219
0
            cur = build_lora_mm(model.layers[il].ffn_down, cur);
220
0
            cb(cur, "ffn_out", il);
221
0
        }
222
0
        cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1);
223
0
        cb(cur, "ffn_post_norm", il);
224
225
0
        ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel);  // [n_embd, n_tokens]
226
0
        cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
227
228
0
        ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il);  // [n_embd, n_tokens, n_altup]
229
230
0
        ggml_tensor * first_prediction;                                                   // [n_embd, n_tokens]
231
0
        {
232
0
            first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act);          // [n_embd, n_tokens]
233
0
            first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
234
0
            first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
235
0
            first_prediction = ggml_gelu(ctx0, first_prediction);                 // [n_embd_altup, n_tokens]
236
0
            cb(first_prediction, "first_prediction_gated", il);
237
238
0
            ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il);   // [n_embd_altup, n_tokens]
239
0
            first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer);  // [n_embd_altup, n_tokens]
240
0
            cb(first_prediction, "first_prediction_scaled", il);
241
242
0
            first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction);  // [n_embd, n_tokens]
243
0
            first_prediction =
244
0
                build_norm(first_prediction, model.layers[il].per_layer_post_norm, NULL, LLM_NORM_RMS, il);
245
0
            cb(first_prediction, "first_prediction_out", il);
246
0
        }
247
        // equivalent to python code: corrected_predictions[1:] += first_prediction
248
0
        {
249
0
            ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0);
250
0
            ggml_tensor * slice_rest  = ggml_view_3d(
251
0
                ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd),
252
0
                ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected));
253
0
            ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction);  // [n_embd, n_tokens, n_altup - 1]
254
0
            corrected         = ggml_concat(ctx0, slice_first, tmp, 2);        // [n_embd, n_tokens, n_altup]
255
0
        }
256
0
        cur = corrected;                                                       // [n_embd, n_tokens, n_altup]
257
0
        cur = build_cvec(cur, il);
258
0
        cb(cur, "l_out", il);
259
260
        // input for next layer
261
0
        inpL = cur;
262
0
    }
263
0
    cur = inpL;  // [n_embd, n_tokens, n_altup]
264
265
    // cur now has multiple altup(s), we want to merge them back to 1 altup
266
0
    {
267
0
        ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act));  // [n_embd, n_tokens]
268
        // do a view to skip the first slice (active altup)
269
0
        ggml_tensor * alt_slice =
270
0
            ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd),
271
0
                         ggml_row_size(cur->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(cur));
272
0
        ggml_tensor * altup_unembd =
273
0
            ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice);  // shape: [n_embd, n_tokens, n_altup - 1]
274
0
        ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
275
0
        altup_unembd                = ggml_div(ctx0, ggml_mul(ctx0, altup_unembd, target_magnitude), new_magnitude);
276
0
        cb(altup_unembd, "altup_unembd", -1);
277
278
        // equivalent to torch.mean(hidden_states, dim=0)
279
0
        cur = ggml_view_2d_slice(ctx0, cur, 0);  // [n_embd, n_tokens]
280
0
        for (int i = 0; i < n_altup - 1; ++i) {
281
0
            cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i));
282
0
        }
283
0
        cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup));  // [n_embd, n_tokens]
284
0
        cb(cur, "unembd_merged", -1);
285
0
    }
286
    // cur now has shape: [n_embd, n_tokens]
287
288
    // TODO: move this to right after the last KV layer
289
0
    {
290
        // skip computing output for unused tokens
291
0
        ggml_tensor * inp_out_ids = build_inp_out_ids();
292
0
        cur                       = ggml_get_rows(ctx0, cur, inp_out_ids);
293
0
    }
294
0
    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
295
296
0
    cb(cur, "result_norm", -1);
297
0
    res->t_embd = cur;
298
299
0
    cur = build_lora_mm(model.output, cur, model.output_s);
300
301
0
    {
302
        // final logit soft-capping
303
0
        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
304
0
        cur = ggml_tanh(ctx0, cur);
305
0
        cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
306
0
    }
307
0
    cb(cur, "result_output", -1);
308
0
    res->t_logits = cur;
309
310
0
    ggml_build_forward_expand(gf, cur);
311
0
}
312
313
0
ggml_tensor * llama_model_gemma3n::graph::calc_magnitude(ggml_tensor * x) {
314
0
    return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
315
0
}
316
317
// equivalent to get_per_layer_inputs() in python code
318
// output shape: [n_embd_altup, n_layer, n_tokens]
319
0
ggml_tensor * llama_model_gemma3n::graph::build_inp_per_layer() {
320
0
    auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
321
0
    ggml_tensor * inp_per_layer;
322
0
    float tok_embd_scale = sqrtf((float) n_embd_altup);
323
0
    if (ubatch.token) {
324
0
        inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
325
0
        ggml_set_input(inp->tokens);
326
0
        res->t_inp_tokens = inp->tokens;
327
0
        inp_per_layer = ggml_get_rows  (ctx0, model.per_layer_tok_embd, inp->tokens);
328
0
        inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
329
0
        inp_per_layer = ggml_scale     (ctx0, inp_per_layer, tok_embd_scale);
330
0
        cb(inp_per_layer, "inp_per_layer_selected", -1);
331
0
        res->add_input(std::move(inp));
332
0
    } else {
333
        // Multimodal embedding path: use padding token (ID=0) embedding
334
        // TODO: verify if this is the correct behavior in transformers implementation
335
0
        const int64_t embd_size = model.per_layer_tok_embd->ne[0];  // n_embd_altup * n_layer
336
337
        // Extract and dequantize padding token embedding (row 0)
338
0
        ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
339
0
        inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32);
340
0
        inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale);
341
342
        // Reshape to [n_embd_altup, n_layer, 1]
343
0
        inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);
344
0
        cb(inp_per_layer, "inp_per_layer_multimodal", -1);
345
0
    }
346
0
    return inp_per_layer;
347
0
}
348
349
// equivalent to project_per_layer_inputs() in python code
350
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
351
// output shape: [n_embd_altup, n_tokens, n_layer]
352
0
ggml_tensor * llama_model_gemma3n::graph::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
353
0
    const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
354
0
    const float per_layer_input_scale      = 1.0f / sqrtf(2.0f);
355
356
0
    ggml_tensor * per_layer_proj;
357
0
    per_layer_proj = ggml_mul_mat   (ctx0, model.per_layer_model_proj, inp_batch);
358
0
    per_layer_proj = ggml_scale     (ctx0, per_layer_proj, per_layer_projection_scale);
359
0
    per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
360
361
0
    per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1);
362
0
    cb(per_layer_proj, "per_layer_proj", -1);
363
364
0
    inp_per_layer = ggml_add  (ctx0, per_layer_proj, inp_per_layer);
365
0
    inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
366
0
    cb(inp_per_layer, "inp_per_layer", -1);
367
368
    // permute to shape: [n_embd_altup, n_tokens, n_layer]
369
0
    inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
370
0
    return inp_per_layer;
371
0
}
372
373
// input cur shape: [n_altup, n_tokens]
374
// output    shape: [n_altup, n_tokens]
375
0
ggml_tensor * llama_model_gemma3n::graph::laurel(ggml_tensor * cur, int il) {
376
0
    ggml_tensor * tmp = cur;
377
0
    tmp               = build_lora_mm(model.layers[il].laurel_l, tmp);
378
0
    tmp               = build_lora_mm(model.layers[il].laurel_r, tmp);
379
0
    tmp               = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
380
0
    tmp               = ggml_add(ctx0, tmp, cur);
381
0
    cb(tmp, "laurel_out", il);
382
0
    return tmp;
383
0
}
384
385
// input x shape: [n_embd, n_tokens]
386
// output  shape: [n_embd, n_tokens]
387
0
ggml_tensor * llama_model_gemma3n::graph::gaussian_topk(ggml_tensor * x) {
388
0
    ggml_tensor * mean = ggml_mean(ctx0, x);
389
0
    ggml_tensor * std  = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
390
0
                                                    1.0f / (float) (x->ne[0] - 1)));
391
0
    ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
392
0
    return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
393
0
}
394
395
//
396
// altup functions
397
//
398
399
// equivalent to compute_router_modalities() in python code
400
// input x shape: [n_embd,  n_tokens]
401
// output  shape: [n_altup, n_tokens]
402
0
ggml_tensor * llama_model_gemma3n::graph::altup_compute_router_modalities(ggml_tensor * x, int il) {
403
0
    ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il);
404
405
    // router_input_scale
406
0
    router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float) n_embd);
407
408
0
    ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
409
0
    return ggml_tanh(ctx0, output);  // [n_altup, n_tokens]
410
0
}
411
412
// input cur shape: [n_embd, n_tokens, n_altup]
413
// output    shape: [n_embd, n_tokens, n_altup]
414
0
ggml_tensor * llama_model_gemma3n::graph::altup_predict(ggml_tensor * cur, int il) {
415
0
    ggml_tensor * activated  = ggml_view_2d_slice(ctx0, cur, i_altup_act);      // [n_embd, n_tokens]
416
0
    ggml_tensor * modalities = altup_compute_router_modalities(activated, il);  // [n_altup, n_tokens]
417
0
    cb(modalities, "modalities", il);
418
419
0
    ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
420
0
    cb(all_coefs, "all_coefs", il);
421
    // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
422
0
    all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
423
424
    // permute to [n_altup, n_embd, n_tokens]
425
0
    ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
426
0
    ggml_tensor * predictions  = ggml_mul_mat(ctx0, cur_permuted, all_coefs);  // [n_altup, n_embd, n_tokens]
427
428
    // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
429
0
    predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
430
0
    predictions = ggml_add(ctx0, predictions, cur);
431
0
    cb(predictions, "predictions", il);
432
433
0
    return predictions;
434
0
}
435
436
// input predictions       shape: [n_embd, n_tokens, n_altup]
437
// input activated         shape: [n_embd, n_tokens]
438
// output                  shape: [n_embd, n_tokens, n_altup]
439
0
ggml_tensor * llama_model_gemma3n::graph::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
440
0
    ggml_tensor * modalities = altup_compute_router_modalities(activated, il);  // [n_altup, n_tokens]
441
0
    cb(modalities, "modalities", il);
442
443
0
    ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act);
444
0
    ggml_tensor * innovation        = ggml_sub(ctx0, activated, active_prediction);  // [n_embd, n_tokens]
445
0
    cb(innovation, "innovation", il);
446
447
0
    ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities);  // [n_altup, n_tokens]
448
0
    all_coefs               = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f);                    // + 1.0
449
0
    cb(all_coefs, "all_coefs", il);
450
0
    all_coefs = ggml_transpose(ctx0, all_coefs);                                               // [n_tokens, n_altup]
451
0
    all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup);                           // [1, n_tokens, n_altup]
452
453
0
    innovation              = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
454
0
    ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs);   // [n_embd, n_tokens, n_altup]
455
0
    corrected               = ggml_add(ctx0, corrected, predictions);  // [n_embd, n_tokens, n_altup]
456
0
    cb(corrected, "corrected", il);
457
458
0
    return corrected;
459
0
}