/src/llama.cpp/src/models/qwen3next.cpp
Line | Count | Source |
1 | | #include "ggml.h" |
2 | | #include "models.h" |
3 | | |
4 | 0 | #define CHUNK_SIZE 64 |
5 | | |
6 | | llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : |
7 | 0 | llm_graph_context_mamba(params), model(model) { |
8 | 0 | ggml_tensor * cur; |
9 | 0 | ggml_tensor * inpL; |
10 | |
|
11 | 0 | inpL = build_inp_embd(model.tok_embd); |
12 | 0 | cb(inpL, "model.embed_tokens", -1); |
13 | |
|
14 | 0 | auto * inp = build_inp_mem_hybrid(); |
15 | |
|
16 | 0 | ggml_tensor * inp_pos = build_inp_pos(); |
17 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
18 | |
|
19 | 0 | ggml_tensor * causal_mask = |
20 | 0 | ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), |
21 | 0 | GGML_TRI_TYPE_LOWER); |
22 | |
|
23 | 0 | ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); |
24 | 0 | ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); |
25 | |
|
26 | 0 | ggml_build_forward_expand(gf, causal_mask); |
27 | 0 | ggml_build_forward_expand(gf, identity); |
28 | 0 | ggml_build_forward_expand(gf, diag_mask); |
29 | |
|
30 | 0 | for (int il = 0; il < n_layer; ++il) { |
31 | 0 | ggml_tensor * inpSA = inpL; |
32 | |
|
33 | 0 | cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); |
34 | 0 | cb(cur, "attn_norm", il); |
35 | | |
36 | | // Determine layer type and build appropriate attention mechanism |
37 | 0 | if (hparams.is_recurrent(il)) { |
38 | | // Linear attention layer (gated delta net) |
39 | 0 | cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); |
40 | 0 | } else { |
41 | | // Full attention layer |
42 | 0 | cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); |
43 | 0 | } |
44 | |
|
45 | 0 | if (il == n_layer - 1 && inp_out_ids) { |
46 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
47 | 0 | inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); |
48 | 0 | } |
49 | | |
50 | | // Residual connection |
51 | 0 | cur = ggml_add(ctx0, cur, inpSA); |
52 | 0 | cb(cur, "attn_residual", il); |
53 | | |
54 | | // Save the tensor before post-attention norm for residual connection |
55 | 0 | ggml_tensor * ffn_residual = cur; |
56 | | |
57 | | // Post-attention norm |
58 | 0 | ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); |
59 | 0 | cb(attn_post_norm, "attn_post_norm", il); |
60 | | |
61 | | // FFN layer (MoE or dense) - without residual connection |
62 | 0 | cur = build_layer_ffn(attn_post_norm, il); |
63 | 0 | cb(cur, "ffn_out", il); |
64 | | |
65 | | // Residual connection for FFN - add to the tensor from before post_attention_layernorm |
66 | 0 | cur = ggml_add(ctx0, cur, ffn_residual); |
67 | 0 | cb(cur, "post_moe", il); |
68 | | |
69 | | // Input for next layer |
70 | 0 | inpL = cur; |
71 | 0 | } |
72 | 0 | cur = inpL; |
73 | | |
74 | | // Final norm |
75 | 0 | cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); |
76 | |
|
77 | 0 | cb(cur, "result_norm", -1); |
78 | 0 | res->t_embd = cur; |
79 | | |
80 | | // LM head |
81 | 0 | cur = build_lora_mm(model.output, cur); |
82 | |
|
83 | 0 | cb(cur, "result_output", -1); |
84 | 0 | res->t_logits = cur; |
85 | |
|
86 | 0 | ggml_build_forward_expand(gf, cur); |
87 | 0 | } |
88 | | |
89 | | ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( |
90 | | ggml_tensor * q, |
91 | | ggml_tensor * k, |
92 | | ggml_tensor * v, |
93 | | ggml_tensor * g, |
94 | | ggml_tensor * beta, |
95 | | ggml_tensor * state, |
96 | | ggml_tensor * causal_mask, |
97 | | ggml_tensor * identity, |
98 | | ggml_tensor * diag_mask, |
99 | 0 | int il) { |
100 | 0 | const int64_t S_k = q->ne[0]; |
101 | 0 | const int64_t H_k = q->ne[1]; |
102 | 0 | const int64_t n_tokens = q->ne[2]; |
103 | 0 | const int64_t n_seqs = q->ne[3]; |
104 | |
|
105 | 0 | const int64_t S_v = v->ne[0]; |
106 | 0 | const int64_t H_v = v->ne[1]; |
107 | |
|
108 | 0 | GGML_ASSERT(v->ne[2] == n_tokens); |
109 | 0 | GGML_ASSERT(k->ne[2] == n_tokens); |
110 | 0 | GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); |
111 | 0 | GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); |
112 | 0 | GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); |
113 | |
|
114 | 0 | GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); |
115 | 0 | GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); |
116 | |
|
117 | 0 | GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case |
118 | |
|
119 | 0 | const float eps_norm = hparams.f_norm_rms_eps; |
120 | |
|
121 | 0 | q = ggml_l2_norm(ctx0, q, eps_norm); |
122 | 0 | k = ggml_l2_norm(ctx0, k, eps_norm); |
123 | |
|
124 | 0 | const float scale = 1.0f / sqrtf(S_v); |
125 | |
|
126 | 0 | q = ggml_scale(ctx0, q, scale); |
127 | |
|
128 | 0 | beta = ggml_sigmoid(ctx0, beta); |
129 | |
|
130 | 0 | cb(q, "q_in", il); |
131 | 0 | cb(k, "k_in", il); |
132 | 0 | cb(v, "v_in", il); |
133 | 0 | cb(beta, "beta_in", il); |
134 | 0 | cb(g, "g_in", il); |
135 | |
|
136 | 0 | q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); |
137 | 0 | k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); |
138 | 0 | v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); |
139 | 0 | g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); |
140 | |
|
141 | 0 | beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); |
142 | 0 | state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); |
143 | |
|
144 | 0 | cb(q, "q_perm", il); |
145 | 0 | cb(k, "k_perm", il); |
146 | 0 | cb(v, "v_perm", il); |
147 | 0 | cb(beta, "beta_perm", il); |
148 | 0 | cb(g, "g_perm", il); |
149 | 0 | cb(state, "state_in", il); |
150 | |
|
151 | 0 | GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); |
152 | 0 | GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); |
153 | 0 | GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); |
154 | 0 | GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); |
155 | | |
156 | | // Do padding |
157 | 0 | const int64_t chunk_size = CHUNK_SIZE; |
158 | |
|
159 | 0 | const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; |
160 | 0 | const int64_t n_chunks = (n_tokens + pad) / chunk_size; |
161 | |
|
162 | 0 | q = ggml_pad(ctx0, q, 0, pad, 0, 0); |
163 | 0 | k = ggml_pad(ctx0, k, 0, pad, 0, 0); |
164 | 0 | v = ggml_pad(ctx0, v, 0, pad, 0, 0); |
165 | 0 | g = ggml_pad(ctx0, g, pad, 0, 0, 0); |
166 | 0 | beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); |
167 | |
|
168 | 0 | cb(q, "q_pad", il); |
169 | 0 | cb(k, "k_pad", il); |
170 | 0 | cb(v, "v_pad", il); |
171 | 0 | cb(beta, "beta_pad", il); |
172 | 0 | cb(g, "g_pad", il); |
173 | |
|
174 | 0 | ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); |
175 | 0 | ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); |
176 | |
|
177 | 0 | cb(v_beta, "v_beta", il); |
178 | 0 | cb(k_beta, "k_beta", il); |
179 | |
|
180 | 0 | q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); |
181 | 0 | k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); |
182 | 0 | k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); |
183 | 0 | v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); |
184 | 0 | v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); |
185 | |
|
186 | 0 | g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); |
187 | 0 | beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); |
188 | |
|
189 | 0 | ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); |
190 | |
|
191 | 0 | cb(g_cumsum, "g_cumsum", il); |
192 | |
|
193 | 0 | ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); |
194 | 0 | ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); |
195 | |
|
196 | 0 | ggml_tensor * gcs_j_broadcast = |
197 | 0 | ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); |
198 | |
|
199 | 0 | ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); |
200 | |
|
201 | 0 | cb(decay_mask, "decay_mask", il); |
202 | |
|
203 | 0 | decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); |
204 | 0 | decay_mask = ggml_exp(ctx0, decay_mask); |
205 | 0 | decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); |
206 | |
|
207 | 0 | ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); |
208 | |
|
209 | 0 | ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); |
210 | 0 | ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); |
211 | |
|
212 | 0 | cb(attn, "attn_pre_solve", il); |
213 | |
|
214 | 0 | ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); |
215 | 0 | ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); |
216 | |
|
217 | 0 | ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); |
218 | 0 | attn = ggml_mul(ctx0, lin_solve, causal_mask); |
219 | 0 | attn = ggml_add(ctx0, attn, identity); |
220 | |
|
221 | 0 | cb(attn, "attn_solved", il); |
222 | |
|
223 | 0 | v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); |
224 | |
|
225 | 0 | ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); |
226 | 0 | ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); |
227 | |
|
228 | 0 | ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); |
229 | |
|
230 | 0 | cb(kbeta_gexp, "kbeta_gexp", il); |
231 | |
|
232 | 0 | ggml_tensor * k_cumdecay = |
233 | 0 | ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); |
234 | |
|
235 | 0 | cb(k_cumdecay, "k_cumdecay", il); |
236 | |
|
237 | 0 | ggml_tensor * core_attn_out = nullptr; |
238 | 0 | ggml_tensor * new_state = ggml_dup(ctx0, state); |
239 | |
|
240 | 0 | cb(new_state, "new_state", il); |
241 | |
|
242 | 0 | for (int64_t chunk = 0; chunk < n_chunks; chunk++) { |
243 | 0 | auto chunkify = [=](ggml_tensor * t) { |
244 | 0 | return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], |
245 | 0 | t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); |
246 | 0 | }; |
247 | |
|
248 | 0 | auto chunkify_g = [=](ggml_tensor * t) { |
249 | 0 | return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3], |
250 | 0 | t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); |
251 | 0 | }; |
252 | |
|
253 | 0 | ggml_tensor * k_chunk = chunkify(k); |
254 | 0 | ggml_tensor * q_chunk = chunkify(q); |
255 | 0 | ggml_tensor * v_chunk = chunkify(v); |
256 | |
|
257 | 0 | ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum); |
258 | 0 | ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk)); |
259 | |
|
260 | 0 | ggml_tensor * decay_mask_chunk = chunkify(decay_mask); |
261 | 0 | ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); |
262 | |
|
263 | 0 | ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t); |
264 | | |
265 | | // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) |
266 | 0 | attn = ggml_mul_mat(ctx0, k_chunk, q_chunk); |
267 | 0 | attn = ggml_mul(ctx0, attn, decay_mask_chunk); |
268 | 0 | attn = ggml_mul(ctx0, attn, diag_mask); |
269 | |
|
270 | 0 | ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); |
271 | | |
272 | | // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state |
273 | 0 | ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); |
274 | | |
275 | | // v_new = v_i - v_prime |
276 | 0 | ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); |
277 | 0 | ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); |
278 | | |
279 | | // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state |
280 | 0 | ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); |
281 | 0 | ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); |
282 | | |
283 | | // core_attn_out[:, :, i] = attn_inter + attn @ v_new |
284 | 0 | ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); |
285 | |
|
286 | 0 | ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); |
287 | |
|
288 | 0 | core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); |
289 | | |
290 | | // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) |
291 | | // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() |
292 | | // key_gdiff = key * g_diff.unsqueeze(-1) |
293 | | // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new |
294 | | // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew |
295 | |
|
296 | 0 | ggml_tensor * g_cum_last = |
297 | 0 | ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3], |
298 | 0 | g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3], |
299 | 0 | g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1))); |
300 | |
|
301 | 0 | ggml_tensor * gexp_last = |
302 | 0 | ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); |
303 | |
|
304 | 0 | ggml_tensor * g_cum_last_3d = |
305 | 0 | ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); |
306 | |
|
307 | 0 | ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]); |
308 | |
|
309 | 0 | ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); |
310 | |
|
311 | 0 | ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); |
312 | |
|
313 | 0 | ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk, |
314 | 0 | ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], |
315 | 0 | g_diff_exp->ne[2] * g_diff_exp->ne[3])); |
316 | |
|
317 | 0 | ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); |
318 | |
|
319 | 0 | new_state = ggml_add(ctx0, |
320 | 0 | ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)), |
321 | 0 | ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); |
322 | 0 | } |
323 | |
|
324 | 0 | core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); |
325 | |
|
326 | 0 | ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0); |
327 | 0 | cb(output_tokens, "output_tokens", il); |
328 | | |
329 | | // flatten output |
330 | 0 | ggml_tensor * flat_output = |
331 | 0 | ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); |
332 | |
|
333 | 0 | ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs); |
334 | |
|
335 | 0 | return ggml_concat(ctx0, flat_output, flat_state, 0); |
336 | 0 | } |
337 | | |
338 | | ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive( |
339 | | ggml_tensor * q, |
340 | | ggml_tensor * k, |
341 | | ggml_tensor * v, |
342 | | ggml_tensor * g, |
343 | | ggml_tensor * beta, |
344 | | ggml_tensor * state, |
345 | 0 | int il) { |
346 | 0 | const int64_t S_k = q->ne[0]; |
347 | 0 | const int64_t H_k = q->ne[1]; |
348 | 0 | const int64_t n_tokens = q->ne[2]; |
349 | 0 | const int64_t n_seqs = q->ne[3]; |
350 | |
|
351 | 0 | const int64_t S_v = v->ne[0]; |
352 | 0 | const int64_t H_v = v->ne[1]; |
353 | |
|
354 | 0 | GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing |
355 | 0 | GGML_ASSERT(v->ne[2] == n_tokens); |
356 | 0 | GGML_ASSERT(k->ne[2] == n_tokens); |
357 | 0 | GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); |
358 | 0 | GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); |
359 | 0 | GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); |
360 | |
|
361 | 0 | GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); |
362 | 0 | GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); |
363 | |
|
364 | 0 | GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case |
365 | |
|
366 | 0 | const float eps_norm = hparams.f_norm_rms_eps; |
367 | |
|
368 | 0 | q = ggml_l2_norm(ctx0, q, eps_norm); |
369 | 0 | k = ggml_l2_norm(ctx0, k, eps_norm); |
370 | |
|
371 | 0 | const float scale = 1.0f / sqrtf(S_v); |
372 | |
|
373 | 0 | q = ggml_scale(ctx0, q, scale); |
374 | 0 | beta = ggml_sigmoid(ctx0, beta); |
375 | |
|
376 | 0 | cb(q, "q_in", il); |
377 | 0 | cb(k, "k_in", il); |
378 | 0 | cb(v, "v_in", il); |
379 | 0 | cb(beta, "beta_in", il); |
380 | 0 | cb(g, "g_in", il); |
381 | |
|
382 | 0 | state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); |
383 | |
|
384 | 0 | ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); |
385 | 0 | ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); |
386 | | |
387 | | // Apply exponential to g_t |
388 | 0 | g_t = ggml_exp(ctx0, g_t); |
389 | | |
390 | | // Apply the gated delta rule for the single timestep |
391 | | // last_recurrent_state = last_recurrent_state * g_t |
392 | 0 | state = ggml_mul(ctx0, state, g_t); |
393 | | |
394 | | // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) |
395 | 0 | ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); |
396 | 0 | ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); |
397 | | // we need to sum over dim=-2, so we transpose, sum, then transpose again |
398 | 0 | kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); |
399 | | |
400 | | // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) |
401 | 0 | ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); |
402 | | // delta = (v_t - kv_mem) * beta_t |
403 | 0 | ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] |
404 | 0 | ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); |
405 | | |
406 | | // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta |
407 | 0 | ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); |
408 | 0 | state = ggml_add(ctx0, state, k_t_delta); |
409 | | |
410 | | // Compute the attention output |
411 | | // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) |
412 | 0 | ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t |
413 | 0 | ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); |
414 | | // again, since it's over dim = -2, transpose, sum, transpose back |
415 | 0 | ggml_tensor * core_attn_out = |
416 | 0 | ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); |
417 | | |
418 | | // core_attn_out should be [S_v, 1, H_v, n_seqs] after this |
419 | 0 | cb(core_attn_out, "output_tokens", il); |
420 | 0 | cb(state, "new_state", il); |
421 | | |
422 | | // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise |
423 | 0 | ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs); |
424 | 0 | ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs); |
425 | |
|
426 | 0 | return ggml_concat(ctx0, flat_output, flat_state, 0); |
427 | 0 | } |
428 | | |
429 | | ggml_tensor * llm_build_qwen3next::build_norm_gated( |
430 | | ggml_tensor * input, |
431 | | ggml_tensor * weights, |
432 | | ggml_tensor * gate, |
433 | 0 | int layer) { |
434 | 0 | ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); |
435 | 0 | ggml_tensor * gated_silu = ggml_silu(ctx0, gate); |
436 | |
|
437 | 0 | return ggml_mul(ctx0, normalized, gated_silu); |
438 | 0 | } |
439 | | |
440 | | ggml_tensor * llm_build_qwen3next::build_layer_attn( |
441 | | llm_graph_input_attn_kv * inp, |
442 | | ggml_tensor * cur, |
443 | | ggml_tensor * inp_pos, |
444 | 0 | int il) { |
445 | 0 | const int64_t n_embd_head = hparams.n_embd_head_v; |
446 | 0 | GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); |
447 | | |
448 | | // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention |
449 | | |
450 | | // Qwen3Next uses a single Q projection that outputs query + gate |
451 | 0 | ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); |
452 | 0 | cb(Qcur_full, "Qcur_full", il); |
453 | |
|
454 | 0 | Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1); |
455 | | |
456 | | // Split Q projection into query and gate |
457 | | // The split should be along dimension 0 (the feature dimension) |
458 | 0 | ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, |
459 | 0 | Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); |
460 | 0 | ggml_tensor * gate = |
461 | 0 | ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, |
462 | 0 | Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); |
463 | 0 | cb(Qcur, "Qcur", il); |
464 | 0 | cb(gate, "gate", il); |
465 | | |
466 | | // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention |
467 | 0 | Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); |
468 | 0 | cb(Qcur, "Qcur_reshaped", il); |
469 | | |
470 | | // Apply Q normalization |
471 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); |
472 | 0 | cb(Qcur, "Qcur_normed", il); |
473 | |
|
474 | 0 | ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); |
475 | 0 | cb(Kcur, "Kcur", il); |
476 | |
|
477 | 0 | ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); |
478 | 0 | cb(Vcur, "Vcur", il); |
479 | | |
480 | | // Apply K normalization |
481 | 0 | Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); |
482 | 0 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); |
483 | 0 | cb(Kcur, "Kcur_normed", il); |
484 | | |
485 | | // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads) |
486 | 0 | gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); |
487 | 0 | cb(gate, "gate_reshaped", il); |
488 | |
|
489 | 0 | Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); |
490 | | |
491 | | // Apply RoPE |
492 | 0 | Qcur = ggml_rope_ext( |
493 | 0 | ctx0, Qcur, inp_pos, nullptr, |
494 | 0 | n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
495 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
496 | |
|
497 | 0 | Kcur = ggml_rope_ext( |
498 | 0 | ctx0, Kcur, inp_pos, nullptr, |
499 | 0 | n_rot, rope_type, n_ctx_orig, freq_base, |
500 | 0 | freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); |
501 | |
|
502 | 0 | cb(Qcur, "Qcur", il); |
503 | 0 | cb(Kcur, "Kcur", il); |
504 | 0 | cb(Vcur, "Vcur", il); |
505 | | |
506 | | // Attention computation |
507 | 0 | const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; |
508 | |
|
509 | 0 | cur = build_attn(inp, |
510 | 0 | nullptr, nullptr, |
511 | 0 | Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); |
512 | 0 | cb(cur, "attn_pregate", il); |
513 | |
|
514 | 0 | ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); |
515 | 0 | cb(gate_sigmoid, "gate_sigmoid", il); |
516 | |
|
517 | 0 | cur = ggml_mul(ctx0, cur, gate_sigmoid); |
518 | 0 | cb(cur, "attn_gated", il); |
519 | |
|
520 | 0 | cur = build_lora_mm(model.layers[il].wo, cur); |
521 | 0 | cb(cur, "attn_output", il); |
522 | |
|
523 | 0 | return cur; |
524 | 0 | } |
525 | | |
526 | | ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( |
527 | | llm_graph_input_rs * inp, |
528 | | ggml_tensor * cur, |
529 | | ggml_tensor * causal_mask, |
530 | | ggml_tensor * identity, |
531 | | ggml_tensor * diag_mask, |
532 | 0 | int il) { |
533 | 0 | const auto * mctx_cur = inp->mctx; |
534 | |
|
535 | 0 | const int64_t d_inner = hparams.ssm_d_inner; |
536 | 0 | const int64_t n_seqs = ubatch.n_seqs; |
537 | 0 | const int64_t head_k_dim = hparams.ssm_d_state; |
538 | 0 | const int64_t num_k_heads = hparams.ssm_n_group; |
539 | 0 | const int64_t num_v_heads = hparams.ssm_dt_rank; |
540 | 0 | const int64_t head_v_dim = d_inner / num_v_heads; |
541 | 0 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
542 | |
|
543 | 0 | const auto kv_head = mctx_cur->get_head(); |
544 | |
|
545 | 0 | GGML_ASSERT(n_seqs != 0); |
546 | 0 | GGML_ASSERT(ubatch.equal_seqs()); |
547 | 0 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
548 | | |
549 | | // Input projections |
550 | 0 | ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur); |
551 | 0 | cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); |
552 | |
|
553 | 0 | ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur); |
554 | 0 | cb(mixed_ba, "linear_attn_mixed_ba", il); |
555 | |
|
556 | 0 | int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); |
557 | 0 | ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); |
558 | | |
559 | | // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] |
560 | 0 | int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; |
561 | 0 | ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); |
562 | | |
563 | | // Split mixed_ba into b and a (beta and alpha parameters) |
564 | 0 | int64_t split_sizes_ba[2] = { |
565 | 0 | num_v_heads / num_k_heads, // beta size |
566 | 0 | num_v_heads / num_k_heads // alpha size |
567 | 0 | }; |
568 | |
|
569 | 0 | ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs, |
570 | 0 | mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0); |
571 | 0 | cb(b, "b", il); |
572 | |
|
573 | 0 | ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs, |
574 | 0 | mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], |
575 | 0 | split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); |
576 | 0 | cb(a, "a", il); |
577 | | |
578 | | // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] |
579 | 0 | ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs); |
580 | 0 | ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); |
581 | |
|
582 | 0 | ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); |
583 | 0 | ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); |
584 | 0 | cb(alpha_softplus, "a_softplus", il); |
585 | 0 | ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus |
586 | 0 | cb(gate, "gate", il); |
587 | | |
588 | | // Split mixed_qkvz into query, key, value, z |
589 | 0 | int64_t split_sizes_qkvz[4] = { |
590 | 0 | head_k_dim, // query size |
591 | 0 | head_k_dim, // key size |
592 | 0 | head_v_dim * num_v_heads / num_k_heads, // value size |
593 | 0 | head_v_dim * num_v_heads / num_k_heads // z size |
594 | 0 | }; |
595 | |
|
596 | 0 | ggml_tensor * query = |
597 | 0 | ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, |
598 | 0 | mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); |
599 | 0 | cb(query, "q", il); |
600 | |
|
601 | 0 | ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, |
602 | 0 | mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], |
603 | 0 | split_sizes_qkvz[0] * sizeof(float)); |
604 | 0 | cb(key, "k", il); |
605 | |
|
606 | 0 | ggml_tensor * value = |
607 | 0 | ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, |
608 | 0 | mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], |
609 | 0 | (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)); |
610 | 0 | cb(value, "v", il); |
611 | |
|
612 | 0 | ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, |
613 | 0 | mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], |
614 | 0 | (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); |
615 | 0 | cb(z, "z", il); |
616 | | |
617 | | // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions |
618 | | // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] |
619 | 0 | ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); |
620 | 0 | cb(query_flat, "query_flat", il); |
621 | | |
622 | | // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] |
623 | 0 | ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); |
624 | 0 | cb(key_flat, "key_flat", il); |
625 | | |
626 | | // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] |
627 | 0 | ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); |
628 | 0 | cb(value_flat, "value_flat", il); |
629 | | |
630 | | // Get convolution states from cache |
631 | 0 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
632 | 0 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
633 | | |
634 | | // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); |
635 | | |
636 | | // Build the convolution states tensor |
637 | 0 | ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); |
638 | 0 | cb(conv_states, "conv_states", il); |
639 | | |
640 | | // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] |
641 | 0 | ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); |
642 | 0 | qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); |
643 | 0 | cb(qkv_mixed, "qkv_mixed", il); |
644 | |
|
645 | 0 | qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); |
646 | 0 | cb(qkv_mixed, "qkv_mixed_permuted", il); |
647 | | |
648 | | // Calculate the total conv dimension |
649 | 0 | int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; |
650 | | |
651 | | // Calculate convolution kernel size |
652 | 0 | ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; |
653 | 0 | const int64_t conv_kernel_size = conv_kernel->ne[0]; |
654 | 0 | const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; |
655 | 0 | conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); |
656 | 0 | cb(conv_states, "conv_states_reshaped", il); |
657 | |
|
658 | 0 | ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); |
659 | 0 | cb(conv_input, "conv_input", il); |
660 | | |
661 | | // Update convolution state cache |
662 | | // Extract the last (conv_kernel_size - 1) states from conv_input |
663 | 0 | ggml_tensor * last_conv_states = |
664 | 0 | ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], |
665 | 0 | conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); |
666 | 0 | cb(last_conv_states, "last_conv_states", il); |
667 | |
|
668 | 0 | ggml_tensor * state_update_target = |
669 | 0 | ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, |
670 | 0 | kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); |
671 | 0 | cb(state_update_target, "state_update_target", il); |
672 | |
|
673 | 0 | ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); |
674 | 0 | cb(conv_states_all, "conv_states_updated", il); |
675 | | |
676 | | // Apply SSM convolution |
677 | 0 | ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); |
678 | 0 | cb(conv_output_proper, "conv_output_raw", il); |
679 | |
|
680 | 0 | conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper)); |
681 | 0 | cb(conv_output_proper, "conv_output_pre_silu", il); |
682 | |
|
683 | 0 | ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); |
684 | 0 | cb(conv_output_silu, "conv_output_silu", il); |
685 | |
|
686 | 0 | ggml_tensor * conv_qkv_mix = |
687 | 0 | ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs); |
688 | 0 | cb(conv_qkv_mix, "conv_qkv_mix", il); |
689 | | |
690 | | // Extract the convolved Q, K, V from conv_output |
691 | 0 | ggml_tensor * q_conv = |
692 | 0 | ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0); |
693 | 0 | cb(q_conv, "q_conv", il); |
694 | 0 | ggml_tensor * k_conv = |
695 | 0 | ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], |
696 | 0 | head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); |
697 | 0 | cb(k_conv, "k_conv", il); |
698 | 0 | ggml_tensor * v_conv = |
699 | 0 | ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], |
700 | 0 | 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); |
701 | 0 | cb(v_conv, "v_conv", il); |
702 | | |
703 | | // Unsqueeze them |
704 | 0 | q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); |
705 | 0 | k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); |
706 | 0 | v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); |
707 | |
|
708 | 0 | beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); |
709 | |
|
710 | 0 | ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); |
711 | 0 | state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); |
712 | 0 | cb(state, "state_predelta", il); |
713 | | |
714 | | // if head keys and value keys are different, repeat to force tensors into matching shapes |
715 | 0 | if (num_k_heads != num_v_heads) { |
716 | 0 | GGML_ASSERT(num_v_heads % num_k_heads == 0); |
717 | 0 | int64_t repeat_factor = num_v_heads / num_k_heads; |
718 | | |
719 | | // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back |
720 | 0 | ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); |
721 | 0 | ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); |
722 | | |
723 | | // Repeat along the third dimension (the new dimension with size 1) |
724 | 0 | ggml_tensor * q_repeated = |
725 | 0 | ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); |
726 | 0 | ggml_tensor * k_repeated = |
727 | 0 | ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); |
728 | | |
729 | | // Reshape back to merge the head and repeat dimensions |
730 | | // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs] |
731 | | // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs] |
732 | 0 | q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); |
733 | 0 | k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); |
734 | 0 | } |
735 | |
|
736 | 0 | cb(q_conv, "q_conv_predelta", il); |
737 | 0 | cb(k_conv, "k_conv_predelta", il); |
738 | 0 | cb(v_conv, "v_conv_predelta", il); |
739 | | |
740 | | // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens |
741 | 0 | ggml_tensor * attn_out; |
742 | 0 | if (n_seq_tokens == 1) { |
743 | 0 | attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); |
744 | 0 | } else { |
745 | 0 | attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); |
746 | 0 | } |
747 | 0 | cb(attn_out, "attn_out", il); |
748 | | |
749 | | // The tensors were concatenated 1d, so we need to extract them 1d as well |
750 | 0 | const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; |
751 | 0 | ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0); |
752 | 0 | cb(attn_out_1d, "attn_out_1d", il); |
753 | |
|
754 | 0 | ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); |
755 | 0 | cb(attn_out_final, "attn_out_reshaped", il); |
756 | | |
757 | | // Extract the state part (second part of the concatenated tensor) |
758 | | // State starts after n_tokens elements along dimension 1 |
759 | 0 | const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs; |
760 | |
|
761 | 0 | ggml_tensor * state_1d = |
762 | 0 | ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out)); |
763 | 0 | cb(state_1d, "state_1d", il); |
764 | | |
765 | | // Update the recurrent states |
766 | 0 | ggml_build_forward_expand(gf, |
767 | 0 | ggml_cpy(ctx0, state_1d, |
768 | 0 | ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, |
769 | 0 | kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); |
770 | |
|
771 | 0 | GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out)); |
772 | | |
773 | | // Reshape both attn_out_final and z to 2D tensors for normalization |
774 | | // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] |
775 | 0 | ggml_tensor * attn_out_2d_final = |
776 | 0 | ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); |
777 | | |
778 | | // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] |
779 | 0 | ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); |
780 | | |
781 | | // Apply gated normalization: self.norm(core_attn_out, z) |
782 | 0 | ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); |
783 | | |
784 | | // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] |
785 | 0 | ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); |
786 | 0 | cb(final_output, "final_output", il); |
787 | | |
788 | | // Output projection |
789 | 0 | cur = build_lora_mm(model.layers[il].ssm_out, final_output); |
790 | 0 | cb(cur, "linear_attn_out", il); |
791 | | |
792 | | // Reshape back to original dimensions |
793 | 0 | cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); |
794 | 0 | return cur; |
795 | 0 | } |
796 | | |
797 | 0 | ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) { |
798 | | // Check if this is an MoE layer |
799 | 0 | if (model.layers[il].ffn_gate_inp != nullptr) { |
800 | | // MoE branch |
801 | 0 | ggml_tensor * moe_out = |
802 | 0 | build_moe_ffn(cur, |
803 | 0 | model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, |
804 | 0 | model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, |
805 | 0 | nullptr, |
806 | 0 | n_expert, n_expert_used, LLM_FFN_SILU, |
807 | 0 | true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); |
808 | 0 | cb(moe_out, "ffn_moe_out", il); |
809 | | |
810 | | // Add shared experts if present - following Qwen3Next reference implementation |
811 | 0 | if (model.layers[il].ffn_up_shexp != nullptr) { |
812 | 0 | ggml_tensor * ffn_shexp = |
813 | 0 | build_ffn(cur, |
814 | 0 | model.layers[il].ffn_up_shexp, NULL, NULL, |
815 | 0 | model.layers[il].ffn_gate_shexp, NULL, NULL, |
816 | 0 | model.layers[il].ffn_down_shexp, NULL, NULL, |
817 | 0 | NULL, |
818 | 0 | LLM_FFN_SILU, LLM_FFN_PAR, il); |
819 | 0 | cb(ffn_shexp, "ffn_shexp", il); |
820 | | |
821 | | // Apply shared expert gating as in the reference implementation |
822 | | // The shared expert has its own gate that is sigmoided |
823 | | // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token) |
824 | 0 | ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); |
825 | 0 | cb(shared_gate, "shared_expert_gate", il); |
826 | | |
827 | | // Apply sigmoid to the gate |
828 | 0 | shared_gate = ggml_sigmoid(ctx0, shared_gate); |
829 | 0 | cb(shared_gate, "shared_expert_gate_sigmoid", il); |
830 | | |
831 | | // The gate needs to be broadcast to match the dimensions of ffn_shexp |
832 | | // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1] |
833 | | // We need to repeat the gate along the feature dimension |
834 | 0 | shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp); |
835 | 0 | cb(shared_gate, "shared_expert_gate_broadcast", il); |
836 | | |
837 | | // Apply the gate to the shared expert output |
838 | 0 | ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); |
839 | 0 | cb(ffn_shexp, "ffn_shexp_gated", il); |
840 | |
|
841 | 0 | cur = ggml_add(ctx0, moe_out, ffn_shexp); |
842 | 0 | cb(cur, "ffn_out", il); |
843 | 0 | } else { |
844 | 0 | cur = moe_out; |
845 | 0 | } |
846 | 0 | } else { |
847 | | // Dense FFN branch (not currently used I believe) |
848 | 0 | cur = build_ffn(cur, |
849 | 0 | model.layers[il].ffn_up, NULL, NULL, |
850 | 0 | model.layers[il].ffn_gate, NULL, NULL, |
851 | 0 | model.layers[il].ffn_down, NULL, NULL, |
852 | | NULL, |
853 | 0 | LLM_FFN_SILU, LLM_FFN_PAR, il); |
854 | 0 | cb(cur, "ffn_out", il); |
855 | 0 | } |
856 | 0 | return cur; |
857 | 0 | } |