Coverage Report

Created: 2026-01-10 06:24

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/models/gemma3n-iswa.cpp
Line
Count
Source
1
#include "models.h"
2
3
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
4
0
    llm_graph_context(params),
5
0
    model(model),
6
0
    n_embd_head(model.hparams.n_embd_head_k),
7
0
    n_embd_altup(model.hparams.n_embd_altup),
8
0
    n_altup(model.hparams.n_altup),
9
0
    i_altup_act(model.hparams.i_altup_act) {
10
0
    ggml_tensor * cur;
11
0
    ggml_tensor * inpL;
12
13
0
    inpL = build_inp_embd(model.tok_embd);
14
15
    // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
16
0
    inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
17
0
    cb(inpL, "inp_scaled", -1);
18
19
    // inp_pos - contains the positions
20
0
    ggml_tensor * inp_pos = build_inp_pos();
21
22
    // TODO: is causal == true correct? might need some changes
23
0
    auto * inp_attn = build_attn_inp_kv_iswa();
24
25
    // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
26
0
    ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
27
28
    // inpL now has only 1 altup, project it to the rest of the altups
29
    // these "added" altups will be concat to the last dim of inpL
30
0
    {
31
0
        ggml_tensor * target_magnitude = calc_magnitude(inpL);
32
0
        ggml_tensor * inp_repeated     = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
33
0
        ggml_tensor * altup_added =
34
0
            ggml_mul_mat(ctx0, model.altup_proj, inp_repeated);  // shape: [n_embd, n_tokens, n_altup - 1]
35
0
        ggml_tensor * new_magnitude = calc_magnitude(altup_added);
36
0
        altup_added                 = ggml_div(ctx0, ggml_mul(ctx0, altup_added, target_magnitude), new_magnitude);
37
0
        inpL                        = ggml_concat(ctx0, inpL, altup_added, 2);  // shape: [n_embd, n_tokens, n_altup]
38
0
        cb(inpL, "inp_stacked", -1);
39
0
    }
40
    // inpL now has shape:          [n_embd,       n_tokens, n_altup]
41
    // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
42
43
0
    for (int il = 0; il < n_layer; ++il) {
44
        // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
45
0
        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
46
0
        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
47
48
0
        ggml_tensor * cur         = inpL;                    // [n_embd, n_tokens, n_altup]
49
0
        ggml_tensor * predictions = altup_predict(cur, il);  // [n_embd, n_tokens, n_altup]
50
51
        // predicted value will go through self-attention and laurel
52
0
        ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);  // [n_embd, n_tokens]
53
0
        cur                             = active_prediction;
54
0
        cb(cur, "active_prediction", il);
55
56
        // norm
57
0
        cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
58
0
        cb(cur, "attn_norm", il);
59
60
        // laurel
61
0
        ggml_tensor * laurel_out = laurel(cur, il);  // [n_embd, n_tokens]
62
63
        // self-attention
64
0
        if (hparams.has_kv(il)) {
65
            // compute Q and K and RoPE them
66
0
            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
67
0
            cb(Qcur, "Qcur", il);
68
69
0
            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
70
0
            cb(Kcur, "Kcur", il);
71
72
0
            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
73
0
            cb(Vcur, "Vcur", il);
74
75
0
            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
76
0
            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
77
0
            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
78
79
0
            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
80
0
            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
81
0
            Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
82
83
0
            cb(Qcur, "Qcur_normed", il);
84
0
            cb(Kcur, "Kcur_normed", il);
85
0
            cb(Vcur, "Vcur_normed", il);
86
87
0
            Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
88
0
                                 ext_factor, attn_factor, beta_fast, beta_slow);
89
90
0
            Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
91
0
                                 ext_factor, attn_factor, beta_fast, beta_slow);
92
93
0
            cb(Qcur, "Qcur_pos", il);
94
0
            cb(Kcur, "Kcur_pos", il);
95
96
0
            cur = build_attn(inp_attn, model.layers[il].wo,
97
0
                    NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr,
98
0
                    hparams.f_attention_scale, il);
