Coverage Report

Created: 2026-06-22 06:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp
Line
Count
Source
1
#include "ops.h"
2
3
#include "ggml-cpu.h"
4
#include "ggml-impl.h"
5
#include "binary-ops.h"
6
#include "simd-gemm.h"
7
#include "ggml.h"
8
#include "unary-ops.h"
9
#include "vec.h"
10
11
#include <algorithm>
12
#include <cfloat>
13
#include <cmath>
14
15
// ggml_compute_forward_dup
16
17
static void ggml_compute_forward_dup_same_cont(
18
        const ggml_compute_params * params,
19
0
        ggml_tensor * dst) {
20
21
0
    const ggml_tensor * src0 = dst->src[0];
22
23
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
24
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
25
0
    GGML_ASSERT(src0->type == dst->type);
26
27
0
    const size_t nb0 = ggml_type_size(src0->type);
28
29
0
    const int ith = params->ith; // thread index
30
0
    const int nth = params->nth; // number of threads
31
32
    // parallelize by blocks
33
0
    const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
34
0
    const int dr = (nk + nth - 1) / nth;
35
0
    const int k0 = dr * ith;
36
0
    const int k1 = MIN(k0 + dr, nk);
37
38
0
    if (k0 < k1) {
39
0
        memcpy(
40
0
            ((char *)  dst->data + k0*nb0),
41
0
            ((char *) src0->data + k0*nb0),
42
0
            (k1 - k0) * nb0);
43
0
    }
44
0
}
45
46
template<typename src_t, typename dst_t>
47
static void ggml_compute_forward_dup_flt(
48
        const ggml_compute_params * params,
49
0
        ggml_tensor * dst) {
50
51
0
    const ggml_tensor * src0 = dst->src[0];
52
53
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
54
0
    GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
55
56
0
    GGML_TENSOR_UNARY_OP_LOCALS
57
58
0
    const int ith = params->ith; // thread index
59
0
    const int nth = params->nth; // number of threads
60
61
    // parallelize by rows
62
0
    const int nr = ne01;
63
    // number of rows per thread
64
0
    const int dr = (nr + nth - 1) / nth;
65
    // row range for this thread
66
0
    const int ir0 = dr * ith;
67
0
    const int ir1 = MIN(ir0 + dr, nr);
68
69
    // case: type & row size equal
70
0
    if (src0->type == dst->type &&
71
0
        ne00 == ne0 &&
72
0
        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
73
        // copy by rows
74
0
        const size_t rs = ne00*nb00;
75
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
76
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
77
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
78
0
                    memcpy(
79
0
                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
80
0
                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
81
0
                        rs);
82
0
                }
83
0
            }
84
0
        }
85
0
        return;
86
0
    }
87
88
    // case: dst tensor is contiguous
89
0
    if (ggml_is_contiguous(dst)) {
90
0
        if (nb00 == sizeof(src_t)) {
91
0
            if constexpr (std::is_same_v<dst_t, src_t>) {
92
                // same type
93
0
                size_t id = 0;
94
0
                const size_t rs = ne00 * nb00;
95
0
                char * dst_ptr = (char *) dst->data;
96
97
0
                for (int i03 = 0; i03 < ne03; i03++) {
98
0
                    for (int i02 = 0; i02 < ne02; i02++) {
99
0
                        id += rs * ir0;
100
0
                        for (int i01 = ir0; i01 < ir1; i01++) {
101
0
                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
102
0
                            memcpy(dst_ptr + id, src0_ptr, rs);
103
0
                            id += rs;
104
0
                        }
105
0
                        id += rs * (ne01 - ir1);
106
0
                    }
107
0
                }
108
0
            } else {
109
                // casting between non-quantized types
110
0
                size_t id = 0;
111
0
                dst_t * dst_ptr = (dst_t *) dst->data;
112
113
0
                for (int i03 = 0; i03 < ne03; i03++) {
114
0
                    for (int i02 = 0; i02 < ne02; i02++) {
115
0
                        id += ne00 * ir0;
116
0
                        for (int i01 = ir0; i01 < ir1; i01++) {
117
0
                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
118
0
                            for (int i00 = 0; i00 < ne00; i00++) {
119
0
                                float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
120
0
                                dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
121
0
                                id++;
122
0
                            }
123
0
                        }
124
0
                        id += ne00 * (ne01 - ir1);
125
0
                    }
126
0
                }
127
0
            }
128
0
        } else {
129
            //printf("%s: this is not optimal - fix me\n", __func__);
130
131
0
            size_t id = 0;
132
0
            dst_t * dst_ptr = (dst_t *) dst->data;
133
134
0
            for (int i03 = 0; i03 < ne03; i03++) {
135
0
                for (int i02 = 0; i02 < ne02; i02++) {
136
0
                    id += ne00 * ir0;
137
0
                    for (int i01 = ir0; i01 < ir1; i01++) {
138
0
                        for (int i00 = 0; i00 < ne00; i00++) {
139
0
                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
140
141
0
                            float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
142
0
                            dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
143
0
                            id++;
144
0
                        }
145
0
                    }
146
0
                    id += ne00 * (ne01 - ir1);
147
0
                }
148
0
            }
149
0
        }
150
0
        return;
151
0
    }
152
153
    // dst counters
154
0
    int64_t i10 = 0;
155
0
    int64_t i11 = 0;
156
0
    int64_t i12 = 0;
157
0
    int64_t i13 = 0;
158
159
0
    if constexpr (std::is_same_v<dst_t, src_t>) {
160
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
161
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
162
0
                i10 += ne00 * ir0;
163
0
                while (i10 >= ne0) {
164
0
                    i10 -= ne0;
165
0
                    if (++i11 == ne1) {
166
0
                        i11 = 0;
167
0
                        if (++i12 == ne2) {
168
0
                            i12 = 0;
169
0
                            if (++i13 == ne3) {
170
0
                                i13 = 0;
171
0
                            }
172
0
                        }
173
0
                    }
174
0
                }
175
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
176
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
177
0
                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
178
0
                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
179
180
0
                        memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
181
182
0
                        if (++i10 == ne00) {
183
0
                            i10 = 0;
184
0
                            if (++i11 == ne01) {
185
0
                                i11 = 0;
186
0
                                if (++i12 == ne02) {
187
0
                                    i12 = 0;
188
0
                                    if (++i13 == ne03) {
189
0
                                        i13 = 0;
190
0
                                    }
191
0
                                }
192
0
                            }
193
0
                        }
194
0
                    }
195
0
                }
196
0
                i10 += ne00 * (ne01 - ir1);
197
0
                while (i10 >= ne0) {
198
0
                    i10 -= ne0;
199
0
                    if (++i11 == ne1) {
200
0
                        i11 = 0;
201
0
                        if (++i12 == ne2) {
202
0
                            i12 = 0;
203
0
                            if (++i13 == ne3) {
204
0
                                i13 = 0;
205
0
                            }
206
0
                        }
207
0
                    }
208
0
                }
209
0
            }
210
0
        }
211
212
0
    } else {
213
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
214
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
215
0
                i10 += ne00 * ir0;
216
0
                while (i10 >= ne0) {
217
0
                    i10 -= ne0;
218
0
                    if (++i11 == ne1) {
219
0
                        i11 = 0;
220
0
                        if (++i12 == ne2) {
221
0
                            i12 = 0;
222
0
                            if (++i13 == ne3) {
223
0
                                i13 = 0;
224
0
                            }
225
0
                        }
226
0
                    }
227
0
                }
228
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
229
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
230
0
                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
231
0
                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
232
233
0
                        float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
234
0
                        *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
235
236
0
                        if (++i10 == ne0) {
237
0
                            i10 = 0;
238
0
                            if (++i11 == ne1) {
239
0
                                i11 = 0;
240
0
                                if (++i12 == ne2) {
241
0
                                    i12 = 0;
242
0
                                    if (++i13 == ne3) {
243
0
                                        i13 = 0;
244
0
                                    }
245
0
                                }
246
0
                            }
247
0
                        }
248
0
                    }
249
0
                }
250
0
                i10 += ne00 * (ne01 - ir1);
251
0
                while (i10 >= ne0) {
252
0
                    i10 -= ne0;
253
0
                    if (++i11 == ne1) {
254
0
                        i11 = 0;
255
0
                        if (++i12 == ne2) {
256
0
                            i12 = 0;
257
0
                            if (++i13 == ne3) {
258
0
                                i13 = 0;
259
0
                            }
260
0
                        }
261
0
                    }
262
0
                }
263
0
            }
264
0
        }
265
0
    }
266
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<unsigned short, unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<unsigned short, ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<unsigned short, float>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<ggml_bf16_t, unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<ggml_bf16_t, float>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, float>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, int>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<int, float>(ggml_compute_params const*, ggml_tensor*)
267
268
269
template<typename src_t>
270
static void ggml_compute_forward_dup_to_q(
271
        const ggml_compute_params * params,
272
0
        ggml_tensor * dst) {
273
274
0
    const ggml_tensor * src0 = dst->src[0];
275
276
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
277
0
    GGML_ASSERT(!ggml_is_quantized(src0->type));
278
279
0
    GGML_TENSOR_UNARY_OP_LOCALS
280
281
0
    const int ith = params->ith; // thread index
282
0
    const int nth = params->nth; // number of threads
283
284
    // parallelize by rows
285
0
    const int nr = ne01;
286
    // number of rows per thread
287
0
    const int dr = (nr + nth - 1) / nth;
288
    // row range for this thread
289
0
    const int ir0 = dr * ith;
290
0
    const int ir1 = MIN(ir0 + dr, nr);
291
292
0
    if (ggml_is_contiguous(dst) &&
293
0
            nb00 == sizeof(src_t) &&
294
0
            ggml_get_type_traits_cpu(dst->type)->from_float) {
295
        // casting non-quantized types --> intermediate f32 --> quantized
296
0
        ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
297
0
        float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
298
299
0
        size_t id = 0;
300
0
        size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
301
0
        char * dst_ptr = (char *) dst->data;
302
303
0
        for (int i03 = 0; i03 < ne03; i03++) {
304
0
            for (int i02 = 0; i02 < ne02; i02++) {
305
0
                id += rs * ir0;
306
0
                for (int i01 = ir0; i01 < ir1; i01++) {
307
0
                    const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
308
309
0
                    for (int i00 = 0; i00 < ne00; i00++) {
310
0
                        src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
311
0
                    }
312
313
0
                    quantize_row_q(src0_f32, dst_ptr + id, ne00);
314
0
                    id += rs;
315
0
                }
316
0
                id += rs * (ne01 - ir1);
317
0
            }
318
0
        }
319
0
    } else {
320
        // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
321
0
        GGML_ABORT("not implemented");
322
0
    }
323
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_to_q<unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_to_q<ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_to_q<float>(ggml_compute_params const*, ggml_tensor*)
324
325
// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
326
static void ggml_compute_forward_dup_bytes(
327
        const ggml_compute_params * params,
328
0
        ggml_tensor * dst) {
329
0
    const ggml_tensor * src0 = dst->src[0];
330
331
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
332
0
    GGML_ASSERT(src0->type == dst->type);
333
334
0
    GGML_TENSOR_UNARY_OP_LOCALS;
335
336
0
    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
337
0
        ggml_compute_forward_dup_same_cont(params, dst);
338
0
        return;
339
0
    }
340
341
0
    const size_t type_size = ggml_type_size(src0->type);
342
343
0
    const int ith = params->ith; // thread index
344
0
    const int nth = params->nth; // number of threads
345
346
    // parallelize by rows
347
0
    const int nr = ne01;
348
    // number of rows per thread
349
0
    const int dr = (nr + nth - 1) / nth;
350
    // row range for this thread
351
0
    const int ir0 = dr * ith;
352
0
    const int ir1 = MIN(ir0 + dr, nr);
353
354
0
    if (src0->type == dst->type &&
355
0
        ggml_are_same_shape(src0, dst) &&
356
0
        nb00 == type_size && nb0 == type_size) {
357
        // copy by rows
358
0
        const size_t rs = ggml_row_size(src0->type, ne00);
359
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
360
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
361
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
362
0
                    memcpy(
363
0
                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
364
0
                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
365
0
                        rs);
366
0
                }
367
0
            }
368
0
        }
369
0
        return;
370
0
    }
371
372
0
    if (ggml_is_contiguous(dst)) {
373
0
        size_t id = 0;
374
0
        char * dst_ptr = (char *) dst->data;
375
0
        const size_t rs = ne00 * type_size;
376
377
0
        if (nb00 == type_size) {
378
            // src0 is contiguous on first dimension, copy by rows
379
0
            for (int64_t i03 = 0; i03 < ne03; i03++) {
380
0
                for (int64_t i02 = 0; i02 < ne02; i02++) {
381
0
                    id += rs * ir0;
382
0
                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
383
0
                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
384
0
                        memcpy(dst_ptr + id, src0_ptr, rs);
385
0
                        id += rs;
386
0
                    }
387
0
                    id += rs * (ne01 - ir1);
388
0
                }
389
0
            }
390
0
        } else {
391
            //printf("%s: this is not optimal - fix me\n", __func__);
392
393
0
            for (int64_t i03 = 0; i03 < ne03; i03++) {
394
0
                for (int64_t i02 = 0; i02 < ne02; i02++) {
395
0
                    id += rs * ir0;
396
0
                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
397
0
                        for (int64_t i00 = 0; i00 < ne00; i00++) {
398
0
                            const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
399
0
                            memcpy(dst_ptr + id, src0_ptr, type_size);
400
401
0
                            id += type_size;
402
0
                        }
403
0
                    }
404
0
                    id += rs * (ne01 - ir1);
405
0
                }
406
0
            }
407
0
        }
408
409
0
        return;
410
0
    }
411
412
    // dst counters
413
0
    int64_t k10 = 0;
414
0
    int64_t i11 = 0;
415
0
    int64_t i12 = 0;
416
0
    int64_t i13 = 0;
417
418
    // number of blocks in a row
419
0
    const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
420
0
    const int64_t nk0  = ne0  / ggml_blck_size(dst->type);
421
422
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
423
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
424
0
            k10 += nk00 * ir0;
425
0
            while (k10 >= nk0) {
426
0
                k10 -= nk0;
427
0
                if (++i11 == ne1) {
428
0
                    i11 = 0;
429
0
                    if (++i12 == ne2) {
430
0
                        i12 = 0;
431
0
                        if (++i13 == ne3) {
432
0
                            i13 = 0;
433
0
                        }
434
0
                    }
435
0
                }
436
0
            }
437
0
            for (int64_t i01 = ir0; i01 < ir1; i01++) {
438
0
                for (int64_t k00 = 0; k00 < nk00; k00++) {
439
0
                    const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
440
0
                          char * dst_ptr  = ((char *)  dst->data + k10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
441
442
0
                    memcpy(dst_ptr, src0_ptr, type_size);
443
444
0
                    if (++k10 == nk0) {
445
0
                        k10 = 0;
446
0
                        if (++i11 == ne1) {
447
0
                            i11 = 0;
448
0
                            if (++i12 == ne2) {
449
0
                                i12 = 0;
450
0
                                if (++i13 == ne3) {
451
0
                                    i13 = 0;
452
0
                                }
453
0
                            }
454
0
                        }
455
0
                    }
456
0
                }
457
0
            }
458
0
            k10 += nk00 * (ne01 - ir1);
459
0
            while (k10 >= nk0) {
460
0
                k10 -= nk0;
461
0
                if (++i11 == ne1) {
462
0
                    i11 = 0;
463
0
                    if (++i12 == ne2) {
464
0
                        i12 = 0;
465
0
                        if (++i13 == ne3) {
466
0
                            i13 = 0;
467
0
                        }
468
0
                    }
469
0
                }
470
0
            }
471
0
        }
472
0
    }
473
0
}
474
475
static void ggml_compute_forward_dup_from_q(
476
        const ggml_compute_params * params,
477
0
              ggml_tensor * dst) {
478
479
0
    const ggml_tensor * src0 = dst->src[0];
480
0
    const ggml_tensor * src1 = dst->src[1];
481
482
0
    GGML_TENSOR_BINARY_OP_LOCALS
483
484
0
    const ggml_type type = src0->type;
485
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
486
487
0
    size_t qk = ggml_blck_size(type);
488
0
    const int64_t nr = ggml_nelements(src1) / qk;
489
490
    // destination must be contiguous in the first dimension
491
0
    GGML_ASSERT(nb10 == ggml_type_size(dst->type));
492
    // must either have first dimension large enough to hold a row, or fully contiguous
493
0
    GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
494
495
0
    const int ith = params->ith;
496
0
    const int nth = params->nth;
497
498
0
    const int dr = (nr + nth - 1)/nth;
499
500
    // row range for this thread
501
0
    const int ir0 = dr*ith;
502
0
    const int ir1 = MIN(ir0 + dr, nr);
503
504
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
505
506
0
        uint32_t i = ir * qk;
507
508
0
        const int64_t i03 = i/(ne00 * ne01 * ne02);
509
0
        const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
510
0
        const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
511
0
        const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
512
0
        const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
513
514
0
        const int64_t i13 = i/(ne10 * ne11 * ne12);
515
0
        const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
516
0
        const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
517
0
        const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
518
0
        const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
519
520
0
        dequantize_row_q(
521
0
                (const void *) ((char *) src0->data + x_offset),
522
0
                     (float *) ((char *)  dst->data + dst_offset), qk);
523
0
    }
524
0
}
525
526
void ggml_compute_forward_dup(
527
        const ggml_compute_params * params,
528
0
        ggml_tensor * dst) {
529
530
0
    const ggml_tensor * src0 = dst->src[0];
531
532
0
    if (src0->type == dst->type) {
533
0
        ggml_compute_forward_dup_bytes(params, dst);
534
0
        return;
535
0
    }
536
537
0
    switch (src0->type) {
538
0
        case GGML_TYPE_F16:
539
0
            {
540
0
                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
541
0
                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
542
0
                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_fp16_t, float      >(params, dst);
543
0
                else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
544
0
            } break;
545
0
        case GGML_TYPE_BF16:
546
0
            {
547
0
                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
548
0
                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
549
0
                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_bf16_t, float      >(params, dst);
550
0
                else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
551
0
            } break;
552
0
        case GGML_TYPE_F32:
553
0
            {
554
0
                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
555
0
                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
556
0
                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<float, float      >(params, dst);
557
0
                else if (dst->type == GGML_TYPE_I32)  ggml_compute_forward_dup_flt<float, int32_t    >(params, dst);
558
0
                else ggml_compute_forward_dup_to_q<float>(params, dst);
559
0
            } break;
560
0
        case GGML_TYPE_I32:
561
0
            {
562
0
                if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
563
0
                else GGML_ABORT("not implemented");
564
0
            } break;
565
0
        default:
566
0
            {
567
0
                if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
568
0
                    ggml_compute_forward_dup_from_q(params, dst);
569
0
                    break;
570
0
                }
571
0
                GGML_ABORT("fatal error");
572
0
            }
573
0
    }
574
0
}
575
576
// ggml_compute_forward_add
577
578
static void ggml_compute_forward_add_q_f32(
579
        const ggml_compute_params * params,
580
0
        ggml_tensor * dst) {
581
582
0
    const ggml_tensor * src0 = dst->src[0];
583
0
    const ggml_tensor * src1 = dst->src[1];
584
585
0
    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
586
587
0
    const int nr  = ggml_nrows(src0);
588
589
0
    GGML_TENSOR_BINARY_OP_LOCALS
590
591
0
    const int ith = params->ith;
592
0
    const int nth = params->nth;
593
594
0
    const ggml_type type = src0->type;
595
0
    const ggml_type dtype = dst->type;
596
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
597
0
    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
598
599
    // we don't support permuted src0 or src1
600
0
    GGML_ASSERT(nb00 == ggml_type_size(type));
601
0
    GGML_ASSERT(nb10 == sizeof(float));
602
603
    // dst cannot be transposed or permuted
604
0
    GGML_ASSERT(nb0 <= nb1);
605
0
    GGML_ASSERT(nb1 <= nb2);
606
0
    GGML_ASSERT(nb2 <= nb3);
607
608
0
    GGML_ASSERT(ggml_is_quantized(src0->type));
609
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
610
611
    // rows per thread
612
0
    const int dr = (nr + nth - 1)/nth;
613
614
    // row range for this thread
615
0
    const int ir0 = dr*ith;
616
0
    const int ir1 = MIN(ir0 + dr, nr);
617
618
0
    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
619
620
0
    for (int ir = ir0; ir < ir1; ++ir) {
621
        // src0 indices
622
0
        const int i03 = ir/(ne02*ne01);
623
0
        const int i02 = (ir - i03*ne02*ne01)/ne01;
624
0
        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
625
626
        // src1 and dst are same shape as src0 => same indices
627
0
        const int i13 = i03;
628
0
        const int i12 = i02;
629
0
        const int i11 = i01;
630
631
0
        const int i3 = i03;
632
0
        const int i2 = i02;
633
0
        const int i1 = i01;
634
635
0
        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
636
0
        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
637
0
        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
638
639
0
        assert(ne00 % 32 == 0);
640
641
        // unquantize row from src0 to temp buffer
642
0
        dequantize_row_q(src0_row, wdata, ne00);
643
        // add src1
644
0
        ggml_vec_acc_f32(ne00, wdata, src1_row);
645
        // quantize row to dst
646
0
        if (quantize_row_q != NULL) {
647
0
            quantize_row_q(wdata, dst_row, ne00);
648
0
        } else {
649
0
            memcpy(dst_row, wdata, ne0*nb0);
650
0
        }
651
0
    }
652
0
}
653
654
void ggml_compute_forward_add(
655
        const ggml_compute_params * params,
656
0
        ggml_tensor * dst) {
657
658
0
    const ggml_tensor * src0 = dst->src[0];
659
660
0
    switch (src0->type) {
661
0
        case GGML_TYPE_F32:
662
0
        case GGML_TYPE_F16:
663
0
        case GGML_TYPE_BF16:
664
0
            {
665
0
                ggml_compute_forward_add_non_quantized(params, dst);
666
0
            } break;
667
0
        case GGML_TYPE_Q1_0:
668
0
        case GGML_TYPE_Q4_0:
669
0
        case GGML_TYPE_Q4_1:
670
0
        case GGML_TYPE_Q5_0:
671
0
        case GGML_TYPE_Q5_1:
672
0
        case GGML_TYPE_Q8_0:
673
0
        case GGML_TYPE_MXFP4:
674
0
        case GGML_TYPE_NVFP4:
675
0
        case GGML_TYPE_Q2_K:
676
0
        case GGML_TYPE_Q3_K:
677
0
        case GGML_TYPE_Q4_K:
678
0
        case GGML_TYPE_Q5_K:
679
0
        case GGML_TYPE_Q6_K:
680
0
        case GGML_TYPE_TQ1_0:
681
0
        case GGML_TYPE_TQ2_0:
682
0
        case GGML_TYPE_IQ2_XXS:
683
0
        case GGML_TYPE_IQ2_XS:
684
0
        case GGML_TYPE_IQ3_XXS:
685
0
        case GGML_TYPE_IQ1_S:
686
0
        case GGML_TYPE_IQ1_M:
687
0
        case GGML_TYPE_IQ4_NL:
688
0
        case GGML_TYPE_IQ4_XS:
689
0
        case GGML_TYPE_IQ3_S:
690
0
        case GGML_TYPE_IQ2_S:
691
0
            {
692
0
                ggml_compute_forward_add_q_f32(params, dst);
693
0
            } break;
694
0
        default:
695
0
            {
696
0
                GGML_ABORT("fatal error");
697
0
            }
698
0
    }
699
0
}
700
701
// ggml_compute_forward_add_id
702
703
static void ggml_compute_forward_add_id_f32(
704
        const ggml_compute_params * params,
705
0
        ggml_tensor * dst) {
706
707
0
    const ggml_tensor * src0 = dst->src[0];
708
0
    const ggml_tensor * src1 = dst->src[1];
709
0
    const ggml_tensor * src2 = dst->src[2];
710
711
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
712
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
713
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
714
0
    GGML_ASSERT(src2->type == GGML_TYPE_I32);
715
716
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
717
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
718
719
0
    const int ith = params->ith;
720
0
    const int nth = params->nth;
721
722
0
    const int nr  = ggml_nrows(src0);
723
724
0
    GGML_TENSOR_TERNARY_OP_LOCALS
725
726
0
    GGML_ASSERT( nb0 == sizeof(float));
727
0
    GGML_ASSERT(nb10 == sizeof(float));
728
729
    // rows per thread
730
0
    const int dr = (nr + nth - 1)/nth;
731
732
    // row range for this thread
733
0
    const int ir0 = dr*ith;
734
0
    const int ir1 = MIN(ir0 + dr, nr);
735
736
0
    for (int ir = ir0; ir < ir1; ++ir) {
737
        // src0 indices
738
0
        const int i3 = ir/(ne2*ne1);
739
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
740
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
741
742
        // src1 indices
743
0
        const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
744
745
0
        GGML_ASSERT(i11 >= 0 && i11 < ne11);
746
747
0
        ggml_vec_add_f32(ne0,
748
0
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
749
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
750
0
                (float *) ((char *) src1->data + i11*nb11));
751
0
    }
752
0
}
753
754
void ggml_compute_forward_add_id(
755
        const ggml_compute_params * params,
756
0
        ggml_tensor * dst) {
757
758
0
    const ggml_tensor * src0 = dst->src[0];
759
760
0
    switch (src0->type) {
761
0
        case GGML_TYPE_F32:
762
0
            {
763
0
                ggml_compute_forward_add_id_f32(params, dst);
764
0
            } break;
765
0
        default:
766
0
            {
767
0
                GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
768
0
            }
769
0
    }
770
0
}
771
772
// ggml_compute_forward_add1
773
774
static void ggml_compute_forward_add1_f32(
775
        const ggml_compute_params * params,
776
0
        ggml_tensor * dst) {
777
778
0
    const ggml_tensor * src0 = dst->src[0];
779
0
    const ggml_tensor * src1 = dst->src[1];
780
781
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
782
0
    GGML_ASSERT(ggml_is_scalar(src1));
783
784
0
    const int ith = params->ith;
785
0
    const int nth = params->nth;
786
787
0
    const int nr  = ggml_nrows(src0);
788
789
0
    GGML_TENSOR_UNARY_OP_LOCALS
790
791
0
    GGML_ASSERT( nb0 == sizeof(float));
792
0
    GGML_ASSERT(nb00 == sizeof(float));
793
794
    // rows per thread
795
0
    const int dr = (nr + nth - 1)/nth;
796
797
    // row range for this thread
798
0
    const int ir0 = dr*ith;
799
0
    const int ir1 = MIN(ir0 + dr, nr);
800
801
0
    for (int ir = ir0; ir < ir1; ++ir) {
802
        // src0 and dst are same shape => same indices
803
0
        const int i3 = ir/(ne2*ne1);
804
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
805
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
806
807
#ifdef GGML_USE_ACCELERATE
808
        GGML_UNUSED(ggml_vec_add1_f32);
809
810
        vDSP_vadd(
811
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
812
                (float *) ((char *) src1->data), 0,
813
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
814
                ne0);
815
#else
816
0
        ggml_vec_add1_f32(ne0,
817
0
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
818
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
819
0
               *(float *) src1->data);
820
0
#endif
821
0
    }
822
0
}
823
824
static void ggml_compute_forward_add1_f16_f32(
825
        const ggml_compute_params * params,
826
0
        ggml_tensor * dst) {
827
828
0
    const ggml_tensor * src0 = dst->src[0];
829
0
    const ggml_tensor * src1 = dst->src[1];
830
831
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
832
0
    GGML_ASSERT(ggml_is_scalar(src1));
833
834
    // scalar to add
835
0
    const float v = *(float *) src1->data;
836
837
0
    const int ith = params->ith;
838
0
    const int nth = params->nth;
839
840
0
    const int nr  = ggml_nrows(src0);
841
842
0
    GGML_TENSOR_UNARY_OP_LOCALS
843
844
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
845
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
846
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
847
848
0
    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
849
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
850
851
    // rows per thread
852
0
    const int dr = (nr + nth - 1)/nth;
853
854
    // row range for this thread
855
0
    const int ir0 = dr*ith;
856
0
    const int ir1 = MIN(ir0 + dr, nr);
857
858
0
    for (int ir = ir0; ir < ir1; ++ir) {
859
        // src0 and dst are same shape => same indices
860
0
        const int i3 = ir/(ne2*ne1);
861
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
862
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
863
864
0
        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
865
0
        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
866
0
        for (int i = 0; i < ne0; i++) {
867
0
            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
868
0
        }
869
0
    }
870
0
}
871
872
static void ggml_compute_forward_add1_f16_f16(
873
        const ggml_compute_params * params,
874
0
        ggml_tensor * dst) {
875
876
0
    const ggml_tensor * src0 = dst->src[0];
877
0
    const ggml_tensor * src1 = dst->src[1];
878
879
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
880
0
    GGML_ASSERT(ggml_is_scalar(src1));
881
882
    // scalar to add
883
0
    const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
884
885
0
    const int ith = params->ith;
886
0
    const int nth = params->nth;
887
888
0
    const int nr  = ggml_nrows(src0);
889
890
0
    GGML_TENSOR_UNARY_OP_LOCALS
891
892
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
893
0
    GGML_ASSERT(src1->type == GGML_TYPE_F16);
894
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
895
896
0
    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
897
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
898
899
    // rows per thread
900
0
    const int dr = (nr + nth - 1)/nth;
901
902
    // row range for this thread
903
0
    const int ir0 = dr*ith;
904
0
    const int ir1 = MIN(ir0 + dr, nr);
905
906
0
    for (int ir = ir0; ir < ir1; ++ir) {
907
        // src0 and dst are same shape => same indices
908
0
        const int i3 = ir/(ne2*ne1);
909
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
910
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
911
912
0
        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
913
0
        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
914
0
        for (int i = 0; i < ne0; i++) {
915
0
            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
916
0
        }
917
0
    }
918
0
}
919
920
static void ggml_compute_forward_add1_q_f32(
921
        const ggml_compute_params * params,
922
0
        ggml_tensor * dst) {
923
924
0
    const ggml_tensor * src0 = dst->src[0];
925
0
    const ggml_tensor * src1 = dst->src[1];
926
927
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
928
0
    GGML_ASSERT(ggml_is_scalar(src1));
929
930
    // scalar to add
931
0
    const float v = *(float *) src1->data;
932
933
0
    const int ith = params->ith;
934
0
    const int nth = params->nth;
935
936
0
    const int nr  = ggml_nrows(src0);
937
938
0
    GGML_TENSOR_UNARY_OP_LOCALS
939
940
0
    const ggml_type type = src0->type;
941
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
942
0
    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
943
944
    // we don't support permuted src0
945
0
    GGML_ASSERT(nb00 == ggml_type_size(type));
946
947
    // dst cannot be transposed or permuted
948
0
    GGML_ASSERT(nb0 <= nb1);
949
0
    GGML_ASSERT(nb1 <= nb2);
950
0
    GGML_ASSERT(nb2 <= nb3);
951
952
0
    GGML_ASSERT(ggml_is_quantized(src0->type));
953
0
    GGML_ASSERT(dst->type == src0->type);
954
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
955
956
    // rows per thread
957
0
    const int dr = (nr + nth - 1)/nth;
958
959
    // row range for this thread
960
0
    const int ir0 = dr*ith;
961
0
    const int ir1 = MIN(ir0 + dr, nr);
962
963
0
    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
964
965
0
    for (int ir = ir0; ir < ir1; ++ir) {
966
        // src0 and dst are same shape => same indices
967
0
        const int i3 = ir/(ne2*ne1);
968
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
969
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
970
971
0
        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
972
0
        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));
973
974
0
        assert(ne0 % 32 == 0);
975
976
        // unquantize row from src0 to temp buffer
977
0
        dequantize_row_q(src0_row, wdata, ne0);
978
        // add src1
979
0
        ggml_vec_acc1_f32(ne0, wdata, v);
980
        // quantize row to dst
981
0
        quantize_row_q(wdata, dst_row, ne0);
982
0
    }
983
0
}
984
985
static void ggml_compute_forward_add1_bf16_f32(
986
        const ggml_compute_params * params,
987
0
        ggml_tensor * dst) {
988
989
0
    const ggml_tensor * src0 = dst->src[0];
990
0
    const ggml_tensor * src1 = dst->src[1];
991
992
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
993
0
    GGML_ASSERT(ggml_is_scalar(src1));
994
995
    // scalar to add
996
0
    const float v = *(float *) src1->data;
997
998
0
    const int ith = params->ith;
999
0
    const int nth = params->nth;
1000
1001
0
    const int nr  = ggml_nrows(src0);
1002
1003
0
    GGML_TENSOR_UNARY_OP_LOCALS
1004
1005
0
    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
1006
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
1007
0
    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
1008
1009
0
    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
1010
0
    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
1011
1012
    // rows per thread
1013
0
    const int dr = (nr + nth - 1)/nth;
1014
1015
    // row range for this thread
1016
0
    const int ir0 = dr*ith;
1017
0
    const int ir1 = MIN(ir0 + dr, nr);
1018
1019
0
    for (int ir = ir0; ir < ir1; ++ir) {
1020
        // src0 and dst are same shape => same indices
1021
0
        const int i3 = ir/(ne2*ne1);
1022
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
1023
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1024
1025
0
        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
1026
0
        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1027
0
        for (int i = 0; i < ne0; i++) {
1028
0
            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
1029
0
        }
1030
0
    }
1031
0
}
1032
1033
static void ggml_compute_forward_add1_bf16_bf16(
1034
        const ggml_compute_params * params,
1035
0
        ggml_tensor * dst) {
1036
1037
0
    const ggml_tensor * src0 = dst->src[0];
1038
0
    const ggml_tensor * src1 = dst->src[1];
1039
1040
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
1041
0
    GGML_ASSERT(ggml_is_scalar(src1));
1042
1043
    // scalar to add
1044
0
    const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
1045
1046
0
    const int ith = params->ith;
1047
0
    const int nth = params->nth;
1048
1049
0
    const int nr  = ggml_nrows(src0);
1050
1051
0
    GGML_TENSOR_UNARY_OP_LOCALS
1052
1053
0
    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
1054
0
    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
1055
0
    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
1056
1057
0
    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
1058
0
    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
1059
1060
    // rows per thread
1061
0
    const int dr = (nr + nth - 1)/nth;
1062
1063
    // row range for this thread
1064
0
    const int ir0 = dr*ith;
1065
0
    const int ir1 = MIN(ir0 + dr, nr);
1066
1067
0
    for (int ir = ir0; ir < ir1; ++ir) {
1068
        // src0 and dst are same shape => same indices
1069
0
        const int i3 = ir/(ne2*ne1);
1070
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
1071
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1072
1073
0
        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
1074
0
        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1075
0
        for (int i = 0; i < ne0; i++) {
1076
0
            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
1077
0
        }
1078
0
    }
1079
0
}
1080
1081
void ggml_compute_forward_add1(
1082
        const ggml_compute_params * params,
1083
0
        ggml_tensor * dst) {
1084
1085
0
    const ggml_tensor * src0 = dst->src[0];
1086
0
    const ggml_tensor * src1 = dst->src[1];
1087
1088
0
    switch (src0->type) {
1089
0
        case GGML_TYPE_F32:
1090
0
            {
1091
0
                ggml_compute_forward_add1_f32(params, dst);
1092
0
            } break;
1093
0
        case GGML_TYPE_F16:
1094
0
            {
1095
0
                if (src1->type == GGML_TYPE_F16) {
1096
0
                    ggml_compute_forward_add1_f16_f16(params, dst);
1097
0
                }
1098
0
                else if (src1->type == GGML_TYPE_F32) {
1099
0
                    ggml_compute_forward_add1_f16_f32(params, dst);
1100
0
                }
1101
0
                else {
1102
0
                    GGML_ABORT("fatal error");
1103
0
                }
1104
0
            } break;
1105
0
        case GGML_TYPE_BF16:
1106
0
            {
1107
0
                if (src1->type == GGML_TYPE_BF16) {
1108
0
                    ggml_compute_forward_add1_bf16_bf16(params, dst);
1109
0
                }
1110
0
                else if (src1->type == GGML_TYPE_F32) {
1111
0
                    ggml_compute_forward_add1_bf16_f32(params, dst);
1112
0
                }
1113
0
                else {
1114
0
                    GGML_ABORT("fatal error");
1115
0
                }
1116
0
            } break;
1117
0
        case GGML_TYPE_Q1_0:
1118
0
        case GGML_TYPE_Q4_0:
1119
0
        case GGML_TYPE_Q4_1:
1120
0
        case GGML_TYPE_Q5_0:
1121
0
        case GGML_TYPE_Q5_1:
1122
0
        case GGML_TYPE_Q8_0:
1123
0
        case GGML_TYPE_Q8_1:
1124
0
        case GGML_TYPE_MXFP4:
1125
0
        case GGML_TYPE_NVFP4:
1126
0
        case GGML_TYPE_Q2_K:
1127
0
        case GGML_TYPE_Q3_K:
1128
0
        case GGML_TYPE_Q4_K:
1129
0
        case GGML_TYPE_Q5_K:
1130
0
        case GGML_TYPE_Q6_K:
1131
0
        case GGML_TYPE_TQ1_0:
1132
0
        case GGML_TYPE_TQ2_0:
1133
0
        case GGML_TYPE_IQ2_XXS:
1134
0
        case GGML_TYPE_IQ2_XS:
1135
0
        case GGML_TYPE_IQ3_XXS:
1136
0
        case GGML_TYPE_IQ1_S:
1137
0
        case GGML_TYPE_IQ1_M:
1138
0
        case GGML_TYPE_IQ4_NL:
1139
0
        case GGML_TYPE_IQ4_XS:
1140
0
        case GGML_TYPE_IQ3_S:
1141
0
        case GGML_TYPE_IQ2_S:
1142
0
            {
1143
0
                ggml_compute_forward_add1_q_f32(params, dst);
1144
0
            } break;
1145
0
        default:
1146
0
            {
1147
0
                GGML_ABORT("fatal error");
1148
0
            }
1149
0
    }
1150
0
}
1151
1152
// ggml_compute_forward_acc
1153
1154
static void ggml_compute_forward_acc_f32(
1155
        const ggml_compute_params * params,
1156
0
        ggml_tensor * dst) {
1157
1158
0
    const ggml_tensor * src0 = dst->src[0];
1159
0
    const ggml_tensor * src1 = dst->src[1];
1160
1161
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
1162
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
1163
1164
    // view src0 and dst with these strides and data offset inbytes during acc
1165
    // nb0 is implicitly element_size because src0 and dst are contiguous
1166
0
    size_t nb1     = ((int32_t *) dst->op_params)[0];
1167
0
    size_t nb2     = ((int32_t *) dst->op_params)[1];
1168
0
    size_t nb3     = ((int32_t *) dst->op_params)[2];
1169
0
    size_t offset  = ((int32_t *) dst->op_params)[3];
1170
0
    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
1171
1172
0
    if (!inplace) {
1173
0
        if (params->ith == 0) {
1174
            // memcpy needs to be synchronized across threads to avoid race conditions.
1175
            // => do it in INIT phase
1176
0
            memcpy(
1177
0
                ((char *)  dst->data),
1178
0
                ((char *) src0->data),
1179
0
                ggml_nbytes(dst));
1180
0
        }
1181
0
        ggml_barrier(params->threadpool);
1182
0
    }
1183
1184
0
    const int ith = params->ith;
1185
0
    const int nth = params->nth;
1186
1187
0
    const int nr = ggml_nrows(src1);
1188
0
    const int nc = src1->ne[0];
1189
1190
0
    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
1191
0
    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
1192
1193
    // src0 and dst as viewed during acc
1194
0
    const size_t nb0 = ggml_element_size(src0);
1195
1196
0
    const size_t nb00 = nb0;
1197
0
    const size_t nb01 = nb1;
1198
0
    const size_t nb02 = nb2;
1199
0
    const size_t nb03 = nb3;
1200
1201
0
    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));
1202
0
    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
1203
1204
0
    GGML_ASSERT(nb10 == sizeof(float));
1205
1206
    // rows per thread
1207
0
    const int dr = (nr + nth - 1)/nth;
1208
1209
    // row range for this thread
1210
0
    const int ir0 = dr*ith;
1211
0
    const int ir1 = MIN(ir0 + dr, nr);
1212
1213
0
    for (int ir = ir0; ir < ir1; ++ir) {
1214
        // src0 and dst are viewed with shape of src1 and offset
1215
        // => same indices
1216
0
        const int i3 = ir/(ne12*ne11);
1217
0
        const int i2 = (ir - i3*ne12*ne11)/ne11;
1218
0
        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
1219
1220
#ifdef GGML_USE_ACCELERATE
1221
        vDSP_vadd(
1222
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
1223
                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
1224
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);
1225
#else
1226
0
        ggml_vec_add_f32(nc,
1227
0
                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
1228
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
1229
0
                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
1230
0
#endif
1231
0
    }
1232
0
}
1233
1234
void ggml_compute_forward_acc(
1235
        const ggml_compute_params * params,
1236
0
        ggml_tensor * dst) {
1237
1238
0
    const ggml_tensor * src0 = dst->src[0];
1239
1240
0
    switch (src0->type) {
1241
0
        case GGML_TYPE_F32:
1242
0
            {
1243
0
                ggml_compute_forward_acc_f32(params, dst);
1244
0
            } break;
1245
0
        case GGML_TYPE_F16:
1246
0
        case GGML_TYPE_BF16:
1247
0
        case GGML_TYPE_Q1_0:
1248
0
        case GGML_TYPE_Q4_0:
1249
0
        case GGML_TYPE_Q4_1:
1250
0
        case GGML_TYPE_Q5_0:
1251
0
        case GGML_TYPE_Q5_1:
1252
0
        case GGML_TYPE_Q8_0:
1253
0
        case GGML_TYPE_Q8_1:
1254
0
        case GGML_TYPE_MXFP4:
1255
0
        case GGML_TYPE_NVFP4:
1256
0
        case GGML_TYPE_Q2_K:
1257
0
        case GGML_TYPE_Q3_K:
1258
0
        case GGML_TYPE_Q4_K:
1259
0
        case GGML_TYPE_Q5_K:
1260
0
        case GGML_TYPE_Q6_K:
1261
0
        case GGML_TYPE_TQ1_0:
1262
0
        case GGML_TYPE_TQ2_0:
1263
0
        case GGML_TYPE_IQ2_XXS:
1264
0
        case GGML_TYPE_IQ2_XS:
1265
0
        case GGML_TYPE_IQ3_XXS:
1266
0
        case GGML_TYPE_IQ1_S:
1267
0
        case GGML_TYPE_IQ1_M:
1268
0
        case GGML_TYPE_IQ4_NL:
1269
0
        case GGML_TYPE_IQ4_XS:
1270
0
        case GGML_TYPE_IQ3_S:
1271
0
        case GGML_TYPE_IQ2_S:
1272
0
        default:
1273
0
            {
1274
0
                GGML_ABORT("fatal error");
1275
0
            }
1276
0
    }
1277
0
}
1278
1279
// ggml_compute_forward_sum
1280
1281
static void ggml_compute_forward_sum_f32(
1282
        const ggml_compute_params * params,
1283
0
        ggml_tensor * dst) {
1284
1285
0
    const ggml_tensor * src0 = dst->src[0];
1286
1287
0
    if (params->ith != 0) {
1288
0
        return;
1289
0
    }
1290
1291
0
    assert(ggml_is_scalar(dst));
1292
0
    assert(src0->nb[0] == sizeof(float));
1293
1294
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1295
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
1296
1297
0
    ggml_float sum     = 0;
1298
0
    ggml_float row_sum = 0;
1299
1300
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1301
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1302
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1303
0
                ggml_vec_sum_f32_ggf(ne00,
1304
0
                        &row_sum,
1305
0
                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
1306
0
                sum += row_sum;
1307
0
            }
1308
0
        }
1309
0
    }
1310
0
    ((float *) dst->data)[0] = sum;
1311
0
}
1312
1313
static void ggml_compute_forward_sum_f16(
1314
    const ggml_compute_params * params,
1315
0
          ggml_tensor * dst) {
1316
1317
0
    const ggml_tensor * src0 = dst->src[0];
1318
1319
0
    if (params->ith != 0) {
1320
0
        return;
1321
0
    }
1322
1323
0
    assert(ggml_is_scalar(dst));
1324
1325
0
    assert(src0->nb[0] == sizeof(ggml_fp16_t));
1326
1327
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1328
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
1329
1330
0
    float sum = 0;
1331
0
    float row_sum = 0;
1332
1333
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1334
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1335
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1336
0
                ggml_vec_sum_f16_ggf(ne00,
1337
0
                    &row_sum,
1338
0
                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
1339
0
                sum += row_sum;
1340
0
            }
1341
0
        }
1342
0
    }
1343
0
    ((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
1344
0
}
1345
1346
static void ggml_compute_forward_sum_bf16(
1347
    const ggml_compute_params * params,
1348
0
          ggml_tensor * dst) {
1349
1350
0
    const ggml_tensor * src0 = dst->src[0];
1351
1352
0
    if (params->ith != 0) {
1353
0
        return;
1354
0
    }
1355
1356
0
    assert(ggml_is_scalar(dst));
1357
1358
0
    assert(src0->nb[0] == sizeof(ggml_bf16_t));
1359
1360
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1361
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
1362
1363
0
    float sum = 0;
1364
0
    float row_sum = 0;
1365
1366
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1367
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1368
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1369
0
                ggml_vec_sum_bf16_ggf(ne00,
1370
0
                    &row_sum,
1371
0
                    (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
1372
0
                sum += row_sum;
1373
0
            }
1374
0
        }
1375
0
    }
1376
0
    ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
1377
0
}
1378
1379
void ggml_compute_forward_sum(
1380
        const ggml_compute_params * params,
1381
0
        ggml_tensor * dst) {
1382
1383
0
    const ggml_tensor * src0 = dst->src[0];
1384
1385
0
    switch (src0->type) {
1386
0
        case GGML_TYPE_F32:
1387
0
            {
1388
0
                ggml_compute_forward_sum_f32(params, dst);
1389
0
            } break;
1390
0
        case GGML_TYPE_F16:
1391
0
            {
1392
0
                ggml_compute_forward_sum_f16(params, dst);
1393
0
            } break;
1394
0
        case GGML_TYPE_BF16:
1395
0
            {
1396
0
                ggml_compute_forward_sum_bf16(params, dst);
1397
0
            } break;
1398
0
        default:
1399
0
            {
1400
0
                GGML_ABORT("fatal error");
1401
0
            }
1402
0
    }
1403
0
}
1404
1405
// ggml_compute_forward_cumsum
1406
1407
static void ggml_compute_forward_cumsum_f32(
1408
        const ggml_compute_params * params,
1409
0
        ggml_tensor * dst) {
1410
1411
0
    const ggml_tensor * src0 = dst->src[0];
1412
1413
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
1414
0
    GGML_ASSERT(dst->nb[0] == sizeof(float));
1415
1416
0
    GGML_TENSOR_UNARY_OP_LOCALS
1417
1418
0
    GGML_ASSERT(ne0 == ne00);
1419
0
    GGML_ASSERT(ne1 == ne01);
1420
0
    GGML_ASSERT(ne2 == ne02);
1421
0
    GGML_ASSERT(ne3 == ne03);
1422
1423
0
    const auto [ir0, ir1] = get_thread_range(params, src0);
1424
1425
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
1426
0
        const int64_t i03 = ir/(ne02*ne01);
1427
0
        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
1428
0
        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1429
1430
0
        float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1431
0
        float * dst_row = (float *) ((char *) dst->data  + i01*nb1  + i02*nb2  + i03*nb3);
1432
1433
0
        ggml_vec_cumsum_f32(ne00, dst_row, src_row);
1434
0
    }
1435
0
}
1436
1437
void ggml_compute_forward_cumsum(
1438
        const ggml_compute_params * params,
1439
0
        ggml_tensor * dst) {
1440
1441
0
    const ggml_tensor * src0 = dst->src[0];
1442
1443
0
    switch (src0->type) {
1444
0
        case GGML_TYPE_F32:
1445
0
            {
1446
0
                ggml_compute_forward_cumsum_f32(params, dst);
1447
0
            } break;
1448
0
        default:
1449
0
            {
1450
0
                GGML_ABORT("fatal error");
1451
0
            }
1452
0
    }
1453
0
}
1454
1455
// ggml_compute_forward_sum_rows
1456
1457
static void ggml_compute_forward_sum_rows_f32(
1458
        const ggml_compute_params * params,
1459
0
        ggml_tensor * dst) {
1460
1461
0
    const ggml_tensor * src0 = dst->src[0];
1462
1463
0
    if (params->ith != 0) {
1464
0
        return;
1465
0
    }
1466
1467
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
1468
0
    GGML_ASSERT(dst->nb[0] == sizeof(float));
1469
1470
0
    GGML_TENSOR_UNARY_OP_LOCALS
1471
1472
0
    GGML_ASSERT(ne0 == 1);
1473
0
    GGML_ASSERT(ne1 == ne01);
1474
0
    GGML_ASSERT(ne2 == ne02);
1475
0
    GGML_ASSERT(ne3 == ne03);
1476
1477
0
    for (int64_t i3 = 0; i3 < ne03; i3++) {
1478
0
        for (int64_t i2 = 0; i2 < ne02; i2++) {
1479
0
            for (int64_t i1 = 0; i1 < ne01; i1++) {
1480
0
                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1481
0
                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
1482
0
                float row_sum = 0;
1483
0
                ggml_vec_sum_f32(ne00, &row_sum, src_row);
1484
0
                dst_row[0] = row_sum;
1485
0
            }
1486
0
        }
1487
0
    }
1488
0
}
1489
1490
void ggml_compute_forward_sum_rows(
1491
        const ggml_compute_params * params,
1492
0
        ggml_tensor * dst) {
1493
1494
0
    const ggml_tensor * src0 = dst->src[0];
1495
1496
0
    switch (src0->type) {
1497
0
        case GGML_TYPE_F32:
1498
0
            {
1499
0
                ggml_compute_forward_sum_rows_f32(params, dst);
1500
0
            } break;
1501
0
        default:
1502
0
            {
1503
0
                GGML_ABORT("fatal error");
1504
0
            }
1505
0
    }
1506
0
}
1507
1508
// ggml_compute_forward_mean
1509
1510
static void ggml_compute_forward_mean_f32(
1511
        const ggml_compute_params * params,
1512
0
        ggml_tensor * dst) {
1513
1514
0
    const ggml_tensor * src0 = dst->src[0];
1515
1516
0
    if (params->ith != 0) {
1517
0
        return;
1518
0
    }
1519
1520
0
    assert(src0->nb[0] == sizeof(float));
1521
1522
0
    GGML_TENSOR_UNARY_OP_LOCALS
1523
1524
0
    assert(ne0 == 1);
1525
0
    assert(ne1 == ne01);
1526
0
    assert(ne2 == ne02);
1527
0
    assert(ne3 == ne03);
1528
1529
0
    GGML_UNUSED(ne0);
1530
0
    GGML_UNUSED(ne1);
1531
0
    GGML_UNUSED(ne2);
1532
0
    GGML_UNUSED(ne3);
1533
1534
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1535
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1536
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1537
0
                ggml_vec_sum_f32(ne00,
1538
0
                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
1539
0
                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
1540
1541
0
                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
1542
0
            }
1543
0
        }
1544
0
    }
1545
0
}
1546
1547
void ggml_compute_forward_mean(
1548
        const ggml_compute_params * params,
1549
0
        ggml_tensor * dst) {
1550
1551
0
    const ggml_tensor * src0 = dst->src[0];
1552
1553
0
    switch (src0->type) {
1554
0
        case GGML_TYPE_F32:
1555
0
            {
1556
0
                ggml_compute_forward_mean_f32(params, dst);
1557
0
            } break;
1558
0
        default:
1559
0
            {
1560
0
                GGML_ABORT("fatal error");
1561
0
            }
1562
0
    }
1563
0
}
1564
1565
// ggml_compute_forward_argmax
1566
1567
static void ggml_compute_forward_argmax_f32(
1568
        const ggml_compute_params * params,
1569
0
        ggml_tensor * dst) {
1570
1571
0
    const ggml_tensor * src0 = dst->src[0];
1572
1573
0
    if (params->ith != 0) {
1574
0
        return;
1575
0
    }
1576
1577
0
    assert(src0->nb[0] == sizeof(float));
1578
0
    assert(dst->nb[0] == sizeof(float));
1579
1580
0
    const int64_t ne00 = src0->ne[0];
1581
0
    const int64_t ne01 = src0->ne[1];
1582
1583
0
    const size_t nb01 = src0->nb[1];
1584
0
    const size_t nb0 = dst->nb[0];
1585
1586
0
    for (int64_t i1 = 0; i1 < ne01; i1++) {
1587
0
        float * src = (float *) ((char *) src0->data + i1*nb01);
1588
0
        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);
1589
0
        int v = 0;
1590
0
        ggml_vec_argmax_f32(ne00, &v, src);
1591
0
        dst_[0] = v;
1592
0
    }
1593
0
}
1594
1595
void ggml_compute_forward_argmax(
1596
        const ggml_compute_params * params,
1597
0
        ggml_tensor * dst) {
1598
1599
0
    const ggml_tensor * src0 = dst->src[0];
1600
1601
0
    switch (src0->type) {
1602
0
        case GGML_TYPE_F32:
1603
0
            {
1604
0
                ggml_compute_forward_argmax_f32(params, dst);
1605
0
            } break;
1606
0
        default:
1607
0
            {
1608
0
                GGML_ABORT("fatal error");
1609
0
            }
1610
0
    }
1611
0
}
1612
1613
// ggml_compute_forward_count_equal
1614
1615
static void ggml_compute_forward_count_equal_i32(
1616
        const ggml_compute_params * params,
1617
0
        ggml_tensor * dst) {
1618
1619
0
    const ggml_tensor * src0 = dst->src[0];
1620
0
    const ggml_tensor * src1 = dst->src[1];
1621
1622
0
    GGML_TENSOR_BINARY_OP_LOCALS;
1623
1624
0
    GGML_ASSERT(src0->type == GGML_TYPE_I32);
1625
0
    GGML_ASSERT(src1->type == GGML_TYPE_I32);
1626
0
    GGML_ASSERT(ggml_are_same_shape(src0, src1));
1627
0
    GGML_ASSERT(ggml_is_scalar(dst));
1628
0
    GGML_ASSERT(dst->type == GGML_TYPE_I64);
1629
1630
0
    const int64_t nr = ggml_nrows(src0);
1631
1632
0
    const int ith = params->ith;
1633
0
    const int nth = params->nth;
1634
1635
0
    int64_t * sums = (int64_t *) params->wdata;
1636
0
    int64_t sum_thread = 0;
1637
1638
    // rows per thread
1639
0
    const int64_t dr = (nr + nth - 1)/nth;
1640
1641
    // row range for this thread
1642
0
    const int64_t ir0 = dr*ith;
1643
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
1644
1645
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
1646
0
        const int64_t i03 =  ir                        / (ne02*ne01);
1647
0
        const int64_t i02 = (ir - i03*ne03)            /       ne01;
1648
0
        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
1649
1650
0
        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
1651
0
        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
1652
1653
0
        for (int64_t i00 = 0; i00 < ne00; ++i00) {
1654
0
            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
1655
0
            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
1656
1657
0
            sum_thread += val0 == val1;
1658
0
        }
1659
0
    }
1660
0
    if (ith != 0) {
1661
0
        sums[ith] = sum_thread;
1662
0
    }
1663
0
    ggml_barrier(params->threadpool);
1664
1665
0
    if (ith != 0) {
1666
0
        return;
1667
0
    }
1668
1669
0
    for (int ith_other = 1; ith_other < nth; ++ith_other) {
1670
0
        sum_thread += sums[ith_other];
1671
0
    }
1672
0
    *((int64_t *) dst->data) = sum_thread;
1673
0
}
1674
1675
void ggml_compute_forward_count_equal(
1676
        const ggml_compute_params * params,
1677
0
        ggml_tensor * dst) {
1678
1679
0
    const ggml_tensor * src0 = dst->src[0];
1680
1681
0
    switch (src0->type) {
1682
0
        case GGML_TYPE_I32:
1683
0
            {
1684
0
                ggml_compute_forward_count_equal_i32(params, dst);
1685
0
            } break;
1686
0
        default:
1687
0
            {
1688
0
                GGML_ABORT("fatal error");
1689
0
            }
1690
0
    }
1691
0
}
1692
1693
// ggml_compute_forward_repeat
1694
1695
static void ggml_compute_forward_repeat_f32(
1696
        const ggml_compute_params * params,
1697
0
        ggml_tensor * dst) {
1698
1699
0
    const ggml_tensor * src0 = dst->src[0];
1700
1701
0
    if (params->ith != 0) {
1702
0
        return;
1703
0
    }
1704
1705
0
    GGML_ASSERT(ggml_can_repeat(src0, dst));
1706
1707
0
    GGML_TENSOR_UNARY_OP_LOCALS
1708
1709
    // guaranteed to be an integer due to the check in ggml_can_repeat
1710
0
    const int nr0 = (int)(ne0/ne00);
1711
0
    const int nr1 = (int)(ne1/ne01);
1712
0
    const int nr2 = (int)(ne2/ne02);
1713
0
    const int nr3 = (int)(ne3/ne03);
1714
1715
    // TODO: support for transposed / permuted tensors
1716
0
    GGML_ASSERT(nb0  == sizeof(float));
1717
0
    GGML_ASSERT(nb00 == sizeof(float));
1718
1719
    // TODO: maybe this is not optimal?
1720
0
    for                         (int i3 = 0; i3 < nr3;  i3++) {
1721
0
        for                     (int k3 = 0; k3 < ne03; k3++) {
1722
0
            for                 (int i2 = 0; i2 < nr2;  i2++) {
1723
0
                for             (int k2 = 0; k2 < ne02; k2++) {
1724
0
                    for         (int i1 = 0; i1 < nr1;  i1++) {
1725
0
                        for     (int k1 = 0; k1 < ne01; k1++) {
1726
0
                            for (int i0 = 0; i0 < nr0;  i0++) {
1727
0
                                ggml_vec_cpy_f32(ne00,
1728
0
                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),
1729
0
                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));
1730
0
                            }
1731
0
                        }
1732
0
                    }
1733
0
                }
1734
0
            }
1735
0
        }
1736
0
    }
1737
0
}
1738
1739
static void ggml_compute_forward_repeat_f16(
1740
        const ggml_compute_params * params,
1741
0
        ggml_tensor * dst) {
1742
1743
0
    const ggml_tensor * src0 = dst->src[0];
1744
1745
0
    if (params->ith != 0) {
1746
0
        return;
1747
0
    }
1748
1749
0
    GGML_ASSERT(ggml_can_repeat(src0, dst));
1750
1751
0
    GGML_TENSOR_UNARY_OP_LOCALS
1752
1753
    // guaranteed to be an integer due to the check in ggml_can_repeat
1754
0
    const int nr0 = (int)(ne0/ne00);
1755
0
    const int nr1 = (int)(ne1/ne01);
1756
0
    const int nr2 = (int)(ne2/ne02);
1757
0
    const int nr3 = (int)(ne3/ne03);
1758
1759
    // TODO: support for transposed / permuted tensors
1760
0
    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));
1761
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
1762
1763
    // TODO: maybe this is not optimal?
1764
0
    for                         (int i3 = 0; i3 < nr3;  i3++) {
1765
0
        for                     (int k3 = 0; k3 < ne03; k3++) {
1766
0
            for                 (int i2 = 0; i2 < nr2;  i2++) {
1767
0
                for             (int k2 = 0; k2 < ne02; k2++) {
1768
0
                    for         (int i1 = 0; i1 < nr1;  i1++) {
1769
0
                        for     (int k1 = 0; k1 < ne01; k1++) {
1770
0
                            for (int i0 = 0; i0 < nr0;  i0++) {
1771
0
                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
1772
0
                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
1773
                                // ggml_vec_cpy_f16(ne00, y, x)
1774
0
                                for (int i = 0; i < ne00; ++i) {
1775
0
                                    y[i]  = x[i];
1776
0
                                }
1777
0
                            }
1778
0
                        }
1779
0
                    }
1780
0
                }
1781
0
            }
1782
0
        }
1783
0
    }
1784
0
}
1785
1786
void ggml_compute_forward_repeat(
1787
        const ggml_compute_params * params,
1788
0
        ggml_tensor * dst) {
1789
1790
0
    const ggml_tensor * src0 = dst->src[0];
1791
1792
0
    switch (src0->type) {
1793
0
        case GGML_TYPE_F16:
1794
0
        case GGML_TYPE_BF16:
1795
0
        case GGML_TYPE_I16:
1796
0
            {
1797
0
                ggml_compute_forward_repeat_f16(params, dst);
1798
0
            } break;
1799
0
        case GGML_TYPE_F32:
1800
0
        case GGML_TYPE_I32:
1801
0
            {
1802
0
                ggml_compute_forward_repeat_f32(params, dst);
1803
0
            } break;
1804
        // TODO: templateify the implementation and support for I64
1805
        //       ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
1806
        //case GGML_TYPE_I64:
1807
        //    {
1808
        //        ggml_compute_forward_repeat_i64(params, dst);
1809
        //    } break;
1810
0
        default:
1811
0
            {
1812
0
                GGML_ABORT("fatal error");
1813
0
            }
1814
0
    }
1815
0
}
1816
1817
// ggml_compute_forward_repeat_back
1818
1819
static void ggml_compute_forward_repeat_back_f32(
1820
        const ggml_compute_params * params,
1821
0
        ggml_tensor * dst) {
1822
1823
0
    const ggml_tensor * src0 = dst->src[0];
1824
1825
0
    if (params->ith != 0) {
1826
0
        return;
1827
0
    }
1828
1829
0
    GGML_ASSERT(ggml_can_repeat(dst, src0));
1830
1831
0
    GGML_TENSOR_UNARY_OP_LOCALS
1832
1833
    // guaranteed to be an integer due to the check in ggml_can_repeat
1834
0
    const int nr0 = (int)(ne00/ne0);
1835
0
    const int nr1 = (int)(ne01/ne1);
1836
0
    const int nr2 = (int)(ne02/ne2);
1837
0
    const int nr3 = (int)(ne03/ne3);
1838
1839
    // TODO: support for transposed / permuted tensors
1840
0
    GGML_ASSERT(nb0  == sizeof(float));
1841
0
    GGML_ASSERT(nb00 == sizeof(float));
1842
1843
0
    if (ggml_is_contiguous(dst)) {
1844
0
        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
1845
0
    } else {
1846
0
        for         (int k3 = 0; k3 < ne3; k3++) {
1847
0
            for     (int k2 = 0; k2 < ne2; k2++) {
1848
0
                for (int k1 = 0; k1 < ne1; k1++) {
1849
0
                    ggml_vec_set_f32(ne0,
1850
0
                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
1851
0
                        0);
1852
0
                }
1853
0
            }
1854
0
        }
1855
0
    }
1856
1857
    // TODO: maybe this is not optimal?
1858
0
    for                         (int i3 = 0; i3 < nr3; i3++) {
1859
0
        for                     (int k3 = 0; k3 < ne3; k3++) {
1860
0
            for                 (int i2 = 0; i2 < nr2; i2++) {
1861
0
                for             (int k2 = 0; k2 < ne2; k2++) {
1862
0
                    for         (int i1 = 0; i1 < nr1; i1++) {
1863
0
                        for     (int k1 = 0; k1 < ne1; k1++) {
1864
0
                            for (int i0 = 0; i0 < nr0; i0++) {
1865
0
                                ggml_vec_acc_f32(ne0,
1866
0
                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
1867
0
                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
1868
0
                            }
1869
0
                        }
1870
0
                    }
1871
0
                }
1872
0
            }
1873
0
        }
1874
0
    }
1875
0
}
1876
1877
void ggml_compute_forward_repeat_back(
1878
        const ggml_compute_params * params,
1879
0
        ggml_tensor * dst) {
1880
1881
0
    const ggml_tensor * src0 = dst->src[0];
1882
1883
0
    switch (src0->type) {
1884
0
        case GGML_TYPE_F32:
1885
0
            {
1886
0
                ggml_compute_forward_repeat_back_f32(params, dst);
1887
0
            } break;
1888
0
        default:
1889
0
            {
1890
0
                GGML_ABORT("fatal error");
1891
0
            }
1892
0
    }
1893
0
}
1894
1895
// ggml_compute_forward_concat
1896
1897
static void ggml_compute_forward_concat_any(
1898
    const ggml_compute_params * params,
1899
0
    ggml_tensor * dst) {
1900
1901
0
    const ggml_tensor * src0 = dst->src[0];
1902
0
    const ggml_tensor * src1 = dst->src[1];
1903
1904
0
    const size_t len = ggml_type_size(src0->type);
1905
1906
0
    const int ith = params->ith;
1907
0
    const int nth = params->nth;
1908
1909
0
    GGML_TENSOR_BINARY_OP_LOCALS
1910
1911
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
1912
1913
0
    GGML_ASSERT(dim >= 0 && dim < 4);
1914
1915
0
    int64_t o[4] = {0, 0, 0, 0};
1916
0
    o[dim] = src0->ne[dim];
1917
1918
0
    const char * x;
1919
1920
    // TODO: smarter multi-theading
1921
0
    for (int i3 = 0; i3 < ne3; i3++) {
1922
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
1923
0
            for (int i1 = 0; i1 < ne1; i1++) {
1924
0
                for (int i0 = 0; i0 < ne0; i0++) {
1925
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1926
0
                        x = (const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03;
1927
0
                    } else {
1928
0
                        x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
1929
0
                    }
1930
1931
0
                    char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
1932
1933
0
                    memcpy(y, x, len);
1934
0
                }
1935
0
            }
1936
0
        }
1937
0
    }
1938
0
}
1939
1940
static void ggml_compute_forward_concat_i8(
1941
    const ggml_compute_params * params,
1942
0
    ggml_tensor * dst) {
1943
1944
0
    const ggml_tensor * src0 = dst->src[0];
1945
0
    const ggml_tensor * src1 = dst->src[1];
1946
1947
0
    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
1948
1949
0
    const int ith = params->ith;
1950
0
    const int nth = params->nth;
1951
1952
0
    GGML_TENSOR_BINARY_OP_LOCALS
1953
1954
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
1955
1956
0
    GGML_ASSERT(dim >= 0 && dim < 4);
1957
1958
0
    int64_t o[4] = {0, 0, 0, 0};
1959
0
    o[dim] = src0->ne[dim];
1960
1961
0
    const int8_t * x;
1962
1963
    // TODO: smarter multi-theading
1964
0
    for (int i3 = 0; i3 < ne3; i3++) {
1965
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
1966
0
            for (int i1 = 0; i1 < ne1; i1++) {
1967
0
                for (int i0 = 0; i0 < ne0; i0++) {
1968
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1969
0
                        x = (const int8_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
1970
0
                    } else {
1971
0
                        x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
1972
0
                    }
1973
1974
0
                    int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
1975
1976
0
                    *y = *x;
1977
0
                }
1978
0
            }
1979
0
        }
1980
0
    }
1981
0
}
1982
1983
static void ggml_compute_forward_concat_f16(
1984
    const ggml_compute_params * params,
1985
0
    ggml_tensor * dst) {
1986
1987
0
    const ggml_tensor * src0 = dst->src[0];
1988
0
    const ggml_tensor * src1 = dst->src[1];
1989
1990
0
    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
1991
1992
0
    const int ith = params->ith;
1993
0
    const int nth = params->nth;
1994
1995
0
    GGML_TENSOR_BINARY_OP_LOCALS
1996
1997
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
1998
1999
0
    GGML_ASSERT(dim >= 0 && dim < 4);
2000
2001
0
    int64_t o[4] = {0, 0, 0, 0};
2002
0
    o[dim] = src0->ne[dim];
2003
2004
0
    const ggml_fp16_t * x;
2005
2006
    // TODO: smarter multi-theading
2007
0
    for (int i3 = 0; i3 < ne3; i3++) {
2008
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
2009
0
            for (int i1 = 0; i1 < ne1; i1++) {
2010
0
                for (int i0 = 0; i0 < ne0; i0++) {
2011
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
2012
0
                        x = (const ggml_fp16_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
2013
0
                    } else {
2014
0
                        x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
2015
0
                    }
2016
2017
0
                    ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
2018
2019
0
                    *y = *x;
2020
0
                }
2021
0
            }
2022
0
        }
2023
0
    }
2024
0
}
2025
2026
static void ggml_compute_forward_concat_f32(
2027
    const ggml_compute_params * params,
2028
0
    ggml_tensor * dst) {
2029
2030
0
    const ggml_tensor * src0 = dst->src[0];
2031
0
    const ggml_tensor * src1 = dst->src[1];
2032
2033
0
    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
2034
2035
0
    const int ith = params->ith;
2036
0
    const int nth = params->nth;
2037
2038
0
    GGML_TENSOR_BINARY_OP_LOCALS
2039
2040
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
2041
2042
0
    GGML_ASSERT(dim >= 0 && dim < 4);
2043
2044
0
    int64_t o[4] = {0, 0, 0, 0};
2045
0
    o[dim] = src0->ne[dim];
2046
2047
0
    const float * x;
2048
2049
    // TODO: smarter multi-theading
2050
0
    for (int i3 = 0; i3 < ne3; i3++) {
2051
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
2052
0
            for (int i1 = 0; i1 < ne1; i1++) {
2053
0
                for (int i0 = 0; i0 < ne0; i0++) {
2054
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
2055
0
                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
2056
0
                    } else {
2057
0
                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
2058
0
                    }
2059
2060
0
                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
2061
2062
0
                    *y = *x;
2063
0
                }
2064
0
            }
2065
0
        }
2066
0
    }
2067
0
}
2068
2069
void ggml_compute_forward_concat(
2070
    const ggml_compute_params * params,
2071
0
    ggml_tensor * dst) {
2072
2073
0
    const ggml_tensor * src0 = dst->src[0];
2074
2075
0
    switch (src0->type) {
2076
0
        case GGML_TYPE_F16:
2077
0
        case GGML_TYPE_BF16:
2078
0
        case GGML_TYPE_I16:
2079
0
            {
2080
0
                ggml_compute_forward_concat_f16(params, dst);
2081
0
            } break;
2082
0
        case GGML_TYPE_I8:
2083
0
            {
2084
0
                ggml_compute_forward_concat_i8(params, dst);
2085
0
            } break;
2086
0
        case GGML_TYPE_F32:
2087
0
        case GGML_TYPE_I32:
2088
0
            {
2089
0
                ggml_compute_forward_concat_f32(params, dst);
2090
0
            } break;
2091
0
        default:
2092
0
            {
2093
0
                ggml_compute_forward_concat_any(params, dst);
2094
0
            }
2095
0
    }
2096
0
}
2097
2098
// ggml_compute_forward_gelu
2099
2100
static void ggml_compute_forward_gelu_f32(
2101
        const ggml_compute_params * params,
2102
0
        ggml_tensor * dst) {
2103
2104
0
    const ggml_tensor * src0 = dst->src[0];
2105
2106
0
    assert(ggml_is_contiguous_rows(src0));
2107
0
    assert(ggml_are_same_shape(src0, dst));
2108
2109
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2110
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
2111
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
2112
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
2113
2114
0
    const int ith = params->ith;
2115
0
    const int nth = params->nth;
2116
2117
0
    const int nc = src0->ne[0];
2118
0
    const int nr = ggml_nrows(src0);
2119
2120
    // rows per thread
2121
0
    const int dr = (nr + nth - 1)/nth;
2122
2123
    // row range for this thread
2124
0
    const int ir0 = dr*ith;
2125
0
    const int ir1 = MIN(ir0 + dr, nr);
2126
2127
0
    for (int ir = ir0; ir < ir1; ++ir) {
2128
0
        const int i3 = ir/(ne02*ne01);
2129
0
        const int i2 = (ir - i3*ne02*ne01)/ne01;
2130
0
        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2131
2132
0
        ggml_vec_gelu_f32(nc,
2133
0
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
2134
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2135
2136
#ifndef NDEBUG
2137
        for (int k = 0; k < nc; k++) {
2138
            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2139
            GGML_UNUSED(x);
2140
            assert(!isnan(x));
2141
            assert(!isinf(x));
2142
        }
2143
#endif // NDEBUG
2144
0
    }
2145
0
}
2146
2147
static void ggml_compute_forward_gelu_f16(
2148
    const ggml_compute_params * params,
2149
0
    ggml_tensor * dst) {
2150
2151
0
    const ggml_tensor * src0 = dst->src[0];
2152
2153
0
    assert(ggml_is_contiguous_rows(src0));
2154
0
    assert(ggml_are_same_shape(src0, dst));
2155
2156
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2157
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
2158
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
2159
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
2160
2161
0
    const int ith = params->ith;
2162
0
    const int nth = params->nth;
2163
2164
0
    const int nc = src0->ne[0];
2165
0
    const int nr = ggml_nrows(src0);
2166
2167
    // rows per thread
2168
0
    const int dr = (nr + nth - 1)/nth;
2169
2170
    // row range for this thread
2171
0
    const int ir0 = dr*ith;
2172
0
    const int ir1 = MIN(ir0 + dr, nr);
2173
2174
0
    for (int ir = ir0; ir < ir1; ++ir) {
2175
0
        const int i3 = ir/(ne02*ne01);
2176
0
        const int i2 = (ir - i3*ne02*ne01)/ne01;
2177
0
        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2178
2179
0
        ggml_vec_gelu_f16(nc,
2180
0
                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
2181
0
                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2182
2183
#ifndef NDEBUG
2184
        for (int k = 0; k < nc; k++) {
2185
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2186
            const float v = GGML_CPU_FP16_TO_FP32(x);
2187
            GGML_UNUSED(v);
2188
            assert(!isnan(v));
2189
            assert(!isinf(v));
2190
        }
2191
#endif // NDEBUG
2192
0
    }
2193
0
}
2194
2195
static void ggml_compute_forward_gelu(
2196
        const ggml_compute_params * params,
2197
0
        ggml_tensor * dst) {
2198
2199
0
    const ggml_tensor * src0 = dst->src[0];
2200
2201
0
    switch (src0->type) {
2202
0
        case GGML_TYPE_F32:
2203
0
            {
2204
0
                ggml_compute_forward_gelu_f32(params, dst);
2205
0
            } break;
2206
0
        case GGML_TYPE_F16:
2207
0
            {
2208
0
                ggml_compute_forward_gelu_f16(params, dst);
2209
0
            } break;
2210
0
        default:
2211
0
            {
2212
0
                GGML_ABORT("fatal error");
2213
0
            }
2214
0
    }
2215
0
}
2216
2217
// ggml_compute_fill
2218
2219
0
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2220
0
    const float c = ggml_get_op_params_f32(dst, 0);
2221
2222
0
    GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2223
0
    GGML_TENSOR_LOCALS(size_t,  nb, dst, nb);
2224
2225
0
    const auto [ir0, ir1] = get_thread_range(params, dst);
2226
2227
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
2228
0
        const int64_t i03 = ir/(ne2*ne1);
2229
0
        const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2230
0
        const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2231
2232
0
        float * dst_ptr  = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2233
2234
0
        ggml_vec_set_f32(ne0, dst_ptr, c);
2235
0
    }
2236
0
}
2237
2238
0
static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) {
2239
0
    const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0));
2240
2241
0
    GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2242
0
    GGML_TENSOR_LOCALS(size_t,  nb, dst, nb);
2243
2244
0
    const auto [ir0, ir1] = get_thread_range(params, dst);
2245
2246
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
2247
0
        const int64_t i03 = ir/(ne2*ne1);
2248
0
        const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2249
0
        const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2250
2251
0
        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2252
2253
0
        ggml_vec_set_f16(ne0, dst_ptr, c);
2254
0
    }
2255
0
}
2256
2257
0
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
2258
0
    const ggml_tensor * src0 = dst->src[0];
2259
2260
0
    switch (src0->type) {
2261
0
        case GGML_TYPE_F32:
2262
0
            {
2263
0
                ggml_compute_forward_fill_f32(params, dst);
2264
0
            } break;
2265
0
        case GGML_TYPE_F16:
2266
0
            {
2267
0
                ggml_compute_forward_fill_f16(params, dst);
2268
0
            } break;
2269
0
        default:
2270
0
            {
2271
0
                GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type));
2272
0
            }
2273
0
    }
2274
0
}
2275
2276
// ggml_compute_tri
2277
2278
0
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2279
0
    const ggml_tensor * src0 = dst->src[0];
2280
2281
0
    const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2282
2283
0
    GGML_ASSERT(ggml_is_contiguous(src0));
2284
2285
0
    GGML_TENSOR_UNARY_OP_LOCALS
2286
2287
0
    const auto [ir0, ir1] = get_thread_range(params, src0);
2288
2289
0
    bool (*bipred)(int, int);
2290
2291
0
    switch (ttype) {
2292
0
        case GGML_TRI_TYPE_LOWER:      bipred = [](int i, int r) { return i <  r; }; break;
2293
0
        case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2294
0
        case GGML_TRI_TYPE_UPPER:      bipred = [](int i, int r) { return i >  r; }; break;
2295
0
        case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
2296
0
        default: GGML_ABORT("invalid tri type");
2297
0
    }
2298
2299
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
2300
0
        const int64_t i03 = ir/(ne02*ne01);
2301
0
        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
2302
0
        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
2303
2304
0
        const float * src_ptr = (const float  *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
2305
0
              float * dst_ptr = (      float  *) ((      char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1);
2306
2307
0
        for (int i0 = 0; i0 < ne0; ++i0) {
2308
0
            dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
2309
0
        }
2310
0
    }
2311
0
}
2312
2313
0
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
2314
0
    const ggml_tensor * src0 = dst->src[0];
2315
2316
0
    switch (src0->type) {
2317
0
        case GGML_TYPE_F32:
2318
0
            {
2319
0
                ggml_compute_forward_tri_f32(params, dst);
2320
0
            } break;
2321
0
        default:
2322
0
            {
2323
0
                GGML_ABORT("fatal error");
2324
0
            }
2325
0
    }
2326
0
}
2327
2328
// ggml_compute_forward_gelu_erf
2329
2330
static void ggml_compute_forward_gelu_erf_f32(
2331
        const ggml_compute_params * params,
2332
0
        ggml_tensor * dst) {
2333
2334
0
    const ggml_tensor * src0 = dst->src[0];
2335
2336
0
    assert(ggml_is_contiguous_rows(src0));
2337
0
    assert(ggml_are_same_shape(src0, dst));
2338
2339
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2340
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
2341
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
2342
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
2343
2344
0
    const int ith = params->ith;
2345
0
    const int nth = params->nth;
2346
2347
0
    const int nc = src0->ne[0];
2348
0
    const int nr = ggml_nrows(src0);
2349
2350
    // rows per thread
2351
0
    const int dr = (nr + nth - 1)/nth;
2352
2353
    // row range for this thread
2354
0
    const int ir0 = dr*ith;
2355
0
    const int ir1 = MIN(ir0 + dr, nr);
2356
2357
0
    for (int ir = ir0; ir < ir1; ++ir) {
2358
0
        const int i3 = ir/(ne02*ne01);
2359
0
        const int i2 = (ir - i3*ne02*ne01)/ne01;
2360
0
        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2361
2362
0
        ggml_vec_gelu_erf_f32(nc,
2363
0
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
2364
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2365
2366
#ifndef NDEBUG
2367
        for (int k = 0; k < nc; k++) {
2368
            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2369
            GGML_UNUSED(x);
2370
            assert(!isnan(x));
2371
            assert(!isinf(x));
2372
        }
2373
#endif // NDEBUG
2374
0
    }
2375
0
}
2376
2377
static void ggml_compute_forward_gelu_erf_f16(
2378
    const ggml_compute_params * params,
2379
0
    ggml_tensor * dst) {
2380
2381
0
    const ggml_tensor * src0 = dst->src[0];
2382
2383
0
    assert(ggml_is_contiguous_rows(src0));
2384
0
    assert(ggml_are_same_shape(src0, dst));
2385
2386
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2387
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
2388
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
2389
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
2390
2391
0
    const int ith = params->ith;
2392
0
    const int nth = params->nth;
2393
2394
0
    const int nc = src0->ne[0];
2395
0
    const int nr = ggml_nrows(src0);
2396
2397
    // rows per thread
2398
0
    const int dr = (nr + nth - 1)/nth;
2399
2400
    // row range for this thread
2401
0
    const int ir0 = dr*ith;
2402
0
    const int ir1 = MIN(ir0 + dr, nr);
2403
2404
0
    for (int ir = ir0; ir < ir1; ++ir) {
2405
0
        const int i3 = ir/(ne02*ne01);
2406
0
        const int i2 = (ir - i3*ne02*ne01)/ne01;
2407
0
        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2408
2409
0
        ggml_vec_gelu_erf_f16(nc,
2410
0
                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
2411
0
                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2412
2413
#ifndef NDEBUG
2414
        for (int k = 0; k < nc; k++) {
2415
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2416
            const float v = GGML_CPU_FP16_TO_FP32(x);
2417
            GGML_UNUSED(v);
2418
            assert(!isnan(v));
2419
            assert(!isinf(v));
2420
        }
2421
#endif // NDEBUG
2422
0
    }
2423
0
}
2424
2425
static void ggml_compute_forward_gelu_erf(
2426
        const ggml_compute_params * params,
2427
0
        ggml_tensor * dst) {
2428
2429
0
    const ggml_tensor * src0 = dst->src[0];
2430
2431
0
    switch (src0->type) {
2432
0
        case GGML_TYPE_F32:
2433
0
            {
2434
0
                ggml_compute_forward_gelu_erf_f32(params, dst);
2435
0
            } break;
2436
0
        case GGML_TYPE_F16:
2437
0
            {
2438
0
                ggml_compute_forward_gelu_erf_f16(params, dst);
2439
0
            } break;
2440
0
        default:
2441
0
            {
2442
0
                GGML_ABORT("fatal error");
2443
0
            }
2444
0
    }
2445
0
}
2446
2447
// ggml_compute_forward_gelu_quick
2448
2449
static void ggml_compute_forward_gelu_quick_f32(
2450
        const ggml_compute_params * params,
2451
0
        ggml_tensor * dst) {
2452
2453
0
    const ggml_tensor * src0 = dst->src[0];
2454
2455
0
    assert(ggml_is_contiguous_rows(src0));
2456
0
    assert(ggml_are_same_shape(src0, dst));
2457
2458
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2459
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
2460
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
2461
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
2462
2463
0
    const int ith = params->ith;
2464
0
    const int nth = params->nth;
2465
2466
0
    const int nc = src0->ne[0];
2467
0
    const int nr = ggml_nrows(src0);
2468
2469
    // rows per thread
2470
0
    const int dr = (nr + nth - 1)/nth;
2471
2472
    // row range for this thread
2473
0
    const int ir0 = dr*ith;
2474
0
    const int ir1 = MIN(ir0 + dr, nr);
2475
2476
0
    for (int ir = ir0; ir < ir1; ++ir) {
2477
0
        const int i3 = ir/(ne02*ne01);
2478
0
        const int i2 = (ir - i3*ne02*ne01)/ne01;
2479
0
        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2480
2481
0
        ggml_vec_gelu_quick_f32(nc,
2482
0
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
2483
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2484
2485
#ifndef NDEBUG
2486
        for (int k = 0; k < nc; k++) {
2487
            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2488
            GGML_UNUSED(x);
2489
            assert(!isnan(x));
2490
            assert(!isinf(x));
2491
        }
2492
#endif // NDEBUG
2493
0
    }
2494
0
}
2495
2496
static void ggml_compute_forward_gelu_quick_f16(
2497
    const ggml_compute_params * params,
2498
0
    ggml_tensor * dst) {
2499
2500
0
    const ggml_tensor * src0 = dst->src[0];
2501
2502
0
    assert(ggml_is_contiguous_rows(src0));
2503
0
    assert(ggml_are_same_shape(src0, dst));
2504
2505
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2506
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
2507
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
2508
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
2509
2510
0
    const int ith = params->ith;
2511
0
    const int nth = params->nth;
2512
2513
0
    const int nc = src0->ne[0];
2514
0
    const int nr = ggml_nrows(src0);
2515
2516
    // rows per thread
2517
0
    const int dr = (nr + nth - 1)/nth;
2518
2519
    // row range for this thread
2520
0
    const int ir0 = dr*ith;
2521
0
    const int ir1 = MIN(ir0 + dr, nr);
2522
2523
0
    for (int ir = ir0; ir < ir1; ++ir) {
2524
0
        const int i3 = ir/(ne02*ne01);
2525
0
        const int i2 = (ir - i3*ne02*ne01)/ne01;
2526
0
        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2527
2528
0
        ggml_vec_gelu_quick_f16(nc,
2529
0
                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
2530
0
                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2531
2532
#ifndef NDEBUG
2533
        for (int k = 0; k < nc; k++) {
2534
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2535
            const float v = GGML_CPU_FP16_TO_FP32(x);
2536
            GGML_UNUSED(v);
2537
            assert(!isnan(v));
2538
            assert(!isinf(v));
2539
        }
2540
#endif // NDEBUG
2541
0
    }
2542
0
}
2543
2544
static void ggml_compute_forward_gelu_quick(
2545
        const ggml_compute_params * params,
2546
0
        ggml_tensor * dst) {
2547
2548
0
    const ggml_tensor * src0 = dst->src[0];
2549
2550
0
    switch (src0->type) {
2551
0
        case GGML_TYPE_F32:
2552
0
            {
2553
0
                ggml_compute_forward_gelu_quick_f32(params, dst);
2554
0
            } break;
2555
0
        case GGML_TYPE_F16:
2556
0
            {
2557
0
                ggml_compute_forward_gelu_quick_f16(params, dst);
2558
0
            } break;
2559
0
        default:
2560
0
            {
2561
0
                GGML_ABORT("fatal error");
2562
0
            }
2563
0
    }
2564
0
}
2565
2566
// ggml_compute_forward_silu
2567
2568
static void ggml_compute_forward_silu_f32(
2569
        const ggml_compute_params * params,
2570
0
        ggml_tensor * dst) {
2571
2572
0
    const ggml_tensor * src0 = dst->src[0];
2573
2574
0
    assert(ggml_is_contiguous_rows(src0));
2575
0
    assert(ggml_are_same_shape(src0, dst));
2576
2577
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2578
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
2579
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
2580
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
2581
2582
0
    const int ith = params->ith;
2583
0
    const int nth = params->nth;
2584
2585
0
    const int nc = src0->ne[0];
2586
0
    const int nr = ggml_nrows(src0);
2587
2588
    // rows per thread
2589
0
    const int dr = (nr + nth - 1)/nth;
2590
2591
    // row range for this thread
2592
0
    const int ir0 = dr*ith;
2593
0
    const int ir1 = MIN(ir0 + dr, nr);
2594
2595
0
    for (int ir = ir0; ir < ir1; ++ir) {
2596
0
        const int i3 = ir/(ne02*ne01);
2597
0
        const int i2 = (ir - i3*ne02*ne01)/ne01;
2598
0
        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2599
2600
0
        ggml_vec_silu_f32(nc,
2601
0
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
2602
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2603
2604
#ifndef NDEBUG
2605
        for (int k = 0; k < nc; k++) {
2606
            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
2607
            GGML_UNUSED(x);
2608
            assert(!isnan(x));
2609
            assert(!isinf(x));
2610
        }
2611
#endif // NDEBUG
2612
0
    }
2613
0
}
2614
2615
static void ggml_compute_forward_silu_f16(
2616
    const ggml_compute_params * params,
2617
0
    ggml_tensor * dst) {
2618
2619
0
    const ggml_tensor * src0 = dst->src[0];
2620
2621
0
    assert(ggml_is_contiguous_rows(src0));
2622
0
    assert(ggml_are_same_shape(src0, dst));
2623
2624
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2625
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
2626
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
2627
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
2628
2629
0
    const int ith = params->ith;
2630
0
    const int nth = params->nth;
2631
2632
0
    const int nc = src0->ne[0];
2633
0
    const int nr = ggml_nrows(src0);
2634
2635
    // rows per thread
2636
0
    const int dr = (nr + nth - 1)/nth;
2637
2638
    // row range for this thread
2639
0
    const int ir0 = dr*ith;
2640
0
    const int ir1 = MIN(ir0 + dr, nr);
2641
2642
0
    for (int ir = ir0; ir < ir1; ++ir) {
2643
0
        const int i3 = ir/(ne02*ne01);
2644
0
        const int i2 = (ir - i3*ne02*ne01)/ne01;
2645
0
        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2646
2647
0
        ggml_vec_silu_f16(nc,
2648
0
                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
2649
0
                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2650
2651
#ifndef NDEBUG
2652
        for (int k = 0; k < nc; k++) {
2653
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
2654
            const float v = GGML_CPU_FP16_TO_FP32(x);
2655
            GGML_UNUSED(v);
2656
            assert(!isnan(v));
2657
            assert(!isinf(v));
2658
        }
2659
#endif // NDEBUG
2660
0
    }
2661
0
}
2662
2663
static void ggml_compute_forward_silu(
2664
        const ggml_compute_params * params,
2665
0
        ggml_tensor * dst) {
2666
2667
0
    const ggml_tensor * src0 = dst->src[0];
2668
2669
0
    switch (src0->type) {
2670
0
        case GGML_TYPE_F32:
2671
0
            {
2672
0
                ggml_compute_forward_silu_f32(params, dst);
2673
0
            } break;
2674
0
        case GGML_TYPE_F16:
2675
0
            {
2676
0
                ggml_compute_forward_silu_f16(params, dst);
2677
0
            } break;
2678
0
        default:
2679
0
            {
2680
0
                GGML_ABORT("fatal error");
2681
0
            }
2682
0
    }
2683
0
}
2684
// ggml_compute_forward_leaky_relu
2685
2686
static void ggml_compute_forward_leaky_relu_f32(
2687
        const ggml_compute_params * params,
2688
0
        ggml_tensor * dst) {
2689
2690
0
    const ggml_tensor * src0 = dst->src[0];
2691
2692
0
    if (params->ith != 0) {
2693
0
        return;
2694
0
    }
2695
2696
0
    assert(ggml_is_contiguous_1(src0));
2697
0
    assert(ggml_is_contiguous_1(dst));
2698
0
    assert(ggml_are_same_shape(src0, dst));
2699
2700
0
    const int n  = ggml_nrows(src0);
2701
0
    const int nc = src0->ne[0];
2702
2703
0
    float negative_slope;
2704
0
    memcpy(&negative_slope, dst->op_params, sizeof(float));
2705
2706
0
    assert(dst->nb[0]  == sizeof(float));
2707
0
    assert(src0->nb[0] == sizeof(float));
2708
2709
0
    for (int i = 0; i < n; i++) {
2710
0
        ggml_vec_leaky_relu_f32(nc,
2711
0
                (float *) ((char *) dst->data  + i*( dst->nb[1])),
2712
0
                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2713
0
    }
2714
0
}
2715
2716
static void ggml_compute_forward_leaky_relu_f16(
2717
    const ggml_compute_params * params,
2718
0
    ggml_tensor * dst) {
2719
2720
0
    const ggml_tensor * src0 = dst->src[0];
2721
2722
0
    if (params->ith != 0) {
2723
0
        return;
2724
0
    }
2725
2726
0
    assert(ggml_is_contiguous_1(src0));
2727
0
    assert(ggml_is_contiguous_1(dst));
2728
0
    assert(ggml_are_same_shape(src0, dst));
2729
2730
0
    const int n  = ggml_nrows(src0);
2731
0
    const int nc = src0->ne[0];
2732
2733
0
    float negative_slope;
2734
0
    memcpy(&negative_slope, dst->op_params, sizeof(float));
2735
2736
0
    assert(dst->nb[0]  == sizeof(ggml_fp16_t));
2737
0
    assert(src0->nb[0] == sizeof(ggml_fp16_t));
2738
2739
0
    for (int i = 0; i < n; i++) {
2740
0
        ggml_vec_leaky_relu_f16(nc,
2741
0
                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
2742
0
                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2743
0
    }
2744
0
}
2745
2746
void ggml_compute_forward_leaky_relu(
2747
        const ggml_compute_params * params,
2748
0
        ggml_tensor * dst) {
2749
2750
0
    const ggml_tensor * src0 = dst->src[0];
2751
2752
0
    switch (src0->type) {
2753
0
        case GGML_TYPE_F32:
2754
0
            {
2755
0
                ggml_compute_forward_leaky_relu_f32(params, dst);
2756
0
            } break;
2757
0
        case GGML_TYPE_F16:
2758
0
            {
2759
0
                ggml_compute_forward_leaky_relu_f16(params, dst);
2760
0
            } break;
2761
0
        default:
2762
0
            {
2763
0
                GGML_ABORT("fatal error");
2764
0
            }
2765
0
    }
2766
0
}
2767
2768
// ggml_compute_forward_silu_back
2769
2770
static void ggml_compute_forward_silu_back_f32(
2771
        const ggml_compute_params * params,
2772
0
        ggml_tensor * dst) {
2773
2774
0
    const ggml_tensor * grad = dst->src[0];
2775
0
    const ggml_tensor * src1 = dst->src[1];
2776
2777
0
    assert(ggml_is_contiguous_1(grad));
2778
0
    assert(ggml_is_contiguous_1(src1));
2779
0
    assert(ggml_is_contiguous_1(dst));
2780
0
    assert(ggml_are_same_shape(src1, dst));
2781
0
    assert(ggml_are_same_shape(src1, grad));
2782
2783
0
    const int ith = params->ith;
2784
0
    const int nth = params->nth;
2785
2786
0
    const int nc = src1->ne[0];
2787
0
    const int nr = ggml_nrows(src1);
2788
2789
    // rows per thread
2790
0
    const int dr = (nr + nth - 1)/nth;
2791
2792
    // row range for this thread
2793
0
    const int ir0 = dr*ith;
2794
0
    const int ir1 = MIN(ir0 + dr, nr);
2795
2796
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2797
0
        ggml_vec_silu_backward_f32(nc,
2798
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2799
0
                (float *) ((char *) src1->data + i1*(src1->nb[1])),
2800
0
                (float *) ((char *) grad->data + i1*(grad->nb[1])));
2801
2802
#ifndef NDEBUG
2803
        for (int k = 0; k < nc; k++) {
2804
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2805
            GGML_UNUSED(x);
2806
            assert(!isnan(x));
2807
            assert(!isinf(x));
2808
        }
2809
#endif // NDEBUG
2810
0
    }
2811
0
}
2812
2813
static void ggml_compute_forward_silu_back_f16(
2814
    const ggml_compute_params * params,
2815
0
    ggml_tensor * dst) {
2816
2817
0
    const ggml_tensor * grad = dst->src[0];
2818
0
    const ggml_tensor * src1 = dst->src[1];
2819
2820
0
    assert(ggml_is_contiguous_1(grad));
2821
0
    assert(ggml_is_contiguous_1(src1));
2822
0
    assert(ggml_is_contiguous_1(dst));
2823
0
    assert(ggml_are_same_shape(src1, dst));
2824
0
    assert(ggml_are_same_shape(src1, grad));
2825
2826
0
    const int ith = params->ith;
2827
0
    const int nth = params->nth;
2828
2829
0
    const int nc = src1->ne[0];
2830
0
    const int nr = ggml_nrows(src1);
2831
2832
    // rows per thread
2833
0
    const int dr = (nr + nth - 1)/nth;
2834
2835
    // row range for this thread
2836
0
    const int ir0 = dr*ith;
2837
0
    const int ir1 = MIN(ir0 + dr, nr);
2838
2839
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2840
0
        ggml_vec_silu_backward_f16(nc,
2841
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2842
0
                (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
2843
0
                (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2844
2845
#ifndef NDEBUG
2846
        for (int k = 0; k < nc; k++) {
2847
            const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2848
            const float v = GGML_CPU_FP16_TO_FP32(x);
2849
            GGML_UNUSED(v);
2850
            assert(!isnan(v));
2851
            assert(!isinf(v));
2852
        }
2853
#endif // NDEBUG
2854
0
    }
2855
0
}
2856
2857
void ggml_compute_forward_silu_back(
2858
        const ggml_compute_params * params,
2859
0
        ggml_tensor * dst) {
2860
2861
0
    const ggml_tensor * src0 = dst->src[0];
2862
2863
0
    switch (src0->type) {
2864
0
        case GGML_TYPE_F32:
2865
0
            {
2866
0
                ggml_compute_forward_silu_back_f32(params, dst);
2867
0
            } break;
2868
0
        case GGML_TYPE_F16:
2869
0
            {
2870
0
                ggml_compute_forward_silu_back_f16(params, dst);
2871
0
            } break;
2872
0
        default:
2873
0
            {
2874
0
                GGML_ABORT("fatal error");
2875
0
            }
2876
0
    }
2877
0
}
2878
2879
// ggml_compute_forward_reglu
2880
2881
static void ggml_compute_forward_reglu_f32(
2882
        const ggml_compute_params * params,
2883
0
        ggml_tensor * dst) {
2884
2885
0
    const ggml_tensor * src0 = dst->src[0];
2886
0
    const ggml_tensor * src1 = dst->src[1];
2887
0
    char * src0_d = (char *) src0->data;
2888
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
2889
0
    const size_t src0_o = src0->nb[1];
2890
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2891
2892
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
2893
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
2894
2895
0
    if (src1) {
2896
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
2897
0
        GGML_ASSERT(src0->type == src1->type);
2898
0
    }
2899
2900
0
    const int ith = params->ith;
2901
0
    const int nth = params->nth;
2902
2903
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2904
0
    const int nr = ggml_nrows(src0);
2905
2906
0
    GGML_ASSERT(dst->ne[0] == nc);
2907
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
2908
2909
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2910
2911
    // rows per thread
2912
0
    const int dr = (nr + nth - 1)/nth;
2913
2914
    // row range for this thread
2915
0
    const int ir0 = dr*ith;
2916
0
    const int ir1 = MIN(ir0 + dr, nr);
2917
2918
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2919
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
2920
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
2921
2922
0
        if (!src1) {
2923
0
            src0_p += swapped ? nc : 0;
2924
0
            src1_p += swapped ? 0 : nc;
2925
0
        }
2926
2927
0
        ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2928
2929
#ifndef NDEBUG
2930
        for (int k = 0; k < nc; k++) {
2931
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2932
            GGML_UNUSED(x);
2933
            assert(!isnan(x));
2934
            assert(!isinf(x));
2935
        }
2936
#endif // NDEBUG
2937
0
    }
2938
0
}
2939
2940
static void ggml_compute_forward_reglu_f16(
2941
    const ggml_compute_params * params,
2942
0
    ggml_tensor * dst) {
2943
2944
0
    const ggml_tensor * src0 = dst->src[0];
2945
0
    const ggml_tensor * src1 = dst->src[1];
2946
0
    char * src0_d = (char *) src0->data;
2947
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
2948
0
    const size_t src0_o = src0->nb[1];
2949
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2950
2951
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
2952
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
2953
2954
0
    if (src1) {
2955
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
2956
0
        GGML_ASSERT(src0->type == src1->type);
2957
0
    }
2958
2959
0
    const int ith = params->ith;
2960
0
    const int nth = params->nth;
2961
2962
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2963
0
    const int nr = ggml_nrows(src0);
2964
2965
0
    GGML_ASSERT(dst->ne[0] == nc);
2966
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
2967
2968
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2969
2970
    // rows per thread
2971
0
    const int dr = (nr + nth - 1)/nth;
2972
2973
    // row range for this thread
2974
0
    const int ir0 = dr*ith;
2975
0
    const int ir1 = MIN(ir0 + dr, nr);
2976
2977
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2978
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
2979
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
2980
2981
0
        if (!src1) {
2982
0
            src0_p += swapped ? nc : 0;
2983
0
            src1_p += swapped ? 0 : nc;
2984
0
        }
2985
2986
0
        ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2987
2988
#ifndef NDEBUG
2989
        for (int k = 0; k < nc; k++) {
2990
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2991
            const float v = GGML_FP16_TO_FP32(x);
2992
            GGML_UNUSED(v);
2993
            assert(!isnan(v));
2994
            assert(!isinf(v));
2995
        }
2996
#endif // NDEBUG
2997
0
    }
2998
0
}
2999
3000
static void ggml_compute_forward_reglu(
3001
        const ggml_compute_params * params,
3002
0
        ggml_tensor * dst) {
3003
3004
0
    const ggml_tensor * src0 = dst->src[0];
3005
3006
0
    switch (src0->type) {
3007
0
        case GGML_TYPE_F32:
3008
0
            {
3009
0
                ggml_compute_forward_reglu_f32(params, dst);
3010
0
            } break;
3011
0
        case GGML_TYPE_F16:
3012
0
            {
3013
0
                ggml_compute_forward_reglu_f16(params, dst);
3014
0
            } break;
3015
0
        default:
3016
0
            {
3017
0
                GGML_ABORT("fatal error");
3018
0
            }
3019
0
    }
3020
0
}
3021
3022
// ggml_compute_forward_geglu
3023
3024
static void ggml_compute_forward_geglu_f32(
3025
        const ggml_compute_params * params,
3026
0
        ggml_tensor * dst) {
3027
3028
0
    const ggml_tensor * src0 = dst->src[0];
3029
0
    const ggml_tensor * src1 = dst->src[1];
3030
0
    char * src0_d = (char *) src0->data;
3031
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3032
0
    const size_t src0_o = src0->nb[1];
3033
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3034
3035
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3036
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3037
3038
0
    if (src1) {
3039
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3040
0
        GGML_ASSERT(src0->type == src1->type);
3041
0
    }
3042
3043
0
    const int ith = params->ith;
3044
0
    const int nth = params->nth;
3045
3046
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3047
0
    const int nr = ggml_nrows(src0);
3048
3049
0
    GGML_ASSERT(dst->ne[0] == nc);
3050
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3051
3052
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3053
3054
    // rows per thread
3055
0
    const int dr = (nr + nth - 1)/nth;
3056
3057
    // row range for this thread
3058
0
    const int ir0 = dr*ith;
3059
0
    const int ir1 = MIN(ir0 + dr, nr);
3060
3061
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3062
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3063
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3064
3065
0
        if (!src1) {
3066
0
            src0_p += swapped ? nc : 0;
3067
0
            src1_p += swapped ? 0 : nc;
3068
0
        }
3069
3070
0
        ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3071
3072
#ifndef NDEBUG
3073
        for (int k = 0; k < nc; k++) {
3074
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3075
            GGML_UNUSED(x);
3076
            assert(!isnan(x));
3077
            assert(!isinf(x));
3078
        }
3079
#endif // NDEBUG
3080
0
    }
3081
0
}
3082
3083
static void ggml_compute_forward_geglu_f16(
3084
    const ggml_compute_params * params,
3085
0
    ggml_tensor * dst) {
3086
3087
0
    const ggml_tensor * src0 = dst->src[0];
3088
0
    const ggml_tensor * src1 = dst->src[1];
3089
0
    char * src0_d = (char *) src0->data;
3090
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3091
0
    const size_t src0_o = src0->nb[1];
3092
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3093
3094
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3095
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3096
3097
0
    if (src1) {
3098
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3099
0
        GGML_ASSERT(src0->type == src1->type);
3100
0
    }
3101
3102
0
    const int ith = params->ith;
3103
0
    const int nth = params->nth;
3104
3105
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3106
0
    const int nr = ggml_nrows(src0);
3107
3108
0
    GGML_ASSERT(dst->ne[0] == nc);
3109
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3110
3111
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3112
3113
    // rows per thread
3114
0
    const int dr = (nr + nth - 1)/nth;
3115
3116
    // row range for this thread
3117
0
    const int ir0 = dr*ith;
3118
0
    const int ir1 = MIN(ir0 + dr, nr);
3119
3120
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3121
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3122
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3123
3124
0
        if (!src1) {
3125
0
            src0_p += swapped ? nc : 0;
3126
0
            src1_p += swapped ? 0 : nc;
3127
0
        }
3128
3129
0
        ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3130
3131
#ifndef NDEBUG
3132
        for (int k = 0; k < nc; k++) {
3133
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3134
            const float v = GGML_FP16_TO_FP32(x);
3135
            GGML_UNUSED(v);
3136
            assert(!isnan(v));
3137
            assert(!isinf(v));
3138
        }
3139
#endif // NDEBUG
3140
0
    }
3141
0
}
3142
3143
static void ggml_compute_forward_geglu(
3144
        const ggml_compute_params * params,
3145
0
        ggml_tensor * dst) {
3146
3147
0
    const ggml_tensor * src0 = dst->src[0];
3148
3149
0
    switch (src0->type) {
3150
0
        case GGML_TYPE_F32:
3151
0
            {
3152
0
                ggml_compute_forward_geglu_f32(params, dst);
3153
0
            } break;
3154
0
        case GGML_TYPE_F16:
3155
0
            {
3156
0
                ggml_compute_forward_geglu_f16(params, dst);
3157
0
            } break;
3158
0
        default:
3159
0
            {
3160
0
                GGML_ABORT("fatal error");
3161
0
            }
3162
0
    }
3163
0
}
3164
3165
// ggml_compute_forward_swiglu
3166
3167
static void ggml_compute_forward_swiglu_f32(
3168
        const ggml_compute_params * params,
3169
0
        ggml_tensor * dst) {
3170
3171
0
    const ggml_tensor * src0 = dst->src[0];
3172
0
    const ggml_tensor * src1 = dst->src[1];
3173
0
    char * src0_d = (char *) src0->data;
3174
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3175
0
    const size_t src0_o = src0->nb[1];
3176
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3177
3178
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3179
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3180
3181
0
    if (src1) {
3182
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3183
0
        GGML_ASSERT(src0->type == src1->type);
3184
0
    }
3185
3186
0
    const int ith = params->ith;
3187
0
    const int nth = params->nth;
3188
3189
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3190
0
    const int nr = ggml_nrows(src0);
3191
3192
0
    GGML_ASSERT(dst->ne[0] == nc);
3193
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3194
3195
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3196
3197
    // rows per thread
3198
0
    const int dr = (nr + nth - 1)/nth;
3199
3200
    // row range for this thread
3201
0
    const int ir0 = dr*ith;
3202
0
    const int ir1 = MIN(ir0 + dr, nr);
3203
3204
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3205
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3206
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3207
3208
0
        if (!src1) {
3209
0
            src0_p += swapped ? nc : 0;
3210
0
            src1_p += swapped ? 0 : nc;
3211
0
        }
3212
3213
0
        ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3214
3215
#ifndef NDEBUG
3216
        for (int k = 0; k < nc; k++) {
3217
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3218
            GGML_UNUSED(x);
3219
            assert(!isnan(x));
3220
            assert(!isinf(x));
3221
        }
3222
#endif // NDEBUG
3223
0
    }
3224
0
}
3225
3226
static void ggml_compute_forward_swiglu_f16(
3227
    const ggml_compute_params * params,
3228
0
    ggml_tensor * dst) {
3229
3230
0
    const ggml_tensor * src0 = dst->src[0];
3231
0
    const ggml_tensor * src1 = dst->src[1];
3232
0
    char * src0_d = (char *) src0->data;
3233
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3234
0
    const size_t src0_o = src0->nb[1];
3235
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3236
3237
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3238
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3239
3240
0
    if (src1) {
3241
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3242
0
        GGML_ASSERT(src0->type == src1->type);
3243
0
    }
3244
3245
0
    const int ith = params->ith;
3246
0
    const int nth = params->nth;
3247
3248
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3249
0
    const int nr = ggml_nrows(src0);
3250
3251
0
    GGML_ASSERT(dst->ne[0] == nc);
3252
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3253
3254
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3255
3256
    // rows per thread
3257
0
    const int dr = (nr + nth - 1)/nth;
3258
3259
    // row range for this thread
3260
0
    const int ir0 = dr*ith;
3261
0
    const int ir1 = MIN(ir0 + dr, nr);
3262
3263
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3264
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3265
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3266
3267
0
        if (!src1) {
3268
0
            src0_p += swapped ? nc : 0;
3269
0
            src1_p += swapped ? 0 : nc;
3270
0
        }
3271
3272
0
        ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3273
3274
#ifndef NDEBUG
3275
        for (int k = 0; k < nc; k++) {
3276
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3277
            const float v = GGML_FP16_TO_FP32(x);
3278
            GGML_UNUSED(v);
3279
            assert(!isnan(v));
3280
            assert(!isinf(v));
3281
        }
3282
#endif // NDEBUG
3283
0
    }
3284
0
}
3285
3286
static void ggml_compute_forward_swiglu(
3287
        const ggml_compute_params * params,
3288
0
        ggml_tensor * dst) {
3289
3290
0
    const ggml_tensor * src0 = dst->src[0];
3291
3292
0
    switch (src0->type) {
3293
0
        case GGML_TYPE_F32:
3294
0
            {
3295
0
                ggml_compute_forward_swiglu_f32(params, dst);
3296
0
            } break;
3297
0
        case GGML_TYPE_F16:
3298
0
            {
3299
0
                ggml_compute_forward_swiglu_f16(params, dst);
3300
0
            } break;
3301
0
        default:
3302
0
            {
3303
0
                GGML_ABORT("fatal error");
3304
0
            }
3305
0
    }
3306
0
}
3307
3308
// ggml_compute_forward_swiglu_oai
3309
3310
static void ggml_compute_forward_swiglu_oai_f32(
3311
        const ggml_compute_params * params,
3312
0
        ggml_tensor * dst) {
3313
3314
0
    const ggml_tensor * src0 = dst->src[0];
3315
0
    const ggml_tensor * src1 = dst->src[1];
3316
0
    char * src0_d = (char *) src0->data;
3317
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3318
0
    const size_t src0_o = src0->nb[1];
3319
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3320
3321
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3322
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3323
3324
0
    if (src1) {
3325
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3326
0
        GGML_ASSERT(src0->type == src1->type);
3327
0
    }
3328
3329
0
    const int ith = params->ith;
3330
0
    const int nth = params->nth;
3331
3332
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3333
0
    const int nr = ggml_nrows(src0);
3334
3335
0
    GGML_ASSERT(dst->ne[0] == nc);
3336
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3337
3338
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3339
0
    const float alpha = ggml_get_op_params_f32(dst, 2);
3340
0
    const float limit = ggml_get_op_params_f32(dst, 3);
3341
3342
    // rows per thread
3343
0
    const int dr = (nr + nth - 1)/nth;
3344
3345
    // row range for this thread
3346
0
    const int ir0 = dr*ith;
3347
0
    const int ir1 = MIN(ir0 + dr, nr);
3348
3349
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3350
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3351
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3352
0
        float * dst_p  = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3353
3354
0
        if (!src1) {
3355
0
            src0_p += swapped ? nc : 0;
3356
0
            src1_p += swapped ? 0 : nc;
3357
0
        }
3358
3359
0
        for (int k = 0; k < nc; k++) {
3360
0
            const float x = std::min(src0_p[k], limit);
3361
0
            const float y = std::clamp(src1_p[k], -limit, limit);
3362
0
            const float out_glu = x / (1.f + expf(alpha * (-x)));
3363
0
            dst_p[k] = out_glu * (y + 1.f);
3364
0
        }
3365
3366
#ifndef NDEBUG
3367
        for (int k = 0; k < nc; k++) {
3368
            const float x = dst_p[k];
3369
            GGML_UNUSED(x);
3370
            assert(!isnan(x));
3371
            assert(!isinf(x));
3372
        }
3373
#endif // NDEBUG
3374
0
    }
3375
0
}
3376
3377
static void ggml_compute_forward_swiglu_oai(
3378
        const ggml_compute_params * params,
3379
0
        ggml_tensor * dst) {
3380
3381
0
    const ggml_tensor * src0 = dst->src[0];
3382
3383
0
    switch (src0->type) {
3384
0
        case GGML_TYPE_F32:
3385
0
            {
3386
0
                ggml_compute_forward_swiglu_oai_f32(params, dst);
3387
0
            } break;
3388
0
        default:
3389
0
            {
3390
0
                GGML_ABORT("fatal error");
3391
0
            }
3392
0
    }
3393
0
}
3394
3395
// ggml_compute_forward_geglu_erf
3396
3397
static void ggml_compute_forward_geglu_erf_f32(
3398
        const ggml_compute_params * params,
3399
0
        ggml_tensor * dst) {
3400
3401
0
    const ggml_tensor * src0 = dst->src[0];
3402
0
    const ggml_tensor * src1 = dst->src[1];
3403
0
    char * src0_d = (char *) src0->data;
3404
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3405
0
    const size_t src0_o = src0->nb[1];
3406
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3407
3408
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3409
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3410
3411
0
    if (src1) {
3412
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3413
0
        GGML_ASSERT(src0->type == src1->type);
3414
0
    }
3415
3416
0
    const int ith = params->ith;
3417
0
    const int nth = params->nth;
3418
3419
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3420
0
    const int nr = ggml_nrows(src0);
3421
3422
0
    GGML_ASSERT(dst->ne[0] == nc);
3423
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3424
3425
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3426
3427
    // rows per thread
3428
0
    const int dr = (nr + nth - 1)/nth;
3429
3430
    // row range for this thread
3431
0
    const int ir0 = dr*ith;
3432
0
    const int ir1 = MIN(ir0 + dr, nr);
3433
3434
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3435
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3436
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3437
3438
0
        if (!src1) {
3439
0
            src0_p += swapped ? nc : 0;
3440
0
            src1_p += swapped ? 0 : nc;
3441
0
        }
3442
3443
0
        ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3444
3445
#ifndef NDEBUG
3446
        for (int k = 0; k < nc; k++) {
3447
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3448
            GGML_UNUSED(x);
3449
            assert(!isnan(x));
3450
            assert(!isinf(x));
3451
        }
3452
#endif // NDEBUG
3453
0
    }
3454
0
}
3455
3456
static void ggml_compute_forward_geglu_erf_f16(
3457
    const ggml_compute_params * params,
3458
0
    ggml_tensor * dst) {
3459
3460
0
    const ggml_tensor * src0 = dst->src[0];
3461
0
    const ggml_tensor * src1 = dst->src[1];
3462
0
    char * src0_d = (char *) src0->data;
3463
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3464
0
    const size_t src0_o = src0->nb[1];
3465
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3466
3467
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3468
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3469
3470
0
    if (src1) {
3471
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3472
0
        GGML_ASSERT(src0->type == src1->type);
3473
0
    }
3474
3475
0
    const int ith = params->ith;
3476
0
    const int nth = params->nth;
3477
3478
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3479
0
    const int nr = ggml_nrows(src0);
3480
3481
0
    GGML_ASSERT(dst->ne[0] == nc);
3482
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3483
3484
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3485
3486
    // rows per thread
3487
0
    const int dr = (nr + nth - 1)/nth;
3488
3489
    // row range for this thread
3490
0
    const int ir0 = dr*ith;
3491
0
    const int ir1 = MIN(ir0 + dr, nr);
3492
3493
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3494
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3495
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3496
3497
0
        if (!src1) {
3498
0
            src0_p += swapped ? nc : 0;
3499
0
            src1_p += swapped ? 0 : nc;
3500
0
        }
3501
3502
0
        ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3503
3504
#ifndef NDEBUG
3505
        for (int k = 0; k < nc; k++) {
3506
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3507
            const float v = GGML_FP16_TO_FP32(x);
3508
            GGML_UNUSED(v);
3509
            assert(!isnan(v));
3510
            assert(!isinf(v));
3511
        }
3512
#endif // NDEBUG
3513
0
    }
3514
0
}
3515
3516
static void ggml_compute_forward_geglu_erf(
3517
        const ggml_compute_params * params,
3518
0
        ggml_tensor * dst) {
3519
3520
0
    const ggml_tensor * src0 = dst->src[0];
3521
3522
0
    switch (src0->type) {
3523
0
        case GGML_TYPE_F32:
3524
0
            {
3525
0
                ggml_compute_forward_geglu_erf_f32(params, dst);
3526
0
            } break;
3527
0
        case GGML_TYPE_F16:
3528
0
            {
3529
0
                ggml_compute_forward_geglu_erf_f16(params, dst);
3530
0
            } break;
3531
0
        default:
3532
0
            {
3533
0
                GGML_ABORT("fatal error");
3534
0
            }
3535
0
    }
3536
0
}
3537
3538
// ggml_compute_forward_geglu_quick
3539
3540
static void ggml_compute_forward_geglu_quick_f32(
3541
        const ggml_compute_params * params,
3542
0
        ggml_tensor * dst) {
3543
3544
0
    const ggml_tensor * src0 = dst->src[0];
3545
0
    const ggml_tensor * src1 = dst->src[1];
3546
0
    char * src0_d = (char *) src0->data;
3547
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3548
0
    const size_t src0_o = src0->nb[1];
3549
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3550
3551
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3552
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3553
3554
0
    if (src1) {
3555
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3556
0
        GGML_ASSERT(src0->type == src1->type);
3557
0
    }
3558
3559
0
    const int ith = params->ith;
3560
0
    const int nth = params->nth;
3561
3562
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3563
0
    const int nr = ggml_nrows(src0);
3564
3565
0
    GGML_ASSERT(dst->ne[0] == nc);
3566
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3567
3568
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3569
3570
    // rows per thread
3571
0
    const int dr = (nr + nth - 1)/nth;
3572
3573
    // row range for this thread
3574
0
    const int ir0 = dr*ith;
3575
0
    const int ir1 = MIN(ir0 + dr, nr);
3576
3577
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3578
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3579
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3580
3581
0
        if (!src1) {
3582
0
            src0_p += swapped ? nc : 0;
3583
0
            src1_p += swapped ? 0 : nc;
3584
0
        }
3585
3586
0
        ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3587
3588
#ifndef NDEBUG
3589
        for (int k = 0; k < nc; k++) {
3590
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3591
            GGML_UNUSED(x);
3592
            assert(!isnan(x));
3593
            assert(!isinf(x));
3594
        }
3595
#endif // NDEBUG
3596
0
    }
3597
0
}
3598
3599
static void ggml_compute_forward_geglu_quick_f16(
3600
    const ggml_compute_params * params,
3601
0
    ggml_tensor * dst) {
3602
3603
0
    const ggml_tensor * src0 = dst->src[0];
3604
0
    const ggml_tensor * src1 = dst->src[1];
3605
0
    char * src0_d = (char *) src0->data;
3606
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3607
0
    const size_t src0_o = src0->nb[1];
3608
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3609
3610
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3611
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3612
3613
0
    if (src1) {
3614
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3615
0
        GGML_ASSERT(src0->type == src1->type);
3616
0
    }
3617
3618
0
    const int ith = params->ith;
3619
0
    const int nth = params->nth;
3620
3621
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3622
0
    const int nr = ggml_nrows(src0);
3623
3624
0
    GGML_ASSERT(dst->ne[0] == nc);
3625
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3626
3627
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3628
3629
    // rows per thread
3630
0
    const int dr = (nr + nth - 1)/nth;
3631
3632
    // row range for this thread
3633
0
    const int ir0 = dr*ith;
3634
0
    const int ir1 = MIN(ir0 + dr, nr);
3635
3636
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3637
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3638
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3639
3640
0
        if (!src1) {
3641
0
            src0_p += swapped ? nc : 0;
3642
0
            src1_p += swapped ? 0 : nc;
3643
0
        }
3644
3645
0
        ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3646
3647
#ifndef NDEBUG
3648
        for (int k = 0; k < nc; k++) {
3649
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3650
            const float v = GGML_FP16_TO_FP32(x);
3651
            GGML_UNUSED(v);
3652
            assert(!isnan(v));
3653
            assert(!isinf(v));
3654
        }
3655
#endif // NDEBUG
3656
0
    }
3657
0
}
3658
3659
static void ggml_compute_forward_geglu_quick(
3660
        const ggml_compute_params * params,
3661
0
        ggml_tensor * dst) {
3662
3663
0
    const ggml_tensor * src0 = dst->src[0];
3664
3665
0
    switch (src0->type) {
3666
0
        case GGML_TYPE_F32:
3667
0
            {
3668
0
                ggml_compute_forward_geglu_quick_f32(params, dst);
3669
0
            } break;
3670
0
        case GGML_TYPE_F16:
3671
0
            {
3672
0
                ggml_compute_forward_geglu_quick_f16(params, dst);
3673
0
            } break;
3674
0
        default:
3675
0
            {
3676
0
                GGML_ABORT("fatal error");
3677
0
            }
3678
0
    }
3679
0
}
3680
3681
// ggml_compute_forward_norm
3682
3683
static void ggml_compute_forward_norm_f32(
3684
        const ggml_compute_params * params,
3685
0
        ggml_tensor * dst) {
3686
3687
0
    const ggml_tensor * src0 = dst->src[0];
3688
3689
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
3690
3691
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3692
3693
0
    const int ith = params->ith;
3694
0
    const int nth = params->nth;
3695
3696
0
    GGML_TENSOR_UNARY_OP_LOCALS
3697
3698
0
    float eps;
3699
0
    memcpy(&eps, dst->op_params, sizeof(float));
3700
3701
0
    GGML_ASSERT(eps >= 0.0f);
3702
3703
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
3704
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
3705
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3706
0
                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3707
3708
0
                float sum = 0.0;
3709
0
                ggml_vec_sum_f32(ne00, &sum, x);
3710
0
                float mean = sum/ne00;
3711
3712
0
                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3713
0
                float variance = 0;
3714
3715
#ifdef GGML_USE_ACCELERATE
3716
                mean = -mean;
3717
                vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3718
                vDSP_measqv(y, 1, &variance, ne00);
3719
#else
3720
0
                variance = ggml_vec_cvar_f32(ne00, y, x, mean);
3721
0
#endif //GGML_USE_ACCELERATE
3722
3723
0
                const float scale = 1.0f/sqrtf(variance + eps);
3724
0
                ggml_vec_scale_f32(ne00, y, scale);
3725
0
            }
3726
0
        }
3727
0
    }
3728
0
}
3729
3730
void ggml_compute_forward_norm(
3731
        const ggml_compute_params * params,
3732
0
        ggml_tensor * dst) {
3733
3734
0
    const ggml_tensor * src0 = dst->src[0];
3735
3736
0
    switch (src0->type) {
3737
0
        case GGML_TYPE_F32:
3738
0
            {
3739
0
                ggml_compute_forward_norm_f32(params, dst);
3740
0
            } break;
3741
0
        default:
3742
0
            {
3743
0
                GGML_ABORT("fatal error");
3744
0
            }
3745
0
    }
3746
0
}
3747
3748
// ggml_compute_forward_group_rms_norm
3749
3750
// fusion kinds that can be combined with the rms_norm computation in a single pass.
3751
// extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...).
3752
enum ggml_rms_norm_fuse_op {
3753
    GGML_RMS_NORM_FUSE_OP_NONE,
3754
    GGML_RMS_NORM_FUSE_OP_MUL,
3755
};
3756
3757
template <ggml_rms_norm_fuse_op FUSE_OP>
3758
static void ggml_compute_forward_rms_norm_f32(
3759
        const ggml_compute_params * params,
3760
        ggml_tensor * dst_rms_norm,
3761
0
        ggml_tensor * dst_fused = nullptr) {
3762
3763
0
    const ggml_tensor * src0 = dst_rms_norm->src[0];
3764
0
    const ggml_tensor * src1 = nullptr;
3765
0
    ggml_tensor       * dst  = dst_rms_norm;
3766
3767
0
    if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
3768
0
        src1 = (dst_fused->src[0] == dst_rms_norm) ? dst_fused->src[1] : dst_fused->src[0];
3769
0
        dst  = dst_fused;
3770
0
    }
3771
3772
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
3773
3774
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3775
3776
0
    const int ith = params->ith;
3777
0
    const int nth = params->nth;
3778
3779
0
    GGML_TENSOR_BINARY_OP_LOCALS
3780
3781
0
    float eps;
3782
0
    memcpy(&eps, dst_rms_norm->op_params, sizeof(float));
3783
0
    GGML_ASSERT(eps >= 0.0f);
3784
3785
    // TODO: optimize
3786
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
3787
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
3788
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3789
0
                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3790
3791
0
                ggml_float sum = 0.0;
3792
                // worth switching to explicit SIMD?
3793
0
                for (int64_t i00 = 0; i00 < ne00; i00++) {
3794
0
                    sum += (ggml_float)(x[i00] * x[i00]);
3795
0
                }
3796
3797
0
                const float mean  = sum/ne00;
3798
0
                const float scale = 1.0f/sqrtf(mean + eps);
3799
3800
                // if you hit this, likely you got an inf somewhere earlier
3801
0
                assert(scale > 0.0f);
3802
3803
0
                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3804
3805
0
                if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) {
3806
0
                    const int64_t i11 = i01 % ne11;
3807
0
                    const int64_t i12 = i02 % ne12;
3808
0
                    const int64_t i13 = i03 % ne13;
3809
0
                    const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
3810
3811
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
3812
0
                        y[i00] = x[i00] * scale * w[i00];
3813
0
                    }
3814
0
                } else {
3815
0
                    memcpy(y, x, ne00 * sizeof(float));
3816
0
                    ggml_vec_scale_f32(ne00, y, scale);
3817
0
                }
3818
0
            }
3819
0
        }
3820
0
    }
3821
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_rms_norm_f32<(ggml_rms_norm_fuse_op)0>(ggml_compute_params const*, ggml_tensor*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_rms_norm_f32<(ggml_rms_norm_fuse_op)1>(ggml_compute_params const*, ggml_tensor*, ggml_tensor*)
3822
3823
void ggml_compute_forward_rms_norm(
3824
        const ggml_compute_params * params,
3825
0
        ggml_tensor * dst) {
3826
3827
0
    const ggml_tensor * src0 = dst->src[0];
3828
3829
0
    switch (src0->type) {
3830
0
        case GGML_TYPE_F32:
3831
0
            {
3832
0
                ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_NONE>(params, dst);
3833
0
            } break;
3834
0
        default:
3835
0
            {
3836
0
                GGML_ABORT("fatal error");
3837
0
            }
3838
0
    }
3839
0
}
3840
3841
// Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass.
3842
// This avoids materializing the intermediate rms_norm result in memory.
3843
void ggml_compute_forward_rms_norm_mul_fused(
3844
        const ggml_compute_params * params,
3845
        ggml_tensor * dst_rms_norm,
3846
0
        ggml_tensor * dst_mul) {
3847
3848
0
    GGML_ASSERT(dst_mul != nullptr);
3849
0
    GGML_ASSERT(dst_mul->src[0] == dst_rms_norm || dst_mul->src[1] == dst_rms_norm);
3850
3851
0
    const ggml_tensor * src0 = dst_rms_norm->src[0];
3852
3853
0
    switch (src0->type) {
3854
0
        case GGML_TYPE_F32:
3855
0
            {
3856
0
                ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_MUL>(params, dst_rms_norm, dst_mul);
3857
0
            } break;
3858
0
        default:
3859
0
            {
3860
0
                GGML_ABORT("fatal error");
3861
0
            }
3862
0
    }
3863
0
}
3864
3865
static void ggml_compute_forward_rms_norm_back_f32(
3866
        const ggml_compute_params * params,
3867
0
        ggml_tensor * dst) {
3868
3869
0
    const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
3870
0
    const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
3871
3872
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
3873
3874
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3875
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
3876
3877
0
    const int ith = params->ith;
3878
0
    const int nth = params->nth;
3879
3880
0
    GGML_TENSOR_BINARY_OP_LOCALS
3881
3882
0
    float eps;
3883
0
    memcpy(&eps, dst->op_params, sizeof(float));
3884
3885
    // TODO: optimize
3886
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
3887
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
3888
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3889
                // src1 is same shape as src0 => same indices
3890
0
                const int64_t i11 = i01;
3891
0
                const int64_t i12 = i02;
3892
0
                const int64_t i13 = i03;
3893
3894
0
                const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3895
0
                const float * x  = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
3896
3897
0
                ggml_float sum_xx  = 0.0;
3898
0
                ggml_float sum_xdz = 0.0;
3899
3900
0
                for (int64_t i00 = 0; i00 < ne00; i00++) {
3901
0
                    sum_xx  += (ggml_float)(x[i00] * x[i00]);
3902
0
                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);
3903
0
                }
3904
3905
                //const float mean     = (float)(sum_xx)/ne00;
3906
0
                const float mean_eps = (float)(sum_xx)/ne00 + eps;
3907
0
                const float sum_eps  = (float)(sum_xx) + eps*ne00;
3908
                //const float mean_xdz = (float)(sum_xdz)/ne00;
3909
                // we could cache rms from forward pass to improve performance.
3910
                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
3911
                //const float rms      = sqrtf(mean_eps);
3912
0
                const float rrms     = 1.0f / sqrtf(mean_eps);
3913
                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
3914
3915
0
                {
3916
                    // z = rms_norm(x)
3917
                    //
3918
                    // rms_norm(src1) =
3919
                    //     scale(
3920
                    //         src1,
3921
                    //         div(
3922
                    //             1,
3923
                    //             sqrt(
3924
                    //                 add(
3925
                    //                     scale(
3926
                    //                         sum(
3927
                    //                             sqr(
3928
                    //                                 src1)),
3929
                    //                         (1.0/N)),
3930
                    //                     eps))));
3931
3932
                    // postorder:
3933
                    // ## op    args         grad
3934
                    // 00 param src1         grad[#00]
3935
                    // 01 const 1
3936
                    // 02 sqr   (#00)        grad[#02]
3937
                    // 03 sum   (#02)        grad[#03]
3938
                    // 04 const 1/N
3939
                    // 05 scale (#03, #04)   grad[#05]
3940
                    // 06 const eps
3941
                    // 07 add   (#05, #06)   grad[#07]
3942
                    // 08 sqrt  (#07)        grad[#08]
3943
                    // 09 div   (#01,#08)    grad[#09]
3944
                    // 10 scale (#00,#09)    grad[#10]
3945
                    //
3946
                    // backward pass, given grad[#10]
3947
                    // #10: scale
3948
                    // grad[#00] += scale(grad[#10],#09)
3949
                    // grad[#09] += sum(mul(grad[#10],#00))
3950
                    // #09: div
3951
                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
3952
                    // #08: sqrt
3953
                    // grad[#07] += mul(grad[#08], div(0.5, #08))
3954
                    // #07: add
3955
                    // grad[#05] += grad[#07]
3956
                    // #05: scale
3957
                    // grad[#03] += scale(grad[#05],#04)
3958
                    // #03: sum
3959
                    // grad[#02] += repeat(grad[#03], #02)
3960
                    // #02:
3961
                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
3962
                    //
3963
                    // substitute and simplify:
3964
                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
3965
                    // grad[#02] = repeat(grad[#03], #02)
3966
                    // grad[#02] = repeat(scale(grad[#05],#04), #02)
3967
                    // grad[#02] = repeat(scale(grad[#07],#04), #02)
3968
                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
3969
                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
3970
                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
3971
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
3972
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
3973
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
3974
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
3975
                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
3976
                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
3977
                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
3978
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
3979
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
3980
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
3981
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
3982
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
3983
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
3984
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
3985
                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
3986
                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
3987
                    // a = b*c + d*e
3988
                    // a = b*c*f/f + d*e*f/f
3989
                    // a = (b*c*f + d*e*f)*(1/f)
3990
                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
3991
                    // a = (b + d*e/c)*c
3992
                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
3993
                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
3994
                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
3995
                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
3996
                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
3997
                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
3998
                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
3999
                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
4000
                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
4001
                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
4002
0
                }
4003
                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
4004
                // post-order:
4005
                // dx := x
4006
                // dx := scale(dx,-mean_xdz/mean_eps)
4007
                // dx := add(dx, dz)
4008
                // dx := scale(dx, rrms)
4009
0
                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
4010
4011
                // dx[i00] = (dz + x*(-sum_xdz/sum_eps)) * rrms
4012
                // note: https://github.com/ggml-org/ggml/issues/1491
4013
0
                const float scale_x = (float) (-sum_xdz) / sum_eps;
4014
0
                for (int64_t i00 = 0; i00 < ne00; i00++) {
4015
0
                    dx[i00] = (dz[i00] + x[i00] * scale_x) * rrms;
4016
0
                }
4017
0
            }
4018
0
        }
4019
0
    }
4020
0
}
4021
4022
void ggml_compute_forward_rms_norm_back(
4023
        const ggml_compute_params * params,
4024
0
        ggml_tensor * dst) {
4025
4026
0
    const ggml_tensor * src0 = dst->src[0];
4027
4028
0
    switch (src0->type) {
4029
0
        case GGML_TYPE_F32:
4030
0
            {
4031
0
                ggml_compute_forward_rms_norm_back_f32(params, dst);
4032
0
            } break;
4033
0
        default:
4034
0
            {
4035
0
                GGML_ABORT("fatal error");
4036
0
            }
4037
0
    }
4038
0
}
4039
4040
// ggml_compute_forward_group_norm
4041
4042
static void ggml_compute_forward_group_norm_f32(
4043
    const ggml_compute_params * params,
4044
0
    ggml_tensor * dst) {
4045
4046
0
    const ggml_tensor * src0 = dst->src[0];
4047
4048
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4049
4050
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
4051
4052
0
    const int ith = params->ith;
4053
0
    const int nth = params->nth;
4054
4055
0
    GGML_TENSOR_UNARY_OP_LOCALS
4056
4057
    // TODO: optimize
4058
4059
0
    float eps;
4060
0
    memcpy(&eps, dst->op_params + 1, sizeof(float));
4061
4062
0
    int n_channels = src0->ne[2];
4063
0
    int n_groups = dst->op_params[0];
4064
0
    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
4065
0
    for (int i = ith; i < n_groups; i += nth) {
4066
0
        int start = i * n_channels_per_group;
4067
0
        int end = start + n_channels_per_group;
4068
0
        if (end > n_channels) {
4069
0
            end = n_channels;
4070
0
        }
4071
0
        int step = end - start;
4072
4073
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
4074
0
            ggml_float sum = 0.0;
4075
0
            for (int64_t i02 = start; i02 < end; i02++) {
4076
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
4077
0
                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
4078
4079
0
                    ggml_float sumr = 0.0;
4080
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
4081
0
                        sumr += (ggml_float)x[i00];
4082
0
                    }
4083
0
                    sum += sumr;
4084
0
                }
4085
0
            }
4086
0
            const float mean = sum / (ne00 * ne01 * step);
4087
4088
0
            ggml_float sum2 = 0.0;
4089
0
            for (int64_t i02 = start; i02 < end; i02++) {
4090
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
4091
0
                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
4092
4093
0
                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
4094
4095
0
                    ggml_float sumr = 0.0;
4096
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
4097
0
                        float v = x[i00] - mean;
4098
0
                        y[i00] = v;
4099
0
                        sumr += (ggml_float)(v * v);
4100
0
                    }
4101
0
                    sum2 += sumr;
4102
0
                }
4103
0
            }
4104
0
            const float variance = sum2 / (ne00 * ne01 * step);
4105
0
            const float scale = 1.0f / sqrtf(variance + eps);
4106
4107
0
            for (int64_t i02 = start; i02 < end; i02++) {
4108
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
4109
0
                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
4110
0
                    ggml_vec_scale_f32(ne00, y, scale);
4111
0
                }
4112
0
            }
4113
0
        }
4114
0
    }
4115
0
}
4116
4117
void ggml_compute_forward_group_norm(
4118
    const ggml_compute_params * params,
4119
0
    ggml_tensor * dst) {
4120
4121
0
    const ggml_tensor * src0 = dst->src[0];
4122
4123
0
    switch (src0->type) {
4124
0
        case GGML_TYPE_F32:
4125
0
            {
4126
0
                ggml_compute_forward_group_norm_f32(params, dst);
4127
0
            } break;
4128
0
        default:
4129
0
            {
4130
0
                GGML_ABORT("fatal error");
4131
0
            }
4132
0
    }
4133
0
}
4134
4135
// ggml_compute_forward_l2_norm
4136
4137
static void ggml_compute_forward_l2_norm_f32(
4138
    const ggml_compute_params * params,
4139
0
    ggml_tensor * dst) {
4140
4141
0
    const ggml_tensor * src0 = dst->src[0];
4142
4143
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4144
4145
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
4146
4147
0
    const int ith = params->ith;
4148
0
    const int nth = params->nth;
4149
4150
0
    GGML_TENSOR_UNARY_OP_LOCALS
4151
4152
0
    float eps;
4153
0
    memcpy(&eps, dst->op_params, sizeof(float));
4154
4155
0
    GGML_ASSERT(eps >= 0.0f);
4156
4157
    // TODO: optimize
4158
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
4159
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
4160
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
4161
0
                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
4162
4163
0
                ggml_float sum = 0.0;
4164
0
                for (int64_t i00 = 0; i00 < ne00; i00++) {
4165
0
                    sum += (ggml_float)(x[i00] * x[i00]);
4166
0
                }
4167
4168
0
                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
4169
4170
0
                memcpy(y, x, ne00 * sizeof(float));
4171
4172
0
                const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
4173
4174
0
                ggml_vec_scale_f32(ne00, y, scale);
4175
0
            }
4176
0
        }
4177
0
    }
4178
0
}
4179
4180
void ggml_compute_forward_l2_norm(
4181
    const ggml_compute_params * params,
4182
0
    ggml_tensor * dst) {
4183
4184
0
    const ggml_tensor * src0 = dst->src[0];
4185
4186
0
    switch (src0->type) {
4187
0
        case GGML_TYPE_F32:
4188
0
            {
4189
0
                ggml_compute_forward_l2_norm_f32(params, dst);
4190
0
            } break;
4191
0
        default:
4192
0
            {
4193
0
                GGML_ABORT("fatal error");
4194
0
            }
4195
0
    }
4196
0
}
4197
4198
// ggml_compute_forward_out_prod
4199
4200
static void ggml_compute_forward_out_prod_f32(
4201
        const ggml_compute_params * params,
4202
0
              ggml_tensor * dst) {
4203
4204
0
    const ggml_tensor * src0 = dst->src[0];
4205
0
    const ggml_tensor * src1 = dst->src[1];
4206
4207
0
    GGML_TENSOR_BINARY_OP_LOCALS
4208
4209
0
    GGML_ASSERT(dst->type == GGML_TYPE_F32);
4210
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
4211
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
4212
4213
0
    const int ith = params->ith;
4214
0
    const int nth = params->nth;
4215
4216
0
    GGML_ASSERT(ne0 == ne00);
4217
0
    GGML_ASSERT(ne1 == ne10);
4218
0
    GGML_ASSERT(ne2 == ne12);
4219
0
    GGML_ASSERT(ne3 == ne13);
4220
4221
0
    GGML_ASSERT(ne2 % ne02 == 0);
4222
0
    GGML_ASSERT(ne3 % ne03 == 0);
4223
4224
    // we don't support permuted src0 or src1
4225
0
    GGML_ASSERT(nb00 == sizeof(float));
4226
4227
    // dst cannot be transposed or permuted
4228
0
    GGML_ASSERT(nb0 == sizeof(float));
4229
    // GGML_ASSERT(nb0 <= nb1);
4230
    // GGML_ASSERT(nb1 <= nb2);
4231
    // GGML_ASSERT(nb2 <= nb3);
4232
4233
    // nb01 >= nb00 - src0 is not transposed
4234
    //   compute by src0 rows
4235
4236
0
    if (ith == 0) {
4237
0
        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
4238
0
    }
4239
0
    ggml_barrier(params->threadpool);
4240
4241
    // dst[:,:,:,:] = 0
4242
    // for i2,i3:
4243
    //   for i1:
4244
    //     for i01:
4245
    //       for i0:
4246
    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4247
4248
    // parallelize by last three dimensions
4249
4250
    // total rows in dst
4251
0
    const int64_t nr = ne1*ne2*ne3;
4252
4253
    // rows per thread
4254
0
    const int64_t dr = (nr + nth - 1)/nth;
4255
4256
    // row range for this thread
4257
0
    const int64_t ir0 = dr*ith;
4258
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
4259
4260
    // block-tiling attempt
4261
0
    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
4262
0
    const int64_t blck_1 = 16;
4263
4264
    // dps == dst per src0, used for group query attention
4265
0
    const int64_t dps2 = ne2 / ne02;
4266
0
    const int64_t dps3 = ne3 / ne03;
4267
4268
0
    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
4269
0
        const int64_t bir1 = MIN(bir + blck_1, ir1);
4270
0
        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
4271
0
            const int64_t bne01 = MIN(bi01 + blck_0, ne01);
4272
0
            for (int64_t ir = bir; ir < bir1; ++ir) {
4273
                // dst indices
4274
0
                const int64_t i3 = ir/(ne2*ne1);
4275
0
                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4276
0
                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4277
4278
0
                const int64_t i02 = i2 / dps2;
4279
0
                const int64_t i03 = i3 / dps3;
4280
4281
                //const int64_t i10 = i1;
4282
0
                const int64_t i12 = i2;
4283
0
                const int64_t i13 = i3;
4284
4285
0
#if GGML_VEC_MAD_UNROLL > 2
4286
0
                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
4287
0
                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
4288
0
                    const int64_t i11 = i01;
4289
4290
0
                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4291
0
                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4292
0
                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
4293
4294
0
                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
4295
0
                }
4296
0
                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
4297
0
                    const int64_t i11 = i01;
4298
4299
0
                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4300
0
                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4301
0
                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
4302
4303
0
                    ggml_vec_mad_f32(ne0, d, s0, *s1);
4304
0
                }
4305
#else
4306
                for (int64_t i01 = bi01; i01 < bne01; ++i01) {
4307
                    const int64_t i11 = i01;
4308
4309
                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4310
                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4311
                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
4312
4313
                    ggml_vec_mad_f32(ne0, d, s0, *s1);
4314
                }
4315
#endif
4316
0
            }
4317
0
        }
4318
0
    }
4319
0
}
4320
4321
static void ggml_compute_forward_out_prod_q_f32(
4322
        const ggml_compute_params * params,
4323
0
              ggml_tensor * dst) {
4324
4325
0
    const ggml_tensor * src0 = dst->src[0];
4326
0
    const ggml_tensor * src1 = dst->src[1];
4327
4328
0
    GGML_TENSOR_BINARY_OP_LOCALS;
4329
4330
0
    const int ith = params->ith;
4331
0
    const int nth = params->nth;
4332
4333
0
    const ggml_type type = src0->type;
4334
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
4335
4336
0
    GGML_ASSERT(ne02 == ne12);
4337
0
    GGML_ASSERT(ne03 == ne13);
4338
0
    GGML_ASSERT(ne2  == ne12);
4339
0
    GGML_ASSERT(ne3  == ne13);
4340
4341
    // we don't support permuted src0 dim0
4342
0
    GGML_ASSERT(nb00 == ggml_type_size(type));
4343
4344
    // dst dim0 cannot be transposed or permuted
4345
0
    GGML_ASSERT(nb0 == sizeof(float));
4346
    // GGML_ASSERT(nb0 <= nb1);
4347
    // GGML_ASSERT(nb1 <= nb2);
4348
    // GGML_ASSERT(nb2 <= nb3);
4349
4350
0
    GGML_ASSERT(ne0 == ne00);
4351
0
    GGML_ASSERT(ne1 == ne10);
4352
0
    GGML_ASSERT(ne2 == ne02);
4353
0
    GGML_ASSERT(ne3 == ne03);
4354
4355
    // nb01 >= nb00 - src0 is not transposed
4356
    //   compute by src0 rows
4357
4358
0
    if (ith == 0) {
4359
0
        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
4360
0
    }
4361
0
    ggml_barrier(params->threadpool);
4362
4363
    // parallelize by last three dimensions
4364
4365
    // total rows in dst
4366
0
    const int64_t nr = ne1*ne2*ne3;
4367
4368
    // rows per thread
4369
0
    const int64_t dr = (nr + nth - 1)/nth;
4370
4371
    // row range for this thread
4372
0
    const int64_t ir0 = dr*ith;
4373
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
4374
4375
    // dst[:,:,:,:] = 0
4376
    // for i2,i3:
4377
    //   for i1:
4378
    //     for i01:
4379
    //       for i0:
4380
    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4381
4382
0
    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
4383
4384
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
4385
        // dst indices
4386
0
        const int64_t i3 = ir/(ne2*ne1);
4387
0
        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4388
0
        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4389
4390
0
        const int64_t i02 = i2;
4391
0
        const int64_t i03 = i3;
4392
4393
        //const int64_t i10 = i1;
4394
0
        const int64_t i12 = i2;
4395
0
        const int64_t i13 = i3;
4396
4397
0
        for (int64_t i01 = 0; i01 < ne01; ++i01) {
4398
0
            const int64_t i11 = i01;
4399
4400
0
            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4401
0
            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4402
0
            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
4403
4404
0
            dequantize_row_q(s0, wdata, ne0);
4405
0
            ggml_vec_mad_f32(ne0, d, wdata, *s1);
4406
0
        }
4407
0
    }
4408
0
}
4409
4410
void ggml_compute_forward_out_prod(
4411
        const ggml_compute_params * params,
4412
0
        ggml_tensor * dst) {
4413
4414
0
    const ggml_tensor * src0 = dst->src[0];
4415
4416
0
    switch (src0->type) {
4417
0
        case GGML_TYPE_Q1_0:
4418
0
        case GGML_TYPE_Q4_0:
4419
0
        case GGML_TYPE_Q4_1:
4420
0
        case GGML_TYPE_Q5_0:
4421
0
        case GGML_TYPE_Q5_1:
4422
0
        case GGML_TYPE_Q8_0:
4423
0
        case GGML_TYPE_MXFP4:
4424
0
        case GGML_TYPE_NVFP4:
4425
0
        case GGML_TYPE_Q2_K:
4426
0
        case GGML_TYPE_Q3_K:
4427
0
        case GGML_TYPE_Q4_K:
4428
0
        case GGML_TYPE_Q5_K:
4429
0
        case GGML_TYPE_Q6_K:
4430
0
        case GGML_TYPE_TQ1_0:
4431
0
        case GGML_TYPE_TQ2_0:
4432
0
        case GGML_TYPE_IQ2_XXS:
4433
0
        case GGML_TYPE_IQ2_XS:
4434
0
        case GGML_TYPE_IQ3_XXS:
4435
0
        case GGML_TYPE_IQ1_S:
4436
0
        case GGML_TYPE_IQ1_M:
4437
0
        case GGML_TYPE_IQ4_NL:
4438
0
        case GGML_TYPE_IQ4_XS:
4439
0
        case GGML_TYPE_IQ3_S:
4440
0
        case GGML_TYPE_IQ2_S:
4441
0
            {
4442
0
                ggml_compute_forward_out_prod_q_f32(params, dst);
4443
0
            } break;
4444
0
        case GGML_TYPE_F16:
4445
0
            {
4446
0
                GGML_ABORT("fatal error"); // todo
4447
                // ggml_compute_forward_out_prod_f16_f32(params, dst);
4448
0
            }
4449
0
        case GGML_TYPE_F32:
4450
0
            {
4451
0
                ggml_compute_forward_out_prod_f32(params, dst);
4452
0
            } break;
4453
0
        default:
4454
0
            {
4455
0
                GGML_ABORT("fatal error");
4456
0
            }
4457
0
    }
4458
0
}
4459
4460
// ggml_compute_forward_scale
4461
4462
static void ggml_compute_forward_scale_f32(
4463
        const ggml_compute_params * params,
4464
0
        ggml_tensor * dst) {
4465
4466
0
    const ggml_tensor * src0 = dst->src[0];
4467
4468
0
    GGML_ASSERT(ggml_is_contiguous(src0));
4469
0
    GGML_ASSERT(ggml_is_contiguous(dst));
4470
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4471
4472
0
    float s; // scale factor
4473
0
    float b; // bias
4474
4475
0
    memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4476
0
    memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
4477
4478
0
    const int ith = params->ith;
4479
0
    const int nth = params->nth;
4480
4481
0
    const int nc = src0->ne[0];
4482
0
    const int nr = ggml_nrows(src0);
4483
4484
    // rows per thread
4485
0
    const int dr = (nr + nth - 1)/nth;
4486
4487
    // row range for this thread
4488
0
    const int ir0 = dr*ith;
4489
0
    const int ir1 = MIN(ir0 + dr, nr);
4490
4491
0
    const size_t nb01 = src0->nb[1];
4492
4493
0
    const size_t nb1 = dst->nb[1];
4494
4495
0
    if (b == 0.0f) {
4496
0
        for (int i1 = ir0; i1 < ir1; i1++) {
4497
0
            if (dst->data != src0->data) {
4498
                // src0 is same shape as dst => same indices
4499
                // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4500
0
                memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4501
0
            }
4502
0
            ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4503
0
        }
4504
0
    } else {
4505
0
        for (int i1 = ir0; i1 < ir1; i1++) {
4506
0
            ggml_vec_mad1_f32(nc,
4507
0
                (float *) ((char *) dst->data  + i1*nb1),
4508
0
                (float *) ((char *) src0->data + i1*nb1),
4509
0
                s, b);
4510
0
        }
4511
0
    }
4512
0
}
4513
4514
void ggml_compute_forward_scale(
4515
        const ggml_compute_params * params,
4516
0
        ggml_tensor * dst) {
4517
4518
0
    const ggml_tensor * src0 = dst->src[0];
4519
4520
0
    switch (src0->type) {
4521
0
        case GGML_TYPE_F32:
4522
0
            {
4523
0
                ggml_compute_forward_scale_f32(params, dst);
4524
0
            } break;
4525
0
        default:
4526
0
            {
4527
0
                GGML_ABORT("fatal error");
4528
0
            }
4529
0
    }
4530
0
}
4531
4532
// ggml_compute_forward_set
4533
4534
static void ggml_compute_forward_set_f32(
4535
        const ggml_compute_params * params,
4536
0
        ggml_tensor * dst) {
4537
4538
0
    const ggml_tensor * src0 = dst->src[0];
4539
0
    const ggml_tensor * src1 = dst->src[1];
4540
4541
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4542
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
4543
4544
    // view src0 and dst with these strides and data offset inbytes during set
4545
    // nb0 is implicitly element_size because src0 and dst are contiguous
4546
0
    size_t nb1     = ((int32_t *) dst->op_params)[0];
4547
0
    size_t nb2     = ((int32_t *) dst->op_params)[1];
4548
0
    size_t nb3     = ((int32_t *) dst->op_params)[2];
4549
0
    size_t offset  = ((int32_t *) dst->op_params)[3];
4550
0
    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
4551
4552
0
    if (!inplace) {
4553
0
        if (params->ith == 0) {
4554
            // memcpy needs to be synchronized across threads to avoid race conditions.
4555
            // => do it in INIT phase
4556
0
            memcpy(
4557
0
                ((char *)  dst->data),
4558
0
                ((char *) src0->data),
4559
0
                ggml_nbytes(dst));
4560
0
        }
4561
0
        ggml_barrier(params->threadpool);
4562
0
    }
4563
4564
0
    const int ith = params->ith;
4565
0
    const int nth = params->nth;
4566
4567
0
    const int nr = ggml_nrows(src1);
4568
0
    const int nc = src1->ne[0];
4569
4570
0
    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
4571
0
    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
4572
4573
    // src0 and dst as viewed during set
4574
0
    const size_t nb0 = ggml_element_size(src0);
4575
4576
0
    const int im0 = (ne10 == 0 ? 0 : ne10-1);
4577
0
    const int im1 = (ne11 == 0 ? 0 : ne11-1);
4578
0
    const int im2 = (ne12 == 0 ? 0 : ne12-1);
4579
0
    const int im3 = (ne13 == 0 ? 0 : ne13-1);
4580
4581
0
    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
4582
4583
0
    GGML_ASSERT(nb10 == sizeof(float));
4584
4585
    // rows per thread
4586
0
    const int dr = (nr + nth - 1)/nth;
4587
4588
    // row range for this thread
4589
0
    const int ir0 = dr*ith;
4590
0
    const int ir1 = MIN(ir0 + dr, nr);
4591
4592
0
    for (int ir = ir0; ir < ir1; ++ir) {
4593
        // src0 and dst are viewed with shape of src1 and offset
4594
        // => same indices
4595
0
        const int i3 = ir/(ne12*ne11);
4596
0
        const int i2 = (ir - i3*ne12*ne11)/ne11;
4597
0
        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
4598
4599
0
        ggml_vec_cpy_f32(nc,
4600
0
                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
4601
0
                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
4602
0
    }
4603
0
}
4604
4605
static void ggml_compute_forward_set_i32(
4606
        const ggml_compute_params * params,
4607
0
        ggml_tensor * dst) {
4608
4609
0
    const ggml_tensor * src0 = dst->src[0];
4610
0
    const ggml_tensor * src1 = dst->src[1];
4611
4612
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4613
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
4614
4615
    // view src0 and dst with these strides and data offset inbytes during set
4616
    // nb0 is implicitly element_size because src0 and dst are contiguous
4617
0
    size_t nb1     = ((int32_t *) dst->op_params)[0];
4618
0
    size_t nb2     = ((int32_t *) dst->op_params)[1];
4619
0
    size_t nb3     = ((int32_t *) dst->op_params)[2];
4620
0
    size_t offset  = ((int32_t *) dst->op_params)[3];
4621
0
    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
4622
4623
0
    if (!inplace) {
4624
0
        if (params->ith == 0) {
4625
            // memcpy needs to be synchronized across threads to avoid race conditions.
4626
            // => do it in INIT phase
4627
0
            memcpy(
4628
0
                ((char *)  dst->data),
4629
0
                ((char *) src0->data),
4630
0
                ggml_nbytes(dst));
4631
0
        }
4632
0
        ggml_barrier(params->threadpool);
4633
0
    }
4634
4635
0
    const int ith = params->ith;
4636
0
    const int nth = params->nth;
4637
4638
0
    const int nr = ggml_nrows(src1);
4639
0
    const int nc = src1->ne[0];
4640
4641
0
    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
4642
0
    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
4643
4644
    // src0 and dst as viewed during set
4645
0
    const size_t nb0 = ggml_element_size(src0);
4646
4647
0
    const int im0 = (ne10 == 0 ? 0 : ne10-1);
4648
0
    const int im1 = (ne11 == 0 ? 0 : ne11-1);
4649
0
    const int im2 = (ne12 == 0 ? 0 : ne12-1);
4650
0
    const int im3 = (ne13 == 0 ? 0 : ne13-1);
4651
4652
0
    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
4653
4654
0
    GGML_ASSERT(nb10 == sizeof(int32_t));
4655
4656
    // rows per thread
4657
0
    const int dr = (nr + nth - 1)/nth;
4658
4659
    // row range for this thread
4660
0
    const int ir0 = dr*ith;
4661
0
    const int ir1 = MIN(ir0 + dr, nr);
4662
4663
0
    for (int ir = ir0; ir < ir1; ++ir) {
4664
        // src0 and dst are viewed with shape of src1 and offset
4665
        // => same indices
4666
0
        const int i3 = ir/(ne12*ne11);
4667
0
        const int i2 = (ir - i3*ne12*ne11)/ne11;
4668
0
        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
4669
4670
0
        ggml_vec_cpy_i32(nc,
4671
0
                (int32_t *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
4672
0
                (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
4673
0
    }
4674
0
}
4675
4676
void ggml_compute_forward_set(
4677
        const ggml_compute_params * params,
4678
0
        ggml_tensor * dst) {
4679
4680
0
    const ggml_tensor * src0 = dst->src[0];
4681
4682
0
    switch (src0->type) {
4683
0
        case GGML_TYPE_F32:
4684
0
            {
4685
0
                ggml_compute_forward_set_f32(params, dst);
4686
0
            } break;
4687
0
        case GGML_TYPE_I32:
4688
0
            {
4689
0
                ggml_compute_forward_set_i32(params, dst);
4690
0
            } break;
4691
0
        case GGML_TYPE_F16:
4692
0
        case GGML_TYPE_BF16:
4693
0
        case GGML_TYPE_Q1_0:
4694
0
        case GGML_TYPE_Q4_0:
4695
0
        case GGML_TYPE_Q4_1:
4696
0
        case GGML_TYPE_Q5_0:
4697
0
        case GGML_TYPE_Q5_1:
4698
0
        case GGML_TYPE_Q8_0:
4699
0
        case GGML_TYPE_Q8_1:
4700
0
        case GGML_TYPE_MXFP4:
4701
0
        case GGML_TYPE_NVFP4:
4702
0
        case GGML_TYPE_Q2_K:
4703
0
        case GGML_TYPE_Q3_K:
4704
0
        case GGML_TYPE_Q4_K:
4705
0
        case GGML_TYPE_Q5_K:
4706
0
        case GGML_TYPE_Q6_K:
4707
0
        case GGML_TYPE_TQ1_0:
4708
0
        case GGML_TYPE_TQ2_0:
4709
0
        case GGML_TYPE_IQ2_XXS:
4710
0
        case GGML_TYPE_IQ2_XS:
4711
0
        case GGML_TYPE_IQ3_XXS:
4712
0
        case GGML_TYPE_IQ1_S:
4713
0
        case GGML_TYPE_IQ1_M:
4714
0
        case GGML_TYPE_IQ4_NL:
4715
0
        case GGML_TYPE_IQ4_XS:
4716
0
        case GGML_TYPE_IQ3_S:
4717
0
        case GGML_TYPE_IQ2_S:
4718
0
        default:
4719
0
            {
4720
0
                GGML_ABORT("fatal error");
4721
0
            }
4722
0
    }
4723
0
}
4724
4725
// ggml_compute_forward_cpy
4726
4727
void ggml_compute_forward_cpy(
4728
        const ggml_compute_params * params,
4729
0
        ggml_tensor * dst) {
4730
0
    ggml_compute_forward_dup(params, dst);
4731
0
}
4732
4733
// ggml_compute_forward_cont
4734
4735
void ggml_compute_forward_cont(
4736
        const ggml_compute_params * params,
4737
0
        ggml_tensor * dst) {
4738
0
    ggml_compute_forward_dup(params, dst);
4739
0
}
4740
4741
// ggml_compute_forward_get_rows
4742
4743
static void ggml_compute_forward_get_rows_q(
4744
        const ggml_compute_params * params,
4745
0
              ggml_tensor * dst) {
4746
4747
0
    const ggml_tensor * src0 = dst->src[0];
4748
0
    const ggml_tensor * src1 = dst->src[1];
4749
4750
0
    GGML_TENSOR_BINARY_OP_LOCALS
4751
4752
0
    const int64_t nc = ne00;
4753
0
    const int64_t nr = ggml_nelements(src1);
4754
4755
0
    const ggml_type type = src0->type;
4756
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
4757
4758
0
    assert(ne0  == nc);
4759
0
    assert(ne02 == ne11);
4760
0
    assert(nb00 == ggml_type_size(type));
4761
0
    assert(ggml_nrows(dst) == nr);
4762
4763
0
    const int ith = params->ith;
4764
0
    const int nth = params->nth;
4765
4766
    // rows per thread
4767
0
    const int dr = (nr + nth - 1)/nth;
4768
4769
    // row range for this thread
4770
0
    const int ir0 = dr*ith;
4771
0
    const int ir1 = MIN(ir0 + dr, nr);
4772
4773
0
    for (int64_t i = ir0; i < ir1; ++i) {
4774
0
        const int64_t i12 = i/(ne11*ne10);
4775
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4776
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4777
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4778
4779
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4780
4781
0
        dequantize_row_q(
4782
0
                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4783
0
                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
4784
0
    }
4785
0
}
4786
4787
static void ggml_compute_forward_get_rows_f16(
4788
        const ggml_compute_params * params,
4789
0
              ggml_tensor * dst) {
4790
4791
0
    const ggml_tensor * src0 = dst->src[0];
4792
0
    const ggml_tensor * src1 = dst->src[1];
4793
4794
0
    GGML_TENSOR_BINARY_OP_LOCALS
4795
4796
0
    const int64_t nc = ne00;
4797
0
    const int64_t nr = ggml_nelements(src1);
4798
4799
0
    assert(ne0  == nc);
4800
0
    assert(ne02 == ne11);
4801
0
    assert(nb00 == sizeof(ggml_fp16_t));
4802
0
    assert(ggml_nrows(dst) == nr);
4803
4804
0
    const int ith = params->ith;
4805
0
    const int nth = params->nth;
4806
4807
    // rows per thread
4808
0
    const int dr = (nr + nth - 1)/nth;
4809
4810
    // row range for this thread
4811
0
    const int ir0 = dr*ith;
4812
0
    const int ir1 = MIN(ir0 + dr, nr);
4813
4814
0
    for (int64_t i = ir0; i < ir1; ++i) {
4815
0
        const int64_t i12 = i/(ne11*ne10);
4816
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4817
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4818
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4819
4820
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4821
4822
0
        ggml_cpu_fp16_to_fp32(
4823
0
            (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4824
0
                       (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
4825
0
    }
4826
0
}
4827
4828
static void ggml_compute_forward_get_rows_bf16(
4829
        const ggml_compute_params * params,
4830
0
              ggml_tensor * dst) {
4831
4832
0
    const ggml_tensor * src0 = dst->src[0];
4833
0
    const ggml_tensor * src1 = dst->src[1];
4834
4835
0
    GGML_TENSOR_BINARY_OP_LOCALS
4836
4837
0
    const int64_t nc = ne00;
4838
0
    const int64_t nr = ggml_nelements(src1);
4839
4840
0
    assert(ne0  == nc);
4841
0
    assert(ne02 == ne11);
4842
0
    assert(nb00 == sizeof(ggml_bf16_t));
4843
0
    assert(ggml_nrows(dst) == nr);
4844
4845
0
    const int ith = params->ith;
4846
0
    const int nth = params->nth;
4847
4848
    // rows per thread
4849
0
    const int dr = (nr + nth - 1)/nth;
4850
4851
    // row range for this thread
4852
0
    const int ir0 = dr*ith;
4853
0
    const int ir1 = MIN(ir0 + dr, nr);
4854
4855
0
    for (int64_t i = ir0; i < ir1; ++i) {
4856
0
        const int64_t i12 = i/(ne11*ne10);
4857
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4858
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4859
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4860
4861
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4862
4863
0
        ggml_cpu_bf16_to_fp32(
4864
0
            (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4865
0
                        (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
4866
0
    }
4867
0
}
4868
4869
static void ggml_compute_forward_get_rows_f32(
4870
        const ggml_compute_params * params,
4871
0
              ggml_tensor * dst) {
4872
4873
0
    const ggml_tensor * src0 = dst->src[0];
4874
0
    const ggml_tensor * src1 = dst->src[1];
4875
4876
0
    GGML_TENSOR_BINARY_OP_LOCALS
4877
4878
0
    const int64_t nc = ne00;
4879
0
    const int64_t nr = ggml_nelements(src1);
4880
4881
0
    assert(ne0  == nc);
4882
0
    assert(ne02 == ne11);
4883
0
    assert(nb00 == sizeof(float));
4884
0
    assert(ggml_nrows(dst) == nr);
4885
4886
0
    const int ith = params->ith;
4887
0
    const int nth = params->nth;
4888
4889
    // rows per thread
4890
0
    const int dr = (nr + nth - 1)/nth;
4891
4892
    // row range for this thread
4893
0
    const int ir0 = dr*ith;
4894
0
    const int ir1 = MIN(ir0 + dr, nr);
4895
4896
0
    for (int64_t i = ir0; i < ir1; ++i) {
4897
0
        const int64_t i12 = i/(ne11*ne10);
4898
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4899
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4900
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4901
4902
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4903
4904
0
        ggml_vec_cpy_f32(nc,
4905
0
                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
4906
0
                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
4907
0
    }
4908
0
}
4909
4910
void ggml_compute_forward_get_rows(
4911
        const ggml_compute_params * params,
4912
0
        ggml_tensor * dst) {
4913
4914
0
    const ggml_tensor * src0 = dst->src[0];
4915
4916
0
    switch (src0->type) {
4917
0
        case GGML_TYPE_Q1_0:
4918
0
        case GGML_TYPE_Q4_0:
4919
0
        case GGML_TYPE_Q4_1:
4920
0
        case GGML_TYPE_Q5_0:
4921
0
        case GGML_TYPE_Q5_1:
4922
0
        case GGML_TYPE_Q8_0:
4923
0
        case GGML_TYPE_Q8_1:
4924
0
        case GGML_TYPE_MXFP4:
4925
0
        case GGML_TYPE_NVFP4:
4926
0
        case GGML_TYPE_Q2_K:
4927
0
        case GGML_TYPE_Q3_K:
4928
0
        case GGML_TYPE_Q4_K:
4929
0
        case GGML_TYPE_Q5_K:
4930
0
        case GGML_TYPE_Q6_K:
4931
0
        case GGML_TYPE_TQ1_0:
4932
0
        case GGML_TYPE_TQ2_0:
4933
0
        case GGML_TYPE_IQ2_XXS:
4934
0
        case GGML_TYPE_IQ2_XS:
4935
0
        case GGML_TYPE_IQ3_XXS:
4936
0
        case GGML_TYPE_IQ1_S:
4937
0
        case GGML_TYPE_IQ1_M:
4938
0
        case GGML_TYPE_IQ4_NL:
4939
0
        case GGML_TYPE_IQ4_XS:
4940
0
        case GGML_TYPE_IQ3_S:
4941
0
        case GGML_TYPE_IQ2_S:
4942
0
            {
4943
0
                ggml_compute_forward_get_rows_q(params, dst);
4944
0
            } break;
4945
0
        case GGML_TYPE_F16:
4946
0
            {
4947
0
                ggml_compute_forward_get_rows_f16(params, dst);
4948
0
            } break;
4949
0
        case GGML_TYPE_BF16:
4950
0
            {
4951
0
                ggml_compute_forward_get_rows_bf16(params, dst);
4952
0
            } break;
4953
0
        case GGML_TYPE_F32:
4954
0
        case GGML_TYPE_I32:
4955
0
            {
4956
0
                ggml_compute_forward_get_rows_f32(params, dst);
4957
0
            } break;
4958
0
        default:
4959
0
            {
4960
0
                GGML_ABORT("fatal error");
4961
0
            }
4962
0
    }
4963
4964
    //static bool first = true;
4965
    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
4966
    //if (first) {
4967
    //    first = false;
4968
    //} else {
4969
    //    for (int k = 0; k < dst->ne[1]; ++k) {
4970
    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
4971
    //            for (int i = 0; i < 16; ++i) {
4972
    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
4973
    //            }
4974
    //            printf("\n");
4975
    //        }
4976
    //        printf("\n");
4977
    //    }
4978
    //    printf("\n");
4979
    //    exit(0);
4980
    //}
4981
0
}
4982
4983
template<typename idx_t>
4984
static void ggml_compute_forward_set_rows_f32(
4985
        const ggml_compute_params * params,
4986
0
              ggml_tensor * dst) {
4987
4988
0
    const ggml_tensor * src0 = dst->src[0];
4989
0
    const ggml_tensor * src1 = dst->src[1];
4990
4991
0
    GGML_TENSOR_BINARY_OP_LOCALS
4992
4993
0
    const int64_t nc = ne00;
4994
0
    const int64_t nr = ne01;
4995
4996
0
    assert(ne0  == nc);
4997
0
    assert(ne2  == ne02);
4998
0
    assert(ne3  == ne03);
4999
0
    assert(src0->type == GGML_TYPE_F32);
5000
0
    assert(ne02 % ne11 == 0);
5001
0
    assert(ne03 % ne12 == 0);
5002
5003
0
    const int ith = params->ith;
5004
0
    const int nth = params->nth;
5005
5006
    // rows per thread
5007
0
    const int64_t dr = (nr + nth - 1)/nth;
5008
5009
    // row range for this thread
5010
0
    const int64_t ir0 = dr*ith;
5011
0
    const int64_t ir1 = std::min(ir0 + dr, nr);
5012
5013
0
    ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
5014
5015
0
    for (int64_t i03 = 0; i03 < ne03; ++i03) {
5016
0
        for (int64_t i02 = 0; i02 < ne02; ++i02) {
5017
0
            for (int64_t i = ir0; i < ir1; ++i) {
5018
0
                const int64_t i12 = i03%ne12;
5019
0
                const int64_t i11 = i02%ne11;
5020
0
                const int64_t i10 = i;
5021
5022
0
                const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
5023
5024
0
                GGML_ASSERT(i1 >= 0 && i1 < ne1);
5025
5026
0
                from_float(
5027
0
                        (const float *) ((char *) src0->data +  i*nb01 + i02*nb02 + i03*nb03),
5028
0
                                        ((char *)  dst->data + i1*nb1  + i02*nb2  + i03*nb3), nc);
5029
0
            }
5030
0
        }
5031
0
    }
5032
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_set_rows_f32<long>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_set_rows_f32<int>(ggml_compute_params const*, ggml_tensor*)
5033
5034
void ggml_compute_forward_set_rows(
5035
        const ggml_compute_params * params,
5036
0
        ggml_tensor * dst) {
5037
5038
0
    const ggml_tensor * src0 = dst->src[0];
5039
0
    const ggml_tensor * src1 = dst->src[1];
5040
5041
0
    switch (src0->type) {
5042
0
        case GGML_TYPE_F32:
5043
0
            {
5044
0
                if (src1->type == GGML_TYPE_I64) {
5045
0
                    ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
5046
0
                } else if (src1->type == GGML_TYPE_I32) {
5047
0
                    ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
5048
0
                } else {
5049
0
                    GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
5050
0
                }
5051
0
            } break;
5052
0
        default:
5053
0
            {
5054
0
                GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
5055
0
            }
5056
0
    }
5057
0
}
5058
5059
// ggml_compute_forward_get_rows_back
5060
5061
static void ggml_compute_forward_get_rows_back_f32_f16(
5062
        const ggml_compute_params * params,
5063
0
              ggml_tensor * dst) {
5064
5065
0
    const ggml_tensor * src0 = dst->src[0];
5066
0
    const ggml_tensor * src1 = dst->src[1];
5067
5068
0
    if (params->ith != 0) {
5069
0
        return;
5070
0
    }
5071
5072
0
    GGML_ASSERT(ggml_is_contiguous(dst));
5073
5074
    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
5075
5076
0
    memset(dst->data, 0, ggml_nbytes(dst));
5077
5078
0
    const int nc = src0->ne[0];
5079
0
    const int nr = ggml_nelements(src1);
5080
5081
0
    GGML_ASSERT( dst->ne[0] == nc);
5082
0
    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
5083
5084
0
    for (int i = 0; i < nr; ++i) {
5085
0
        const int r = ((int32_t *) src1->data)[i];
5086
5087
0
        for (int j = 0; j < nc; ++j) {
5088
0
            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
5089
0
            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
5090
0
        }
5091
0
    }
5092
0
}
5093
5094
static void ggml_compute_forward_get_rows_back_f32(
5095
        const ggml_compute_params * params,
5096
0
              ggml_tensor * dst) {
5097
5098
0
    const ggml_tensor * src0 = dst->src[0];
5099
0
    const ggml_tensor * src1 = dst->src[1];
5100
5101
0
    if (params->ith != 0) {
5102
0
        return;
5103
0
    }
5104
5105
0
    GGML_ASSERT(ggml_is_contiguous(dst));
5106
5107
    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
5108
5109
0
    memset(dst->data, 0, ggml_nbytes(dst));
5110
5111
0
    const int nc = src0->ne[0];
5112
0
    const int nr = ggml_nelements(src1);
5113
5114
0
    GGML_ASSERT( dst->ne[0] == nc);
5115
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
5116
5117
0
    for (int i = 0; i < nr; ++i) {
5118
0
        const int r = ((int32_t *) src1->data)[i];
5119
5120
0
        ggml_vec_add_f32(nc,
5121
0
                (float *) ((char *)  dst->data + r*dst->nb[1]),
5122
0
                (float *) ((char *)  dst->data + r*dst->nb[1]),
5123
0
                (float *) ((char *) src0->data + i*src0->nb[1]));
5124
0
    }
5125
0
}
5126
5127
void ggml_compute_forward_get_rows_back(
5128
        const ggml_compute_params * params,
5129
0
        ggml_tensor * dst) {
5130
5131
0
    const ggml_tensor * src0 = dst->src[0];
5132
5133
0
    switch (src0->type) {
5134
0
        case GGML_TYPE_F16:
5135
0
            {
5136
0
                ggml_compute_forward_get_rows_back_f32_f16(params, dst);
5137
0
            } break;
5138
0
        case GGML_TYPE_F32:
5139
0
            {
5140
0
                ggml_compute_forward_get_rows_back_f32(params, dst);
5141
0
            } break;
5142
0
        default:
5143
0
            {
5144
0
                GGML_ABORT("fatal error");
5145
0
            }
5146
0
    }
5147
5148
    //static bool first = true;
5149
    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
5150
    //if (first) {
5151
    //    first = false;
5152
    //} else {
5153
    //    for (int k = 0; k < dst->ne[1]; ++k) {
5154
    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
5155
    //            for (int i = 0; i < 16; ++i) {
5156
    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
5157
    //            }
5158
    //            printf("\n");
5159
    //        }
5160
    //        printf("\n");
5161
    //    }
5162
    //    printf("\n");
5163
    //    exit(0);
5164
    //}
5165
0
}
5166
5167
// ggml_compute_forward_diag
5168
5169
static void ggml_compute_forward_diag_f32(
5170
        const ggml_compute_params * params,
5171
0
        ggml_tensor * dst) {
5172
5173
0
    const ggml_tensor * src0 = dst->src[0];
5174
5175
0
    if (params->ith != 0) {
5176
0
        return;
5177
0
    }
5178
5179
    // TODO: handle transposed/permuted matrices
5180
5181
0
    GGML_TENSOR_UNARY_OP_LOCALS
5182
5183
0
    GGML_ASSERT(ne00 == ne0);
5184
0
    GGML_ASSERT(ne00 == ne1);
5185
0
    GGML_ASSERT(ne01 == 1);
5186
0
    GGML_ASSERT(ne02 == ne2);
5187
0
    GGML_ASSERT(ne03 == ne3);
5188
5189
0
    GGML_ASSERT(nb00 == sizeof(float));
5190
0
    GGML_ASSERT(nb0  == sizeof(float));
5191
5192
0
    for (int i3 = 0; i3 < ne3; i3++) {
5193
0
        for (int i2 = 0; i2 < ne2; i2++) {
5194
0
            for (int i1 = 0; i1 < ne1; i1++) {
5195
0
                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);
5196
0
                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
5197
0
                for (int i0 = 0; i0 < i1; i0++) {
5198
0
                    d[i0] = 0;
5199
0
                }
5200
0
                d[i1] = s[i1];
5201
0
                for (int i0 = i1+1; i0 < ne0; i0++) {
5202
0
                    d[i0] = 0;
5203
0
                }
5204
0
            }
5205
0
        }
5206
0
    }
5207
0
}
5208
5209
void ggml_compute_forward_diag(
5210
        const ggml_compute_params * params,
5211
0
        ggml_tensor * dst) {
5212
5213
0
    const ggml_tensor * src0 = dst->src[0];
5214
5215
0
    switch (src0->type) {
5216
0
        case GGML_TYPE_F32:
5217
0
            {
5218
0
                ggml_compute_forward_diag_f32(params, dst);
5219
0
            } break;
5220
0
        default:
5221
0
            {
5222
0
                GGML_ABORT("fatal error");
5223
0
            }
5224
0
    }
5225
0
}
5226
5227
// ggml_compute_forward_diag_mask_inf
5228
5229
static void ggml_compute_forward_diag_mask_f32(
5230
        const ggml_compute_params * params,
5231
        ggml_tensor * dst,
5232
0
        const float value) {
5233
5234
0
    const ggml_tensor * src0 = dst->src[0];
5235
5236
0
    const int ith = params->ith;
5237
0
    const int nth = params->nth;
5238
5239
0
    const int  n_past  = ((int32_t *) dst->op_params)[0];
5240
0
    const bool inplace = src0->data == dst->data;
5241
5242
0
    GGML_ASSERT(n_past >= 0);
5243
5244
0
    if (!inplace) {
5245
0
        if (ith == 0) {
5246
            // memcpy needs to be synchronized across threads to avoid race conditions.
5247
            // => do it in INIT phase
5248
0
            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5249
0
            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
5250
0
            memcpy(
5251
0
                ((char *)  dst->data),
5252
0
                ((char *) src0->data),
5253
0
                ggml_nbytes(dst));
5254
0
        }
5255
0
        ggml_barrier(params->threadpool);
5256
0
    }
5257
5258
    // TODO: handle transposed/permuted matrices
5259
5260
0
    const int n  = ggml_nrows(src0);
5261
0
    const int nc = src0->ne[0];
5262
0
    const int nr = src0->ne[1];
5263
0
    const int nz = n/nr;
5264
5265
0
    GGML_ASSERT( dst->nb[0] == sizeof(float));
5266
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
5267
5268
0
    for (int k = 0; k < nz; k++) {
5269
0
        for (int j = ith; j < nr; j += nth) {
5270
0
            for (int i = n_past; i < nc; i++) {
5271
0
                if (i > n_past + j) {
5272
0
                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
5273
0
                }
5274
0
            }
5275
0
        }
5276
0
    }
5277
0
}
5278
5279
void ggml_compute_forward_diag_mask_inf(
5280
        const ggml_compute_params * params,
5281
0
        ggml_tensor * dst) {
5282
5283
0
    const ggml_tensor * src0 = dst->src[0];
5284
5285
0
    switch (src0->type) {
5286
0
        case GGML_TYPE_F32:
5287
0
            {
5288
0
                ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
5289
0
            } break;
5290
0
        default:
5291
0
            {
5292
0
                GGML_ABORT("fatal error");
5293
0
            }
5294
0
    }
5295
0
}
5296
5297
void ggml_compute_forward_diag_mask_zero(
5298
        const ggml_compute_params * params,
5299
0
        ggml_tensor * dst) {
5300
5301
0
    const ggml_tensor * src0 = dst->src[0];
5302
5303
0
    switch (src0->type) {
5304
0
        case GGML_TYPE_F32:
5305
0
            {
5306
0
                ggml_compute_forward_diag_mask_f32(params, dst, 0);
5307
0
            } break;
5308
0
        default:
5309
0
            {
5310
0
                GGML_ABORT("fatal error");
5311
0
            }
5312
0
    }
5313
0
}
5314
5315
// ggml_compute_forward_soft_max
5316
5317
static void ggml_compute_forward_soft_max_f32(
5318
        const ggml_compute_params * params,
5319
0
              ggml_tensor * dst) {
5320
5321
0
    const ggml_tensor * src0 = dst->src[0];
5322
0
    const ggml_tensor * src1 = dst->src[1];
5323
0
    const ggml_tensor * src2 = dst->src[2];
5324
5325
0
    assert(ggml_is_contiguous(dst));
5326
0
    assert(ggml_are_same_shape(src0, dst));
5327
5328
0
    float scale    = 1.0f;
5329
0
    float max_bias = 0.0f;
5330
5331
0
    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
5332
0
    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5333
5334
0
    const int ith = params->ith;
5335
0
    const int nth = params->nth;
5336
5337
0
    GGML_TENSOR_UNARY_OP_LOCALS
5338
5339
0
    const int64_t nb11 = src1 ? src1->nb[1] : 1;
5340
0
    const int64_t nb12 = src1 ? src1->nb[2] : 1;
5341
0
    const int64_t nb13 = src1 ? src1->nb[3] : 1;
5342
5343
0
    const int64_t ne12 = src1 ? src1->ne[2] : 1;
5344
0
    const int64_t ne13 = src1 ? src1->ne[3] : 1;
5345
5346
    // TODO: is this supposed to be ceil instead of floor?
5347
    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
5348
0
    const uint32_t n_head      = ne02;
5349
0
    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
5350
5351
0
    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
5352
0
    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5353
5354
0
    float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5355
5356
0
    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5357
5358
    // sinks
5359
0
    const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5360
5361
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
5362
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
5363
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5364
0
                const int64_t i11 = i01;
5365
0
                const int64_t i12 = i02%ne12;
5366
0
                const int64_t i13 = i03%ne13;
5367
5368
                // ALiBi
5369
0
                const uint32_t h = i02; // head
5370
0
                const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5371
5372
0
                float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5373
0
                float * dp = (float *)((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3);
5374
5375
                // broadcast the mask across rows
5376
0
                ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5377
0
                float       * mp_f32 = src1 ? (float       *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5378
5379
0
                ggml_vec_cpy_f32  (ne00, wp, sp);
5380
0
                ggml_vec_scale_f32(ne00, wp, scale);
5381
0
                if (mp_f32) {
5382
0
                    if (use_f16) {
5383
0
                        for (int i = 0; i < ne00; ++i) {
5384
0
                            wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5385
0
                        }
5386
0
                    } else {
5387
0
                        for (int i = 0; i < ne00; ++i) {
5388
0
                            wp[i] += slope*mp_f32[i];
5389
0
                        }
5390
0
                    }
5391
0
                }
5392
5393
#ifndef NDEBUG
5394
                for (int i = 0; i < ne00; ++i) {
5395
                    //printf("p[%d] = %f\n", i, p[i]);
5396
                    assert(!isnan(wp[i]));
5397
                }
5398
#endif // NDEBUG
5399
5400
0
                float max = -INFINITY;
5401
0
                ggml_vec_max_f32(ne00, &max, wp);
5402
5403
                // if we have sinks, make a correction as if they were included in the softmax
5404
0
                if (sk) {
5405
0
                    max = MAX(max, sk[i02]);
5406
0
                }
5407
5408
0
                ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5409
0
                assert(sum > 0.0);
5410
5411
0
                if (sk) {
5412
0
                    sum += (ggml_float) expf(sk[i02] - max);
5413
0
                }
5414
5415
0
                sum = 1.0/sum;
5416
0
                ggml_vec_scale_f32(ne00, dp, sum);
5417
5418
#ifndef NDEBUG
5419
                for (int i = 0; i < ne00; ++i) {
5420
                    assert(!isnan(dp[i]));
5421
                    assert(!isinf(dp[i]));
5422
                }
5423
#endif // NDEBUG
5424
0
            }
5425
0
        }
5426
0
    }
5427
0
}
5428
5429
void ggml_compute_forward_soft_max(
5430
        const ggml_compute_params * params,
5431
0
              ggml_tensor * dst) {
5432
5433
0
    const ggml_tensor * src0 = dst->src[0];
5434
5435
0
    switch (src0->type) {
5436
0
        case GGML_TYPE_F32:
5437
0
            {
5438
0
                ggml_compute_forward_soft_max_f32(params, dst);
5439
0
            } break;
5440
0
        default:
5441
0
            {
5442
0
                GGML_ABORT("fatal error");
5443
0
            }
5444
0
    }
5445
0
}
5446
5447
5448
// ggml_compute_forward_soft_max_ext_back
5449
5450
static void ggml_compute_forward_soft_max_ext_back_f32(
5451
        const ggml_compute_params * params,
5452
0
        ggml_tensor * dst) {
5453
5454
0
    const ggml_tensor * src0 = dst->src[0];
5455
0
    const ggml_tensor * src1 = dst->src[1];
5456
5457
0
    GGML_ASSERT(ggml_is_contiguous(src0));
5458
0
    GGML_ASSERT(ggml_is_contiguous(src1));
5459
0
    GGML_ASSERT(ggml_is_contiguous(dst));
5460
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
5461
0
    GGML_ASSERT(ggml_are_same_shape(src1, dst));
5462
5463
0
    float scale    = 1.0f;
5464
0
    float max_bias = 0.0f;
5465
5466
0
    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
5467
0
    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
5468
5469
0
    GGML_ASSERT(max_bias == 0.0f);
5470
5471
    // TODO: handle transposed/permuted matrices
5472
5473
0
    const int ith = params->ith;
5474
0
    const int nth = params->nth;
5475
5476
0
    const int nc = src0->ne[0];
5477
0
    const int nr = ggml_nrows(src0);
5478
5479
    // rows per thread
5480
0
    const int dr = (nr + nth - 1)/nth;
5481
5482
    // row range for this thread
5483
0
    const int ir0 = dr*ith;
5484
0
    const int ir1 = MIN(ir0 + dr, nr);
5485
5486
0
    for (int i1 = ir0; i1 < ir1; i1++) {
5487
0
        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
5488
0
        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
5489
0
        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
5490
5491
#ifndef NDEBUG
5492
        for (int i = 0; i < nc; ++i) {
5493
            //printf("p[%d] = %f\n", i, p[i]);
5494
            assert(!isnan(dy[i]));
5495
            assert(!isnan(y[i]));
5496
        }
5497
#endif // NDEBUG
5498
        // Jii = yi - yi*yi
5499
        // Jij = -yi*yj
5500
        // J = diag(y)-y.T*y
5501
        // dx = J * dy
5502
        // dxk = sum_i(Jki * dyi)
5503
        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
5504
        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
5505
        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
5506
        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
5507
        // dxk = -yk * dot(y, dy) + yk*dyk
5508
        // dxk = yk * (- dot(y, dy) + dyk)
5509
        // dxk = yk * (dyk - dot(y, dy))
5510
        //
5511
        // post-order:
5512
        // dot_y_dy := dot(y, dy)
5513
        // dx := dy
5514
        // dx := dx - dot_y_dy
5515
        // dx := dx * y
5516
5517
        // linear runtime, no additional memory
5518
0
        float dot_y_dy = 0;
5519
0
        ggml_vec_dot_f32  (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
5520
0
        ggml_vec_cpy_f32  (nc, dx, dy);
5521
0
        ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
5522
0
        ggml_vec_mul_f32  (nc, dx, dx, y);
5523
0
        ggml_vec_scale_f32(nc, dx, scale);
5524
5525
#ifndef NDEBUG
5526
        for (int i = 0; i < nc; ++i) {
5527
            assert(!isnan(dx[i]));
5528
            assert(!isinf(dx[i]));
5529
        }
5530
#endif // NDEBUG
5531
0
    }
5532
0
}
5533
5534
void ggml_compute_forward_soft_max_ext_back(
5535
        const ggml_compute_params * params,
5536
0
        ggml_tensor * dst) {
5537
5538
0
    const ggml_tensor * src0 = dst->src[0];
5539
5540
0
    switch (src0->type) {
5541
0
        case GGML_TYPE_F32:
5542
0
            {
5543
0
                ggml_compute_forward_soft_max_ext_back_f32(params, dst);
5544
0
            } break;
5545
0
        default:
5546
0
            {
5547
0
                GGML_ABORT("fatal error");
5548
0
            }
5549
0
    }
5550
0
}
5551
5552
// ggml_compute_forward_clamp
5553
5554
static void ggml_compute_forward_clamp_f32(
5555
        const ggml_compute_params * params,
5556
0
        ggml_tensor * dst) {
5557
5558
0
    const ggml_tensor * src0 = dst->src[0];
5559
5560
0
    float min;
5561
0
    float max;
5562
0
    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
5563
0
    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
5564
5565
0
    const int ith = params->ith;
5566
0
    const int nth = params->nth;
5567
5568
0
    const int n  = ggml_nrows(src0);
5569
0
    const int nc = src0->ne[0];
5570
5571
0
    const size_t nb00 = src0->nb[0];
5572
0
    const size_t nb01 = src0->nb[1];
5573
5574
0
    const size_t nb0 = dst->nb[0];
5575
0
    const size_t nb1 = dst->nb[1];
5576
5577
0
    GGML_ASSERT( nb0 == sizeof(float));
5578
0
    GGML_ASSERT(nb00 == sizeof(float));
5579
5580
0
    for (int j = ith; j < n; j += nth) {
5581
0
        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
5582
0
        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
5583
5584
0
        for (int i = 0; i < nc; i++) {
5585
0
            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
5586
0
        }
5587
0
    }
5588
0
}
5589
5590
static void ggml_compute_forward_clamp_f16(
5591
    const ggml_compute_params * params,
5592
0
    ggml_tensor * dst) {
5593
5594
0
    const ggml_tensor * src0 = dst->src[0];
5595
5596
0
    float min;
5597
0
    float max;
5598
0
    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
5599
0
    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
5600
5601
0
    const int ith = params->ith;
5602
0
    const int nth = params->nth;
5603
5604
0
    const int n  = ggml_nrows(src0);
5605
0
    const int nc = src0->ne[0];
5606
5607
0
    const size_t nb00 = src0->nb[0];
5608
0
    const size_t nb01 = src0->nb[1];
5609
5610
0
    const size_t nb0 = dst->nb[0];
5611
0
    const size_t nb1 = dst->nb[1];
5612
5613
0
    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5614
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5615
5616
0
    for (int j = ith; j < n; j += nth) {
5617
0
        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *)  dst->data + j*nb1);
5618
0
        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5619
5620
0
        for (int i = 0; i < nc; i++) {
5621
0
            float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
5622
0
            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
5623
0
        }
5624
0
    }
5625
0
}
5626
5627
void ggml_compute_forward_clamp(
5628
        const ggml_compute_params * params,
5629
0
        ggml_tensor * dst) {
5630
5631
0
    const ggml_tensor * src0 = dst->src[0];
5632
5633
0
    switch (src0->type) {
5634
0
        case GGML_TYPE_F32:
5635
0
            {
5636
0
                ggml_compute_forward_clamp_f32(params, dst);
5637
0
            } break;
5638
0
        case GGML_TYPE_F16:
5639
0
            {
5640
0
                ggml_compute_forward_clamp_f16(params, dst);
5641
0
            } break;
5642
0
        case GGML_TYPE_BF16:
5643
0
        case GGML_TYPE_Q1_0:
5644
0
        case GGML_TYPE_Q4_0:
5645
0
        case GGML_TYPE_Q4_1:
5646
0
        case GGML_TYPE_Q5_0:
5647
0
        case GGML_TYPE_Q5_1:
5648
0
        case GGML_TYPE_Q8_0:
5649
0
        case GGML_TYPE_Q8_1:
5650
0
        case GGML_TYPE_MXFP4:
5651
0
        case GGML_TYPE_NVFP4:
5652
0
        case GGML_TYPE_Q2_K:
5653
0
        case GGML_TYPE_Q3_K:
5654
0
        case GGML_TYPE_Q4_K:
5655
0
        case GGML_TYPE_Q5_K:
5656
0
        case GGML_TYPE_Q6_K:
5657
0
        case GGML_TYPE_TQ1_0:
5658
0
        case GGML_TYPE_TQ2_0:
5659
0
        case GGML_TYPE_IQ2_XXS:
5660
0
        case GGML_TYPE_IQ2_XS:
5661
0
        case GGML_TYPE_IQ3_XXS:
5662
0
        case GGML_TYPE_IQ1_S:
5663
0
        case GGML_TYPE_IQ1_M:
5664
0
        case GGML_TYPE_IQ4_NL:
5665
0
        case GGML_TYPE_IQ4_XS:
5666
0
        case GGML_TYPE_IQ3_S:
5667
0
        case GGML_TYPE_IQ2_S:
5668
0
        case GGML_TYPE_Q8_K:
5669
0
        case GGML_TYPE_I8:
5670
0
        case GGML_TYPE_I16:
5671
0
        case GGML_TYPE_I32:
5672
0
        case GGML_TYPE_I64:
5673
0
        case GGML_TYPE_F64:
5674
0
        case GGML_TYPE_COUNT:
5675
0
            {
5676
0
                GGML_ABORT("fatal error");
5677
0
            }
5678
0
    }
5679
0
}
5680
5681
// ggml_compute_forward_rope
5682
5683
0
static float rope_yarn_ramp(const float low, const float high, const int i0) {
5684
0
    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
5685
0
    return 1 - MIN(1, MAX(0, y));
5686
0
}
5687
5688
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
5689
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
5690
static void rope_yarn(
5691
    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
5692
0
    float * cos_theta, float * sin_theta) {
5693
    // Get n-d rotational scaling corrected for extrapolation
5694
0
    float theta_interp = freq_scale * theta_extrap;
5695
0
    float theta = theta_interp;
5696
0
    if (ext_factor != 0.0f) {
5697
0
        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
5698
0
        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
5699
5700
        // Get n-d magnitude scaling corrected for interpolation
5701
0
        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
5702
0
    }
5703
0
    *cos_theta = cosf(theta) * mscale;
5704
0
    *sin_theta = sinf(theta) * mscale;
5705
0
}
5706
5707
static void ggml_rope_cache_init(
5708
     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5709
0
     float * cache, float sin_sign, float theta_scale) {
5710
    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5711
0
    float theta = theta_base;
5712
0
    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5713
0
        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5714
0
        rope_yarn(
5715
0
            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5716
0
        );
5717
0
        cache[i0 + 1] *= sin_sign;
5718
5719
0
        theta *= theta_scale;
5720
0
    }
5721
0
}
5722
5723
static void ggml_mrope_cache_init(
5724
     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5725
     float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5726
0
     float * cache, float sin_sign, float theta_scale) {
5727
    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5728
0
    float theta_t = theta_base_t;
5729
0
    float theta_h = theta_base_h;
5730
0
    float theta_w = theta_base_w;
5731
0
    float theta_e = theta_base_e;  // extra position id for vision encoder
5732
0
    int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
5733
0
    int sec_w = sections[1] + sections[0];
5734
0
    int sec_e = sections[2] + sec_w;
5735
0
    GGML_ASSERT(sect_dims <= ne0);
5736
5737
0
    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5738
0
        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5739
5740
0
        int sector = (i0 / 2) % sect_dims;
5741
0
        if (indep_sects) {
5742
            // compute theta independently for each dim sections
5743
            // (i.e. reset corresponding theta when `i0` go from one section to another)
5744
0
            if (sector == 0) {
5745
0
                theta_t = theta_base_t;
5746
0
            }
5747
0
            else if (sector == sections[0]) {
5748
0
                theta_h = theta_base_h;;
5749
0
            }
5750
0
            else if (sector == sec_w) {
5751
0
                theta_w = theta_base_w;
5752
0
            }
5753
0
            else if (sector == sec_e) {
5754
0
                theta_e = theta_base_e;
5755
0
            }
5756
0
        }
5757
5758
0
        float theta = theta_t;
5759
0
        if (is_imrope) { // qwen3vl apply interleaved mrope
5760
0
            if (sector % 3 == 1 && sector < 3 * sections[1]) {
5761
0
                theta = theta_h;
5762
0
            } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5763
0
                theta = theta_w;
5764
0
            } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5765
0
                theta = theta_t;
5766
0
            } else {
5767
0
                theta = theta_e;
5768
0
            }
5769
0
        } else {
5770
0
            if (sector >= sections[0] && sector < sec_w) {
5771
0
                theta = theta_h;
5772
0
            }
5773
0
            else if (sector >= sec_w && sector < sec_w + sections[2]) {
5774
0
                theta = theta_w;
5775
0
            }
5776
0
            else if (sector >= sec_w + sections[2]) {
5777
0
                theta = theta_e;
5778
0
            }
5779
0
        }
5780
5781
0
        rope_yarn(
5782
0
            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5783
0
        );
5784
0
        cache[i0 + 1] *= sin_sign;
5785
5786
0
        theta_t *= theta_scale;
5787
0
        theta_w *= theta_scale;
5788
0
        theta_h *= theta_scale;
5789
0
        theta_e *= theta_scale;
5790
0
    }
5791
0
}
5792
5793
5794
template<typename T>
5795
0
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
5796
0
  for (int64_t i0 = 0; i0 < n; i0 += 2) {
5797
0
    const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5798
5799
0
    const float cos_theta = cache[i0 + 0];
5800
0
    const float sin_theta = cache[i0 + 1];
5801
5802
0
    const T * const src = src_data + ic;
5803
0
    T * dst             = dst_data + ic;
5804
5805
0
    const float x0 = type_conversion_table<T>::to_f32(src[0]);
5806
0
    const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
5807
5808
0
    dst[0]        = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5809
0
    dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
5810
0
  }
5811
0
}
Unexecuted instantiation: ops.cpp:void rotate_pairs<unsigned short>(long, long, float const*, unsigned short const*, unsigned short*, int)
Unexecuted instantiation: ops.cpp:void rotate_pairs<float>(long, long, float const*, float const*, float*, int)
5812
5813
template<typename T> //float or ggml_fp16_t
5814
static void ggml_compute_forward_rope_flt(
5815
        const ggml_compute_params * params,
5816
        ggml_tensor * dst,
5817
0
        const bool forward) {
5818
5819
0
    const ggml_tensor * src0 = dst->src[0];
5820
0
    const ggml_tensor * src1 = dst->src[1];
5821
0
    const ggml_tensor * src2 = dst->src[2];
5822
5823
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5824
0
    GGML_ASSERT(src1->type == GGML_TYPE_I32);
5825
5826
0
    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5827
0
    int sections[4];
5828
5829
    //const int n_past     = ((int32_t *) dst->op_params)[0];
5830
0
    const int n_dims     = ((int32_t *) dst->op_params)[1];
5831
0
    const int mode       = ((int32_t *) dst->op_params)[2];
5832
    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
5833
0
    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5834
5835
0
    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
5836
0
    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
5837
0
    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
5838
0
    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
5839
0
    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
5840
0
    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
5841
0
    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
5842
5843
0
    GGML_TENSOR_UNARY_OP_LOCALS
5844
5845
    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5846
    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5847
5848
0
    GGML_ASSERT(nb0 == nb00);
5849
0
    GGML_ASSERT(nb0 == sizeof(T));
5850
5851
0
    const int ith = params->ith;
5852
0
    const int nth = params->nth;
5853
5854
0
    const int nr = ggml_nrows(dst);
5855
5856
0
    GGML_ASSERT(n_dims <= ne0);
5857
0
    GGML_ASSERT(n_dims % 2 == 0);
5858
5859
    // rows per thread
5860
0
    const int dr = (nr + nth - 1)/nth;
5861
5862
    // row range for this thread
5863
0
    const int ir0 = dr*ith;
5864
0
    const int ir1 = MIN(ir0 + dr, nr);
5865
5866
    // row index used to determine which thread to use
5867
0
    int ir = 0;
5868
5869
0
    const float theta_scale = powf(freq_base, -2.0f/n_dims);
5870
5871
0
    float corr_dims[2];
5872
0
    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5873
5874
0
    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5875
0
    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
5876
0
    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5877
5878
0
    if (mrope_used) {
5879
0
        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5880
0
    }
5881
5882
0
    if (is_vision) {
5883
0
        GGML_ASSERT(n_dims == ne0/2);
5884
0
    }
5885
5886
0
    const float * freq_factors = NULL;
5887
0
    if (src2 != NULL) {
5888
0
        GGML_ASSERT(src2->type == GGML_TYPE_F32);
5889
0
        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5890
0
        freq_factors = (const float *) src2->data;
5891
0
    }
5892
5893
    // backward process uses inverse rotation by cos and sin.
5894
    // cos and sin build a rotation matrix, where the inverse is the transpose.
5895
    // this essentially just switches the sign of sin.
5896
0
    const float sin_sign = forward ? 1.0f : -1.0f;
5897
5898
0
    const int32_t * pos = (const int32_t *) src1->data;
5899
5900
0
    int64_t last_i2 = -1;
5901
5902
0
    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5903
0
        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5904
0
            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5905
0
                if (ir++ < ir0) continue; // skip rows mapped to other threads
5906
0
                if (ir   > ir1) break;
5907
5908
0
                float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5909
0
                if (last_i2 != i2) {
5910
0
                    if (!mrope_used) {
5911
0
                        const int64_t p = pos[i2];
5912
0
                        ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5913
0
                    }
5914
0
                    else {
5915
0
                        const int64_t p_t = pos[i2];
5916
0
                        const int64_t p_h = pos[i2 + ne2];
5917
0
                        const int64_t p_w = pos[i2 + ne2 * 2];
5918
0
                        const int64_t p_e = pos[i2 + ne2 * 3];
5919
0
                        ggml_mrope_cache_init(
5920
0
                            p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5921
0
                            freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5922
0
                    }
5923
5924
0
                    last_i2 = i2;
5925
0
                }
5926
5927
0
                T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5928
0
                T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1);
5929
5930
0
                switch (mode) {
5931
0
                    case GGML_ROPE_TYPE_NORMAL:
5932
0
                        rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
5933
0
                        break;
5934
0
                    case GGML_ROPE_TYPE_NEOX:
5935
0
                    case GGML_ROPE_TYPE_MROPE:
5936
0
                    case GGML_ROPE_TYPE_IMROPE:
5937
0
                        rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
5938
0
                        break;
5939
0
                    case GGML_ROPE_TYPE_VISION:
5940
0
                        rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
5941
0
                        break;
5942
0
                    default:
5943
0
                        GGML_ABORT("rope type not supported");
5944
0
                }
5945
5946
0
                if (!is_vision) {
5947
                    // fill the remain channels with data from src tensor
5948
0
                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5949
0
                        const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5950
0
                        T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
5951
5952
0
                        dst_data[0] = src[0];
5953
0
                        dst_data[1] = src[1];
5954
0
                    }
5955
0
                }
5956
0
            } //attn-heads
5957
0
        }
5958
0
    }
5959
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_rope_flt<unsigned short>(ggml_compute_params const*, ggml_tensor*, bool)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_rope_flt<float>(ggml_compute_params const*, ggml_tensor*, bool)
5960
5961
void ggml_compute_forward_rope(
5962
        const ggml_compute_params * params,
5963
0
        ggml_tensor * dst) {
5964
5965
0
    const ggml_tensor * src0 = dst->src[0];
5966
5967
0
    switch (src0->type) {
5968
0
        case GGML_TYPE_F16:
5969
0
            {
5970
0
                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
5971
0
            } break;
5972
0
        case GGML_TYPE_F32:
5973
0
            {
5974
0
                ggml_compute_forward_rope_flt<float>(params, dst, true);
5975
0
            } break;
5976
0
        default:
5977
0
            {
5978
0
                GGML_ABORT("fatal error");
5979
0
            }
5980
0
    }
5981
0
}
5982
5983
// ggml_compute_forward_rope_back
5984
5985
void ggml_compute_forward_rope_back(
5986
        const ggml_compute_params * params,
5987
0
        ggml_tensor * dst) {
5988
5989
0
    const ggml_tensor * src0 = dst->src[0];
5990
5991
0
    switch (src0->type) {
5992
0
        case GGML_TYPE_F16:
5993
0
            {
5994
0
                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
5995
0
            } break;
5996
0
        case GGML_TYPE_F32:
5997
0
            {
5998
0
                ggml_compute_forward_rope_flt<float>(params, dst, false);
5999
0
            } break;
6000
0
        default:
6001
0
            {
6002
0
                GGML_ABORT("fatal error");
6003
0
            }
6004
0
    }
6005
0
}
6006
6007
// ggml_compute_forward_conv_transpose_1d
6008
6009
static void ggml_compute_forward_conv_transpose_1d_f16_f32(
6010
        const ggml_compute_params * params,
6011
0
              ggml_tensor * dst) {
6012
6013
0
    const ggml_tensor * src0 = dst->src[0];
6014
0
    const ggml_tensor * src1 = dst->src[1];
6015
6016
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
6017
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6018
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6019
6020
0
    GGML_TENSOR_BINARY_OP_LOCALS
6021
6022
0
    const int ith = params->ith;
6023
0
    const int nth = params->nth;
6024
6025
0
    const int nk = ne00*ne01*ne02;
6026
6027
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6028
0
    GGML_ASSERT(nb10 == sizeof(float));
6029
6030
0
    if (ith == 0) {
6031
0
        memset(params->wdata, 0, params->wsize);
6032
6033
        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
6034
0
        {
6035
0
            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
6036
6037
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
6038
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
6039
0
                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
6040
0
                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
6041
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
6042
0
                        dst_data[i00*ne02 + i02] = src[i00];
6043
0
                    }
6044
0
                }
6045
0
            }
6046
0
        }
6047
6048
        // permute source data (src1) from (L x Cin) to (Cin x L)
6049
0
        {
6050
0
            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
6051
0
            ggml_fp16_t * dst_data = wdata;
6052
6053
0
            for (int64_t i11 = 0; i11 < ne11; i11++) {
6054
0
                const float * const src = (float *)((char *) src1->data + i11*nb11);
6055
0
                for (int64_t i10 = 0; i10 < ne10; i10++) {
6056
0
                    dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
6057
0
                }
6058
0
            }
6059
0
        }
6060
6061
        // need to zero dst since we are accumulating into it
6062
0
        memset(dst->data, 0, ggml_nbytes(dst));
6063
0
    }
6064
0
    ggml_barrier(params->threadpool);
6065
6066
0
    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
6067
6068
    // total rows in dst
6069
0
    const int nr = ne1;
6070
6071
    // rows per thread
6072
0
    const int dr = (nr + nth - 1)/nth;
6073
6074
    // row range for this thread
6075
0
    const int ir0 = dr*ith;
6076
0
    const int ir1 = MIN(ir0 + dr, nr);
6077
6078
0
    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
6079
0
    ggml_fp16_t * const wdata_src = wdata + nk;
6080
6081
0
    for (int i1 = ir0; i1 < ir1; i1++) {
6082
0
        float * dst_data = (float *)((char *) dst->data + i1*nb1);
6083
0
        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
6084
0
        for (int i10 = 0; i10 < ne10; i10++) {
6085
0
            const int i1n = i10*ne11;
6086
0
            for (int i00 = 0; i00 < ne00; i00++) {
6087
0
                float v = 0;
6088
0
                ggml_vec_dot_f16(ne02, &v, 0,
6089
0
                        (ggml_fp16_t *)    wdata_src + i1n, 0,
6090
0
                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
6091
0
                dst_data[i10*s0 + i00] += v;
6092
0
            }
6093
0
        }
6094
0
    }
6095
0
}
6096
6097
static void ggml_compute_forward_conv_transpose_1d_f32(
6098
        const ggml_compute_params * params,
6099
0
              ggml_tensor * dst) {
6100
6101
0
    const ggml_tensor * src0 = dst->src[0];
6102
0
    const ggml_tensor * src1 = dst->src[1];
6103
6104
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
6105
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6106
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6107
6108
0
    GGML_TENSOR_BINARY_OP_LOCALS
6109
6110
0
    const int ith = params->ith;
6111
0
    const int nth = params->nth;
6112
6113
0
    const int nk = ne00*ne01*ne02;
6114
6115
0
    GGML_ASSERT(nb00 == sizeof(float));
6116
0
    GGML_ASSERT(nb10 == sizeof(float));
6117
6118
0
    if (ith == 0) {
6119
0
        memset(params->wdata, 0, params->wsize);
6120
6121
        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
6122
0
        {
6123
0
            float * const wdata = (float *) params->wdata + 0;
6124
6125
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
6126
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
6127
0
                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
6128
0
                    float * dst_data = wdata + i01*ne00*ne02;
6129
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
6130
0
                        dst_data[i00*ne02 + i02] = src[i00];
6131
0
                    }
6132
0
                }
6133
0
            }
6134
0
        }
6135
6136
        // prepare source data (src1)
6137
0
        {
6138
0
            float * const wdata = (float *) params->wdata + nk;
6139
0
            float * dst_data = wdata;
6140
6141
0
            for (int64_t i11 = 0; i11 < ne11; i11++) {
6142
0
                const float * const src = (float *)((char *) src1->data + i11*nb11);
6143
0
                for (int64_t i10 = 0; i10 < ne10; i10++) {
6144
0
                    dst_data[i10*ne11 + i11] = src[i10];
6145
0
                }
6146
0
            }
6147
0
        }
6148
6149
        // need to zero dst since we are accumulating into it
6150
0
        memset(dst->data, 0, ggml_nbytes(dst));
6151
0
    }
6152
0
    ggml_barrier(params->threadpool);
6153
6154
0
    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
6155
6156
    // total rows in dst
6157
0
    const int nr = ne1;
6158
6159
    // rows per thread
6160
0
    const int dr = (nr + nth - 1)/nth;
6161
6162
    // row range for this thread
6163
0
    const int ir0 = dr*ith;
6164
0
    const int ir1 = MIN(ir0 + dr, nr);
6165
6166
0
    float * const wdata     = (float *) params->wdata + 0;
6167
0
    float * const wdata_src = wdata + nk;
6168
6169
0
    for (int i1 = ir0; i1 < ir1; i1++) {
6170
0
        float * dst_data = (float *)((char *) dst->data + i1*nb1);
6171
0
        float * wdata_kernel = wdata + i1*ne02*ne00;
6172
0
        for (int i10 = 0; i10 < ne10; i10++) {
6173
0
            const int i1n = i10*ne11;
6174
0
            for (int i00 = 0; i00 < ne00; i00++) {
6175
0
                float v = 0;
6176
0
                ggml_vec_dot_f32(ne02, &v, 0,
6177
0
                        wdata_src + i1n, 0,
6178
0
                        wdata_kernel + i00*ne02, 0, 1);
6179
0
                dst_data[i10*s0 + i00] += v;
6180
0
            }
6181
0
        }
6182
0
    }
6183
0
}
6184
6185
void ggml_compute_forward_conv_transpose_1d(
6186
        const ggml_compute_params * params,
6187
0
              ggml_tensor * dst) {
6188
6189
0
    const ggml_tensor * src0 = dst->src[0];
6190
6191
0
    switch (src0->type) {
6192
0
        case GGML_TYPE_F16:
6193
0
            {
6194
0
                ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
6195
0
            } break;
6196
0
        case GGML_TYPE_F32:
6197
0
            {
6198
0
                ggml_compute_forward_conv_transpose_1d_f32(params, dst);
6199
0
            } break;
6200
0
        default:
6201
0
            {
6202
0
                GGML_ABORT("fatal error");
6203
0
            }
6204
0
    }
6205
0
}
6206
6207
// ggml_compute_forward_im2col_f32
6208
// src0: kernel [OC, IC, KH, KW]
6209
// src1: image [N, IC, IH, IW]
6210
// dst:  result [N, OH, OW, IC*KH*KW]
6211
static void ggml_compute_forward_im2col_f32(
6212
        const ggml_compute_params * params,
6213
0
              ggml_tensor * dst) {
6214
6215
0
    const ggml_tensor * src0 = dst->src[0];
6216
0
    const ggml_tensor * src1 = dst->src[1];
6217
6218
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6219
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6220
6221
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6222
6223
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6224
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6225
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6226
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6227
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6228
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6229
0
    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6230
6231
0
    const int ith = params->ith;
6232
0
    const int nth = params->nth;
6233
6234
0
    const int64_t N  = is_2D ? ne13 : ne12;
6235
0
    const int64_t IC = is_2D ? ne12 : ne11;
6236
0
    const int64_t IH = is_2D ? ne11 : 1;
6237
0
    const int64_t IW = ne10;
6238
6239
0
    const int64_t KH = is_2D ? ne01 : 1;
6240
0
    const int64_t KW = ne00;
6241
6242
0
    const int64_t OH = is_2D ? ne2 : 1;
6243
0
    const int64_t OW = ne1;
6244
6245
0
    int ofs0 = is_2D ? nb13 : nb12;
6246
0
    int ofs1 = is_2D ? nb12 : nb11;
6247
6248
0
    GGML_ASSERT(nb10 == sizeof(float));
6249
6250
    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6251
0
    {
6252
0
        float * const wdata = (float *) dst->data;
6253
6254
0
        for (int64_t in = 0; in < N; in++) {
6255
0
            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
6256
0
                for (int64_t iow = 0; iow < OW; iow++) {
6257
0
                    for (int64_t iic = ith; iic < IC; iic += nth) {
6258
6259
                        // micro kernel
6260
0
                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6261
0
                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6262
6263
0
                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
6264
0
                            for (int64_t ikw = 0; ikw < KW; ikw++) {
6265
0
                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
6266
0
                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
6267
6268
0
                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6269
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6270
0
                                } else {
6271
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
6272
0
                                }
6273
0
                            }
6274
0
                        }
6275
0
                    }
6276
0
                }
6277
0
            }
6278
0
        }
6279
0
    }
6280
0
}
6281
6282
6283
// ggml_compute_forward_im2col_f16
6284
// src0: kernel [OC, IC, KH, KW]
6285
// src1: image [N, IC, IH, IW]
6286
// dst:  result [N, OH, OW, IC*KH*KW]
6287
static void ggml_compute_forward_im2col_f16(
6288
        const ggml_compute_params * params,
6289
0
              ggml_tensor * dst) {
6290
6291
0
    const ggml_tensor * src0 = dst->src[0];
6292
0
    const ggml_tensor * src1 = dst->src[1];
6293
6294
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
6295
0
    GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
6296
0
    GGML_ASSERT( dst->type == GGML_TYPE_F16);
6297
6298
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6299
6300
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6301
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6302
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6303
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6304
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6305
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6306
0
    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6307
6308
0
    const int ith = params->ith;
6309
0
    const int nth = params->nth;
6310
6311
0
    const int64_t N  = is_2D ? ne13 : ne12;
6312
0
    const int64_t IC = is_2D ? ne12 : ne11;
6313
0
    const int64_t IH = is_2D ? ne11 : 1;
6314
0
    const int64_t IW = ne10;
6315
6316
0
    const int64_t KH = is_2D ? ne01 : 1;
6317
0
    const int64_t KW = ne00;
6318
6319
0
    const int64_t OH = is_2D ? ne2 : 1;
6320
0
    const int64_t OW = ne1;
6321
6322
0
    int ofs0 = is_2D ? nb13 : nb12;
6323
0
    int ofs1 = is_2D ? nb12 : nb11;
6324
6325
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6326
0
    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
6327
6328
    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6329
0
    {
6330
0
        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6331
6332
0
        for (int64_t in = 0; in < N; in++) {
6333
0
            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
6334
0
                for (int64_t iow = 0; iow < OW; iow++) {
6335
0
                    for (int64_t iic = ith; iic < IC; iic += nth) {
6336
6337
                        // micro kernel
6338
0
                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6339
0
                        const float * const src_data_f32 = src1->type == GGML_TYPE_F32
6340
0
                            ? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6341
0
                            : nullptr; // [IH, IW]
6342
0
                        const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
6343
0
                            ? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
6344
0
                            : nullptr; // [IH, IW]
6345
6346
0
                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
6347
0
                            for (int64_t ikw = 0; ikw < KW; ikw++) {
6348
0
                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
6349
0
                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
6350
6351
0
                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6352
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6353
0
                                } else {
6354
0
                                    if (src_data_f32 != nullptr) {
6355
0
                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
6356
0
                                    } else {
6357
0
                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
6358
0
                                    }
6359
0
                                }
6360
0
                            }
6361
0
                        }
6362
0
                    }
6363
0
                }
6364
0
            }
6365
0
        }
6366
0
    }
6367
0
}
6368
6369
void ggml_compute_forward_im2col(
6370
        const ggml_compute_params * params,
6371
0
              ggml_tensor * dst) {
6372
0
    switch (dst->type) {
6373
0
        case GGML_TYPE_F16:
6374
0
            {
6375
0
                ggml_compute_forward_im2col_f16(params, dst);
6376
0
            } break;
6377
0
        case GGML_TYPE_F32:
6378
0
            {
6379
0
                ggml_compute_forward_im2col_f32(params, dst);
6380
0
            } break;
6381
0
        default:
6382
0
            {
6383
0
                GGML_ABORT("fatal error");
6384
0
            }
6385
0
    }
6386
0
}
6387
6388
// ggml_compute_forward_im2col_back_f32
6389
6390
void ggml_compute_forward_im2col_back_f32(
6391
        const ggml_compute_params * params,
6392
0
              ggml_tensor * dst) {
6393
6394
0
    const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
6395
0
    const ggml_tensor * src1 = dst->src[1]; // convolution kernel
6396
6397
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
6398
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6399
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6400
6401
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6402
6403
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6404
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6405
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6406
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6407
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6408
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6409
0
    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6410
6411
0
    const int ith = params->ith;
6412
0
    const int nth = params->nth;
6413
6414
0
    const int64_t N  = is_2D ? ne3 : ne2;
6415
0
    const int64_t IC = is_2D ? ne2 : ne1;
6416
0
    const int64_t IH = is_2D ? ne1 : 1;
6417
0
    const int64_t IW = ne0;
6418
6419
0
    const int64_t KH = is_2D ? ne11 : 1;
6420
0
    const int64_t KW = ne10;
6421
6422
0
    const int64_t OH = is_2D ? ne02 : 1;
6423
0
    const int64_t OW = ne01;
6424
6425
0
    int ofs0 = is_2D ? nb3 : nb2;
6426
0
    int ofs1 = is_2D ? nb2 : nb1;
6427
6428
0
    GGML_ASSERT(nb0  == sizeof(float));
6429
6430
    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6431
0
    {
6432
0
        float * const wdata = (float *) dst->data;
6433
6434
0
        for (int64_t in = 0; in < N; in++) {
6435
0
            for (int64_t iic = ith; iic < IC; iic += nth) {
6436
0
                for (int64_t iih = 0; iih < IH; iih++) {
6437
0
                    for (int64_t iiw = 0; iiw < IW; iiw++) {
6438
6439
                        // micro kernel
6440
0
                        float grad = 0.0f;
6441
0
                        for (int64_t ikh = 0; ikh < KH; ikh++) {
6442
0
                            for (int64_t ikw = 0; ikw < KW; ikw++) {
6443
                                // For s0 > 1 some values were skipped over in the forward pass.
6444
                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6445
0
                                const int64_t tmpw = (iiw + p0 - ikw*d0);
6446
0
                                if (tmpw % s0 != 0) {
6447
0
                                    continue;
6448
0
                                }
6449
0
                                const int64_t iow = tmpw / s0;
6450
6451
                                // Equivalent logic as above except for s1.
6452
0
                                int64_t ioh;
6453
0
                                if (is_2D) {
6454
0
                                    const int64_t tmph = iih + p1 - ikh*d1;
6455
6456
0
                                    if (tmph % s1 != 0) {
6457
0
                                        continue;
6458
0
                                    }
6459
6460
0
                                    ioh = tmph / s1;
6461
0
                                } else {
6462
0
                                    ioh = 0;
6463
0
                                }
6464
6465
0
                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
6466
0
                                    continue;
6467
0
                                }
6468
6469
0
                                const float * const grad_in = (const float *) src0->data
6470
0
                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6471
0
                                grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
6472
0
                            }
6473
0
                        }
6474
0
                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
6475
0
                        dst_data[iih*IW + iiw] = grad;
6476
0
                    }
6477
0
                }
6478
0
            }
6479
0
        }
6480
0
    }
6481
0
}
6482
6483
6484
// ggml_compute_forward_im2col_3d_f16
6485
// src0: kernel [OC*IC, KD, KH, KW]
6486
// src1: image [N*IC, ID, IH, IW]
6487
// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]
6488
static void ggml_compute_forward_im2col_3d_f16(
6489
        const ggml_compute_params * params,
6490
0
              ggml_tensor * dst) {
6491
6492
0
    const ggml_tensor * src0 = dst->src[0];
6493
0
    const ggml_tensor * src1 = dst->src[1];
6494
6495
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
6496
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6497
0
    GGML_ASSERT( dst->type == GGML_TYPE_F16);
6498
6499
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6500
6501
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6502
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6503
0
    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6504
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6505
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6506
0
    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6507
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6508
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6509
0
    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6510
0
    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6511
6512
6513
0
    const int ith = params->ith;
6514
0
    const int nth = params->nth;
6515
6516
0
    const int64_t N  = ne13 / IC;
6517
0
    const int64_t ID = ne12;
6518
0
    const int64_t IH = ne11;
6519
0
    const int64_t IW = ne10;
6520
6521
0
    const int64_t OC = ne03 / IC;
6522
0
    GGML_UNUSED(OC);
6523
0
    const int64_t KD = ne02;
6524
0
    const int64_t KH = ne01;
6525
0
    const int64_t KW = ne00;
6526
6527
0
    const int64_t OD = ne3 / N;
6528
0
    const int64_t OH = ne2;
6529
0
    const int64_t OW = ne1;
6530
0
    const int64_t OH_OW = OH*OW;
6531
0
    const int64_t KD_KH_KW = KD*KH*KW;
6532
0
    const int64_t KH_KW = KH*KW;
6533
0
    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6534
6535
0
    GGML_ASSERT(nb10 == sizeof(float));
6536
6537
    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6538
0
    {
6539
0
        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6540
6541
0
        for (int64_t in = 0; in < N; in++) {
6542
0
            for (int64_t iod = 0; iod < OD; iod++) {
6543
0
                for (int64_t ioh = 0; ioh < OH; ioh++) {
6544
0
                    for (int64_t iow = 0; iow < OW; iow++) {
6545
0
                        for (int64_t iic = ith; iic < IC; iic += nth) {
6546
6547
                            // micro kernel
6548
0
                            ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6549
0
                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6550
6551
0
                            for (int64_t ikd = 0; ikd < KD; ikd++) {
6552
0
                                for (int64_t ikh = 0; ikh < KH; ikh++) {
6553
0
                                    for (int64_t ikw = 0; ikw < KW; ikw++) {
6554
0
                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;
6555
0
                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;
6556
0
                                        const int64_t iid = iod*s2 + ikd*d2 - p2;
6557
6558
0
                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6559
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6560
0
                                        } else {
6561
0
                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6562
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
6563
0
                                        }
6564
0
                                    }
6565
0
                                }
6566
0
                            }
6567
0
                        }
6568
0
                    }
6569
0
                }
6570
0
            }
6571
0
        }
6572
0
    }
6573
0
}
6574
6575
// ggml_compute_forward_im2col_3d_f32
6576
// src0: kernel [OC*IC, KD, KH, KW]
6577
// src1: image [N*IC, ID, IH, IW]
6578
// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]
6579
static void ggml_compute_forward_im2col_3d_f32(
6580
        const ggml_compute_params * params,
6581
0
              ggml_tensor * dst) {
6582
6583
0
    const ggml_tensor * src0 = dst->src[0];
6584
0
    const ggml_tensor * src1 = dst->src[1];
6585
6586
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6587
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6588
6589
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6590
6591
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6592
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6593
0
    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6594
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6595
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6596
0
    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6597
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6598
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6599
0
    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6600
0
    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6601
6602
6603
0
    const int ith = params->ith;
6604
0
    const int nth = params->nth;
6605
6606
0
    const int64_t N  = ne13 / IC;
6607
0
    const int64_t ID = ne12;
6608
0
    const int64_t IH = ne11;
6609
0
    const int64_t IW = ne10;
6610
6611
0
    const int64_t OC = ne03 / IC;
6612
0
    GGML_UNUSED(OC);
6613
0
    const int64_t KD = ne02;
6614
0
    const int64_t KH = ne01;
6615
0
    const int64_t KW = ne00;
6616
6617
0
    const int64_t OD = ne3 / N;
6618
0
    const int64_t OH = ne2;
6619
0
    const int64_t OW = ne1;
6620
6621
0
    const int64_t OH_OW = OH*OW;
6622
0
    const int64_t KD_KH_KW = KD*KH*KW;
6623
0
    const int64_t KH_KW = KH*KW;
6624
0
    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6625
6626
0
    GGML_ASSERT(nb10 == sizeof(float));
6627
6628
    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6629
0
    {
6630
0
        float * const wdata = (float *) dst->data;
6631
6632
0
        for (int64_t in = 0; in < N; in++) {
6633
0
            for (int64_t iod = 0; iod < OD; iod++) {
6634
0
                for (int64_t ioh = 0; ioh < OH; ioh++) {
6635
0
                    for (int64_t iow = 0; iow < OW; iow++) {
6636
0
                        for (int64_t iic = ith; iic < IC; iic += nth) {
6637
6638
                            // micro kernel
6639
0
                            float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6640
0
                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6641
6642
0
                            for (int64_t ikd = 0; ikd < KD; ikd++) {
6643
0
                                for (int64_t ikh = 0; ikh < KH; ikh++) {
6644
0
                                    for (int64_t ikw = 0; ikw < KW; ikw++) {
6645
0
                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;
6646
0
                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;
6647
0
                                        const int64_t iid = iod*s2 + ikd*d2 - p2;
6648
6649
0
                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6650
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6651
0
                                        } else {
6652
0
                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6653
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
6654
0
                                        }
6655
0
                                    }
6656
0
                                }
6657
0
                            }
6658
0
                        }
6659
0
                    }
6660
0
                }
6661
0
            }
6662
0
        }
6663
0
    }
6664
0
}
6665
6666
6667
void ggml_compute_forward_im2col_3d(
6668
        const ggml_compute_params * params,
6669
0
              ggml_tensor * dst) {
6670
0
    switch (dst->type) {
6671
0
        case GGML_TYPE_F16:
6672
0
            {
6673
0
                ggml_compute_forward_im2col_3d_f16(params, dst);
6674
0
            } break;
6675
0
        case GGML_TYPE_F32:
6676
0
            {
6677
0
                ggml_compute_forward_im2col_3d_f32(params, dst);
6678
0
            } break;
6679
0
        default:
6680
0
            {
6681
0
                GGML_ABORT("fatal error");
6682
0
            }
6683
0
    }
6684
0
}
6685
6686
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6687
0
                              void * a, void * b, float * c) {
6688
0
    const ggml_type_traits * traits = ggml_get_type_traits(type);
6689
0
    struct ggml_tensor src1 = {};
6690
0
    src1.type  = type;
6691
0
    src1.ne[0] = k;
6692
0
    src1.ne[1] = m;
6693
0
    src1.ne[2] = 1;
6694
0
    src1.ne[3] = 1;
6695
0
    src1.nb[0] = traits->type_size;
6696
0
    src1.nb[1] = k * traits->type_size;
6697
0
    src1.nb[2] = src1.nb[1];
6698
0
    src1.nb[3] = src1.nb[2];
6699
0
    src1.data  = a;
6700
6701
0
    struct ggml_tensor src0 = {};
6702
0
    src0.type  = type;
6703
0
    src0.ne[0] = k;
6704
0
    src0.ne[1] = n;
6705
0
    src0.ne[2] = 1;
6706
0
    src0.ne[3] = 1;
6707
0
    src0.nb[0] = traits->type_size;
6708
0
    src0.nb[1] = k * traits->type_size;
6709
0
    src0.nb[2] = src0.nb[1];
6710
0
    src0.nb[3] = src0.nb[2];
6711
0
    src0.data  = b;
6712
6713
0
    struct ggml_tensor dst = {};
6714
0
    dst.ne[0] = n;
6715
0
    dst.ne[1] = m;
6716
0
    dst.ne[2] = 1;
6717
0
    dst.ne[3] = 1;
6718
0
    dst.nb[0] = sizeof(float);
6719
0
    dst.nb[1] = n * sizeof(float);
6720
0
    dst.nb[2] = dst.nb[1];
6721
0
    dst.nb[3] = dst.nb[2];
6722
0
    dst.data  = c;
6723
0
    dst.src[0] = &src0;
6724
0
    dst.src[1] = &src1;
6725
6726
0
    ggml_compute_forward_mul_mat(params, &dst);
6727
0
}
6728
6729
0
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
6730
0
    return (coord  + size) % size; // adding size avoids negative number weirdness
6731
0
}
6732
6733
// ggml_compute_forward_col2im_1d
6734
//
6735
// Scatter-add columns [K*OC, T_in] -> signal [T_out, OC]
6736
// where T_out = (T_in - 1)*s + K - 2*p.  Gather approach: each output reads ceil(K/s) inputs.
6737
// Parallelized over the time axis so the split stays balanced whatever OC is.
6738
// Supports F32, F16, BF16 input/output (same type), F32 accumulator.
6739
6740
template <typename elem_t>
6741
static void ggml_compute_forward_col2im_1d_impl(
6742
        const ggml_compute_params * params,
6743
0
        ggml_tensor * dst) {
6744
6745
0
    const ggml_tensor * src = dst->src[0];  // [K*OC, T_in]
6746
6747
0
    GGML_ASSERT(ggml_is_contiguous(src));
6748
0
    GGML_ASSERT(ggml_is_contiguous(dst));
6749
6750
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6751
0
    const int32_t OC = ((const int32_t *)(dst->op_params))[1];
6752
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6753
6754
0
    const int64_t K_OC = src->ne[0];
6755
0
    const int64_t T_in = src->ne[1];
6756
0
    const int64_t K    = K_OC / OC;
6757
0
    const int64_t T_out = dst->ne[0];
6758
6759
0
    const elem_t * col_data = (const elem_t *) src->data;
6760
0
    elem_t       * dst_data = (elem_t *) dst->data;
6761
6762
0
    const int ith = params->ith;
6763
0
    const int nth = params->nth;
6764
6765
    // Parallelize over the time axis: the split stays balanced whatever OC is,
6766
    // down to OC = 1 for mono audio, and threads read disjoint column bands
6767
0
    const int64_t dr = (T_out + nth - 1) / nth;
6768
0
    const int64_t it0 = dr * ith;
6769
0
    const int64_t it1 = it0 + dr < T_out ? it0 + dr : T_out;
6770
6771
0
    for (int64_t oc = 0; oc < OC; oc++) {
6772
0
        for (int64_t t_out = it0; t_out < it1; t_out++) {
6773
0
            const int64_t t_abs = t_out + p0;  // absolute position in uncropped signal
6774
            // Gather: find all (t_in, k) where t_in * s + k == t_abs, 0 <= k < K
6775
0
            int64_t t_in_min = (t_abs - K + 1 + s0 - 1) / s0;  // ceil((t_abs-K+1)/s)
6776
0
            if (t_in_min < 0) t_in_min = 0;
6777
0
            int64_t t_in_max = t_abs / s0;
6778
0
            if (t_in_max >= T_in) t_in_max = T_in - 1;
6779
6780
0
            float sum = 0.0f;
6781
0
            for (int64_t t_in = t_in_min; t_in <= t_in_max; t_in++) {
6782
0
                int64_t k = t_abs - t_in * s0;
6783
0
                if (k >= 0 && k < K) {
6784
                    // col layout: [K*OC, T_in], element (oc*K+k, t_in)
6785
0
                    sum += type_conversion_table<elem_t>::to_f32(col_data[(oc * K + k) + t_in * K_OC]);
6786
0
                }
6787
0
            }
6788
            // dst layout: [T_out, OC], element (t_out, oc)
6789
0
            dst_data[t_out + oc * T_out] = type_conversion_table<elem_t>::from_f32(sum);
6790
0
        }
6791
0
    }
6792
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_col2im_1d_impl<float>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_col2im_1d_impl<unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_col2im_1d_impl<ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
6793
6794
void ggml_compute_forward_col2im_1d(
6795
        const ggml_compute_params * params,
6796
0
        ggml_tensor * dst) {
6797
0
    switch (dst->src[0]->type) {
6798
0
        case GGML_TYPE_F32:  ggml_compute_forward_col2im_1d_impl<float>      (params, dst); break;
6799
0
        case GGML_TYPE_F16:  ggml_compute_forward_col2im_1d_impl<ggml_fp16_t>(params, dst); break;
6800
0
        case GGML_TYPE_BF16: ggml_compute_forward_col2im_1d_impl<ggml_bf16_t>(params, dst); break;
6801
0
        default: GGML_ABORT("col2im_1d: unsupported type %d", dst->src[0]->type);
6802
0
    }
6803
0
}
6804
6805
// ggml_compute_forward_conv_2d
6806
6807
6808
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6809
                                              const ggml_tensor *         kernel,  // [KW, KH, IC, OC]
6810
                                              const ggml_tensor *         src,     // [W, H, C, N]
6811
                                              ggml_tensor *               dst,     // [OW, OH, OC, N]
6812
0
                                              ggml_type                   kernel_type) {
6813
6814
0
    GGML_ASSERT(ggml_is_contiguous(kernel));
6815
0
    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6816
0
    GGML_ASSERT(kernel->type == kernel_type);
6817
6818
0
    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6819
6820
0
    const int32_t stride_x   = dst->op_params[0];
6821
0
    const int32_t stride_y   = dst->op_params[1];
6822
0
    const int32_t pad_x      = dst->op_params[2];
6823
0
    const int32_t pad_y      = dst->op_params[3];
6824
0
    const int32_t dilation_x = dst->op_params[4];
6825
0
    const int32_t dilation_y = dst->op_params[5];
6826
6827
0
    const int64_t c_in  = src->ne[2];
6828
0
    const int64_t c_out = kernel->ne[3];
6829
0
    GGML_ASSERT(c_in == kernel->ne[2]);
6830
6831
0
    const int64_t src_w = src->ne[0];
6832
0
    const int64_t src_h = src->ne[1];
6833
0
    const int64_t knl_w = kernel->ne[0];
6834
0
    const int64_t knl_h = kernel->ne[1];
6835
0
    const int64_t dst_w = dst->ne[0];
6836
0
    const int64_t dst_h = dst->ne[1];
6837
6838
0
    const float * src_data = (float *) src->data;
6839
0
    void  * knl_data       = kernel->data;
6840
0
    float * dst_data       = (float *) dst->data;
6841
6842
0
    const int64_t knl_n           = knl_w * knl_h * c_in;
6843
0
    const int64_t patch_total     = dst->ne[3] * dst_w * dst_h;
6844
6845
0
    const int64_t space_per_patch   = knl_n * traits->type_size + c_out * sizeof(float);
6846
0
    const int64_t batch_size        = params->wsize / space_per_patch;
6847
0
    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6848
0
    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;
6849
6850
0
    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6851
6852
0
    void * tmp = params->wdata;
6853
6854
0
    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6855
6856
0
        const int64_t patch_start_batch = batch_i * patches_per_batch;
6857
0
        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch,
6858
0
                                              patch_total);
6859
0
        const int64_t patch_n           = patch_end_batch - patch_start_batch;
6860
6861
0
        const int64_t patch_per_thread  = (patch_n + params->nth - 1) / params->nth;
6862
0
        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;
6863
0
        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);
6864
6865
        //im2col for a patch
6866
0
        for (int64_t p = patch_start; p < patch_end; ++p) {
6867
0
            const int64_t  batch_n     =  p / (dst_w * dst_h);
6868
0
            const int64_t  src_x       = (p / dst_w) % dst_h;
6869
0
            const int64_t  src_y       =  p % dst_w;
6870
6871
0
            const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6872
0
            char *        dst_row  = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6873
6874
0
            for (int64_t ic = 0; ic < c_in; ++ic) {
6875
0
                for (int64_t ky = 0; ky < knl_h; ++ky) {
6876
0
                    for (int64_t kx = 0; kx < knl_w; ++kx) {
6877
0
                        const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6878
0
                        const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6879
6880
0
                        int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6881
6882
0
                        float src_val;
6883
0
                        if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6884
0
                            src_val = 0.0f;
6885
0
                        } else {
6886
0
                            const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6887
0
                            src_val               = *src_ptr;
6888
0
                        }
6889
6890
0
                        char * element_ptr = dst_row + dst_idx * traits->type_size;
6891
0
                        if (kernel_type == GGML_TYPE_F32) {
6892
0
                            *(float *) element_ptr = src_val;
6893
0
                        } else if (kernel_type == GGML_TYPE_F16) {
6894
0
                            *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6895
0
                        }
6896
0
                    }
6897
0
                }
6898
0
            }
6899
0
        }   // patches handled by this thread
6900
6901
0
        ggml_barrier(params->threadpool);
6902
6903
0
        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6904
6905
0
        GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6906
6907
        // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6908
0
        ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6909
6910
0
        ggml_barrier(params->threadpool);
6911
6912
6913
        //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6914
0
        const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
6915
0
        const int64_t permute_start = params->ith * permute_per_thread;
6916
0
        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
6917
6918
0
        for (int64_t i = permute_start; i < permute_end; ++i) {
6919
0
            const int64_t p       = patch_start_batch + i;
6920
0
            const int64_t batch_n = p / (dst_w * dst_h);
6921
0
            const int64_t dst_y   = (p / dst_w) % dst_h;
6922
0
            const int64_t dst_x   = p % dst_w;
6923
6924
0
            for (int64_t oc = 0; oc < c_out; ++oc) {
6925
0
                const float value = gemm_output[i * c_out + oc];
6926
0
                float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
6927
0
                *dst_ptr = value;
6928
0
            }
6929
0
        }
6930
0
    }
6931
0
}
6932
6933
void ggml_compute_forward_conv_2d(
6934
        const ggml_compute_params * params,
6935
0
        ggml_tensor * dst) {
6936
6937
0
    const ggml_tensor * src0 = dst->src[0];
6938
0
    const ggml_tensor * src1 = dst->src[1];
6939
6940
0
    ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
6941
0
}
6942
6943
// ggml_compute_forward_conv_3d
6944
6945
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
6946
                                              const ggml_tensor *         kernel,
6947
                                              const ggml_tensor *         src,
6948
                                              ggml_tensor *               dst,
6949
0
                                              ggml_type                   kernel_type) {
6950
6951
0
    GGML_ASSERT(ggml_is_contiguous(kernel));
6952
0
    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6953
0
    GGML_ASSERT(kernel->type == kernel_type);
6954
6955
0
    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6956
6957
0
    const int32_t s0 = dst->op_params[0];
6958
0
    const int32_t s1 = dst->op_params[1];
6959
0
    const int32_t s2 = dst->op_params[2];
6960
0
    const int32_t p0 = dst->op_params[3];
6961
0
    const int32_t p1 = dst->op_params[4];
6962
0
    const int32_t p2 = dst->op_params[5];
6963
0
    const int32_t d0 = dst->op_params[6];
6964
0
    const int32_t d1 = dst->op_params[7];
6965
0
    const int32_t d2 = dst->op_params[8];
6966
0
    const int32_t c  = dst->op_params[9];
6967
0
    const int32_t n  = dst->op_params[10];
6968
0
    const int32_t oc = dst->op_params[11];
6969
6970
0
    const int64_t src_w = src->ne[0];
6971
0
    const int64_t src_h = src->ne[1];
6972
0
    const int64_t src_d = src->ne[2];
6973
0
    const int64_t knl_w = kernel->ne[0];
6974
0
    const int64_t knl_h = kernel->ne[1];
6975
0
    const int64_t knl_d = kernel->ne[2];
6976
0
    const int64_t dst_w = dst->ne[0];
6977
0
    const int64_t dst_h = dst->ne[1];
6978
0
    const int64_t dst_d = dst->ne[2];
6979
6980
0
    const float * src_data = (float *) src->data;
6981
0
    void  * knl_data       = kernel->data;
6982
0
    float * dst_data       = (float *) dst->data;
6983
6984
0
    const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
6985
0
    const int64_t knl_n_total       = knl_n_per_channel * c;
6986
0
    const int64_t patch_total       = n * dst_w * dst_h * dst_d;
6987
6988
0
    const int64_t space_per_patch   = knl_n_total * traits->type_size + oc * sizeof(float);
6989
0
    const int64_t batch_size        = params->wsize / space_per_patch;
6990
0
    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6991
0
    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;
6992
6993
0
    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6994
6995
0
    void * tmp = params->wdata;
6996
6997
0
    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6998
0
        const int64_t patch_start_batch = batch_i * patches_per_batch;
6999
0
        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch, patch_total);
7000
0
        const int64_t patch_n_in_batch  = patch_end_batch - patch_start_batch;
7001
7002
0
        const int64_t patch_per_thread  = (patch_n_in_batch + params->nth - 1) / params->nth;
7003
0
        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;
7004
0
        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);
7005
7006
0
        for (int64_t p = patch_start; p < patch_end; ++p) {
7007
0
            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
7008
0
            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
7009
0
            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);
7010
0
            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);
7011
0
            const int64_t dst_y      = p_in_depth / dst_w;
7012
0
            const int64_t dst_x      = p_in_depth % dst_w;
7013
7014
0
            char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
7015
7016
0
            for (int64_t ic = 0; ic < c; ++ic) {
7017
0
                for (int64_t kz = 0; kz < knl_d; ++kz) {
7018
0
                    for (int64_t ky = 0; ky < knl_h; ++ky) {
7019
0
                        for (int64_t kx = 0; kx < knl_w; ++kx) {
7020
0
                            const int64_t sz = dst_z * s2 + kz * d2 - p2;
7021
0
                            const int64_t sy = dst_y * s1 + ky * d1 - p1;
7022
0
                            const int64_t sx = dst_x * s0 + kx * d0 - p0;
7023
7024
0
                            int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
7025
7026
0
                            float src_val;
7027
0
                            if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
7028
0
                                src_val = 0.0f;
7029
0
                            } else {
7030
0
                                const int64_t cn_idx = batch_idx * c + ic;
7031
0
                                const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
7032
0
                                src_val = *src_ptr;
7033
0
                            }
7034
7035
0
                            char * element_ptr = dst_row + dst_idx * traits->type_size;
7036
0
                            if (kernel_type == GGML_TYPE_F32) {
7037
0
                                *(float *)element_ptr = src_val;
7038
0
                            } else if (kernel_type == GGML_TYPE_F16) {
7039
0
                                *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
7040
0
                            }
7041
0
                        }
7042
0
                    }
7043
0
                }
7044
0
            }
7045
0
        }
7046
7047
0
        ggml_barrier(params->threadpool);
7048
7049
0
        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
7050
0
        ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
7051
7052
0
        ggml_barrier(params->threadpool);
7053
7054
0
        const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
7055
0
        const int64_t permute_start = params->ith * permute_per_thread;
7056
0
        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
7057
7058
0
        for (int64_t i = permute_start; i < permute_end; ++i) {
7059
0
            const int64_t p = patch_start_batch + i;
7060
0
            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
7061
0
            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
7062
0
            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);
7063
0
            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);
7064
0
            const int64_t dst_y      = p_in_depth / dst_w;
7065
0
            const int64_t dst_x      = p_in_depth % dst_w;
7066
7067
0
            for (int64_t ioc = 0; ioc < oc; ++ioc) {
7068
0
                const float value = gemm_output[i * oc + ioc];
7069
0
                const int64_t ocn_idx = batch_idx * oc + ioc;
7070
0
                float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
7071
0
                *dst_ptr = value;
7072
0
            }
7073
0
        }
7074
0
    }
7075
0
}
7076
7077
void ggml_compute_forward_conv_3d(
7078
        const ggml_compute_params * params,
7079
0
        ggml_tensor * dst) {
7080
0
    const ggml_tensor * src0 = dst->src[0];
7081
0
    const ggml_tensor * src1 = dst->src[1];
7082
0
    ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
7083
0
}
7084
7085
template <typename kernel_t>
7086
static void ggml_compute_forward_conv_transpose_2d_impl(
7087
    const ggml_compute_params * params,
7088
0
          ggml_tensor * dst) {
7089
7090
0
    const ggml_tensor * src0 = dst->src[0];
7091
0
    const ggml_tensor * src1 = dst->src[1];
7092
7093
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
7094
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
7095
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
7096
7097
0
    GGML_TENSOR_BINARY_OP_LOCALS
7098
7099
0
    const int ith = params->ith;
7100
0
    const int nth = params->nth;
7101
7102
0
    const int nk = ne00*ne01*ne02*ne03;
7103
7104
0
    GGML_ASSERT(nb00 == ggml_type_size(src0->type));
7105
0
    GGML_ASSERT(nb10 == sizeof(float));
7106
7107
0
    if (ith == 0) {
7108
0
        memset(params->wdata, 0, params->wsize);
7109
7110
        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
7111
0
        {
7112
0
            kernel_t * const wdata = (kernel_t *) params->wdata + 0;
7113
7114
0
            for (int64_t i03 = 0; i03 < ne03; i03++) {
7115
0
                for (int64_t i02 = 0; i02 < ne02; i02++) {
7116
0
                    const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02);
7117
0
                    kernel_t * dst_data = wdata + i02*ne01*ne00*ne03;
7118
0
                    for (int64_t i01 = 0; i01 < ne01; i01++) {
7119
0
                        for (int64_t i00 = 0; i00 < ne00; i00++) {
7120
0
                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
7121
0
                        }
7122
0
                    }
7123
0
                }
7124
0
            }
7125
0
        }
7126
7127
        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
7128
0
        {
7129
0
            kernel_t * const wdata = (kernel_t *) params->wdata + nk;
7130
0
            for (int i12 = 0; i12 < ne12; i12++) {
7131
0
                for (int i11 = 0; i11 < ne11; i11++) {
7132
0
                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
7133
0
                    kernel_t * dst_data = wdata + i11*ne10*ne12;
7134
0
                    for (int i10 = 0; i10 < ne10; i10++) {
7135
0
                        if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
7136
0
                            dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
7137
0
                        } else {
7138
0
                            dst_data[i10*ne12 + i12] = src[i10];
7139
0
                        }
7140
0
                    }
7141
0
                }
7142
0
            }
7143
0
        }
7144
7145
0
        memset(dst->data, 0, ggml_nbytes(dst));
7146
0
    }
7147
0
    ggml_barrier(params->threadpool);
7148
7149
0
    const int32_t stride = ggml_get_op_params_i32(dst, 0);
7150
7151
    // total patches in dst
7152
0
    const int np = ne2;
7153
7154
    // patches per thread
7155
0
    const int dp = (np + nth - 1)/nth;
7156
7157
    // patch range for this thread
7158
0
    const int ip0 = dp*ith;
7159
0
    const int ip1 = MIN(ip0 + dp, np);
7160
7161
0
    kernel_t * const wdata = (kernel_t *) params->wdata + 0;
7162
0
    kernel_t * const wdata_src = wdata + nk;
7163
7164
0
    for (int i2 = ip0; i2 < ip1; i2++) { // Cout
7165
0
        float * dst_data = (float *)((char *) dst->data + i2*nb2);
7166
0
        kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
7167
0
        for (int i11 = 0; i11 < ne11; i11++) {
7168
0
            for (int i10 = 0; i10 < ne10; i10++) {
7169
0
                const int i1n = i11*ne10*ne12 + i10*ne12;
7170
0
                for (int i01 = 0; i01 < ne01; i01++) {
7171
0
                    for (int i00 = 0; i00 < ne00; i00++) {
7172
0
                        float v = 0;
7173
0
                        if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
7174
0
                            ggml_vec_dot_f16(ne03, &v, 0,
7175
0
                                    wdata_src + i1n, 0,
7176
0
                                    wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
7177
0
                        } else {
7178
0
                            ggml_vec_dot_f32(ne03, &v, 0,
7179
0
                                    wdata_src + i1n, 0,
7180
0
                                    wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
7181
0
                        }
7182
0
                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
7183
0
                    }
7184
0
                }
7185
0
            }
7186
0
        }
7187
0
    }
7188
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_conv_transpose_2d_impl<unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_conv_transpose_2d_impl<float>(ggml_compute_params const*, ggml_tensor*)
7189
7190
void ggml_compute_forward_conv_transpose_2d(
7191
        const ggml_compute_params * params,
7192
0
              ggml_tensor * dst) {
7193
7194
0
    const ggml_tensor * src0 = dst->src[0];
7195
7196
0
    switch (src0->type) {
7197
0
        case GGML_TYPE_F16:
7198
0
            {
7199
0
                ggml_compute_forward_conv_transpose_2d_impl<ggml_fp16_t>(params, dst);
7200
0
            } break;
7201
0
        case GGML_TYPE_F32:
7202
0
            {
7203
0
                ggml_compute_forward_conv_transpose_2d_impl<float>(params, dst);
7204
0
            } break;
7205
0
        default:
7206
0
            {
7207
0
                GGML_ABORT("fatal error");
7208
0
            }
7209
0
    }
7210
0
}
7211
7212
// ggml_compute_forward_conv_2d_dw
7213
7214
struct ggml_conv_2d_dw_params {
7215
    int64_t channels;
7216
    int64_t batch;
7217
    int64_t src_w;
7218
    int64_t src_h;
7219
    int64_t dst_w;
7220
    int64_t dst_h;
7221
    int64_t knl_w;
7222
    int64_t knl_h;
7223
    int stride_x;
7224
    int stride_y;
7225
    int pad_x;
7226
    int pad_y;
7227
    int dilation_x;
7228
    int dilation_y;
7229
};
7230
7231
static void ggml_compute_forward_conv_2d_dw_cwhn(
7232
        const ggml_compute_params * params,
7233
        const ggml_tensor * src,
7234
        const ggml_tensor * kernel,
7235
        ggml_tensor * dst,
7236
0
        const ggml_conv_2d_dw_params & p) {
7237
7238
0
    const int64_t c = p.channels;
7239
0
    const float * knl_data = (const float *)kernel->data;
7240
7241
0
    const int64_t rows_total = p.dst_h * p.batch;
7242
0
    const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
7243
0
    const int64_t row_start = params->ith * rows_per_thread;
7244
0
    const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
7245
7246
0
#ifdef GGML_SIMD
7247
    #if defined(__ARM_FEATURE_SVE)
7248
        const int64_t pkg_size = svcntw();
7249
    #else
7250
0
        const int64_t pkg_size = GGML_F32_EPR;
7251
0
    #endif
7252
0
    const int64_t pkg_count = c / pkg_size;
7253
0
    const int64_t c_pkg_end = pkg_count * pkg_size;
7254
#else
7255
    const int64_t c_pkg_end = 0;
7256
#endif
7257
7258
0
    for (int64_t row = row_start; row < row_end; ++row) {
7259
0
        const int64_t dst_y = row % p.dst_h;
7260
0
        const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
7261
0
        for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
7262
0
            float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
7263
0
            const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
7264
0
            const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
7265
7266
0
#ifdef GGML_SIMD
7267
            // Vectorized loop
7268
0
            for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
7269
0
                GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
7270
0
                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7271
0
                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;
7272
0
                    if (src_y < 0 || src_y >= p.src_h) {
7273
0
                        continue;
7274
0
                    }
7275
0
                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7276
0
                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;
7277
0
                        if (src_x < 0 || src_x >= p.src_w) {
7278
0
                            continue;
7279
0
                        }
7280
0
                        GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
7281
0
                        GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
7282
0
                        sum = GGML_F32_VEC_FMA(sum, k, s);
7283
0
                    }
7284
0
                }
7285
0
                GGML_F32_VEC_STORE(dst_data + c_i, sum);
7286
0
            }
7287
0
#endif
7288
            // Scalar loop
7289
0
            for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
7290
0
                float sum = 0.0f;
7291
0
                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7292
0
                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;
7293
0
                    if (src_y < 0 || src_y >= p.src_h) {
7294
0
                        continue;
7295
0
                    }
7296
0
                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7297
0
                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;
7298
0
                        if (src_x < 0 || src_x >= p.src_w) {
7299
0
                            continue;
7300
0
                        }
7301
0
                        sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
7302
0
                             * src_data[(src_y * p.src_w + src_x) * c + c_i];
7303
0
                    }
7304
0
                }
7305
0
                dst_data[c_i] = sum;
7306
0
            }
7307
0
        }
7308
0
    }
7309
0
}
7310
7311
static void ggml_compute_forward_conv_2d_dw_whcn(
7312
        const ggml_compute_params * params,
7313
        const ggml_tensor * src,
7314
        const ggml_tensor * kernel,
7315
        ggml_tensor * dst,
7316
0
        const ggml_conv_2d_dw_params & p) {
7317
7318
0
    const int64_t n = p.channels * p.batch;
7319
0
    const int64_t per_thread = (n + params->nth - 1) / params->nth;
7320
0
    const int64_t start = params->ith * per_thread;
7321
0
    const int64_t end = MIN(start + per_thread, n);
7322
7323
0
    for (int64_t i = start; i < end; ++i) {
7324
0
        const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
7325
0
        const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
7326
0
        float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
7327
7328
0
        for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
7329
0
            for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
7330
7331
0
                float sum = 0.0f;
7332
0
                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7333
0
                    const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
7334
0
                    if (src_y < 0 || src_y >= p.src_h) {
7335
0
                        continue;
7336
0
                    }
7337
0
                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7338
0
                        const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
7339
0
                        if (src_x < 0 || src_x >= p.src_w) {
7340
0
                            continue;
7341
0
                        }
7342
0
                        sum += knl_data[knl_y * p.knl_w + knl_x]
7343
0
                             * src_data[src_y * p.src_w + src_x];
7344
0
                    }
7345
0
                }
7346
0
                dst_data[dst_y * p.dst_w + dst_x] = sum;
7347
0
            }
7348
0
        }
7349
0
    }
7350
0
}
7351
7352
void ggml_compute_forward_conv_2d_dw(
7353
        const ggml_compute_params * params,
7354
0
        ggml_tensor * dst) {
7355
7356
0
    const ggml_tensor * kernel = dst->src[0];
7357
0
    const ggml_tensor * src = dst->src[1];
7358
0
    ggml_conv_2d_dw_params p;
7359
0
    p.channels = src->ne[2];
7360
0
    p.batch = src->ne[3];
7361
0
    p.src_w = src->ne[0];
7362
0
    p.src_h = src->ne[1];
7363
0
    p.dst_w = dst->ne[0];
7364
0
    p.dst_h = dst->ne[1];
7365
0
    p.knl_w = kernel->ne[0];
7366
0
    p.knl_h = kernel->ne[1];
7367
0
    p.stride_x = dst->op_params[0];
7368
0
    p.stride_y = dst->op_params[1];
7369
0
    p.pad_x = dst->op_params[2];
7370
0
    p.pad_y = dst->op_params[3];
7371
0
    p.dilation_x = dst->op_params[4];
7372
0
    p.dilation_y = dst->op_params[5];
7373
7374
0
    GGML_ASSERT(kernel->ne[3] == p.channels);
7375
0
    GGML_ASSERT(dst->ne[3] == p.batch);
7376
7377
0
    if (ggml_is_contiguous(src)) {
7378
0
        ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
7379
0
    } else if (ggml_is_contiguous_channels(src)) {
7380
        // kernel should also have channels most contiguous in memory
7381
0
        GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
7382
0
        ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
7383
0
    } else {
7384
0
        GGML_ABORT("non-contiguous memory layout not supported");
7385
0
    }
7386
0
}
7387
7388
// ggml_compute_forward_pool_1d_ksp
7389
static void ggml_compute_forward_pool_1d_ksp(
7390
        const ggml_compute_params * params,
7391
        const ggml_op_pool op,
7392
        const int k,
7393
        const int s,
7394
        const int p,
7395
0
        ggml_tensor * dst) {
7396
7397
0
    const ggml_tensor * src = dst->src[0];
7398
7399
0
    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
7400
7401
0
    if (params->ith != 0) {
7402
0
        return;
7403
0
    }
7404
7405
0
    const int64_t IW = src->ne[0];
7406
0
    const int64_t OW = dst->ne[0];
7407
7408
0
    const int64_t nr = ggml_nrows(src);
7409
7410
0
    for (int64_t ir = 0; ir < nr; ++ir) {
7411
0
        const char * srow_bytes =            (const char *) src->data + ir * src->nb[1];
7412
0
        float      * drow       = (float *) ((      char *) dst->data + ir * dst->nb[1]);
7413
7414
0
        for (int64_t ow = 0; ow < OW; ++ow) {
7415
0
            float res = 0;
7416
0
            switch (op) {
7417
0
                case GGML_OP_POOL_AVG: res = 0.0f;     break;
7418
0
                case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7419
0
                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7420
0
            }
7421
7422
0
            int count = 0;
7423
0
            const int base = (int) ow * s - p;
7424
7425
0
            for (int ki = 0; ki < k; ++ki) {
7426
0
                const int j = base + ki;
7427
0
                if (j < 0 || j >= (int) IW) {
7428
0
                    continue;
7429
0
                }
7430
7431
0
                float v;
7432
0
                if (src->type == GGML_TYPE_F32) {
7433
0
                    v = ((const float *) srow_bytes)[j];
7434
0
                } else {
7435
0
                    v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
7436
0
                }
7437
7438
0
                switch (op) {
7439
0
                    case GGML_OP_POOL_AVG: res += v;                break;
7440
0
                    case GGML_OP_POOL_MAX: res =  std::max(v, res); break;
7441
0
                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7442
0
                }
7443
7444
0
                ++count;
7445
0
            }
7446
7447
0
            switch (op) {
7448
0
                case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
7449
0
                case GGML_OP_POOL_MAX:                                           break;
7450
0
                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7451
0
            }
7452
7453
0
            drow[ow] = res;
7454
0
        }
7455
0
    }
7456
0
}
7457
7458
// ggml_compute_forward_pool_1d
7459
7460
void ggml_compute_forward_pool_1d(
7461
        const ggml_compute_params * params,
7462
0
              ggml_tensor * dst) {
7463
7464
0
    const int32_t * opts = (const int32_t *)dst->op_params;
7465
0
    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7466
0
    const int k0 = opts[1];
7467
0
    const int s0 = opts[2];
7468
0
    const int p0 = opts[3];
7469
7470
0
    ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
7471
0
}
7472
7473
// ggml_compute_forward_pool_2d
7474
7475
void ggml_compute_forward_pool_2d(
7476
        const ggml_compute_params * params,
7477
0
        ggml_tensor * dst) {
7478
7479
0
    const ggml_tensor * src = dst->src[0];
7480
7481
0
    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
7482
7483
0
    if (params->ith != 0) {
7484
0
        return;
7485
0
    }
7486
7487
0
    const int32_t * opts = (const int32_t *)dst->op_params;
7488
7489
0
    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7490
0
    const int k0 = opts[1];
7491
0
    const int k1 = opts[2];
7492
0
    const int s0 = opts[3];
7493
0
    const int s1 = opts[4];
7494
0
    const int p0 = opts[5];
7495
0
    const int p1 = opts[6];
7496
0
    const char * cdata = (const char*)src->data;
7497
0
    const char * const data_end = cdata + ggml_nbytes(src);
7498
7499
0
    const int64_t px = dst->ne[0];
7500
0
    const int64_t py = dst->ne[1];
7501
0
    const int64_t pa = px * py;
7502
7503
0
    float * dplane = (float *)dst->data;
7504
7505
0
    const int ka = k0 * k1;
7506
0
    const int offset0 = -p0;
7507
0
    const int offset1 = -p1;
7508
7509
0
    while (cdata < data_end) {
7510
0
        for (int oy = 0; oy < py; ++oy) {
7511
0
            float * const drow = dplane + oy * px;
7512
0
            float * const out  = drow;
7513
7514
0
            for (int ox = 0; ox < px; ++ox) {
7515
0
                float res = 0;
7516
0
                switch (op) {
7517
0
                    case GGML_OP_POOL_AVG: res = 0;        break;
7518
0
                    case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7519
0
                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7520
0
                }
7521
7522
0
                const int ix = offset0 + ox * s0;
7523
0
                const int iy = offset1 + oy * s1;
7524
7525
0
                for (int ky = 0; ky < k1; ++ky) {
7526
0
                    if (iy + ky < 0 || iy + ky >= src->ne[1]) {
7527
0
                        continue;
7528
0
                    }
7529
7530
0
                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
7531
0
                    for (int kx = 0; kx < k0; ++kx) {
7532
0
                        int j = ix + kx;
7533
0
                        if (j < 0 || j >= src->ne[0]) {
7534
0
                            continue;
7535
0
                        }
7536
7537
0
                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7538
0
                        switch (op) {
7539
0
                            case GGML_OP_POOL_AVG: res += srow_j;                break;
7540
0
                            case GGML_OP_POOL_MAX: res =  std::max(srow_j, res); break;
7541
0
                            case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
7542
0
                        }
7543
0
                    }
7544
0
                }
7545
0
                switch (op) {
7546
0
                    case GGML_OP_POOL_AVG:           res /= ka; break;
7547
0
                    case GGML_OP_POOL_MAX:                      break;
7548
0
                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7549
0
                }
7550
7551
0
                out[ox] = res;
7552
0
            }
7553
0
        }
7554
7555
0
        cdata  += src->nb[2];
7556
0
        dplane += pa;
7557
0
    }
7558
0
}
7559
7560
// ggml_compute_forward_pool_2d_back
7561
7562
void ggml_compute_forward_pool_2d_back(
7563
        const ggml_compute_params * params,
7564
0
        ggml_tensor * dst) {
7565
7566
0
    const ggml_tensor * src  = dst->src[0];
7567
0
    const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
7568
7569
0
    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
7570
7571
0
    if (params->ith != 0) {
7572
0
        return;
7573
0
    }
7574
7575
0
    const int32_t * opts = (const int32_t *)dst->op_params;
7576
0
    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7577
0
    const int k0 = opts[1];
7578
0
    const int k1 = opts[2];
7579
0
    const int s0 = opts[3];
7580
0
    const int s1 = opts[4];
7581
0
    const int p0 = opts[5];
7582
0
    const int p1 = opts[6];
7583
7584
0
    char       * cdata  = (char       *) dst->data;
7585
0
    const char * cdataf = (const char *) dstf->data;
7586
0
    const char * const data_end = cdata + ggml_nbytes(dst);
7587
7588
0
    GGML_ASSERT(params->ith == 0);
7589
0
    memset(cdata, 0, ggml_nbytes(dst));
7590
7591
0
    const int64_t px = src->ne[0];
7592
0
    const int64_t py = src->ne[1];
7593
0
    const int64_t pa = px * py;
7594
7595
0
    const float * splane = (const float *) src->data;
7596
7597
0
    const int ka = k0 * k1;
7598
0
    const int offset0 = -p0;
7599
0
    const int offset1 = -p1;
7600
7601
0
    while (cdata < data_end) {
7602
0
        for (int oy = 0; oy < py; ++oy) {
7603
0
            const float * const srow = splane + oy * px;
7604
0
            for (int ox = 0; ox < px; ++ox) {
7605
0
                const float grad0 = srow[ox];
7606
7607
0
                const int ix = offset0 + ox * s0;
7608
0
                const int iy = offset1 + oy * s1;
7609
7610
0
                if (op == GGML_OP_POOL_MAX) {
7611
0
                    float maxval = -FLT_MAX;
7612
0
                    int kxmax = -1;
7613
0
                    int kymax = -1;
7614
7615
0
                    for (int ky = 0; ky < k1; ++ky) {
7616
0
                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
7617
0
                            continue;
7618
0
                        }
7619
0
                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
7620
0
                        for (int kx = 0; kx < k0; ++kx) {
7621
0
                            int j = ix + kx;
7622
0
                            if (j < 0 || j >= dst->ne[0]) {
7623
0
                                continue;
7624
0
                            }
7625
7626
0
                            const float val = dst->type == GGML_TYPE_F32 ?
7627
0
                                ((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
7628
0
                            if (val <= maxval) {
7629
0
                                continue;
7630
0
                            }
7631
7632
0
                            maxval = val;
7633
0
                            kxmax = kx;
7634
0
                            kymax = ky;
7635
0
                        }
7636
0
                    }
7637
7638
0
                    if (kxmax == -1 || kymax == -1) {
7639
0
                        continue;
7640
0
                    }
7641
7642
0
                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
7643
0
                    const int j = ix + kxmax;
7644
0
                    if (dst->type == GGML_TYPE_F32) {
7645
0
                        ((float *) drow)[j] += grad0;
7646
0
                    } else {
7647
0
                        ((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
7648
0
                    }
7649
0
                } else if (op == GGML_OP_POOL_AVG) {
7650
0
                    const float grad = grad0 / ka;
7651
7652
0
                    for (int ky = 0; ky < k1; ++ky) {
7653
0
                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
7654
0
                            continue;
7655
0
                        }
7656
0
                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
7657
0
                        for (int kx = 0; kx < k0; ++kx) {
7658
0
                            int j = ix + kx;
7659
0
                            if (j < 0 || j >= dst->ne[0]) {
7660
0
                                continue;
7661
0
                            }
7662
7663
0
                            if (dst->type == GGML_TYPE_F32) {
7664
0
                                ((float *) drow)[j] += grad;
7665
0
                            } else {
7666
0
                                ((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
7667
0
                            }
7668
0
                        }
7669
0
                    }
7670
0
                } else {
7671
0
                    GGML_ASSERT(false);
7672
0
                }
7673
0
            }
7674
0
        }
7675
7676
0
        cdata  += dst->nb[2];
7677
0
        cdataf += dst->nb[2];
7678
0
        splane += pa;
7679
0
    }
7680
0
}
7681
7682
// ggml_compute_forward_upscale
7683
7684
static void ggml_compute_forward_upscale_f32(
7685
    const ggml_compute_params * params,
7686
0
    ggml_tensor * dst) {
7687
7688
0
    const ggml_tensor * src0 = dst->src[0];
7689
7690
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
7691
7692
0
    const int ith = params->ith;
7693
0
    const int nth = params->nth;
7694
7695
0
    GGML_TENSOR_UNARY_OP_LOCALS
7696
7697
0
    float sf0 = (float)ne0/src0->ne[0];
7698
0
    float sf1 = (float)ne1/src0->ne[1];
7699
0
    float sf2 = (float)ne2/src0->ne[2];
7700
0
    float sf3 = (float)ne3/src0->ne[3];
7701
0
    float pixel_offset = 0.5f;
7702
7703
0
    const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7704
0
    const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7705
7706
0
    if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7707
0
        pixel_offset = 0.0f;
7708
0
        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7709
0
        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7710
0
    }
7711
7712
0
    if (mode == GGML_SCALE_MODE_NEAREST) {
7713
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7714
0
            const int64_t i03 = i3 / sf3;
7715
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7716
0
                const int64_t i02 = i2 / sf2;
7717
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7718
0
                    const int64_t i01 = i1 / sf1;
7719
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7720
0
                        const int64_t i00 = i0 / sf0;
7721
7722
0
                        const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7723
0
                              float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
7724
7725
0
                        *y = *x;
7726
0
                    }
7727
0
                }
7728
0
            }
7729
0
        }
7730
0
    } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
7731
        // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7732
        // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
7733
0
        auto triangle_filter = [](float x) -> float {
7734
0
            return std::max(1.0f - fabsf(x), 0.0f);
7735
0
        };
7736
7737
        // support and invscale, minimum 1 pixel for bilinear
7738
0
        const float support1  = std::max(1.0f, 1.0f / sf1);
7739
0
        const float invscale1 = 1.0f / support1;
7740
0
        const float support0  = std::max(1.0f, 1.0f / sf0);
7741
0
        const float invscale0 = 1.0f / support0;
7742
7743
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7744
0
            const int64_t i03 = i3 / sf3;
7745
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7746
0
                const int64_t i02 = i2 / sf2;
7747
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7748
0
                    const float y = ((float) i1 + pixel_offset) / sf1;
7749
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7750
0
                        const float x = ((float) i0 + pixel_offset) / sf0;
7751
7752
                        // the range of source pixels that contribute
7753
0
                        const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
7754
0
                        const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
7755
0
                        const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
7756
0
                        const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
7757
7758
                        // bilinear filter with antialiasing
7759
0
                        float val = 0.0f;
7760
0
                        float total_weight = 0.0f;
7761
7762
0
                        for (int64_t sy = y_min; sy < y_max; sy++) {
7763
0
                            const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
7764
7765
0
                            for (int64_t sx = x_min; sx < x_max; sx++) {
7766
0
                                const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
7767
0
                                const float weight = weight_x * weight_y;
7768
7769
0
                                if (weight <= 0.0f) {
7770
0
                                    continue;
7771
0
                                }
7772
7773
0
                                const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
7774
0
                                val += pixel * weight;
7775
0
                                total_weight += weight;
7776
0
                            }
7777
0
                        }
7778
7779
0
                        if (total_weight > 0.0f) {
7780
0
                            val /= total_weight;
7781
0
                        }
7782
7783
0
                        float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7784
0
                        *dst_ptr = val;
7785
0
                    }
7786
0
                }
7787
0
            }
7788
0
        }
7789
0
    } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7790
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7791
0
            const int64_t i03 = i3 / sf3;
7792
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7793
0
                const int64_t i02 = i2 / sf2;
7794
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7795
0
                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7796
0
                    int64_t y0 = (int64_t)floorf(y);
7797
0
                    int64_t y1 = y0 + 1;
7798
7799
0
                    y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
7800
0
                    y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
7801
7802
0
                    float dy = y - (float)y0;
7803
0
                    dy = std::max(0.0f, std::min(dy, 1.0f));
7804
7805
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7806
0
                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7807
0
                        int64_t x0 = (int64_t)floorf(x);
7808
0
                        int64_t x1 = x0 + 1;
7809
7810
0
                        x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
7811
0
                        x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
7812
7813
0
                        float dx = x - (float)x0;
7814
0
                        dx = std::max(0.0f, std::min(dx, 1.0f));
7815
7816
                        // fetch the four surrounding pixel values and interpolate
7817
0
                        const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
7818
0
                        const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
7819
0
                        const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
7820
0
                        const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
7821
7822
0
                        const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7823
7824
0
                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7825
0
                        *y_dst = val;
7826
0
                    }
7827
0
                }
7828
0
            }
7829
0
        }
7830
0
    } else if (mode == GGML_SCALE_MODE_BICUBIC) {
7831
        // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7832
0
        const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7833
0
        auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7834
0
        auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7835
0
        auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7836
0
            const float w0 = weight2(x + 1);
7837
0
            const float w1 = weight1(x + 0);
7838
0
            const float w2 = weight1(1 - x);
7839
0
            const float w3 = weight2(2 - x);
7840
0
            return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7841
0
        };
7842
7843
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7844
0
            const int64_t i03 = i3 / sf3;
7845
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7846
0
                const int64_t i02 = i2 / sf2;
7847
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7848
0
                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7849
0
                    const int64_t y0 = (int64_t)floorf(y);
7850
0
                    const float dy = y - (float)y0;
7851
7852
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7853
0
                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7854
0
                        const int64_t x0 = (int64_t)floorf(x);
7855
0
                        const float dx = x - (float)x0;
7856
7857
0
                        auto p = [=](int64_t x_off, int64_t y_off) -> float {
7858
0
                            int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7859
0
                            int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7860
0
                            return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7861
0
                        };
7862
7863
0
                        const float val = bicubic(
7864
0
                            bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7865
0
                            bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7866
0
                            bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7867
0
                            bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7868
7869
0
                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7870
0
                        *y_dst = val;
7871
0
                    }
7872
0
                }
7873
0
            }
7874
0
        }
7875
0
    } else {
7876
0
        GGML_ABORT("unsupported upscale mode");
7877
0
    }
7878
0
}
7879
7880
void ggml_compute_forward_upscale(
7881
    const ggml_compute_params * params,
7882
0
    ggml_tensor * dst) {
7883
7884
0
    const ggml_tensor * src0 = dst->src[0];
7885
7886
0
    switch (src0->type) {
7887
0
        case GGML_TYPE_F32:
7888
0
            {
7889
0
                ggml_compute_forward_upscale_f32(params, dst);
7890
0
            } break;
7891
0
        default:
7892
0
            {
7893
0
                GGML_ABORT("fatal error");
7894
0
            }
7895
0
    }
7896
0
}
7897
7898
7899
// ggml_compute_forward_pad
7900
7901
template<bool circular_t>
7902
static void ggml_compute_forward_pad_f32(
7903
    const ggml_compute_params * params,
7904
0
          ggml_tensor * dst) {
7905
7906
0
    const ggml_tensor * src0 = dst->src[0];
7907
7908
0
    assert(dst->nb[0] == sizeof(float));
7909
7910
0
    const int ith = params->ith;
7911
0
    const int nth = params->nth;
7912
7913
0
    GGML_TENSOR_UNARY_OP_LOCALS
7914
7915
0
    float * dst_ptr = (float *) dst->data;
7916
0
    const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
7917
0
    const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
7918
0
    const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
7919
0
    const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
7920
0
    const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
7921
0
    const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
7922
0
    const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
7923
0
    const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
7924
7925
    // TODO: optimize
7926
7927
0
    for (int64_t i2 = 0; i2 < ne2; ++i2) {
7928
0
        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7929
0
            for (int64_t i0 = 0; i0 < ne0; ++i0) {
7930
0
                for (int64_t i3 = 0; i3 < ne3; ++i3) {
7931
                    // circular means wrap around on a torus, so x and y loop around
7932
0
                    if constexpr (circular_t) {
7933
0
                        const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7934
0
                        const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
7935
0
                        const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
7936
0
                        const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
7937
0
                        const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
7938
7939
0
                        const int64_t src_idx =
7940
0
                            src_i3*nb03 +
7941
0
                            src_i2*nb02 +
7942
0
                            src_i1*nb01 +
7943
0
                            src_i0*nb00;
7944
7945
0
                        const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7946
0
                        dst_ptr[dst_idx] = *src_ptr;
7947
0
                    } else {
7948
0
                        const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7949
0
                        if ((i0 >= lp0 && i0 < ne0 - rp0) \
7950
0
                            && (i1 >= lp1 && i1 < ne1 - rp1) \
7951
0
                            && (i2 >= lp2 && i2 < ne2 - rp2) \
7952
0
                            && (i3 >= lp3 && i3 < ne3 - rp3)) {
7953
0
                            const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7954
0
                            const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7955
0
                            dst_ptr[dst_idx] = *src_ptr;
7956
0
                        } else {
7957
0
                            dst_ptr[dst_idx] = 0;
7958
0
                        }
7959
0
                    }
7960
0
                }
7961
0
            }
7962
0
        }
7963
0
    }
7964
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_pad_f32<true>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_pad_f32<false>(ggml_compute_params const*, ggml_tensor*)
7965
7966
7967
void ggml_compute_forward_pad(
7968
    const ggml_compute_params * params,
7969
0
    ggml_tensor * dst) {
7970
0
    const ggml_tensor * src0 = dst->src[0];
7971
0
    const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
7972
0
    switch (src0->type) {
7973
0
        case GGML_TYPE_F32:
7974
0
            {
7975
0
                if (circular) {
7976
0
                    ggml_compute_forward_pad_f32<true>(params, dst);
7977
0
                } else {
7978
0
                    ggml_compute_forward_pad_f32<false>(params, dst);
7979
0
                }
7980
0
            } break;
7981
0
        default:
7982
0
            {
7983
0
                GGML_ABORT("fatal error");
7984
0
            }
7985
0
    }
7986
0
}
7987
7988
// ggml_compute_forward_pad_reflect_1d
7989
7990
void ggml_compute_forward_pad_reflect_1d(
7991
        const ggml_compute_params * params,
7992
0
              ggml_tensor * dst) {
7993
7994
0
    const ggml_tensor * src0 = dst->src[0];
7995
7996
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
7997
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
7998
7999
0
    const int ith = params->ith;
8000
0
    const int nth = params->nth;
8001
8002
0
    const int32_t * opts = (const int32_t *) dst->op_params;
8003
0
    const int p0 = opts[0];
8004
0
    const int p1 = opts[1];
8005
8006
0
    GGML_TENSOR_UNARY_OP_LOCALS
8007
8008
0
    for (int64_t i3 = 0; i3 < ne3; i3++) {
8009
0
        for (int64_t i2 = 0; i2 < ne2; i2++) {
8010
0
            for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
8011
0
                float * left  = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 +         p0*nb0);
8012
0
                float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
8013
8014
0
                ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
8015
8016
0
                for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0];   }
8017
0
                for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
8018
0
            }
8019
0
        }
8020
0
    }
8021
0
}
8022
8023
// ggml_compute_forward_roll
8024
8025
0
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
8026
0
    if (i < 0) {
8027
0
        return i + ne;
8028
0
    } else if (i >= ne) {
8029
0
        return i - ne;
8030
0
    }
8031
0
    return i;
8032
0
}
8033
8034
static void ggml_compute_forward_roll_f32(
8035
        const ggml_compute_params * params,
8036
0
        ggml_tensor * dst) {
8037
8038
0
    const ggml_tensor * src0 = dst->src[0];
8039
0
    const float * src_data = (const float *) src0->data;
8040
0
    float * dst_data = (float *) dst->data;
8041
8042
0
    GGML_TENSOR_UNARY_OP_LOCALS
8043
8044
0
    const int s0 = ggml_get_op_params_i32(dst, 0);
8045
0
    const int s1 = ggml_get_op_params_i32(dst, 1);
8046
0
    const int s2 = ggml_get_op_params_i32(dst, 2);
8047
0
    const int s3 = ggml_get_op_params_i32(dst, 3);
8048
8049
0
    const int64_t total = ne1 * ne2 * ne3;
8050
0
    const int64_t per_thread = (total + params->nth) / params->nth;
8051
0
    const int64_t start = params->ith * per_thread;
8052
0
    const int64_t end   = std::min(start + per_thread, total);
8053
8054
0
    for (int64_t i = start; i < end; ++i) {
8055
0
        const int64_t i1 = i % ne1;
8056
0
        const int64_t i2 = (i / ne1) % ne2;
8057
0
        const int64_t i3 = i / (ne2 * ne1);
8058
0
        float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
8059
8060
0
        const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
8061
0
        const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
8062
0
        const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
8063
0
        const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
8064
8065
0
        const int64_t s = ggml_wrap_index(-s0, ne00);
8066
0
        const int64_t n = ne00 - s;
8067
0
        ggml_vec_cpy_f32(n, dst_row,     src_row + s);
8068
0
        ggml_vec_cpy_f32(s, dst_row + n, src_row);
8069
0
    }
8070
0
}
8071
8072
void ggml_compute_forward_roll(
8073
        const ggml_compute_params * params,
8074
0
        ggml_tensor * dst) {
8075
8076
0
    const ggml_tensor * src0 = dst->src[0];
8077
8078
0
    switch (src0->type) {
8079
0
        case GGML_TYPE_F32:
8080
0
            {
8081
0
                ggml_compute_forward_roll_f32(params, dst);
8082
0
            } break;
8083
0
        default:
8084
0
            {
8085
0
                GGML_ABORT("fatal error");
8086
0
            }
8087
0
    }
8088
0
}
8089
8090
// ggml_compute_forward_arange
8091
8092
static void ggml_compute_forward_arange_f32(
8093
    const ggml_compute_params * params,
8094
0
    ggml_tensor * dst) {
8095
8096
0
    GGML_ASSERT(dst->nb[0] == sizeof(float));
8097
8098
0
    const int ith = params->ith;
8099
0
    const int nth = params->nth;
8100
8101
0
    const float start = ggml_get_op_params_f32(dst, 0);
8102
0
    const float stop  = ggml_get_op_params_f32(dst, 1);
8103
0
    const float step  = ggml_get_op_params_f32(dst, 2);
8104
8105
0
    const int64_t steps = (int64_t) ceilf((stop - start) / step);
8106
8107
0
    GGML_ASSERT(ggml_nelements(dst) == steps);
8108
8109
0
    for (int64_t i = ith; i < steps; i+= nth) {
8110
0
        float value = start + step * i;
8111
0
        ((float *)dst->data)[i] = value;
8112
0
    }
8113
0
}
8114
8115
void ggml_compute_forward_arange(
8116
    const ggml_compute_params * params,
8117
0
    ggml_tensor * dst) {
8118
0
    switch (dst->type) {
8119
0
        case GGML_TYPE_F32:
8120
0
            {
8121
0
                ggml_compute_forward_arange_f32(params, dst);
8122
0
            } break;
8123
0
        default:
8124
0
            {
8125
0
                GGML_ABORT("fatal error");
8126
0
            }
8127
0
    }
8128
0
}
8129
8130
static void ggml_compute_forward_timestep_embedding_f32(
8131
    const ggml_compute_params * params,
8132
0
    ggml_tensor * dst) {
8133
8134
0
    const ggml_tensor * src0 = dst->src[0];
8135
8136
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
8137
8138
0
    const int ith = params->ith;
8139
0
    const int nth = params->nth;
8140
8141
0
    GGML_TENSOR_UNARY_OP_LOCALS
8142
8143
0
    const int dim = ggml_get_op_params_i32(dst, 0);
8144
0
    const int max_period = ggml_get_op_params_i32(dst, 1);
8145
8146
0
    int half = dim / 2;
8147
8148
0
    for (int64_t i = 0; i < ne00; i++) {
8149
0
        float * embed_data = (float *)((char *)  dst->data +  i*nb1);
8150
0
        for (int64_t j = ith; j < half; j += nth) {
8151
0
            float timestep = ((float *)src0->data)[i];
8152
0
            float freq = (float)expf(-logf(max_period) * j / half);
8153
0
            float arg = timestep * freq;
8154
0
            embed_data[j] = cosf(arg);
8155
0
            embed_data[j + half] = sinf(arg);
8156
0
        }
8157
0
        if (dim % 2 != 0 && ith == 0) {
8158
0
            embed_data[2 * half] = 0.f;
8159
0
        }
8160
0
    }
8161
0
}
8162
8163
void ggml_compute_forward_timestep_embedding(
8164
    const ggml_compute_params * params,
8165
0
    ggml_tensor * dst) {
8166
8167
0
    const ggml_tensor * src0 = dst->src[0];
8168
8169
0
    switch (src0->type) {
8170
0
        case GGML_TYPE_F32:
8171
0
            {
8172
0
                ggml_compute_forward_timestep_embedding_f32(params, dst);
8173
0
            } break;
8174
0
        default:
8175
0
            {
8176
0
                GGML_ABORT("fatal error");
8177
0
            }
8178
0
    }
8179
0
}
8180
8181
// ggml_compute_forward_argsort
8182
8183
template<enum ggml_sort_order order>
8184
struct cmp_argsort {
8185
    const float * data;
8186
0
    bool operator()(int32_t a, int32_t b) const {
8187
0
        if constexpr (order == GGML_SORT_ORDER_ASC) {
8188
0
            return data[a] < data[b];
8189
0
        } else {
8190
0
            return data[a] > data[b];
8191
0
        }
8192
0
    }
Unexecuted instantiation: cmp_argsort<(ggml_sort_order)0>::operator()(int, int) const
Unexecuted instantiation: cmp_argsort<(ggml_sort_order)1>::operator()(int, int) const
8193
};
8194
8195
static void ggml_compute_forward_argsort_f32(
8196
    const ggml_compute_params * params,
8197
0
    ggml_tensor * dst) {
8198
8199
0
    const ggml_tensor * src0 = dst->src[0];
8200
8201
0
    GGML_TENSOR_UNARY_OP_LOCALS
8202
8203
0
    GGML_ASSERT(nb0 == sizeof(float));
8204
8205
0
    const int ith = params->ith;
8206
0
    const int nth = params->nth;
8207
8208
0
    const int64_t nr = ggml_nrows(src0);
8209
8210
0
    ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
8211
8212
0
    for (int64_t i = ith; i < nr; i += nth) {
8213
0
        const float * src_data = (float *)((char *) src0->data + i*nb01);
8214
8215
0
        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
8216
8217
0
        for (int64_t j = 0; j < ne0; j++) {
8218
0
            dst_data[j] = j;
8219
0
        }
8220
8221
0
        switch (order) {
8222
0
            case GGML_SORT_ORDER_ASC:
8223
0
                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
8224
0
                break;
8225
8226
0
            case GGML_SORT_ORDER_DESC:
8227
0
                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
8228
0
                break;
8229
8230
0
            default:
8231
0
                GGML_ABORT("invalid sort order");
8232
0
        }
8233
0
    }
8234
0
}
8235
8236
void ggml_compute_forward_argsort(
8237
    const ggml_compute_params * params,
8238
0
    ggml_tensor * dst) {
8239
8240
0
    const ggml_tensor * src0 = dst->src[0];
8241
8242
0
    switch (src0->type) {
8243
0
        case GGML_TYPE_F32:
8244
0
            {
8245
0
                ggml_compute_forward_argsort_f32(params, dst);
8246
0
            } break;
8247
0
        default:
8248
0
            {
8249
0
                GGML_ABORT("fatal error");
8250
0
            }
8251
0
    }
8252
0
}
8253
8254
// ggml_compute_forward_top_k
8255
8256
struct cmp_top_k {
8257
    const float * data;
8258
0
    bool operator()(int32_t a, int32_t b) const {
8259
0
        return data[a] > data[b];
8260
0
    }
8261
};
8262
8263
static void ggml_compute_forward_top_k_f32(
8264
    const ggml_compute_params * params,
8265
0
    ggml_tensor * dst) {
8266
8267
0
    const ggml_tensor * src0 = dst->src[0];
8268
8269
0
    GGML_TENSOR_UNARY_OP_LOCALS
8270
8271
0
    GGML_ASSERT(nb0 == sizeof(float));
8272
8273
0
    const int ith = params->ith;
8274
0
    const int nth = params->nth;
8275
8276
0
    const int64_t nr = ggml_nrows(src0);
8277
8278
0
    const int top_k = ne0;
8279
8280
0
    int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
8281
8282
0
    for (int64_t i = ith; i < nr; i += nth) {
8283
0
        const float * src_data = (float *)((char *) src0->data + i*nb01);
8284
8285
0
        for (int64_t j = 0; j < ne00; j++) {
8286
0
            tmp[j] = j;
8287
0
        }
8288
8289
0
        std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
8290
8291
0
        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
8292
8293
0
        std::copy(tmp, tmp + top_k, dst_data);
8294
8295
        // emphasize that the order is not important
8296
0
        if (top_k > 1) {
8297
0
            std::swap(dst_data[0], dst_data[1]);
8298
0
        }
8299
0
    }
8300
0
}
8301
8302
void ggml_compute_forward_top_k(
8303
    const ggml_compute_params * params,
8304
0
    ggml_tensor * dst) {
8305
8306
0
    const ggml_tensor * src0 = dst->src[0];
8307
8308
0
    switch (src0->type) {
8309
0
        case GGML_TYPE_F32:
8310
0
            {
8311
0
                ggml_compute_forward_top_k_f32(params, dst);
8312
0
            } break;
8313
0
        default:
8314
0
            {
8315
0
                GGML_ABORT("fatal error");
8316
0
            }
8317
0
    }
8318
0
}
8319
8320
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8321
        const ggml_compute_params * params,
8322
        ggml_tensor * dst,
8323
        int ir0, int ir1,
8324
        int64_t ic_start, int64_t ic_end,
8325
0
        float * partials, int64_t partial_stride) {
8326
8327
0
    const bool write_partials = (partials != nullptr);
8328
0
    const ggml_tensor * q     = dst->src[0];
8329
0
    const ggml_tensor * k     = dst->src[1];
8330
0
    const ggml_tensor * v     = dst->src[2];
8331
0
    const ggml_tensor * mask  = dst->src[3];
8332
0
    const ggml_tensor * sinks = dst->src[4];
8333
8334
0
    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
8335
0
    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
8336
0
    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
8337
0
    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
8338
0
    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
8339
0
    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
8340
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
8341
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
8342
8343
0
    const int64_t DK = nek0;
8344
0
    const int64_t DV = nev0;
8345
0
    const int64_t N  = neq1;
8346
8347
0
    GGML_ASSERT(ne0 == DV);
8348
0
    GGML_ASSERT(ne2 == N);
8349
8350
    // input tensor rows must be contiguous
8351
0
    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8352
0
    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8353
0
    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8354
8355
0
    GGML_ASSERT(neq0 == DK);
8356
0
    GGML_ASSERT(nek0 == DK);
8357
0
    GGML_ASSERT(nev0 == DV);
8358
8359
0
    GGML_ASSERT(neq1 == N);
8360
8361
    // dst cannot be transposed or permuted
8362
0
    GGML_ASSERT(nb0 == sizeof(float));
8363
0
    GGML_ASSERT(nb0 <= nb1);
8364
0
    GGML_ASSERT(nb1 <= nb2);
8365
0
    GGML_ASSERT(nb2 <= nb3);
8366
8367
    // broadcast factors
8368
0
    const int64_t rk2 = neq2/nek2;
8369
0
    const int64_t rk3 = neq3/nek3;
8370
8371
0
    const int64_t rv2 = neq2/nev2;
8372
0
    const int64_t rv3 = neq3/nev3;
8373
8374
    // parallelize by q rows using ggml_vec_dot_f32
8375
8376
0
    float scale         = 1.0f;
8377
0
    float max_bias      = 0.0f;
8378
0
    float logit_softcap = 0.0f;
8379
8380
0
    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
8381
0
    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
8382
0
    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
8383
8384
0
    if (logit_softcap != 0) {
8385
0
        scale /= logit_softcap;
8386
0
    }
8387
8388
0
    const uint32_t n_head      = neq2;
8389
0
    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
8390
8391
0
    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
8392
0
    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8393
8394
0
    ggml_type         const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
8395
0
    ggml_from_float_t const q_to_vec_dot   = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
8396
0
    ggml_vec_dot_t    const kq_vec_dot     = ggml_get_type_traits_cpu(k->type)->vec_dot;
8397
0
    ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
8398
8399
0
    GGML_ASSERT((                            q_to_vec_dot) && "fattn: unsupported K-type");
8400
0
    GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float  ) && "fattn: unsupported V-type");
8401
8402
0
    int ith = params->ith;
8403
8404
0
    for (int ir = ir0; ir < ir1; ++ir) {
8405
        // q indices
8406
0
        const int iq3 = ir/(neq2*neq1);
8407
0
        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
8408
0
        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
8409
8410
0
        const uint32_t h = iq2; // head index
8411
0
        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
8412
8413
0
        float S = 0.0f;      // sum
8414
0
        float M = -INFINITY; // maximum KQ value
8415
8416
0
        float       * VKQ32 = (float       *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
8417
0
        float       * V32   =                 (VKQ32 + 1*DV); // (temporary) FP32 V buffer
8418
0
        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
8419
0
        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
8420
8421
0
        if (v->type == GGML_TYPE_F16) {
8422
0
            memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
8423
0
        } else {
8424
0
            memset(VKQ32, 0, DV*sizeof(float));
8425
0
        }
8426
8427
0
        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
8428
8429
        // k indices
8430
0
        const int ik3 = iq3 / rk3;
8431
0
        const int ik2 = iq2 / rk2;
8432
8433
        // v indices
8434
0
        const int iv3 = iq3 / rv3;
8435
0
        const int iv2 = iq2 / rv2;
8436
8437
0
        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
8438
0
        q_to_vec_dot(pq, Q_q, DK);
8439
8440
        // online softmax / attention
8441
        // loop over n_kv and n_head_kv
8442
        // ref: https://arxiv.org/pdf/2112.05682.pdf
8443
8444
0
        for (int64_t ic = ic_start; ic < ic_end; ++ic) {
8445
0
            const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
8446
0
            if (mv == -INFINITY) {
8447
0
                continue;
8448
0
            }
8449
8450
0
            float s; // KQ value
8451
8452
0
            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
8453
0
            kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
8454
8455
0
            s = s*scale; // scale KQ value
8456
8457
0
            if (logit_softcap != 0.0f) {
8458
0
                s = logit_softcap*tanhf(s);
8459
0
            }
8460
8461
0
            s += mv; // apply mask
8462
8463
0
            const float Mold = M;
8464
8465
0
            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
8466
0
            float vs = 1.0f; // post-softmax KQ value, expf(s - M)
8467
8468
0
            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
8469
8470
0
            if (v->type == GGML_TYPE_F16) {
8471
0
                if (s > M) {
8472
                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
8473
0
                    M = s;
8474
0
                    ms = expf(Mold - M);
8475
8476
                    // V = V*expf(Mold - M)
8477
0
                    ggml_vec_scale_f16(DV, VKQ16, ms);
8478
0
                } else {
8479
                    // no new maximum, ms == 1.0f, vs != 1.0f
8480
0
                    vs = expf(s - M);
8481
0
                }
8482
8483
                // V += v*expf(s - M)
8484
0
                ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
8485
0
            } else {
8486
0
                if (s > M) {
8487
                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
8488
0
                    M = s;
8489
0
                    ms = expf(Mold - M);
8490
8491
                    // V = V*expf(Mold - M)
8492
0
                    ggml_vec_scale_f32(DV, VKQ32, ms);
8493
0
                } else {
8494
                    // no new maximum, ms == 1.0f, vs != 1.0f
8495
0
                    vs = expf(s - M);
8496
0
                }
8497
8498
                // V += v*expf(s - M)
8499
0
                if (v_to_float) {
8500
0
                    v_to_float(v_data, V32, DV);
8501
0
                    ggml_vec_mad_f32(DV, VKQ32, V32, vs);
8502
0
                } else {
8503
                    // V is F32
8504
0
                    ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
8505
0
                }
8506
0
            }
8507
8508
0
            S = S*ms + vs; // scale and increment sum with partial sum
8509
0
        }
8510
8511
0
        if (v->type == GGML_TYPE_F16) {
8512
0
            for (int64_t d = 0; d < DV; ++d) {
8513
0
                VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
8514
0
            }
8515
0
        }
8516
8517
        // sinks - apply only on the first kv-chunk
8518
0
        if (sinks && ic_start == 0) {
8519
0
            const float s = ((float *)((char *) sinks->data))[h];
8520
8521
0
            float ms = 1.0f;
8522
0
            float vs = 1.0f;
8523
8524
0
            if (s > M) {
8525
0
                ms = expf(M - s);
8526
0
                M = s;
8527
0
                ggml_vec_scale_f32(DV, VKQ32, ms);
8528
0
            } else {
8529
0
                vs = expf(s - M);
8530
0
            }
8531
8532
0
            S = S*ms + vs;
8533
0
        }
8534
8535
0
        if (write_partials) {
8536
            // Write M, S, VKQ to partials for later reduction
8537
            // partials layout: [M, S, VKQ[DV]] per query head
8538
0
            float * partial = partials + ir * partial_stride;
8539
0
            partial[0] = M;
8540
0
            partial[1] = S;
8541
0
            memcpy(partial + 2, VKQ32, DV * sizeof(float));
8542
0
        } else {
8543
            // V /= S
8544
0
            const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8545
0
            ggml_vec_scale_f32(DV, VKQ32, S_inv);
8546
8547
            // dst indices
8548
0
            const int i1 = iq1;
8549
0
            const int i2 = iq2;
8550
0
            const int i3 = iq3;
8551
8552
            // permute(0, 2, 1, 3)
8553
0
            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8554
0
        }
8555
0
    }
8556
0
}
8557
8558
static void ggml_compute_forward_flash_attn_ext_tiled(
8559
        const ggml_compute_params * params,
8560
        ggml_tensor * dst,
8561
0
        int ir0, int ir1) {
8562
0
    const ggml_tensor * q     = dst->src[0];
8563
0
    const ggml_tensor * k     = dst->src[1];
8564
0
    const ggml_tensor * v     = dst->src[2];
8565
0
    const ggml_tensor * mask  = dst->src[3];
8566
0
    const ggml_tensor * sinks = dst->src[4];
8567
8568
0
    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
8569
0
    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
8570
0
    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
8571
0
    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
8572
0
    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
8573
0
    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
8574
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
8575
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
8576
8577
0
    const int64_t DK = nek0;
8578
0
    const int64_t DV = nev0;
8579
0
    const int64_t N  = neq1;
8580
8581
0
    GGML_ASSERT(ne0 == DV);
8582
0
    GGML_ASSERT(ne2 == N);
8583
8584
    // input tensor rows must be contiguous
8585
0
    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8586
0
    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8587
0
    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8588
8589
0
    GGML_ASSERT(neq0 == DK);
8590
0
    GGML_ASSERT(nek0 == DK);
8591
0
    GGML_ASSERT(nev0 == DV);
8592
8593
0
    GGML_ASSERT(neq1 == N);
8594
8595
    // dst cannot be transposed or permuted
8596
0
    GGML_ASSERT(nb0 == sizeof(float));
8597
0
    GGML_ASSERT(nb0 <= nb1);
8598
0
    GGML_ASSERT(nb1 <= nb2);
8599
0
    GGML_ASSERT(nb2 <= nb3);
8600
8601
0
    GGML_ASSERT(k->type == v->type);
8602
0
    const ggml_type kv_type = k->type;
8603
8604
8605
    // broadcast factors
8606
0
    const int64_t rk2 = neq2/nek2;
8607
0
    const int64_t rk3 = neq3/nek3;
8608
8609
0
    const int64_t rv2 = neq2/nev2;
8610
0
    const int64_t rv3 = neq3/nev3;
8611
8612
0
    float scale         = 1.0f;
8613
0
    float max_bias      = 0.0f;
8614
0
    float logit_softcap = 0.0f;
8615
8616
0
    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
8617
0
    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
8618
0
    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
8619
8620
0
    if (logit_softcap != 0) {
8621
0
        scale /= logit_softcap;
8622
0
    }
8623
8624
0
    const uint32_t n_head      = neq2;
8625
0
    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
8626
8627
0
    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
8628
0
    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8629
8630
0
    int ith = params->ith;
8631
8632
0
    static constexpr int Q_TILE_SZ  = ggml_fa_tile_config::Q;
8633
0
    static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
8634
8635
0
    int ir = ir0;
8636
0
    while (ir < ir1) {
8637
        // q indices for the start of this tile
8638
0
        const int iq3 = ir/(neq2*neq1);
8639
0
        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
8640
0
        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
8641
8642
        // Number of valid rows in this tile:
8643
        // - limited by tile size (Q_TILE_SZ)
8644
        // - limited by chunk boundary (ir1 - ir)
8645
        // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
8646
0
        const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
8647
0
        GGML_ASSERT(tile_rows > 0);
8648
8649
0
        const uint32_t h = iq2; // head index
8650
0
        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
8651
8652
0
        float S[Q_TILE_SZ];
8653
0
        float M[Q_TILE_SZ];
8654
8655
0
        for (int i = 0 ; i < Q_TILE_SZ; ++i) {
8656
0
            S[i] = 0.;
8657
0
            M[i] = -INFINITY;
8658
0
        }
8659
8660
        // Per-thread scratch layout:
8661
        // Q_q:    Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
8662
        // KQ:     Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
8663
        // mask:   Q_TILE_SZ * KV_TILE_SZ (mask in float)
8664
        // VKQ32:  Q_TILE_SZ * DV (FP32 output accumulator)
8665
        // V32:    KV_TILE_SZ * DV (F32 buffer for V tile)
8666
        // K_f32:  KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
8667
0
        float * base  = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
8668
8669
0
        void  * Q_q    = base;
8670
0
        float * KQ     = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
8671
0
        float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
8672
0
        float * VKQ32  = mask32 + Q_TILE_SZ * KV_TILE_SZ;
8673
0
        float * V32    = VKQ32 + Q_TILE_SZ * DV;
8674
0
        float * K_f32  = V32 + KV_TILE_SZ * DV;
8675
8676
0
        memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
8677
0
        memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
8678
8679
        // k indices
8680
0
        const int ik3 = iq3 / rk3;
8681
0
        const int ik2 = iq2 / rk2;
8682
8683
        // v indices
8684
0
        const int iv3 = iq3 / rv3;
8685
0
        const int iv2 = iq2 / rv2;
8686
8687
0
        {
8688
0
            float * Q_f32 = (float *)Q_q;
8689
0
            for (int tq = 0; tq < tile_rows; tq++) {
8690
0
                const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8691
0
                memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
8692
0
            }
8693
0
            for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
8694
0
                memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
8695
0
            }
8696
0
        }
8697
8698
0
        memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
8699
0
        memset(V32,   0, KV_TILE_SZ * DV * sizeof(float));
8700
8701
0
        for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
8702
0
            const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
8703
8704
            // skip the tile entirely if all the masks are -inf
8705
0
            if (mask) {
8706
0
                bool can_skip = true;
8707
0
                for (int tq = 0; tq < tile_rows; tq++) {
8708
0
                    const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
8709
0
                    for (int tk = 0; tk < kv_tile; tk++) {
8710
0
                        mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
8711
0
                        if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
8712
0
                            can_skip = false;
8713
0
                        }
8714
0
                    }
8715
                    // Pad remaining mask entries with -inf
8716
0
                    for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8717
0
                        mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
8718
0
                    }
8719
0
                }
8720
8721
0
                if (can_skip) {
8722
0
                    continue;
8723
0
                }
8724
0
            }
8725
8726
            // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
8727
            // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
8728
0
            for (int tk = 0; tk < kv_tile; tk++) {
8729
0
                const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
8730
0
                if (kv_type == GGML_TYPE_F16) {
8731
0
                    const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
8732
0
                    for (int64_t dk = 0; dk < DK; dk++) {
8733
0
                        K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
8734
0
                    }
8735
0
                } else {
8736
0
                    const float * k_f32_src = (const float *)k_data;
8737
0
                    for (int64_t dk = 0; dk < DK; dk++) {
8738
0
                        K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
8739
0
                    }
8740
0
                }
8741
0
            }
8742
0
            memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
8743
0
            simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
8744
0
            ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
8745
8746
            // Set padded KQ entries to -inf so softmax gives them zero weight
8747
0
            if (kv_tile < KV_TILE_SZ) {
8748
0
                for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8749
0
                    for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8750
0
                        KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
8751
0
                    }
8752
0
                }
8753
0
            }
8754
8755
0
            if (logit_softcap != 0.0f) {
8756
0
                ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
8757
0
                ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
8758
0
            }
8759
8760
0
            if (mask) {
8761
0
                ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
8762
0
            }
8763
8764
0
            bool skip[Q_TILE_SZ] = {};
8765
8766
0
            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8767
0
                float * kq_row = KQ + tq * KV_TILE_SZ;
8768
8769
0
                float tile_max;
8770
0
                ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
8771
8772
0
                if (tile_max == -INFINITY) {
8773
0
                    skip[tq] = true;
8774
0
                    continue;
8775
0
                }
8776
8777
0
                const float Mold = M[tq];
8778
0
                const float Mnew = fmaxf(Mold, tile_max);
8779
8780
0
                if (Mnew > Mold) {
8781
0
                    const float ms = expf(Mold - Mnew);
8782
0
                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
8783
0
                    S[tq] *= ms;
8784
0
                }
8785
0
                M[tq] = Mnew;
8786
8787
8788
0
                S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
8789
0
            }
8790
8791
            // V accumulation: VKQ32 += softmax(KQ) * V
8792
            // Pack V tile to contiguous F32, zero-padded
8793
0
            for (int tk = 0; tk < kv_tile; tk++) {
8794
0
                const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
8795
0
                if (kv_type == GGML_TYPE_F16) {
8796
0
                    ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
8797
0
                } else {
8798
0
                    memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
8799
0
                }
8800
0
            }
8801
0
            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8802
0
                if (skip[tq]) {
8803
0
                    memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
8804
0
                }
8805
0
            }
8806
0
            simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
8807
0
        }
8808
8809
        // sinks (apply only to valid rows in the tile)
8810
0
        if (sinks) {
8811
0
            const float s = ((float *)((char *) sinks->data))[h];
8812
8813
0
            for (int tq = 0; tq < tile_rows; tq++) {
8814
0
                float ms = 1.0f;
8815
0
                float vs = 1.0f;
8816
8817
0
                if (s > M[tq]) {
8818
0
                    ms = expf(M[tq] - s);
8819
0
                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
8820
0
                } else {
8821
0
                    vs = expf(s - M[tq]);
8822
0
                }
8823
8824
0
                S[tq] = S[tq] * ms + vs;
8825
0
            }
8826
0
        }
8827
8828
0
        for (int tq = 0; tq < tile_rows; tq++) {
8829
            // V /= S
8830
0
            const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
8831
0
            ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
8832
8833
            // dst indices
8834
0
            const int i1 = iq1 + tq;
8835
0
            const int i2 = iq2;
8836
0
            const int i3 = iq3;
8837
8838
            // permute(0, 2, 1, 3)
8839
0
            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
8840
0
        }
8841
8842
0
        ir += tile_rows;
8843
0
    }
8844
0
}
8845
8846
// Reduction function: combines partial results across KV chunks
8847
// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
8848
static void ggml_flash_attn_ext_reduce_partials(
8849
        const ggml_compute_params * params,
8850
        ggml_tensor * dst,
8851
        const int64_t n_chunks,
8852
0
        const int64_t chunk_size) {
8853
8854
0
    const ggml_tensor * q = dst->src[0];
8855
0
    const ggml_tensor * k = dst->src[1];
8856
0
    const ggml_tensor * v = dst->src[2];
8857
8858
0
    const int64_t DK        = k->ne[0];
8859
0
    const int64_t DV        = v->ne[0];
8860
0
    const int64_t nek1      = k->ne[1];
8861
0
    const int64_t n_q_heads = q->ne[2];
8862
8863
0
    const int ith = params->ith;
8864
0
    const int nth = params->nth;
8865
8866
0
    const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
8867
0
    float *       thread_wdata     = (float *) params->wdata + ith * wdata_per_thread;
8868
8869
0
    const int64_t partials_offset  = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8870
0
    const int64_t partial_size     = 2 + DV;
8871
0
    const float * partials_base    = (const float *) params->wdata + partials_offset;
8872
8873
    // Output layout
8874
0
    const int64_t ne1 = dst->ne[1];
8875
0
    const int64_t ne2 = dst->ne[2];
8876
0
    const size_t  nb1 = dst->nb[1];
8877
8878
    // Each thread reduces a subset of query heads
8879
0
    for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
8880
0
        float   M_final   = -INFINITY;
8881
0
        float   S_final   = 0.0f;
8882
0
        float * VKQ_final = thread_wdata;
8883
0
        memset(VKQ_final, 0, DV * sizeof(float));
8884
8885
        // Combine partials from all chunks
8886
0
        for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
8887
0
            const int64_t ic_start = chunk_idx * chunk_size;
8888
0
            if (ic_start >= nek1) continue;
8889
8890
0
            const float * partial   = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8891
0
            const float   M_chunk   = partial[0];
8892
0
            const float   S_chunk   = partial[1];
8893
0
            const float * VKQ_chunk = partial + 2;
8894
8895
0
            if (S_chunk == 0.0f) continue;
8896
8897
0
            const float M_new     = fmaxf(M_final, M_chunk);
8898
0
            const float scale_old = expf(M_final - M_new);
8899
0
            const float scale_new = expf(M_chunk - M_new);
8900
8901
0
            for (int64_t d = 0; d < DV; ++d) {
8902
0
                VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
8903
0
            }
8904
0
            S_final = S_final * scale_old + S_chunk * scale_new;
8905
0
            M_final = M_new;
8906
0
        }
8907
8908
        // Normalize and write to output
8909
0
        if (S_final != 0.0f) {
8910
0
            const float S_inv = 1.0f / S_final;
8911
0
            ggml_vec_scale_f32(DV, VKQ_final, S_inv);
8912
0
        }
8913
        // iq1=0, iq3=0 for decode
8914
0
        memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
8915
0
    }
8916
0
}
8917
8918
static void ggml_compute_forward_flash_attn_ext_f16(
8919
        const ggml_compute_params * params,
8920
0
        ggml_tensor * dst) {
8921
8922
0
    const ggml_tensor * q     = dst->src[0];
8923
0
    const ggml_tensor * k     = dst->src[1];
8924
0
    const ggml_tensor * v     = dst->src[2];
8925
8926
0
    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
8927
0
    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
8928
0
    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
8929
0
    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
8930
0
    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
8931
0
    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
8932
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
8933
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
8934
8935
0
    const int64_t DK = nek0;
8936
0
    const int64_t DV = nev0;
8937
0
    const int64_t N  = neq1;
8938
8939
8940
0
    GGML_ASSERT(ne0 == DV);
8941
0
    GGML_ASSERT(ne2 == N);
8942
8943
    // input tensor rows must be contiguous
8944
0
    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8945
0
    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8946
0
    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8947
8948
0
    GGML_ASSERT(neq0 == DK);
8949
0
    GGML_ASSERT(nek0 == DK);
8950
0
    GGML_ASSERT(nev0 == DV);
8951
8952
0
    GGML_ASSERT(neq1 == N);
8953
8954
    // dst cannot be transposed or permuted
8955
0
    GGML_ASSERT(nb0 == sizeof(float));
8956
0
    GGML_ASSERT(nb0 <= nb1);
8957
0
    GGML_ASSERT(nb1 <= nb2);
8958
0
    GGML_ASSERT(nb2 <= nb3);
8959
8960
0
    const int ith = params->ith;
8961
0
    const int nth = params->nth;
8962
8963
    // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
8964
0
    const bool use_ref = params->use_ref;
8965
8966
0
    const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
8967
0
    const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
8968
8969
0
    if (use_split_kv_path) {
8970
0
        const int64_t chunk_size = (nek1 + nth - 1) / nth;
8971
8972
        // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
8973
0
        const int64_t partial_size  = 2 + DV;
8974
0
        float *       partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8975
8976
0
        const int64_t ic_start = ith * chunk_size;
8977
0
        const int64_t ic_end   = std::min(ic_start + chunk_size, nek1);
8978
8979
0
        const int64_t partial_stride = nth * partial_size;
8980
0
        float *       chunk_partials = partials_base + ith * partial_size;
8981
8982
0
        if (ic_start < nek1) {
8983
0
            for (int64_t q_head = 0; q_head < neq2; q_head++) {
8984
0
                ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8985
0
                    params, dst, q_head, q_head + 1, ic_start, ic_end,
8986
0
                    chunk_partials, partial_stride);
8987
0
            }
8988
0
        } else {
8989
0
            for (int64_t q_head = 0; q_head < neq2; q_head++) {
8990
0
                float * q_partials = chunk_partials + q_head * partial_stride;
8991
0
                q_partials[0] = -INFINITY;  // M
8992
0
                q_partials[1] = 0.0f;       // S
8993
0
            }
8994
0
        }
8995
8996
0
        ggml_barrier(params->threadpool);
8997
0
        ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
8998
0
    } else {
8999
9000
        // total rows in q
9001
0
        const int64_t nr = neq1*neq2*neq3;
9002
9003
        // disable for NUMA
9004
0
        const bool disable_chunking = ggml_is_numa();
9005
9006
        // 4x chunks per thread
9007
0
        int nth_scaled = nth * 4;
9008
0
        int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
9009
0
        int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
9010
9011
0
        if (nth == 1 || nchunk < nth || disable_chunking) {
9012
0
            nchunk = nth;
9013
0
        }
9014
9015
0
        if (ith == 0) {
9016
0
            ggml_threadpool_chunk_set(params->threadpool, nth);
9017
0
        }
9018
9019
0
        ggml_barrier(params->threadpool);
9020
9021
0
        const int64_t dr = (nr + nchunk - 1) / nchunk;
9022
9023
0
        static constexpr int64_t Q_TILE_SZ  = ggml_fa_tile_config::Q;
9024
0
        bool use_tiled = !use_ref &&
9025
0
                               (q->type == GGML_TYPE_F32 &&
9026
0
                                kv_is_f32_or_f16 &&
9027
0
                                k->type == v->type &&
9028
0
                                neq1 >= Q_TILE_SZ);
9029
0
#ifdef GGML_SIMD
9030
#if defined(__ARM_FEATURE_SVE)
9031
        const int64_t f32_epr = svcntw();
9032
#else
9033
0
        const int64_t f32_epr = GGML_F32_EPR;
9034
0
#endif
9035
0
        use_tiled &= (DV % f32_epr == 0);
9036
0
#endif
9037
0
        int current_chunk = ith;
9038
9039
0
        while (current_chunk < nchunk) {
9040
0
            const int64_t ir0 = dr * current_chunk;
9041
0
            const int64_t ir1 = MIN(ir0 + dr, nr);
9042
9043
0
            if (use_tiled) {
9044
0
                ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
9045
0
            } else {
9046
0
                ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
9047
0
            }
9048
9049
0
            current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
9050
0
        }
9051
0
    }
9052
0
}
9053
9054
void ggml_compute_forward_flash_attn_ext(
9055
        const ggml_compute_params * params,
9056
0
        ggml_tensor * dst) {
9057
0
    switch (dst->op_params[3]) {
9058
0
        case GGML_PREC_DEFAULT:
9059
0
        case GGML_PREC_F32:
9060
0
            {
9061
                // uses F32 accumulators
9062
0
                ggml_compute_forward_flash_attn_ext_f16(params, dst);
9063
0
            } break;
9064
0
        default:
9065
0
            {
9066
0
                GGML_ABORT("fatal error");
9067
0
            }
9068
0
    }
9069
0
}
9070
9071
// ggml_compute_forward_flash_attn_back
9072
9073
static void ggml_compute_forward_flash_attn_back_f32(
9074
        const ggml_compute_params * params,
9075
        const bool masked,
9076
0
              ggml_tensor * dst) {
9077
9078
0
    const ggml_tensor * q = dst->src[0];
9079
0
    const ggml_tensor * k = dst->src[1];
9080
0
    const ggml_tensor * v = dst->src[2];
9081
0
    const ggml_tensor * d = dst->src[3];
9082
9083
0
    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
9084
0
    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
9085
0
    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
9086
0
    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
9087
0
    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
9088
0
    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
9089
0
    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)
9090
0
    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)
9091
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
9092
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
9093
9094
0
    const int ith = params->ith;
9095
0
    const int nth = params->nth;
9096
9097
0
    const int64_t D = neq0;
9098
0
    const int64_t N = neq1;
9099
0
    const int64_t P = nek1 - N;
9100
0
    const int64_t M = P + N;
9101
9102
0
    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
9103
0
    const int mxDM = MAX(D, Mup);
9104
9105
    // GGML_ASSERT(ne0 == D);
9106
    // GGML_ASSERT(ne1 == N);
9107
0
    GGML_ASSERT(P >= 0);
9108
9109
0
    GGML_ASSERT(nbq0 == sizeof(float));
9110
0
    GGML_ASSERT(nbk0 == sizeof(float));
9111
0
    GGML_ASSERT(nbv0 == sizeof(float));
9112
9113
0
    GGML_ASSERT(neq0 == D);
9114
0
    GGML_ASSERT(nek0 == D);
9115
0
    GGML_ASSERT(nev1 == D);
9116
0
    GGML_ASSERT(ned0 == D);
9117
9118
0
    GGML_ASSERT(neq1 == N);
9119
0
    GGML_ASSERT(nek1 == N + P);
9120
0
    GGML_ASSERT(nev1 == D);
9121
0
    GGML_ASSERT(ned1 == N);
9122
9123
    // dst cannot be transposed or permuted
9124
0
    GGML_ASSERT(nb0 == sizeof(float));
9125
0
    GGML_ASSERT(nb0 <= nb1);
9126
0
    GGML_ASSERT(nb1 <= nb2);
9127
0
    GGML_ASSERT(nb2 <= nb3);
9128
9129
0
    if (ith == 0) {
9130
0
        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
9131
0
    }
9132
0
    ggml_barrier(params->threadpool);
9133
9134
0
    const int64_t elem_q = ggml_nelements(q);
9135
0
    const int64_t elem_k = ggml_nelements(k);
9136
9137
0
    ggml_type result_type = dst->type;
9138
0
    GGML_ASSERT(ggml_blck_size(result_type) == 1);
9139
0
    const size_t tsize = ggml_type_size(result_type);
9140
9141
0
    const size_t offs_q = 0;
9142
0
    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
9143
0
    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
9144
9145
0
    void * grad_q = (char *) dst->data;
9146
0
    void * grad_k = (char *) dst->data + offs_k;
9147
0
    void * grad_v = (char *) dst->data + offs_v;
9148
9149
0
    const size_t nbgq1 = nb0*neq0;
9150
0
    const size_t nbgq2 = nb0*neq0*neq1;
9151
0
    const size_t nbgq3 = nb0*neq0*neq1*neq2;
9152
9153
0
    const size_t nbgk1 = nb0*nek0;
9154
0
    const size_t nbgk2 = nb0*nek0*nek1;
9155
0
    const size_t nbgk3 = nb0*nek0*nek1*neq2;
9156
9157
0
    const size_t nbgv1 = nb0*nev0;
9158
0
    const size_t nbgv2 = nb0*nev0*nev1;
9159
0
    const size_t nbgv3 = nb0*nev0*nev1*neq2;
9160
9161
    // parallelize by k rows using ggml_vec_dot_f32
9162
9163
    // total rows in k
9164
0
    const int nr = nek2*nek3;
9165
9166
    // rows per thread
9167
0
    const int dr = (nr + nth - 1)/nth;
9168
9169
    // row range for this thread
9170
0
    const int ir0 = dr*ith;
9171
0
    const int ir1 = MIN(ir0 + dr, nr);
9172
9173
0
    const float scale = 1.0f/sqrtf(D);
9174
9175
    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
9176
9177
    // how often k2 (and v2) is repeated in q2
9178
0
    int nrep = neq2/nek2;
9179
9180
0
    for (int ir = ir0; ir < ir1; ++ir) {
9181
        // q indices
9182
0
        const int ik3 = ir/(nek2);
9183
0
        const int ik2 = ir - ik3*nek2;
9184
9185
0
        const int iq3 = ik3;
9186
0
        const int id3 = ik3;
9187
0
        const int iv3 = ik3;
9188
0
        const int iv2 = ik2;
9189
9190
0
        for (int irep = 0; irep < nrep; ++irep) {
9191
0
            const int iq2 = ik2 + irep*nek2;
9192
0
            const int id2 = iq2;
9193
9194
            // (ik2 + irep*nek2) % nek2 == ik2
9195
0
            for (int iq1 = 0; iq1 < neq1; ++iq1) {
9196
0
                const int id1 = iq1;
9197
9198
                // not sure about CACHE_LINE_SIZE_F32..
9199
                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
9200
0
                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
9201
0
                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
9202
9203
0
                for (int i = M; i < Mup; ++i) {
9204
0
                    S[i] = -INFINITY;
9205
0
                }
9206
9207
0
                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
9208
0
                for (int64_t ic = 0; ic < masked_begin; ++ic) {
9209
                    // k indices
9210
0
                    const int ik1 = ic;
9211
9212
                    // S indices
9213
0
                    const int i1 = ik1;
9214
9215
0
                    ggml_vec_dot_f32(neq0,
9216
0
                            S + i1, 0,
9217
0
                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
9218
0
                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
9219
0
                }
9220
9221
                // scale
9222
0
                ggml_vec_scale_f32(masked_begin, S, scale);
9223
9224
0
                for (int64_t i = masked_begin; i < M; i++) {
9225
0
                    S[i] = -INFINITY;
9226
0
                }
9227
9228
                // softmax
9229
                // exclude known -INF S[..] values from max and loop
9230
                // dont forget to set their SM values to zero
9231
0
                {
9232
0
                    float max = -INFINITY;
9233
0
                    ggml_vec_max_f32(masked_begin, &max, S);
9234
9235
0
                    ggml_float sum = 0.0;
9236
0
                    {
9237
#ifdef GGML_SOFT_MAX_ACCELERATE
9238
                        max = -max;
9239
                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
9240
                        vvexpf(SM, SM, &Mup);
9241
                        ggml_vec_sum_f32(Mup, &sum, SM);
9242
#else
9243
0
                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
9244
0
#endif
9245
0
                    }
9246
9247
0
                    assert(sum > 0.0);
9248
9249
0
                    sum = 1.0/sum;
9250
0
                    ggml_vec_scale_f32(masked_begin, SM, sum);
9251
9252
0
                }
9253
9254
                // step-by-step explanation
9255
0
                {
9256
                    // forward-process                    shape      grads from backward process
9257
                    // parallel_for ik2,ik3:
9258
                    //  for irep:
9259
                    //   iq2 = ik2 + irep*nek2
9260
                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]
9261
                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
9262
                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]
9263
                    //   for iq1:
9264
                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
9265
                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
9266
                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
9267
                    //    S0     = -Inf                   [D,1,1,1]
9268
                    //   ~S1[i]  = dot(kcur[:D,i], qcur)
9269
                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
9270
                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
9271
                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
9272
                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
9273
                    //   ~S5[i]  = dot(vcur[:,i], S4)
9274
                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]
9275
                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
9276
                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
9277
                    // dst                               backward-/ grad[dst]                 = d
9278
                    //
9279
                    // output gradients with their dependencies:
9280
                    //
9281
                    // grad[kcur] = grad[S1].T @ qcur
9282
                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
9283
                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
9284
                    // grad[S4]   = grad[S5] @ vcur
9285
                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
9286
                    // grad[qcur] = grad[S1]   @ kcur
9287
                    // grad[vcur] = grad[S5].T @ S4
9288
                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
9289
                    //
9290
                    // in post-order:
9291
                    //
9292
                    // S1         = qcur @ kcur.T
9293
                    // S2         = S1 * scale
9294
                    // S3         = diag_mask_inf(S2, P)
9295
                    // S4         = softmax(S3)
9296
                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
9297
                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
9298
                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
9299
                    // grad[qcur] = grad[S1]   @ kcur
9300
                    // grad[kcur] = grad[S1].T @ qcur
9301
                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
9302
                    //
9303
                    // using less variables (SM=S4):
9304
                    //
9305
                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
9306
                    // SM            = softmax(S)
9307
                    // S             = d[:D,iq1,iq2,iq3] @ vcur
9308
                    // dot_SM_gradSM = dot(SM, S)
9309
                    // S             = SM * (S - dot(SM, S))
9310
                    // S             = diag_mask_zero(S, P) * scale
9311
                    //
9312
                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
9313
                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur
9314
                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM
9315
0
                }
9316
9317
                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
9318
                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
9319
                // for ic:
9320
                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
9321
                // exclude known future zero S[..] values from operation
9322
0
                ggml_vec_set_f32(masked_begin, S, 0);
9323
0
                for (int64_t ic = 0; ic < D; ++ic) {
9324
0
                    ggml_vec_mad_f32(masked_begin,
9325
0
                            S,
9326
0
                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),
9327
0
                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
9328
0
                }
9329
9330
                // S = SM * (S - dot(SM, S))
9331
0
                float dot_SM_gradSM = 0;
9332
0
                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
9333
0
                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
9334
0
                ggml_vec_mul_f32 (masked_begin, S, S, SM);
9335
9336
                // S = diag_mask_zero(S, P) * scale
9337
                // already done by above ggml_vec_set_f32
9338
9339
                // exclude known zero S[..] values from operation
9340
0
                ggml_vec_scale_f32(masked_begin, S, scale);
9341
9342
                // S    shape [M,1]
9343
                // SM   shape [M,1]
9344
                // kcur shape [D,M]
9345
                // qcur shape [D,1]
9346
                // vcur shape [M,D]
9347
9348
                // grad[q][:D,iq1,iq2,iq3] += S @ kcur
9349
                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
9350
                // for ic:
9351
                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
9352
                // exclude known zero S[..] values from loop
9353
0
                for (int64_t ic = 0; ic < masked_begin; ++ic) {
9354
0
                    ggml_vec_mad_f32(D,
9355
0
                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),
9356
0
                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),
9357
0
                            S[ic]);
9358
0
                }
9359
9360
                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
9361
                // for ic:
9362
                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
9363
                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
9364
                // exclude known zero S[..] values from loop
9365
0
                for (int64_t ic = 0; ic < masked_begin; ++ic) {
9366
0
                    ggml_vec_mad_f32(D,
9367
0
                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),
9368
0
                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),
9369
0
                            S[ic]);
9370
0
                }
9371
9372
                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM
9373
                // for ic:
9374
                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
9375
                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]
9376
                // exclude known zero SM[..] values from mad
9377
0
                for (int64_t ic = 0; ic < D; ++ic) {
9378
0
                    ggml_vec_mad_f32(masked_begin,
9379
0
                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
9380
0
                            SM,
9381
0
                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));
9382
0
                }
9383
0
            }
9384
0
        }
9385
0
    }
9386
0
}
9387
9388
void ggml_compute_forward_flash_attn_back(
9389
        const ggml_compute_params * params,
9390
        const bool masked,
9391
0
        ggml_tensor * dst) {
9392
9393
0
    const ggml_tensor * q = dst->src[0];
9394
9395
0
    switch (q->type) {
9396
0
        case GGML_TYPE_F32:
9397
0
            {
9398
0
                ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
9399
0
            } break;
9400
0
        default:
9401
0
            {
9402
0
                GGML_ABORT("fatal error");
9403
0
            }
9404
0
    }
9405
0
}
9406
9407
// ggml_compute_forward_ssm_conv
9408
9409
static void ggml_compute_forward_ssm_conv_f32(
9410
        const ggml_compute_params * params,
9411
0
        ggml_tensor * dst) {
9412
0
    const ggml_tensor * src0 = dst->src[0]; // conv_x
9413
0
    const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
9414
9415
0
    const int ith = params->ith;
9416
0
    const int nth = params->nth;
9417
9418
0
    const int nc  = src1->ne[0]; // d_conv
9419
0
    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
9420
0
    const int nr  = src0->ne[1]; // d_inner
9421
0
    const int n_t =  dst->ne[1]; // tokens per sequence
9422
0
    const int n_s =  dst->ne[2]; // number of sequences in the batch
9423
9424
0
    GGML_ASSERT( dst->ne[0] == nr);
9425
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
9426
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
9427
0
    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
9428
9429
    // rows per thread
9430
0
    const int dr = (nr + nth - 1)/nth;
9431
9432
    // row range for this thread
9433
0
    const int ir0 = dr*ith;
9434
0
    const int ir1 = MIN(ir0 + dr, nr);
9435
0
    const int ir  = ir1 - ir0;
9436
9437
0
    for (int i3 = 0; i3 < n_s; ++i3) {
9438
0
        for (int i2 = 0; i2 < n_t; ++i2) {
9439
            // {d_conv - 1 + n_t, d_inner, n_seqs}
9440
            // sliding window
9441
0
            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
9442
0
            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
9443
0
            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
9444
9445
            // TODO: transpose the output for smaller strides for big batches?
9446
            // d_inner
9447
0
            for (int i1 = 0; i1 < ir; ++i1) {
9448
                // rowwise dot product
9449
                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
9450
0
                float sumf = 0.0f;
9451
9452
                // d_conv
9453
0
                for (int i0 = 0; i0 < nc; ++i0) {
9454
0
                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
9455
0
                }
9456
0
                x[i1] = sumf;
9457
0
            }
9458
0
        }
9459
0
    }
9460
0
}
9461
9462
void ggml_compute_forward_ssm_conv(
9463
        const ggml_compute_params * params,
9464
0
        ggml_tensor * dst) {
9465
0
    switch (dst->src[0]->type) {
9466
0
        case GGML_TYPE_F32:
9467
0
            {
9468
0
                ggml_compute_forward_ssm_conv_f32(params, dst);
9469
0
            } break;
9470
0
        default:
9471
0
            {
9472
0
                GGML_ABORT("fatal error");
9473
0
            }
9474
0
    }
9475
0
}
9476
9477
// ggml_compute_forward_ssm_scan
9478
9479
static void ggml_compute_forward_ssm_scan_f32(
9480
        const ggml_compute_params * params,
9481
0
        ggml_tensor * dst) {
9482
0
    const ggml_tensor * src0 = dst->src[0]; // s  {d_state, dim, n_head, n_seqs+}
9483
0
    const ggml_tensor * src1 = dst->src[1]; // x  {dim, n_head, n_seq_tokens, n_seqs}
9484
0
    const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
9485
0
    const ggml_tensor * src3 = dst->src[3]; // A  {d_state, n_head} or {1, n_head}
9486
0
    const ggml_tensor * src4 = dst->src[4]; // B  {d_state, n_group, n_seq_tokens, n_seqs}
9487
0
    const ggml_tensor * src5 = dst->src[5]; // C  {d_state, n_group, n_seq_tokens, n_seqs}
9488
0
    const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
9489
9490
0
    const int ith = params->ith;
9491
0
    const int nth = params->nth;
9492
9493
0
    const int64_t nc = src0->ne[0]; // d_state
9494
0
    const int64_t nr = src0->ne[1]; // dim
9495
0
    const int64_t nh = src1->ne[1]; // n_head
9496
0
    const int64_t ng = src4->ne[1];
9497
0
    const int64_t nt = src1->ne[2]; // number of tokens per sequence
9498
0
    const int64_t ns = src1->ne[3]; // number of sequences in the batch
9499
9500
    // can't use ggml_nbytes because src1 is not necessarily contiguous
9501
0
    const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
9502
9503
0
    GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
9504
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
9505
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
9506
0
    GGML_ASSERT(src2->nb[0] == sizeof(float));
9507
0
    GGML_ASSERT(src3->nb[0] == sizeof(float));
9508
0
    GGML_ASSERT(src4->nb[0] == sizeof(float));
9509
0
    GGML_ASSERT(src5->nb[0] == sizeof(float));
9510
0
    GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
9511
0
    GGML_ASSERT(nh % ng == 0);
9512
9513
    // heads per thread
9514
0
    const int dh = (nh + nth - 1)/nth;
9515
9516
    // head range for this thread
9517
0
    const int ih0 = dh*ith;
9518
0
    const int ih1 = MIN(ih0 + dh, nh);
9519
9520
0
    const int32_t * ids = (const int32_t *) src6->data;
9521
9522
0
    for (int i3 = 0; i3 < ns; ++i3) {
9523
0
        const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
9524
0
              float * s  = (      float *) ((      char *) dst->data  + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
9525
9526
0
        for (int i2 = 0; i2 < nt; ++i2) {
9527
0
            const float * x  = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
9528
0
            const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
9529
0
            const float * A  = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
9530
0
            const float * B  = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
9531
0
            const float * C  = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
9532
0
                  float * y  = (      float *) ((      char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
9533
9534
0
            if (src3->ne[0] == 1) {
9535
                // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
9536
9537
                // n_head
9538
0
                for (int h = ih0; h < ih1; ++h) {
9539
                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9540
0
                    const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
9541
0
                    const float dA = expf(dt_soft_plus * A[h]);
9542
0
                    const int g = h / (nh / ng); // repeat_interleave
9543
9544
                    // dim
9545
0
                    for (int i1 = 0; i1 < nr; ++i1) {
9546
0
                        const int ii = i1 + h*nr;
9547
0
                        const float x_dt = x[ii] * dt_soft_plus;
9548
0
                        float sumf = 0.0f;
9549
0
#if defined(GGML_SIMD)
9550
    #if defined(__ARM_FEATURE_SVE)
9551
                        const int ggml_f32_epr = svcntw();
9552
                        const int ggml_f32_step = 1 * ggml_f32_epr;
9553
9554
                        const int np = (nc & ~(ggml_f32_step - 1));
9555
9556
                        GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
9557
9558
                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
9559
                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
9560
9561
                        for (int i = 0; i < np; i += ggml_f32_step) {
9562
                            // TODO: maybe unroll more?
9563
                            for (int j = 0; j < 1; j++) {
9564
                                GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
9565
                                GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
9566
                                GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
9567
9568
                                t0 = GGML_F32_VEC_MUL(t0, adA);
9569
                                t1 = GGML_F32_VEC_MUL(t1, axdt);
9570
9571
                                t0 = GGML_F32_VEC_ADD(t0, t1);
9572
9573
                                sum = GGML_F32_VEC_FMA(sum, t0, t2);
9574
9575
                                GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
9576
                            }
9577
                        }
9578
9579
                        sumf = GGML_F32xt_REDUCE_ONE(sum);
9580
    #elif defined(__riscv_v_intrinsic)
9581
                        // todo: RVV implementation
9582
                        const int np = 0;
9583
    #else
9584
0
                        const int np = (nc & ~(GGML_F32_STEP - 1));
9585
9586
0
                        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9587
9588
0
                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
9589
0
                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
9590
9591
0
                        GGML_F32_VEC ax[GGML_F32_ARR];
9592
0
                        GGML_F32_VEC ay[GGML_F32_ARR];
9593
0
                        GGML_F32_VEC az[GGML_F32_ARR];
9594
9595
0
                        for (int i = 0; i < np; i += GGML_F32_STEP) {
9596
0
                            for (int j = 0; j < GGML_F32_ARR; j++) {
9597
0
                                ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
9598
0
                                ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
9599
0
                                az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
9600
9601
0
                                ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
9602
0
                                ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
9603
9604
0
                                ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
9605
9606
0
                                sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
9607
9608
0
                                GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
9609
0
                            }
9610
0
                        }
9611
9612
                        // reduce sum0..sum3 to sum0
9613
0
                        GGML_F32_VEC_REDUCE(sumf, sum);
9614
0
    #endif
9615
#else
9616
                        const int np = 0;
9617
#endif
9618
                        // d_state
9619
0
                        for (int i0 = np; i0 < nc; ++i0) {
9620
0
                            const int i = i0 + ii*nc;
9621
0
                            const int ig = i0 + g*nc;
9622
                            // state = prev_state * dA + dB * x
9623
0
                            const float state = (s0[i] * dA) + (B[ig] * x_dt);
9624
                            // y = rowwise_dotprod(state, C)
9625
0
                            sumf += state * C[ig];
9626
0
                            s[i] = state;
9627
0
                        }
9628
0
                        y[ii] = sumf;
9629
0
                    }
9630
0
                }
9631
0
            } else {
9632
                // Mamba-1 has an element-wise decay factor for the states
9633
9634
                // n_head
9635
0
                for (int h = ih0; h < ih1; ++h) {
9636
                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9637
0
                    const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
9638
0
                    const int g = h / (nh / ng); // repeat_interleave
9639
9640
                    // dim
9641
0
                    for (int i1 = 0; i1 < nr; ++i1) {
9642
0
                        const int ii = i1 + h*nr;
9643
0
                        const float x_dt = x[ii] * dt_soft_plus;
9644
#if defined(__ARM_FEATURE_SVE)
9645
                        svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
9646
                        svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
9647
                        svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
9648
9649
                        // d_state
9650
                        // TODO: what happens when (d_state % svcntw()) != 0?
9651
                        for (int64_t k = 0; k < nc; k += svcntw()) {
9652
                            svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
9653
                            svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
9654
                            svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
9655
                            svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
9656
9657
                            svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
9658
                            t1 = exp_ps_sve(svptrue_b32(), t1);
9659
                            svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
9660
9661
                            vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
9662
                            r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
9663
9664
                            GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
9665
                        }
9666
                        y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
9667
#else
9668
0
                        float sumf = 0.0f;
9669
                        // NOTE: can't really use GGML_SIMD here because d_state is usually 16
9670
                        //       and also because expf is used within the loop.
9671
                        // d_state
9672
0
                        for (int i0 = 0; i0 < nc; ++i0) {
9673
0
                            const int i = i0 + ii*nc;
9674
0
                            const int ig = i0 + g*nc;
9675
                            // state = prev_state * dA + dB * x
9676
0
                            const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
9677
                            // y = rowwise_dotprod(state, C)
9678
0
                            sumf += state * C[ig];
9679
0
                            s[i] = state;
9680
0
                        }
9681
0
                        y[ii] = sumf;
9682
0
#endif
9683
0
                    }
9684
0
                }
9685
0
            }
9686
            // use the output as the source when it's not the first token-wise iteration
9687
0
            s0 = s;
9688
0
        }
9689
0
    }
9690
0
}
9691
9692
void ggml_compute_forward_ssm_scan(
9693
        const ggml_compute_params * params,
9694
0
        ggml_tensor * dst) {
9695
0
    switch (dst->src[0]->type) {
9696
0
        case GGML_TYPE_F32:
9697
0
            {
9698
0
                ggml_compute_forward_ssm_scan_f32(params, dst);
9699
0
            } break;
9700
0
        default:
9701
0
            {
9702
0
                GGML_ABORT("fatal error");
9703
0
            }
9704
0
    }
9705
0
}
9706
9707
// ggml_compute_forward_win_part
9708
9709
static void ggml_compute_forward_win_part_f32(
9710
        const ggml_compute_params * params,
9711
0
        ggml_tensor * dst) {
9712
0
    GGML_UNUSED(params);
9713
9714
0
    const ggml_tensor * src0 = dst->src[0];
9715
9716
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
9717
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
9718
9719
0
    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
9720
0
    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
9721
0
    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
9722
9723
0
    assert(ne00 == ne0);
9724
0
    assert(ne3  == nep0*nep1);
9725
9726
    // TODO: optimize / multi-thread
9727
0
    for (int py = 0; py < nep1; ++py) {
9728
0
        for (int px = 0; px < nep0; ++px) {
9729
0
            const int64_t i3 = py*nep0 + px;
9730
0
            for (int64_t i2 = 0; i2 < ne2; ++i2) {
9731
0
                for (int64_t i1 = 0; i1 < ne1; ++i1) {
9732
0
                    for (int64_t i0 = 0; i0 < ne0; ++i0) {
9733
0
                        const int64_t i02 = py*w + i2;
9734
0
                        const int64_t i01 = px*w + i1;
9735
0
                        const int64_t i00 = i0;
9736
9737
0
                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;
9738
0
                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;
9739
9740
0
                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
9741
0
                            ((float *) dst->data)[i] = 0.0f;
9742
0
                        } else {
9743
0
                            ((float *) dst->data)[i] = ((float *) src0->data)[j];
9744
0
                        }
9745
0
                    }
9746
0
                }
9747
0
            }
9748
0
        }
9749
0
    }
9750
0
}
9751
9752
void ggml_compute_forward_win_part(
9753
        const ggml_compute_params * params,
9754
0
        ggml_tensor * dst) {
9755
9756
0
    const ggml_tensor * src0 = dst->src[0];
9757
9758
0
    switch (src0->type) {
9759
0
        case GGML_TYPE_F32:
9760
0
            {
9761
0
                ggml_compute_forward_win_part_f32(params, dst);
9762
0
            } break;
9763
0
        default:
9764
0
            {
9765
0
                GGML_ABORT("fatal error");
9766
0
            }
9767
0
    }
9768
0
}
9769
9770
// ggml_compute_forward_win_unpart
9771
9772
static void ggml_compute_forward_win_unpart_f32(
9773
        const ggml_compute_params * params,
9774
0
        ggml_tensor * dst) {
9775
0
    GGML_UNUSED(params);
9776
9777
0
    const ggml_tensor * src0 = dst->src[0];
9778
9779
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
9780
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
9781
9782
0
    const int32_t w = ((const int32_t *)(dst->op_params))[0];
9783
9784
    // padding
9785
0
    const int px = (w - ne1%w)%w;
9786
    //const int py = (w - ne2%w)%w;
9787
9788
0
    const int npx = (px + ne1)/w;
9789
    //const int npy = (py + ne2)/w;
9790
9791
0
    assert(ne0 == ne00);
9792
9793
    // TODO: optimize / multi-thread
9794
0
    for (int64_t i2 = 0; i2 < ne2; ++i2) {
9795
0
        for (int64_t i1 = 0; i1 < ne1; ++i1) {
9796
0
            for (int64_t i0 = 0; i0 < ne0; ++i0) {
9797
0
                const int ip2 = i2/w;
9798
0
                const int ip1 = i1/w;
9799
9800
0
                const int64_t i02 = i2%w;
9801
0
                const int64_t i01 = i1%w;
9802
0
                const int64_t i00 = i0;
9803
9804
0
                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
9805
0
                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;
9806
9807
0
                ((float *) dst->data)[j] = ((float *) src0->data)[i];
9808
0
            }
9809
0
        }
9810
0
    }
9811
0
}
9812
9813
void ggml_compute_forward_win_unpart(
9814
        const ggml_compute_params * params,
9815
0
        ggml_tensor * dst) {
9816
9817
0
    const ggml_tensor * src0 = dst->src[0];
9818
9819
0
    switch (src0->type) {
9820
0
        case GGML_TYPE_F32:
9821
0
            {
9822
0
                ggml_compute_forward_win_unpart_f32(params, dst);
9823
0
            } break;
9824
0
        default:
9825
0
            {
9826
0
                GGML_ABORT("fatal error");
9827
0
            }
9828
0
    }
9829
0
}
9830
9831
//ggml_compute_forward_unary
9832
9833
void ggml_compute_forward_unary(
9834
        const ggml_compute_params * params,
9835
0
        ggml_tensor * dst) {
9836
9837
0
    const ggml_unary_op op = ggml_get_unary_op(dst);
9838
9839
0
    switch (op) {
9840
0
        case GGML_UNARY_OP_ABS:
9841
0
            {
9842
0
                ggml_compute_forward_abs(params, dst);
9843
0
            } break;
9844
0
        case GGML_UNARY_OP_SGN:
9845
0
            {
9846
0
                ggml_compute_forward_sgn(params, dst);
9847
0
            } break;
9848
0
        case GGML_UNARY_OP_NEG:
9849
0
            {
9850
0
                ggml_compute_forward_neg(params, dst);
9851
0
            } break;
9852
0
        case GGML_UNARY_OP_STEP:
9853
0
            {
9854
0
                ggml_compute_forward_step(params, dst);
9855
0
            } break;
9856
0
        case GGML_UNARY_OP_TANH:
9857
0
            {
9858
0
                ggml_compute_forward_tanh(params, dst);
9859
0
            } break;
9860
0
        case GGML_UNARY_OP_ELU:
9861
0
            {
9862
0
                ggml_compute_forward_elu(params, dst);
9863
0
            } break;
9864
0
        case GGML_UNARY_OP_RELU:
9865
0
            {
9866
0
                ggml_compute_forward_relu(params, dst);
9867
0
            } break;
9868
0
        case GGML_UNARY_OP_SIGMOID:
9869
0
            {
9870
0
                ggml_compute_forward_sigmoid(params, dst);
9871
0
            } break;
9872
0
        case GGML_UNARY_OP_GELU:
9873
0
            {
9874
0
                ggml_compute_forward_gelu(params, dst);
9875
0
            } break;
9876
0
        case GGML_UNARY_OP_GELU_ERF:
9877
0
            {
9878
0
                ggml_compute_forward_gelu_erf(params, dst);
9879
0
            } break;
9880
0
        case GGML_UNARY_OP_GELU_QUICK:
9881
0
            {
9882
0
                ggml_compute_forward_gelu_quick(params, dst);
9883
0
            } break;
9884
0
        case GGML_UNARY_OP_SILU:
9885
0
            {
9886
0
                ggml_compute_forward_silu(params, dst);
9887
0
            } break;
9888
0
        case GGML_UNARY_OP_HARDSWISH:
9889
0
            {
9890
0
                ggml_compute_forward_hardswish(params, dst);
9891
0
            } break;
9892
0
        case GGML_UNARY_OP_HARDSIGMOID:
9893
0
            {
9894
0
                ggml_compute_forward_hardsigmoid(params, dst);
9895
0
            } break;
9896
0
        case GGML_UNARY_OP_EXP:
9897
0
            {
9898
0
                ggml_compute_forward_exp(params, dst);
9899
0
            } break;
9900
0
        case GGML_UNARY_OP_FLOOR:
9901
0
            {
9902
0
                ggml_compute_forward_floor(params, dst);
9903
0
            } break;
9904
0
        case GGML_UNARY_OP_CEIL:
9905
0
            {
9906
0
                ggml_compute_forward_ceil(params, dst);
9907
0
            } break;
9908
0
        case GGML_UNARY_OP_ROUND:
9909
0
            {
9910
0
                ggml_compute_forward_round(params, dst);
9911
0
            } break;
9912
0
        case GGML_UNARY_OP_TRUNC:
9913
0
            {
9914
0
                ggml_compute_forward_trunc(params, dst);
9915
0
            } break;
9916
0
        case GGML_UNARY_OP_XIELU:
9917
0
            {
9918
0
                ggml_compute_forward_xielu(params, dst);
9919
0
            } break;
9920
0
        case GGML_UNARY_OP_EXPM1:
9921
0
            {
9922
0
                ggml_compute_forward_expm1(params, dst);
9923
0
            } break;
9924
0
        case GGML_UNARY_OP_SOFTPLUS:
9925
0
            {
9926
0
                ggml_compute_forward_softplus(params, dst);
9927
0
            } break;
9928
0
        default:
9929
0
            {
9930
0
                GGML_ABORT("fatal error");
9931
0
            }
9932
0
    }
9933
0
}
9934
9935
//ggml_compute_forward_glu
9936
9937
void ggml_compute_forward_glu(
9938
        const ggml_compute_params * params,
9939
0
        ggml_tensor * dst) {
9940
9941
0
    const ggml_glu_op op = ggml_get_glu_op(dst);
9942
9943
0
    switch (op) {
9944
0
        case GGML_GLU_OP_REGLU:
9945
0
            {
9946
0
                ggml_compute_forward_reglu(params, dst);
9947
0
            } break;
9948
0
        case GGML_GLU_OP_GEGLU:
9949
0
            {
9950
0
                ggml_compute_forward_geglu(params, dst);
9951
0
            } break;
9952
0
        case GGML_GLU_OP_SWIGLU:
9953
0
            {
9954
0
                ggml_compute_forward_swiglu(params, dst);
9955
0
            } break;
9956
0
        case GGML_GLU_OP_SWIGLU_OAI:
9957
0
            {
9958
0
                ggml_compute_forward_swiglu_oai(params, dst);
9959
0
            } break;
9960
0
        case GGML_GLU_OP_GEGLU_ERF:
9961
0
            {
9962
0
                ggml_compute_forward_geglu_erf(params, dst);
9963
0
            } break;
9964
0
        case GGML_GLU_OP_GEGLU_QUICK:
9965
0
            {
9966
0
                ggml_compute_forward_geglu_quick(params, dst);
9967
0
            } break;
9968
0
        default:
9969
0
            {
9970
0
                GGML_ABORT("fatal error");
9971
0
            }
9972
0
    }
9973
0
}
9974
9975
// ggml_compute_forward_get_rel_pos
9976
9977
static void ggml_compute_forward_get_rel_pos_f16(
9978
        const ggml_compute_params * params,
9979
0
        ggml_tensor * dst) {
9980
0
    GGML_UNUSED(params);
9981
9982
0
    const ggml_tensor * src0 = dst->src[0];
9983
9984
    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
9985
9986
0
    GGML_TENSOR_UNARY_OP_LOCALS
9987
9988
0
    const int64_t w = ne1;
9989
9990
0
    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
9991
0
    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;
9992
9993
0
    for (int64_t i2 = 0; i2 < ne2; ++i2) {
9994
0
        for (int64_t i1 = 0; i1 < ne1; ++i1) {
9995
0
            const int64_t pos = (w - i1 - 1) + i2;
9996
0
            for (int64_t i0 = 0; i0 < ne0; ++i0) {
9997
0
                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
9998
0
            }
9999
0
        }
10000
0
    }
10001
0
}
10002
10003
void ggml_compute_forward_get_rel_pos(
10004
        const ggml_compute_params * params,
10005
0
        ggml_tensor * dst) {
10006
10007
0
    const ggml_tensor * src0 = dst->src[0];
10008
10009
0
    switch (src0->type) {
10010
0
        case GGML_TYPE_F16:
10011
0
        case GGML_TYPE_BF16:
10012
0
            {
10013
0
                ggml_compute_forward_get_rel_pos_f16(params, dst);
10014
0
            } break;
10015
0
        default:
10016
0
            {
10017
0
                GGML_ABORT("fatal error");
10018
0
            }
10019
0
    }
10020
0
}
10021
10022
// ggml_compute_forward_add_rel_pos
10023
10024
static void ggml_compute_forward_add_rel_pos_f32(
10025
        const ggml_compute_params * params,
10026
0
        ggml_tensor * dst) {
10027
10028
0
    const ggml_tensor * src0 = dst->src[0];
10029
0
    const ggml_tensor * src1 = dst->src[1];
10030
0
    const ggml_tensor * src2 = dst->src[2];
10031
10032
0
    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
10033
0
    if (!inplace) {
10034
0
        if (params->ith == 0) {
10035
0
            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
10036
0
        }
10037
0
        ggml_barrier(params->threadpool);
10038
0
    }
10039
    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
10040
10041
0
    float * src1_data = (float *) src1->data;
10042
0
    float * src2_data = (float *) src2->data;
10043
0
    float * dst_data  = (float *) dst->data;
10044
10045
0
    const int64_t ne10 = src1->ne[0];
10046
0
    const int64_t ne11 = src1->ne[1];
10047
0
    const int64_t ne12 = src1->ne[2];
10048
0
    const int64_t ne13 = src1->ne[3];
10049
10050
0
    const int ith = params->ith;
10051
0
    const int nth = params->nth;
10052
10053
    // total patches in dst
10054
0
    const int np = ne13;
10055
10056
    // patches per thread
10057
0
    const int dp = (np + nth - 1)/nth;
10058
10059
    // patch range for this thread
10060
0
    const int ip0 = dp*ith;
10061
0
    const int ip1 = MIN(ip0 + dp, np);
10062
10063
0
    for (int64_t i13 = ip0; i13 < ip1; ++i13) {
10064
0
        for (int64_t i12 = 0; i12 < ne12; ++i12) {
10065
0
            for (int64_t i11 = 0; i11 < ne11; ++i11) {
10066
0
                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
10067
0
                for (int64_t i10 = 0; i10 < ne10; ++i10) {
10068
0
                    const int64_t jp0  = jp1 + i10;
10069
0
                    const float src1_e = src1_data[jp0];
10070
0
                    const float src2_e = src2_data[jp0];
10071
10072
0
                    const int64_t jdh = jp0 * ne10;
10073
0
                    const int64_t jdw = jdh - (ne10 - 1) * i10;
10074
10075
0
                    for (int64_t j = 0; j < ne10; ++j) {
10076
0
                        dst_data[jdh + j     ] += src2_e;
10077
0
                        dst_data[jdw + j*ne10] += src1_e;
10078
0
                    }
10079
0
                }
10080
0
            }
10081
0
        }
10082
0
    }
10083
0
}
10084
10085
void ggml_compute_forward_add_rel_pos(
10086
        const ggml_compute_params * params,
10087
0
        ggml_tensor * dst) {
10088
10089
0
    const ggml_tensor * src0 = dst->src[0];
10090
10091
0
    switch (src0->type) {
10092
0
        case GGML_TYPE_F32:
10093
0
            {
10094
0
                ggml_compute_forward_add_rel_pos_f32(params, dst);
10095
0
            } break;
10096
0
        default:
10097
0
            {
10098
0
                GGML_ABORT("fatal error");
10099
0
            }
10100
0
    }
10101
0
}
10102
10103
// ggml_compute_forward_rwkv_wkv6
10104
10105
static void ggml_compute_forward_rwkv_wkv6_f32(
10106
        const ggml_compute_params * params,
10107
0
        ggml_tensor * dst) {
10108
0
    const int64_t T = dst->src[1]->ne[2];
10109
0
    const int64_t C = dst->ne[0];
10110
0
    const int64_t HEADS = dst->src[1]->ne[1];
10111
0
    const int64_t n_seqs = dst->src[5]->ne[1];
10112
0
    const int64_t head_size = C / HEADS;
10113
10114
0
    float * dst_data = (float *) dst->data;
10115
0
    float * state = ((float *) dst->data) + C * T;
10116
10117
0
    const int ith = params->ith;
10118
0
    const int nth = params->nth;
10119
10120
0
    const int h_start =  (HEADS * (ith    )) / nth;
10121
0
    const int h_end   = ((HEADS * (ith + 1)) / nth < HEADS) ?
10122
0
                         (HEADS * (ith + 1)) / nth : HEADS;
10123
10124
0
    float * k =          (float *) dst->src[0]->data;
10125
0
    float * v =          (float *) dst->src[1]->data;
10126
0
    float * r =          (float *) dst->src[2]->data;
10127
0
    float * time_faaaa = (float *) dst->src[3]->data;
10128
0
    float * time_decay = (float *) dst->src[4]->data;
10129
10130
0
    size_t t_stride = HEADS * head_size; // Same to C
10131
10132
0
    size_t h_stride = C / HEADS;
10133
0
    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
10134
0
    size_t h_stride_2d = head_size * head_size;
10135
10136
0
    if (ith == 0) {
10137
0
        memset(dst_data, 0, T * C * sizeof(float));
10138
0
    }
10139
0
    ggml_barrier(params->threadpool);
10140
10141
10142
0
    #if defined(__AVX__) && !defined(__AVX512F__)
10143
0
        #define GGML_F32X GGML_F32x8
10144
0
        #define GGML_F32X_SET1 GGML_F32x8_SET1
10145
0
        #define GGML_F32X_LOAD GGML_F32x8_LOAD
10146
0
        #define GGML_F32X_STORE GGML_F32x8_STORE
10147
0
        #define GGML_F32X_MUL GGML_F32x8_MUL
10148
0
        #define GGML_F32X_FMA GGML_F32x8_FMA
10149
0
        #define WKV_VECTOR_SIZE 8
10150
    #elif defined(__AVX512F__)
10151
        #define GGML_F32X GGML_F32x16
10152
        #define GGML_F32X_SET1 GGML_F32x16_SET1
10153
        #define GGML_F32X_LOAD GGML_F32x16_LOAD
10154
        #define GGML_F32X_STORE GGML_F32x16_STORE
10155
        #define GGML_F32X_MUL GGML_F32x16_MUL
10156
        #define GGML_F32X_FMA GGML_F32x16_FMA
10157
        #define WKV_VECTOR_SIZE 16
10158
    #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
10159
        #define GGML_F32X GGML_F32xt
10160
        #define GGML_F32X_SET1 GGML_F32xt_SET1
10161
        #define GGML_F32X_LOAD GGML_F32xt_LOAD
10162
        #define GGML_F32X_STORE GGML_F32xt_STORE
10163
        #define GGML_F32X_MUL GGML_F32xt_MUL
10164
        #define GGML_F32X_FMA GGML_F32xt_FMA
10165
        #define WKV_VECTOR_SIZE 8
10166
    #elif defined(__ARM_NEON) && defined(__aarch64__)
10167
        #define GGML_F32X GGML_F32x4
10168
        #define GGML_F32X_SET1 GGML_F32x4_SET1
10169
        #define GGML_F32X_LOAD GGML_F32x4_LOAD
10170
        #define GGML_F32X_STORE GGML_F32x4_STORE
10171
        #define GGML_F32X_MUL GGML_F32x4_MUL
10172
        #define GGML_F32X_FMA GGML_F32x4_FMA
10173
        #define WKV_VECTOR_SIZE 4
10174
    #endif
10175
10176
0
    #ifdef WKV_VECTOR_SIZE
10177
0
        int wkv_vector_size;
10178
        #if defined(__ARM_FEATURE_SVE)
10179
            wkv_vector_size = svcntw();
10180
        #else
10181
0
            wkv_vector_size = WKV_VECTOR_SIZE;
10182
0
        #endif
10183
0
        const int64_t vec_count = head_size / wkv_vector_size;
10184
10185
0
        for (int64_t t = 0; t < T; t++) {
10186
0
            size_t t_offset = t * t_stride;
10187
0
            size_t state_offset = head_size * C * (t / (T / n_seqs));
10188
0
            float * state_cur = state + state_offset;
10189
0
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
10190
10191
0
            for (int64_t h = h_start; h < h_end; h++) {
10192
0
                size_t h_offset = h * h_stride;
10193
0
                size_t t_h_offset = t_offset + h_offset;
10194
0
                size_t h_2d_offset = h * h_stride_2d;
10195
10196
0
                for (int64_t i = 0; i < head_size; i++) {
10197
0
                    size_t t_h_i_offset = t_h_offset + i;
10198
0
                    size_t h_i_offset = h_offset + i;
10199
0
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
10200
10201
0
                    float k_val = k[t_h_i_offset];
10202
0
                    float r_val = r[t_h_i_offset];
10203
0
                    float time_faaaa_val = time_faaaa[h_i_offset];
10204
0
                    float time_decay_val = time_decay[t_h_i_offset];
10205
10206
                    // Broadcast scalar values to vectors
10207
0
                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
10208
0
                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);
10209
0
                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
10210
0
                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
10211
10212
0
                    for (int64_t j = 0; j < vec_count; j++) {
10213
0
                        size_t base_j = j * wkv_vector_size;
10214
0
                        size_t t_h_j_offset = t_h_offset + base_j;
10215
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
10216
10217
                        // Load x elements at once
10218
0
                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
10219
0
                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
10220
0
                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
10221
10222
                        // Compute kv = v * k
10223
0
                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
10224
10225
                        // Compute temp = kv * time_faaaa + prev_state
10226
0
                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
10227
10228
                        // Update dst: dst += temp * r
10229
0
                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
10230
0
                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
10231
10232
                        // Update state: state = prev_state * time_decay + kv
10233
0
                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
10234
0
                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
10235
0
                    }
10236
10237
                    // Handle remaining elements, this will not be used.
10238
0
                    for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
10239
0
                        size_t t_h_j_offset = t_h_offset + j;
10240
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
10241
0
                        float v_val = v[t_h_j_offset];
10242
0
                        float kv_val = v_val * k_val;
10243
0
                        float prev_state_val = state_prev[h_2d_i_j_offset];
10244
0
                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
10245
0
                        dst_data[t_h_j_offset] += temp_val * r_val;
10246
0
                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
10247
0
                    }
10248
0
                }
10249
0
            }
10250
0
        }
10251
10252
    #else
10253
        // basically fused operations:
10254
        // dst = r @ (time_faaaa * (k @ v) + state),
10255
        // state = time_decay * state + (k @ v),
10256
        // recursive through each token
10257
        for (int64_t t = 0; t < T; t++) {
10258
            size_t t_offset = t * t_stride;
10259
            size_t state_offset = head_size * C * (t / (T / n_seqs));
10260
            float * state_cur = state + state_offset;
10261
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
10262
10263
            for (int64_t h = h_start; h < h_end; h++) {
10264
                size_t h_offset = h * h_stride;
10265
                size_t t_h_offset = t_offset + h_offset;
10266
                size_t h_2d_offset = h * h_stride_2d;
10267
10268
                for (int64_t i = 0; i < head_size; i++) {
10269
                    size_t t_h_i_offset = t_h_offset + i;
10270
                    size_t h_i_offset = h_offset + i;
10271
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
10272
10273
                    float k_val = k[t_h_i_offset];
10274
                    float r_val = r[t_h_i_offset];
10275
                    float time_faaaa_val = time_faaaa[h_i_offset];
10276
                    // RWKV v6: different time_decay for each token.
10277
                    float time_decay_val = time_decay[t_h_i_offset];
10278
10279
                    for (int64_t j = 0; j < head_size; j++) {
10280
                        size_t t_h_j_offset = t_h_offset + j;
10281
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
10282
10283
                        float v_val = v[t_h_j_offset];
10284
                        float kv_val = v_val * k_val;
10285
                        float prev_state_val = state_prev[h_2d_i_j_offset];
10286
                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
10287
                        dst_data[t_h_j_offset] += temp_val * r_val;
10288
                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
10289
                    }
10290
                }
10291
            }
10292
        }
10293
    #endif
10294
0
}
10295
10296
10297
void ggml_compute_forward_rwkv_wkv6(
10298
        const ggml_compute_params * params,
10299
0
        ggml_tensor * dst) {
10300
10301
0
    const ggml_tensor * src0 = dst->src[0];
10302
10303
0
    switch (src0->type) {
10304
0
        case GGML_TYPE_F32:
10305
0
            {
10306
0
                ggml_compute_forward_rwkv_wkv6_f32(params, dst);
10307
0
            } break;
10308
0
        default:
10309
0
            {
10310
0
                GGML_ABORT("fatal error");
10311
0
            }
10312
0
    }
10313
0
}
10314
10315
// ggml_compute_forward_gla
10316
10317
static void ggml_compute_forward_gla_f32(
10318
        const ggml_compute_params * params,
10319
0
        ggml_tensor * dst) {
10320
0
    const int64_t T = dst->src[1]->ne[2];
10321
0
    const int64_t C = dst->ne[0];
10322
0
    const int64_t HEADS = dst->src[1]->ne[1];
10323
0
    const int64_t n_seqs = dst->src[4]->ne[1];
10324
0
    const int64_t head_size = C / HEADS;
10325
0
    const float scale = ggml_get_op_params_f32(dst, 0);
10326
10327
0
    float * dst_data = (float *) dst->data;
10328
0
    float * state = ((float *) dst->data) + C * T;
10329
10330
0
    const int ith = params->ith;
10331
0
    const int nth = params->nth;
10332
10333
0
    const int h_start =  (HEADS * (ith    )) / nth;
10334
0
    const int h_end   = ((HEADS * (ith + 1)) / nth < HEADS) ?
10335
0
                         (HEADS * (ith + 1)) / nth : HEADS;
10336
10337
0
    float * k = (float *) dst->src[0]->data;
10338
0
    float * v = (float *) dst->src[1]->data;
10339
0
    float * q = (float *) dst->src[2]->data;
10340
0
    float * g = (float *) dst->src[3]->data;
10341
10342
0
    size_t t_stride = HEADS * head_size; // Same to C
10343
10344
0
    size_t h_stride = C / HEADS;
10345
0
    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
10346
0
    size_t h_stride_2d = head_size * head_size;
10347
10348
0
    if (ith == 0) {
10349
0
        memset(dst_data, 0, T * C * sizeof(float));
10350
0
    }
10351
0
    ggml_barrier(params->threadpool);
10352
10353
10354
0
    #if defined(__AVX__) && !defined(__AVX512F__)
10355
0
        #define GGML_F32X GGML_F32x8
10356
0
        #define GGML_F32X_SET1 GGML_F32x8_SET1
10357
0
        #define GGML_F32X_LOAD GGML_F32x8_LOAD
10358
0
        #define GGML_F32X_STORE GGML_F32x8_STORE
10359
0
        #define GGML_F32X_MUL GGML_F32x8_MUL
10360
0
        #define GGML_F32X_FMA GGML_F32x8_FMA
10361
0
        #define GLA_VECTOR_SIZE 8
10362
    #elif defined(__AVX512F__)
10363
        #define GGML_F32X GGML_F32x16
10364
        #define GGML_F32X_SET1 GGML_F32x16_SET1
10365
        #define GGML_F32X_LOAD GGML_F32x16_LOAD
10366
        #define GGML_F32X_STORE GGML_F32x16_STORE
10367
        #define GGML_F32X_MUL GGML_F32x16_MUL
10368
        #define GGML_F32X_FMA GGML_F32x16_FMA
10369
        #define GLA_VECTOR_SIZE 16
10370
    #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
10371
        #define GGML_F32X GGML_F32xt
10372
        #define GGML_F32X_SET1 GGML_F32xt_SET1
10373
        #define GGML_F32X_LOAD GGML_F32xt_LOAD
10374
        #define GGML_F32X_STORE GGML_F32xt_STORE
10375
        #define GGML_F32X_MUL GGML_F32xt_MUL
10376
        #define GGML_F32X_FMA GGML_F32xt_FMA
10377
        #define GLA_VECTOR_SIZE 8
10378
    #elif defined(__ARM_NEON) && defined(__aarch64__)
10379
        #define GGML_F32X GGML_F32x4
10380
        #define GGML_F32X_SET1 GGML_F32x4_SET1
10381
        #define GGML_F32X_LOAD GGML_F32x4_LOAD
10382
        #define GGML_F32X_STORE GGML_F32x4_STORE
10383
        #define GGML_F32X_MUL GGML_F32x4_MUL
10384
        #define GGML_F32X_FMA GGML_F32x4_FMA
10385
        #define GLA_VECTOR_SIZE 4
10386
    #endif
10387
10388
0
    #ifdef GLA_VECTOR_SIZE
10389
0
        int gla_vector_size;
10390
        #if defined(__ARM_FEATURE_SVE)
10391
            gla_vector_size = svcntw();
10392
        #else
10393
0
            gla_vector_size = GLA_VECTOR_SIZE;
10394
0
        #endif
10395
0
        const int64_t vec_count = head_size / gla_vector_size;
10396
10397
0
        for (int64_t t = 0; t < T; t++) {
10398
0
            size_t t_offset = t * t_stride;
10399
0
            size_t state_offset = head_size * C * (t / (T / n_seqs));
10400
0
            float * state_cur = state + state_offset;
10401
0
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
10402
10403
0
            for (int64_t h = h_start; h < h_end; h++) {
10404
0
                size_t h_offset = h * h_stride;
10405
0
                size_t t_h_offset = t_offset + h_offset;
10406
0
                size_t h_2d_offset = h * h_stride_2d;
10407
10408
0
                for (int64_t i = 0; i < head_size; i++) {
10409
0
                    size_t t_h_i_offset = t_h_offset + i;
10410
0
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
10411
10412
0
                    float k_val = k[t_h_i_offset];
10413
0
                    float q_val = q[t_h_i_offset] * scale;
10414
0
                    float g_val = g[t_h_i_offset];
10415
10416
                    // Broadcast scalar values to vectors
10417
0
                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
10418
0
                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
10419
0
                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
10420
10421
0
                    for (int64_t j = 0; j < vec_count; j++) {
10422
0
                        size_t base_j = j * gla_vector_size;
10423
0
                        size_t t_h_j_offset = t_h_offset + base_j;
10424
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
10425
10426
                        // Load x elements at once
10427
0
                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
10428
0
                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
10429
0
                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
10430
10431
                        // Compute kv = v * k
10432
0
                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
10433
10434
                        // Compute temp = prev_state * g + kv
10435
0
                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
10436
10437
                        // Update dst: dst += temp * q
10438
0
                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
10439
0
                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
10440
10441
                        // Update state
10442
0
                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
10443
0
                    }
10444
10445
                    // Handle remaining elements, this will not be used.
10446
0
                    for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
10447
0
                        size_t t_h_j_offset = t_h_offset + j;
10448
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
10449
0
                        float v_val = v[t_h_j_offset];
10450
0
                        float kv_val = v_val * k_val;
10451
0
                        float prev_state_val = state_prev[h_2d_i_j_offset];
10452
0
                        float temp_val = kv_val + prev_state_val * g_val;
10453
0
                        dst_data[t_h_j_offset] += temp_val * q_val;
10454
0
                        state_cur[h_2d_i_j_offset] = temp_val;
10455
0
                    }
10456
0
                }
10457
0
            }
10458
0
        }
10459
10460
    #else
10461
        for (int64_t t = 0; t < T; t++) {
10462
            size_t t_offset = t * t_stride;
10463
            size_t state_offset = head_size * C * (t / (T / n_seqs));
10464
            float * state_cur = state + state_offset;
10465
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
10466
10467
            for (int64_t h = h_start; h < h_end; h++) {
10468
                size_t h_offset = h * h_stride;
10469
                size_t t_h_offset = t_offset + h_offset;
10470
                size_t h_2d_offset = h * h_stride_2d;
10471
10472
                for (int64_t i = 0; i < head_size; i++) {
10473
                    size_t t_h_i_offset = t_h_offset + i;
10474
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
10475
10476
                    float k_val = k[t_h_i_offset];
10477
                    float q_val = q[t_h_i_offset] * scale;
10478
                    float g_val = g[t_h_i_offset];
10479
10480
                    for (int64_t j = 0; j < head_size; j++) {
10481
                        size_t t_h_j_offset = t_h_offset + j;
10482
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
10483
10484
                        float v_val = v[t_h_j_offset];
10485
                        float kv_val = v_val * k_val;
10486
                        float prev_state_val = state_prev[h_2d_i_j_offset];
10487
                        float temp_val = prev_state_val * g_val + kv_val;
10488
                        dst_data[t_h_j_offset] += temp_val * q_val;
10489
                        state_cur[h_2d_i_j_offset] = temp_val;
10490
                    }
10491
                }
10492
            }
10493
        }
10494
    #endif
10495
0
}
10496
10497
10498
void ggml_compute_forward_gla(
10499
        const ggml_compute_params * params,
10500
0
        ggml_tensor * dst) {
10501
10502
0
    const ggml_tensor * src0 = dst->src[0];
10503
10504
0
    switch (src0->type) {
10505
0
        case GGML_TYPE_F32:
10506
0
            {
10507
0
                ggml_compute_forward_gla_f32(params, dst);
10508
0
            } break;
10509
0
        default:
10510
0
            {
10511
0
                GGML_ABORT("fatal error");
10512
0
            }
10513
0
    }
10514
0
}
10515
10516
0
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
10517
0
    const struct ggml_tensor * src0 = dst->src[0];  // A (lower triangular)
10518
0
    const struct ggml_tensor * src1 = dst->src[1];  // B (RHS)
10519
10520
0
    GGML_TENSOR_BINARY_OP_LOCALS;
10521
10522
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
10523
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
10524
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
10525
10526
0
    GGML_ASSERT(ne00 == ne01); // A must be square
10527
0
    GGML_ASSERT(ne0  == ne10); // solution cols == B cols
10528
0
    GGML_ASSERT(ne1  == ne11); // solution rows == B rows
10529
10530
0
    GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
10531
0
    GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
10532
10533
0
    const int ith = params->ith;
10534
0
    const int nth = params->nth;
10535
10536
0
    const int64_t k = ne10;   // number of RHS columns
10537
0
    const int64_t n = ne11;   // A is n×n
10538
0
    const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
10539
10540
    // chunks per thread
10541
0
    const int64_t dr = (nr + nth - 1)/nth;
10542
10543
    // chunk range for this thread
10544
0
    const int64_t ir0 = dr*ith;
10545
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
10546
10547
0
    const float * A = (const float *) src0->data;  // [n, n, B1, B2]
10548
0
    const float * B = (const float *) src1->data;  // [n, k, B1, B2]
10549
0
          float * X = (      float *) dst->data;   // [n, k, B1, B2]
10550
10551
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
10552
0
        const int64_t i03 = ir/(ne02*k);
10553
0
        const int64_t i02 = (ir - i03*ne02*k)/k;
10554
0
        const int64_t i01 = (ir - i03*ne02*k - i02*k);
10555
10556
0
        const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
10557
0
        const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
10558
10559
0
        float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
10560
10561
0
        for (int64_t i00 = 0; i00 < n; ++i00) {
10562
0
            float sum = 0.0f;
10563
0
            for (int64_t t = 0; t < i00; ++t) {
10564
0
                sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
10565
0
            }
10566
10567
0
            const float diag = A_batch[i00 * n + i00];
10568
0
            assert(diag != 0.0f && "Zero diagonal in triangular matrix");
10569
10570
0
            X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
10571
0
        }
10572
0
    }
10573
0
}
10574
10575
0
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
10576
0
    const ggml_tensor * src0 = dst->src[0];
10577
0
    const ggml_tensor * src1 = dst->src[1];
10578
10579
0
    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
10580
0
        ggml_compute_forward_solve_tri_f32(params, dst);
10581
0
    } else {
10582
0
        GGML_ABORT("fatal error");
10583
0
    }
10584
0
}
10585
10586
// ggml_compute_forward_gated_delta_net
10587
static void ggml_compute_forward_gated_delta_net_one_chunk(
10588
    const ggml_compute_params * params,
10589
    ggml_tensor * dst,
10590
    int64_t ir0,
10591
0
    int64_t ir1) {
10592
10593
0
    ggml_tensor * src_q     = dst->src[0];
10594
0
    ggml_tensor * src_k     = dst->src[1];
10595
0
    ggml_tensor * src_v     = dst->src[2];
10596
0
    ggml_tensor * src_g     = dst->src[3];
10597
0
    ggml_tensor * src_beta  = dst->src[4];
10598
0
    ggml_tensor * src_state = dst->src[5];
10599
10600
0
    const int64_t S_v      = src_v->ne[0];
10601
0
    const int64_t H        = src_v->ne[1];
10602
0
    const int64_t n_tokens = src_v->ne[2];
10603
0
    const int64_t n_seqs   = src_v->ne[3];
10604
10605
0
    GGML_ASSERT(ggml_is_contiguous_rows(src_q));
10606
0
    GGML_ASSERT(ggml_is_contiguous_rows(src_k));
10607
0
    GGML_ASSERT(ggml_is_contiguous_rows(src_v));
10608
0
    GGML_ASSERT(ggml_is_contiguous(src_g));
10609
0
    GGML_ASSERT(ggml_is_contiguous(src_beta));
10610
0
    GGML_ASSERT(ggml_is_contiguous(src_state));
10611
10612
0
    GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
10613
0
    GGML_ASSERT(src_beta->ne[0] == 1);
10614
10615
0
    GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
10616
0
    GGML_TENSOR_LOCALS(size_t,  nbq, src_q, nb);
10617
0
    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
10618
0
    GGML_TENSOR_LOCALS(size_t,  nbk, src_k, nb);
10619
0
    GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
10620
0
    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);
10621
0
    GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
10622
0
    GGML_TENSOR_LOCALS(size_t,  nbg, src_g, nb);
10623
0
    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);
10624
10625
0
    const bool kda = (neg0 == S_v);
10626
10627
    // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
10628
0
    const int64_t K = ggml_get_op_params_i32(dst, 0);
10629
0
    GGML_ASSERT(K >= 1);
10630
    // per-seq stride in floats (seq s starts at state + s * seq_stride)
10631
0
    const int64_t state_seq_stride = src_state->nb[3] / sizeof(float);
10632
10633
0
    const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
10634
0
    const int ith = params->ith;
10635
10636
0
    float * delta       = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32;
10637
0
    float * state_work  = K > 1 ? (delta + S_v) : nullptr;
10638
10639
    // output layout: [attn_scores | new_states]
10640
    // attn_scores: S_v * H * n_tokens * n_seqs    floats
10641
    // new_states:  S_v * S_v * H * n_seqs * K     floats  (K snapshot slots; last min(n_tokens, K))
10642
0
    const int64_t attn_score_elems    = S_v * H * n_tokens * n_seqs;
10643
0
    const int64_t state_size_per_snap = S_v * S_v * H * n_seqs;
10644
0
    float * attn_out_base  = (float *)dst->data;
10645
0
    float * state_out_base = (float *)dst->data + attn_score_elems;
10646
10647
    // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
10648
    // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
10649
10650
0
    const float * state_in_base = (const float *)src_state->data;
10651
10652
  //const int64_t rq1 = nev1 / neq1;
10653
  //const int64_t rk1 = nev1 / nek1;
10654
0
    const int64_t rq3 = nev3 / neq3;
10655
0
    const int64_t rk3 = nev3 / nek3;
10656
10657
0
    const float scale = 1.0f / sqrtf((float) S_v);
10658
10659
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
10660
0
        const int64_t iv1 = ir % H; // head_index
10661
0
        const int64_t iv3 = ir / H; // sequence
10662
10663
0
        const int64_t iq1 = iv1 % neq1;
10664
0
        const int64_t ik1 = iv1 % nek1;
10665
10666
0
        const int64_t iq3 = iv3 / rq3;
10667
0
        const int64_t ik3 = iv3 / rk3;
10668
10669
        // For K=1, write directly to the single output slot to avoid an extra memcpy at the end.
10670
        // For K>1, work in scratch and copy out per-token when the slot is in range.
10671
0
        float * s_out = (K > 1)
10672
0
            ? state_work
10673
0
            : state_out_base + (iv3 * H + iv1) * S_v * S_v;
10674
10675
        // copy input state into the working buffer and operate in-place
10676
        // state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride.
10677
0
        const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v;
10678
0
        memcpy(s_out, s_in, S_v * S_v * sizeof(float));
10679
10680
        // attn output pointer for first token of this (head, seq)
10681
0
        float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
10682
10683
0
        for (int64_t t = 0; t < n_tokens; t++) {
10684
0
            const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
10685
0
            const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
10686
0
            const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
10687
10688
0
            const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
10689
0
            const float * g_d    =  (const float *)((const char *)src_g->data    + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
10690
10691
            // state is stored transposed: s_out[j*S_v + i] = S[i][j]
10692
            // so row j of s_out = column j of S (contiguous access)
10693
10694
0
            if (kda) {
10695
                // precompute exp(g) into delta scratch (reused below)
10696
0
                for (int64_t i = 0; i < S_v; ++i) {
10697
0
                    delta[i] = expf(g_d[i]);
10698
0
                }
10699
                // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
10700
0
                for (int64_t j = 0; j < S_v; ++j) {
10701
0
                    ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
10702
0
                }
10703
0
            } else {
10704
0
                ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
10705
0
            }
10706
10707
            // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
10708
0
            for (int64_t j = 0; j < S_v; ++j) {
10709
0
                float sum = 0.0f;
10710
0
                ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
10711
0
                delta[j] = (v_d[j] - sum) * beta_val;
10712
0
            }
10713
10714
            // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
10715
0
            for (int64_t j = 0; j < S_v; ++j) {
10716
0
                ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
10717
0
            }
10718
10719
            // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
10720
0
            for (int64_t j = 0; j < S_v; ++j) {
10721
0
                float sum = 0.0f;
10722
0
                ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
10723
0
                attn_data[j] = sum * scale;
10724
0
            }
10725
10726
0
            attn_data += S_v * H; // advance to next token
10727
10728
0
            if (K > 1) {
10729
0
                const int64_t target_slot = n_tokens - 1 - t;
10730
0
                if (target_slot >= 0 && target_slot < K) {
10731
0
                    float * curr_state_o = state_out_base + target_slot * state_size_per_snap +
10732
0
                                     (iv3 * H + iv1) * S_v * S_v;
10733
0
                    memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float));
10734
0
                }
10735
0
            }
10736
0
        }
10737
0
    }
10738
0
}
10739
10740
10741
static void ggml_compute_forward_gated_delta_net_f32(
10742
        const ggml_compute_params * params,
10743
0
        ggml_tensor * dst) {
10744
10745
0
    ggml_tensor * V = dst->src[2];
10746
0
    int64_t nr = V->ne[1] * V->ne[3];
10747
10748
    // disable for NUMA
10749
0
    const bool disable_chunking = ggml_is_numa();
10750
10751
0
    int nth = params->nth;
10752
0
    int ith = params->ith;
10753
10754
    // 4x chunks per thread
10755
0
    int nth_scaled = nth * 4;
10756
0
    int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
10757
0
    int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
10758
10759
0
    if (nth == 1 || nchunk < nth || disable_chunking) {
10760
0
      nchunk = nth;
10761
0
    }
10762
10763
0
    if (ith == 0) {
10764
0
      ggml_threadpool_chunk_set(params->threadpool, nth);
10765
0
    }
10766
10767
0
    ggml_barrier(params->threadpool);
10768
10769
0
    const int64_t dr = (nr + nchunk - 1) / nchunk;
10770
10771
0
    int current_chunk = ith;
10772
10773
0
    while (current_chunk < nchunk) {
10774
0
        const int64_t ir0 = dr * current_chunk;
10775
0
        const int64_t ir1 = MIN(ir0 + dr, nr);
10776
10777
0
        ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
10778
0
        current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
10779
0
    }
10780
0
}
10781
10782
void ggml_compute_forward_gated_delta_net(
10783
        const ggml_compute_params * params,
10784
0
        ggml_tensor * dst) {
10785
0
    const ggml_tensor * src0 = dst->src[0];
10786
10787
0
    switch (src0->type) {
10788
0
        case GGML_TYPE_F32:
10789
0
            {
10790
0
                ggml_compute_forward_gated_delta_net_f32(params, dst);
10791
0
            } break;
10792
0
        default:
10793
0
            {
10794
0
                GGML_ABORT("fatal error");
10795
0
            }
10796
0
    }
10797
0
}
10798
10799
// ggml_compute_forward_rwkv_wkv7
10800
10801
static void ggml_compute_forward_rwkv_wkv7_f32(
10802
        const ggml_compute_params * params,
10803
0
        ggml_tensor * dst) {
10804
0
    const int64_t T = dst->src[1]->ne[2];
10805
0
    const int64_t C = dst->ne[0];
10806
0
    const int64_t HEADS = dst->src[1]->ne[1];
10807
0
    const int64_t n_seqs = dst->src[6]->ne[1];
10808
0
    const int64_t head_size = C / HEADS;
10809
10810
0
    float * dst_data = (float *) dst->data;
10811
0
    float * state = ((float *) dst->data) + C * T;
10812
10813
0
    const int ith = params->ith;
10814
0
    const int nth = params->nth;
10815
10816
0
    const int h_start =  (HEADS * (ith    )) / nth;
10817
0
    const int h_end   = ((HEADS * (ith + 1)) / nth < HEADS) ?
10818
0
                         (HEADS * (ith + 1)) / nth : HEADS;
10819
10820
0
    float * r = (float *) dst->src[0]->data;
10821
0
    float * w = (float *) dst->src[1]->data;
10822
0
    float * k = (float *) dst->src[2]->data;
10823
0
    float * v = (float *) dst->src[3]->data;
10824
0
    float * a = (float *) dst->src[4]->data;
10825
0
    float * b = (float *) dst->src[5]->data;
10826
10827
0
    int64_t t_stride = HEADS * head_size; // Same to C
10828
10829
0
    int64_t h_stride = C / HEADS;
10830
0
    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
10831
0
    int64_t h_stride_2d = head_size * head_size;
10832
10833
0
    #if defined(GGML_SIMD)
10834
        #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
10835
            // scalar Route to scalar implementation       //TODO: Write SVE code and RVV code
10836
            for (int64_t t = 0; t < T; t++) {
10837
                int64_t t_offset = t * t_stride;
10838
                int64_t state_offset = head_size * C * (t / (T / n_seqs));
10839
                float * state_cur = state + state_offset;
10840
                float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10841
10842
                for (int64_t h = h_start; h < h_end; h++) {
10843
                    int64_t h_offset = h * h_stride;
10844
                    int64_t t_h_offset = t_offset + h_offset;
10845
                    int64_t h_2d_offset = h * h_stride_2d;
10846
10847
                    for (int64_t i = 0; i < head_size; i++) {
10848
                        int64_t t_h_i_offset = t_h_offset + i;
10849
                        int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
10850
10851
                        float v_val = v[t_h_i_offset];
10852
10853
                        float sa = 0, result = 0;
10854
                        for (int64_t j = 0; j < head_size; j++) {
10855
                            sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
10856
                        }
10857
10858
                        for (int64_t j = 0; j < head_size; j++) {
10859
                            int64_t t_h_j_offset = t_h_offset + j;
10860
                            int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10861
10862
                            float r_val = r[t_h_j_offset];
10863
                            float w_val = w[t_h_j_offset];
10864
                            float k_val = k[t_h_j_offset];
10865
                            float b_val = b[t_h_j_offset];
10866
                            float kv_val = v_val * k_val;
10867
                            float prev_state_val = state_prev[h_2d_i_j_offset];
10868
                            state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10869
                            result += state_cur[h_2d_i_j_offset] * r_val;
10870
                        }
10871
                        dst_data[t_h_i_offset] = result;
10872
                    }
10873
                }
10874
            }
10875
        #else
10876
0
            for (int64_t t = 0; t < T; t++) {
10877
0
                int64_t t_offset = t * t_stride;
10878
0
                int64_t state_offset = head_size * C * (t / (T / n_seqs));
10879
0
                float * state_cur = state + state_offset;
10880
0
                float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10881
10882
0
                for (int64_t h = h_start; h < h_end; h++) {
10883
0
                    int64_t h_offset = h * h_stride;
10884
0
                    int64_t t_h_offset = t_offset + h_offset;
10885
0
                    int64_t h_2d_offset = h * h_stride_2d;
10886
10887
0
                    for (int64_t ii = 0; ii < head_size; ii++) {
10888
0
                        int64_t t_h_i_offset = t_h_offset + ii;
10889
0
                        int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
10890
10891
0
                        GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
10892
10893
0
                        float sa = 0;
10894
0
                        {
10895
0
                            GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
10896
0
                            GGML_F32_VEC ax[GGML_F32_ARR];
10897
0
                            GGML_F32_VEC ay[GGML_F32_ARR];
10898
0
                            for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
10899
0
                                for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
10900
0
                                    ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
10901
0
                                    ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
10902
0
                                    sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
10903
0
                                }
10904
0
                            }
10905
0
                            GGML_F32_VEC_REDUCE(sa, sum);
10906
0
                        }
10907
10908
0
                        GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
10909
10910
0
                        int64_t j = 0;
10911
0
                        GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
10912
0
                        for (; j < head_size; j += GGML_F32_STEP) {
10913
0
                            for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
10914
0
                                int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
10915
0
                                int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
10916
10917
0
                                GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
10918
0
                                GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
10919
0
                                GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
10920
0
                                GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
10921
10922
0
                                k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
10923
10924
0
                                GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
10925
                                // kv + s * decay + sa * b
10926
0
                                state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
10927
0
                                state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
10928
0
                                GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
10929
10930
0
                                result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
10931
0
                            }
10932
0
                        }
10933
0
                        GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
10934
10935
                        // There shouldn't be left-overs though.
10936
0
                        for (; j < head_size; j++) {
10937
0
                            int64_t t_h_j_offset = t_h_offset + j;
10938
0
                            int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10939
10940
0
                            float r_val = r[t_h_j_offset];
10941
0
                            float w_val = w[t_h_j_offset];
10942
0
                            float k_val = k[t_h_j_offset];
10943
0
                            float b_val = b[t_h_j_offset];
10944
0
                            float kv_val = v[t_h_i_offset] * k_val;
10945
10946
0
                            float prev_state_val = state_prev[h_2d_i_j_offset];
10947
0
                            state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10948
0
                            dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
10949
0
                        }
10950
0
                    }
10951
0
                }
10952
0
            }
10953
0
        #endif
10954
    #else
10955
        for (int64_t t = 0; t < T; t++) {
10956
            int64_t t_offset = t * t_stride;
10957
            int64_t state_offset = head_size * C * (t / (T / n_seqs));
10958
            float * state_cur = state + state_offset;
10959
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10960
10961
            for (int64_t h = h_start; h < h_end; h++) {
10962
                int64_t h_offset = h * h_stride;
10963
                int64_t t_h_offset = t_offset + h_offset;
10964
                int64_t h_2d_offset = h * h_stride_2d;
10965
10966
                for (int64_t i = 0; i < head_size; i++) {
10967
                    int64_t t_h_i_offset = t_h_offset + i;
10968
                    int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
10969
10970
                    float v_val = v[t_h_i_offset];
10971
10972
                    float sa = 0, result = 0;
10973
                    for (int64_t j = 0; j < head_size; j++) {
10974
                        sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
10975
                    }
10976
10977
                    for (int64_t j = 0; j < head_size; j++) {
10978
                        int64_t t_h_j_offset = t_h_offset + j;
10979
                        int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10980
10981
                        float r_val = r[t_h_j_offset];
10982
                        float w_val = w[t_h_j_offset];
10983
                        float k_val = k[t_h_j_offset];
10984
                        float b_val = b[t_h_j_offset];
10985
                        float kv_val = v_val * k_val;
10986
                        float prev_state_val = state_prev[h_2d_i_j_offset];
10987
                        state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10988
                        result += state_cur[h_2d_i_j_offset] * r_val;
10989
                    }
10990
                    dst_data[t_h_i_offset] = result;
10991
                }
10992
            }
10993
        }
10994
    #endif
10995
0
}
10996
10997
10998
void ggml_compute_forward_rwkv_wkv7(
10999
        const ggml_compute_params * params,
11000
0
        ggml_tensor * dst) {
11001
11002
0
    const ggml_tensor * src0 = dst->src[0];
11003
11004
0
    switch (src0->type) {
11005
0
        case GGML_TYPE_F32:
11006
0
            {
11007
0
                ggml_compute_forward_rwkv_wkv7_f32(params, dst);
11008
0
            } break;
11009
0
        default:
11010
0
            {
11011
0
                GGML_ABORT("fatal error");
11012
0
            }
11013
0
    }
11014
0
}
11015
11016
// ggml_compute_forward_map_custom1
11017
11018
void ggml_compute_forward_map_custom1(
11019
        const ggml_compute_params * params,
11020
0
              ggml_tensor * dst) {
11021
11022
0
    const ggml_tensor * a = dst->src[0];
11023
11024
0
    struct ggml_map_custom1_op_params p;
11025
0
    memcpy(&p, dst->op_params, sizeof(p));
11026
11027
0
    p.fun(dst, a, params->ith, params->nth, p.userdata);
11028
0
}
11029
11030
// ggml_compute_forward_map_custom2
11031
11032
void ggml_compute_forward_map_custom2(
11033
        const ggml_compute_params * params,
11034
0
              ggml_tensor * dst) {
11035
11036
0
    const ggml_tensor * a = dst->src[0];
11037
0
    const ggml_tensor * b = dst->src[1];
11038
11039
0
    struct ggml_map_custom2_op_params p;
11040
0
    memcpy(&p, dst->op_params, sizeof(p));
11041
11042
0
    p.fun(dst, a, b, params->ith, params->nth, p.userdata);
11043
0
}
11044
11045
// ggml_compute_forward_map_custom3
11046
11047
void ggml_compute_forward_map_custom3(
11048
        const ggml_compute_params * params,
11049
0
              ggml_tensor * dst) {
11050
11051
0
    const ggml_tensor * a = dst->src[0];
11052
0
    const ggml_tensor * b = dst->src[1];
11053
0
    const ggml_tensor * c = dst->src[2];
11054
11055
0
    struct ggml_map_custom3_op_params p;
11056
0
    memcpy(&p, dst->op_params, sizeof(p));
11057
11058
0
    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
11059
0
}
11060
11061
// ggml_compute_forward_custom
11062
11063
void ggml_compute_forward_custom(
11064
    const struct ggml_compute_params * params,
11065
0
          struct ggml_tensor * dst) {
11066
11067
0
    struct ggml_custom_op_params p;
11068
0
    memcpy(&p, dst->op_params, sizeof(p));
11069
11070
0
    p.fun(dst, params->ith, params->nth, p.userdata);
11071
0
}
11072
11073
// ggml_compute_forward_cross_entropy_loss
11074
11075
static void ggml_compute_forward_cross_entropy_loss_f32(
11076
        const ggml_compute_params * params,
11077
0
        ggml_tensor * dst) {
11078
11079
0
    const ggml_tensor * src0 = dst->src[0];
11080
0
    const ggml_tensor * src1 = dst->src[1];
11081
11082
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
11083
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
11084
0
    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
11085
0
    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
11086
0
    GGML_ASSERT(ggml_are_same_shape(src0, src1));
11087
0
    GGML_ASSERT(ggml_is_scalar(dst));
11088
0
    GGML_ASSERT(dst->type == GGML_TYPE_F32);
11089
11090
    // TODO: handle transposed/permuted matrices
11091
0
    const int64_t nc = src0->ne[0];
11092
0
    const int64_t nr = ggml_nrows(src0);
11093
11094
0
    const int ith = params->ith;
11095
0
    const int nth = params->nth;
11096
11097
0
    float * sums =  (float *) params->wdata;
11098
0
    float * st   = ((float *) params->wdata) + nth + ith*nc;
11099
0
    float sum_thread = 0.0f;
11100
11101
0
    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
11102
11103
    // rows per thread
11104
0
    const int64_t dr = (nr + nth - 1)/nth;
11105
11106
    // row range for this thread
11107
0
    const int64_t ir0 = dr*ith;
11108
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
11109
11110
0
    for (int64_t i1 = ir0; i1 < ir1; ++i1) {
11111
0
        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
11112
0
        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
11113
11114
#ifndef NDEBUG
11115
        for (int64_t i = 0; i < nc; ++i) {
11116
            //printf("p[%d] = %f\n", i, p[i]);
11117
            assert(!isnan(s0[i]));
11118
            assert(!isnan(s1[i]));
11119
        }
11120
#endif // NDEBUG
11121
11122
0
        float max = -INFINITY;
11123
0
        ggml_vec_max_f32(nc, &max, s0);
11124
0
        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
11125
0
        assert(sum_softmax >= 0.0);
11126
11127
0
        ggml_vec_add1_f32(nc, st, st, -sum_softmax);
11128
0
        ggml_vec_mul_f32(nc, st, st, s1);
11129
11130
0
        float sum_st = 0.0f;
11131
0
        ggml_vec_sum_f32(nc, &sum_st, st);
11132
0
        sum_thread += sum_st;
11133
11134
#ifndef NDEBUG
11135
        for (int64_t i = 0; i < nc; ++i) {
11136
            assert(!isnan(st[i]));
11137
            assert(!isinf(st[i]));
11138
        }
11139
#endif // NDEBUG
11140
0
    }
11141
0
    sums[ith] = sum_thread;
11142
0
    ggml_barrier(params->threadpool);
11143
11144
0
    if (ith == 0) {
11145
0
        float * dp = (float *) dst->data;
11146
0
        ggml_vec_sum_f32(nth, dp, sums);
11147
0
        dp[0] *= -1.0f / (float) nr;
11148
0
    }
11149
0
}
11150
11151
void ggml_compute_forward_cross_entropy_loss(
11152
        const ggml_compute_params * params,
11153
0
        ggml_tensor * dst) {
11154
11155
0
    const ggml_tensor * src0 = dst->src[0];
11156
11157
0
    switch (src0->type) {
11158
0
        case GGML_TYPE_F32:
11159
0
            {
11160
0
                ggml_compute_forward_cross_entropy_loss_f32(params, dst);
11161
0
            } break;
11162
0
        default:
11163
0
            {
11164
0
                GGML_ABORT("fatal error");
11165
0
            }
11166
0
    }
11167
0
}
11168
11169
// ggml_compute_forward_cross_entropy_loss_back
11170
11171
static void ggml_compute_forward_cross_entropy_loss_back_f32(
11172
        const ggml_compute_params * params,
11173
0
        ggml_tensor * dst) {
11174
11175
0
    const ggml_tensor * grad  = dst->src[0]; // gradient of forward pass output
11176
0
    const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
11177
0
    const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
11178
11179
0
    GGML_ASSERT(ggml_is_contiguous(dst));
11180
0
    GGML_ASSERT(ggml_is_contiguous(src0f));
11181
0
    GGML_ASSERT(ggml_is_contiguous(src1f));
11182
0
    GGML_ASSERT(ggml_is_contiguous(grad));
11183
0
    GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
11184
11185
0
    const int64_t ith = params->ith;
11186
0
    const int64_t nth = params->nth;
11187
11188
    // TODO: handle transposed/permuted matrices
11189
0
    const int64_t nc = src0f->ne[0];
11190
0
    const int64_t nr = ggml_nrows(src0f);
11191
11192
    // rows per thread
11193
0
    const int64_t dr = (nr + nth - 1)/nth;
11194
11195
    // row range for this thread
11196
0
    const int64_t ir0 = dr*ith;
11197
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
11198
11199
0
    const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
11200
11201
0
    for (int64_t i1 = ir0; i1 < ir1; i1++) {
11202
0
        float       * ds0 = (float       *)((char       *) dst->data   + i1*dst->nb[1]);
11203
0
        const float * s0  = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
11204
0
        const float * s1  = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
11205
11206
#ifndef NDEBUG
11207
        for (int64_t i = 0; i < nc; ++i) {
11208
            //printf("p[%d] = %f\n", i, p[i]);
11209
            assert(!isnan(s0[i]));
11210
            assert(!isnan(s1[i]));
11211
        }
11212
#endif // NDEBUG
11213
11214
        // soft_max
11215
0
        float max = -INFINITY;
11216
0
        ggml_vec_max_f32(nc, &max, s0);
11217
0
        const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
11218
0
        assert(sum > 0.0);
11219
0
        ggml_vec_scale_f32(nc, ds0, 1.0/sum);
11220
11221
        // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
11222
0
        ggml_vec_sub_f32(nc, ds0, ds0, s1);
11223
0
        ggml_vec_scale_f32(nc, ds0, d_by_nr);
11224
11225
#ifndef NDEBUG
11226
        for (int64_t i = 0; i < nc; ++i) {
11227
            assert(!isnan(ds0[i]));
11228
            assert(!isinf(ds0[i]));
11229
        }
11230
#endif // NDEBUG
11231
0
    }
11232
0
}
11233
11234
void ggml_compute_forward_cross_entropy_loss_back(
11235
        const ggml_compute_params * params,
11236
0
        ggml_tensor * dst) {
11237
11238
0
    const ggml_tensor * src0 = dst->src[0];
11239
11240
0
    switch (src0->type) {
11241
0
        case GGML_TYPE_F32:
11242
0
            {
11243
0
                ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
11244
0
            } break;
11245
0
        default:
11246
0
            {
11247
0
                GGML_ABORT("fatal error");
11248
0
            }
11249
0
    }
11250
0
}
11251
11252
static void ggml_compute_forward_opt_step_adamw_f32(
11253
        const ggml_compute_params * params,
11254
0
        ggml_tensor * dst) {
11255
11256
0
    const ggml_tensor * src0         = dst->src[0];
11257
0
    const ggml_tensor * src0_grad    = dst->src[1];
11258
0
    const ggml_tensor * src0_grad_m  = dst->src[2];
11259
0
    const ggml_tensor * src0_grad_v  = dst->src[3];
11260
0
    const ggml_tensor * adamw_params = dst->src[4];
11261
11262
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
11263
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
11264
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
11265
0
    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
11266
11267
0
    const int ith = params->ith;
11268
0
    const int nth = params->nth;
11269
11270
0
    const int nr  = ggml_nrows(src0);
11271
11272
0
    GGML_TENSOR_UNARY_OP_LOCALS
11273
0
    GGML_ASSERT(nb00 == sizeof(float));
11274
11275
    // rows per thread
11276
0
    const int dr = (nr + nth - 1)/nth;
11277
11278
    // row range for this thread
11279
0
    const int ir0 = dr*ith;
11280
0
    const int ir1 = MIN(ir0 + dr, nr);
11281
11282
0
    const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
11283
11284
0
    const float alpha  = adamw_params_ptr[0];
11285
0
    const float beta1  = adamw_params_ptr[1];
11286
0
    const float beta2  = adamw_params_ptr[2];
11287
0
    const float eps    = adamw_params_ptr[3];
11288
0
    const float wd     = adamw_params_ptr[4];
11289
0
    const float beta1h = adamw_params_ptr[5];
11290
0
    const float beta2h = adamw_params_ptr[6];
11291
0
    const float keep   = 1.f - alpha * wd;
11292
0
    for (int ir = ir0; ir < ir1; ++ir) {
11293
0
        const int64_t i03 = ir/(ne02*ne01);
11294
0
        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
11295
0
        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
11296
11297
0
        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
11298
11299
0
        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
11300
0
        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
11301
0
        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
11302
0
        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
11303
11304
0
        for (int i00 = 0; i00 < ne00; ++i00) {
11305
0
            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
11306
0
            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
11307
11308
0
            const float mh =       m[i00]*beta1h;
11309
0
            const float vh = sqrtf(v[i00]*beta2h) + eps;
11310
11311
            // The weight decay is applied independently of the Adam momenta m and v.
11312
            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
11313
            // See: https://arxiv.org/pdf/1711.05101v3.pdf
11314
0
            w[i00] = w[i00] * keep - alpha * mh / vh;
11315
0
        }
11316
0
    }
11317
0
}
11318
11319
void ggml_compute_forward_opt_step_adamw(
11320
        const ggml_compute_params * params,
11321
0
        ggml_tensor * dst) {
11322
11323
0
    const ggml_tensor * src0 = dst->src[0];
11324
11325
0
    switch (src0->type) {
11326
0
        case GGML_TYPE_F32:
11327
0
            {
11328
0
                ggml_compute_forward_opt_step_adamw_f32(params, dst);
11329
0
            } break;
11330
0
        default:
11331
0
            {
11332
0
                GGML_ABORT("fatal error");
11333
0
            }
11334
0
    }
11335
0
}
11336
11337
0
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
11338
0
    const ggml_tensor * src0       = dst->src[0];
11339
0
    const ggml_tensor * src0_grad  = dst->src[1];
11340
0
    const ggml_tensor * sgd_params = dst->src[2];
11341
11342
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
11343
0
    GGML_ASSERT(ggml_nelements(sgd_params) == 2);
11344
11345
0
    const int ith = params->ith;
11346
0
    const int nth = params->nth;
11347
11348
0
    const int nr = ggml_nrows(src0);
11349
11350
0
    GGML_TENSOR_UNARY_OP_LOCALS
11351
0
    GGML_ASSERT(nb00 == sizeof(float));
11352
11353
    // rows per thread
11354
0
    const int dr = (nr + nth - 1) / nth;
11355
11356
    // row range for this thread
11357
0
    const int ir0 = dr * ith;
11358
0
    const int ir1 = MIN(ir0 + dr, nr);
11359
11360
    // using adamw param subset we care about - alpha, wd - could have a separate struct
11361
0
    const float * sgd_params_ptr   = ggml_get_data_f32(sgd_params);
11362
0
    const float   alpha            = sgd_params_ptr[0];
11363
0
    const float   keep             = 1.f - alpha * sgd_params_ptr[1];
11364
11365
0
    for (int ir = ir0; ir < ir1; ++ir) {
11366
0
        const int64_t i03 = ir / (ne02 * ne01);
11367
0
        const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
11368
0
        const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
11369
11370
0
        const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
11371
11372
0
        float *       w = (float *) ((char *) src0->data + offset);                   // weight
11373
0
        const float * g = (const float *) ((const char *) src0_grad->data + offset);  // grad
11374
11375
0
        for (int i00 = 0; i00 < ne00; ++i00) {
11376
0
            w[i00] = w[i00] * keep - alpha * g[i00];
11377
0
        }
11378
0
    }
11379
0
}
11380
11381
0
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
11382
0
    const ggml_tensor * src0 = dst->src[0];
11383
11384
0
    switch (src0->type) {
11385
0
        case GGML_TYPE_F32:
11386
0
            {
11387
0
                ggml_compute_forward_opt_step_sgd_f32(params, dst);
11388
0
            }
11389
0
            break;
11390
0
        default:
11391
0
            {
11392
0
                GGML_ABORT("fatal error - sgd is F32 only");
11393
0
            }
11394
0
    }
11395
0
}
11396
11397
0
static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, ggml_tensor * dst) {
11398
0
    const ggml_tensor * src0 = dst->src[0];
11399
0
    const ggml_tensor * src1 = dst->src[1];
11400
11401
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
11402
0
    GGML_ASSERT(dst->type == GGML_TYPE_F32);
11403
11404
0
    GGML_TENSOR_BINARY_OP_LOCALS
11405
11406
0
    const int ith = params->ith;
11407
0
    const int nth = params->nth;
11408
11409
0
    const int64_t n = ne10;
11410
0
    GGML_ASSERT((n & (n - 1)) == 0); // must be power of 2
11411
11412
0
    const int64_t nr = ne11 * ne12 * ne13;
11413
0
    const int64_t rows_per_thread = (nr + nth - 1) / nth;
11414
0
    const int64_t start_row = ith * rows_per_thread;
11415
0
    const int64_t end_row = MIN(start_row + rows_per_thread, nr);
11416
11417
0
    const float scale = 1.0f / sqrtf((float)n);
11418
11419
0
#if defined(GGML_SIMD)
11420
0
    const GGML_F32_VEC v_minus_one = GGML_F32_VEC_SET1(-1.0f);
11421
0
#endif
11422
11423
0
    for (int64_t r = start_row; r < end_row; r++) {
11424
0
        const int64_t i13 = r / (ne11 * ne12);
11425
0
        const int64_t i12 = (r - i13 * ne11 * ne12) / ne11;
11426
0
        const int64_t i11 = r - i13 * ne11 * ne12 - i12 * ne11;
11427
11428
0
        const float * src_row = (const float *) ((const char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13);
11429
0
        float * dst_row = (float *) ((char *) dst->data + i11 * nb1 + i12 * nb2 + i13 * nb3);
11430
11431
0
        for (int64_t j = 0; j < n; j++) {
11432
0
            dst_row[j] = src_row[j] * scale;
11433
0
        }
11434
11435
        // Scalar passes
11436
0
#if defined(GGML_SIMD)
11437
#if defined(__ARM_FEATURE_SVE)
11438
        const int step = svcntw();
11439
#else
11440
0
        const int step = GGML_F32_EPR;
11441
0
#endif
11442
#else
11443
        const int step = n;
11444
#endif
11445
0
        for (int64_t len = 1; len < step && len < n; len <<= 1) {
11446
0
            for (int64_t i = 0; i < n; i += 2 * len) {
11447
0
                for (int64_t j = 0; j < len; j++) {
11448
0
                    float u = dst_row[i + j];
11449
0
                    float v = dst_row[i + len + j];
11450
0
                    dst_row[i + j] = u + v;
11451
0
                    dst_row[i + len + j] = u - v;
11452
0
                }
11453
0
            }
11454
0
        }
11455
11456
        // SIMD passes using GGML_F32_VEC_* macros for multi-architecture support
11457
0
#if defined(GGML_SIMD)
11458
0
        for (int64_t len = step; len < n; len <<= 1) {
11459
0
            for (int64_t i = 0; i < n; i += 2 * len) {
11460
0
                for (int64_t j = 0; j < len; j += step) {
11461
0
                    GGML_F32_VEC u = GGML_F32_VEC_LOAD(dst_row + i + j);
11462
0
                    GGML_F32_VEC v = GGML_F32_VEC_LOAD(dst_row + i + len + j);
11463
11464
0
                    GGML_F32_VEC_STORE(dst_row + i + j,       GGML_F32_VEC_ADD(u, v));
11465
0
                    GGML_F32_VEC_STORE(dst_row + i + len + j, GGML_F32_VEC_FMA(u, v, v_minus_one));
11466
0
                }
11467
0
            }
11468
0
        }
11469
0
#endif
11470
0
    }
11471
0
}
11472
11473
0
void ggml_compute_forward_fwht(const ggml_compute_params * params, ggml_tensor * dst) {
11474
0
    const ggml_tensor * src1 = dst->src[1];
11475
11476
0
    switch (src1->type) {
11477
0
        case GGML_TYPE_F32:
11478
0
            {
11479
0
                ggml_compute_forward_fwht_f32(params, dst);
11480
0
            }
11481
0
            break;
11482
0
        default:
11483
0
            {
11484
0
                GGML_ABORT("fatal error - fwht is F32 only");
11485
0
            }
11486
0
    }
11487
0
}