/src/llama.cpp/src/models/delta-net-base.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | // utility to get one slice from the third dimension |
4 | | // input dim: [x, y, c, b] |
5 | | // output dim: [x, y, 1, b] |
6 | 0 | static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { |
7 | 0 | return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], |
8 | 0 | t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); |
9 | 0 | } |
10 | | |
11 | 0 | llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {} |
12 | | |
13 | | std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_chunking( |
14 | | ggml_tensor * q, |
15 | | ggml_tensor * k, |
16 | | ggml_tensor * v, |
17 | | ggml_tensor * g, |
18 | | ggml_tensor * b, |
19 | | ggml_tensor * s, |
20 | 0 | int il) { |
21 | 0 | const int64_t S_k = q->ne[0]; |
22 | 0 | const int64_t H_k = q->ne[1]; |
23 | 0 | const int64_t n_tokens = q->ne[2]; |
24 | 0 | const int64_t n_seqs = q->ne[3]; |
25 | |
|
26 | 0 | const int64_t S_v = v->ne[0]; |
27 | 0 | const int64_t H_v = v->ne[1]; |
28 | 0 | const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k); |
29 | |
|
30 | 0 | GGML_ASSERT(S_k == S_v); |
31 | 0 | GGML_ASSERT(H_v % H_k == 0); |
32 | |
|
33 | 0 | GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); |
34 | 0 | GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); |
35 | 0 | GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); |
36 | |
|
37 | 0 | GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); |
38 | 0 | GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); |
39 | 0 | GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); |
40 | 0 | GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); |
41 | |
|
42 | 0 | const float scale = 1.0f / sqrtf(S_k); |
43 | |
|
44 | 0 | q = ggml_scale(ctx0, q, scale); |
45 | |
|
46 | 0 | cb(q, "q_in", il); |
47 | 0 | cb(k, "k_in", il); |
48 | 0 | cb(v, "v_in", il); |
49 | 0 | cb(b, "b_in", il); |
50 | 0 | cb(g, "g_in", il); |
51 | |
|
52 | 0 | q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] |
53 | 0 | k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] |
54 | 0 | v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] |
55 | 0 | g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs] |
56 | 0 | b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [ 1, n_tokens, H_v, n_seqs] |
57 | |
|
58 | 0 | const int CS = kda ? 16 : 64; // chunk size |
59 | |
|
60 | 0 | const int pad = (CS - n_tokens % CS) % CS; |
61 | 0 | const int n_chunks = (n_tokens + pad) / CS; |
62 | |
|
63 | 0 | q = ggml_pad(ctx0, q, 0, pad, 0, 0); |
64 | 0 | k = ggml_pad(ctx0, k, 0, pad, 0, 0); |
65 | 0 | v = ggml_pad(ctx0, v, 0, pad, 0, 0); |
66 | 0 | g = ggml_pad(ctx0, g, 0, pad, 0, 0); |
67 | 0 | b = ggml_pad(ctx0, b, 0, pad, 0, 0); |
68 | |
|
69 | 0 | ggml_tensor * v_b = ggml_mul(ctx0, v, b); |
70 | 0 | ggml_tensor * k_b = ggml_mul(ctx0, k, b); |
71 | |
|
72 | 0 | cb(v_b, "v_b", il); |
73 | 0 | cb(k_b, "k_b", il); |
74 | |
|
75 | 0 | q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); |
76 | 0 | k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); |
77 | 0 | k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); |
78 | 0 | v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); |
79 | 0 | v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); |
80 | |
|
81 | 0 | g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs); |
82 | 0 | b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); |
83 | | |
84 | | // [CS, g_0, n_chunks, H_v * n_seqs] |
85 | | // TODO: extend ggml_cumsum with axis parameter to avoid transpose |
86 | 0 | ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g))); |
87 | 0 | cb(g_cs, "g_cs", il); |
88 | |
|
89 | 0 | ggml_tensor * kb = nullptr; |
90 | 0 | ggml_tensor * kq = nullptr; |
91 | 0 | if (kda) { |
92 | 0 | const int64_t CHB = n_chunks * H_k * n_seqs; |
93 | |
|
94 | 0 | ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] |
95 | 0 | ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] |
96 | |
|
97 | 0 | g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] |
98 | | |
99 | | // decay_mask [chunk_size,chunk_size,S_k,CHB] |
100 | 0 | ggml_tensor * decay_mask; |
101 | 0 | decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); |
102 | 0 | decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); |
103 | 0 | decay_mask = ggml_exp(ctx0, decay_mask); |
104 | 0 | cb(decay_mask, "decay_mask", il); |
105 | | |
106 | | // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched |
107 | 0 | decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB); |
108 | |
|
109 | 0 | ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS, 1, CHB); |
110 | 0 | ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, CS, CHB); |
111 | 0 | ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, CS, 1, CHB); |
112 | |
|
113 | 0 | ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i); |
114 | 0 | ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); |
115 | | |
116 | | // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] |
117 | 0 | kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j); |
118 | 0 | kq = ggml_mul_mat(ctx0, decay_q_i, k_j); |
119 | |
|
120 | 0 | kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs))); |
121 | 0 | kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs))); |
122 | 0 | } else { |
123 | 0 | ggml_tensor * g_cs_i = g_cs; |
124 | 0 | ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); |
125 | |
|
126 | 0 | g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); |
127 | | |
128 | | // [CS, CS, n_chunks, H_v * n_seqs] |
129 | 0 | ggml_tensor * decay_mask; |
130 | 0 | decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); |
131 | 0 | decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); |
132 | 0 | decay_mask = ggml_exp(ctx0, decay_mask); |
133 | 0 | cb(decay_mask, "decay_mask", il); |
134 | | |
135 | | // [CS, CS, n_chunks, H_k * n_seqs] |
136 | 0 | kb = ggml_mul_mat(ctx0, k, k_b); |
137 | 0 | kb = ggml_mul (ctx0, kb, decay_mask); |
138 | | |
139 | | // [CS, CS, n_chunks, H_k * n_seqs] |
140 | 0 | kq = ggml_mul_mat(ctx0, k, q); |
141 | 0 | kq = ggml_mul(ctx0, kq, decay_mask); |
142 | 0 | } |
143 | |
|
144 | 0 | kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); |
145 | 0 | cb(kq, "kq", il); |
146 | | |
147 | | // [CS, CS, n_chunks, H_k * n_seqs] |
148 | 0 | ggml_tensor * attn; |
149 | 0 | attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); |
150 | 0 | cb(attn, "attn", il); |
151 | |
|
152 | 0 | ggml_tensor * identity; |
153 | 0 | identity = ggml_view_1d(ctx0, attn, CS, 0); |
154 | 0 | identity = ggml_fill (ctx0, identity, 1.0f); |
155 | 0 | identity = ggml_diag (ctx0, identity); |
156 | |
|
157 | 0 | ggml_tensor * lhs = ggml_add(ctx0, attn, identity); |
158 | 0 | cb(lhs, "dnet_add_ch_lhs", il); |
159 | |
|
160 | 0 | attn = ggml_neg(ctx0, attn); |
161 | 0 | cb(attn, "attn_pre_solve", il); |
162 | |
|
163 | 0 | ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); |
164 | 0 | attn = ggml_add(ctx0, lin_solve, identity); |
165 | 0 | cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] |
166 | | |
167 | | // [S_v, CS, n_chunks, H_v * n_seqs] |
168 | 0 | v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); |
169 | | |
170 | | // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] |
171 | 0 | ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); |
172 | |
|
173 | 0 | k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); |
174 | | |
175 | | // [CS, S_k, n_chunks, H_k * n_seqs] |
176 | 0 | ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); |
177 | 0 | cb(kbg, "k_beta_g_exp", il); |
178 | | |
179 | | // [S_k, CS, n_chunks, H_k * n_seqs] |
180 | 0 | ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); |
181 | 0 | cb(k_cd, "k_cumdecay", il); |
182 | | |
183 | | // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs] |
184 | 0 | ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp)); |
185 | 0 | ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); |
186 | | |
187 | | // vectorized calculation of key_gdiff |
188 | | // improved from the chunked version: |
189 | | // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) |
190 | | // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() |
191 | | // key_gdiff = key * g_diff.unsqueeze(-1) |
192 | | // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new |
193 | | // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew |
194 | | |
195 | | // get last element in g_cumsum along CS dimension (ne0) |
196 | | // example: [[x, y, z, ..., last], ...] -> [[last], ...] |
197 | | // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs] |
198 | 0 | ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3], |
199 | 0 | g_cs->nb[1], |
200 | 0 | g_cs->nb[2], |
201 | 0 | g_cs->nb[3], |
202 | 0 | ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); |
203 | 0 | cb(g_last, "g_last", il); |
204 | | |
205 | | // TODO: remove this cont when CUDA supports non-cont unary ops |
206 | 0 | g_last = ggml_cont(ctx0, g_last); |
207 | | |
208 | | // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs] |
209 | 0 | ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last)); |
210 | 0 | cb(g_last_exp_t, "g_last_exp_t", il); |
211 | | |
212 | | // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] |
213 | 0 | ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); |
214 | 0 | cb(g_diff, "g_diff", il); |
215 | |
|
216 | 0 | ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff))); |
217 | | |
218 | | // [S_k, CS, n_chunks, H_v * n_seqs] |
219 | 0 | ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); |
220 | 0 | cb(kg, "key_gdiff", il); |
221 | | |
222 | | // [CS, S_k, n_chunks, H_v * n_seqs] |
223 | 0 | ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); |
224 | 0 | cb(kg_t, "key_gdiff_t", il); |
225 | |
|
226 | 0 | ggml_tensor * s_t = ggml_transpose(ctx0, s); |
227 | 0 | s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); |
228 | 0 | cb(s_t, "dnet_add_ch_state", il); |
229 | | |
230 | | // [CS, S_v, n_chunks, H_v * n_seqs] |
231 | 0 | ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); |
232 | |
|
233 | 0 | for (int64_t chunk = 0; chunk < n_chunks; chunk++) { |
234 | 0 | ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] |
235 | 0 | ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] |
236 | 0 | ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] |
237 | 0 | ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] |
238 | 0 | ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] |
239 | | |
240 | | // [CS, S_v, 1, H_v * n_seqs] |
241 | 0 | ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); |
242 | 0 | cb(v_t_p, "v_prime", il); |
243 | | |
244 | | // [CS, S_v, 1, H_v * n_seqs] |
245 | 0 | ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); |
246 | 0 | cb(v_t_new, "v_t_new", il); |
247 | | |
248 | | // [S_v, CS, 1, H_v * n_seqs] |
249 | 0 | ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); |
250 | 0 | cb(v_attn, "v_attn", il); |
251 | | |
252 | | // [S_v, CS, 1, H_v * n_seqs] |
253 | 0 | ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); |
254 | 0 | cb(attn_inter, "attn_inter", il); |
255 | | |
256 | | // [S_v, CS, 1, H_v * n_seqs] |
257 | 0 | ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); |
258 | 0 | cb(o_ch, "dnet_add_ch_attn_out", il); |
259 | |
|
260 | 0 | v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); |
261 | | |
262 | | // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new |
263 | | // TODO: head broadcast might not work here - probably will need a transpose |
264 | 0 | ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] |
265 | | |
266 | | // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew |
267 | 0 | ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk); |
268 | |
|
269 | 0 | s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t); |
270 | 0 | s_t = ggml_add(ctx0, s_t, kgv); |
271 | 0 | cb(s_t, "dnet_add_ch_state", il); |
272 | 0 | } |
273 | |
|
274 | 0 | s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); |
275 | | |
276 | | // truncate padded tokens |
277 | 0 | ggml_tensor * o = ggml_view_4d(ctx0, v, |
278 | 0 | S_v, n_tokens, H_v, n_seqs, |
279 | 0 | ggml_row_size(v->type, S_v), |
280 | 0 | ggml_row_size(v->type, S_v * CS * n_chunks), |
281 | 0 | ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); |
282 | 0 | o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] |
283 | 0 | s = ggml_transpose(ctx0, s_t); |
284 | 0 | cb(s, "output_state", il); |
285 | |
|
286 | 0 | return {o, s}; |
287 | 0 | } |
288 | | |
289 | | std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_autoregressive( |
290 | | ggml_tensor * q, |
291 | | ggml_tensor * k, |
292 | | ggml_tensor * v, |
293 | | ggml_tensor * g, |
294 | | ggml_tensor * b, // beta |
295 | | ggml_tensor * s, // state |
296 | 0 | int il) { |
297 | 0 | const int64_t S_k = q->ne[0]; |
298 | 0 | const int64_t H_k = q->ne[1]; |
299 | 0 | const int64_t n_tokens = q->ne[2]; |
300 | 0 | const int64_t n_seqs = q->ne[3]; |
301 | |
|
302 | 0 | const int64_t S_v = v->ne[0]; |
303 | 0 | const int64_t H_v = v->ne[1]; |
304 | |
|
305 | 0 | GGML_ASSERT(n_tokens == 1); |
306 | |
|
307 | 0 | GGML_ASSERT(S_k == S_v); |
308 | 0 | GGML_ASSERT(H_v % H_k == 0); |
309 | |
|
310 | 0 | GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); |
311 | 0 | GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); |
312 | 0 | GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); |
313 | |
|
314 | 0 | GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); |
315 | 0 | GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); |
316 | 0 | GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); |
317 | 0 | GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); |
318 | |
|
319 | 0 | const float scale = 1.0f / sqrtf(S_k); |
320 | |
|
321 | 0 | q = ggml_scale(ctx0, q, scale); |
322 | |
|
323 | 0 | q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] |
324 | 0 | k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] |
325 | 0 | v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] |
326 | |
|
327 | 0 | cb(q, "q_in", il); |
328 | 0 | cb(k, "k_in", il); |
329 | 0 | cb(v, "v_in", il); |
330 | 0 | cb(b, "b_in", il); |
331 | 0 | cb(g, "g_in", il); |
332 | | |
333 | | // GDA: [1, 1, H_v, n_seqs] |
334 | | // KDA: [1, S_k, H_v, n_seqs] |
335 | 0 | g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs); |
336 | 0 | b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); |
337 | | |
338 | | // [S_v, S_v, H_v, n_seqs] |
339 | 0 | g = ggml_exp(ctx0, g); |
340 | 0 | s = ggml_mul(ctx0, s, g); |
341 | |
|
342 | 0 | ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); |
343 | | |
344 | | // [1, S_v, H_v, n_seqs] |
345 | 0 | ggml_tensor * sk; |
346 | 0 | sk = ggml_mul (ctx0, s_t, k); |
347 | 0 | sk = ggml_sum_rows(ctx0, sk); |
348 | | |
349 | | // [S_v, 1, H_v, n_seqs] |
350 | 0 | ggml_tensor * d; |
351 | 0 | d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); |
352 | 0 | d = ggml_mul(ctx0, d, b); |
353 | | |
354 | | // [1, S_v, H_v, n_seqs] |
355 | 0 | ggml_tensor * d_t; |
356 | 0 | d_t = ggml_transpose(ctx0, d); |
357 | | |
358 | | // [S_v, S_v, H_v, n_seqs] |
359 | 0 | ggml_tensor * kd; |
360 | 0 | k = ggml_repeat(ctx0, k, s); |
361 | 0 | kd = ggml_mul (ctx0, k, d_t); |
362 | |
|
363 | 0 | s_t = ggml_add(ctx0, s_t, kd); |
364 | |
|
365 | 0 | cb(s_t, "dnet_add_ar_state", il); |
366 | |
|
367 | 0 | ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); |
368 | 0 | ggml_tensor * o = ggml_sum_rows(ctx0, s_q); |
369 | |
|
370 | 0 | o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] |
371 | 0 | s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] |
372 | |
|
373 | 0 | return {o, s}; |
374 | 0 | } |