99
0
        } else {
100
            // reuse KV cache of earlier layers
101
0
            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
102
0
            cb(Qcur, "Qcur", il);
103
0
            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
104
105
0
            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
106
0
            cb(Qcur, "Qcur_normed", il);
107
108
0
            Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
109
0
                                 ext_factor, attn_factor, beta_fast, beta_slow);
110
0
            cb(Qcur, "Qcur_pos", il);
111
112
0
            cur = build_attn(inp_attn,
113
0
                    model.layers[il].wo, NULL,
114
0
                    Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
115
0
        }
116
0
        cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
117
0
        cb(cur, "attn_post_norm", il);
118
119
0
        cur = ggml_add(ctx0, cur, active_prediction);  // [n_embd, n_tokens]
120
0
        cb(cur, "attn_gated", il);
121
122
0
        ggml_tensor * attn_laurel = ggml_scale(ctx0, ggml_add(ctx0, cur, laurel_out),
123
0
                                               1.0f / sqrtf(2.0f));  // [n_embd, n_tokens]
124
0
        cb(attn_laurel, "attn_laurel", il);
125
126
0
        cur = build_norm(attn_laurel, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
127
0
        cb(cur, "ffn_norm", il);
128
129
        // feed-forward network
130
0
        {
131
0
            ggml_tensor * up_proj   = build_lora_mm(model.layers[il].ffn_up, cur);
132
0
            ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
133
134
0
            if (il < n_layer_sparsity) {
135
                // apply activation sparsity
136
0
                gate_proj = gaussian_topk(gate_proj);
137
0
            }
138
0
            gate_proj = ggml_gelu(ctx0, gate_proj);
139
140
0
            cur = ggml_mul(ctx0, up_proj, gate_proj);
141
0
            cur = build_lora_mm(model.layers[il].ffn_down, cur);
142
0
            cb(cur, "ffn_out", il);
143
0
        }
144
0
        cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1);
145
0
        cb(cur, "ffn_post_norm", il);
146
147
0
        ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel);  // [n_embd, n_tokens]
148
0
        cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
149
150
0
        ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il);  // [n_embd, n_tokens, n_altup]
151
152
0
        ggml_tensor * first_prediction;                                                   // [n_embd, n_tokens]
153
0
        {
154
0
            first_prediction = view_2d_slice(corrected, i_altup_act);                     // [n_embd, n_tokens]
155
0
            first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
156
0
            first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
157
0
            first_prediction = ggml_gelu(ctx0, first_prediction);                 // [n_embd_altup, n_tokens]
158
0
            cb(first_prediction, "first_prediction_gated", il);
159
0
            ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il);      // [n_embd_altup, n_tokens]
160
0
            first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer);  // [n_embd_altup, n_tokens]
161
0
            cb(first_prediction, "first_prediction_scaled", il);
162
163
0
            first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction);  // [n_embd, n_tokens]
164
0
            first_prediction =
165
0
                build_norm(first_prediction, model.layers[il].per_layer_post_norm, NULL, LLM_NORM_RMS, il);
166
0
            cb(first_prediction, "first_prediction_out", il);
167
0
        }
168
        // equivalent to python code: corrected_predictions[1:] += first_prediction
169
0
        {
170
0
            ggml_tensor * slice_first = view_2d_slice(corrected, 0);
171
0
            ggml_tensor * slice_rest  = ggml_view_3d(
172
0
                ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd),
173
0
                ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected));
174
0
            ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction);  // [n_embd, n_tokens, n_altup - 1]
175
0
            corrected         = ggml_concat(ctx0, slice_first, tmp, 2);        // [n_embd, n_tokens, n_altup]
176
0
        }
177
0
        cur = corrected;                                                       // [n_embd, n_tokens, n_altup]
178
0
        cur = build_cvec(cur, il);
179
0
        cb(cur, "l_out", il);
180
181
        // input for next layer
182
0
        inpL = cur;
183
0
    }
184
0
    cur = inpL;  // [n_embd, n_tokens, n_altup]
185
186
    // cur now has multiple altup(s), we want to merge them back to 1 altup
