/src/llama.cpp/src/models/delta-net-base.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | #include "llama-impl.h" |
4 | | #include "llama-memory-recurrent.h" |
5 | | |
6 | | // utility to get one slice from the third dimension |
7 | | // input dim: [x, y, c, b] |
8 | | // output dim: [x, y, 1, b] |
9 | 0 | static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { |
10 | 0 | return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], |
11 | 0 | t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); |
12 | 0 | } |
13 | | |
14 | 0 | llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {} |
15 | | |
16 | | std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_chunking( |
17 | | ggml_tensor * q, |
18 | | ggml_tensor * k, |
19 | | ggml_tensor * v, |
20 | | ggml_tensor * g, |
21 | | ggml_tensor * b, |
22 | | ggml_tensor * s, |
23 | 0 | int il) { |
24 | 0 | const int64_t S_k = q->ne[0]; |
25 | 0 | const int64_t H_k = q->ne[1]; |
26 | 0 | const int64_t n_tokens = q->ne[2]; |
27 | 0 | const int64_t n_seqs = q->ne[3]; |
28 | |
|
29 | 0 | const int64_t S_v = v->ne[0]; |
30 | 0 | const int64_t H_v = v->ne[1]; |
31 | 0 | const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k); |
32 | |
|
33 | 0 | GGML_ASSERT(S_k == S_v); |
34 | 0 | GGML_ASSERT(H_v % H_k == 0); |
35 | |
|
36 | 0 | GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); |
37 | 0 | GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); |
38 | 0 | GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); |
39 | |
|
40 | 0 | GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); |
41 | 0 | GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); |
42 | 0 | GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); |
43 | 0 | GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); |
44 | |
|
45 | 0 | const float scale = 1.0f / sqrtf(S_k); |
46 | |
|
47 | 0 | q = ggml_scale(ctx0, q, scale); |
48 | |
|
49 | 0 | cb(q, "q_in", il); |
50 | 0 | cb(k, "k_in", il); |
51 | 0 | cb(v, "v_in", il); |
52 | 0 | cb(b, "b_in", il); |
53 | 0 | cb(g, "g_in", il); |
54 | |
|
55 | 0 | q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] |
56 | 0 | k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] |
57 | 0 | v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] |
58 | 0 | g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs] |
59 | 0 | b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [ 1, n_tokens, H_v, n_seqs] |
60 | |
|
61 | 0 | const int CS = kda ? 16 : 64; // chunk size |
62 | |
|
63 | 0 | const int pad = (CS - n_tokens % CS) % CS; |
64 | 0 | const int n_chunks = (n_tokens + pad) / CS; |
65 | |
|
66 | 0 | q = ggml_pad(ctx0, q, 0, pad, 0, 0); |
67 | 0 | k = ggml_pad(ctx0, k, 0, pad, 0, 0); |
68 | 0 | v = ggml_pad(ctx0, v, 0, pad, 0, 0); |
69 | 0 | g = ggml_pad(ctx0, g, 0, pad, 0, 0); |
70 | 0 | b = ggml_pad(ctx0, b, 0, pad, 0, 0); |
71 | |
|
72 | 0 | ggml_tensor * v_b = ggml_mul(ctx0, v, b); |
73 | 0 | ggml_tensor * k_b = ggml_mul(ctx0, k, b); |
74 | |
|
75 | 0 | cb(v_b, "v_b", il); |
76 | 0 | cb(k_b, "k_b", il); |
77 | |
|
78 | 0 | q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); |
79 | 0 | k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); |
80 | 0 | k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); |
81 | 0 | v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); |
82 | 0 | v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); |
83 | |
|
84 | 0 | g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs); |
85 | 0 | b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); |
86 | | |
87 | | // [CS, g_0, n_chunks, H_v * n_seqs] |
88 | | // TODO: extend ggml_cumsum with axis parameter to avoid transpose |
89 | 0 | ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g))); |
90 | 0 | cb(g_cs, "g_cs", il); |
91 | |
|
92 | 0 | ggml_tensor * kb = nullptr; |
93 | 0 | ggml_tensor * kq = nullptr; |
94 | 0 | if (kda) { |
95 | 0 | const int64_t CHB = n_chunks * H_k * n_seqs; |
96 | |
|
97 | 0 | ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] |
98 | 0 | ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] |
99 | |
|
100 | 0 | g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] |
101 | | |
102 | | // decay_mask [chunk_size,chunk_size,S_k,CHB] |
103 | 0 | ggml_tensor * decay_mask; |
104 | 0 | decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); |
105 | 0 | decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); |
106 | 0 | decay_mask = ggml_exp(ctx0, decay_mask); |
107 | 0 | cb(decay_mask, "decay_mask", il); |
108 | | |
109 | | // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched |
110 | 0 | decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB); |
111 | |
|
112 | 0 | ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS, 1, CHB); |
113 | 0 | ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, CS, CHB); |
114 | 0 | ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, CS, 1, CHB); |
115 | |
|
116 | 0 | ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i); |
117 | 0 | ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); |
118 | | |
119 | | // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] |
120 | 0 | kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j); |
121 | 0 | kq = ggml_mul_mat(ctx0, decay_q_i, k_j); |
122 | |
|
123 | 0 | kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs))); |
124 | 0 | kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs))); |
125 | 0 | } else { |
126 | 0 | ggml_tensor * g_cs_i = g_cs; |
127 | 0 | ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); |
128 | |
|
129 | 0 | g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); |
130 | | |
131 | | // [CS, CS, n_chunks, H_v * n_seqs] |
132 | 0 | ggml_tensor * decay_mask; |
133 | 0 | decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); |
134 | 0 | decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); |
135 | 0 | decay_mask = ggml_exp(ctx0, decay_mask); |
136 | 0 | cb(decay_mask, "decay_mask", il); |
137 | | |
138 | | // [CS, CS, n_chunks, H_k * n_seqs] |
139 | 0 | kb = ggml_mul_mat(ctx0, k, k_b); |
140 | 0 | kb = ggml_mul (ctx0, kb, decay_mask); |
141 | | |
142 | | // [CS, CS, n_chunks, H_k * n_seqs] |
143 | 0 | kq = ggml_mul_mat(ctx0, k, q); |
144 | 0 | kq = ggml_mul(ctx0, kq, decay_mask); |
145 | 0 | } |
146 | |
|
147 | 0 | kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); |
148 | 0 | cb(kq, "kq", il); |
149 | | |
150 | | // [CS, CS, n_chunks, H_k * n_seqs] |
151 | 0 | ggml_tensor * attn; |
152 | 0 | attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); |
153 | 0 | cb(attn, "attn", il); |
154 | |
|
155 | 0 | ggml_tensor * identity; |
156 | 0 | identity = ggml_view_1d(ctx0, attn, CS, 0); |
157 | 0 | identity = ggml_fill (ctx0, identity, 1.0f); |
158 | 0 | identity = ggml_diag (ctx0, identity); |
159 | |
|
160 | 0 | ggml_tensor * lhs = ggml_add(ctx0, attn, identity); |
161 | 0 | cb(lhs, "dnet_add_ch_lhs", il); |
162 | |
|
163 | 0 | attn = ggml_neg(ctx0, attn); |
164 | 0 | cb(attn, "attn_pre_solve", il); |
165 | |
|
166 | 0 | ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); |
167 | 0 | attn = ggml_add(ctx0, lin_solve, identity); |
168 | 0 | cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] |
169 | | |
170 | | // [S_v, CS, n_chunks, H_v * n_seqs] |
171 | 0 | v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); |
172 | | |
173 | | // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] |
174 | 0 | ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); |
175 | |
|
176 | 0 | k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); |
177 | | |
178 | | // [CS, S_k, n_chunks, H_k * n_seqs] |
179 | 0 | ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); |
180 | 0 | cb(kbg, "k_beta_g_exp", il); |
181 | | |
182 | | // [S_k, CS, n_chunks, H_k * n_seqs] |
183 | 0 | ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); |
184 | 0 | cb(k_cd, "k_cumdecay", il); |
185 | | |
186 | | // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs] |
187 | 0 | ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp)); |
188 | 0 | ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); |
189 | | |
190 | | // vectorized calculation of key_gdiff |
191 | | // improved from the chunked version: |
192 | | // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) |
193 | | // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() |
194 | | // key_gdiff = key * g_diff.unsqueeze(-1) |
195 | | // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new |
196 | | // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew |
197 | | |
198 | | // get last element in g_cumsum along CS dimension (ne0) |
199 | | // example: [[x, y, z, ..., last], ...] -> [[last], ...] |
200 | | // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs] |
201 | 0 | ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3], |
202 | 0 | g_cs->nb[1], |
203 | 0 | g_cs->nb[2], |
204 | 0 | g_cs->nb[3], |
205 | 0 | ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); |
206 | 0 | cb(g_last, "g_last", il); |
207 | | |
208 | | // TODO: remove this cont when CUDA supports non-cont unary ops |
209 | 0 | g_last = ggml_cont(ctx0, g_last); |
210 | | |
211 | | // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs] |
212 | 0 | ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last)); |
213 | 0 | cb(g_last_exp_t, "g_last_exp_t", il); |
214 | | |
215 | | // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] |
216 | 0 | ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); |
217 | 0 | cb(g_diff, "g_diff", il); |
218 | |
|
219 | 0 | ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff))); |
220 | | |
221 | | // [S_k, CS, n_chunks, H_v * n_seqs] |
222 | 0 | ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); |
223 | 0 | cb(kg, "key_gdiff", il); |
224 | | |
225 | | // [CS, S_k, n_chunks, H_v * n_seqs] |
226 | 0 | ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); |
227 | 0 | cb(kg_t, "key_gdiff_t", il); |
228 | |
|
229 | 0 | s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs); |
230 | 0 | cb(s, "dnet_add_ch_state", il); |
231 | | |
232 | | // [CS, S_v, n_chunks, H_v * n_seqs] |
233 | 0 | ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); |
234 | |
|
235 | 0 | for (int64_t chunk = 0; chunk < n_chunks; chunk++) { |
236 | 0 | ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] |
237 | 0 | ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] |
238 | 0 | ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] |
239 | 0 | ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] |
240 | 0 | ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] |
241 | | |
242 | | // [CS, S_v, 1, H_v * n_seqs] |
243 | 0 | ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s); |
244 | 0 | cb(v_t_p, "v_prime", il); |
245 | | |
246 | | // [CS, S_v, 1, H_v * n_seqs] |
247 | 0 | ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); |
248 | 0 | cb(v_t_new, "v_t_new", il); |
249 | | |
250 | | // [S_v, CS, 1, H_v * n_seqs] |
251 | 0 | ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); |
252 | 0 | cb(v_attn, "v_attn", il); |
253 | | |
254 | | // [S_v, CS, 1, H_v * n_seqs] |
255 | 0 | ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp); |
256 | 0 | cb(attn_inter, "attn_inter", il); |
257 | | |
258 | | // [S_v, CS, 1, H_v * n_seqs] |
259 | 0 | ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); |
260 | 0 | cb(o_ch, "dnet_add_ch_attn_out", il); |
261 | |
|
262 | 0 | v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); |
263 | | |
264 | | // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new |
265 | | // TODO: head broadcast might not work here - probably will need a transpose |
266 | 0 | ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] |
267 | | |
268 | | // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew |
269 | 0 | ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk); |
270 | |
|
271 | 0 | s = ggml_mul(ctx0, s, ch_g_last_exp_t); |
272 | 0 | s = ggml_add(ctx0, s, kgv); |
273 | 0 | cb(s, "dnet_add_ch_state", il); |
274 | 0 | } |
275 | | |
276 | | // truncate padded tokens |
277 | 0 | ggml_tensor * o = ggml_view_4d(ctx0, v, |
278 | 0 | S_v, n_tokens, H_v, n_seqs, |
279 | 0 | ggml_row_size(v->type, S_v), |
280 | 0 | ggml_row_size(v->type, S_v * CS * n_chunks), |
281 | 0 | ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); |
282 | 0 | o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] |
283 | 0 | s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs); |
284 | 0 | cb(s, "output_state", il); |
285 | |
|
286 | 0 | return {o, s}; |
287 | 0 | } |
288 | | |
289 | | std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_autoregressive( |
290 | | ggml_tensor * q, |
291 | | ggml_tensor * k, |
292 | | ggml_tensor * v, |
293 | | ggml_tensor * g, |
294 | | ggml_tensor * b, // beta |
295 | | ggml_tensor * s, // state |
296 | 0 | int il) { |
297 | 0 | const int64_t S_k = q->ne[0]; |
298 | 0 | const int64_t H_k = q->ne[1]; |
299 | 0 | const int64_t n_tokens = q->ne[2]; |
300 | 0 | const int64_t n_seqs = q->ne[3]; |
301 | |
|
302 | 0 | const int64_t S_v = v->ne[0]; |
303 | 0 | const int64_t H_v = v->ne[1]; |
304 | |
|
305 | 0 | GGML_ASSERT(n_tokens == 1); |
306 | |
|
307 | 0 | GGML_ASSERT(S_k == S_v); |
308 | 0 | GGML_ASSERT(H_v % H_k == 0); |
309 | |
|
310 | 0 | GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); |
311 | 0 | GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); |
312 | 0 | GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); |
313 | |
|
314 | 0 | GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); |
315 | 0 | GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); |
316 | 0 | GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); |
317 | 0 | GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); |
318 | |
|
319 | 0 | const float scale = 1.0f / sqrtf(S_k); |
320 | |
|
321 | 0 | q = ggml_scale(ctx0, q, scale); |
322 | |
|
323 | 0 | q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] |
324 | 0 | k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] |
325 | 0 | v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] |
326 | |
|
327 | 0 | cb(q, "q_in", il); |
328 | 0 | cb(k, "k_in", il); |
329 | 0 | cb(v, "v_in", il); |
330 | 0 | cb(b, "b_in", il); |
331 | 0 | cb(g, "g_in", il); |
332 | | |
333 | | // GDA: [1, 1, H_v, n_seqs] |
334 | | // KDA: [1, S_k, H_v, n_seqs] |
335 | 0 | g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs); |
336 | 0 | b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); |
337 | | |
338 | | // [S_v, S_v, H_v, n_seqs] |
339 | 0 | g = ggml_exp(ctx0, g); |
340 | 0 | s = ggml_mul(ctx0, s, g); |
341 | | |
342 | | // [1, S_v, H_v, n_seqs] |
343 | 0 | ggml_tensor * sk; |
344 | 0 | sk = ggml_mul (ctx0, s, k); |
345 | 0 | sk = ggml_sum_rows(ctx0, sk); |
346 | | |
347 | | // [S_v, 1, H_v, n_seqs] |
348 | 0 | ggml_tensor * d; |
349 | 0 | d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); |
350 | 0 | d = ggml_mul(ctx0, d, b); |
351 | | |
352 | | // [1, S_v, H_v, n_seqs] |
353 | 0 | ggml_tensor * d_t; |
354 | 0 | d_t = ggml_transpose(ctx0, d); |
355 | | |
356 | | // [S_v, S_v, H_v, n_seqs] |
357 | 0 | ggml_tensor * kd; |
358 | 0 | k = ggml_repeat(ctx0, k, s); |
359 | 0 | kd = ggml_mul (ctx0, k, d_t); |
360 | |
|
361 | 0 | s = ggml_add(ctx0, s, kd); |
362 | |
|
363 | 0 | cb(s, "dnet_add_ar_state", il); |
364 | |
|
365 | 0 | ggml_tensor * s_q = ggml_mul (ctx0, s, q); |
366 | 0 | ggml_tensor * o = ggml_sum_rows(ctx0, s_q); |
367 | |
|
368 | 0 | o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] |
369 | |
|
370 | 0 | return {o, s}; |
371 | 0 | } |
372 | | |
373 | | std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_fused( |
374 | | ggml_tensor * q, |
375 | | ggml_tensor * k, |
376 | | ggml_tensor * v, |
377 | | ggml_tensor * g, |
378 | | ggml_tensor * b, |
379 | | ggml_tensor * s, |
380 | 0 | int il) { |
381 | 0 | const int64_t S_k = q->ne[0]; |
382 | 0 | const int64_t H_k = q->ne[1]; |
383 | 0 | const int64_t n_tokens = q->ne[2]; |
384 | 0 | const int64_t n_seqs = q->ne[3]; |
385 | |
|
386 | 0 | const int64_t S_v = v->ne[0]; |
387 | 0 | const int64_t H_v = v->ne[1]; |
388 | |
|
389 | 0 | GGML_ASSERT(S_k == S_v); |
390 | 0 | GGML_ASSERT(H_v % H_k == 0); |
391 | |
|
392 | 0 | GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); |
393 | 0 | GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); |
394 | 0 | GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); |
395 | |
|
396 | 0 | GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); |
397 | 0 | GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); |
398 | 0 | GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); |
399 | 0 | GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); |
400 | | |
401 | | // K=1: output carries the final state only. state s is 4D [S_v, S_v, H_v, n_seqs]. |
402 | 0 | ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, /*K=*/1); |
403 | 0 | if (n_tokens == 1) { |
404 | 0 | cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); |
405 | 0 | } else { |
406 | 0 | cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il); |
407 | 0 | } |
408 | |
|
409 | 0 | ggml_tensor * output = ggml_view_4d(ctx0, result, |
410 | 0 | S_v, H_v, n_tokens, n_seqs, |
411 | 0 | ggml_row_size(result->type, S_v), |
412 | 0 | ggml_row_size(result->type, S_v * H_v), |
413 | 0 | ggml_row_size(result->type, S_v * H_v * n_tokens), 0); |
414 | |
|
415 | 0 | ggml_tensor * new_state = ggml_view_4d(ctx0, result, |
416 | 0 | S_v, S_v, H_v, n_seqs, |
417 | 0 | ggml_row_size(result->type, S_v), |
418 | 0 | ggml_row_size(result->type, S_v * S_v), |
419 | 0 | ggml_row_size(result->type, S_v * S_v * H_v), |
420 | 0 | ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs)); |
421 | |
|
422 | 0 | return {output, new_state}; |
423 | 0 | } |
424 | | |
425 | | std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net( |
426 | | ggml_tensor * q, |
427 | | ggml_tensor * k, |
428 | | ggml_tensor * v, |
429 | | ggml_tensor * g, |
430 | | ggml_tensor * b, |
431 | | ggml_tensor * s, |
432 | 0 | int il) { |
433 | 0 | const int64_t n_seq_tokens = q->ne[2]; |
434 | |
|
435 | 0 | if (n_seq_tokens == 1) { |
436 | 0 | if (cparams.fused_gdn_ar) { |
437 | 0 | return build_delta_net_fused(q, k, v, g, b, s, il); |
438 | 0 | } |
439 | 0 | return build_delta_net_autoregressive(q, k, v, g, b, s, il); |
440 | 0 | } |
441 | | |
442 | 0 | if (cparams.fused_gdn_ch) { |
443 | 0 | return build_delta_net_fused(q, k, v, g, b, s, il); |
444 | 0 | } |
445 | | |
446 | 0 | return build_delta_net_chunking(q, k, v, g, b, s, il); |
447 | 0 | } |
448 | | |
449 | | ggml_tensor * llm_build_delta_net_base::build_conv_state( |
450 | | llm_graph_input_rs * inp, |
451 | | ggml_tensor * conv_states_all, |
452 | | ggml_tensor * qkv_mixed, |
453 | | int64_t conv_kernel_size, |
454 | | int64_t conv_channels, |
455 | 0 | int il) { |
456 | 0 | const auto * mctx_cur = inp->mctx; |
457 | |
|
458 | 0 | const auto kv_head = mctx_cur->get_head(); |
459 | 0 | const auto mem_size = mctx_cur->get_size(); |
460 | |
|
461 | 0 | const int64_t n_seqs = ubatch.n_seqs; |
462 | |
|
463 | 0 | ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); |
464 | 0 | cb(conv_states, "conv_states", il); |
465 | |
|
466 | 0 | conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); |
467 | 0 | cb(conv_states, "conv_states_reshaped", il); |
468 | |
|
469 | 0 | qkv_mixed = ggml_transpose(ctx0, qkv_mixed); |
470 | 0 | cb(qkv_mixed, "qkv_mixed_transposed", il); |
471 | |
|
472 | 0 | ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); |
473 | 0 | cb(conv_input, "conv_input", il); |
474 | |
|
475 | 0 | const int64_t row_count = (conv_kernel_size - 1) * conv_channels; |
476 | |
|
477 | 0 | const size_t row_size = ggml_row_size(conv_states_all->type, row_count); |
478 | |
|
479 | 0 | if (cparams.n_rs_seq == 0) { |
480 | 0 | const int64_t s_idx = conv_input->ne[0] - conv_states->ne[0]; |
481 | 0 | const int64_t s_slot = 0; |
482 | |
|
483 | 0 | ggml_tensor * conv_state_last = |
484 | 0 | ggml_view_3d(ctx0, conv_input, |
485 | 0 | conv_kernel_size - 1, conv_channels, n_seqs, |
486 | 0 | conv_input->nb[1], conv_input->nb[2], |
487 | 0 | ggml_row_size(conv_input->type, s_idx)); |
488 | 0 | cb(conv_state_last, "conv_state_last", il); |
489 | |
|
490 | 0 | ggml_tensor * conv_state_update = |
491 | 0 | ggml_view_2d(ctx0, conv_states_all, |
492 | 0 | row_count, n_seqs, conv_states_all->nb[1], |
493 | 0 | (s_slot * mem_size + kv_head) * row_size); |
494 | 0 | cb(conv_state_update, "conv_state_update", il); |
495 | |
|
496 | 0 | ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); |
497 | 0 | } else { |
498 | | // [TAG_RECURRENT_ROLLBACK_SPLITS] |
499 | | // TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are |
500 | | // inside the same ubatch. currently with `split_equal()` this is not correct |
501 | |
|
502 | 0 | const int64_t K = (int64_t) cparams.n_rs_seq + 1; |
503 | |
|
504 | 0 | for (int64_t t = 1; t <= K; ++t) { |
505 | 0 | const int64_t s_idx = std::max<int64_t>(0, conv_input->ne[0] - conv_states->ne[0] - K + t); |
506 | 0 | const int64_t s_slot = K - t; |
507 | |
|
508 | 0 | ggml_tensor * conv_state_last = |
509 | 0 | ggml_view_3d(ctx0, conv_input, |
510 | 0 | conv_kernel_size - 1, conv_channels, n_seqs, |
511 | 0 | conv_input->nb[1], conv_input->nb[2], |
512 | 0 | ggml_row_size(conv_input->type, s_idx)); |
513 | |
|
514 | 0 | ggml_tensor * conv_state_update = |
515 | 0 | ggml_view_2d(ctx0, |
516 | 0 | conv_states_all, row_count, n_seqs, |
517 | 0 | conv_states_all->nb[1], |
518 | 0 | (s_slot * mem_size + kv_head) * row_size); |
519 | |
|
520 | 0 | ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); |
521 | 0 | } |
522 | 0 | } |
523 | |
|
524 | 0 | return conv_input; |
525 | 0 | } |
526 | | |
527 | | ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( |
528 | | llm_graph_input_rs * inp, |
529 | | ggml_tensor * ssm_states_all, |
530 | | ggml_tensor * q, |
531 | | ggml_tensor * k, |
532 | | ggml_tensor * v, |
533 | | ggml_tensor * g, |
534 | | ggml_tensor * b, |
535 | | ggml_tensor * s, |
536 | 0 | int il) { |
537 | 0 | const auto * mctx_cur = inp->mctx; |
538 | 0 | const auto kv_head = mctx_cur->get_head(); |
539 | 0 | const uint32_t mem_size = mctx_cur->get_size(); |
540 | |
|
541 | 0 | const int64_t S_v = s->ne[0]; |
542 | 0 | const int64_t H_v = s->ne[2]; |
543 | 0 | const int64_t n_seqs = s->ne[3]; |
544 | 0 | const int64_t n_seq_tokens = q->ne[2]; |
545 | |
|
546 | 0 | const bool keep = cparams.n_rs_seq > 0; |
547 | |
|
548 | 0 | if (!keep) { |
549 | 0 | auto attn_out = build_delta_net(q, k, v, g, b, s, il); |
550 | 0 | ggml_tensor * output = attn_out.first; |
551 | 0 | ggml_tensor * new_state = attn_out.second; |
552 | 0 | cb(output, "attn_output", il); |
553 | 0 | cb(new_state, "new_state", il); |
554 | |
|
555 | 0 | ggml_build_forward_expand(gf, |
556 | 0 | ggml_cpy(ctx0, new_state, |
557 | 0 | ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], |
558 | 0 | kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); |
559 | |
|
560 | 0 | return output; |
561 | 0 | } |
562 | | |
563 | 0 | const int64_t D = S_v * S_v * H_v; |
564 | 0 | const int64_t K = cparams.n_rs_seq + 1; |
565 | | |
566 | | // state s is 4D [S_v, S_v, H_v, n_seqs]; K snapshot slots are written into the output. |
567 | 0 | ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, K); |
568 | 0 | if (n_seq_tokens > 1) { |
569 | 0 | cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); |
570 | 0 | } else { |
571 | 0 | cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il); |
572 | 0 | } |
573 | |
|
574 | 0 | const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; |
575 | 0 | const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; |
576 | |
|
577 | 0 | ggml_tensor * output = ggml_view_4d(ctx0, gdn_out, |
578 | 0 | S_v, H_v, n_seq_tokens, n_seqs, |
579 | 0 | ggml_row_size(gdn_out->type, S_v), |
580 | 0 | ggml_row_size(gdn_out->type, S_v * H_v), |
581 | 0 | ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens), |
582 | 0 | 0); |
583 | 0 | cb(output, "attn_output", il); |
584 | |
|
585 | 0 | const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); |
586 | | |
587 | | // op writes the last min(n_seq_tokens, K) snapshots; trailing slots are left unwritten |
588 | 0 | const int64_t n_written = std::min<int64_t>(n_seq_tokens, K); |
589 | | |
590 | | // write the produced snapshots into the recurrent cache (snapshot slot i -> rollback group i) |
591 | 0 | ggml_tensor * src = ggml_view_3d(ctx0, gdn_out, |
592 | 0 | D, n_seqs, n_written, |
593 | 0 | ggml_row_size(gdn_out->type, D), |
594 | 0 | ggml_row_size(gdn_out->type, state_size_per_snap), |
595 | 0 | ggml_row_size(gdn_out->type, attn_score_elems)); |
596 | |
|
597 | 0 | ggml_tensor * dst = ggml_view_3d(ctx0, ssm_states_all, |
598 | 0 | D, n_seqs, n_written, |
599 | 0 | ssm_states_all->nb[1], |
600 | 0 | (size_t) mem_size * row_size, |
601 | 0 | (size_t) kv_head * row_size); |
602 | |
|
603 | 0 | ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); |
604 | |
|
605 | 0 | return output; |
606 | 0 | } |