187
0
    {
188
0
        ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act));  // [n_embd, n_tokens]
189
        // do a view to skip the first slice (active altup)
190
0
        ggml_tensor * alt_slice =
191
0
            ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd),
192
0
                         ggml_row_size(cur->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(cur));
193
0
        ggml_tensor * altup_unembd =
194
0
            ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice);  // shape: [n_embd, n_tokens, n_altup - 1]
195
0
        ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
196
0
        altup_unembd                = ggml_div(ctx0, ggml_mul(ctx0, altup_unembd, target_magnitude), new_magnitude);
197
0
        cb(altup_unembd, "altup_unembd", -1);
198
199
        // equivalent to torch.mean(hidden_states, dim=0)
200
0
        cur = view_2d_slice(cur, 0);  // [n_embd, n_tokens]
201
0
        for (int i = 0; i < n_altup - 1; ++i) {
202
0
            cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
203
0
        }
204
0
        cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup));  // [n_embd, n_tokens]
205
0
        cb(cur, "unembd_merged", -1);
206
0
    }
207
    // cur now has shape: [n_embd, n_tokens]
208
209
    // TODO: move this to right after the last KV layer
210
0
    {
211
        // skip computing output for unused tokens
212
0
        ggml_tensor * inp_out_ids = build_inp_out_ids();
213
0
        cur                       = ggml_get_rows(ctx0, cur, inp_out_ids);
214
0
    }
215
0
    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
216
217
0
    cb(cur, "result_norm", -1);
218
0
    res->t_embd = cur;
219
220
0
    cur = build_lora_mm(model.output, cur);
221
222
0
    {
223
        // final logit soft-capping
224
0
        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
225
0
        cur = ggml_tanh(ctx0, cur);
226
0
        cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
227
0
    }
228
0
    cb(cur, "result_output", -1);
229
0
    res->t_logits = cur;
230
231
0
    ggml_build_forward_expand(gf, cur);
232
0
}
233
234
0
ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
235
0
    return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
236
0
}
237
238
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
239
0
ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
240
0
    GGML_ASSERT(idx < (int) x->ne[2]);
241
0
    return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
242
0
                        idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
243
0
}
244
245
// equivalent to get_per_layer_inputs() in python code
246
// output shape: [n_embd_altup, n_layer, n_tokens]
247
0
ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
248
0
    auto inp = std::make_unique<llm_graph_input_embd>();
249
0
    ggml_tensor * inp_per_layer;
250
0
    if (ubatch.token) {
251
0
        inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
252
0
        ggml_set_input(inp->tokens);
253
0
        res->t_tokens = inp->tokens;
254
0
        inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
255
0
        inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
256
0
        inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
257
0
        cb(inp_per_layer, "inp_per_layer_selected", -1);
258
0
        res->add_input(std::move(inp));
259
0
    } else {
260
        // Vision embedding path: use padding token (ID=0) embedding
261
0
        const int64_t embd_size = model.tok_embd_per_layer->ne[0];  // n_embd_altup * n_layer
262
263
        // Extract and dequantize padding token embedding (column 0)
264
0
        ggml_tensor * padding_q = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
265
0
        ggml_tensor * padding_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size);
266
0
        inp_per_layer = ggml_cpy(ctx0, padding_q, padding_f32);
267
268
        // Reshape to [n_embd_altup, n_layer, 1]
269
0
        inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);
270
0
        cb(inp_per_layer, "inp_per_layer_vision", -1);
271
0
    }
272
0
    return inp_per_layer;
273
0
}
274
275
// equivalent to project_per_layer_inputs() in python code
276
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
277
// output shape: [n_embd_altup, n_tokens, n_layer]
278
0
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
279
0
    const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
280
0
    const float per_layer_input_scale      = 1.0f / sqrtf(2.0f);
281
282
0
    ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
283
0
    per_layer_proj               = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
284
0
    per_layer_proj               = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
285
0
    per_layer_proj               = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS,
286
0
                                              -1);  // [n_embd_altup, n_layer, n_tokens]
287
0
    cb(per_layer_proj, "per_layer_proj", -1);
288
289
0
    inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
290
0
    inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
291
0
    cb(inp_per_layer, "inp_per_layer", -1);
292
293
    // permute to shape: [n_embd_altup, n_tokens, n_layer]
294
0
    inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
295
0
    return inp_per_layer;
296
0
}
297
298
// input cur shape: [n_altup, n_tokens]
299
// output    shape: [n_altup, n_tokens]
300
0
ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) {
301
0
    ggml_tensor * tmp = cur;
302
0
    tmp               = build_lora_mm(model.layers[il].laurel_l, tmp);
303
0
    tmp               = build_lora_mm(model.layers[il].laurel_r, tmp);
304
0
    tmp               = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
305
0
    tmp               = ggml_add(ctx0, tmp, cur);
306
0
    cb(tmp, "laurel_out", il);
307
0
    return tmp;
308
0
}
309
310
// input x shape: [n_embd, n_tokens]
311
// output  shape: [n_embd, n_tokens]
312
0
ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) {
313
0
    ggml_tensor * mean = ggml_mean(ctx0, x);
314
0
    ggml_tensor * std  = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
315
0
                                                    1.0f / (float) (x->ne[0] - 1)));
316
0
    ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
317
0
    return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
318
0
}
319
320
//
321
// altup functions
322
//
323
324
// equivalent to compute_router_modalities() in python code
325
// input x shape: [n_embd,  n_tokens]
326
// output  shape: [n_altup, n_tokens]
327
0
ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) {
328
0
    ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il);
329
330
    // router_input_scale
331
0
    router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float) n_embd);
332
333
0
    ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
334
0
    return ggml_tanh(ctx0, output);  // [n_altup, n_tokens]
335
0
}
336
337
// input cur shape: [n_embd, n_tokens, n_altup]
338
// output    shape: [n_embd, n_tokens, n_altup]
339
0
ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) {
340
0
    ggml_tensor * activated  = view_2d_slice(cur, i_altup_act);                 // [n_embd, n_tokens]
341
0
    ggml_tensor * modalities = altup_compute_router_modalities(activated, il);  // [n_altup, n_tokens]
342
0
    cb(modalities, "modalities", il);
343
344
0
    ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
345
0
    cb(all_coefs, "all_coefs", il);
346
    // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
347
0
    all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
348
349
    // permute to [n_altup, n_embd, n_tokens]
350
0
    ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
351
0
    ggml_tensor * predictions  = ggml_mul_mat(ctx0, cur_permuted, all_coefs);  // [n_altup, n_embd, n_tokens]
352
353
    // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
354
0
    predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
355
0
    predictions = ggml_add(ctx0, predictions, cur);
356
0
    cb(predictions, "predictions", il);
357
358
0
    return predictions;
359
0
}
360
361
// input predictions       shape: [n_embd, n_tokens, n_altup]
362
// input activated         shape: [n_embd, n_tokens]
363
// output                  shape: [n_embd, n_tokens, n_altup]
364
0
ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
365
0
    ggml_tensor * modalities = altup_compute_router_modalities(activated, il);  // [n_altup, n_tokens]
366
0
    cb(modalities, "modalities", il);
367
368
0
    ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
369
0
    ggml_tensor * innovation        = ggml_sub(ctx0, activated, active_prediction);  // [n_embd, n_tokens]
370
0
    cb(innovation, "innovation", il);
371
372
0
    ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities);  // [n_altup, n_tokens]
373
0
    all_coefs               = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f);                    // + 1.0
374
0
    cb(all_coefs, "all_coefs", il);
375
0
    all_coefs = ggml_transpose(ctx0, all_coefs);                                               // [n_tokens, n_altup]
376
0
    all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup);                           // [1, n_tokens, n_altup]
377
378
0
    innovation              = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
379
0
    ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs);   // [n_embd, n_tokens, n_altup]
380
0
    corrected               = ggml_add(ctx0, corrected, predictions);  // [n_embd, n_tokens, n_altup]
381
0
    cb(corrected, "corrected", il);
382
383
0
    return corrected;
384
0
}