Coverage Report

Created: 2025-12-28 06:26

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp
Line
Count
Source
1
#include "ops.h"
2
3
#include "ggml-cpu.h"
4
#include "ggml-impl.h"
5
#include "binary-ops.h"
6
#include "ggml.h"
7
#include "unary-ops.h"
8
#include "vec.h"
9
10
#include <cfloat>
11
#include <algorithm>
12
#include <cmath>
13
#include <functional>
14
15
// ggml_compute_forward_dup
16
17
static void ggml_compute_forward_dup_same_cont(
18
        const ggml_compute_params * params,
19
0
        ggml_tensor * dst) {
20
21
0
    const ggml_tensor * src0 = dst->src[0];
22
23
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
24
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
25
0
    GGML_ASSERT(src0->type == dst->type);
26
27
0
    const size_t nb0 = ggml_type_size(src0->type);
28
29
0
    const int ith = params->ith; // thread index
30
0
    const int nth = params->nth; // number of threads
31
32
    // parallelize by blocks
33
0
    const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
34
0
    const int dr = (nk + nth - 1) / nth;
35
0
    const int k0 = dr * ith;
36
0
    const int k1 = MIN(k0 + dr, nk);
37
38
0
    if (k0 < k1) {
39
0
        memcpy(
40
0
            ((char *)  dst->data + k0*nb0),
41
0
            ((char *) src0->data + k0*nb0),
42
0
            (k1 - k0) * nb0);
43
0
    }
44
0
}
45
46
template<typename src_t, typename dst_t>
47
static void ggml_compute_forward_dup_flt(
48
        const ggml_compute_params * params,
49
0
        ggml_tensor * dst) {
50
51
0
    const ggml_tensor * src0 = dst->src[0];
52
53
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
54
0
    GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
55
56
0
    GGML_TENSOR_UNARY_OP_LOCALS
57
58
0
    const int ith = params->ith; // thread index
59
0
    const int nth = params->nth; // number of threads
60
61
    // parallelize by rows
62
0
    const int nr = ne01;
63
    // number of rows per thread
64
0
    const int dr = (nr + nth - 1) / nth;
65
    // row range for this thread
66
0
    const int ir0 = dr * ith;
67
0
    const int ir1 = MIN(ir0 + dr, nr);
68
69
    // case: type & row size equal
70
0
    if (src0->type == dst->type &&
71
0
        ne00 == ne0 &&
72
0
        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
73
        // copy by rows
74
0
        const size_t rs = ne00*nb00;
75
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
76
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
77
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
78
0
                    memcpy(
79
0
                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
80
0
                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
81
0
                        rs);
82
0
                }
83
0
            }
84
0
        }
85
0
        return;
86
0
    }
87
88
    // case: dst tensor is contiguous
89
0
    if (ggml_is_contiguous(dst)) {
90
0
        if (nb00 == sizeof(src_t)) {
91
0
            if constexpr (std::is_same_v<dst_t, src_t>) {
92
                // same type
93
0
                size_t id = 0;
94
0
                const size_t rs = ne00 * nb00;
95
0
                char * dst_ptr = (char *) dst->data;
96
97
0
                for (int i03 = 0; i03 < ne03; i03++) {
98
0
                    for (int i02 = 0; i02 < ne02; i02++) {
99
0
                        id += rs * ir0;
100
0
                        for (int i01 = ir0; i01 < ir1; i01++) {
101
0
                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
102
0
                            memcpy(dst_ptr + id, src0_ptr, rs);
103
0
                            id += rs;
104
0
                        }
105
0
                        id += rs * (ne01 - ir1);
106
0
                    }
107
0
                }
108
0
            } else {
109
                // casting between non-quantized types
110
0
                size_t id = 0;
111
0
                dst_t * dst_ptr = (dst_t *) dst->data;
112
113
0
                for (int i03 = 0; i03 < ne03; i03++) {
114
0
                    for (int i02 = 0; i02 < ne02; i02++) {
115
0
                        id += ne00 * ir0;
116
0
                        for (int i01 = ir0; i01 < ir1; i01++) {
117
0
                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
118
0
                            for (int i00 = 0; i00 < ne00; i00++) {
119
0
                                float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
120
0
                                dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
121
0
                                id++;
122
0
                            }
123
0
                        }
124
0
                        id += ne00 * (ne01 - ir1);
125
0
                    }
126
0
                }
127
0
            }
128
0
        } else {
129
            //printf("%s: this is not optimal - fix me\n", __func__);
130
131
0
            size_t id = 0;
132
0
            dst_t * dst_ptr = (dst_t *) dst->data;
133
134
0
            for (int i03 = 0; i03 < ne03; i03++) {
135
0
                for (int i02 = 0; i02 < ne02; i02++) {
136
0
                    id += ne00 * ir0;
137
0
                    for (int i01 = ir0; i01 < ir1; i01++) {
138
0
                        for (int i00 = 0; i00 < ne00; i00++) {
139
0
                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
140
141
0
                            float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
142
0
                            dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
143
0
                            id++;
144
0
                        }
145
0
                    }
146
0
                    id += ne00 * (ne01 - ir1);
147
0
                }
148
0
            }
149
0
        }
150
0
        return;
151
0
    }
152
153
    // dst counters
154
0
    int64_t i10 = 0;
155
0
    int64_t i11 = 0;
156
0
    int64_t i12 = 0;
157
0
    int64_t i13 = 0;
158
159
0
    if constexpr (std::is_same_v<dst_t, src_t>) {
160
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
161
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
162
0
                i10 += ne00 * ir0;
163
0
                while (i10 >= ne0) {
164
0
                    i10 -= ne0;
165
0
                    if (++i11 == ne1) {
166
0
                        i11 = 0;
167
0
                        if (++i12 == ne2) {
168
0
                            i12 = 0;
169
0
                            if (++i13 == ne3) {
170
0
                                i13 = 0;
171
0
                            }
172
0
                        }
173
0
                    }
174
0
                }
175
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
176
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
177
0
                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
178
0
                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
179
180
0
                        memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
181
182
0
                        if (++i10 == ne00) {
183
0
                            i10 = 0;
184
0
                            if (++i11 == ne01) {
185
0
                                i11 = 0;
186
0
                                if (++i12 == ne02) {
187
0
                                    i12 = 0;
188
0
                                    if (++i13 == ne03) {
189
0
                                        i13 = 0;
190
0
                                    }
191
0
                                }
192
0
                            }
193
0
                        }
194
0
                    }
195
0
                }
196
0
                i10 += ne00 * (ne01 - ir1);
197
0
                while (i10 >= ne0) {
198
0
                    i10 -= ne0;
199
0
                    if (++i11 == ne1) {
200
0
                        i11 = 0;
201
0
                        if (++i12 == ne2) {
202
0
                            i12 = 0;
203
0
                            if (++i13 == ne3) {
204
0
                                i13 = 0;
205
0
                            }
206
0
                        }
207
0
                    }
208
0
                }
209
0
            }
210
0
        }
211
212
0
    } else {
213
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
214
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
215
0
                i10 += ne00 * ir0;
216
0
                while (i10 >= ne0) {
217
0
                    i10 -= ne0;
218
0
                    if (++i11 == ne1) {
219
0
                        i11 = 0;
220
0
                        if (++i12 == ne2) {
221
0
                            i12 = 0;
222
0
                            if (++i13 == ne3) {
223
0
                                i13 = 0;
224
0
                            }
225
0
                        }
226
0
                    }
227
0
                }
228
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
229
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
230
0
                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
231
0
                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
232
233
0
                        float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
234
0
                        *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
235
236
0
                        if (++i10 == ne0) {
237
0
                            i10 = 0;
238
0
                            if (++i11 == ne1) {
239
0
                                i11 = 0;
240
0
                                if (++i12 == ne2) {
241
0
                                    i12 = 0;
242
0
                                    if (++i13 == ne3) {
243
0
                                        i13 = 0;
244
0
                                    }
245
0
                                }
246
0
                            }
247
0
                        }
248
0
                    }
249
0
                }
250
0
                i10 += ne00 * (ne01 - ir1);
251
0
                while (i10 >= ne0) {
252
0
                    i10 -= ne0;
253
0
                    if (++i11 == ne1) {
254
0
                        i11 = 0;
255
0
                        if (++i12 == ne2) {
256
0
                            i12 = 0;
257
0
                            if (++i13 == ne3) {
258
0
                                i13 = 0;
259
0
                            }
260
0
                        }
261
0
                    }
262
0
                }
263
0
            }
264
0
        }
265
0
    }
266
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<unsigned short, unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<unsigned short, ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<unsigned short, float>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<ggml_bf16_t, unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<ggml_bf16_t, float>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, float>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, int>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<int, float>(ggml_compute_params const*, ggml_tensor*)
267
268
269
template<typename src_t>
270
static void ggml_compute_forward_dup_to_q(
271
        const ggml_compute_params * params,
272
0
        ggml_tensor * dst) {
273
274
0
    const ggml_tensor * src0 = dst->src[0];
275
276
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
277
0
    GGML_ASSERT(!ggml_is_quantized(src0->type));
278
279
0
    GGML_TENSOR_UNARY_OP_LOCALS
280
281
0
    const int ith = params->ith; // thread index
282
0
    const int nth = params->nth; // number of threads
283
284
    // parallelize by rows
285
0
    const int nr = ne01;
286
    // number of rows per thread
287
0
    const int dr = (nr + nth - 1) / nth;
288
    // row range for this thread
289
0
    const int ir0 = dr * ith;
290
0
    const int ir1 = MIN(ir0 + dr, nr);
291
292
0
    if (ggml_is_contiguous(dst) &&
293
0
            nb00 == sizeof(src_t) &&
294
0
            ggml_get_type_traits_cpu(dst->type)->from_float) {
295
        // casting non-quantized types --> intermediate f32 --> quantized
296
0
        ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
297
0
        float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
298
299
0
        size_t id = 0;
300
0
        size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
301
0
        char * dst_ptr = (char *) dst->data;
302
303
0
        for (int i03 = 0; i03 < ne03; i03++) {
304
0
            for (int i02 = 0; i02 < ne02; i02++) {
305
0
                id += rs * ir0;
306
0
                for (int i01 = ir0; i01 < ir1; i01++) {
307
0
                    const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
308
309
0
                    for (int i00 = 0; i00 < ne00; i00++) {
310
0
                        src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
311
0
                    }
312
313
0
                    quantize_row_q(src0_f32, dst_ptr + id, ne00);
314
0
                    id += rs;
315
0
                }
316
0
                id += rs * (ne01 - ir1);
317
0
            }
318
0
        }
319
0
    } else {
320
        // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
321
0
        GGML_ABORT("not implemented");
322
0
    }
323
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_to_q<unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_to_q<ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_to_q<float>(ggml_compute_params const*, ggml_tensor*)
324
325
// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
326
static void ggml_compute_forward_dup_bytes(
327
        const ggml_compute_params * params,
328
0
        ggml_tensor * dst) {
329
0
    const ggml_tensor * src0 = dst->src[0];
330
331
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
332
0
    GGML_ASSERT(src0->type == dst->type);
333
334
0
    GGML_TENSOR_UNARY_OP_LOCALS;
335
336
0
    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
337
0
        ggml_compute_forward_dup_same_cont(params, dst);
338
0
        return;
339
0
    }
340
341
0
    const size_t type_size = ggml_type_size(src0->type);
342
343
0
    const int ith = params->ith; // thread index
344
0
    const int nth = params->nth; // number of threads
345
346
    // parallelize by rows
347
0
    const int nr = ne01;
348
    // number of rows per thread
349
0
    const int dr = (nr + nth - 1) / nth;
350
    // row range for this thread
351
0
    const int ir0 = dr * ith;
352
0
    const int ir1 = MIN(ir0 + dr, nr);
353
354
0
    if (src0->type == dst->type &&
355
0
        ggml_are_same_shape(src0, dst) &&
356
0
        nb00 == type_size && nb0 == type_size) {
357
        // copy by rows
358
0
        const size_t rs = ggml_row_size(src0->type, ne00);
359
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
360
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
361
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
362
0
                    memcpy(
363
0
                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
364
0
                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
365
0
                        rs);
366
0
                }
367
0
            }
368
0
        }
369
0
        return;
370
0
    }
371
372
0
    if (ggml_is_contiguous(dst)) {
373
0
        size_t id = 0;
374
0
        char * dst_ptr = (char *) dst->data;
375
0
        const size_t rs = ne00 * type_size;
376
377
0
        if (nb00 == type_size) {
378
            // src0 is contigous on first dimension, copy by rows
379
0
            for (int64_t i03 = 0; i03 < ne03; i03++) {
380
0
                for (int64_t i02 = 0; i02 < ne02; i02++) {
381
0
                    id += rs * ir0;
382
0
                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
383
0
                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
384
0
                        memcpy(dst_ptr + id, src0_ptr, rs);
385
0
                        id += rs;
386
0
                    }
387
0
                    id += rs * (ne01 - ir1);
388
0
                }
389
0
            }
390
0
        } else {
391
            //printf("%s: this is not optimal - fix me\n", __func__);
392
393
0
            for (int64_t i03 = 0; i03 < ne03; i03++) {
394
0
                for (int64_t i02 = 0; i02 < ne02; i02++) {
395
0
                    id += rs * ir0;
396
0
                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
397
0
                        for (int64_t i00 = 0; i00 < ne00; i00++) {
398
0
                            const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
399
0
                            memcpy(dst_ptr + id, src0_ptr, type_size);
400
401
0
                            id += type_size;
402
0
                        }
403
0
                    }
404
0
                    id += rs * (ne01 - ir1);
405
0
                }
406
0
            }
407
0
        }
408
409
0
        return;
410
0
    }
411
412
    // dst counters
413
0
    int64_t k10 = 0;
414
0
    int64_t i11 = 0;
415
0
    int64_t i12 = 0;
416
0
    int64_t i13 = 0;
417
418
    // number of blocks in a row
419
0
    const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
420
0
    const int64_t nk0  = ne0  / ggml_blck_size(dst->type);
421
422
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
423
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
424
0
            k10 += nk00 * ir0;
425
0
            while (k10 >= nk0) {
426
0
                k10 -= nk0;
427
0
                if (++i11 == ne1) {
428
0
                    i11 = 0;
429
0
                    if (++i12 == ne2) {
430
0
                        i12 = 0;
431
0
                        if (++i13 == ne3) {
432
0
                            i13 = 0;
433
0
                        }
434
0
                    }
435
0
                }
436
0
            }
437
0
            for (int64_t i01 = ir0; i01 < ir1; i01++) {
438
0
                for (int64_t k00 = 0; k00 < nk00; k00++) {
439
0
                    const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
440
0
                          char * dst_ptr  = ((char *)  dst->data + k10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
441
442
0
                    memcpy(dst_ptr, src0_ptr, type_size);
443
444
0
                    if (++k10 == nk0) {
445
0
                        k10 = 0;
446
0
                        if (++i11 == ne1) {
447
0
                            i11 = 0;
448
0
                            if (++i12 == ne2) {
449
0
                                i12 = 0;
450
0
                                if (++i13 == ne3) {
451
0
                                    i13 = 0;
452
0
                                }
453
0
                            }
454
0
                        }
455
0
                    }
456
0
                }
457
0
            }
458
0
            k10 += nk00 * (ne01 - ir1);
459
0
            while (k10 >= nk0) {
460
0
                k10 -= nk0;
461
0
                if (++i11 == ne1) {
462
0
                    i11 = 0;
463
0
                    if (++i12 == ne2) {
464
0
                        i12 = 0;
465
0
                        if (++i13 == ne3) {
466
0
                            i13 = 0;
467
0
                        }
468
0
                    }
469
0
                }
470
0
            }
471
0
        }
472
0
    }
473
0
}
474
475
static void ggml_compute_forward_dup_from_q(
476
        const ggml_compute_params * params,
477
0
              ggml_tensor * dst) {
478
479
0
    const ggml_tensor * src0 = dst->src[0];
480
0
    const ggml_tensor * src1 = dst->src[1];
481
482
0
    GGML_TENSOR_BINARY_OP_LOCALS
483
484
0
    const ggml_type type = src0->type;
485
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
486
487
0
    size_t qk = ggml_blck_size(type);
488
0
    const int64_t nr = ggml_nelements(src1) / qk;
489
490
    // destination must be contiguous in the first dimension
491
0
    GGML_ASSERT(nb10 == ggml_type_size(dst->type));
492
    // must either have first dimension large enough to hold a row, or fully contiguous
493
0
    GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
494
495
0
    const int ith = params->ith;
496
0
    const int nth = params->nth;
497
498
0
    const int dr = (nr + nth - 1)/nth;
499
500
    // row range for this thread
501
0
    const int ir0 = dr*ith;
502
0
    const int ir1 = MIN(ir0 + dr, nr);
503
504
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
505
506
0
        uint32_t i = ir * qk;
507
508
0
        const int64_t i03 = i/(ne00 * ne01 * ne02);
509
0
        const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
510
0
        const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
511
0
        const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
512
0
        const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
513
514
0
        const int64_t i13 = i/(ne10 * ne11 * ne12);
515
0
        const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
516
0
        const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
517
0
        const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
518
0
        const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
519
520
0
        dequantize_row_q(
521
0
                (const void *) ((char *) src0->data + x_offset),
522
0
                     (float *) ((char *)  dst->data + dst_offset), qk);
523
0
    }
524
0
}
525
526
void ggml_compute_forward_dup(
527
        const ggml_compute_params * params,
528
0
        ggml_tensor * dst) {
529
530
0
    const ggml_tensor * src0 = dst->src[0];
531
532
0
    if (src0->type == dst->type) {
533
0
        ggml_compute_forward_dup_bytes(params, dst);
534
0
        return;
535
0
    }
536
537
0
    switch (src0->type) {
538
0
        case GGML_TYPE_F16:
539
0
            {
540
0
                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
541
0
                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
542
0
                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_fp16_t, float      >(params, dst);
543
0
                else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
544
0
            } break;
545
0
        case GGML_TYPE_BF16:
546
0
            {
547
0
                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
548
0
                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
549
0
                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_bf16_t, float      >(params, dst);
550
0
                else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
551
0
            } break;
552
0
        case GGML_TYPE_F32:
553
0
            {
554
0
                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
555
0
                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
556
0
                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<float, float      >(params, dst);
557
0
                else if (dst->type == GGML_TYPE_I32)  ggml_compute_forward_dup_flt<float, int32_t    >(params, dst);
558
0
                else ggml_compute_forward_dup_to_q<float>(params, dst);
559
0
            } break;
560
0
        case GGML_TYPE_I32:
561
0
            {
562
0
                if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
563
0
                else GGML_ABORT("not implemented");
564
0
            } break;
565
0
        default:
566
0
            {
567
0
                if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
568
0
                    ggml_compute_forward_dup_from_q(params, dst);
569
0
                    break;
570
0
                }
571
0
                GGML_ABORT("fatal error");
572
0
            }
573
0
    }
574
0
}
575
576
// ggml_compute_forward_add
577
578
static void ggml_compute_forward_add_q_f32(
579
        const ggml_compute_params * params,
580
0
        ggml_tensor * dst) {
581
582
0
    const ggml_tensor * src0 = dst->src[0];
583
0
    const ggml_tensor * src1 = dst->src[1];
584
585
0
    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
586
587
0
    const int nr  = ggml_nrows(src0);
588
589
0
    GGML_TENSOR_BINARY_OP_LOCALS
590
591
0
    const int ith = params->ith;
592
0
    const int nth = params->nth;
593
594
0
    const ggml_type type = src0->type;
595
0
    const ggml_type dtype = dst->type;
596
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
597
0
    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
598
599
    // we don't support permuted src0 or src1
600
0
    GGML_ASSERT(nb00 == ggml_type_size(type));
601
0
    GGML_ASSERT(nb10 == sizeof(float));
602
603
    // dst cannot be transposed or permuted
604
0
    GGML_ASSERT(nb0 <= nb1);
605
0
    GGML_ASSERT(nb1 <= nb2);
606
0
    GGML_ASSERT(nb2 <= nb3);
607
608
0
    GGML_ASSERT(ggml_is_quantized(src0->type));
609
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
610
611
    // rows per thread
612
0
    const int dr = (nr + nth - 1)/nth;
613
614
    // row range for this thread
615
0
    const int ir0 = dr*ith;
616
0
    const int ir1 = MIN(ir0 + dr, nr);
617
618
0
    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
619
620
0
    for (int ir = ir0; ir < ir1; ++ir) {
621
        // src0 indices
622
0
        const int i03 = ir/(ne02*ne01);
623
0
        const int i02 = (ir - i03*ne02*ne01)/ne01;
624
0
        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
625
626
        // src1 and dst are same shape as src0 => same indices
627
0
        const int i13 = i03;
628
0
        const int i12 = i02;
629
0
        const int i11 = i01;
630
631
0
        const int i3 = i03;
632
0
        const int i2 = i02;
633
0
        const int i1 = i01;
634
635
0
        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
636
0
        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
637
0
        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
638
639
0
        assert(ne00 % 32 == 0);
640
641
        // unquantize row from src0 to temp buffer
642
0
        dequantize_row_q(src0_row, wdata, ne00);
643
        // add src1
644
0
        ggml_vec_acc_f32(ne00, wdata, src1_row);
645
        // quantize row to dst
646
0
        if (quantize_row_q != NULL) {
647
0
            quantize_row_q(wdata, dst_row, ne00);
648
0
        } else {
649
0
            memcpy(dst_row, wdata, ne0*nb0);
650
0
        }
651
0
    }
652
0
}
653
654
void ggml_compute_forward_add(
655
        const ggml_compute_params * params,
656
0
        ggml_tensor * dst) {
657
658
0
    const ggml_tensor * src0 = dst->src[0];
659
660
0
    switch (src0->type) {
661
0
        case GGML_TYPE_F32:
662
0
        case GGML_TYPE_F16:
663
0
        case GGML_TYPE_BF16:
664
0
            {
665
0
                ggml_compute_forward_add_non_quantized(params, dst);
666
0
            } break;
667
0
        case GGML_TYPE_Q4_0:
668
0
        case GGML_TYPE_Q4_1:
669
0
        case GGML_TYPE_Q5_0:
670
0
        case GGML_TYPE_Q5_1:
671
0
        case GGML_TYPE_Q8_0:
672
0
        case GGML_TYPE_MXFP4:
673
0
        case GGML_TYPE_Q2_K:
674
0
        case GGML_TYPE_Q3_K:
675
0
        case GGML_TYPE_Q4_K:
676
0
        case GGML_TYPE_Q5_K:
677
0
        case GGML_TYPE_Q6_K:
678
0
        case GGML_TYPE_TQ1_0:
679
0
        case GGML_TYPE_TQ2_0:
680
0
        case GGML_TYPE_IQ2_XXS:
681
0
        case GGML_TYPE_IQ2_XS:
682
0
        case GGML_TYPE_IQ3_XXS:
683
0
        case GGML_TYPE_IQ1_S:
684
0
        case GGML_TYPE_IQ1_M:
685
0
        case GGML_TYPE_IQ4_NL:
686
0
        case GGML_TYPE_IQ4_XS:
687
0
        case GGML_TYPE_IQ3_S:
688
0
        case GGML_TYPE_IQ2_S:
689
0
            {
690
0
                ggml_compute_forward_add_q_f32(params, dst);
691
0
            } break;
692
0
        default:
693
0
            {
694
0
                GGML_ABORT("fatal error");
695
0
            }
696
0
    }
697
0
}
698
699
// ggml_compute_forward_add_id
700
701
static void ggml_compute_forward_add_id_f32(
702
        const ggml_compute_params * params,
703
0
        ggml_tensor * dst) {
704
705
0
    const ggml_tensor * src0 = dst->src[0];
706
0
    const ggml_tensor * src1 = dst->src[1];
707
0
    const ggml_tensor * src2 = dst->src[2];
708
709
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
710
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
711
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
712
0
    GGML_ASSERT(src2->type == GGML_TYPE_I32);
713
714
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
715
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
716
717
0
    const int ith = params->ith;
718
0
    const int nth = params->nth;
719
720
0
    const int nr  = ggml_nrows(src0);
721
722
0
    GGML_TENSOR_TERNARY_OP_LOCALS
723
724
0
    GGML_ASSERT( nb0 == sizeof(float));
725
0
    GGML_ASSERT(nb10 == sizeof(float));
726
727
    // rows per thread
728
0
    const int dr = (nr + nth - 1)/nth;
729
730
    // row range for this thread
731
0
    const int ir0 = dr*ith;
732
0
    const int ir1 = MIN(ir0 + dr, nr);
733
734
0
    for (int ir = ir0; ir < ir1; ++ir) {
735
        // src0 indices
736
0
        const int i3 = ir/(ne2*ne1);
737
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
738
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
739
740
        // src1 indices
741
0
        const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
742
743
0
        GGML_ASSERT(i11 >= 0 && i11 < ne11);
744
745
0
        ggml_vec_add_f32(ne0,
746
0
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
747
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
748
0
                (float *) ((char *) src1->data + i11*nb11));
749
0
    }
750
0
}
751
752
void ggml_compute_forward_add_id(
753
        const ggml_compute_params * params,
754
0
        ggml_tensor * dst) {
755
756
0
    const ggml_tensor * src0 = dst->src[0];
757
758
0
    switch (src0->type) {
759
0
        case GGML_TYPE_F32:
760
0
            {
761
0
                ggml_compute_forward_add_id_f32(params, dst);
762
0
            } break;
763
0
        default:
764
0
            {
765
0
                GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
766
0
            }
767
0
    }
768
0
}
769
770
// ggml_compute_forward_add1
771
772
static void ggml_compute_forward_add1_f32(
773
        const ggml_compute_params * params,
774
0
        ggml_tensor * dst) {
775
776
0
    const ggml_tensor * src0 = dst->src[0];
777
0
    const ggml_tensor * src1 = dst->src[1];
778
779
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
780
0
    GGML_ASSERT(ggml_is_scalar(src1));
781
782
0
    const int ith = params->ith;
783
0
    const int nth = params->nth;
784
785
0
    const int nr  = ggml_nrows(src0);
786
787
0
    GGML_TENSOR_UNARY_OP_LOCALS
788
789
0
    GGML_ASSERT( nb0 == sizeof(float));
790
0
    GGML_ASSERT(nb00 == sizeof(float));
791
792
    // rows per thread
793
0
    const int dr = (nr + nth - 1)/nth;
794
795
    // row range for this thread
796
0
    const int ir0 = dr*ith;
797
0
    const int ir1 = MIN(ir0 + dr, nr);
798
799
0
    for (int ir = ir0; ir < ir1; ++ir) {
800
        // src0 and dst are same shape => same indices
801
0
        const int i3 = ir/(ne2*ne1);
802
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
803
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
804
805
#ifdef GGML_USE_ACCELERATE
806
        GGML_UNUSED(ggml_vec_add1_f32);
807
808
        vDSP_vadd(
809
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
810
                (float *) ((char *) src1->data), 0,
811
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
812
                ne0);
813
#else
814
0
        ggml_vec_add1_f32(ne0,
815
0
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
816
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
817
0
               *(float *) src1->data);
818
0
#endif
819
0
    }
820
0
}
821
822
static void ggml_compute_forward_add1_f16_f32(
823
        const ggml_compute_params * params,
824
0
        ggml_tensor * dst) {
825
826
0
    const ggml_tensor * src0 = dst->src[0];
827
0
    const ggml_tensor * src1 = dst->src[1];
828
829
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
830
0
    GGML_ASSERT(ggml_is_scalar(src1));
831
832
    // scalar to add
833
0
    const float v = *(float *) src1->data;
834
835
0
    const int ith = params->ith;
836
0
    const int nth = params->nth;
837
838
0
    const int nr  = ggml_nrows(src0);
839
840
0
    GGML_TENSOR_UNARY_OP_LOCALS
841
842
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
843
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
844
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
845
846
0
    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
847
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
848
849
    // rows per thread
850
0
    const int dr = (nr + nth - 1)/nth;
851
852
    // row range for this thread
853
0
    const int ir0 = dr*ith;
854
0
    const int ir1 = MIN(ir0 + dr, nr);
855
856
0
    for (int ir = ir0; ir < ir1; ++ir) {
857
        // src0 and dst are same shape => same indices
858
0
        const int i3 = ir/(ne2*ne1);
859
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
860
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
861
862
0
        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
863
0
        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
864
0
        for (int i = 0; i < ne0; i++) {
865
0
            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
866
0
        }
867
0
    }
868
0
}
869
870
static void ggml_compute_forward_add1_f16_f16(
871
        const ggml_compute_params * params,
872
0
        ggml_tensor * dst) {
873
874
0
    const ggml_tensor * src0 = dst->src[0];
875
0
    const ggml_tensor * src1 = dst->src[1];
876
877
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
878
0
    GGML_ASSERT(ggml_is_scalar(src1));
879
880
    // scalar to add
881
0
    const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
882
883
0
    const int ith = params->ith;
884
0
    const int nth = params->nth;
885
886
0
    const int nr  = ggml_nrows(src0);
887
888
0
    GGML_TENSOR_UNARY_OP_LOCALS
889
890
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
891
0
    GGML_ASSERT(src1->type == GGML_TYPE_F16);
892
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
893
894
0
    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
895
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
896
897
    // rows per thread
898
0
    const int dr = (nr + nth - 1)/nth;
899
900
    // row range for this thread
901
0
    const int ir0 = dr*ith;
902
0
    const int ir1 = MIN(ir0 + dr, nr);
903
904
0
    for (int ir = ir0; ir < ir1; ++ir) {
905
        // src0 and dst are same shape => same indices
906
0
        const int i3 = ir/(ne2*ne1);
907
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
908
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
909
910
0
        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
911
0
        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
912
0
        for (int i = 0; i < ne0; i++) {
913
0
            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
914
0
        }
915
0
    }
916
0
}
917
918
static void ggml_compute_forward_add1_q_f32(
919
        const ggml_compute_params * params,
920
0
        ggml_tensor * dst) {
921
922
0
    const ggml_tensor * src0 = dst->src[0];
923
0
    const ggml_tensor * src1 = dst->src[1];
924
925
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
926
0
    GGML_ASSERT(ggml_is_scalar(src1));
927
928
    // scalar to add
929
0
    const float v = *(float *) src1->data;
930
931
0
    const int ith = params->ith;
932
0
    const int nth = params->nth;
933
934
0
    const int nr  = ggml_nrows(src0);
935
936
0
    GGML_TENSOR_UNARY_OP_LOCALS
937
938
0
    const ggml_type type = src0->type;
939
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
940
0
    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
941
942
    // we don't support permuted src0
943
0
    GGML_ASSERT(nb00 == ggml_type_size(type));
944
945
    // dst cannot be transposed or permuted
946
0
    GGML_ASSERT(nb0 <= nb1);
947
0
    GGML_ASSERT(nb1 <= nb2);
948
0
    GGML_ASSERT(nb2 <= nb3);
949
950
0
    GGML_ASSERT(ggml_is_quantized(src0->type));
951
0
    GGML_ASSERT(dst->type == src0->type);
952
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
953
954
    // rows per thread
955
0
    const int dr = (nr + nth - 1)/nth;
956
957
    // row range for this thread
958
0
    const int ir0 = dr*ith;
959
0
    const int ir1 = MIN(ir0 + dr, nr);
960
961
0
    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
962
963
0
    for (int ir = ir0; ir < ir1; ++ir) {
964
        // src0 and dst are same shape => same indices
965
0
        const int i3 = ir/(ne2*ne1);
966
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
967
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
968
969
0
        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
970
0
        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));
971
972
0
        assert(ne0 % 32 == 0);
973
974
        // unquantize row from src0 to temp buffer
975
0
        dequantize_row_q(src0_row, wdata, ne0);
976
        // add src1
977
0
        ggml_vec_acc1_f32(ne0, wdata, v);
978
        // quantize row to dst
979
0
        quantize_row_q(wdata, dst_row, ne0);
980
0
    }
981
0
}
982
983
static void ggml_compute_forward_add1_bf16_f32(
984
        const ggml_compute_params * params,
985
0
        ggml_tensor * dst) {
986
987
0
    const ggml_tensor * src0 = dst->src[0];
988
0
    const ggml_tensor * src1 = dst->src[1];
989
990
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
991
0
    GGML_ASSERT(ggml_is_scalar(src1));
992
993
    // scalar to add
994
0
    const float v = *(float *) src1->data;
995
996
0
    const int ith = params->ith;
997
0
    const int nth = params->nth;
998
999
0
    const int nr  = ggml_nrows(src0);
1000
1001
0
    GGML_TENSOR_UNARY_OP_LOCALS
1002
1003
0
    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
1004
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
1005
0
    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
1006
1007
0
    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
1008
0
    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
1009
1010
    // rows per thread
1011
0
    const int dr = (nr + nth - 1)/nth;
1012
1013
    // row range for this thread
1014
0
    const int ir0 = dr*ith;
1015
0
    const int ir1 = MIN(ir0 + dr, nr);
1016
1017
0
    for (int ir = ir0; ir < ir1; ++ir) {
1018
        // src0 and dst are same shape => same indices
1019
0
        const int i3 = ir/(ne2*ne1);
1020
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
1021
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1022
1023
0
        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
1024
0
        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1025
0
        for (int i = 0; i < ne0; i++) {
1026
0
            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
1027
0
        }
1028
0
    }
1029
0
}
1030
1031
static void ggml_compute_forward_add1_bf16_bf16(
1032
        const ggml_compute_params * params,
1033
0
        ggml_tensor * dst) {
1034
1035
0
    const ggml_tensor * src0 = dst->src[0];
1036
0
    const ggml_tensor * src1 = dst->src[1];
1037
1038
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
1039
0
    GGML_ASSERT(ggml_is_scalar(src1));
1040
1041
    // scalar to add
1042
0
    const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
1043
1044
0
    const int ith = params->ith;
1045
0
    const int nth = params->nth;
1046
1047
0
    const int nr  = ggml_nrows(src0);
1048
1049
0
    GGML_TENSOR_UNARY_OP_LOCALS
1050
1051
0
    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
1052
0
    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
1053
0
    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
1054
1055
0
    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
1056
0
    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
1057
1058
    // rows per thread
1059
0
    const int dr = (nr + nth - 1)/nth;
1060
1061
    // row range for this thread
1062
0
    const int ir0 = dr*ith;
1063
0
    const int ir1 = MIN(ir0 + dr, nr);
1064
1065
0
    for (int ir = ir0; ir < ir1; ++ir) {
1066
        // src0 and dst are same shape => same indices
1067
0
        const int i3 = ir/(ne2*ne1);
1068
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
1069
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1070
1071
0
        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
1072
0
        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1073
0
        for (int i = 0; i < ne0; i++) {
1074
0
            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
1075
0
        }
1076
0
    }
1077
0
}
1078
1079
void ggml_compute_forward_add1(
1080
        const ggml_compute_params * params,
1081
0
        ggml_tensor * dst) {
1082
1083
0
    const ggml_tensor * src0 = dst->src[0];
1084
0
    const ggml_tensor * src1 = dst->src[1];
1085
1086
0
    switch (src0->type) {
1087
0
        case GGML_TYPE_F32:
1088
0
            {
1089
0
                ggml_compute_forward_add1_f32(params, dst);
1090
0
            } break;
1091
0
        case GGML_TYPE_F16:
1092
0
            {
1093
0
                if (src1->type == GGML_TYPE_F16) {
1094
0
                    ggml_compute_forward_add1_f16_f16(params, dst);
1095
0
                }
1096
0
                else if (src1->type == GGML_TYPE_F32) {
1097
0
                    ggml_compute_forward_add1_f16_f32(params, dst);
1098
0
                }
1099
0
                else {
1100
0
                    GGML_ABORT("fatal error");
1101
0
                }
1102
0
            } break;
1103
0
        case GGML_TYPE_BF16:
1104
0
            {
1105
0
                if (src1->type == GGML_TYPE_BF16) {
1106
0
                    ggml_compute_forward_add1_bf16_bf16(params, dst);
1107
0
                }
1108
0
                else if (src1->type == GGML_TYPE_F32) {
1109
0
                    ggml_compute_forward_add1_bf16_f32(params, dst);
1110
0
                }
1111
0
                else {
1112
0
                    GGML_ABORT("fatal error");
1113
0
                }
1114
0
            } break;
1115
0
        case GGML_TYPE_Q4_0:
1116
0
        case GGML_TYPE_Q4_1:
1117
0
        case GGML_TYPE_Q5_0:
1118
0
        case GGML_TYPE_Q5_1:
1119
0
        case GGML_TYPE_Q8_0:
1120
0
        case GGML_TYPE_Q8_1:
1121
0
        case GGML_TYPE_MXFP4:
1122
0
        case GGML_TYPE_Q2_K:
1123
0
        case GGML_TYPE_Q3_K:
1124
0
        case GGML_TYPE_Q4_K:
1125
0
        case GGML_TYPE_Q5_K:
1126
0
        case GGML_TYPE_Q6_K:
1127
0
        case GGML_TYPE_TQ1_0:
1128
0
        case GGML_TYPE_TQ2_0:
1129
0
        case GGML_TYPE_IQ2_XXS:
1130
0
        case GGML_TYPE_IQ2_XS:
1131
0
        case GGML_TYPE_IQ3_XXS:
1132
0
        case GGML_TYPE_IQ1_S:
1133
0
        case GGML_TYPE_IQ1_M:
1134
0
        case GGML_TYPE_IQ4_NL:
1135
0
        case GGML_TYPE_IQ4_XS:
1136
0
        case GGML_TYPE_IQ3_S:
1137
0
        case GGML_TYPE_IQ2_S:
1138
0
            {
1139
0
                ggml_compute_forward_add1_q_f32(params, dst);
1140
0
            } break;
1141
0
        default:
1142
0
            {
1143
0
                GGML_ABORT("fatal error");
1144
0
            }
1145
0
    }
1146
0
}
1147
1148
// ggml_compute_forward_acc
1149
1150
static void ggml_compute_forward_acc_f32(
1151
        const ggml_compute_params * params,
1152
0
        ggml_tensor * dst) {
1153
1154
0
    const ggml_tensor * src0 = dst->src[0];
1155
0
    const ggml_tensor * src1 = dst->src[1];
1156
1157
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
1158
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
1159
1160
    // view src0 and dst with these strides and data offset inbytes during acc
1161
    // nb0 is implicitly element_size because src0 and dst are contiguous
1162
0
    size_t nb1     = ((int32_t *) dst->op_params)[0];
1163
0
    size_t nb2     = ((int32_t *) dst->op_params)[1];
1164
0
    size_t nb3     = ((int32_t *) dst->op_params)[2];
1165
0
    size_t offset  = ((int32_t *) dst->op_params)[3];
1166
0
    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
1167
1168
0
    if (!inplace) {
1169
0
        if (params->ith == 0) {
1170
            // memcpy needs to be synchronized across threads to avoid race conditions.
1171
            // => do it in INIT phase
1172
0
            memcpy(
1173
0
                ((char *)  dst->data),
1174
0
                ((char *) src0->data),
1175
0
                ggml_nbytes(dst));
1176
0
        }
1177
0
        ggml_barrier(params->threadpool);
1178
0
    }
1179
1180
0
    const int ith = params->ith;
1181
0
    const int nth = params->nth;
1182
1183
0
    const int nr = ggml_nrows(src1);
1184
0
    const int nc = src1->ne[0];
1185
1186
0
    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
1187
0
    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
1188
1189
    // src0 and dst as viewed during acc
1190
0
    const size_t nb0 = ggml_element_size(src0);
1191
1192
0
    const size_t nb00 = nb0;
1193
0
    const size_t nb01 = nb1;
1194
0
    const size_t nb02 = nb2;
1195
0
    const size_t nb03 = nb3;
1196
1197
0
    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));
1198
0
    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
1199
1200
0
    GGML_ASSERT(nb10 == sizeof(float));
1201
1202
    // rows per thread
1203
0
    const int dr = (nr + nth - 1)/nth;
1204
1205
    // row range for this thread
1206
0
    const int ir0 = dr*ith;
1207
0
    const int ir1 = MIN(ir0 + dr, nr);
1208
1209
0
    for (int ir = ir0; ir < ir1; ++ir) {
1210
        // src0 and dst are viewed with shape of src1 and offset
1211
        // => same indices
1212
0
        const int i3 = ir/(ne12*ne11);
1213
0
        const int i2 = (ir - i3*ne12*ne11)/ne11;
1214
0
        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
1215
1216
#ifdef GGML_USE_ACCELERATE
1217
        vDSP_vadd(
1218
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
1219
                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
1220
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);
1221
#else
1222
0
        ggml_vec_add_f32(nc,
1223
0
                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
1224
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
1225
0
                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
1226
0
#endif
1227
0
    }
1228
0
}
1229
1230
void ggml_compute_forward_acc(
1231
        const ggml_compute_params * params,
1232
0
        ggml_tensor * dst) {
1233
1234
0
    const ggml_tensor * src0 = dst->src[0];
1235
1236
0
    switch (src0->type) {
1237
0
        case GGML_TYPE_F32:
1238
0
            {
1239
0
                ggml_compute_forward_acc_f32(params, dst);
1240
0
            } break;
1241
0
        case GGML_TYPE_F16:
1242
0
        case GGML_TYPE_BF16:
1243
0
        case GGML_TYPE_Q4_0:
1244
0
        case GGML_TYPE_Q4_1:
1245
0
        case GGML_TYPE_Q5_0:
1246
0
        case GGML_TYPE_Q5_1:
1247
0
        case GGML_TYPE_Q8_0:
1248
0
        case GGML_TYPE_Q8_1:
1249
0
        case GGML_TYPE_MXFP4:
1250
0
        case GGML_TYPE_Q2_K:
1251
0
        case GGML_TYPE_Q3_K:
1252
0
        case GGML_TYPE_Q4_K:
1253
0
        case GGML_TYPE_Q5_K:
1254
0
        case GGML_TYPE_Q6_K:
1255
0
        case GGML_TYPE_TQ1_0:
1256
0
        case GGML_TYPE_TQ2_0:
1257
0
        case GGML_TYPE_IQ2_XXS:
1258
0
        case GGML_TYPE_IQ2_XS:
1259
0
        case GGML_TYPE_IQ3_XXS:
1260
0
        case GGML_TYPE_IQ1_S:
1261
0
        case GGML_TYPE_IQ1_M:
1262
0
        case GGML_TYPE_IQ4_NL:
1263
0
        case GGML_TYPE_IQ4_XS:
1264
0
        case GGML_TYPE_IQ3_S:
1265
0
        case GGML_TYPE_IQ2_S:
1266
0
        default:
1267
0
            {
1268
0
                GGML_ABORT("fatal error");
1269
0
            }
1270
0
    }
1271
0
}
1272
1273
// ggml_compute_forward_sum
1274
1275
static void ggml_compute_forward_sum_f32(
1276
        const ggml_compute_params * params,
1277
0
        ggml_tensor * dst) {
1278
1279
0
    const ggml_tensor * src0 = dst->src[0];
1280
1281
0
    if (params->ith != 0) {
1282
0
        return;
1283
0
    }
1284
1285
0
    assert(ggml_is_scalar(dst));
1286
0
    assert(src0->nb[0] == sizeof(float));
1287
1288
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1289
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
1290
1291
0
    ggml_float sum     = 0;
1292
0
    ggml_float row_sum = 0;
1293
1294
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1295
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1296
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1297
0
                ggml_vec_sum_f32_ggf(ne00,
1298
0
                        &row_sum,
1299
0
                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
1300
0
                sum += row_sum;
1301
0
            }
1302
0
        }
1303
0
    }
1304
0
    ((float *) dst->data)[0] = sum;
1305
0
}
1306
1307
static void ggml_compute_forward_sum_f16(
1308
    const ggml_compute_params * params,
1309
0
          ggml_tensor * dst) {
1310
1311
0
    const ggml_tensor * src0 = dst->src[0];
1312
1313
0
    if (params->ith != 0) {
1314
0
        return;
1315
0
    }
1316
1317
0
    assert(ggml_is_scalar(dst));
1318
1319
0
    assert(src0->nb[0] == sizeof(ggml_fp16_t));
1320
1321
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1322
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
1323
1324
0
    float sum = 0;
1325
0
    float row_sum = 0;
1326
1327
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1328
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1329
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1330
0
                ggml_vec_sum_f16_ggf(ne00,
1331
0
                    &row_sum,
1332
0
                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
1333
0
                sum += row_sum;
1334
0
            }
1335
0
        }
1336
0
    }
1337
0
    ((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
1338
0
}
1339
1340
static void ggml_compute_forward_sum_bf16(
1341
    const ggml_compute_params * params,
1342
0
          ggml_tensor * dst) {
1343
1344
0
    const ggml_tensor * src0 = dst->src[0];
1345
1346
0
    if (params->ith != 0) {
1347
0
        return;
1348
0
    }
1349
1350
0
    assert(ggml_is_scalar(dst));
1351
1352
0
    assert(src0->nb[0] == sizeof(ggml_bf16_t));
1353
1354
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1355
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
1356
1357
0
    float sum = 0;
1358
0
    float row_sum = 0;
1359
1360
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1361
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1362
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1363
0
                ggml_vec_sum_bf16_ggf(ne00,
1364
0
                    &row_sum,
1365
0
                    (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
1366
0
                sum += row_sum;
1367
0
            }
1368
0
        }
1369
0
    }
1370
0
    ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
1371
0
}
1372
1373
void ggml_compute_forward_sum(
1374
        const ggml_compute_params * params,
1375
0
        ggml_tensor * dst) {
1376
1377
0
    const ggml_tensor * src0 = dst->src[0];
1378
1379
0
    switch (src0->type) {
1380
0
        case GGML_TYPE_F32:
1381
0
            {
1382
0
                ggml_compute_forward_sum_f32(params, dst);
1383
0
            } break;
1384
0
        case GGML_TYPE_F16:
1385
0
            {
1386
0
                ggml_compute_forward_sum_f16(params, dst);
1387
0
            } break;
1388
0
        case GGML_TYPE_BF16:
1389
0
            {
1390
0
                ggml_compute_forward_sum_bf16(params, dst);
1391
0
            } break;
1392
0
        default:
1393
0
            {
1394
0
                GGML_ABORT("fatal error");
1395
0
            }
1396
0
    }
1397
0
}
1398
1399
// ggml_compute_forward_cumsum
1400
1401
static void ggml_compute_forward_cumsum_f32(
1402
        const ggml_compute_params * params,
1403
0
        ggml_tensor * dst) {
1404
1405
0
    const ggml_tensor * src0 = dst->src[0];
1406
1407
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
1408
0
    GGML_ASSERT(dst->nb[0] == sizeof(float));
1409
1410
0
    GGML_TENSOR_UNARY_OP_LOCALS
1411
1412
0
    GGML_ASSERT(ne0 == ne00);
1413
0
    GGML_ASSERT(ne1 == ne01);
1414
0
    GGML_ASSERT(ne2 == ne02);
1415
0
    GGML_ASSERT(ne3 == ne03);
1416
1417
0
    const auto [ir0, ir1] = get_thread_range(params, src0);
1418
1419
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
1420
0
        const int64_t i03 = ir/(ne02*ne01);
1421
0
        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
1422
0
        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1423
1424
0
        float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1425
0
        float * dst_row = (float *) ((char *) dst->data  + i01*nb1  + i02*nb2  + i03*nb3);
1426
1427
0
        ggml_vec_cumsum_f32(ne00, dst_row, src_row);
1428
0
    }
1429
0
}
1430
1431
void ggml_compute_forward_cumsum(
1432
        const ggml_compute_params * params,
1433
0
        ggml_tensor * dst) {
1434
1435
0
    const ggml_tensor * src0 = dst->src[0];
1436
1437
0
    switch (src0->type) {
1438
0
        case GGML_TYPE_F32:
1439
0
            {
1440
0
                ggml_compute_forward_cumsum_f32(params, dst);
1441
0
            } break;
1442
0
        default:
1443
0
            {
1444
0
                GGML_ABORT("fatal error");
1445
0
            }
1446
0
    }
1447
0
}
1448
1449
// ggml_compute_forward_sum_rows
1450
1451
static void ggml_compute_forward_sum_rows_f32(
1452
        const ggml_compute_params * params,
1453
0
        ggml_tensor * dst) {
1454
1455
0
    const ggml_tensor * src0 = dst->src[0];
1456
1457
0
    if (params->ith != 0) {
1458
0
        return;
1459
0
    }
1460
1461
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
1462
0
    GGML_ASSERT(dst->nb[0] == sizeof(float));
1463
1464
0
    GGML_TENSOR_UNARY_OP_LOCALS
1465
1466
0
    GGML_ASSERT(ne0 == 1);
1467
0
    GGML_ASSERT(ne1 == ne01);
1468
0
    GGML_ASSERT(ne2 == ne02);
1469
0
    GGML_ASSERT(ne3 == ne03);
1470
1471
0
    for (int64_t i3 = 0; i3 < ne03; i3++) {
1472
0
        for (int64_t i2 = 0; i2 < ne02; i2++) {
1473
0
            for (int64_t i1 = 0; i1 < ne01; i1++) {
1474
0
                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1475
0
                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
1476
0
                float row_sum = 0;
1477
0
                ggml_vec_sum_f32(ne00, &row_sum, src_row);
1478
0
                dst_row[0] = row_sum;
1479
0
            }
1480
0
        }
1481
0
    }
1482
0
}
1483
1484
void ggml_compute_forward_sum_rows(
1485
        const ggml_compute_params * params,
1486
0
        ggml_tensor * dst) {
1487
1488
0
    const ggml_tensor * src0 = dst->src[0];
1489
1490
0
    switch (src0->type) {
1491
0
        case GGML_TYPE_F32:
1492
0
            {
1493
0
                ggml_compute_forward_sum_rows_f32(params, dst);
1494
0
            } break;
1495
0
        default:
1496
0
            {
1497
0
                GGML_ABORT("fatal error");
1498
0
            }
1499
0
    }
1500
0
}
1501
1502
// ggml_compute_forward_mean
1503
1504
static void ggml_compute_forward_mean_f32(
1505
        const ggml_compute_params * params,
1506
0
        ggml_tensor * dst) {
1507
1508
0
    const ggml_tensor * src0 = dst->src[0];
1509
1510
0
    if (params->ith != 0) {
1511
0
        return;
1512
0
    }
1513
1514
0
    assert(src0->nb[0] == sizeof(float));
1515
1516
0
    GGML_TENSOR_UNARY_OP_LOCALS
1517
1518
0
    assert(ne0 == 1);
1519
0
    assert(ne1 == ne01);
1520
0
    assert(ne2 == ne02);
1521
0
    assert(ne3 == ne03);
1522
1523
0
    GGML_UNUSED(ne0);
1524
0
    GGML_UNUSED(ne1);
1525
0
    GGML_UNUSED(ne2);
1526
0
    GGML_UNUSED(ne3);
1527
1528
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1529
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1530
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1531
0
                ggml_vec_sum_f32(ne00,
1532
0
                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
1533
0
                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
1534
1535
0
                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
1536
0
            }
1537
0
        }
1538
0
    }
1539
0
}
1540
1541
void ggml_compute_forward_mean(
1542
        const ggml_compute_params * params,
1543
0
        ggml_tensor * dst) {
1544
1545
0
    const ggml_tensor * src0 = dst->src[0];
1546
1547
0
    switch (src0->type) {
1548
0
        case GGML_TYPE_F32:
1549
0
            {
1550
0
                ggml_compute_forward_mean_f32(params, dst);
1551
0
            } break;
1552
0
        default:
1553
0
            {
1554
0
                GGML_ABORT("fatal error");
1555
0
            }
1556
0
    }
1557
0
}
1558
1559
// ggml_compute_forward_argmax
1560
1561
static void ggml_compute_forward_argmax_f32(
1562
        const ggml_compute_params * params,
1563
0
        ggml_tensor * dst) {
1564
1565
0
    const ggml_tensor * src0 = dst->src[0];
1566
1567
0
    if (params->ith != 0) {
1568
0
        return;
1569
0
    }
1570
1571
0
    assert(src0->nb[0] == sizeof(float));
1572
0
    assert(dst->nb[0] == sizeof(float));
1573
1574
0
    const int64_t ne00 = src0->ne[0];
1575
0
    const int64_t ne01 = src0->ne[1];
1576
1577
0
    const size_t nb01 = src0->nb[1];
1578
0
    const size_t nb0 = dst->nb[0];
1579
1580
0
    for (int64_t i1 = 0; i1 < ne01; i1++) {
1581
0
        float * src = (float *) ((char *) src0->data + i1*nb01);
1582
0
        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);
1583
0
        int v = 0;
1584
0
        ggml_vec_argmax_f32(ne00, &v, src);
1585
0
        dst_[0] = v;
1586
0
    }
1587
0
}
1588
1589
void ggml_compute_forward_argmax(
1590
        const ggml_compute_params * params,
1591
0
        ggml_tensor * dst) {
1592
1593
0
    const ggml_tensor * src0 = dst->src[0];
1594
1595
0
    switch (src0->type) {
1596
0
        case GGML_TYPE_F32:
1597
0
            {
1598
0
                ggml_compute_forward_argmax_f32(params, dst);
1599
0
            } break;
1600
0
        default:
1601
0
            {
1602
0
                GGML_ABORT("fatal error");
1603
0
            }
1604
0
    }
1605
0
}
1606
1607
// ggml_compute_forward_count_equal
1608
1609
static void ggml_compute_forward_count_equal_i32(
1610
        const ggml_compute_params * params,
1611
0
        ggml_tensor * dst) {
1612
1613
0
    const ggml_tensor * src0 = dst->src[0];
1614
0
    const ggml_tensor * src1 = dst->src[1];
1615
1616
0
    GGML_TENSOR_BINARY_OP_LOCALS;
1617
1618
0
    GGML_ASSERT(src0->type == GGML_TYPE_I32);
1619
0
    GGML_ASSERT(src1->type == GGML_TYPE_I32);
1620
0
    GGML_ASSERT(ggml_are_same_shape(src0, src1));
1621
0
    GGML_ASSERT(ggml_is_scalar(dst));
1622
0
    GGML_ASSERT(dst->type == GGML_TYPE_I64);
1623
1624
0
    const int64_t nr = ggml_nrows(src0);
1625
1626
0
    const int ith = params->ith;
1627
0
    const int nth = params->nth;
1628
1629
0
    int64_t * sums = (int64_t *) params->wdata;
1630
0
    int64_t sum_thread = 0;
1631
1632
    // rows per thread
1633
0
    const int64_t dr = (nr + nth - 1)/nth;
1634
1635
    // row range for this thread
1636
0
    const int64_t ir0 = dr*ith;
1637
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
1638
1639
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
1640
0
        const int64_t i03 =  ir                        / (ne02*ne01);
1641
0
        const int64_t i02 = (ir - i03*ne03)            /       ne01;
1642
0
        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
1643
1644
0
        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
1645
0
        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
1646
1647
0
        for (int64_t i00 = 0; i00 < ne00; ++i00) {
1648
0
            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
1649
0
            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
1650
1651
0
            sum_thread += val0 == val1;
1652
0
        }
1653
0
    }
1654
0
    if (ith != 0) {
1655
0
        sums[ith] = sum_thread;
1656
0
    }
1657
0
    ggml_barrier(params->threadpool);
1658
1659
0
    if (ith != 0) {
1660
0
        return;
1661
0
    }
1662
1663
0
    for (int ith_other = 1; ith_other < nth; ++ith_other) {
1664
0
        sum_thread += sums[ith_other];
1665
0
    }
1666
0
    *((int64_t *) dst->data) = sum_thread;
1667
0
}
1668
1669
void ggml_compute_forward_count_equal(
1670
        const ggml_compute_params * params,
1671
0
        ggml_tensor * dst) {
1672
1673
0
    const ggml_tensor * src0 = dst->src[0];
1674
1675
0
    switch (src0->type) {
1676
0
        case GGML_TYPE_I32:
1677
0
            {
1678
0
                ggml_compute_forward_count_equal_i32(params, dst);
1679
0
            } break;
1680
0
        default:
1681
0
            {
1682
0
                GGML_ABORT("fatal error");
1683
0
            }
1684
0
    }
1685
0
}
1686
1687
// ggml_compute_forward_repeat
1688
1689
static void ggml_compute_forward_repeat_f32(
1690
        const ggml_compute_params * params,
1691
0
        ggml_tensor * dst) {
1692
1693
0
    const ggml_tensor * src0 = dst->src[0];
1694
1695
0
    if (params->ith != 0) {
1696
0
        return;
1697
0
    }
1698
1699
0
    GGML_ASSERT(ggml_can_repeat(src0, dst));
1700
1701
0
    GGML_TENSOR_UNARY_OP_LOCALS
1702
1703
    // guaranteed to be an integer due to the check in ggml_can_repeat
1704
0
    const int nr0 = (int)(ne0/ne00);
1705
0
    const int nr1 = (int)(ne1/ne01);
1706
0
    const int nr2 = (int)(ne2/ne02);
1707
0
    const int nr3 = (int)(ne3/ne03);
1708
1709
    // TODO: support for transposed / permuted tensors
1710
0
    GGML_ASSERT(nb0  == sizeof(float));
1711
0
    GGML_ASSERT(nb00 == sizeof(float));
1712
1713
    // TODO: maybe this is not optimal?
1714
0
    for                         (int i3 = 0; i3 < nr3;  i3++) {
1715
0
        for                     (int k3 = 0; k3 < ne03; k3++) {
1716
0
            for                 (int i2 = 0; i2 < nr2;  i2++) {
1717
0
                for             (int k2 = 0; k2 < ne02; k2++) {
1718
0
                    for         (int i1 = 0; i1 < nr1;  i1++) {
1719
0
                        for     (int k1 = 0; k1 < ne01; k1++) {
1720
0
                            for (int i0 = 0; i0 < nr0;  i0++) {
1721
0
                                ggml_vec_cpy_f32(ne00,
1722
0
                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),
1723
0
                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));
1724
0
                            }
1725
0
                        }
1726
0
                    }
1727
0
                }
1728
0
            }
1729
0
        }
1730
0
    }
1731
0
}
1732
1733
static void ggml_compute_forward_repeat_f16(
1734
        const ggml_compute_params * params,
1735
0
        ggml_tensor * dst) {
1736
1737
0
    const ggml_tensor * src0 = dst->src[0];
1738
1739
0
    if (params->ith != 0) {
1740
0
        return;
1741
0
    }
1742
1743
0
    GGML_ASSERT(ggml_can_repeat(src0, dst));
1744
1745
0
    GGML_TENSOR_UNARY_OP_LOCALS
1746
1747
    // guaranteed to be an integer due to the check in ggml_can_repeat
1748
0
    const int nr0 = (int)(ne0/ne00);
1749
0
    const int nr1 = (int)(ne1/ne01);
1750
0
    const int nr2 = (int)(ne2/ne02);
1751
0
    const int nr3 = (int)(ne3/ne03);
1752
1753
    // TODO: support for transposed / permuted tensors
1754
0
    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));
1755
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
1756
1757
    // TODO: maybe this is not optimal?
1758
0
    for                         (int i3 = 0; i3 < nr3;  i3++) {
1759
0
        for                     (int k3 = 0; k3 < ne03; k3++) {
1760
0
            for                 (int i2 = 0; i2 < nr2;  i2++) {
1761
0
                for             (int k2 = 0; k2 < ne02; k2++) {
1762
0
                    for         (int i1 = 0; i1 < nr1;  i1++) {
1763
0
                        for     (int k1 = 0; k1 < ne01; k1++) {
1764
0
                            for (int i0 = 0; i0 < nr0;  i0++) {
1765
0
                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
1766
0
                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
1767
                                // ggml_vec_cpy_f16(ne00, y, x)
1768
0
                                for (int i = 0; i < ne00; ++i) {
1769
0
                                    y[i]  = x[i];
1770
0
                                }
1771
0
                            }
1772
0
                        }
1773
0
                    }
1774
0
                }
1775
0
            }
1776
0
        }
1777
0
    }
1778
0
}
1779
1780
void ggml_compute_forward_repeat(
1781
        const ggml_compute_params * params,
1782
0
        ggml_tensor * dst) {
1783
1784
0
    const ggml_tensor * src0 = dst->src[0];
1785
1786
0
    switch (src0->type) {
1787
0
        case GGML_TYPE_F16:
1788
0
        case GGML_TYPE_BF16:
1789
0
        case GGML_TYPE_I16:
1790
0
            {
1791
0
                ggml_compute_forward_repeat_f16(params, dst);
1792
0
            } break;
1793
0
        case GGML_TYPE_F32:
1794
0
        case GGML_TYPE_I32:
1795
0
            {
1796
0
                ggml_compute_forward_repeat_f32(params, dst);
1797
0
            } break;
1798
        // TODO: templateify the implemenation and support for I64
1799
        //       ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
1800
        //case GGML_TYPE_I64:
1801
        //    {
1802
        //        ggml_compute_forward_repeat_i64(params, dst);
1803
        //    } break;
1804
0
        default:
1805
0
            {
1806
0
                GGML_ABORT("fatal error");
1807
0
            }
1808
0
    }
1809
0
}
1810
1811
// ggml_compute_forward_repeat_back
1812
1813
static void ggml_compute_forward_repeat_back_f32(
1814
        const ggml_compute_params * params,
1815
0
        ggml_tensor * dst) {
1816
1817
0
    const ggml_tensor * src0 = dst->src[0];
1818
1819
0
    if (params->ith != 0) {
1820
0
        return;
1821
0
    }
1822
1823
0
    GGML_ASSERT(ggml_can_repeat(dst, src0));
1824
1825
0
    GGML_TENSOR_UNARY_OP_LOCALS
1826
1827
    // guaranteed to be an integer due to the check in ggml_can_repeat
1828
0
    const int nr0 = (int)(ne00/ne0);
1829
0
    const int nr1 = (int)(ne01/ne1);
1830
0
    const int nr2 = (int)(ne02/ne2);
1831
0
    const int nr3 = (int)(ne03/ne3);
1832
1833
    // TODO: support for transposed / permuted tensors
1834
0
    GGML_ASSERT(nb0  == sizeof(float));
1835
0
    GGML_ASSERT(nb00 == sizeof(float));
1836
1837
0
    if (ggml_is_contiguous(dst)) {
1838
0
        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
1839
0
    } else {
1840
0
        for         (int k3 = 0; k3 < ne3; k3++) {
1841
0
            for     (int k2 = 0; k2 < ne2; k2++) {
1842
0
                for (int k1 = 0; k1 < ne1; k1++) {
1843
0
                    ggml_vec_set_f32(ne0,
1844
0
                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
1845
0
                        0);
1846
0
                }
1847
0
            }
1848
0
        }
1849
0
    }
1850
1851
    // TODO: maybe this is not optimal?
1852
0
    for                         (int i3 = 0; i3 < nr3; i3++) {
1853
0
        for                     (int k3 = 0; k3 < ne3; k3++) {
1854
0
            for                 (int i2 = 0; i2 < nr2; i2++) {
1855
0
                for             (int k2 = 0; k2 < ne2; k2++) {
1856
0
                    for         (int i1 = 0; i1 < nr1; i1++) {
1857
0
                        for     (int k1 = 0; k1 < ne1; k1++) {
1858
0
                            for (int i0 = 0; i0 < nr0; i0++) {
1859
0
                                ggml_vec_acc_f32(ne0,
1860
0
                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
1861
0
                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
1862
0
                            }
1863
0
                        }
1864
0
                    }
1865
0
                }
1866
0
            }
1867
0
        }
1868
0
    }
1869
0
}
1870
1871
void ggml_compute_forward_repeat_back(
1872
        const ggml_compute_params * params,
1873
0
        ggml_tensor * dst) {
1874
1875
0
    const ggml_tensor * src0 = dst->src[0];
1876
1877
0
    switch (src0->type) {
1878
0
        case GGML_TYPE_F32:
1879
0
            {
1880
0
                ggml_compute_forward_repeat_back_f32(params, dst);
1881
0
            } break;
1882
0
        default:
1883
0
            {
1884
0
                GGML_ABORT("fatal error");
1885
0
            }
1886
0
    }
1887
0
}
1888
1889
// ggml_compute_forward_concat
1890
1891
static void ggml_compute_forward_concat_any(
1892
    const ggml_compute_params * params,
1893
0
    ggml_tensor * dst) {
1894
1895
0
    const ggml_tensor * src0 = dst->src[0];
1896
0
    const ggml_tensor * src1 = dst->src[1];
1897
1898
0
    const size_t len = ggml_type_size(src0->type);
1899
1900
0
    const int ith = params->ith;
1901
0
    const int nth = params->nth;
1902
1903
0
    GGML_TENSOR_BINARY_OP_LOCALS
1904
1905
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
1906
1907
0
    GGML_ASSERT(dim >= 0 && dim < 4);
1908
1909
0
    int64_t o[4] = {0, 0, 0, 0};
1910
0
    o[dim] = src0->ne[dim];
1911
1912
0
    const char * x;
1913
1914
    // TODO: smarter multi-theading
1915
0
    for (int i3 = 0; i3 < ne3; i3++) {
1916
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
1917
0
            for (int i1 = 0; i1 < ne1; i1++) {
1918
0
                for (int i0 = 0; i0 < ne0; i0++) {
1919
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1920
0
                        x = (const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03;
1921
0
                    } else {
1922
0
                        x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
1923
0
                    }
1924
1925
0
                    char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
1926
1927
0
                    memcpy(y, x, len);
1928
0
                }
1929
0
            }
1930
0
        }
1931
0
    }
1932
0
}
1933
1934
static void ggml_compute_forward_concat_i8(
1935
    const ggml_compute_params * params,
1936
0
    ggml_tensor * dst) {
1937
1938
0
    const ggml_tensor * src0 = dst->src[0];
1939
0
    const ggml_tensor * src1 = dst->src[1];
1940
1941
0
    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
1942
1943
0
    const int ith = params->ith;
1944
0
    const int nth = params->nth;
1945
1946
0
    GGML_TENSOR_BINARY_OP_LOCALS
1947
1948
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
1949
1950
0
    GGML_ASSERT(dim >= 0 && dim < 4);
1951
1952
0
    int64_t o[4] = {0, 0, 0, 0};
1953
0
    o[dim] = src0->ne[dim];
1954
1955
0
    const int8_t * x;
1956
1957
    // TODO: smarter multi-theading
1958
0
    for (int i3 = 0; i3 < ne3; i3++) {
1959
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
1960
0
            for (int i1 = 0; i1 < ne1; i1++) {
1961
0
                for (int i0 = 0; i0 < ne0; i0++) {
1962
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1963
0
                        x = (const int8_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
1964
0
                    } else {
1965
0
                        x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
1966
0
                    }
1967
1968
0
                    int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
1969
1970
0
                    *y = *x;
1971
0
                }
1972
0
            }
1973
0
        }
1974
0
    }
1975
0
}
1976
1977
static void ggml_compute_forward_concat_f16(
1978
    const ggml_compute_params * params,
1979
0
    ggml_tensor * dst) {
1980
1981
0
    const ggml_tensor * src0 = dst->src[0];
1982
0
    const ggml_tensor * src1 = dst->src[1];
1983
1984
0
    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
1985
1986
0
    const int ith = params->ith;
1987
0
    const int nth = params->nth;
1988
1989
0
    GGML_TENSOR_BINARY_OP_LOCALS
1990
1991
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
1992
1993
0
    GGML_ASSERT(dim >= 0 && dim < 4);
1994
1995
0
    int64_t o[4] = {0, 0, 0, 0};
1996
0
    o[dim] = src0->ne[dim];
1997
1998
0
    const ggml_fp16_t * x;
1999
2000
    // TODO: smarter multi-theading
2001
0
    for (int i3 = 0; i3 < ne3; i3++) {
2002
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
2003
0
            for (int i1 = 0; i1 < ne1; i1++) {
2004
0
                for (int i0 = 0; i0 < ne0; i0++) {
2005
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
2006
0
                        x = (const ggml_fp16_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
2007
0
                    } else {
2008
0
                        x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
2009
0
                    }
2010
2011
0
                    ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
2012
2013
0
                    *y = *x;
2014
0
                }
2015
0
            }
2016
0
        }
2017
0
    }
2018
0
}
2019
2020
static void ggml_compute_forward_concat_f32(
2021
    const ggml_compute_params * params,
2022
0
    ggml_tensor * dst) {
2023
2024
0
    const ggml_tensor * src0 = dst->src[0];
2025
0
    const ggml_tensor * src1 = dst->src[1];
2026
2027
0
    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
2028
2029
0
    const int ith = params->ith;
2030
0
    const int nth = params->nth;
2031
2032
0
    GGML_TENSOR_BINARY_OP_LOCALS
2033
2034
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
2035
2036
0
    GGML_ASSERT(dim >= 0 && dim < 4);
2037
2038
0
    int64_t o[4] = {0, 0, 0, 0};
2039
0
    o[dim] = src0->ne[dim];
2040
2041
0
    const float * x;
2042
2043
    // TODO: smarter multi-theading
2044
0
    for (int i3 = 0; i3 < ne3; i3++) {
2045
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
2046
0
            for (int i1 = 0; i1 < ne1; i1++) {
2047
0
                for (int i0 = 0; i0 < ne0; i0++) {
2048
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
2049
0
                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
2050
0
                    } else {
2051
0
                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
2052
0
                    }
2053
2054
0
                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
2055
2056
0
                    *y = *x;
2057
0
                }
2058
0
            }
2059
0
        }
2060
0
    }
2061
0
}
2062
2063
void ggml_compute_forward_concat(
2064
    const ggml_compute_params * params,
2065
0
    ggml_tensor * dst) {
2066
2067
0
    const ggml_tensor * src0 = dst->src[0];
2068
2069
0
    switch (src0->type) {
2070
0
        case GGML_TYPE_F16:
2071
0
        case GGML_TYPE_BF16:
2072
0
        case GGML_TYPE_I16:
2073
0
            {
2074
0
                ggml_compute_forward_concat_f16(params, dst);
2075
0
            } break;
2076
0
        case GGML_TYPE_I8:
2077
0
            {
2078
0
                ggml_compute_forward_concat_i8(params, dst);
2079
0
            } break;
2080
0
        case GGML_TYPE_F32:
2081
0
        case GGML_TYPE_I32:
2082
0
            {
2083
0
                ggml_compute_forward_concat_f32(params, dst);
2084
0
            } break;
2085
0
        default:
2086
0
            {
2087
0
                ggml_compute_forward_concat_any(params, dst);
2088
0
            }
2089
0
    }
2090
0
}
2091
2092
// ggml_compute_forward_gelu
2093
2094
static void ggml_compute_forward_gelu_f32(
2095
        const ggml_compute_params * params,
2096
0
        ggml_tensor * dst) {
2097
2098
0
    const ggml_tensor * src0 = dst->src[0];
2099
2100
0
    assert(ggml_is_contiguous_1(src0));
2101
0
    assert(ggml_is_contiguous_1(dst));
2102
0
    assert(ggml_are_same_shape(src0, dst));
2103
2104
0
    const int ith = params->ith;
2105
0
    const int nth = params->nth;
2106
2107
0
    const int nc = src0->ne[0];
2108
0
    const int nr = ggml_nrows(src0);
2109
2110
    // rows per thread
2111
0
    const int dr = (nr + nth - 1)/nth;
2112
2113
    // row range for this thread
2114
0
    const int ir0 = dr*ith;
2115
0
    const int ir1 = MIN(ir0 + dr, nr);
2116
2117
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2118
0
        ggml_vec_gelu_f32(nc,
2119
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2120
0
                (float *) ((char *) src0->data + i1*(src0->nb[1])));
2121
2122
#ifndef NDEBUG
2123
        for (int k = 0; k < nc; k++) {
2124
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2125
            GGML_UNUSED(x);
2126
            assert(!isnan(x));
2127
            assert(!isinf(x));
2128
        }
2129
#endif
2130
0
    }
2131
0
}
2132
2133
static void ggml_compute_forward_gelu_f16(
2134
    const ggml_compute_params * params,
2135
0
    ggml_tensor * dst) {
2136
2137
0
    const ggml_tensor * src0 = dst->src[0];
2138
2139
0
    assert(ggml_is_contiguous_1(src0));
2140
0
    assert(ggml_is_contiguous_1(dst));
2141
0
    assert(ggml_are_same_shape(src0, dst));
2142
2143
0
    const int ith = params->ith;
2144
0
    const int nth = params->nth;
2145
2146
0
    const int nc = src0->ne[0];
2147
0
    const int nr = ggml_nrows(src0);
2148
2149
    // rows per thread
2150
0
    const int dr = (nr + nth - 1)/nth;
2151
2152
    // row range for this thread
2153
0
    const int ir0 = dr*ith;
2154
0
    const int ir1 = MIN(ir0 + dr, nr);
2155
2156
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2157
0
        ggml_vec_gelu_f16(nc,
2158
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2159
0
                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2160
2161
#ifndef NDEBUG
2162
        for (int k = 0; k < nc; k++) {
2163
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2164
            const float v = GGML_CPU_FP16_TO_FP32(x);
2165
            GGML_UNUSED(v);
2166
            assert(!isnan(v));
2167
            assert(!isinf(v));
2168
        }
2169
#endif
2170
0
    }
2171
0
}
2172
2173
static void ggml_compute_forward_gelu(
2174
        const ggml_compute_params * params,
2175
0
        ggml_tensor * dst) {
2176
2177
0
    const ggml_tensor * src0 = dst->src[0];
2178
2179
0
    switch (src0->type) {
2180
0
        case GGML_TYPE_F32:
2181
0
            {
2182
0
                ggml_compute_forward_gelu_f32(params, dst);
2183
0
            } break;
2184
0
        case GGML_TYPE_F16:
2185
0
            {
2186
0
                ggml_compute_forward_gelu_f16(params, dst);
2187
0
            } break;
2188
0
        default:
2189
0
            {
2190
0
                GGML_ABORT("fatal error");
2191
0
            }
2192
0
    }
2193
0
}
2194
2195
// ggml_compute_fill
2196
2197
0
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2198
0
    const float c = ggml_get_op_params_f32(dst, 0);
2199
2200
0
    GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2201
0
    GGML_TENSOR_LOCALS(size_t,  nb, dst, nb);
2202
2203
0
    const auto [ir0, ir1] = get_thread_range(params, dst);
2204
2205
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
2206
0
        const int64_t i03 = ir/(ne2*ne1);
2207
0
        const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2208
0
        const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2209
2210
0
        float * dst_ptr  = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2211
2212
0
        ggml_vec_set_f32(ne0, dst_ptr, c);
2213
0
    }
2214
0
}
2215
2216
0
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
2217
0
    ggml_compute_forward_fill_f32(params, dst);
2218
0
}
2219
2220
// ggml_compute_tri
2221
2222
0
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2223
0
    const ggml_tensor * src0 = dst->src[0];
2224
2225
0
    const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2226
2227
0
    GGML_ASSERT(ggml_is_contiguous(src0));
2228
2229
0
    GGML_TENSOR_UNARY_OP_LOCALS
2230
2231
0
    const auto [ir0, ir1] = get_thread_range(params, src0);
2232
2233
0
    bool (*bipred)(int, int);
2234
2235
0
    switch (ttype) {
2236
0
        case GGML_TRI_TYPE_LOWER:      bipred = [](int i, int r) { return i <  r; }; break;
2237
0
        case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2238
0
        case GGML_TRI_TYPE_UPPER:      bipred = [](int i, int r) { return i >  r; }; break;
2239
0
        case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
2240
0
        default: GGML_ABORT("invalid tri type");
2241
0
    }
2242
2243
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
2244
0
        const int64_t i03 = ir/(ne02*ne01);
2245
0
        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
2246
0
        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
2247
2248
0
        const float * src_ptr = (const float  *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
2249
0
              float * dst_ptr = (      float  *) ((      char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1);
2250
2251
0
        for (int i0 = 0; i0 < ne0; ++i0) {
2252
0
            dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
2253
0
        }
2254
0
    }
2255
0
}
2256
2257
0
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
2258
0
    const ggml_tensor * src0 = dst->src[0];
2259
2260
0
    switch (src0->type) {
2261
0
        case GGML_TYPE_F32:
2262
0
            {
2263
0
                ggml_compute_forward_tri_f32(params, dst);
2264
0
            } break;
2265
0
        default:
2266
0
            {
2267
0
                GGML_ABORT("fatal error");
2268
0
            }
2269
0
    }
2270
0
}
2271
2272
// ggml_compute_forward_gelu_erf
2273
2274
static void ggml_compute_forward_gelu_erf_f32(
2275
        const ggml_compute_params * params,
2276
0
        ggml_tensor * dst) {
2277
2278
0
    const ggml_tensor * src0 = dst->src[0];
2279
2280
0
    assert(ggml_is_contiguous_1(src0));
2281
0
    assert(ggml_is_contiguous_1(dst));
2282
0
    assert(ggml_are_same_shape(src0, dst));
2283
2284
0
    const int ith = params->ith;
2285
0
    const int nth = params->nth;
2286
2287
0
    const int nc = src0->ne[0];
2288
0
    const int nr = ggml_nrows(src0);
2289
2290
    // rows per thread
2291
0
    const int dr = (nr + nth - 1)/nth;
2292
2293
    // row range for this thread
2294
0
    const int ir0 = dr*ith;
2295
0
    const int ir1 = MIN(ir0 + dr, nr);
2296
2297
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2298
0
        ggml_vec_gelu_erf_f32(nc,
2299
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2300
0
                (float *) ((char *) src0->data + i1*(src0->nb[1])));
2301
2302
#ifndef NDEBUG
2303
        for (int k = 0; k < nc; k++) {
2304
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2305
            GGML_UNUSED(x);
2306
            assert(!isnan(x));
2307
            assert(!isinf(x));
2308
        }
2309
#endif
2310
0
    }
2311
0
}
2312
2313
static void ggml_compute_forward_gelu_erf_f16(
2314
    const ggml_compute_params * params,
2315
0
    ggml_tensor * dst) {
2316
2317
0
    const ggml_tensor * src0 = dst->src[0];
2318
2319
0
    assert(ggml_is_contiguous_1(src0));
2320
0
    assert(ggml_is_contiguous_1(dst));
2321
0
    assert(ggml_are_same_shape(src0, dst));
2322
2323
0
    const int ith = params->ith;
2324
0
    const int nth = params->nth;
2325
2326
0
    const int nc = src0->ne[0];
2327
0
    const int nr = ggml_nrows(src0);
2328
2329
    // rows per thread
2330
0
    const int dr = (nr + nth - 1)/nth;
2331
2332
    // row range for this thread
2333
0
    const int ir0 = dr*ith;
2334
0
    const int ir1 = MIN(ir0 + dr, nr);
2335
2336
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2337
0
        ggml_vec_gelu_erf_f16(nc,
2338
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2339
0
                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2340
2341
#ifndef NDEBUG
2342
        for (int k = 0; k < nc; k++) {
2343
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2344
            const float v = GGML_CPU_FP16_TO_FP32(x);
2345
            GGML_UNUSED(v);
2346
            assert(!isnan(v));
2347
            assert(!isinf(v));
2348
        }
2349
#endif
2350
0
    }
2351
0
}
2352
2353
static void ggml_compute_forward_gelu_erf(
2354
        const ggml_compute_params * params,
2355
0
        ggml_tensor * dst) {
2356
2357
0
    const ggml_tensor * src0 = dst->src[0];
2358
2359
0
    switch (src0->type) {
2360
0
        case GGML_TYPE_F32:
2361
0
            {
2362
0
                ggml_compute_forward_gelu_erf_f32(params, dst);
2363
0
            } break;
2364
0
        case GGML_TYPE_F16:
2365
0
            {
2366
0
                ggml_compute_forward_gelu_erf_f16(params, dst);
2367
0
            } break;
2368
0
        default:
2369
0
            {
2370
0
                GGML_ABORT("fatal error");
2371
0
            }
2372
0
    }
2373
0
}
2374
2375
// ggml_compute_forward_gelu_quick
2376
2377
static void ggml_compute_forward_gelu_quick_f32(
2378
        const ggml_compute_params * params,
2379
0
        ggml_tensor * dst) {
2380
2381
0
    const ggml_tensor * src0 = dst->src[0];
2382
2383
0
    assert(ggml_is_contiguous_1(src0));
2384
0
    assert(ggml_is_contiguous_1(dst));
2385
0
    assert(ggml_are_same_shape(src0, dst));
2386
2387
0
    const int ith = params->ith;
2388
0
    const int nth = params->nth;
2389
2390
0
    const int nc = src0->ne[0];
2391
0
    const int nr = ggml_nrows(src0);
2392
2393
    // rows per thread
2394
0
    const int dr = (nr + nth - 1)/nth;
2395
2396
    // row range for this thread
2397
0
    const int ir0 = dr*ith;
2398
0
    const int ir1 = MIN(ir0 + dr, nr);
2399
2400
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2401
0
        ggml_vec_gelu_quick_f32(nc,
2402
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2403
0
                (float *) ((char *) src0->data + i1*(src0->nb[1])));
2404
2405
#ifndef NDEBUG
2406
        for (int k = 0; k < nc; k++) {
2407
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2408
            GGML_UNUSED(x);
2409
            assert(!isnan(x));
2410
            assert(!isinf(x));
2411
        }
2412
#endif
2413
0
    }
2414
0
}
2415
2416
static void ggml_compute_forward_gelu_quick_f16(
2417
    const ggml_compute_params * params,
2418
0
    ggml_tensor * dst) {
2419
2420
0
    const ggml_tensor * src0 = dst->src[0];
2421
2422
0
    assert(ggml_is_contiguous_1(src0));
2423
0
    assert(ggml_is_contiguous_1(dst));
2424
0
    assert(ggml_are_same_shape(src0, dst));
2425
2426
0
    const int ith = params->ith;
2427
0
    const int nth = params->nth;
2428
2429
0
    const int nc = src0->ne[0];
2430
0
    const int nr = ggml_nrows(src0);
2431
2432
    // rows per thread
2433
0
    const int dr = (nr + nth - 1)/nth;
2434
2435
    // row range for this thread
2436
0
    const int ir0 = dr*ith;
2437
0
    const int ir1 = MIN(ir0 + dr, nr);
2438
2439
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2440
0
        ggml_vec_gelu_quick_f16(nc,
2441
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2442
0
                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2443
2444
#ifndef NDEBUG
2445
        for (int k = 0; k < nc; k++) {
2446
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2447
            const float v = GGML_CPU_FP16_TO_FP32(x);
2448
            GGML_UNUSED(v);
2449
            assert(!isnan(v));
2450
            assert(!isinf(v));
2451
        }
2452
#endif
2453
0
    }
2454
0
}
2455
2456
static void ggml_compute_forward_gelu_quick(
2457
        const ggml_compute_params * params,
2458
0
        ggml_tensor * dst) {
2459
2460
0
    const ggml_tensor * src0 = dst->src[0];
2461
2462
0
    switch (src0->type) {
2463
0
        case GGML_TYPE_F32:
2464
0
            {
2465
0
                ggml_compute_forward_gelu_quick_f32(params, dst);
2466
0
            } break;
2467
0
        case GGML_TYPE_F16:
2468
0
            {
2469
0
                ggml_compute_forward_gelu_quick_f16(params, dst);
2470
0
            } break;
2471
0
        default:
2472
0
            {
2473
0
                GGML_ABORT("fatal error");
2474
0
            }
2475
0
    }
2476
0
}
2477
2478
// ggml_compute_forward_silu
2479
2480
static void ggml_compute_forward_silu_f32(
2481
        const ggml_compute_params * params,
2482
0
        ggml_tensor * dst) {
2483
2484
0
    const ggml_tensor * src0 = dst->src[0];
2485
2486
0
    assert(ggml_is_contiguous_1(src0));
2487
0
    assert(ggml_is_contiguous_1(dst));
2488
0
    assert(ggml_are_same_shape(src0, dst));
2489
2490
0
    const int ith = params->ith;
2491
0
    const int nth = params->nth;
2492
2493
0
    const int nc = src0->ne[0];
2494
0
    const int nr = ggml_nrows(src0);
2495
2496
    // rows per thread
2497
0
    const int dr = (nr + nth - 1)/nth;
2498
2499
    // row range for this thread
2500
0
    const int ir0 = dr*ith;
2501
0
    const int ir1 = MIN(ir0 + dr, nr);
2502
2503
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2504
0
        ggml_vec_silu_f32(nc,
2505
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2506
0
                (float *) ((char *) src0->data + i1*(src0->nb[1])));
2507
2508
#ifndef NDEBUG
2509
        for (int k = 0; k < nc; k++) {
2510
            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2511
            GGML_UNUSED(x);
2512
            assert(!isnan(x));
2513
            assert(!isinf(x));
2514
        }
2515
#endif
2516
0
    }
2517
0
}
2518
2519
static void ggml_compute_forward_silu_f16(
2520
    const ggml_compute_params * params,
2521
0
    ggml_tensor * dst) {
2522
2523
0
    const ggml_tensor * src0 = dst->src[0];
2524
2525
0
    assert(ggml_is_contiguous_1(src0));
2526
0
    assert(ggml_is_contiguous_1(dst));
2527
0
    assert(ggml_are_same_shape(src0, dst));
2528
2529
0
    const int ith = params->ith;
2530
0
    const int nth = params->nth;
2531
2532
0
    const int nc = src0->ne[0];
2533
0
    const int nr = ggml_nrows(src0);
2534
2535
    // rows per thread
2536
0
    const int dr = (nr + nth - 1)/nth;
2537
2538
    // row range for this thread
2539
0
    const int ir0 = dr*ith;
2540
0
    const int ir1 = MIN(ir0 + dr, nr);
2541
2542
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2543
0
        ggml_vec_silu_f16(nc,
2544
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2545
0
                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2546
2547
#ifndef NDEBUG
2548
        for (int k = 0; k < nc; k++) {
2549
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2550
            const float v = GGML_CPU_FP16_TO_FP32(x);
2551
            GGML_UNUSED(v);
2552
            assert(!isnan(v));
2553
            assert(!isinf(v));
2554
        }
2555
#endif
2556
0
    }
2557
0
}
2558
2559
static void ggml_compute_forward_silu(
2560
        const ggml_compute_params * params,
2561
0
        ggml_tensor * dst) {
2562
2563
0
    const ggml_tensor * src0 = dst->src[0];
2564
2565
0
    switch (src0->type) {
2566
0
        case GGML_TYPE_F32:
2567
0
            {
2568
0
                ggml_compute_forward_silu_f32(params, dst);
2569
0
            } break;
2570
0
        case GGML_TYPE_F16:
2571
0
            {
2572
0
                ggml_compute_forward_silu_f16(params, dst);
2573
0
            } break;
2574
0
        default:
2575
0
            {
2576
0
                GGML_ABORT("fatal error");
2577
0
            }
2578
0
    }
2579
0
}
2580
// ggml_compute_forward_leaky_relu
2581
2582
static void ggml_compute_forward_leaky_relu_f32(
2583
        const ggml_compute_params * params,
2584
0
        ggml_tensor * dst) {
2585
2586
0
    const ggml_tensor * src0 = dst->src[0];
2587
2588
0
    if (params->ith != 0) {
2589
0
        return;
2590
0
    }
2591
2592
0
    assert(ggml_is_contiguous_1(src0));
2593
0
    assert(ggml_is_contiguous_1(dst));
2594
0
    assert(ggml_are_same_shape(src0, dst));
2595
2596
0
    const int n  = ggml_nrows(src0);
2597
0
    const int nc = src0->ne[0];
2598
2599
0
    float negative_slope;
2600
0
    memcpy(&negative_slope, dst->op_params, sizeof(float));
2601
2602
0
    assert(dst->nb[0]  == sizeof(float));
2603
0
    assert(src0->nb[0] == sizeof(float));
2604
2605
0
    for (int i = 0; i < n; i++) {
2606
0
        ggml_vec_leaky_relu_f32(nc,
2607
0
                (float *) ((char *) dst->data  + i*( dst->nb[1])),
2608
0
                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2609
0
    }
2610
0
}
2611
2612
static void ggml_compute_forward_leaky_relu_f16(
2613
    const ggml_compute_params * params,
2614
0
    ggml_tensor * dst) {
2615
2616
0
    const ggml_tensor * src0 = dst->src[0];
2617
2618
0
    if (params->ith != 0) {
2619
0
        return;
2620
0
    }
2621
2622
0
    assert(ggml_is_contiguous_1(src0));
2623
0
    assert(ggml_is_contiguous_1(dst));
2624
0
    assert(ggml_are_same_shape(src0, dst));
2625
2626
0
    const int n  = ggml_nrows(src0);
2627
0
    const int nc = src0->ne[0];
2628
2629
0
    float negative_slope;
2630
0
    memcpy(&negative_slope, dst->op_params, sizeof(float));
2631
2632
0
    assert(dst->nb[0]  == sizeof(ggml_fp16_t));
2633
0
    assert(src0->nb[0] == sizeof(ggml_fp16_t));
2634
2635
0
    for (int i = 0; i < n; i++) {
2636
0
        ggml_vec_leaky_relu_f16(nc,
2637
0
                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
2638
0
                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2639
0
    }
2640
0
}
2641
2642
void ggml_compute_forward_leaky_relu(
2643
        const ggml_compute_params * params,
2644
0
        ggml_tensor * dst) {
2645
2646
0
    const ggml_tensor * src0 = dst->src[0];
2647
2648
0
    switch (src0->type) {
2649
0
        case GGML_TYPE_F32:
2650
0
            {
2651
0
                ggml_compute_forward_leaky_relu_f32(params, dst);
2652
0
            } break;
2653
0
        case GGML_TYPE_F16:
2654
0
            {
2655
0
                ggml_compute_forward_leaky_relu_f16(params, dst);
2656
0
            } break;
2657
0
        default:
2658
0
            {
2659
0
                GGML_ABORT("fatal error");
2660
0
            }
2661
0
    }
2662
0
}
2663
2664
// ggml_compute_forward_silu_back
2665
2666
static void ggml_compute_forward_silu_back_f32(
2667
        const ggml_compute_params * params,
2668
0
        ggml_tensor * dst) {
2669
2670
0
    const ggml_tensor * grad = dst->src[0];
2671
0
    const ggml_tensor * src1 = dst->src[1];
2672
2673
0
    assert(ggml_is_contiguous_1(grad));
2674
0
    assert(ggml_is_contiguous_1(src1));
2675
0
    assert(ggml_is_contiguous_1(dst));
2676
0
    assert(ggml_are_same_shape(src1, dst));
2677
0
    assert(ggml_are_same_shape(src1, grad));
2678
2679
0
    const int ith = params->ith;
2680
0
    const int nth = params->nth;
2681
2682
0
    const int nc = src1->ne[0];
2683
0
    const int nr = ggml_nrows(src1);
2684
2685
    // rows per thread
2686
0
    const int dr = (nr + nth - 1)/nth;
2687
2688
    // row range for this thread
2689
0
    const int ir0 = dr*ith;
2690
0
    const int ir1 = MIN(ir0 + dr, nr);
2691
2692
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2693
0
        ggml_vec_silu_backward_f32(nc,
2694
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2695
0
                (float *) ((char *) src1->data + i1*(src1->nb[1])),
2696
0
                (float *) ((char *) grad->data + i1*(grad->nb[1])));
2697
2698
#ifndef NDEBUG
2699
        for (int k = 0; k < nc; k++) {
2700
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2701
            GGML_UNUSED(x);
2702
            assert(!isnan(x));
2703
            assert(!isinf(x));
2704
        }
2705
#endif
2706
0
    }
2707
0
}
2708
2709
static void ggml_compute_forward_silu_back_f16(
2710
    const ggml_compute_params * params,
2711
0
    ggml_tensor * dst) {
2712
2713
0
    const ggml_tensor * grad = dst->src[0];
2714
0
    const ggml_tensor * src1 = dst->src[1];
2715
2716
0
    assert(ggml_is_contiguous_1(grad));
2717
0
    assert(ggml_is_contiguous_1(src1));
2718
0
    assert(ggml_is_contiguous_1(dst));
2719
0
    assert(ggml_are_same_shape(src1, dst));
2720
0
    assert(ggml_are_same_shape(src1, grad));
2721
2722
0
    const int ith = params->ith;
2723
0
    const int nth = params->nth;
2724
2725
0
    const int nc = src1->ne[0];
2726
0
    const int nr = ggml_nrows(src1);
2727
2728
    // rows per thread
2729
0
    const int dr = (nr + nth - 1)/nth;
2730
2731
    // row range for this thread
2732
0
    const int ir0 = dr*ith;
2733
0
    const int ir1 = MIN(ir0 + dr, nr);
2734
2735
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2736
0
        ggml_vec_silu_backward_f16(nc,
2737
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2738
0
                (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
2739
0
                (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2740
2741
    #ifndef NDEBUG
2742
        for (int k = 0; k < nc; k++) {
2743
            const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2744
            const float v = GGML_CPU_FP16_TO_FP32(x);
2745
            GGML_UNUSED(v);
2746
            assert(!isnan(v));
2747
            assert(!isinf(v));
2748
        }
2749
    #endif
2750
0
    }
2751
0
}
2752
2753
void ggml_compute_forward_silu_back(
2754
        const ggml_compute_params * params,
2755
0
        ggml_tensor * dst) {
2756
2757
0
    const ggml_tensor * src0 = dst->src[0];
2758
2759
0
    switch (src0->type) {
2760
0
        case GGML_TYPE_F32:
2761
0
            {
2762
0
                ggml_compute_forward_silu_back_f32(params, dst);
2763
0
            } break;
2764
0
        case GGML_TYPE_F16:
2765
0
            {
2766
0
                ggml_compute_forward_silu_back_f16(params, dst);
2767
0
            } break;
2768
0
        default:
2769
0
            {
2770
0
                GGML_ABORT("fatal error");
2771
0
            }
2772
0
    }
2773
0
}
2774
2775
// ggml_compute_forward_reglu
2776
2777
static void ggml_compute_forward_reglu_f32(
2778
        const ggml_compute_params * params,
2779
0
        ggml_tensor * dst) {
2780
2781
0
    const ggml_tensor * src0 = dst->src[0];
2782
0
    const ggml_tensor * src1 = dst->src[1];
2783
0
    char * src0_d = (char *) src0->data;
2784
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
2785
0
    const size_t src0_o = src0->nb[1];
2786
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2787
2788
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
2789
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
2790
2791
0
    if (src1) {
2792
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
2793
0
        GGML_ASSERT(src0->type == src1->type);
2794
0
    }
2795
2796
0
    const int ith = params->ith;
2797
0
    const int nth = params->nth;
2798
2799
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2800
0
    const int nr = ggml_nrows(src0);
2801
2802
0
    GGML_ASSERT(dst->ne[0] == nc);
2803
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
2804
2805
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2806
2807
    // rows per thread
2808
0
    const int dr = (nr + nth - 1)/nth;
2809
2810
    // row range for this thread
2811
0
    const int ir0 = dr*ith;
2812
0
    const int ir1 = MIN(ir0 + dr, nr);
2813
2814
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2815
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
2816
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
2817
2818
0
        if (!src1) {
2819
0
            src0_p += swapped ? nc : 0;
2820
0
            src1_p += swapped ? 0 : nc;
2821
0
        }
2822
2823
0
        ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2824
2825
#ifndef NDEBUG
2826
        for (int k = 0; k < nc; k++) {
2827
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2828
            GGML_UNUSED(x);
2829
            assert(!isnan(x));
2830
            assert(!isinf(x));
2831
        }
2832
#endif
2833
0
    }
2834
0
}
2835
2836
static void ggml_compute_forward_reglu_f16(
2837
    const ggml_compute_params * params,
2838
0
    ggml_tensor * dst) {
2839
2840
0
    const ggml_tensor * src0 = dst->src[0];
2841
0
    const ggml_tensor * src1 = dst->src[1];
2842
0
    char * src0_d = (char *) src0->data;
2843
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
2844
0
    const size_t src0_o = src0->nb[1];
2845
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2846
2847
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
2848
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
2849
2850
0
    if (src1) {
2851
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
2852
0
        GGML_ASSERT(src0->type == src1->type);
2853
0
    }
2854
2855
0
    const int ith = params->ith;
2856
0
    const int nth = params->nth;
2857
2858
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2859
0
    const int nr = ggml_nrows(src0);
2860
2861
0
    GGML_ASSERT(dst->ne[0] == nc);
2862
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
2863
2864
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2865
2866
    // rows per thread
2867
0
    const int dr = (nr + nth - 1)/nth;
2868
2869
    // row range for this thread
2870
0
    const int ir0 = dr*ith;
2871
0
    const int ir1 = MIN(ir0 + dr, nr);
2872
2873
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2874
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
2875
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
2876
2877
0
        if (!src1) {
2878
0
            src0_p += swapped ? nc : 0;
2879
0
            src1_p += swapped ? 0 : nc;
2880
0
        }
2881
2882
0
        ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2883
2884
#ifndef NDEBUG
2885
        for (int k = 0; k < nc; k++) {
2886
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2887
            const float v = GGML_FP16_TO_FP32(x);
2888
            GGML_UNUSED(v);
2889
            assert(!isnan(v));
2890
            assert(!isinf(v));
2891
        }
2892
#endif
2893
0
    }
2894
0
}
2895
2896
static void ggml_compute_forward_reglu(
2897
        const ggml_compute_params * params,
2898
0
        ggml_tensor * dst) {
2899
2900
0
    const ggml_tensor * src0 = dst->src[0];
2901
2902
0
    switch (src0->type) {
2903
0
        case GGML_TYPE_F32:
2904
0
            {
2905
0
                ggml_compute_forward_reglu_f32(params, dst);
2906
0
            } break;
2907
0
        case GGML_TYPE_F16:
2908
0
            {
2909
0
                ggml_compute_forward_reglu_f16(params, dst);
2910
0
            } break;
2911
0
        default:
2912
0
            {
2913
0
                GGML_ABORT("fatal error");
2914
0
            }
2915
0
    }
2916
0
}
2917
2918
// ggml_compute_forward_geglu
2919
2920
static void ggml_compute_forward_geglu_f32(
2921
        const ggml_compute_params * params,
2922
0
        ggml_tensor * dst) {
2923
2924
0
    const ggml_tensor * src0 = dst->src[0];
2925
0
    const ggml_tensor * src1 = dst->src[1];
2926
0
    char * src0_d = (char *) src0->data;
2927
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
2928
0
    const size_t src0_o = src0->nb[1];
2929
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2930
2931
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
2932
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
2933
2934
0
    if (src1) {
2935
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
2936
0
        GGML_ASSERT(src0->type == src1->type);
2937
0
    }
2938
2939
0
    const int ith = params->ith;
2940
0
    const int nth = params->nth;
2941
2942
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2943
0
    const int nr = ggml_nrows(src0);
2944
2945
0
    GGML_ASSERT(dst->ne[0] == nc);
2946
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
2947
2948
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2949
2950
    // rows per thread
2951
0
    const int dr = (nr + nth - 1)/nth;
2952
2953
    // row range for this thread
2954
0
    const int ir0 = dr*ith;
2955
0
    const int ir1 = MIN(ir0 + dr, nr);
2956
2957
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2958
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
2959
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
2960
2961
0
        if (!src1) {
2962
0
            src0_p += swapped ? nc : 0;
2963
0
            src1_p += swapped ? 0 : nc;
2964
0
        }
2965
2966
0
        ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2967
2968
#ifndef NDEBUG
2969
        for (int k = 0; k < nc; k++) {
2970
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2971
            GGML_UNUSED(x);
2972
            assert(!isnan(x));
2973
            assert(!isinf(x));
2974
        }
2975
#endif
2976
0
    }
2977
0
}
2978
2979
static void ggml_compute_forward_geglu_f16(
2980
    const ggml_compute_params * params,
2981
0
    ggml_tensor * dst) {
2982
2983
0
    const ggml_tensor * src0 = dst->src[0];
2984
0
    const ggml_tensor * src1 = dst->src[1];
2985
0
    char * src0_d = (char *) src0->data;
2986
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
2987
0
    const size_t src0_o = src0->nb[1];
2988
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2989
2990
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
2991
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
2992
2993
0
    if (src1) {
2994
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
2995
0
        GGML_ASSERT(src0->type == src1->type);
2996
0
    }
2997
2998
0
    const int ith = params->ith;
2999
0
    const int nth = params->nth;
3000
3001
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3002
0
    const int nr = ggml_nrows(src0);
3003
3004
0
    GGML_ASSERT(dst->ne[0] == nc);
3005
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3006
3007
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3008
3009
    // rows per thread
3010
0
    const int dr = (nr + nth - 1)/nth;
3011
3012
    // row range for this thread
3013
0
    const int ir0 = dr*ith;
3014
0
    const int ir1 = MIN(ir0 + dr, nr);
3015
3016
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3017
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3018
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3019
3020
0
        if (!src1) {
3021
0
            src0_p += swapped ? nc : 0;
3022
0
            src1_p += swapped ? 0 : nc;
3023
0
        }
3024
3025
0
        ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3026
3027
#ifndef NDEBUG
3028
        for (int k = 0; k < nc; k++) {
3029
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3030
            const float v = GGML_FP16_TO_FP32(x);
3031
            GGML_UNUSED(v);
3032
            assert(!isnan(v));
3033
            assert(!isinf(v));
3034
        }
3035
#endif
3036
0
    }
3037
0
}
3038
3039
static void ggml_compute_forward_geglu(
3040
        const ggml_compute_params * params,
3041
0
        ggml_tensor * dst) {
3042
3043
0
    const ggml_tensor * src0 = dst->src[0];
3044
3045
0
    switch (src0->type) {
3046
0
        case GGML_TYPE_F32:
3047
0
            {
3048
0
                ggml_compute_forward_geglu_f32(params, dst);
3049
0
            } break;
3050
0
        case GGML_TYPE_F16:
3051
0
            {
3052
0
                ggml_compute_forward_geglu_f16(params, dst);
3053
0
            } break;
3054
0
        default:
3055
0
            {
3056
0
                GGML_ABORT("fatal error");
3057
0
            }
3058
0
    }
3059
0
}
3060
3061
// ggml_compute_forward_swiglu
3062
3063
static void ggml_compute_forward_swiglu_f32(
3064
        const ggml_compute_params * params,
3065
0
        ggml_tensor * dst) {
3066
3067
0
    const ggml_tensor * src0 = dst->src[0];
3068
0
    const ggml_tensor * src1 = dst->src[1];
3069
0
    char * src0_d = (char *) src0->data;
3070
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3071
0
    const size_t src0_o = src0->nb[1];
3072
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3073
3074
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3075
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3076
3077
0
    if (src1) {
3078
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3079
0
        GGML_ASSERT(src0->type == src1->type);
3080
0
    }
3081
3082
0
    const int ith = params->ith;
3083
0
    const int nth = params->nth;
3084
3085
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3086
0
    const int nr = ggml_nrows(src0);
3087
3088
0
    GGML_ASSERT(dst->ne[0] == nc);
3089
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3090
3091
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3092
3093
    // rows per thread
3094
0
    const int dr = (nr + nth - 1)/nth;
3095
3096
    // row range for this thread
3097
0
    const int ir0 = dr*ith;
3098
0
    const int ir1 = MIN(ir0 + dr, nr);
3099
3100
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3101
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3102
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3103
3104
0
        if (!src1) {
3105
0
            src0_p += swapped ? nc : 0;
3106
0
            src1_p += swapped ? 0 : nc;
3107
0
        }
3108
3109
0
        ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3110
3111
#ifndef NDEBUG
3112
        for (int k = 0; k < nc; k++) {
3113
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3114
            GGML_UNUSED(x);
3115
            assert(!isnan(x));
3116
            assert(!isinf(x));
3117
        }
3118
#endif
3119
0
    }
3120
0
}
3121
3122
static void ggml_compute_forward_swiglu_f16(
3123
    const ggml_compute_params * params,
3124
0
    ggml_tensor * dst) {
3125
3126
0
    const ggml_tensor * src0 = dst->src[0];
3127
0
    const ggml_tensor * src1 = dst->src[1];
3128
0
    char * src0_d = (char *) src0->data;
3129
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3130
0
    const size_t src0_o = src0->nb[1];
3131
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3132
3133
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3134
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3135
3136
0
    if (src1) {
3137
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3138
0
        GGML_ASSERT(src0->type == src1->type);
3139
0
    }
3140
3141
0
    const int ith = params->ith;
3142
0
    const int nth = params->nth;
3143
3144
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3145
0
    const int nr = ggml_nrows(src0);
3146
3147
0
    GGML_ASSERT(dst->ne[0] == nc);
3148
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3149
3150
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3151
3152
    // rows per thread
3153
0
    const int dr = (nr + nth - 1)/nth;
3154
3155
    // row range for this thread
3156
0
    const int ir0 = dr*ith;
3157
0
    const int ir1 = MIN(ir0 + dr, nr);
3158
3159
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3160
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3161
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3162
3163
0
        if (!src1) {
3164
0
            src0_p += swapped ? nc : 0;
3165
0
            src1_p += swapped ? 0 : nc;
3166
0
        }
3167
3168
0
        ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3169
3170
#ifndef NDEBUG
3171
        for (int k = 0; k < nc; k++) {
3172
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3173
            const float v = GGML_FP16_TO_FP32(x);
3174
            GGML_UNUSED(v);
3175
            assert(!isnan(v));
3176
            assert(!isinf(v));
3177
        }
3178
#endif
3179
0
    }
3180
0
}
3181
3182
static void ggml_compute_forward_swiglu(
3183
        const ggml_compute_params * params,
3184
0
        ggml_tensor * dst) {
3185
3186
0
    const ggml_tensor * src0 = dst->src[0];
3187
3188
0
    switch (src0->type) {
3189
0
        case GGML_TYPE_F32:
3190
0
            {
3191
0
                ggml_compute_forward_swiglu_f32(params, dst);
3192
0
            } break;
3193
0
        case GGML_TYPE_F16:
3194
0
            {
3195
0
                ggml_compute_forward_swiglu_f16(params, dst);
3196
0
            } break;
3197
0
        default:
3198
0
            {
3199
0
                GGML_ABORT("fatal error");
3200
0
            }
3201
0
    }
3202
0
}
3203
3204
// ggml_compute_forward_swiglu_oai
3205
3206
static void ggml_compute_forward_swiglu_oai_f32(
3207
        const ggml_compute_params * params,
3208
0
        ggml_tensor * dst) {
3209
3210
0
    const ggml_tensor * src0 = dst->src[0];
3211
0
    const ggml_tensor * src1 = dst->src[1];
3212
0
    char * src0_d = (char *) src0->data;
3213
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3214
0
    const size_t src0_o = src0->nb[1];
3215
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3216
3217
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3218
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3219
3220
0
    if (src1) {
3221
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3222
0
        GGML_ASSERT(src0->type == src1->type);
3223
0
    }
3224
3225
0
    const int ith = params->ith;
3226
0
    const int nth = params->nth;
3227
3228
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3229
0
    const int nr = ggml_nrows(src0);
3230
3231
0
    GGML_ASSERT(dst->ne[0] == nc);
3232
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3233
3234
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3235
0
    const float alpha = ggml_get_op_params_f32(dst, 2);
3236
0
    const float limit = ggml_get_op_params_f32(dst, 3);
3237
3238
    // rows per thread
3239
0
    const int dr = (nr + nth - 1)/nth;
3240
3241
    // row range for this thread
3242
0
    const int ir0 = dr*ith;
3243
0
    const int ir1 = MIN(ir0 + dr, nr);
3244
3245
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3246
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3247
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3248
0
        float * dst_p  = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3249
3250
0
        if (!src1) {
3251
0
            src0_p += swapped ? nc : 0;
3252
0
            src1_p += swapped ? 0 : nc;
3253
0
        }
3254
3255
0
        for (int k = 0; k < nc; k++) {
3256
0
            const float x = std::min(src0_p[k], limit);
3257
0
            const float y = std::clamp(src1_p[k], -limit, limit);
3258
0
            const float out_glu = x / (1.f + expf(alpha * (-x)));
3259
0
            dst_p[k] = out_glu * (y + 1.f);
3260
0
        }
3261
3262
#ifndef NDEBUG
3263
        for (int k = 0; k < nc; k++) {
3264
            const float x = dst_p[k];
3265
            GGML_UNUSED(x);
3266
            assert(!isnan(x));
3267
            assert(!isinf(x));
3268
        }
3269
#endif
3270
0
    }
3271
0
}
3272
3273
static void ggml_compute_forward_swiglu_oai(
3274
        const ggml_compute_params * params,
3275
0
        ggml_tensor * dst) {
3276
3277
0
    const ggml_tensor * src0 = dst->src[0];
3278
3279
0
    switch (src0->type) {
3280
0
        case GGML_TYPE_F32:
3281
0
            {
3282
0
                ggml_compute_forward_swiglu_oai_f32(params, dst);
3283
0
            } break;
3284
0
        default:
3285
0
            {
3286
0
                GGML_ABORT("fatal error");
3287
0
            }
3288
0
    }
3289
0
}
3290
3291
// ggml_compute_forward_geglu_erf
3292
3293
static void ggml_compute_forward_geglu_erf_f32(
3294
        const ggml_compute_params * params,
3295
0
        ggml_tensor * dst) {
3296
3297
0
    const ggml_tensor * src0 = dst->src[0];
3298
0
    const ggml_tensor * src1 = dst->src[1];
3299
0
    char * src0_d = (char *) src0->data;
3300
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3301
0
    const size_t src0_o = src0->nb[1];
3302
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3303
3304
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3305
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3306
3307
0
    if (src1) {
3308
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3309
0
        GGML_ASSERT(src0->type == src1->type);
3310
0
    }
3311
3312
0
    const int ith = params->ith;
3313
0
    const int nth = params->nth;
3314
3315
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3316
0
    const int nr = ggml_nrows(src0);
3317
3318
0
    GGML_ASSERT(dst->ne[0] == nc);
3319
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3320
3321
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3322
3323
    // rows per thread
3324
0
    const int dr = (nr + nth - 1)/nth;
3325
3326
    // row range for this thread
3327
0
    const int ir0 = dr*ith;
3328
0
    const int ir1 = MIN(ir0 + dr, nr);
3329
3330
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3331
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3332
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3333
3334
0
        if (!src1) {
3335
0
            src0_p += swapped ? nc : 0;
3336
0
            src1_p += swapped ? 0 : nc;
3337
0
        }
3338
3339
0
        ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3340
3341
#ifndef NDEBUG
3342
        for (int k = 0; k < nc; k++) {
3343
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3344
            GGML_UNUSED(x);
3345
            assert(!isnan(x));
3346
            assert(!isinf(x));
3347
        }
3348
#endif
3349
0
    }
3350
0
}
3351
3352
static void ggml_compute_forward_geglu_erf_f16(
3353
    const ggml_compute_params * params,
3354
0
    ggml_tensor * dst) {
3355
3356
0
    const ggml_tensor * src0 = dst->src[0];
3357
0
    const ggml_tensor * src1 = dst->src[1];
3358
0
    char * src0_d = (char *) src0->data;
3359
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3360
0
    const size_t src0_o = src0->nb[1];
3361
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3362
3363
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3364
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3365
3366
0
    if (src1) {
3367
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3368
0
        GGML_ASSERT(src0->type == src1->type);
3369
0
    }
3370
3371
0
    const int ith = params->ith;
3372
0
    const int nth = params->nth;
3373
3374
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3375
0
    const int nr = ggml_nrows(src0);
3376
3377
0
    GGML_ASSERT(dst->ne[0] == nc);
3378
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3379
3380
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3381
3382
    // rows per thread
3383
0
    const int dr = (nr + nth - 1)/nth;
3384
3385
    // row range for this thread
3386
0
    const int ir0 = dr*ith;
3387
0
    const int ir1 = MIN(ir0 + dr, nr);
3388
3389
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3390
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3391
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3392
3393
0
        if (!src1) {
3394
0
            src0_p += swapped ? nc : 0;
3395
0
            src1_p += swapped ? 0 : nc;
3396
0
        }
3397
3398
0
        ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3399
3400
#ifndef NDEBUG
3401
        for (int k = 0; k < nc; k++) {
3402
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3403
            const float v = GGML_FP16_TO_FP32(x);
3404
            GGML_UNUSED(v);
3405
            assert(!isnan(v));
3406
            assert(!isinf(v));
3407
        }
3408
#endif
3409
0
    }
3410
0
}
3411
3412
static void ggml_compute_forward_geglu_erf(
3413
        const ggml_compute_params * params,
3414
0
        ggml_tensor * dst) {
3415
3416
0
    const ggml_tensor * src0 = dst->src[0];
3417
3418
0
    switch (src0->type) {
3419
0
        case GGML_TYPE_F32:
3420
0
            {
3421
0
                ggml_compute_forward_geglu_erf_f32(params, dst);
3422
0
            } break;
3423
0
        case GGML_TYPE_F16:
3424
0
            {
3425
0
                ggml_compute_forward_geglu_erf_f16(params, dst);
3426
0
            } break;
3427
0
        default:
3428
0
            {
3429
0
                GGML_ABORT("fatal error");
3430
0
            }
3431
0
    }
3432
0
}
3433
3434
// ggml_compute_forward_geglu_quick
3435
3436
static void ggml_compute_forward_geglu_quick_f32(
3437
        const ggml_compute_params * params,
3438
0
        ggml_tensor * dst) {
3439
3440
0
    const ggml_tensor * src0 = dst->src[0];
3441
0
    const ggml_tensor * src1 = dst->src[1];
3442
0
    char * src0_d = (char *) src0->data;
3443
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3444
0
    const size_t src0_o = src0->nb[1];
3445
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3446
3447
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3448
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3449
3450
0
    if (src1) {
3451
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3452
0
        GGML_ASSERT(src0->type == src1->type);
3453
0
    }
3454
3455
0
    const int ith = params->ith;
3456
0
    const int nth = params->nth;
3457
3458
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3459
0
    const int nr = ggml_nrows(src0);
3460
3461
0
    GGML_ASSERT(dst->ne[0] == nc);
3462
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3463
3464
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3465
3466
    // rows per thread
3467
0
    const int dr = (nr + nth - 1)/nth;
3468
3469
    // row range for this thread
3470
0
    const int ir0 = dr*ith;
3471
0
    const int ir1 = MIN(ir0 + dr, nr);
3472
3473
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3474
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3475
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3476
3477
0
        if (!src1) {
3478
0
            src0_p += swapped ? nc : 0;
3479
0
            src1_p += swapped ? 0 : nc;
3480
0
        }
3481
3482
0
        ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3483
3484
#ifndef NDEBUG
3485
        for (int k = 0; k < nc; k++) {
3486
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3487
            GGML_UNUSED(x);
3488
            assert(!isnan(x));
3489
            assert(!isinf(x));
3490
        }
3491
#endif
3492
0
    }
3493
0
}
3494
3495
static void ggml_compute_forward_geglu_quick_f16(
3496
    const ggml_compute_params * params,
3497
0
    ggml_tensor * dst) {
3498
3499
0
    const ggml_tensor * src0 = dst->src[0];
3500
0
    const ggml_tensor * src1 = dst->src[1];
3501
0
    char * src0_d = (char *) src0->data;
3502
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3503
0
    const size_t src0_o = src0->nb[1];
3504
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3505
3506
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3507
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3508
3509
0
    if (src1) {
3510
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3511
0
        GGML_ASSERT(src0->type == src1->type);
3512
0
    }
3513
3514
0
    const int ith = params->ith;
3515
0
    const int nth = params->nth;
3516
3517
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3518
0
    const int nr = ggml_nrows(src0);
3519
3520
0
    GGML_ASSERT(dst->ne[0] == nc);
3521
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3522
3523
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3524
3525
    // rows per thread
3526
0
    const int dr = (nr + nth - 1)/nth;
3527
3528
    // row range for this thread
3529
0
    const int ir0 = dr*ith;
3530
0
    const int ir1 = MIN(ir0 + dr, nr);
3531
3532
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3533
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3534
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3535
3536
0
        if (!src1) {
3537
0
            src0_p += swapped ? nc : 0;
3538
0
            src1_p += swapped ? 0 : nc;
3539
0
        }
3540
3541
0
        ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3542
3543
#ifndef NDEBUG
3544
        for (int k = 0; k < nc; k++) {
3545
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3546
            const float v = GGML_FP16_TO_FP32(x);
3547
            GGML_UNUSED(v);
3548
            assert(!isnan(v));
3549
            assert(!isinf(v));
3550
        }
3551
#endif
3552
0
    }
3553
0
}
3554
3555
static void ggml_compute_forward_geglu_quick(
3556
        const ggml_compute_params * params,
3557
0
        ggml_tensor * dst) {
3558
3559
0
    const ggml_tensor * src0 = dst->src[0];
3560
3561
0
    switch (src0->type) {
3562
0
        case GGML_TYPE_F32:
3563
0
            {
3564
0
                ggml_compute_forward_geglu_quick_f32(params, dst);
3565
0
            } break;
3566
0
        case GGML_TYPE_F16:
3567
0
            {
3568
0
                ggml_compute_forward_geglu_quick_f16(params, dst);
3569
0
            } break;
3570
0
        default:
3571
0
            {
3572
0
                GGML_ABORT("fatal error");
3573
0
            }
3574
0
    }
3575
0
}
3576
3577
// ggml_compute_forward_norm
3578
3579
static void ggml_compute_forward_norm_f32(
3580
        const ggml_compute_params * params,
3581
0
        ggml_tensor * dst) {
3582
3583
0
    const ggml_tensor * src0 = dst->src[0];
3584
3585
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
3586
3587
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3588
3589
0
    const int ith = params->ith;
3590
0
    const int nth = params->nth;
3591
3592
0
    GGML_TENSOR_UNARY_OP_LOCALS
3593
3594
0
    float eps;
3595
0
    memcpy(&eps, dst->op_params, sizeof(float));
3596
3597
0
    GGML_ASSERT(eps >= 0.0f);
3598
3599
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
3600
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
3601
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3602
0
                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3603
3604
0
                float sum = 0.0;
3605
0
                ggml_vec_sum_f32(ne00, &sum, x);
3606
0
                float mean = sum/ne00;
3607
3608
0
                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3609
0
                float variance = 0;
3610
3611
#ifdef GGML_USE_ACCELERATE
3612
                mean = -mean;
3613
                vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3614
                vDSP_measqv(y, 1, &variance, ne00);
3615
#else
3616
0
                variance = ggml_vec_cvar_f32(ne00, y, x, mean);
3617
0
#endif //GGML_USE_ACCELERATE
3618
3619
0
                const float scale = 1.0f/sqrtf(variance + eps);
3620
0
                ggml_vec_scale_f32(ne00, y, scale);
3621
0
            }
3622
0
        }
3623
0
    }
3624
0
}
3625
3626
void ggml_compute_forward_norm(
3627
        const ggml_compute_params * params,
3628
0
        ggml_tensor * dst) {
3629
3630
0
    const ggml_tensor * src0 = dst->src[0];
3631
3632
0
    switch (src0->type) {
3633
0
        case GGML_TYPE_F32:
3634
0
            {
3635
0
                ggml_compute_forward_norm_f32(params, dst);
3636
0
            } break;
3637
0
        default:
3638
0
            {
3639
0
                GGML_ABORT("fatal error");
3640
0
            }
3641
0
    }
3642
0
}
3643
3644
// ggml_compute_forward_group_rms_norm
3645
3646
static void ggml_compute_forward_rms_norm_f32(
3647
        const ggml_compute_params * params,
3648
0
        ggml_tensor * dst) {
3649
3650
0
    const ggml_tensor * src0 = dst->src[0];
3651
3652
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
3653
3654
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3655
3656
0
    const int ith = params->ith;
3657
0
    const int nth = params->nth;
3658
3659
0
    GGML_TENSOR_UNARY_OP_LOCALS
3660
3661
0
    float eps;
3662
0
    memcpy(&eps, dst->op_params, sizeof(float));
3663
3664
0
    GGML_ASSERT(eps >= 0.0f);
3665
3666
    // TODO: optimize
3667
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
3668
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
3669
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3670
0
                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3671
3672
0
                ggml_float sum = 0.0;
3673
0
                for (int64_t i00 = 0; i00 < ne00; i00++) {
3674
0
                    sum += (ggml_float)(x[i00] * x[i00]);
3675
0
                }
3676
3677
0
                const float mean = sum/ne00;
3678
3679
0
                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3680
3681
0
                memcpy(y, x, ne00 * sizeof(float));
3682
                // for (int i00 = 0; i00 < ne00; i00++) {
3683
                //     y[i00] = x[i00];
3684
                // }
3685
3686
0
                const float scale = 1.0f/sqrtf(mean + eps);
3687
3688
                // if you hit this, likely you got an inf somewhere earlier
3689
0
                assert(scale > 0.0f);
3690
3691
0
                ggml_vec_scale_f32(ne00, y, scale);
3692
0
            }
3693
0
        }
3694
0
    }
3695
0
}
3696
3697
void ggml_compute_forward_rms_norm(
3698
        const ggml_compute_params * params,
3699
0
        ggml_tensor * dst) {
3700
3701
0
    const ggml_tensor * src0 = dst->src[0];
3702
3703
0
    switch (src0->type) {
3704
0
        case GGML_TYPE_F32:
3705
0
            {
3706
0
                ggml_compute_forward_rms_norm_f32(params, dst);
3707
0
            } break;
3708
0
        default:
3709
0
            {
3710
0
                GGML_ABORT("fatal error");
3711
0
            }
3712
0
    }
3713
0
}
3714
3715
static void ggml_compute_forward_rms_norm_back_f32(
3716
        const ggml_compute_params * params,
3717
0
        ggml_tensor * dst) {
3718
3719
0
    const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
3720
0
    const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
3721
3722
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
3723
3724
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3725
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
3726
3727
0
    const int ith = params->ith;
3728
0
    const int nth = params->nth;
3729
3730
0
    GGML_TENSOR_BINARY_OP_LOCALS
3731
3732
0
    float eps;
3733
0
    memcpy(&eps, dst->op_params, sizeof(float));
3734
3735
    // TODO: optimize
3736
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
3737
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
3738
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3739
                // src1 is same shape as src0 => same indices
3740
0
                const int64_t i11 = i01;
3741
0
                const int64_t i12 = i02;
3742
0
                const int64_t i13 = i03;
3743
3744
0
                const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3745
0
                const float * x  = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
3746
3747
0
                ggml_float sum_xx  = 0.0;
3748
0
                ggml_float sum_xdz = 0.0;
3749
3750
0
                for (int64_t i00 = 0; i00 < ne00; i00++) {
3751
0
                    sum_xx  += (ggml_float)(x[i00] * x[i00]);
3752
0
                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);
3753
0
                }
3754
3755
                //const float mean     = (float)(sum_xx)/ne00;
3756
0
                const float mean_eps = (float)(sum_xx)/ne00 + eps;
3757
0
                const float sum_eps  = (float)(sum_xx) + eps*ne00;
3758
                //const float mean_xdz = (float)(sum_xdz)/ne00;
3759
                // we could cache rms from forward pass to improve performance.
3760
                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
3761
                //const float rms      = sqrtf(mean_eps);
3762
0
                const float rrms     = 1.0f / sqrtf(mean_eps);
3763
                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
3764
3765
0
                {
3766
                    // z = rms_norm(x)
3767
                    //
3768
                    // rms_norm(src1) =
3769
                    //     scale(
3770
                    //         src1,
3771
                    //         div(
3772
                    //             1,
3773
                    //             sqrt(
3774
                    //                 add(
3775
                    //                     scale(
3776
                    //                         sum(
3777
                    //                             sqr(
3778
                    //                                 src1)),
3779
                    //                         (1.0/N)),
3780
                    //                     eps))));
3781
3782
                    // postorder:
3783
                    // ## op    args         grad
3784
                    // 00 param src1         grad[#00]
3785
                    // 01 const 1
3786
                    // 02 sqr   (#00)        grad[#02]
3787
                    // 03 sum   (#02)        grad[#03]
3788
                    // 04 const 1/N
3789
                    // 05 scale (#03, #04)   grad[#05]
3790
                    // 06 const eps
3791
                    // 07 add   (#05, #06)   grad[#07]
3792
                    // 08 sqrt  (#07)        grad[#08]
3793
                    // 09 div   (#01,#08)    grad[#09]
3794
                    // 10 scale (#00,#09)    grad[#10]
3795
                    //
3796
                    // backward pass, given grad[#10]
3797
                    // #10: scale
3798
                    // grad[#00] += scale(grad[#10],#09)
3799
                    // grad[#09] += sum(mul(grad[#10],#00))
3800
                    // #09: div
3801
                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
3802
                    // #08: sqrt
3803
                    // grad[#07] += mul(grad[#08], div(0.5, #08))
3804
                    // #07: add
3805
                    // grad[#05] += grad[#07]
3806
                    // #05: scale
3807
                    // grad[#03] += scale(grad[#05],#04)
3808
                    // #03: sum
3809
                    // grad[#02] += repeat(grad[#03], #02)
3810
                    // #02:
3811
                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
3812
                    //
3813
                    // substitute and simplify:
3814
                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
3815
                    // grad[#02] = repeat(grad[#03], #02)
3816
                    // grad[#02] = repeat(scale(grad[#05],#04), #02)
3817
                    // grad[#02] = repeat(scale(grad[#07],#04), #02)
3818
                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
3819
                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
3820
                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
3821
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
3822
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
3823
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
3824
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
3825
                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
3826
                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
3827
                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
3828
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
3829
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
3830
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
3831
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
3832
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
3833
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
3834
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
3835
                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
3836
                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
3837
                    // a = b*c + d*e
3838
                    // a = b*c*f/f + d*e*f/f
3839
                    // a = (b*c*f + d*e*f)*(1/f)
3840
                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
3841
                    // a = (b + d*e/c)*c
3842
                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
3843
                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
3844
                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
3845
                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
3846
                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
3847
                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
3848
                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
3849
                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
3850
                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3851
                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3852
0
                }
3853
                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3854
                // post-order:
3855
                // dx := x
3856
                // dx := scale(dx,-mean_xdz/mean_eps)
3857
                // dx := add(dx, dz)
3858
                // dx := scale(dx, rrms)
3859
0
                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3860
3861
                // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
3862
0
                ggml_vec_cpy_f32  (ne00, dx, x);
3863
                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
3864
0
                ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
3865
0
                ggml_vec_acc_f32  (ne00, dx, dz);
3866
0
                ggml_vec_scale_f32(ne00, dx, rrms);
3867
0
            }
3868
0
        }
3869
0
    }
3870
0
}
3871
3872
void ggml_compute_forward_rms_norm_back(
3873
        const ggml_compute_params * params,
3874
0
        ggml_tensor * dst) {
3875
3876
0
    const ggml_tensor * src0 = dst->src[0];
3877
3878
0
    switch (src0->type) {
3879
0
        case GGML_TYPE_F32:
3880
0
            {
3881
0
                ggml_compute_forward_rms_norm_back_f32(params, dst);
3882
0
            } break;
3883
0
        default:
3884
0
            {
3885
0
                GGML_ABORT("fatal error");
3886
0
            }
3887
0
    }
3888
0
}
3889
3890
// ggml_compute_forward_group_norm
3891
3892
static void ggml_compute_forward_group_norm_f32(
3893
    const ggml_compute_params * params,
3894
0
    ggml_tensor * dst) {
3895
3896
0
    const ggml_tensor * src0 = dst->src[0];
3897
3898
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
3899
3900
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3901
3902
0
    const int ith = params->ith;
3903
0
    const int nth = params->nth;
3904
3905
0
    GGML_TENSOR_UNARY_OP_LOCALS
3906
3907
    // TODO: optimize
3908
3909
0
    float eps;
3910
0
    memcpy(&eps, dst->op_params + 1, sizeof(float));
3911
3912
0
    int n_channels = src0->ne[2];
3913
0
    int n_groups = dst->op_params[0];
3914
0
    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
3915
0
    for (int i = ith; i < n_groups; i += nth) {
3916
0
        int start = i * n_channels_per_group;
3917
0
        int end = start + n_channels_per_group;
3918
0
        if (end > n_channels) {
3919
0
            end = n_channels;
3920
0
        }
3921
0
        int step = end - start;
3922
3923
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
3924
0
            ggml_float sum = 0.0;
3925
0
            for (int64_t i02 = start; i02 < end; i02++) {
3926
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
3927
0
                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3928
3929
0
                    ggml_float sumr = 0.0;
3930
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
3931
0
                        sumr += (ggml_float)x[i00];
3932
0
                    }
3933
0
                    sum += sumr;
3934
0
                }
3935
0
            }
3936
0
            const float mean = sum / (ne00 * ne01 * step);
3937
3938
0
            ggml_float sum2 = 0.0;
3939
0
            for (int64_t i02 = start; i02 < end; i02++) {
3940
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
3941
0
                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3942
3943
0
                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
3944
3945
0
                    ggml_float sumr = 0.0;
3946
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
3947
0
                        float v = x[i00] - mean;
3948
0
                        y[i00] = v;
3949
0
                        sumr += (ggml_float)(v * v);
3950
0
                    }
3951
0
                    sum2 += sumr;
3952
0
                }
3953
0
            }
3954
0
            const float variance = sum2 / (ne00 * ne01 * step);
3955
0
            const float scale = 1.0f / sqrtf(variance + eps);
3956
3957
0
            for (int64_t i02 = start; i02 < end; i02++) {
3958
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
3959
0
                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
3960
0
                    ggml_vec_scale_f32(ne00, y, scale);
3961
0
                }
3962
0
            }
3963
0
        }
3964
0
    }
3965
0
}
3966
3967
void ggml_compute_forward_group_norm(
3968
    const ggml_compute_params * params,
3969
0
    ggml_tensor * dst) {
3970
3971
0
    const ggml_tensor * src0 = dst->src[0];
3972
3973
0
    switch (src0->type) {
3974
0
        case GGML_TYPE_F32:
3975
0
            {
3976
0
                ggml_compute_forward_group_norm_f32(params, dst);
3977
0
            } break;
3978
0
        default:
3979
0
            {
3980
0
                GGML_ABORT("fatal error");
3981
0
            }
3982
0
    }
3983
0
}
3984
3985
// ggml_compute_forward_l2_norm
3986
3987
static void ggml_compute_forward_l2_norm_f32(
3988
    const ggml_compute_params * params,
3989
0
    ggml_tensor * dst) {
3990
3991
0
    const ggml_tensor * src0 = dst->src[0];
3992
3993
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
3994
3995
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3996
3997
0
    const int ith = params->ith;
3998
0
    const int nth = params->nth;
3999
4000
0
    GGML_TENSOR_UNARY_OP_LOCALS
4001
4002
0
    float eps;
4003
0
    memcpy(&eps, dst->op_params, sizeof(float));
4004
4005
0
    GGML_ASSERT(eps >= 0.0f);
4006
4007
    // TODO: optimize
4008
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
4009
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
4010
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
4011
0
                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
4012
4013
0
                ggml_float sum = 0.0;
4014
0
                for (int64_t i00 = 0; i00 < ne00; i00++) {
4015
0
                    sum += (ggml_float)(x[i00] * x[i00]);
4016
0
                }
4017
4018
0
                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
4019
4020
0
                memcpy(y, x, ne00 * sizeof(float));
4021
4022
0
                const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
4023
4024
0
                ggml_vec_scale_f32(ne00, y, scale);
4025
0
            }
4026
0
        }
4027
0
    }
4028
0
}
4029
4030
void ggml_compute_forward_l2_norm(
4031
    const ggml_compute_params * params,
4032
0
    ggml_tensor * dst) {
4033
4034
0
    const ggml_tensor * src0 = dst->src[0];
4035
4036
0
    switch (src0->type) {
4037
0
        case GGML_TYPE_F32:
4038
0
            {
4039
0
                ggml_compute_forward_l2_norm_f32(params, dst);
4040
0
            } break;
4041
0
        default:
4042
0
            {
4043
0
                GGML_ABORT("fatal error");
4044
0
            }
4045
0
    }
4046
0
}
4047
4048
// ggml_compute_forward_out_prod
4049
4050
static void ggml_compute_forward_out_prod_f32(
4051
        const ggml_compute_params * params,
4052
0
              ggml_tensor * dst) {
4053
4054
0
    const ggml_tensor * src0 = dst->src[0];
4055
0
    const ggml_tensor * src1 = dst->src[1];
4056
4057
0
    GGML_TENSOR_BINARY_OP_LOCALS
4058
4059
0
    GGML_ASSERT(dst->type == GGML_TYPE_F32);
4060
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
4061
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
4062
4063
0
    const int ith = params->ith;
4064
0
    const int nth = params->nth;
4065
4066
0
    GGML_ASSERT(ne0 == ne00);
4067
0
    GGML_ASSERT(ne1 == ne10);
4068
0
    GGML_ASSERT(ne2 == ne12);
4069
0
    GGML_ASSERT(ne3 == ne13);
4070
4071
0
    GGML_ASSERT(ne2 % ne02 == 0);
4072
0
    GGML_ASSERT(ne3 % ne03 == 0);
4073
4074
    // we don't support permuted src0 or src1
4075
0
    GGML_ASSERT(nb00 == sizeof(float));
4076
4077
    // dst cannot be transposed or permuted
4078
0
    GGML_ASSERT(nb0 == sizeof(float));
4079
    // GGML_ASSERT(nb0 <= nb1);
4080
    // GGML_ASSERT(nb1 <= nb2);
4081
    // GGML_ASSERT(nb2 <= nb3);
4082
4083
    // nb01 >= nb00 - src0 is not transposed
4084
    //   compute by src0 rows
4085
4086
0
    if (ith == 0) {
4087
0
        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
4088
0
    }
4089
0
    ggml_barrier(params->threadpool);
4090
4091
    // dst[:,:,:,:] = 0
4092
    // for i2,i3:
4093
    //   for i1:
4094
    //     for i01:
4095
    //       for i0:
4096
    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4097
4098
    // parallelize by last three dimensions
4099
4100
    // total rows in dst
4101
0
    const int64_t nr = ne1*ne2*ne3;
4102
4103
    // rows per thread
4104
0
    const int64_t dr = (nr + nth - 1)/nth;
4105
4106
    // row range for this thread
4107
0
    const int64_t ir0 = dr*ith;
4108
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
4109
4110
    // block-tiling attempt
4111
0
    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
4112
0
    const int64_t blck_1 = 16;
4113
4114
    // dps == dst per src0, used for group query attention
4115
0
    const int64_t dps2 = ne2 / ne02;
4116
0
    const int64_t dps3 = ne3 / ne03;
4117
4118
0
    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
4119
0
        const int64_t bir1 = MIN(bir + blck_1, ir1);
4120
0
        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
4121
0
            const int64_t bne01 = MIN(bi01 + blck_0, ne01);
4122
0
            for (int64_t ir = bir; ir < bir1; ++ir) {
4123
                // dst indices
4124
0
                const int64_t i3 = ir/(ne2*ne1);
4125
0
                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4126
0
                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4127
4128
0
                const int64_t i02 = i2 / dps2;
4129
0
                const int64_t i03 = i3 / dps3;
4130
4131
                //const int64_t i10 = i1;
4132
0
                const int64_t i12 = i2;
4133
0
                const int64_t i13 = i3;
4134
4135
0
#if GGML_VEC_MAD_UNROLL > 2
4136
0
                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
4137
0
                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
4138
0
                    const int64_t i11 = i01;
4139
4140
0
                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4141
0
                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4142
0
                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
4143
4144
0
                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
4145
0
                }
4146
0
                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
4147
0
                    const int64_t i11 = i01;
4148
4149
0
                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4150
0
                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4151
0
                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
4152
4153
0
                    ggml_vec_mad_f32(ne0, d, s0, *s1);
4154
0
                }
4155
#else
4156
                for (int64_t i01 = bi01; i01 < bne01; ++i01) {
4157
                    const int64_t i11 = i01;
4158
4159
                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4160
                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4161
                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
4162
4163
                    ggml_vec_mad_f32(ne0, d, s0, *s1);
4164
                }
4165
#endif
4166
0
            }
4167
0
        }
4168
0
    }
4169
0
}
4170
4171
static void ggml_compute_forward_out_prod_q_f32(
4172
        const ggml_compute_params * params,
4173
0
              ggml_tensor * dst) {
4174
4175
0
    const ggml_tensor * src0 = dst->src[0];
4176
0
    const ggml_tensor * src1 = dst->src[1];
4177
4178
0
    GGML_TENSOR_BINARY_OP_LOCALS;
4179
4180
0
    const int ith = params->ith;
4181
0
    const int nth = params->nth;
4182
4183
0
    const ggml_type type = src0->type;
4184
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
4185
4186
0
    GGML_ASSERT(ne02 == ne12);
4187
0
    GGML_ASSERT(ne03 == ne13);
4188
0
    GGML_ASSERT(ne2  == ne12);
4189
0
    GGML_ASSERT(ne3  == ne13);
4190
4191
    // we don't support permuted src0 dim0
4192
0
    GGML_ASSERT(nb00 == ggml_type_size(type));
4193
4194
    // dst dim0 cannot be transposed or permuted
4195
0
    GGML_ASSERT(nb0 == sizeof(float));
4196
    // GGML_ASSERT(nb0 <= nb1);
4197
    // GGML_ASSERT(nb1 <= nb2);
4198
    // GGML_ASSERT(nb2 <= nb3);
4199
4200
0
    GGML_ASSERT(ne0 == ne00);
4201
0
    GGML_ASSERT(ne1 == ne10);
4202
0
    GGML_ASSERT(ne2 == ne02);
4203
0
    GGML_ASSERT(ne3 == ne03);
4204
4205
    // nb01 >= nb00 - src0 is not transposed
4206
    //   compute by src0 rows
4207
4208
0
    if (ith == 0) {
4209
0
        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
4210
0
    }
4211
0
    ggml_barrier(params->threadpool);
4212
4213
    // parallelize by last three dimensions
4214
4215
    // total rows in dst
4216
0
    const int64_t nr = ne1*ne2*ne3;
4217
4218
    // rows per thread
4219
0
    const int64_t dr = (nr + nth - 1)/nth;
4220
4221
    // row range for this thread
4222
0
    const int64_t ir0 = dr*ith;
4223
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
4224
4225
    // dst[:,:,:,:] = 0
4226
    // for i2,i3:
4227
    //   for i1:
4228
    //     for i01:
4229
    //       for i0:
4230
    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4231
4232
0
    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
4233
4234
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
4235
        // dst indices
4236
0
        const int64_t i3 = ir/(ne2*ne1);
4237
0
        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4238
0
        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4239
4240
0
        const int64_t i02 = i2;
4241
0
        const int64_t i03 = i3;
4242
4243
        //const int64_t i10 = i1;
4244
0
        const int64_t i12 = i2;
4245
0
        const int64_t i13 = i3;
4246
4247
0
        for (int64_t i01 = 0; i01 < ne01; ++i01) {
4248
0
            const int64_t i11 = i01;
4249
4250
0
            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4251
0
            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4252
0
            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
4253
4254
0
            dequantize_row_q(s0, wdata, ne0);
4255
0
            ggml_vec_mad_f32(ne0, d, wdata, *s1);
4256
0
        }
4257
0
    }
4258
0
}
4259
4260
void ggml_compute_forward_out_prod(
4261
        const ggml_compute_params * params,
4262
0
        ggml_tensor * dst) {
4263
4264
0
    const ggml_tensor * src0 = dst->src[0];
4265
4266
0
    switch (src0->type) {
4267
0
        case GGML_TYPE_Q4_0:
4268
0
        case GGML_TYPE_Q4_1:
4269
0
        case GGML_TYPE_Q5_0:
4270
0
        case GGML_TYPE_Q5_1:
4271
0
        case GGML_TYPE_Q8_0:
4272
0
        case GGML_TYPE_MXFP4:
4273
0
        case GGML_TYPE_Q2_K:
4274
0
        case GGML_TYPE_Q3_K:
4275
0
        case GGML_TYPE_Q4_K:
4276
0
        case GGML_TYPE_Q5_K:
4277
0
        case GGML_TYPE_Q6_K:
4278
0
        case GGML_TYPE_TQ1_0:
4279
0
        case GGML_TYPE_TQ2_0:
4280
0
        case GGML_TYPE_IQ2_XXS:
4281
0
        case GGML_TYPE_IQ2_XS:
4282
0
        case GGML_TYPE_IQ3_XXS:
4283
0
        case GGML_TYPE_IQ1_S:
4284
0
        case GGML_TYPE_IQ1_M:
4285
0
        case GGML_TYPE_IQ4_NL:
4286
0
        case GGML_TYPE_IQ4_XS:
4287
0
        case GGML_TYPE_IQ3_S:
4288
0
        case GGML_TYPE_IQ2_S:
4289
0
            {
4290
0
                ggml_compute_forward_out_prod_q_f32(params, dst);
4291
0
            } break;
4292
0
        case GGML_TYPE_F16:
4293
0
            {
4294
0
                GGML_ABORT("fatal error"); // todo
4295
                // ggml_compute_forward_out_prod_f16_f32(params, dst);
4296
0
            }
4297
0
        case GGML_TYPE_F32:
4298
0
            {
4299
0
                ggml_compute_forward_out_prod_f32(params, dst);
4300
0
            } break;
4301
0
        default:
4302
0
            {
4303
0
                GGML_ABORT("fatal error");
4304
0
            }
4305
0
    }
4306
0
}
4307
4308
// ggml_compute_forward_scale
4309
4310
static void ggml_compute_forward_scale_f32(
4311
        const ggml_compute_params * params,
4312
0
        ggml_tensor * dst) {
4313
4314
0
    const ggml_tensor * src0 = dst->src[0];
4315
4316
0
    GGML_ASSERT(ggml_is_contiguous(src0));
4317
0
    GGML_ASSERT(ggml_is_contiguous(dst));
4318
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4319
4320
0
    float s; // scale factor
4321
0
    float b; // bias
4322
4323
0
    memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4324
0
    memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
4325
4326
0
    const int ith = params->ith;
4327
0
    const int nth = params->nth;
4328
4329
0
    const int nc = src0->ne[0];
4330
0
    const int nr = ggml_nrows(src0);
4331
4332
    // rows per thread
4333
0
    const int dr = (nr + nth - 1)/nth;
4334
4335
    // row range for this thread
4336
0
    const int ir0 = dr*ith;
4337
0
    const int ir1 = MIN(ir0 + dr, nr);
4338
4339
0
    const size_t nb01 = src0->nb[1];
4340
4341
0
    const size_t nb1 = dst->nb[1];
4342
4343
0
    if (b == 0.0f) {
4344
0
        for (int i1 = ir0; i1 < ir1; i1++) {
4345
0
            if (dst->data != src0->data) {
4346
                // src0 is same shape as dst => same indices
4347
                // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4348
0
                memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4349
0
            }
4350
0
            ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4351
0
        }
4352
0
    } else {
4353
0
        for (int i1 = ir0; i1 < ir1; i1++) {
4354
0
            ggml_vec_mad1_f32(nc,
4355
0
                (float *) ((char *) dst->data  + i1*nb1),
4356
0
                (float *) ((char *) src0->data + i1*nb1),
4357
0
                s, b);
4358
0
        }
4359
0
    }
4360
0
}
4361
4362
void ggml_compute_forward_scale(
4363
        const ggml_compute_params * params,
4364
0
        ggml_tensor * dst) {
4365
4366
0
    const ggml_tensor * src0 = dst->src[0];
4367
4368
0
    switch (src0->type) {
4369
0
        case GGML_TYPE_F32:
4370
0
            {
4371
0
                ggml_compute_forward_scale_f32(params, dst);
4372
0
            } break;
4373
0
        default:
4374
0
            {
4375
0
                GGML_ABORT("fatal error");
4376
0
            }
4377
0
    }
4378
0
}
4379
4380
// ggml_compute_forward_set
4381
4382
static void ggml_compute_forward_set_f32(
4383
        const ggml_compute_params * params,
4384
0
        ggml_tensor * dst) {
4385
4386
0
    const ggml_tensor * src0 = dst->src[0];
4387
0
    const ggml_tensor * src1 = dst->src[1];
4388
4389
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4390
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
4391
4392
    // view src0 and dst with these strides and data offset inbytes during set
4393
    // nb0 is implicitly element_size because src0 and dst are contiguous
4394
0
    size_t nb1     = ((int32_t *) dst->op_params)[0];
4395
0
    size_t nb2     = ((int32_t *) dst->op_params)[1];
4396
0
    size_t nb3     = ((int32_t *) dst->op_params)[2];
4397
0
    size_t offset  = ((int32_t *) dst->op_params)[3];
4398
0
    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
4399
4400
0
    if (!inplace) {
4401
0
        if (params->ith == 0) {
4402
            // memcpy needs to be synchronized across threads to avoid race conditions.
4403
            // => do it in INIT phase
4404
0
            memcpy(
4405
0
                ((char *)  dst->data),
4406
0
                ((char *) src0->data),
4407
0
                ggml_nbytes(dst));
4408
0
        }
4409
0
        ggml_barrier(params->threadpool);
4410
0
    }
4411
4412
0
    const int ith = params->ith;
4413
0
    const int nth = params->nth;
4414
4415
0
    const int nr = ggml_nrows(src1);
4416
0
    const int nc = src1->ne[0];
4417
4418
0
    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
4419
0
    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
4420
4421
    // src0 and dst as viewed during set
4422
0
    const size_t nb0 = ggml_element_size(src0);
4423
4424
0
    const int im0 = (ne10 == 0 ? 0 : ne10-1);
4425
0
    const int im1 = (ne11 == 0 ? 0 : ne11-1);
4426
0
    const int im2 = (ne12 == 0 ? 0 : ne12-1);
4427
0
    const int im3 = (ne13 == 0 ? 0 : ne13-1);
4428
4429
0
    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
4430
4431
0
    GGML_ASSERT(nb10 == sizeof(float));
4432
4433
    // rows per thread
4434
0
    const int dr = (nr + nth - 1)/nth;
4435
4436
    // row range for this thread
4437
0
    const int ir0 = dr*ith;
4438
0
    const int ir1 = MIN(ir0 + dr, nr);
4439
4440
0
    for (int ir = ir0; ir < ir1; ++ir) {
4441
        // src0 and dst are viewed with shape of src1 and offset
4442
        // => same indices
4443
0
        const int i3 = ir/(ne12*ne11);
4444
0
        const int i2 = (ir - i3*ne12*ne11)/ne11;
4445
0
        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
4446
4447
0
        ggml_vec_cpy_f32(nc,
4448
0
                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
4449
0
                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
4450
0
    }
4451
0
}
4452
4453
static void ggml_compute_forward_set_i32(
4454
        const ggml_compute_params * params,
4455
0
        ggml_tensor * dst) {
4456
4457
0
    const ggml_tensor * src0 = dst->src[0];
4458
0
    const ggml_tensor * src1 = dst->src[1];
4459
4460
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4461
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
4462
4463
    // view src0 and dst with these strides and data offset inbytes during set
4464
    // nb0 is implicitly element_size because src0 and dst are contiguous
4465
0
    size_t nb1     = ((int32_t *) dst->op_params)[0];
4466
0
    size_t nb2     = ((int32_t *) dst->op_params)[1];
4467
0
    size_t nb3     = ((int32_t *) dst->op_params)[2];
4468
0
    size_t offset  = ((int32_t *) dst->op_params)[3];
4469
0
    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
4470
4471
0
    if (!inplace) {
4472
0
        if (params->ith == 0) {
4473
            // memcpy needs to be synchronized across threads to avoid race conditions.
4474
            // => do it in INIT phase
4475
0
            memcpy(
4476
0
                ((char *)  dst->data),
4477
0
                ((char *) src0->data),
4478
0
                ggml_nbytes(dst));
4479
0
        }
4480
0
        ggml_barrier(params->threadpool);
4481
0
    }
4482
4483
0
    const int ith = params->ith;
4484
0
    const int nth = params->nth;
4485
4486
0
    const int nr = ggml_nrows(src1);
4487
0
    const int nc = src1->ne[0];
4488
4489
0
    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
4490
0
    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
4491
4492
    // src0 and dst as viewed during set
4493
0
    const size_t nb0 = ggml_element_size(src0);
4494
4495
0
    const int im0 = (ne10 == 0 ? 0 : ne10-1);
4496
0
    const int im1 = (ne11 == 0 ? 0 : ne11-1);
4497
0
    const int im2 = (ne12 == 0 ? 0 : ne12-1);
4498
0
    const int im3 = (ne13 == 0 ? 0 : ne13-1);
4499
4500
0
    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
4501
4502
0
    GGML_ASSERT(nb10 == sizeof(int32_t));
4503
4504
    // rows per thread
4505
0
    const int dr = (nr + nth - 1)/nth;
4506
4507
    // row range for this thread
4508
0
    const int ir0 = dr*ith;
4509
0
    const int ir1 = MIN(ir0 + dr, nr);
4510
4511
0
    for (int ir = ir0; ir < ir1; ++ir) {
4512
        // src0 and dst are viewed with shape of src1 and offset
4513
        // => same indices
4514
0
        const int i3 = ir/(ne12*ne11);
4515
0
        const int i2 = (ir - i3*ne12*ne11)/ne11;
4516
0
        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
4517
4518
0
        ggml_vec_cpy_i32(nc,
4519
0
                (int32_t *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
4520
0
                (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
4521
0
    }
4522
0
}
4523
4524
void ggml_compute_forward_set(
4525
        const ggml_compute_params * params,
4526
0
        ggml_tensor * dst) {
4527
4528
0
    const ggml_tensor * src0 = dst->src[0];
4529
4530
0
    switch (src0->type) {
4531
0
        case GGML_TYPE_F32:
4532
0
            {
4533
0
                ggml_compute_forward_set_f32(params, dst);
4534
0
            } break;
4535
0
        case GGML_TYPE_I32:
4536
0
            {
4537
0
                ggml_compute_forward_set_i32(params, dst);
4538
0
            } break;
4539
0
        case GGML_TYPE_F16:
4540
0
        case GGML_TYPE_BF16:
4541
0
        case GGML_TYPE_Q4_0:
4542
0
        case GGML_TYPE_Q4_1:
4543
0
        case GGML_TYPE_Q5_0:
4544
0
        case GGML_TYPE_Q5_1:
4545
0
        case GGML_TYPE_Q8_0:
4546
0
        case GGML_TYPE_Q8_1:
4547
0
        case GGML_TYPE_MXFP4:
4548
0
        case GGML_TYPE_Q2_K:
4549
0
        case GGML_TYPE_Q3_K:
4550
0
        case GGML_TYPE_Q4_K:
4551
0
        case GGML_TYPE_Q5_K:
4552
0
        case GGML_TYPE_Q6_K:
4553
0
        case GGML_TYPE_TQ1_0:
4554
0
        case GGML_TYPE_TQ2_0:
4555
0
        case GGML_TYPE_IQ2_XXS:
4556
0
        case GGML_TYPE_IQ2_XS:
4557
0
        case GGML_TYPE_IQ3_XXS:
4558
0
        case GGML_TYPE_IQ1_S:
4559
0
        case GGML_TYPE_IQ1_M:
4560
0
        case GGML_TYPE_IQ4_NL:
4561
0
        case GGML_TYPE_IQ4_XS:
4562
0
        case GGML_TYPE_IQ3_S:
4563
0
        case GGML_TYPE_IQ2_S:
4564
0
        default:
4565
0
            {
4566
0
                GGML_ABORT("fatal error");
4567
0
            }
4568
0
    }
4569
0
}
4570
4571
// ggml_compute_forward_cpy
4572
4573
void ggml_compute_forward_cpy(
4574
        const ggml_compute_params * params,
4575
0
        ggml_tensor * dst) {
4576
0
    ggml_compute_forward_dup(params, dst);
4577
0
}
4578
4579
// ggml_compute_forward_cont
4580
4581
void ggml_compute_forward_cont(
4582
        const ggml_compute_params * params,
4583
0
        ggml_tensor * dst) {
4584
0
    ggml_compute_forward_dup(params, dst);
4585
0
}
4586
4587
// ggml_compute_forward_get_rows
4588
4589
static void ggml_compute_forward_get_rows_q(
4590
        const ggml_compute_params * params,
4591
0
              ggml_tensor * dst) {
4592
4593
0
    const ggml_tensor * src0 = dst->src[0];
4594
0
    const ggml_tensor * src1 = dst->src[1];
4595
4596
0
    GGML_TENSOR_BINARY_OP_LOCALS
4597
4598
0
    const int64_t nc = ne00;
4599
0
    const int64_t nr = ggml_nelements(src1);
4600
4601
0
    const ggml_type type = src0->type;
4602
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
4603
4604
0
    assert(ne0  == nc);
4605
0
    assert(ne02 == ne11);
4606
0
    assert(nb00 == ggml_type_size(type));
4607
0
    assert(ggml_nrows(dst) == nr);
4608
4609
0
    const int ith = params->ith;
4610
0
    const int nth = params->nth;
4611
4612
    // rows per thread
4613
0
    const int dr = (nr + nth - 1)/nth;
4614
4615
    // row range for this thread
4616
0
    const int ir0 = dr*ith;
4617
0
    const int ir1 = MIN(ir0 + dr, nr);
4618
4619
0
    for (int64_t i = ir0; i < ir1; ++i) {
4620
0
        const int64_t i12 = i/(ne11*ne10);
4621
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4622
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4623
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4624
4625
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4626
4627
0
        dequantize_row_q(
4628
0
                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4629
0
                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
4630
0
    }
4631
0
}
4632
4633
static void ggml_compute_forward_get_rows_f16(
4634
        const ggml_compute_params * params,
4635
0
              ggml_tensor * dst) {
4636
4637
0
    const ggml_tensor * src0 = dst->src[0];
4638
0
    const ggml_tensor * src1 = dst->src[1];
4639
4640
0
    GGML_TENSOR_BINARY_OP_LOCALS
4641
4642
0
    const int64_t nc = ne00;
4643
0
    const int64_t nr = ggml_nelements(src1);
4644
4645
0
    assert(ne0  == nc);
4646
0
    assert(ne02 == ne11);
4647
0
    assert(nb00 == sizeof(ggml_fp16_t));
4648
0
    assert(ggml_nrows(dst) == nr);
4649
4650
0
    const int ith = params->ith;
4651
0
    const int nth = params->nth;
4652
4653
    // rows per thread
4654
0
    const int dr = (nr + nth - 1)/nth;
4655
4656
    // row range for this thread
4657
0
    const int ir0 = dr*ith;
4658
0
    const int ir1 = MIN(ir0 + dr, nr);
4659
4660
0
    for (int64_t i = ir0; i < ir1; ++i) {
4661
0
        const int64_t i12 = i/(ne11*ne10);
4662
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4663
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4664
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4665
4666
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4667
4668
0
        ggml_cpu_fp16_to_fp32(
4669
0
            (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4670
0
                       (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
4671
0
    }
4672
0
}
4673
4674
static void ggml_compute_forward_get_rows_bf16(
4675
        const ggml_compute_params * params,
4676
0
              ggml_tensor * dst) {
4677
4678
0
    const ggml_tensor * src0 = dst->src[0];
4679
0
    const ggml_tensor * src1 = dst->src[1];
4680
4681
0
    GGML_TENSOR_BINARY_OP_LOCALS
4682
4683
0
    const int64_t nc = ne00;
4684
0
    const int64_t nr = ggml_nelements(src1);
4685
4686
0
    assert(ne0  == nc);
4687
0
    assert(ne02 == ne11);
4688
0
    assert(nb00 == sizeof(ggml_bf16_t));
4689
0
    assert(ggml_nrows(dst) == nr);
4690
4691
0
    const int ith = params->ith;
4692
0
    const int nth = params->nth;
4693
4694
    // rows per thread
4695
0
    const int dr = (nr + nth - 1)/nth;
4696
4697
    // row range for this thread
4698
0
    const int ir0 = dr*ith;
4699
0
    const int ir1 = MIN(ir0 + dr, nr);
4700
4701
0
    for (int64_t i = ir0; i < ir1; ++i) {
4702
0
        const int64_t i12 = i/(ne11*ne10);
4703
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4704
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4705
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4706
4707
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4708
4709
0
        ggml_cpu_bf16_to_fp32(
4710
0
            (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4711
0
                        (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
4712
0
    }
4713
0
}
4714
4715
static void ggml_compute_forward_get_rows_f32(
4716
        const ggml_compute_params * params,
4717
0
              ggml_tensor * dst) {
4718
4719
0
    const ggml_tensor * src0 = dst->src[0];
4720
0
    const ggml_tensor * src1 = dst->src[1];
4721
4722
0
    GGML_TENSOR_BINARY_OP_LOCALS
4723
4724
0
    const int64_t nc = ne00;
4725
0
    const int64_t nr = ggml_nelements(src1);
4726
4727
0
    assert(ne0  == nc);
4728
0
    assert(ne02 == ne11);
4729
0
    assert(nb00 == sizeof(float));
4730
0
    assert(ggml_nrows(dst) == nr);
4731
4732
0
    const int ith = params->ith;
4733
0
    const int nth = params->nth;
4734
4735
    // rows per thread
4736
0
    const int dr = (nr + nth - 1)/nth;
4737
4738
    // row range for this thread
4739
0
    const int ir0 = dr*ith;
4740
0
    const int ir1 = MIN(ir0 + dr, nr);
4741
4742
0
    for (int64_t i = ir0; i < ir1; ++i) {
4743
0
        const int64_t i12 = i/(ne11*ne10);
4744
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4745
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4746
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4747
4748
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4749
4750
0
        ggml_vec_cpy_f32(nc,
4751
0
                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
4752
0
                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
4753
0
    }
4754
0
}
4755
4756
void ggml_compute_forward_get_rows(
4757
        const ggml_compute_params * params,
4758
0
        ggml_tensor * dst) {
4759
4760
0
    const ggml_tensor * src0 = dst->src[0];
4761
4762
0
    switch (src0->type) {
4763
0
        case GGML_TYPE_Q4_0:
4764
0
        case GGML_TYPE_Q4_1:
4765
0
        case GGML_TYPE_Q5_0:
4766
0
        case GGML_TYPE_Q5_1:
4767
0
        case GGML_TYPE_Q8_0:
4768
0
        case GGML_TYPE_Q8_1:
4769
0
        case GGML_TYPE_MXFP4:
4770
0
        case GGML_TYPE_Q2_K:
4771
0
        case GGML_TYPE_Q3_K:
4772
0
        case GGML_TYPE_Q4_K:
4773
0
        case GGML_TYPE_Q5_K:
4774
0
        case GGML_TYPE_Q6_K:
4775
0
        case GGML_TYPE_TQ1_0:
4776
0
        case GGML_TYPE_TQ2_0:
4777
0
        case GGML_TYPE_IQ2_XXS:
4778
0
        case GGML_TYPE_IQ2_XS:
4779
0
        case GGML_TYPE_IQ3_XXS:
4780
0
        case GGML_TYPE_IQ1_S:
4781
0
        case GGML_TYPE_IQ1_M:
4782
0
        case GGML_TYPE_IQ4_NL:
4783
0
        case GGML_TYPE_IQ4_XS:
4784
0
        case GGML_TYPE_IQ3_S:
4785
0
        case GGML_TYPE_IQ2_S:
4786
0
            {
4787
0
                ggml_compute_forward_get_rows_q(params, dst);
4788
0
            } break;
4789
0
        case GGML_TYPE_F16:
4790
0
            {
4791
0
                ggml_compute_forward_get_rows_f16(params, dst);
4792
0
            } break;
4793
0
        case GGML_TYPE_BF16:
4794
0
            {
4795
0
                ggml_compute_forward_get_rows_bf16(params, dst);
4796
0
            } break;
4797
0
        case GGML_TYPE_F32:
4798
0
        case GGML_TYPE_I32:
4799
0
            {
4800
0
                ggml_compute_forward_get_rows_f32(params, dst);
4801
0
            } break;
4802
0
        default:
4803
0
            {
4804
0
                GGML_ABORT("fatal error");
4805
0
            }
4806
0
    }
4807
4808
    //static bool first = true;
4809
    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
4810
    //if (first) {
4811
    //    first = false;
4812
    //} else {
4813
    //    for (int k = 0; k < dst->ne[1]; ++k) {
4814
    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
4815
    //            for (int i = 0; i < 16; ++i) {
4816
    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
4817
    //            }
4818
    //            printf("\n");
4819
    //        }
4820
    //        printf("\n");
4821
    //    }
4822
    //    printf("\n");
4823
    //    exit(0);
4824
    //}
4825
0
}
4826
4827
template<typename idx_t>
4828
static void ggml_compute_forward_set_rows_f32(
4829
        const ggml_compute_params * params,
4830
0
              ggml_tensor * dst) {
4831
4832
0
    const ggml_tensor * src0 = dst->src[0];
4833
0
    const ggml_tensor * src1 = dst->src[1];
4834
4835
0
    GGML_TENSOR_BINARY_OP_LOCALS
4836
4837
0
    const int64_t nc = ne00;
4838
0
    const int64_t nr = ne01;
4839
4840
0
    assert(ne0  == nc);
4841
0
    assert(ne2  == ne02);
4842
0
    assert(ne3  == ne03);
4843
0
    assert(src0->type == GGML_TYPE_F32);
4844
0
    assert(ne02 % ne11 == 0);
4845
0
    assert(ne03 % ne12 == 0);
4846
4847
0
    const int ith = params->ith;
4848
0
    const int nth = params->nth;
4849
4850
    // rows per thread
4851
0
    const int64_t dr = (nr + nth - 1)/nth;
4852
4853
    // row range for this thread
4854
0
    const int64_t ir0 = dr*ith;
4855
0
    const int64_t ir1 = std::min(ir0 + dr, nr);
4856
4857
0
    ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
4858
4859
0
    for (int64_t i03 = 0; i03 < ne03; ++i03) {
4860
0
        for (int64_t i02 = 0; i02 < ne02; ++i02) {
4861
0
            for (int64_t i = ir0; i < ir1; ++i) {
4862
0
                const int64_t i12 = i03%ne12;
4863
0
                const int64_t i11 = i02%ne11;
4864
0
                const int64_t i10 = i;
4865
4866
0
                const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4867
4868
0
                GGML_ASSERT(i1 >= 0 && i1 < ne1);
4869
4870
0
                from_float(
4871
0
                        (const float *) ((char *) src0->data +  i*nb01 + i02*nb02 + i03*nb03),
4872
0
                                        ((char *)  dst->data + i1*nb1  + i02*nb2  + i03*nb3), nc);
4873
0
            }
4874
0
        }
4875
0
    }
4876
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_set_rows_f32<long>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_set_rows_f32<int>(ggml_compute_params const*, ggml_tensor*)
4877
4878
void ggml_compute_forward_set_rows(
4879
        const ggml_compute_params * params,
4880
0
        ggml_tensor * dst) {
4881
4882
0
    const ggml_tensor * src0 = dst->src[0];
4883
0
    const ggml_tensor * src1 = dst->src[1];
4884
4885
0
    switch (src0->type) {
4886
0
        case GGML_TYPE_F32:
4887
0
            {
4888
0
                if (src1->type == GGML_TYPE_I64) {
4889
0
                    ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4890
0
                } else if (src1->type == GGML_TYPE_I32) {
4891
0
                    ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4892
0
                } else {
4893
0
                    GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
4894
0
                }
4895
0
            } break;
4896
0
        default:
4897
0
            {
4898
0
                GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
4899
0
            }
4900
0
    }
4901
0
}
4902
4903
// ggml_compute_forward_get_rows_back
4904
4905
static void ggml_compute_forward_get_rows_back_f32_f16(
4906
        const ggml_compute_params * params,
4907
0
              ggml_tensor * dst) {
4908
4909
0
    const ggml_tensor * src0 = dst->src[0];
4910
0
    const ggml_tensor * src1 = dst->src[1];
4911
4912
0
    if (params->ith != 0) {
4913
0
        return;
4914
0
    }
4915
4916
0
    GGML_ASSERT(ggml_is_contiguous(dst));
4917
4918
    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
4919
4920
0
    memset(dst->data, 0, ggml_nbytes(dst));
4921
4922
0
    const int nc = src0->ne[0];
4923
0
    const int nr = ggml_nelements(src1);
4924
4925
0
    GGML_ASSERT( dst->ne[0] == nc);
4926
0
    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
4927
4928
0
    for (int i = 0; i < nr; ++i) {
4929
0
        const int r = ((int32_t *) src1->data)[i];
4930
4931
0
        for (int j = 0; j < nc; ++j) {
4932
0
            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
4933
0
            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
4934
0
        }
4935
0
    }
4936
0
}
4937
4938
static void ggml_compute_forward_get_rows_back_f32(
4939
        const ggml_compute_params * params,
4940
0
              ggml_tensor * dst) {
4941
4942
0
    const ggml_tensor * src0 = dst->src[0];
4943
0
    const ggml_tensor * src1 = dst->src[1];
4944
4945
0
    if (params->ith != 0) {
4946
0
        return;
4947
0
    }
4948
4949
0
    GGML_ASSERT(ggml_is_contiguous(dst));
4950
4951
    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
4952
4953
0
    memset(dst->data, 0, ggml_nbytes(dst));
4954
4955
0
    const int nc = src0->ne[0];
4956
0
    const int nr = ggml_nelements(src1);
4957
4958
0
    GGML_ASSERT( dst->ne[0] == nc);
4959
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
4960
4961
0
    for (int i = 0; i < nr; ++i) {
4962
0
        const int r = ((int32_t *) src1->data)[i];
4963
4964
0
        ggml_vec_add_f32(nc,
4965
0
                (float *) ((char *)  dst->data + r*dst->nb[1]),
4966
0
                (float *) ((char *)  dst->data + r*dst->nb[1]),
4967
0
                (float *) ((char *) src0->data + i*src0->nb[1]));
4968
0
    }
4969
0
}
4970
4971
void ggml_compute_forward_get_rows_back(
4972
        const ggml_compute_params * params,
4973
0
        ggml_tensor * dst) {
4974
4975
0
    const ggml_tensor * src0 = dst->src[0];
4976
4977
0
    switch (src0->type) {
4978
0
        case GGML_TYPE_F16:
4979
0
            {
4980
0
                ggml_compute_forward_get_rows_back_f32_f16(params, dst);
4981
0
            } break;
4982
0
        case GGML_TYPE_F32:
4983
0
            {
4984
0
                ggml_compute_forward_get_rows_back_f32(params, dst);
4985
0
            } break;
4986
0
        default:
4987
0
            {
4988
0
                GGML_ABORT("fatal error");
4989
0
            }
4990
0
    }
4991
4992
    //static bool first = true;
4993
    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
4994
    //if (first) {
4995
    //    first = false;
4996
    //} else {
4997
    //    for (int k = 0; k < dst->ne[1]; ++k) {
4998
    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
4999
    //            for (int i = 0; i < 16; ++i) {
5000
    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
5001
    //            }
5002
    //            printf("\n");
5003
    //        }
5004
    //        printf("\n");
5005
    //    }
5006
    //    printf("\n");
5007
    //    exit(0);
5008
    //}
5009
0
}
5010
5011
// ggml_compute_forward_diag
5012
5013
static void ggml_compute_forward_diag_f32(
5014
        const ggml_compute_params * params,
5015
0
        ggml_tensor * dst) {
5016
5017
0
    const ggml_tensor * src0 = dst->src[0];
5018
5019
0
    if (params->ith != 0) {
5020
0
        return;
5021
0
    }
5022
5023
    // TODO: handle transposed/permuted matrices
5024
5025
0
    GGML_TENSOR_UNARY_OP_LOCALS
5026
5027
0
    GGML_ASSERT(ne00 == ne0);
5028
0
    GGML_ASSERT(ne00 == ne1);
5029
0
    GGML_ASSERT(ne01 == 1);
5030
0
    GGML_ASSERT(ne02 == ne2);
5031
0
    GGML_ASSERT(ne03 == ne3);
5032
5033
0
    GGML_ASSERT(nb00 == sizeof(float));
5034
0
    GGML_ASSERT(nb0  == sizeof(float));
5035
5036
0
    for (int i3 = 0; i3 < ne3; i3++) {
5037
0
        for (int i2 = 0; i2 < ne2; i2++) {
5038
0
            for (int i1 = 0; i1 < ne1; i1++) {
5039
0
                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);
5040
0
                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
5041
0
                for (int i0 = 0; i0 < i1; i0++) {
5042
0
                    d[i0] = 0;
5043
0
                }
5044
0
                d[i1] = s[i1];
5045
0
                for (int i0 = i1+1; i0 < ne0; i0++) {
5046
0
                    d[i0] = 0;
5047
0
                }
5048
0
            }
5049
0
        }
5050
0
    }
5051
0
}
5052
5053
void ggml_compute_forward_diag(
5054
        const ggml_compute_params * params,
5055
0
        ggml_tensor * dst) {
5056
5057
0
    const ggml_tensor * src0 = dst->src[0];
5058
5059
0
    switch (src0->type) {
5060
0
        case GGML_TYPE_F32:
5061
0
            {
5062
0
                ggml_compute_forward_diag_f32(params, dst);
5063
0
            } break;
5064
0
        default:
5065
0
            {
5066
0
                GGML_ABORT("fatal error");
5067
0
            }
5068
0
    }
5069
0
}
5070
5071
// ggml_compute_forward_diag_mask_inf
5072
5073
static void ggml_compute_forward_diag_mask_f32(
5074
        const ggml_compute_params * params,
5075
        ggml_tensor * dst,
5076
0
        const float value) {
5077
5078
0
    const ggml_tensor * src0 = dst->src[0];
5079
5080
0
    const int ith = params->ith;
5081
0
    const int nth = params->nth;
5082
5083
0
    const int  n_past  = ((int32_t *) dst->op_params)[0];
5084
0
    const bool inplace = src0->data == dst->data;
5085
5086
0
    GGML_ASSERT(n_past >= 0);
5087
5088
0
    if (!inplace) {
5089
0
        if (ith == 0) {
5090
            // memcpy needs to be synchronized across threads to avoid race conditions.
5091
            // => do it in INIT phase
5092
0
            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5093
0
            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
5094
0
            memcpy(
5095
0
                ((char *)  dst->data),
5096
0
                ((char *) src0->data),
5097
0
                ggml_nbytes(dst));
5098
0
        }
5099
0
        ggml_barrier(params->threadpool);
5100
0
    }
5101
5102
    // TODO: handle transposed/permuted matrices
5103
5104
0
    const int n  = ggml_nrows(src0);
5105
0
    const int nc = src0->ne[0];
5106
0
    const int nr = src0->ne[1];
5107
0
    const int nz = n/nr;
5108
5109
0
    GGML_ASSERT( dst->nb[0] == sizeof(float));
5110
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
5111
5112
0
    for (int k = 0; k < nz; k++) {
5113
0
        for (int j = ith; j < nr; j += nth) {
5114
0
            for (int i = n_past; i < nc; i++) {
5115
0
                if (i > n_past + j) {
5116
0
                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
5117
0
                }
5118
0
            }
5119
0
        }
5120
0
    }
5121
0
}
5122
5123
void ggml_compute_forward_diag_mask_inf(
5124
        const ggml_compute_params * params,
5125
0
        ggml_tensor * dst) {
5126
5127
0
    const ggml_tensor * src0 = dst->src[0];
5128
5129
0
    switch (src0->type) {
5130
0
        case GGML_TYPE_F32:
5131
0
            {
5132
0
                ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
5133
0
            } break;
5134
0
        default:
5135
0
            {
5136
0
                GGML_ABORT("fatal error");
5137
0
            }
5138
0
    }
5139
0
}
5140
5141
void ggml_compute_forward_diag_mask_zero(
5142
        const ggml_compute_params * params,
5143
0
        ggml_tensor * dst) {
5144
5145
0
    const ggml_tensor * src0 = dst->src[0];
5146
5147
0
    switch (src0->type) {
5148
0
        case GGML_TYPE_F32:
5149
0
            {
5150
0
                ggml_compute_forward_diag_mask_f32(params, dst, 0);
5151
0
            } break;
5152
0
        default:
5153
0
            {
5154
0
                GGML_ABORT("fatal error");
5155
0
            }
5156
0
    }
5157
0
}
5158
5159
// ggml_compute_forward_soft_max
5160
5161
static void ggml_compute_forward_soft_max_f32(
5162
        const ggml_compute_params * params,
5163
0
              ggml_tensor * dst) {
5164
5165
0
    const ggml_tensor * src0 = dst->src[0];
5166
0
    const ggml_tensor * src1 = dst->src[1];
5167
0
    const ggml_tensor * src2 = dst->src[2];
5168
5169
0
    assert(ggml_is_contiguous(dst));
5170
0
    assert(ggml_are_same_shape(src0, dst));
5171
5172
0
    float scale    = 1.0f;
5173
0
    float max_bias = 0.0f;
5174
5175
0
    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
5176
0
    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5177
5178
0
    const int ith = params->ith;
5179
0
    const int nth = params->nth;
5180
5181
0
    GGML_TENSOR_UNARY_OP_LOCALS
5182
5183
0
    const int64_t nb11 = src1 ? src1->nb[1] : 1;
5184
0
    const int64_t nb12 = src1 ? src1->nb[2] : 1;
5185
0
    const int64_t nb13 = src1 ? src1->nb[3] : 1;
5186
5187
0
    const int64_t ne12 = src1 ? src1->ne[2] : 1;
5188
0
    const int64_t ne13 = src1 ? src1->ne[3] : 1;
5189
5190
    // TODO: is this supposed to be ceil instead of floor?
5191
    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
5192
0
    const uint32_t n_head      = ne02;
5193
0
    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
5194
5195
0
    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
5196
0
    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5197
5198
0
    float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5199
5200
0
    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5201
5202
    // sinks
5203
0
    const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5204
5205
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
5206
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
5207
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5208
0
                const int64_t i11 = i01;
5209
0
                const int64_t i12 = i02%ne12;
5210
0
                const int64_t i13 = i03%ne13;
5211
5212
                // ALiBi
5213
0
                const uint32_t h = i02; // head
5214
0
                const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5215
5216
0
                float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5217
0
                float * dp = (float *)((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3);
5218
5219
                // broadcast the mask across rows
5220
0
                ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5221
0
                float       * mp_f32 = src1 ? (float       *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5222
5223
0
                ggml_vec_cpy_f32  (ne00, wp, sp);
5224
0
                ggml_vec_scale_f32(ne00, wp, scale);
5225
0
                if (mp_f32) {
5226
0
                    if (use_f16) {
5227
0
                        for (int i = 0; i < ne00; ++i) {
5228
0
                            wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5229
0
                        }
5230
0
                    } else {
5231
0
                        for (int i = 0; i < ne00; ++i) {
5232
0
                            wp[i] += slope*mp_f32[i];
5233
0
                        }
5234
0
                    }
5235
0
                }
5236
5237
#ifndef NDEBUG
5238
                for (int i = 0; i < ne00; ++i) {
5239
                    //printf("p[%d] = %f\n", i, p[i]);
5240
                    assert(!isnan(wp[i]));
5241
                }
5242
#endif
5243
5244
0
                float max = -INFINITY;
5245
0
                ggml_vec_max_f32(ne00, &max, wp);
5246
5247
                // if we have sinks, make a correction as if they were included in the softmax
5248
0
                if (sk) {
5249
0
                    max = MAX(max, sk[i02]);
5250
0
                }
5251
5252
0
                ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5253
0
                assert(sum > 0.0);
5254
5255
0
                if (sk) {
5256
0
                    sum += (ggml_float) expf(sk[i02] - max);
5257
0
                }
5258
5259
0
                sum = 1.0/sum;
5260
0
                ggml_vec_scale_f32(ne00, dp, sum);
5261
5262
#ifndef NDEBUG
5263
                for (int i = 0; i < ne00; ++i) {
5264
                    assert(!isnan(dp[i]));
5265
                    assert(!isinf(dp[i]));
5266
                }
5267
#endif
5268
0
            }
5269
0
        }
5270
0
    }
5271
0
}
5272
5273
void ggml_compute_forward_soft_max(
5274
        const ggml_compute_params * params,
5275
0
              ggml_tensor * dst) {
5276
5277
0
    const ggml_tensor * src0 = dst->src[0];
5278
5279
0
    switch (src0->type) {
5280
0
        case GGML_TYPE_F32:
5281
0
            {
5282
0
                ggml_compute_forward_soft_max_f32(params, dst);
5283
0
            } break;
5284
0
        default:
5285
0
            {
5286
0
                GGML_ABORT("fatal error");
5287
0
            }
5288
0
    }
5289
0
}
5290
5291
5292
// ggml_compute_forward_soft_max_ext_back
5293
5294
static void ggml_compute_forward_soft_max_ext_back_f32(
5295
        const ggml_compute_params * params,
5296
0
        ggml_tensor * dst) {
5297
5298
0
    const ggml_tensor * src0 = dst->src[0];
5299
0
    const ggml_tensor * src1 = dst->src[1];
5300
5301
0
    GGML_ASSERT(ggml_is_contiguous(src0));
5302
0
    GGML_ASSERT(ggml_is_contiguous(src1));
5303
0
    GGML_ASSERT(ggml_is_contiguous(dst));
5304
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
5305
0
    GGML_ASSERT(ggml_are_same_shape(src1, dst));
5306
5307
0
    float scale    = 1.0f;
5308
0
    float max_bias = 0.0f;
5309
5310
0
    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
5311
0
    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
5312
5313
0
    GGML_ASSERT(max_bias == 0.0f);
5314
5315
    // TODO: handle transposed/permuted matrices
5316
5317
0
    const int ith = params->ith;
5318
0
    const int nth = params->nth;
5319
5320
0
    const int nc = src0->ne[0];
5321
0
    const int nr = ggml_nrows(src0);
5322
5323
    // rows per thread
5324
0
    const int dr = (nr + nth - 1)/nth;
5325
5326
    // row range for this thread
5327
0
    const int ir0 = dr*ith;
5328
0
    const int ir1 = MIN(ir0 + dr, nr);
5329
5330
0
    for (int i1 = ir0; i1 < ir1; i1++) {
5331
0
        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
5332
0
        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
5333
0
        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
5334
5335
#ifndef NDEBUG
5336
        for (int i = 0; i < nc; ++i) {
5337
            //printf("p[%d] = %f\n", i, p[i]);
5338
            assert(!isnan(dy[i]));
5339
            assert(!isnan(y[i]));
5340
        }
5341
#endif
5342
        // Jii = yi - yi*yi
5343
        // Jij = -yi*yj
5344
        // J = diag(y)-y.T*y
5345
        // dx = J * dy
5346
        // dxk = sum_i(Jki * dyi)
5347
        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
5348
        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
5349
        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
5350
        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
5351
        // dxk = -yk * dot(y, dy) + yk*dyk
5352
        // dxk = yk * (- dot(y, dy) + dyk)
5353
        // dxk = yk * (dyk - dot(y, dy))
5354
        //
5355
        // post-order:
5356
        // dot_y_dy := dot(y, dy)
5357
        // dx := dy
5358
        // dx := dx - dot_y_dy
5359
        // dx := dx * y
5360
5361
        // linear runtime, no additional memory
5362
0
        float dot_y_dy = 0;
5363
0
        ggml_vec_dot_f32  (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
5364
0
        ggml_vec_cpy_f32  (nc, dx, dy);
5365
0
        ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
5366
0
        ggml_vec_mul_f32  (nc, dx, dx, y);
5367
0
        ggml_vec_scale_f32(nc, dx, scale);
5368
5369
#ifndef NDEBUG
5370
        for (int i = 0; i < nc; ++i) {
5371
            assert(!isnan(dx[i]));
5372
            assert(!isinf(dx[i]));
5373
        }
5374
#endif
5375
0
    }
5376
0
}
5377
5378
void ggml_compute_forward_soft_max_ext_back(
5379
        const ggml_compute_params * params,
5380
0
        ggml_tensor * dst) {
5381
5382
0
    const ggml_tensor * src0 = dst->src[0];
5383
5384
0
    switch (src0->type) {
5385
0
        case GGML_TYPE_F32:
5386
0
            {
5387
0
                ggml_compute_forward_soft_max_ext_back_f32(params, dst);
5388
0
            } break;
5389
0
        default:
5390
0
            {
5391
0
                GGML_ABORT("fatal error");
5392
0
            }
5393
0
    }
5394
0
}
5395
5396
// ggml_compute_forward_clamp
5397
5398
static void ggml_compute_forward_clamp_f32(
5399
        const ggml_compute_params * params,
5400
0
        ggml_tensor * dst) {
5401
5402
0
    const ggml_tensor * src0 = dst->src[0];
5403
5404
0
    float min;
5405
0
    float max;
5406
0
    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
5407
0
    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
5408
5409
0
    const int ith = params->ith;
5410
0
    const int nth = params->nth;
5411
5412
0
    const int n  = ggml_nrows(src0);
5413
0
    const int nc = src0->ne[0];
5414
5415
0
    const size_t nb00 = src0->nb[0];
5416
0
    const size_t nb01 = src0->nb[1];
5417
5418
0
    const size_t nb0 = dst->nb[0];
5419
0
    const size_t nb1 = dst->nb[1];
5420
5421
0
    GGML_ASSERT( nb0 == sizeof(float));
5422
0
    GGML_ASSERT(nb00 == sizeof(float));
5423
5424
0
    for (int j = ith; j < n; j += nth) {
5425
0
        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
5426
0
        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
5427
5428
0
        for (int i = 0; i < nc; i++) {
5429
0
            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
5430
0
        }
5431
0
    }
5432
0
}
5433
5434
static void ggml_compute_forward_clamp_f16(
5435
    const ggml_compute_params * params,
5436
0
    ggml_tensor * dst) {
5437
5438
0
    const ggml_tensor * src0 = dst->src[0];
5439
5440
0
    float min;
5441
0
    float max;
5442
0
    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
5443
0
    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
5444
5445
0
    const int ith = params->ith;
5446
0
    const int nth = params->nth;
5447
5448
0
    const int n  = ggml_nrows(src0);
5449
0
    const int nc = src0->ne[0];
5450
5451
0
    const size_t nb00 = src0->nb[0];
5452
0
    const size_t nb01 = src0->nb[1];
5453
5454
0
    const size_t nb0 = dst->nb[0];
5455
0
    const size_t nb1 = dst->nb[1];
5456
5457
0
    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5458
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5459
5460
0
    for (int j = ith; j < n; j += nth) {
5461
0
        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *)  dst->data + j*nb1);
5462
0
        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5463
5464
0
        for (int i = 0; i < nc; i++) {
5465
0
            float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
5466
0
            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
5467
0
        }
5468
0
    }
5469
0
}
5470
5471
void ggml_compute_forward_clamp(
5472
        const ggml_compute_params * params,
5473
0
        ggml_tensor * dst) {
5474
5475
0
    const ggml_tensor * src0 = dst->src[0];
5476
5477
0
    switch (src0->type) {
5478
0
        case GGML_TYPE_F32:
5479
0
            {
5480
0
                ggml_compute_forward_clamp_f32(params, dst);
5481
0
            } break;
5482
0
        case GGML_TYPE_F16:
5483
0
            {
5484
0
                ggml_compute_forward_clamp_f16(params, dst);
5485
0
            } break;
5486
0
        case GGML_TYPE_BF16:
5487
0
        case GGML_TYPE_Q4_0:
5488
0
        case GGML_TYPE_Q4_1:
5489
0
        case GGML_TYPE_Q5_0:
5490
0
        case GGML_TYPE_Q5_1:
5491
0
        case GGML_TYPE_Q8_0:
5492
0
        case GGML_TYPE_Q8_1:
5493
0
        case GGML_TYPE_MXFP4:
5494
0
        case GGML_TYPE_Q2_K:
5495
0
        case GGML_TYPE_Q3_K:
5496
0
        case GGML_TYPE_Q4_K:
5497
0
        case GGML_TYPE_Q5_K:
5498
0
        case GGML_TYPE_Q6_K:
5499
0
        case GGML_TYPE_TQ1_0:
5500
0
        case GGML_TYPE_TQ2_0:
5501
0
        case GGML_TYPE_IQ2_XXS:
5502
0
        case GGML_TYPE_IQ2_XS:
5503
0
        case GGML_TYPE_IQ3_XXS:
5504
0
        case GGML_TYPE_IQ1_S:
5505
0
        case GGML_TYPE_IQ1_M:
5506
0
        case GGML_TYPE_IQ4_NL:
5507
0
        case GGML_TYPE_IQ4_XS:
5508
0
        case GGML_TYPE_IQ3_S:
5509
0
        case GGML_TYPE_IQ2_S:
5510
0
        case GGML_TYPE_Q8_K:
5511
0
        case GGML_TYPE_I8:
5512
0
        case GGML_TYPE_I16:
5513
0
        case GGML_TYPE_I32:
5514
0
        case GGML_TYPE_I64:
5515
0
        case GGML_TYPE_F64:
5516
0
        case GGML_TYPE_COUNT:
5517
0
            {
5518
0
                GGML_ABORT("fatal error");
5519
0
            }
5520
0
    }
5521
0
}
5522
5523
// ggml_compute_forward_rope
5524
5525
0
static float rope_yarn_ramp(const float low, const float high, const int i0) {
5526
0
    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
5527
0
    return 1 - MIN(1, MAX(0, y));
5528
0
}
5529
5530
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
5531
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
5532
static void rope_yarn(
5533
    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
5534
0
    float * cos_theta, float * sin_theta) {
5535
    // Get n-d rotational scaling corrected for extrapolation
5536
0
    float theta_interp = freq_scale * theta_extrap;
5537
0
    float theta = theta_interp;
5538
0
    if (ext_factor != 0.0f) {
5539
0
        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
5540
0
        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
5541
5542
        // Get n-d magnitude scaling corrected for interpolation
5543
0
        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
5544
0
    }
5545
0
    *cos_theta = cosf(theta) * mscale;
5546
0
    *sin_theta = sinf(theta) * mscale;
5547
0
}
5548
5549
static void ggml_rope_cache_init(
5550
     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5551
0
     float * cache, float sin_sign, float theta_scale) {
5552
    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5553
0
    float theta = theta_base;
5554
0
    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5555
0
        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5556
0
        rope_yarn(
5557
0
            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5558
0
        );
5559
0
        cache[i0 + 1] *= sin_sign;
5560
5561
0
        theta *= theta_scale;
5562
0
    }
5563
0
}
5564
5565
static void ggml_mrope_cache_init(
5566
     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5567
     float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5568
0
     float * cache, float sin_sign, float theta_scale) {
5569
    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5570
0
    float theta_t = theta_base_t;
5571
0
    float theta_h = theta_base_h;
5572
0
    float theta_w = theta_base_w;
5573
0
    float theta_e = theta_base_e;  // extra position id for vision encoder
5574
0
    int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
5575
0
    int sec_w = sections[1] + sections[0];
5576
0
    int sec_e = sections[2] + sec_w;
5577
0
    GGML_ASSERT(sect_dims <= ne0);
5578
5579
0
    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5580
0
        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5581
5582
0
        int sector = (i0 / 2) % sect_dims;
5583
0
        if (indep_sects) {
5584
            // compute theta independently for each dim sections
5585
            // (i.e. reset corresponding theta when `i0` go from one section to another)
5586
0
            if (sector == 0) {
5587
0
                theta_t = theta_base_t;
5588
0
            }
5589
0
            else if (sector == sections[0]) {
5590
0
                theta_h = theta_base_h;;
5591
0
            }
5592
0
            else if (sector == sec_w) {
5593
0
                theta_w = theta_base_w;
5594
0
            }
5595
0
            else if (sector == sec_e) {
5596
0
                theta_e = theta_base_e;
5597
0
            }
5598
0
        }
5599
5600
0
        float theta = theta_t;
5601
0
        if (is_imrope) { // qwen3vl apply interleaved mrope
5602
0
            if (sector % 3 == 1 && sector < 3 * sections[1]) {
5603
0
                theta = theta_h;
5604
0
            } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5605
0
                theta = theta_w;
5606
0
            } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5607
0
                theta = theta_t;
5608
0
            } else {
5609
0
                theta = theta_e;
5610
0
            }
5611
0
        } else {
5612
0
            if (sector >= sections[0] && sector < sec_w) {
5613
0
                theta = theta_h;
5614
0
            }
5615
0
            else if (sector >= sec_w && sector < sec_w + sections[2]) {
5616
0
                theta = theta_w;
5617
0
            }
5618
0
            else if (sector >= sec_w + sections[2]) {
5619
0
                theta = theta_e;
5620
0
            }
5621
0
        }
5622
5623
0
        rope_yarn(
5624
0
            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5625
0
        );
5626
0
        cache[i0 + 1] *= sin_sign;
5627
5628
0
        theta_t *= theta_scale;
5629
0
        theta_w *= theta_scale;
5630
0
        theta_h *= theta_scale;
5631
0
        theta_e *= theta_scale;
5632
0
    }
5633
0
}
5634
5635
5636
template<typename T>
5637
0
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
5638
0
  for (int64_t i0 = 0; i0 < n; i0 += 2) {
5639
0
    const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5640
5641
0
    const float cos_theta = cache[i0 + 0];
5642
0
    const float sin_theta = cache[i0 + 1];
5643
5644
0
    const T * const src = src_data + ic;
5645
0
    T * dst             = dst_data + ic;
5646
5647
0
    const float x0 = type_conversion_table<T>::to_f32(src[0]);
5648
0
    const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
5649
5650
0
    dst[0]        = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5651
0
    dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
5652
0
  }
5653
0
}
Unexecuted instantiation: ops.cpp:void rotate_pairs<unsigned short>(long, long, float const*, unsigned short const*, unsigned short*, int)
Unexecuted instantiation: ops.cpp:void rotate_pairs<float>(long, long, float const*, float const*, float*, int)
5654
5655
template<typename T> //float or ggml_fp16_t
5656
static void ggml_compute_forward_rope_flt(
5657
        const ggml_compute_params * params,
5658
        ggml_tensor * dst,
5659
0
        const bool forward) {
5660
5661
0
    const ggml_tensor * src0 = dst->src[0];
5662
0
    const ggml_tensor * src1 = dst->src[1];
5663
0
    const ggml_tensor * src2 = dst->src[2];
5664
5665
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5666
0
    GGML_ASSERT(src1->type == GGML_TYPE_I32);
5667
5668
0
    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5669
0
    int sections[4];
5670
5671
    //const int n_past     = ((int32_t *) dst->op_params)[0];
5672
0
    const int n_dims     = ((int32_t *) dst->op_params)[1];
5673
0
    const int mode       = ((int32_t *) dst->op_params)[2];
5674
    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
5675
0
    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5676
5677
0
    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
5678
0
    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
5679
0
    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
5680
0
    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
5681
0
    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
5682
0
    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
5683
0
    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
5684
5685
0
    GGML_TENSOR_UNARY_OP_LOCALS
5686
5687
    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5688
    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5689
5690
0
    GGML_ASSERT(nb0 == nb00);
5691
0
    GGML_ASSERT(nb0 == sizeof(T));
5692
5693
0
    const int ith = params->ith;
5694
0
    const int nth = params->nth;
5695
5696
0
    const int nr = ggml_nrows(dst);
5697
5698
0
    GGML_ASSERT(n_dims <= ne0);
5699
0
    GGML_ASSERT(n_dims % 2 == 0);
5700
5701
    // rows per thread
5702
0
    const int dr = (nr + nth - 1)/nth;
5703
5704
    // row range for this thread
5705
0
    const int ir0 = dr*ith;
5706
0
    const int ir1 = MIN(ir0 + dr, nr);
5707
5708
    // row index used to determine which thread to use
5709
0
    int ir = 0;
5710
5711
0
    const float theta_scale = powf(freq_base, -2.0f/n_dims);
5712
5713
0
    float corr_dims[2];
5714
0
    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5715
5716
0
    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5717
0
    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
5718
0
    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5719
5720
0
    if (mrope_used) {
5721
0
        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5722
0
    }
5723
5724
0
    if (is_vision) {
5725
0
        GGML_ASSERT(n_dims == ne0/2);
5726
0
    }
5727
5728
0
    const float * freq_factors = NULL;
5729
0
    if (src2 != NULL) {
5730
0
        GGML_ASSERT(src2->type == GGML_TYPE_F32);
5731
0
        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5732
0
        freq_factors = (const float *) src2->data;
5733
0
    }
5734
5735
    // backward process uses inverse rotation by cos and sin.
5736
    // cos and sin build a rotation matrix, where the inverse is the transpose.
5737
    // this essentially just switches the sign of sin.
5738
0
    const float sin_sign = forward ? 1.0f : -1.0f;
5739
5740
0
    const int32_t * pos = (const int32_t *) src1->data;
5741
5742
0
    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5743
0
        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5744
5745
0
            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5746
0
            if (!mrope_used) {
5747
0
                const int64_t p = pos[i2];
5748
0
                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5749
0
            }
5750
0
            else {
5751
0
                const int64_t p_t = pos[i2];
5752
0
                const int64_t p_h = pos[i2 + ne2];
5753
0
                const int64_t p_w = pos[i2 + ne2 * 2];
5754
0
                const int64_t p_e = pos[i2 + ne2 * 3];
5755
0
                ggml_mrope_cache_init(
5756
0
                    p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5757
0
                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5758
0
            }
5759
5760
0
            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5761
0
                if (ir++ < ir0) continue;
5762
0
                if (ir   > ir1) break;
5763
5764
0
                T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5765
0
                T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1);
5766
5767
0
                switch (mode) {
5768
0
                    case GGML_ROPE_TYPE_NORMAL:
5769
0
                        rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
5770
0
                        break;
5771
0
                    case GGML_ROPE_TYPE_NEOX:
5772
0
                    case GGML_ROPE_TYPE_MROPE:
5773
0
                    case GGML_ROPE_TYPE_IMROPE:
5774
0
                        rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
5775
0
                        break;
5776
0
                    case GGML_ROPE_TYPE_VISION:
5777
0
                        rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
5778
0
                        break;
5779
0
                    default:
5780
0
                        GGML_ABORT("rope type not supported");
5781
0
                }
5782
5783
0
                if (!is_vision) {
5784
                    // fill the remain channels with data from src tensor
5785
0
                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5786
0
                        const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5787
0
                        T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
5788
5789
0
                        dst_data[0] = src[0];
5790
0
                        dst_data[1] = src[1];
5791
0
                    }
5792
0
                }
5793
0
            } //attn-heads
5794
0
        }
5795
0
    }
5796
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_rope_flt<unsigned short>(ggml_compute_params const*, ggml_tensor*, bool)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_rope_flt<float>(ggml_compute_params const*, ggml_tensor*, bool)
5797
5798
void ggml_compute_forward_rope(
5799
        const ggml_compute_params * params,
5800
0
        ggml_tensor * dst) {
5801
5802
0
    const ggml_tensor * src0 = dst->src[0];
5803
5804
0
    switch (src0->type) {
5805
0
        case GGML_TYPE_F16:
5806
0
            {
5807
0
                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
5808
0
            } break;
5809
0
        case GGML_TYPE_F32:
5810
0
            {
5811
0
                ggml_compute_forward_rope_flt<float>(params, dst, true);
5812
0
            } break;
5813
0
        default:
5814
0
            {
5815
0
                GGML_ABORT("fatal error");
5816
0
            }
5817
0
    }
5818
0
}
5819
5820
// ggml_compute_forward_rope_back
5821
5822
void ggml_compute_forward_rope_back(
5823
        const ggml_compute_params * params,
5824
0
        ggml_tensor * dst) {
5825
5826
0
    const ggml_tensor * src0 = dst->src[0];
5827
5828
0
    switch (src0->type) {
5829
0
        case GGML_TYPE_F16:
5830
0
            {
5831
0
                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
5832
0
            } break;
5833
0
        case GGML_TYPE_F32:
5834
0
            {
5835
0
                ggml_compute_forward_rope_flt<float>(params, dst, false);
5836
0
            } break;
5837
0
        default:
5838
0
            {
5839
0
                GGML_ABORT("fatal error");
5840
0
            }
5841
0
    }
5842
0
}
5843
5844
// ggml_compute_forward_conv_transpose_1d
5845
5846
static void ggml_compute_forward_conv_transpose_1d_f16_f32(
5847
        const ggml_compute_params * params,
5848
0
              ggml_tensor * dst) {
5849
5850
0
    const ggml_tensor * src0 = dst->src[0];
5851
0
    const ggml_tensor * src1 = dst->src[1];
5852
5853
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
5854
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
5855
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
5856
5857
0
    GGML_TENSOR_BINARY_OP_LOCALS
5858
5859
0
    const int ith = params->ith;
5860
0
    const int nth = params->nth;
5861
5862
0
    const int nk = ne00*ne01*ne02;
5863
5864
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5865
0
    GGML_ASSERT(nb10 == sizeof(float));
5866
5867
0
    if (ith == 0) {
5868
0
        memset(params->wdata, 0, params->wsize);
5869
5870
        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
5871
0
        {
5872
0
            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
5873
5874
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
5875
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
5876
0
                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
5877
0
                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
5878
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
5879
0
                        dst_data[i00*ne02 + i02] = src[i00];
5880
0
                    }
5881
0
                }
5882
0
            }
5883
0
        }
5884
5885
        // permute source data (src1) from (L x Cin) to (Cin x L)
5886
0
        {
5887
0
            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
5888
0
            ggml_fp16_t * dst_data = wdata;
5889
5890
0
            for (int64_t i11 = 0; i11 < ne11; i11++) {
5891
0
                const float * const src = (float *)((char *) src1->data + i11*nb11);
5892
0
                for (int64_t i10 = 0; i10 < ne10; i10++) {
5893
0
                    dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
5894
0
                }
5895
0
            }
5896
0
        }
5897
5898
        // need to zero dst since we are accumulating into it
5899
0
        memset(dst->data, 0, ggml_nbytes(dst));
5900
0
    }
5901
0
    ggml_barrier(params->threadpool);
5902
5903
0
    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
5904
5905
    // total rows in dst
5906
0
    const int nr = ne1;
5907
5908
    // rows per thread
5909
0
    const int dr = (nr + nth - 1)/nth;
5910
5911
    // row range for this thread
5912
0
    const int ir0 = dr*ith;
5913
0
    const int ir1 = MIN(ir0 + dr, nr);
5914
5915
0
    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
5916
0
    ggml_fp16_t * const wdata_src = wdata + nk;
5917
5918
0
    for (int i1 = ir0; i1 < ir1; i1++) {
5919
0
        float * dst_data = (float *)((char *) dst->data + i1*nb1);
5920
0
        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
5921
0
        for (int i10 = 0; i10 < ne10; i10++) {
5922
0
            const int i1n = i10*ne11;
5923
0
            for (int i00 = 0; i00 < ne00; i00++) {
5924
0
                float v = 0;
5925
0
                ggml_vec_dot_f16(ne02, &v, 0,
5926
0
                        (ggml_fp16_t *)    wdata_src + i1n, 0,
5927
0
                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
5928
0
                dst_data[i10*s0 + i00] += v;
5929
0
            }
5930
0
        }
5931
0
    }
5932
0
}
5933
5934
static void ggml_compute_forward_conv_transpose_1d_f32(
5935
        const ggml_compute_params * params,
5936
0
              ggml_tensor * dst) {
5937
5938
0
    const ggml_tensor * src0 = dst->src[0];
5939
0
    const ggml_tensor * src1 = dst->src[1];
5940
5941
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
5942
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
5943
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
5944
5945
0
    GGML_TENSOR_BINARY_OP_LOCALS
5946
5947
0
    const int ith = params->ith;
5948
0
    const int nth = params->nth;
5949
5950
0
    const int nk = ne00*ne01*ne02;
5951
5952
0
    GGML_ASSERT(nb00 == sizeof(float));
5953
0
    GGML_ASSERT(nb10 == sizeof(float));
5954
5955
0
    if (ith == 0) {
5956
0
        memset(params->wdata, 0, params->wsize);
5957
5958
        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
5959
0
        {
5960
0
            float * const wdata = (float *) params->wdata + 0;
5961
5962
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
5963
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
5964
0
                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
5965
0
                    float * dst_data = wdata + i01*ne00*ne02;
5966
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
5967
0
                        dst_data[i00*ne02 + i02] = src[i00];
5968
0
                    }
5969
0
                }
5970
0
            }
5971
0
        }
5972
5973
        // prepare source data (src1)
5974
0
        {
5975
0
            float * const wdata = (float *) params->wdata + nk;
5976
0
            float * dst_data = wdata;
5977
5978
0
            for (int64_t i11 = 0; i11 < ne11; i11++) {
5979
0
                const float * const src = (float *)((char *) src1->data + i11*nb11);
5980
0
                for (int64_t i10 = 0; i10 < ne10; i10++) {
5981
0
                    dst_data[i10*ne11 + i11] = src[i10];
5982
0
                }
5983
0
            }
5984
0
        }
5985
5986
        // need to zero dst since we are accumulating into it
5987
0
        memset(dst->data, 0, ggml_nbytes(dst));
5988
0
    }
5989
0
    ggml_barrier(params->threadpool);
5990
5991
0
    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
5992
5993
    // total rows in dst
5994
0
    const int nr = ne1;
5995
5996
    // rows per thread
5997
0
    const int dr = (nr + nth - 1)/nth;
5998
5999
    // row range for this thread
6000
0
    const int ir0 = dr*ith;
6001
0
    const int ir1 = MIN(ir0 + dr, nr);
6002
6003
0
    float * const wdata     = (float *) params->wdata + 0;
6004
0
    float * const wdata_src = wdata + nk;
6005
6006
0
    for (int i1 = ir0; i1 < ir1; i1++) {
6007
0
        float * dst_data = (float *)((char *) dst->data + i1*nb1);
6008
0
        float * wdata_kernel = wdata + i1*ne02*ne00;
6009
0
        for (int i10 = 0; i10 < ne10; i10++) {
6010
0
            const int i1n = i10*ne11;
6011
0
            for (int i00 = 0; i00 < ne00; i00++) {
6012
0
                float v = 0;
6013
0
                ggml_vec_dot_f32(ne02, &v, 0,
6014
0
                        wdata_src + i1n, 0,
6015
0
                        wdata_kernel + i00*ne02, 0, 1);
6016
0
                dst_data[i10*s0 + i00] += v;
6017
0
            }
6018
0
        }
6019
0
    }
6020
0
}
6021
6022
void ggml_compute_forward_conv_transpose_1d(
6023
        const ggml_compute_params * params,
6024
0
              ggml_tensor * dst) {
6025
6026
0
    const ggml_tensor * src0 = dst->src[0];
6027
6028
0
    switch (src0->type) {
6029
0
        case GGML_TYPE_F16:
6030
0
            {
6031
0
                ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
6032
0
            } break;
6033
0
        case GGML_TYPE_F32:
6034
0
            {
6035
0
                ggml_compute_forward_conv_transpose_1d_f32(params, dst);
6036
0
            } break;
6037
0
        default:
6038
0
            {
6039
0
                GGML_ABORT("fatal error");
6040
0
            }
6041
0
    }
6042
0
}
6043
6044
// ggml_compute_forward_im2col_f32
6045
// src0: kernel [OC, IC, KH, KW]
6046
// src1: image [N, IC, IH, IW]
6047
// dst:  result [N, OH, OW, IC*KH*KW]
6048
static void ggml_compute_forward_im2col_f32(
6049
        const ggml_compute_params * params,
6050
0
              ggml_tensor * dst) {
6051
6052
0
    const ggml_tensor * src0 = dst->src[0];
6053
0
    const ggml_tensor * src1 = dst->src[1];
6054
6055
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6056
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6057
6058
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6059
6060
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6061
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6062
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6063
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6064
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6065
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6066
0
    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6067
6068
0
    const int ith = params->ith;
6069
0
    const int nth = params->nth;
6070
6071
0
    const int64_t N  = is_2D ? ne13 : ne12;
6072
0
    const int64_t IC = is_2D ? ne12 : ne11;
6073
0
    const int64_t IH = is_2D ? ne11 : 1;
6074
0
    const int64_t IW = ne10;
6075
6076
0
    const int64_t KH = is_2D ? ne01 : 1;
6077
0
    const int64_t KW = ne00;
6078
6079
0
    const int64_t OH = is_2D ? ne2 : 1;
6080
0
    const int64_t OW = ne1;
6081
6082
0
    int ofs0 = is_2D ? nb13 : nb12;
6083
0
    int ofs1 = is_2D ? nb12 : nb11;
6084
6085
0
    GGML_ASSERT(nb10 == sizeof(float));
6086
6087
    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6088
0
    {
6089
0
        float * const wdata = (float *) dst->data;
6090
6091
0
        for (int64_t in = 0; in < N; in++) {
6092
0
            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
6093
0
                for (int64_t iow = 0; iow < OW; iow++) {
6094
0
                    for (int64_t iic = ith; iic < IC; iic += nth) {
6095
6096
                        // micro kernel
6097
0
                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6098
0
                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6099
6100
0
                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
6101
0
                            for (int64_t ikw = 0; ikw < KW; ikw++) {
6102
0
                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
6103
0
                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
6104
6105
0
                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6106
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6107
0
                                } else {
6108
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
6109
0
                                }
6110
0
                            }
6111
0
                        }
6112
0
                    }
6113
0
                }
6114
0
            }
6115
0
        }
6116
0
    }
6117
0
}
6118
6119
6120
// ggml_compute_forward_im2col_f16
6121
// src0: kernel [OC, IC, KH, KW]
6122
// src1: image [N, IC, IH, IW]
6123
// dst:  result [N, OH, OW, IC*KH*KW]
6124
static void ggml_compute_forward_im2col_f16(
6125
        const ggml_compute_params * params,
6126
0
              ggml_tensor * dst) {
6127
6128
0
    const ggml_tensor * src0 = dst->src[0];
6129
0
    const ggml_tensor * src1 = dst->src[1];
6130
6131
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
6132
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6133
0
    GGML_ASSERT( dst->type == GGML_TYPE_F16);
6134
6135
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6136
6137
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6138
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6139
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6140
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6141
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6142
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6143
0
    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6144
6145
0
    const int ith = params->ith;
6146
0
    const int nth = params->nth;
6147
6148
0
    const int64_t N  = is_2D ? ne13 : ne12;
6149
0
    const int64_t IC = is_2D ? ne12 : ne11;
6150
0
    const int64_t IH = is_2D ? ne11 : 1;
6151
0
    const int64_t IW = ne10;
6152
6153
0
    const int64_t KH = is_2D ? ne01 : 1;
6154
0
    const int64_t KW = ne00;
6155
6156
0
    const int64_t OH = is_2D ? ne2 : 1;
6157
0
    const int64_t OW = ne1;
6158
6159
0
    int ofs0 = is_2D ? nb13 : nb12;
6160
0
    int ofs1 = is_2D ? nb12 : nb11;
6161
6162
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6163
0
    GGML_ASSERT(nb10 == sizeof(float));
6164
6165
    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6166
0
    {
6167
0
        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6168
6169
0
        for (int64_t in = 0; in < N; in++) {
6170
0
            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
6171
0
                for (int64_t iow = 0; iow < OW; iow++) {
6172
0
                    for (int64_t iic = ith; iic < IC; iic += nth) {
6173
6174
                        // micro kernel
6175
0
                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6176
0
                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6177
6178
0
                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
6179
0
                            for (int64_t ikw = 0; ikw < KW; ikw++) {
6180
0
                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
6181
0
                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
6182
6183
0
                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6184
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6185
0
                                } else {
6186
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
6187
0
                                }
6188
0
                            }
6189
0
                        }
6190
0
                    }
6191
0
                }
6192
0
            }
6193
0
        }
6194
0
    }
6195
0
}
6196
6197
void ggml_compute_forward_im2col(
6198
        const ggml_compute_params * params,
6199
0
              ggml_tensor * dst) {
6200
0
    switch (dst->type) {
6201
0
        case GGML_TYPE_F16:
6202
0
            {
6203
0
                ggml_compute_forward_im2col_f16(params, dst);
6204
0
            } break;
6205
0
        case GGML_TYPE_F32:
6206
0
            {
6207
0
                ggml_compute_forward_im2col_f32(params, dst);
6208
0
            } break;
6209
0
        default:
6210
0
            {
6211
0
                GGML_ABORT("fatal error");
6212
0
            }
6213
0
    }
6214
0
}
6215
6216
// ggml_compute_forward_im2col_back_f32
6217
6218
void ggml_compute_forward_im2col_back_f32(
6219
        const ggml_compute_params * params,
6220
0
              ggml_tensor * dst) {
6221
6222
0
    const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
6223
0
    const ggml_tensor * src1 = dst->src[1]; // convolution kernel
6224
6225
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
6226
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6227
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6228
6229
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6230
6231
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6232
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6233
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6234
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6235
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6236
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6237
0
    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6238
6239
0
    const int ith = params->ith;
6240
0
    const int nth = params->nth;
6241
6242
0
    const int64_t N  = is_2D ? ne3 : ne2;
6243
0
    const int64_t IC = is_2D ? ne2 : ne1;
6244
0
    const int64_t IH = is_2D ? ne1 : 1;
6245
0
    const int64_t IW = ne0;
6246
6247
0
    const int64_t KH = is_2D ? ne11 : 1;
6248
0
    const int64_t KW = ne10;
6249
6250
0
    const int64_t OH = is_2D ? ne02 : 1;
6251
0
    const int64_t OW = ne01;
6252
6253
0
    int ofs0 = is_2D ? nb3 : nb2;
6254
0
    int ofs1 = is_2D ? nb2 : nb1;
6255
6256
0
    GGML_ASSERT(nb0  == sizeof(float));
6257
6258
    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6259
0
    {
6260
0
        float * const wdata = (float *) dst->data;
6261
6262
0
        for (int64_t in = 0; in < N; in++) {
6263
0
            for (int64_t iic = ith; iic < IC; iic += nth) {
6264
0
                for (int64_t iih = 0; iih < IH; iih++) {
6265
0
                    for (int64_t iiw = 0; iiw < IW; iiw++) {
6266
6267
                        // micro kernel
6268
0
                        float grad = 0.0f;
6269
0
                        for (int64_t ikh = 0; ikh < KH; ikh++) {
6270
0
                            for (int64_t ikw = 0; ikw < KW; ikw++) {
6271
                                // For s0 > 1 some values were skipped over in the forward pass.
6272
                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6273
0
                                const int64_t tmpw = (iiw + p0 - ikw*d0);
6274
0
                                if (tmpw % s0 != 0) {
6275
0
                                    continue;
6276
0
                                }
6277
0
                                const int64_t iow = tmpw / s0;
6278
6279
                                // Equivalent logic as above except for s1.
6280
0
                                int64_t ioh;
6281
0
                                if (is_2D) {
6282
0
                                    const int64_t tmph = iih + p1 - ikh*d1;
6283
6284
0
                                    if (tmph % s1 != 0) {
6285
0
                                        continue;
6286
0
                                    }
6287
6288
0
                                    ioh = tmph / s1;
6289
0
                                } else {
6290
0
                                    ioh = 0;
6291
0
                                }
6292
6293
0
                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
6294
0
                                    continue;
6295
0
                                }
6296
6297
0
                                const float * const grad_in = (const float *) src0->data
6298
0
                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6299
0
                                grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
6300
0
                            }
6301
0
                        }
6302
0
                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
6303
0
                        dst_data[iih*IW + iiw] = grad;
6304
0
                    }
6305
0
                }
6306
0
            }
6307
0
        }
6308
0
    }
6309
0
}
6310
6311
6312
// ggml_compute_forward_im2col_3d_f16
6313
// src0: kernel [OC*IC, KD, KH, KW]
6314
// src1: image [N*IC, ID, IH, IW]
6315
// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]
6316
static void ggml_compute_forward_im2col_3d_f16(
6317
        const ggml_compute_params * params,
6318
0
              ggml_tensor * dst) {
6319
6320
0
    const ggml_tensor * src0 = dst->src[0];
6321
0
    const ggml_tensor * src1 = dst->src[1];
6322
6323
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
6324
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6325
0
    GGML_ASSERT( dst->type == GGML_TYPE_F16);
6326
6327
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6328
6329
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6330
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6331
0
    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6332
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6333
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6334
0
    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6335
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6336
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6337
0
    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6338
0
    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6339
6340
6341
0
    const int ith = params->ith;
6342
0
    const int nth = params->nth;
6343
6344
0
    const int64_t N  = ne13 / IC;
6345
0
    const int64_t ID = ne12;
6346
0
    const int64_t IH = ne11;
6347
0
    const int64_t IW = ne10;
6348
6349
0
    const int64_t OC = ne03 / IC;
6350
0
    GGML_UNUSED(OC);
6351
0
    const int64_t KD = ne02;
6352
0
    const int64_t KH = ne01;
6353
0
    const int64_t KW = ne00;
6354
6355
0
    const int64_t OD = ne3 / N;
6356
0
    const int64_t OH = ne2;
6357
0
    const int64_t OW = ne1;
6358
0
    const int64_t OH_OW = OH*OW;
6359
0
    const int64_t KD_KH_KW = KD*KH*KW;
6360
0
    const int64_t KH_KW = KH*KW;
6361
0
    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6362
6363
0
    GGML_ASSERT(nb10 == sizeof(float));
6364
6365
    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6366
0
    {
6367
0
        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6368
6369
0
        for (int64_t in = 0; in < N; in++) {
6370
0
            for (int64_t iod = 0; iod < OD; iod++) {
6371
0
                for (int64_t ioh = 0; ioh < OH; ioh++) {
6372
0
                    for (int64_t iow = 0; iow < OW; iow++) {
6373
0
                        for (int64_t iic = ith; iic < IC; iic += nth) {
6374
6375
                            // micro kernel
6376
0
                            ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6377
0
                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6378
6379
0
                            for (int64_t ikd = 0; ikd < KD; ikd++) {
6380
0
                                for (int64_t ikh = 0; ikh < KH; ikh++) {
6381
0
                                    for (int64_t ikw = 0; ikw < KW; ikw++) {
6382
0
                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;
6383
0
                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;
6384
0
                                        const int64_t iid = iod*s2 + ikd*d2 - p2;
6385
6386
0
                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6387
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6388
0
                                        } else {
6389
0
                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6390
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
6391
0
                                        }
6392
0
                                    }
6393
0
                                }
6394
0
                            }
6395
0
                        }
6396
0
                    }
6397
0
                }
6398
0
            }
6399
0
        }
6400
0
    }
6401
0
}
6402
6403
// ggml_compute_forward_im2col_3d_f32
6404
// src0: kernel [OC*IC, KD, KH, KW]
6405
// src1: image [N*IC, ID, IH, IW]
6406
// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]
6407
static void ggml_compute_forward_im2col_3d_f32(
6408
        const ggml_compute_params * params,
6409
0
              ggml_tensor * dst) {
6410
6411
0
    const ggml_tensor * src0 = dst->src[0];
6412
0
    const ggml_tensor * src1 = dst->src[1];
6413
6414
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6415
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6416
6417
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6418
6419
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6420
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6421
0
    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6422
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6423
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6424
0
    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6425
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6426
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6427
0
    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6428
0
    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6429
6430
6431
0
    const int ith = params->ith;
6432
0
    const int nth = params->nth;
6433
6434
0
    const int64_t N  = ne13 / IC;
6435
0
    const int64_t ID = ne12;
6436
0
    const int64_t IH = ne11;
6437
0
    const int64_t IW = ne10;
6438
6439
0
    const int64_t OC = ne03 / IC;
6440
0
    GGML_UNUSED(OC);
6441
0
    const int64_t KD = ne02;
6442
0
    const int64_t KH = ne01;
6443
0
    const int64_t KW = ne00;
6444
6445
0
    const int64_t OD = ne3 / N;
6446
0
    const int64_t OH = ne2;
6447
0
    const int64_t OW = ne1;
6448
6449
0
    const int64_t OH_OW = OH*OW;
6450
0
    const int64_t KD_KH_KW = KD*KH*KW;
6451
0
    const int64_t KH_KW = KH*KW;
6452
0
    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6453
6454
0
    GGML_ASSERT(nb10 == sizeof(float));
6455
6456
    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6457
0
    {
6458
0
        float * const wdata = (float *) dst->data;
6459
6460
0
        for (int64_t in = 0; in < N; in++) {
6461
0
            for (int64_t iod = 0; iod < OD; iod++) {
6462
0
                for (int64_t ioh = 0; ioh < OH; ioh++) {
6463
0
                    for (int64_t iow = 0; iow < OW; iow++) {
6464
0
                        for (int64_t iic = ith; iic < IC; iic += nth) {
6465
6466
                            // micro kernel
6467
0
                            float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6468
0
                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6469
6470
0
                            for (int64_t ikd = 0; ikd < KD; ikd++) {
6471
0
                                for (int64_t ikh = 0; ikh < KH; ikh++) {
6472
0
                                    for (int64_t ikw = 0; ikw < KW; ikw++) {
6473
0
                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;
6474
0
                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;
6475
0
                                        const int64_t iid = iod*s2 + ikd*d2 - p2;
6476
6477
0
                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6478
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6479
0
                                        } else {
6480
0
                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6481
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
6482
0
                                        }
6483
0
                                    }
6484
0
                                }
6485
0
                            }
6486
0
                        }
6487
0
                    }
6488
0
                }
6489
0
            }
6490
0
        }
6491
0
    }
6492
0
}
6493
6494
6495
void ggml_compute_forward_im2col_3d(
6496
        const ggml_compute_params * params,
6497
0
              ggml_tensor * dst) {
6498
0
    switch (dst->type) {
6499
0
        case GGML_TYPE_F16:
6500
0
            {
6501
0
                ggml_compute_forward_im2col_3d_f16(params, dst);
6502
0
            } break;
6503
0
        case GGML_TYPE_F32:
6504
0
            {
6505
0
                ggml_compute_forward_im2col_3d_f32(params, dst);
6506
0
            } break;
6507
0
        default:
6508
0
            {
6509
0
                GGML_ABORT("fatal error");
6510
0
            }
6511
0
    }
6512
0
}
6513
6514
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6515
0
                              void * a, void * b, float * c) {
6516
0
    const ggml_type_traits * traits = ggml_get_type_traits(type);
6517
0
    struct ggml_tensor src1 = {};
6518
0
    src1.type  = type;
6519
0
    src1.ne[0] = k;
6520
0
    src1.ne[1] = m;
6521
0
    src1.ne[2] = 1;
6522
0
    src1.ne[3] = 1;
6523
0
    src1.nb[0] = traits->type_size;
6524
0
    src1.nb[1] = k * traits->type_size;
6525
0
    src1.nb[2] = src1.nb[1];
6526
0
    src1.nb[3] = src1.nb[2];
6527
0
    src1.data  = a;
6528
6529
0
    struct ggml_tensor src0 = {};
6530
0
    src0.type  = type;
6531
0
    src0.ne[0] = k;
6532
0
    src0.ne[1] = n;
6533
0
    src0.ne[2] = 1;
6534
0
    src0.ne[3] = 1;
6535
0
    src0.nb[0] = traits->type_size;
6536
0
    src0.nb[1] = k * traits->type_size;
6537
0
    src0.nb[2] = src0.nb[1];
6538
0
    src0.nb[3] = src0.nb[2];
6539
0
    src0.data  = b;
6540
6541
0
    struct ggml_tensor dst = {};
6542
0
    dst.ne[0] = n;
6543
0
    dst.ne[1] = m;
6544
0
    dst.ne[2] = 1;
6545
0
    dst.ne[3] = 1;
6546
0
    dst.nb[0] = sizeof(float);
6547
0
    dst.nb[1] = n * sizeof(float);
6548
0
    dst.nb[2] = dst.nb[1];
6549
0
    dst.nb[3] = dst.nb[2];
6550
0
    dst.data  = c;
6551
0
    dst.src[0] = &src0;
6552
0
    dst.src[1] = &src1;
6553
6554
0
    ggml_compute_forward_mul_mat(params, &dst);
6555
0
}
6556
6557
0
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
6558
0
    return (coord  + size) % size; // adding size avoids negative number weirdness
6559
0
}
6560
6561
// ggml_compute_forward_conv_2d
6562
6563
6564
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6565
                                              const ggml_tensor *         kernel,  // [KW, KH, IC, OC]
6566
                                              const ggml_tensor *         src,     // [W, H, C, N]
6567
                                              ggml_tensor *               dst,     // [OW, OH, OC, N]
6568
0
                                              ggml_type                   kernel_type) {
6569
6570
0
    GGML_ASSERT(ggml_is_contiguous(kernel));
6571
0
    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6572
0
    GGML_ASSERT(kernel->type == kernel_type);
6573
6574
0
    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6575
6576
0
    const int32_t stride_x   = dst->op_params[0];
6577
0
    const int32_t stride_y   = dst->op_params[1];
6578
0
    const int32_t pad_x      = dst->op_params[2];
6579
0
    const int32_t pad_y      = dst->op_params[3];
6580
0
    const int32_t dilation_x = dst->op_params[4];
6581
0
    const int32_t dilation_y = dst->op_params[5];
6582
6583
0
    const int64_t c_in  = src->ne[2];
6584
0
    const int64_t c_out = kernel->ne[3];
6585
0
    GGML_ASSERT(c_in == kernel->ne[2]);
6586
6587
0
    const int64_t src_w = src->ne[0];
6588
0
    const int64_t src_h = src->ne[1];
6589
0
    const int64_t knl_w = kernel->ne[0];
6590
0
    const int64_t knl_h = kernel->ne[1];
6591
0
    const int64_t dst_w = dst->ne[0];
6592
0
    const int64_t dst_h = dst->ne[1];
6593
6594
0
    const float * src_data = (float *) src->data;
6595
0
    void  * knl_data       = kernel->data;
6596
0
    float * dst_data       = (float *) dst->data;
6597
6598
0
    const int64_t knl_n           = knl_w * knl_h * c_in;
6599
0
    const int64_t patch_total     = dst->ne[3] * dst_w * dst_h;
6600
6601
0
    const int64_t space_per_patch   = knl_n * traits->type_size + c_out * sizeof(float);
6602
0
    const int64_t batch_size        = params->wsize / space_per_patch;
6603
0
    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6604
0
    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;
6605
6606
0
    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6607
6608
0
    void * tmp = params->wdata;
6609
6610
0
    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6611
6612
0
        const int64_t patch_start_batch = batch_i * patches_per_batch;
6613
0
        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch,
6614
0
                                              patch_total);
6615
0
        const int64_t patch_n           = patch_end_batch - patch_start_batch;
6616
6617
0
        const int64_t patch_per_thread  = (patch_n + params->nth - 1) / params->nth;
6618
0
        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;
6619
0
        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);
6620
6621
        //im2col for a patch
6622
0
        for (int64_t p = patch_start; p < patch_end; ++p) {
6623
0
            const int64_t  batch_n     =  p / (dst_w * dst_h);
6624
0
            const int64_t  src_x       = (p / dst_w) % dst_h;
6625
0
            const int64_t  src_y       =  p % dst_w;
6626
6627
0
            const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6628
0
            char *        dst_row  = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6629
6630
0
            for (int64_t ic = 0; ic < c_in; ++ic) {
6631
0
                for (int64_t ky = 0; ky < knl_h; ++ky) {
6632
0
                    for (int64_t kx = 0; kx < knl_w; ++kx) {
6633
0
                        const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6634
0
                        const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6635
6636
0
                        int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6637
6638
0
                        float src_val;
6639
0
                        if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6640
0
                            src_val = 0.0f;
6641
0
                        } else {
6642
0
                            const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6643
0
                            src_val               = *src_ptr;
6644
0
                        }
6645
6646
0
                        char * element_ptr = dst_row + dst_idx * traits->type_size;
6647
0
                        if (kernel_type == GGML_TYPE_F32) {
6648
0
                            *(float *) element_ptr = src_val;
6649
0
                        } else if (kernel_type == GGML_TYPE_F16) {
6650
0
                            *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6651
0
                        }
6652
0
                    }
6653
0
                }
6654
0
            }
6655
0
        }   // patches handled by this thread
6656
6657
0
        ggml_barrier(params->threadpool);
6658
6659
0
        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6660
6661
0
        GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6662
6663
        // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6664
0
        ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6665
6666
0
        ggml_barrier(params->threadpool);
6667
6668
6669
        //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6670
0
        const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
6671
0
        const int64_t permute_start = params->ith * permute_per_thread;
6672
0
        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
6673
6674
0
        for (int64_t i = permute_start; i < permute_end; ++i) {
6675
0
            const int64_t p       = patch_start_batch + i;
6676
0
            const int64_t batch_n = p / (dst_w * dst_h);
6677
0
            const int64_t dst_y   = (p / dst_w) % dst_h;
6678
0
            const int64_t dst_x   = p % dst_w;
6679
6680
0
            for (int64_t oc = 0; oc < c_out; ++oc) {
6681
0
                const float value = gemm_output[i * c_out + oc];
6682
0
                float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
6683
0
                *dst_ptr = value;
6684
0
            }
6685
0
        }
6686
0
    }
6687
0
}
6688
6689
void ggml_compute_forward_conv_2d(
6690
        const ggml_compute_params * params,
6691
0
        ggml_tensor * dst) {
6692
6693
0
    const ggml_tensor * src0 = dst->src[0];
6694
0
    const ggml_tensor * src1 = dst->src[1];
6695
6696
0
    ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
6697
0
}
6698
6699
// ggml_compute_forward_conv_3d
6700
6701
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
6702
                                              const ggml_tensor *         kernel,
6703
                                              const ggml_tensor *         src,
6704
                                              ggml_tensor *               dst,
6705
0
                                              ggml_type                   kernel_type) {
6706
6707
0
    GGML_ASSERT(ggml_is_contiguous(kernel));
6708
0
    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6709
0
    GGML_ASSERT(kernel->type == kernel_type);
6710
6711
0
    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6712
6713
0
    const int32_t s0 = dst->op_params[0];
6714
0
    const int32_t s1 = dst->op_params[1];
6715
0
    const int32_t s2 = dst->op_params[2];
6716
0
    const int32_t p0 = dst->op_params[3];
6717
0
    const int32_t p1 = dst->op_params[4];
6718
0
    const int32_t p2 = dst->op_params[5];
6719
0
    const int32_t d0 = dst->op_params[6];
6720
0
    const int32_t d1 = dst->op_params[7];
6721
0
    const int32_t d2 = dst->op_params[8];
6722
0
    const int32_t c  = dst->op_params[9];
6723
0
    const int32_t n  = dst->op_params[10];
6724
0
    const int32_t oc = dst->op_params[11];
6725
6726
0
    const int64_t src_w = src->ne[0];
6727
0
    const int64_t src_h = src->ne[1];
6728
0
    const int64_t src_d = src->ne[2];
6729
0
    const int64_t knl_w = kernel->ne[0];
6730
0
    const int64_t knl_h = kernel->ne[1];
6731
0
    const int64_t knl_d = kernel->ne[2];
6732
0
    const int64_t dst_w = dst->ne[0];
6733
0
    const int64_t dst_h = dst->ne[1];
6734
0
    const int64_t dst_d = dst->ne[2];
6735
6736
0
    const float * src_data = (float *) src->data;
6737
0
    void  * knl_data       = kernel->data;
6738
0
    float * dst_data       = (float *) dst->data;
6739
6740
0
    const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
6741
0
    const int64_t knl_n_total       = knl_n_per_channel * c;
6742
0
    const int64_t patch_total       = n * dst_w * dst_h * dst_d;
6743
6744
0
    const int64_t space_per_patch   = knl_n_total * traits->type_size + oc * sizeof(float);
6745
0
    const int64_t batch_size        = params->wsize / space_per_patch;
6746
0
    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6747
0
    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;
6748
6749
0
    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6750
6751
0
    void * tmp = params->wdata;
6752
6753
0
    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6754
0
        const int64_t patch_start_batch = batch_i * patches_per_batch;
6755
0
        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch, patch_total);
6756
0
        const int64_t patch_n_in_batch  = patch_end_batch - patch_start_batch;
6757
6758
0
        const int64_t patch_per_thread  = (patch_n_in_batch + params->nth - 1) / params->nth;
6759
0
        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;
6760
0
        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);
6761
6762
0
        for (int64_t p = patch_start; p < patch_end; ++p) {
6763
0
            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6764
0
            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6765
0
            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);
6766
0
            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);
6767
0
            const int64_t dst_y      = p_in_depth / dst_w;
6768
0
            const int64_t dst_x      = p_in_depth % dst_w;
6769
6770
0
            char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
6771
6772
0
            for (int64_t ic = 0; ic < c; ++ic) {
6773
0
                for (int64_t kz = 0; kz < knl_d; ++kz) {
6774
0
                    for (int64_t ky = 0; ky < knl_h; ++ky) {
6775
0
                        for (int64_t kx = 0; kx < knl_w; ++kx) {
6776
0
                            const int64_t sz = dst_z * s2 + kz * d2 - p2;
6777
0
                            const int64_t sy = dst_y * s1 + ky * d1 - p1;
6778
0
                            const int64_t sx = dst_x * s0 + kx * d0 - p0;
6779
6780
0
                            int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
6781
6782
0
                            float src_val;
6783
0
                            if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6784
0
                                src_val = 0.0f;
6785
0
                            } else {
6786
0
                                const int64_t cn_idx = batch_idx * c + ic;
6787
0
                                const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
6788
0
                                src_val = *src_ptr;
6789
0
                            }
6790
6791
0
                            char * element_ptr = dst_row + dst_idx * traits->type_size;
6792
0
                            if (kernel_type == GGML_TYPE_F32) {
6793
0
                                *(float *)element_ptr = src_val;
6794
0
                            } else if (kernel_type == GGML_TYPE_F16) {
6795
0
                                *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6796
0
                            }
6797
0
                        }
6798
0
                    }
6799
0
                }
6800
0
            }
6801
0
        }
6802
6803
0
        ggml_barrier(params->threadpool);
6804
6805
0
        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
6806
0
        ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
6807
6808
0
        ggml_barrier(params->threadpool);
6809
6810
0
        const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6811
0
        const int64_t permute_start = params->ith * permute_per_thread;
6812
0
        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
6813
6814
0
        for (int64_t i = permute_start; i < permute_end; ++i) {
6815
0
            const int64_t p = patch_start_batch + i;
6816
0
            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6817
0
            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6818
0
            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);
6819
0
            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);
6820
0
            const int64_t dst_y      = p_in_depth / dst_w;
6821
0
            const int64_t dst_x      = p_in_depth % dst_w;
6822
6823
0
            for (int64_t ioc = 0; ioc < oc; ++ioc) {
6824
0
                const float value = gemm_output[i * oc + ioc];
6825
0
                const int64_t ocn_idx = batch_idx * oc + ioc;
6826
0
                float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
6827
0
                *dst_ptr = value;
6828
0
            }
6829
0
        }
6830
0
    }
6831
0
}
6832
6833
void ggml_compute_forward_conv_3d(
6834
        const ggml_compute_params * params,
6835
0
        ggml_tensor * dst) {
6836
0
    const ggml_tensor * src0 = dst->src[0];
6837
0
    const ggml_tensor * src1 = dst->src[1];
6838
0
    ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
6839
0
}
6840
6841
// ggml_compute_forward_conv_transpose_2d
6842
6843
void ggml_compute_forward_conv_transpose_2d(
6844
        const ggml_compute_params * params,
6845
0
              ggml_tensor * dst) {
6846
6847
0
    const ggml_tensor * src0 = dst->src[0];
6848
0
    const ggml_tensor * src1 = dst->src[1];
6849
6850
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
6851
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6852
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6853
6854
0
    GGML_TENSOR_BINARY_OP_LOCALS
6855
6856
0
    const int ith = params->ith;
6857
0
    const int nth = params->nth;
6858
6859
0
    const int nk = ne00*ne01*ne02*ne03;
6860
6861
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6862
0
    GGML_ASSERT(nb10 == sizeof(float));
6863
6864
0
    if (ith == 0) {
6865
0
        memset(params->wdata, 0, params->wsize);
6866
6867
        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
6868
0
        {
6869
0
            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
6870
6871
0
            for (int64_t i03 = 0; i03 < ne03; i03++) {
6872
0
                for (int64_t i02 = 0; i02 < ne02; i02++) {
6873
0
                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
6874
0
                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
6875
0
                    for (int64_t i01 = 0; i01 < ne01; i01++) {
6876
0
                        for (int64_t i00 = 0; i00 < ne00; i00++) {
6877
0
                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
6878
0
                        }
6879
0
                    }
6880
0
                }
6881
0
            }
6882
0
        }
6883
6884
        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
6885
0
        {
6886
0
            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
6887
0
            for (int i12 = 0; i12 < ne12; i12++) {
6888
0
                for (int i11 = 0; i11 < ne11; i11++) {
6889
0
                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
6890
0
                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
6891
0
                    for (int i10 = 0; i10 < ne10; i10++) {
6892
0
                        dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
6893
0
                    }
6894
0
                }
6895
0
            }
6896
0
        }
6897
6898
0
        memset(dst->data, 0, ggml_nbytes(dst));
6899
0
    }
6900
0
    ggml_barrier(params->threadpool);
6901
6902
0
    const int32_t stride = ggml_get_op_params_i32(dst, 0);
6903
6904
    // total patches in dst
6905
0
    const int np = ne2;
6906
6907
    // patches per thread
6908
0
    const int dp = (np + nth - 1)/nth;
6909
6910
    // patch range for this thread
6911
0
    const int ip0 = dp*ith;
6912
0
    const int ip1 = MIN(ip0 + dp, np);
6913
6914
0
    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
6915
0
    ggml_fp16_t * const wdata_src = wdata + nk;
6916
6917
0
    for (int i2 = ip0; i2 < ip1; i2++) { // Cout
6918
0
        float * dst_data = (float *)((char *) dst->data + i2*nb2);
6919
0
        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
6920
0
        for (int i11 = 0; i11 < ne11; i11++) {
6921
0
            for (int i10 = 0; i10 < ne10; i10++) {
6922
0
                const int i1n = i11*ne10*ne12 + i10*ne12;
6923
0
                for (int i01 = 0; i01 < ne01; i01++) {
6924
0
                    for (int i00 = 0; i00 < ne00; i00++) {
6925
0
                        float v = 0;
6926
0
                        ggml_vec_dot_f16(ne03, &v, 0,
6927
0
                                wdata_src + i1n, 0,
6928
0
                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
6929
0
                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
6930
0
                    }
6931
0
                }
6932
0
            }
6933
0
        }
6934
0
    }
6935
0
}
6936
6937
// ggml_compute_forward_conv_2d_dw
6938
6939
struct ggml_conv_2d_dw_params {
6940
    int64_t channels;
6941
    int64_t batch;
6942
    int64_t src_w;
6943
    int64_t src_h;
6944
    int64_t dst_w;
6945
    int64_t dst_h;
6946
    int64_t knl_w;
6947
    int64_t knl_h;
6948
    int stride_x;
6949
    int stride_y;
6950
    int pad_x;
6951
    int pad_y;
6952
    int dilation_x;
6953
    int dilation_y;
6954
};
6955
6956
static void ggml_compute_forward_conv_2d_dw_cwhn(
6957
        const ggml_compute_params * params,
6958
        const ggml_tensor * src,
6959
        const ggml_tensor * kernel,
6960
        ggml_tensor * dst,
6961
0
        const ggml_conv_2d_dw_params & p) {
6962
6963
0
    const int64_t c = p.channels;
6964
0
    const float * knl_data = (const float *)kernel->data;
6965
6966
0
    const int64_t rows_total = p.dst_h * p.batch;
6967
0
    const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
6968
0
    const int64_t row_start = params->ith * rows_per_thread;
6969
0
    const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
6970
6971
0
#ifdef GGML_SIMD
6972
    #if defined(__ARM_FEATURE_SVE)
6973
        const int64_t pkg_size = svcntw();
6974
    #else
6975
0
        const int64_t pkg_size = GGML_F32_EPR;
6976
0
    #endif
6977
0
    const int64_t pkg_count = c / pkg_size;
6978
0
    const int64_t c_pkg_end = pkg_count * pkg_size;
6979
#else
6980
    const int64_t c_pkg_end = 0;
6981
#endif
6982
6983
0
    for (int64_t row = row_start; row < row_end; ++row) {
6984
0
        const int64_t dst_y = row % p.dst_h;
6985
0
        const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
6986
0
        for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
6987
0
            float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
6988
0
            const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
6989
0
            const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
6990
6991
0
#ifdef GGML_SIMD
6992
            // Vectorized loop
6993
0
            for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
6994
0
                GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
6995
0
                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
6996
0
                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;
6997
0
                    if (src_y < 0 || src_y >= p.src_h) {
6998
0
                        continue;
6999
0
                    }
7000
0
                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7001
0
                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;
7002
0
                        if (src_x < 0 || src_x >= p.src_w) {
7003
0
                            continue;
7004
0
                        }
7005
0
                        GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
7006
0
                        GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
7007
0
                        sum = GGML_F32_VEC_FMA(sum, k, s);
7008
0
                    }
7009
0
                }
7010
0
                GGML_F32_VEC_STORE(dst_data + c_i, sum);
7011
0
            }
7012
0
#endif
7013
            // Scalar loop
7014
0
            for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
7015
0
                float sum = 0.0f;
7016
0
                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7017
0
                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;
7018
0
                    if (src_y < 0 || src_y >= p.src_h) {
7019
0
                        continue;
7020
0
                    }
7021
0
                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7022
0
                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;
7023
0
                        if (src_x < 0 || src_x >= p.src_w) {
7024
0
                            continue;
7025
0
                        }
7026
0
                        sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
7027
0
                             * src_data[(src_y * p.src_w + src_x) * c + c_i];
7028
0
                    }
7029
0
                }
7030
0
                dst_data[c_i] = sum;
7031
0
            }
7032
0
        }
7033
0
    }
7034
0
}
7035
7036
static void ggml_compute_forward_conv_2d_dw_whcn(
7037
        const ggml_compute_params * params,
7038
        const ggml_tensor * src,
7039
        const ggml_tensor * kernel,
7040
        ggml_tensor * dst,
7041
0
        const ggml_conv_2d_dw_params & p) {
7042
7043
0
    const int64_t n = p.channels * p.batch;
7044
0
    const int64_t per_thread = (n + params->nth - 1) / params->nth;
7045
0
    const int64_t start = params->ith * per_thread;
7046
0
    const int64_t end = MIN(start + per_thread, n);
7047
7048
0
    for (int64_t i = start; i < end; ++i) {
7049
0
        const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
7050
0
        const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
7051
0
        float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
7052
7053
0
        for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
7054
0
            for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
7055
7056
0
                float sum = 0.0f;
7057
0
                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7058
0
                    const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
7059
0
                    if (src_y < 0 || src_y >= p.src_h) {
7060
0
                        continue;
7061
0
                    }
7062
0
                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7063
0
                        const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
7064
0
                        if (src_x < 0 || src_x >= p.src_w) {
7065
0
                            continue;
7066
0
                        }
7067
0
                        sum += knl_data[knl_y * p.knl_w + knl_x]
7068
0
                             * src_data[src_y * p.src_w + src_x];
7069
0
                    }
7070
0
                }
7071
0
                dst_data[dst_y * p.dst_w + dst_x] = sum;
7072
0
            }
7073
0
        }
7074
0
    }
7075
0
}
7076
7077
void ggml_compute_forward_conv_2d_dw(
7078
        const ggml_compute_params * params,
7079
0
        ggml_tensor * dst) {
7080
7081
0
    const ggml_tensor * kernel = dst->src[0];
7082
0
    const ggml_tensor * src = dst->src[1];
7083
0
    ggml_conv_2d_dw_params p;
7084
0
    p.channels = src->ne[2];
7085
0
    p.batch = src->ne[3];
7086
0
    p.src_w = src->ne[0];
7087
0
    p.src_h = src->ne[1];
7088
0
    p.dst_w = dst->ne[0];
7089
0
    p.dst_h = dst->ne[1];
7090
0
    p.knl_w = kernel->ne[0];
7091
0
    p.knl_h = kernel->ne[1];
7092
0
    p.stride_x = dst->op_params[0];
7093
0
    p.stride_y = dst->op_params[1];
7094
0
    p.pad_x = dst->op_params[2];
7095
0
    p.pad_y = dst->op_params[3];
7096
0
    p.dilation_x = dst->op_params[4];
7097
0
    p.dilation_y = dst->op_params[5];
7098
7099
0
    GGML_ASSERT(kernel->ne[3] == p.channels);
7100
0
    GGML_ASSERT(dst->ne[3] == p.batch);
7101
7102
0
    if (ggml_is_contiguous(src)) {
7103
0
        ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
7104
0
    } else if (ggml_is_contiguous_channels(src)) {
7105
        // kernel should also have channels most contiguous in memory
7106
0
        GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
7107
0
        ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
7108
0
    } else {
7109
0
        GGML_ABORT("non-contiguous memory layout not supported");
7110
0
    }
7111
0
}
7112
7113
// ggml_compute_forward_pool_1d_sk_p0
7114
7115
static void ggml_compute_forward_pool_1d_sk_p0(
7116
        const ggml_compute_params * params,
7117
        const ggml_op_pool op,
7118
        const int k,
7119
0
        ggml_tensor * dst) {
7120
7121
0
    const ggml_tensor * src = dst->src[0];
7122
7123
0
    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
7124
7125
0
    if (params->ith != 0) {
7126
0
        return;
7127
0
    }
7128
7129
0
    const char * cdata = (const char *)src->data;
7130
0
    const char * const data_end = cdata + ggml_nbytes(src);
7131
0
    float * drow = (float *)dst->data;
7132
7133
0
    const int64_t rs = dst->ne[0];
7134
7135
0
    while (cdata < data_end) {
7136
0
        const void * srow = (const void *)cdata;
7137
0
        int j = 0;
7138
0
        for (int64_t i = 0; i < rs; ++i) {
7139
0
            switch (op) {
7140
0
                case GGML_OP_POOL_AVG:   drow[i] = 0;        break;
7141
0
                case GGML_OP_POOL_MAX:   drow[i] = -FLT_MAX; break;
7142
0
                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7143
0
            }
7144
0
            for (int ki = 0; ki < k; ++ki) {
7145
0
                const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7146
0
                switch (op) {
7147
0
                    case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break;
7148
0
                    case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break;
7149
0
                    case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error");
7150
0
                }
7151
0
                ++j;
7152
0
            }
7153
0
            switch (op) {
7154
0
                case GGML_OP_POOL_AVG:         drow[i] /= k; break;
7155
0
                case GGML_OP_POOL_MAX:                       break;
7156
0
                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7157
0
            }
7158
0
        }
7159
7160
0
        cdata += src->nb[1];
7161
0
        drow  += rs;
7162
0
    }
7163
0
}
7164
7165
// ggml_compute_forward_pool_1d
7166
7167
void ggml_compute_forward_pool_1d(
7168
        const ggml_compute_params * params,
7169
0
              ggml_tensor * dst) {
7170
7171
0
    const int32_t * opts = (const int32_t *)dst->op_params;
7172
0
    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7173
0
    const int k0 = opts[1];
7174
0
    const int s0 = opts[2];
7175
0
    const int p0 = opts[3];
7176
0
    GGML_ASSERT(p0 == 0); // padding not supported
7177
0
    GGML_ASSERT(k0 == s0); // only s = k supported
7178
7179
0
    ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
7180
0
}
7181
7182
// ggml_compute_forward_pool_2d
7183
7184
void ggml_compute_forward_pool_2d(
7185
        const ggml_compute_params * params,
7186
0
        ggml_tensor * dst) {
7187
7188
0
    const ggml_tensor * src = dst->src[0];
7189
7190
0
    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
7191
7192
0
    if (params->ith != 0) {
7193
0
        return;
7194
0
    }
7195
7196
0
    const int32_t * opts = (const int32_t *)dst->op_params;
7197
0
    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7198
0
    const int k0 = opts[1];
7199
0
    const int k1 = opts[2];
7200
0
    const int s0 = opts[3];
7201
0
    const int s1 = opts[4];
7202
0
    const int p0 = opts[5];
7203
0
    const int p1 = opts[6];
7204
0
    const char * cdata = (const char*)src->data;
7205
0
    const char * const data_end = cdata + ggml_nbytes(src);
7206
7207
0
    const int64_t px = dst->ne[0];
7208
0
    const int64_t py = dst->ne[1];
7209
0
    const int64_t pa = px * py;
7210
7211
0
    float * dplane = (float *)dst->data;
7212
7213
0
    const int ka = k0 * k1;
7214
0
    const int offset0 = -p0;
7215
0
    const int offset1 = -p1;
7216
7217
0
    while (cdata < data_end) {
7218
0
        for (int oy = 0; oy < py; ++oy) {
7219
0
            float * const drow = dplane + oy * px;
7220
0
            for (int ox = 0; ox < px; ++ox) {
7221
0
                float * const out =  drow + ox;
7222
0
                switch (op) {
7223
0
                    case GGML_OP_POOL_AVG:     *out = 0;        break;
7224
0
                    case GGML_OP_POOL_MAX:     *out = -FLT_MAX; break;
7225
0
                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7226
0
                }
7227
7228
0
                const int ix = offset0 + ox * s0;
7229
0
                const int iy = offset1 + oy * s1;
7230
7231
0
                for (int ky = 0; ky < k1; ++ky) {
7232
0
                    if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
7233
0
                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
7234
0
                    for (int kx = 0; kx < k0; ++kx) {
7235
0
                        int j = ix + kx;
7236
0
                        if (j < 0 || j >= src->ne[0]) continue;
7237
0
                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7238
0
                        switch (op) {
7239
0
                            case GGML_OP_POOL_AVG:                     *out += srow_j; break;
7240
0
                            case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break;
7241
0
                            case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
7242
0
                        }
7243
0
                    }
7244
0
                }
7245
0
                switch (op) {
7246
0
                    case GGML_OP_POOL_AVG:           *out /= ka; break;
7247
0
                    case GGML_OP_POOL_MAX:                       break;
7248
0
                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7249
0
                }
7250
0
            }
7251
0
        }
7252
7253
0
        cdata  += src->nb[2];
7254
0
        dplane += pa;
7255
0
    }
7256
0
}
7257
7258
// ggml_compute_forward_pool_2d_back
7259
7260
void ggml_compute_forward_pool_2d_back(
7261
        const ggml_compute_params * params,
7262
0
        ggml_tensor * dst) {
7263
7264
0
    const ggml_tensor * src  = dst->src[0];
7265
0
    const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
7266
7267
0
    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
7268
7269
0
    if (params->ith != 0) {
7270
0
        return;
7271
0
    }
7272
7273
0
    const int32_t * opts = (const int32_t *)dst->op_params;
7274
0
    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7275
0
    const int k0 = opts[1];
7276
0
    const int k1 = opts[2];
7277
0
    const int s0 = opts[3];
7278
0
    const int s1 = opts[4];
7279
0
    const int p0 = opts[5];
7280
0
    const int p1 = opts[6];
7281
7282
0
    char       * cdata  = (char       *) dst->data;
7283
0
    const char * cdataf = (const char *) dstf->data;
7284
0
    const char * const data_end = cdata + ggml_nbytes(dst);
7285
7286
0
    GGML_ASSERT(params->ith == 0);
7287
0
    memset(cdata, 0, ggml_nbytes(dst));
7288
7289
0
    const int64_t px = src->ne[0];
7290
0
    const int64_t py = src->ne[1];
7291
0
    const int64_t pa = px * py;
7292
7293
0
    const float * splane = (const float *) src->data;
7294
7295
0
    const int ka = k0 * k1;
7296
0
    const int offset0 = -p0;
7297
0
    const int offset1 = -p1;
7298
7299
0
    while (cdata < data_end) {
7300
0
        for (int oy = 0; oy < py; ++oy) {
7301
0
            const float * const srow = splane + oy * px;
7302
0
            for (int ox = 0; ox < px; ++ox) {
7303
0
                const float grad0 = srow[ox];
7304
7305
0
                const int ix = offset0 + ox * s0;
7306
0
                const int iy = offset1 + oy * s1;
7307
7308
0
                if (op == GGML_OP_POOL_MAX) {
7309
0
                    float maxval = -FLT_MAX;
7310
0
                    int kxmax = -1;
7311
0
                    int kymax = -1;
7312
7313
0
                    for (int ky = 0; ky < k1; ++ky) {
7314
0
                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
7315
0
                            continue;
7316
0
                        }
7317
0
                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
7318
0
                        for (int kx = 0; kx < k0; ++kx) {
7319
0
                            int j = ix + kx;
7320
0
                            if (j < 0 || j >= dst->ne[0]) {
7321
0
                                continue;
7322
0
                            }
7323
7324
0
                            const float val = dst->type == GGML_TYPE_F32 ?
7325
0
                                ((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
7326
0
                            if (val <= maxval) {
7327
0
                                continue;
7328
0
                            }
7329
7330
0
                            maxval = val;
7331
0
                            kxmax = kx;
7332
0
                            kymax = ky;
7333
0
                        }
7334
0
                    }
7335
7336
0
                    if (kxmax == -1 || kymax == -1) {
7337
0
                        continue;
7338
0
                    }
7339
7340
0
                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
7341
0
                    const int j = ix + kxmax;
7342
0
                    if (dst->type == GGML_TYPE_F32) {
7343
0
                        ((float *) drow)[j] += grad0;
7344
0
                    } else {
7345
0
                        ((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
7346
0
                    }
7347
0
                } else if (op == GGML_OP_POOL_AVG) {
7348
0
                    const float grad = grad0 / ka;
7349
7350
0
                    for (int ky = 0; ky < k1; ++ky) {
7351
0
                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
7352
0
                            continue;
7353
0
                        }
7354
0
                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
7355
0
                        for (int kx = 0; kx < k0; ++kx) {
7356
0
                            int j = ix + kx;
7357
0
                            if (j < 0 || j >= dst->ne[0]) {
7358
0
                                continue;
7359
0
                            }
7360
7361
0
                            if (dst->type == GGML_TYPE_F32) {
7362
0
                                ((float *) drow)[j] += grad;
7363
0
                            } else {
7364
0
                                ((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
7365
0
                            }
7366
0
                        }
7367
0
                    }
7368
0
                } else {
7369
0
                    GGML_ASSERT(false);
7370
0
                }
7371
0
            }
7372
0
        }
7373
7374
0
        cdata  += dst->nb[2];
7375
0
        cdataf += dst->nb[2];
7376
0
        splane += pa;
7377
0
    }
7378
0
}
7379
7380
// ggml_compute_forward_upscale
7381
7382
static void ggml_compute_forward_upscale_f32(
7383
    const ggml_compute_params * params,
7384
0
    ggml_tensor * dst) {
7385
7386
0
    const ggml_tensor * src0 = dst->src[0];
7387
7388
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
7389
7390
0
    const int ith = params->ith;
7391
0
    const int nth = params->nth;
7392
7393
0
    GGML_TENSOR_UNARY_OP_LOCALS
7394
7395
0
    float sf0 = (float)ne0/src0->ne[0];
7396
0
    float sf1 = (float)ne1/src0->ne[1];
7397
0
    float sf2 = (float)ne2/src0->ne[2];
7398
0
    float sf3 = (float)ne3/src0->ne[3];
7399
0
    float pixel_offset = 0.5f;
7400
7401
0
    const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7402
0
    const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7403
7404
0
    if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7405
0
        pixel_offset = 0.0f;
7406
0
        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7407
0
        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7408
0
    }
7409
7410
0
    if (mode == GGML_SCALE_MODE_NEAREST) {
7411
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7412
0
            const int64_t i03 = i3 / sf3;
7413
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7414
0
                const int64_t i02 = i2 / sf2;
7415
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7416
0
                    const int64_t i01 = i1 / sf1;
7417
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7418
0
                        const int64_t i00 = i0 / sf0;
7419
7420
0
                        const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7421
0
                              float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
7422
7423
0
                        *y = *x;
7424
0
                    }
7425
0
                }
7426
0
            }
7427
0
        }
7428
0
    } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
7429
        // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7430
        // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
7431
0
        auto triangle_filter = [](float x) -> float {
7432
0
            return std::max(1.0f - fabsf(x), 0.0f);
7433
0
        };
7434
7435
        // support and invscale, minimum 1 pixel for bilinear
7436
0
        const float support1  = std::max(1.0f, 1.0f / sf1);
7437
0
        const float invscale1 = 1.0f / support1;
7438
0
        const float support0  = std::max(1.0f, 1.0f / sf0);
7439
0
        const float invscale0 = 1.0f / support0;
7440
7441
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7442
0
            const int64_t i03 = i3 / sf3;
7443
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7444
0
                const int64_t i02 = i2 / sf2;
7445
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7446
0
                    const float y = ((float) i1 + pixel_offset) / sf1;
7447
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7448
0
                        const float x = ((float) i0 + pixel_offset) / sf0;
7449
7450
                        // the range of source pixels that contribute
7451
0
                        const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
7452
0
                        const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
7453
0
                        const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
7454
0
                        const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
7455
7456
                        // bilinear filter with antialiasing
7457
0
                        float val = 0.0f;
7458
0
                        float total_weight = 0.0f;
7459
7460
0
                        for (int64_t sy = y_min; sy < y_max; sy++) {
7461
0
                            const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
7462
7463
0
                            for (int64_t sx = x_min; sx < x_max; sx++) {
7464
0
                                const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
7465
0
                                const float weight = weight_x * weight_y;
7466
7467
0
                                if (weight <= 0.0f) {
7468
0
                                    continue;
7469
0
                                }
7470
7471
0
                                const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
7472
0
                                val += pixel * weight;
7473
0
                                total_weight += weight;
7474
0
                            }
7475
0
                        }
7476
7477
0
                        if (total_weight > 0.0f) {
7478
0
                            val /= total_weight;
7479
0
                        }
7480
7481
0
                        float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7482
0
                        *dst_ptr = val;
7483
0
                    }
7484
0
                }
7485
0
            }
7486
0
        }
7487
0
    } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7488
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7489
0
            const int64_t i03 = i3 / sf3;
7490
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7491
0
                const int64_t i02 = i2 / sf2;
7492
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7493
0
                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7494
0
                    int64_t y0 = (int64_t)floorf(y);
7495
0
                    int64_t y1 = y0 + 1;
7496
7497
0
                    y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
7498
0
                    y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
7499
7500
0
                    float dy = y - (float)y0;
7501
0
                    dy = std::max(0.0f, std::min(dy, 1.0f));
7502
7503
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7504
0
                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7505
0
                        int64_t x0 = (int64_t)floorf(x);
7506
0
                        int64_t x1 = x0 + 1;
7507
7508
0
                        x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
7509
0
                        x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
7510
7511
0
                        float dx = x - (float)x0;
7512
0
                        dx = std::max(0.0f, std::min(dx, 1.0f));
7513
7514
                        // fetch the four surrounding pixel values and interpolate
7515
0
                        const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
7516
0
                        const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
7517
0
                        const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
7518
0
                        const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
7519
7520
0
                        const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7521
7522
0
                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7523
0
                        *y_dst = val;
7524
0
                    }
7525
0
                }
7526
0
            }
7527
0
        }
7528
0
    } else if (mode == GGML_SCALE_MODE_BICUBIC) {
7529
        // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7530
0
        const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7531
0
        auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7532
0
        auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7533
0
        auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7534
0
            const float w0 = weight2(x + 1);
7535
0
            const float w1 = weight1(x + 0);
7536
0
            const float w2 = weight1(1 - x);
7537
0
            const float w3 = weight2(2 - x);
7538
0
            return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7539
0
        };
7540
7541
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7542
0
            const int64_t i03 = i3 / sf3;
7543
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7544
0
                const int64_t i02 = i2 / sf2;
7545
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7546
0
                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7547
0
                    const int64_t y0 = (int64_t)floorf(y);
7548
0
                    const float dy = y - (float)y0;
7549
7550
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7551
0
                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7552
0
                        const int64_t x0 = (int64_t)floorf(x);
7553
0
                        const float dx = x - (float)x0;
7554
7555
0
                        auto p = [=](int64_t x_off, int64_t y_off) -> float {
7556
0
                            int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7557
0
                            int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7558
0
                            return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7559
0
                        };
7560
7561
0
                        const float val = bicubic(
7562
0
                            bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7563
0
                            bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7564
0
                            bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7565
0
                            bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7566
7567
0
                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7568
0
                        *y_dst = val;
7569
0
                    }
7570
0
                }
7571
0
            }
7572
0
        }
7573
0
    } else {
7574
0
        GGML_ABORT("unsupported upscale mode");
7575
0
    }
7576
0
}
7577
7578
void ggml_compute_forward_upscale(
7579
    const ggml_compute_params * params,
7580
0
    ggml_tensor * dst) {
7581
7582
0
    const ggml_tensor * src0 = dst->src[0];
7583
7584
0
    switch (src0->type) {
7585
0
        case GGML_TYPE_F32:
7586
0
            {
7587
0
                ggml_compute_forward_upscale_f32(params, dst);
7588
0
            } break;
7589
0
        default:
7590
0
            {
7591
0
                GGML_ABORT("fatal error");
7592
0
            }
7593
0
    }
7594
0
}
7595
7596
7597
// ggml_compute_forward_pad
7598
7599
template<bool circular_t>
7600
static void ggml_compute_forward_pad_f32(
7601
    const ggml_compute_params * params,
7602
0
          ggml_tensor * dst) {
7603
7604
0
    const ggml_tensor * src0 = dst->src[0];
7605
7606
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
7607
0
    GGML_ASSERT( dst->nb[0] == sizeof(float));
7608
7609
0
    const int ith = params->ith;
7610
0
    const int nth = params->nth;
7611
7612
0
    GGML_TENSOR_UNARY_OP_LOCALS
7613
7614
0
    float * dst_ptr = (float *) dst->data;
7615
0
    const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
7616
0
    const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
7617
0
    const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
7618
0
    const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
7619
0
    const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
7620
0
    const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
7621
0
    const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
7622
0
    const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
7623
7624
    // TODO: optimize
7625
7626
0
    for (int64_t i2 = 0; i2 < ne2; ++i2) {
7627
0
        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7628
0
            for (int64_t i0 = 0; i0 < ne0; ++i0) {
7629
0
                for (int64_t i3 = 0; i3 < ne3; ++i3) {
7630
                    // circular means wrap around on a torus, so x and y loop around
7631
0
                    if constexpr (circular_t) {
7632
0
                        const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7633
0
                        const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
7634
0
                        const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
7635
0
                        const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
7636
0
                        const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
7637
7638
0
                        const int64_t src_idx =
7639
0
                            src_i3*nb03 +
7640
0
                            src_i2*nb02 +
7641
0
                            src_i1*nb01 +
7642
0
                            src_i0*nb00;
7643
7644
0
                        const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7645
0
                        dst_ptr[dst_idx] = *src_ptr;
7646
0
                    } else {
7647
0
                        const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7648
0
                        if ((i0 >= lp0 && i0 < ne0 - rp0) \
7649
0
                            && (i1 >= lp1 && i1 < ne1 - rp1) \
7650
0
                            && (i2 >= lp2 && i2 < ne2 - rp2) \
7651
0
                            && (i3 >= lp3 && i3 < ne3 - rp3)) {
7652
0
                            const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7653
0
                            const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7654
0
                            dst_ptr[dst_idx] = *src_ptr;
7655
0
                        } else {
7656
0
                            dst_ptr[dst_idx] = 0;
7657
0
                        }
7658
0
                    }
7659
0
                }
7660
0
            }
7661
0
        }
7662
0
    }
7663
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_pad_f32<true>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_pad_f32<false>(ggml_compute_params const*, ggml_tensor*)
7664
7665
7666
void ggml_compute_forward_pad(
7667
    const ggml_compute_params * params,
7668
0
    ggml_tensor * dst) {
7669
0
    const ggml_tensor * src0 = dst->src[0];
7670
0
    const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
7671
0
    switch (src0->type) {
7672
0
        case GGML_TYPE_F32:
7673
0
            {
7674
0
                if (circular) {
7675
0
                    ggml_compute_forward_pad_f32<true>(params, dst);
7676
0
                } else {
7677
0
                    ggml_compute_forward_pad_f32<false>(params, dst);
7678
0
                }
7679
0
            } break;
7680
0
        default:
7681
0
            {
7682
0
                GGML_ABORT("fatal error");
7683
0
            }
7684
0
    }
7685
0
}
7686
7687
// ggml_compute_forward_pad_reflect_1d
7688
7689
void ggml_compute_forward_pad_reflect_1d(
7690
        const ggml_compute_params * params,
7691
0
              ggml_tensor * dst) {
7692
7693
0
    const ggml_tensor * src0 = dst->src[0];
7694
7695
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
7696
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
7697
7698
0
    const int ith = params->ith;
7699
0
    const int nth = params->nth;
7700
7701
0
    const int32_t * opts = (const int32_t *) dst->op_params;
7702
0
    const int p0 = opts[0];
7703
0
    const int p1 = opts[1];
7704
7705
0
    GGML_TENSOR_UNARY_OP_LOCALS
7706
7707
0
    for (int64_t i3 = 0; i3 < ne3; i3++) {
7708
0
        for (int64_t i2 = 0; i2 < ne2; i2++) {
7709
0
            for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7710
0
                float * left  = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 +         p0*nb0);
7711
0
                float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
7712
7713
0
                ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
7714
7715
0
                for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0];   }
7716
0
                for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
7717
0
            }
7718
0
        }
7719
0
    }
7720
0
}
7721
7722
// ggml_compute_forward_roll
7723
7724
0
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
7725
0
    if (i < 0) {
7726
0
        return i + ne;
7727
0
    } else if (i >= ne) {
7728
0
        return i - ne;
7729
0
    }
7730
0
    return i;
7731
0
}
7732
7733
static void ggml_compute_forward_roll_f32(
7734
        const ggml_compute_params * params,
7735
0
        ggml_tensor * dst) {
7736
7737
0
    const ggml_tensor * src0 = dst->src[0];
7738
0
    const float * src_data = (const float *) src0->data;
7739
0
    float * dst_data = (float *) dst->data;
7740
7741
0
    GGML_TENSOR_UNARY_OP_LOCALS
7742
7743
0
    const int s0 = ggml_get_op_params_i32(dst, 0);
7744
0
    const int s1 = ggml_get_op_params_i32(dst, 1);
7745
0
    const int s2 = ggml_get_op_params_i32(dst, 2);
7746
0
    const int s3 = ggml_get_op_params_i32(dst, 3);
7747
7748
0
    const int64_t total = ne1 * ne2 * ne3;
7749
0
    const int64_t per_thread = (total + params->nth) / params->nth;
7750
0
    const int64_t start = params->ith * per_thread;
7751
0
    const int64_t end   = std::min(start + per_thread, total);
7752
7753
0
    for (int64_t i = start; i < end; ++i) {
7754
0
        const int64_t i1 = i % ne1;
7755
0
        const int64_t i2 = (i / ne1) % ne2;
7756
0
        const int64_t i3 = i / (ne2 * ne1);
7757
0
        float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
7758
7759
0
        const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
7760
0
        const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
7761
0
        const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
7762
0
        const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
7763
7764
0
        const int64_t s = ggml_wrap_index(-s0, ne00);
7765
0
        const int64_t n = ne00 - s;
7766
0
        ggml_vec_cpy_f32(n, dst_row,     src_row + s);
7767
0
        ggml_vec_cpy_f32(s, dst_row + n, src_row);
7768
0
    }
7769
0
}
7770
7771
void ggml_compute_forward_roll(
7772
        const ggml_compute_params * params,
7773
0
        ggml_tensor * dst) {
7774
7775
0
    const ggml_tensor * src0 = dst->src[0];
7776
7777
0
    switch (src0->type) {
7778
0
        case GGML_TYPE_F32:
7779
0
            {
7780
0
                ggml_compute_forward_roll_f32(params, dst);
7781
0
            } break;
7782
0
        default:
7783
0
            {
7784
0
                GGML_ABORT("fatal error");
7785
0
            }
7786
0
    }
7787
0
}
7788
7789
// ggml_compute_forward_arange
7790
7791
static void ggml_compute_forward_arange_f32(
7792
    const ggml_compute_params * params,
7793
0
    ggml_tensor * dst) {
7794
7795
0
    GGML_ASSERT(dst->nb[0] == sizeof(float));
7796
7797
0
    const int ith = params->ith;
7798
0
    const int nth = params->nth;
7799
7800
0
    const float start = ggml_get_op_params_f32(dst, 0);
7801
0
    const float stop  = ggml_get_op_params_f32(dst, 1);
7802
0
    const float step  = ggml_get_op_params_f32(dst, 2);
7803
7804
0
    const int64_t steps = (int64_t) ceilf((stop - start) / step);
7805
7806
0
    GGML_ASSERT(ggml_nelements(dst) == steps);
7807
7808
0
    for (int64_t i = ith; i < steps; i+= nth) {
7809
0
        float value = start + step * i;
7810
0
        ((float *)dst->data)[i] = value;
7811
0
    }
7812
0
}
7813
7814
void ggml_compute_forward_arange(
7815
    const ggml_compute_params * params,
7816
0
    ggml_tensor * dst) {
7817
0
    switch (dst->type) {
7818
0
        case GGML_TYPE_F32:
7819
0
            {
7820
0
                ggml_compute_forward_arange_f32(params, dst);
7821
0
            } break;
7822
0
        default:
7823
0
            {
7824
0
                GGML_ABORT("fatal error");
7825
0
            }
7826
0
    }
7827
0
}
7828
7829
static void ggml_compute_forward_timestep_embedding_f32(
7830
    const ggml_compute_params * params,
7831
0
    ggml_tensor * dst) {
7832
7833
0
    const ggml_tensor * src0 = dst->src[0];
7834
7835
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
7836
7837
0
    const int ith = params->ith;
7838
0
    const int nth = params->nth;
7839
7840
0
    GGML_TENSOR_UNARY_OP_LOCALS
7841
7842
0
    const int dim = ggml_get_op_params_i32(dst, 0);
7843
0
    const int max_period = ggml_get_op_params_i32(dst, 1);
7844
7845
0
    int half = dim / 2;
7846
7847
0
    for (int64_t i = 0; i < ne00; i++) {
7848
0
        float * embed_data = (float *)((char *)  dst->data +  i*nb1);
7849
0
        for (int64_t j = ith; j < half; j += nth) {
7850
0
            float timestep = ((float *)src0->data)[i];
7851
0
            float freq = (float)expf(-logf(max_period) * j / half);
7852
0
            float arg = timestep * freq;
7853
0
            embed_data[j] = cosf(arg);
7854
0
            embed_data[j + half] = sinf(arg);
7855
0
        }
7856
0
        if (dim % 2 != 0 && ith == 0) {
7857
0
            embed_data[2 * half] = 0.f;
7858
0
        }
7859
0
    }
7860
0
}
7861
7862
void ggml_compute_forward_timestep_embedding(
7863
    const ggml_compute_params * params,
7864
0
    ggml_tensor * dst) {
7865
7866
0
    const ggml_tensor * src0 = dst->src[0];
7867
7868
0
    switch (src0->type) {
7869
0
        case GGML_TYPE_F32:
7870
0
            {
7871
0
                ggml_compute_forward_timestep_embedding_f32(params, dst);
7872
0
            } break;
7873
0
        default:
7874
0
            {
7875
0
                GGML_ABORT("fatal error");
7876
0
            }
7877
0
    }
7878
0
}
7879
7880
// ggml_compute_forward_argsort
7881
7882
template<enum ggml_sort_order order>
7883
struct cmp_argsort {
7884
    const float * data;
7885
0
    bool operator()(int32_t a, int32_t b) const {
7886
0
        if constexpr (order == GGML_SORT_ORDER_ASC) {
7887
0
            return data[a] < data[b];
7888
0
        } else {
7889
0
            return data[a] > data[b];
7890
0
        }
7891
0
    }
Unexecuted instantiation: cmp_argsort<(ggml_sort_order)0>::operator()(int, int) const
Unexecuted instantiation: cmp_argsort<(ggml_sort_order)1>::operator()(int, int) const
7892
};
7893
7894
static void ggml_compute_forward_argsort_f32(
7895
    const ggml_compute_params * params,
7896
0
    ggml_tensor * dst) {
7897
7898
0
    const ggml_tensor * src0 = dst->src[0];
7899
7900
0
    GGML_TENSOR_UNARY_OP_LOCALS
7901
7902
0
    GGML_ASSERT(nb0 == sizeof(float));
7903
7904
0
    const int ith = params->ith;
7905
0
    const int nth = params->nth;
7906
7907
0
    const int64_t nr = ggml_nrows(src0);
7908
7909
0
    ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
7910
7911
0
    for (int64_t i = ith; i < nr; i += nth) {
7912
0
        const float * src_data = (float *)((char *) src0->data + i*nb01);
7913
7914
0
        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7915
7916
0
        for (int64_t j = 0; j < ne0; j++) {
7917
0
            dst_data[j] = j;
7918
0
        }
7919
7920
0
        switch (order) {
7921
0
            case GGML_SORT_ORDER_ASC:
7922
0
                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
7923
0
                break;
7924
7925
0
            case GGML_SORT_ORDER_DESC:
7926
0
                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
7927
0
                break;
7928
7929
0
            default:
7930
0
                GGML_ABORT("invalid sort order");
7931
0
        }
7932
0
    }
7933
0
}
7934
7935
void ggml_compute_forward_argsort(
7936
    const ggml_compute_params * params,
7937
0
    ggml_tensor * dst) {
7938
7939
0
    const ggml_tensor * src0 = dst->src[0];
7940
7941
0
    switch (src0->type) {
7942
0
        case GGML_TYPE_F32:
7943
0
            {
7944
0
                ggml_compute_forward_argsort_f32(params, dst);
7945
0
            } break;
7946
0
        default:
7947
0
            {
7948
0
                GGML_ABORT("fatal error");
7949
0
            }
7950
0
    }
7951
0
}
7952
7953
// ggml_compute_forward_top_k
7954
7955
struct cmp_top_k {
7956
    const float * data;
7957
0
    bool operator()(int32_t a, int32_t b) const {
7958
0
        return data[a] > data[b];
7959
0
    }
7960
};
7961
7962
static void ggml_compute_forward_top_k_f32(
7963
    const ggml_compute_params * params,
7964
0
    ggml_tensor * dst) {
7965
7966
0
    const ggml_tensor * src0 = dst->src[0];
7967
7968
0
    GGML_TENSOR_UNARY_OP_LOCALS
7969
7970
0
    GGML_ASSERT(nb0 == sizeof(float));
7971
7972
0
    const int ith = params->ith;
7973
0
    const int nth = params->nth;
7974
7975
0
    const int64_t nr = ggml_nrows(src0);
7976
7977
0
    const int top_k = ne0;
7978
7979
0
    int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7980
7981
0
    for (int64_t i = ith; i < nr; i += nth) {
7982
0
        const float * src_data = (float *)((char *) src0->data + i*nb01);
7983
7984
0
        for (int64_t j = 0; j < ne00; j++) {
7985
0
            tmp[j] = j;
7986
0
        }
7987
7988
0
        std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
7989
7990
0
        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7991
7992
0
        std::copy(tmp, tmp + top_k, dst_data);
7993
7994
        // emphasize that the order is not important
7995
0
        if (top_k > 1) {
7996
0
            std::swap(dst_data[0], dst_data[1]);
7997
0
        }
7998
0
    }
7999
0
}
8000
8001
void ggml_compute_forward_top_k(
8002
    const ggml_compute_params * params,
8003
0
    ggml_tensor * dst) {
8004
8005
0
    const ggml_tensor * src0 = dst->src[0];
8006
8007
0
    switch (src0->type) {
8008
0
        case GGML_TYPE_F32:
8009
0
            {
8010
0
                ggml_compute_forward_top_k_f32(params, dst);
8011
0
            } break;
8012
0
        default:
8013
0
            {
8014
0
                GGML_ABORT("fatal error");
8015
0
            }
8016
0
    }
8017
0
}
8018
8019
// ggml_compute_forward_flash_attn_ext
8020
8021
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8022
        const ggml_compute_params * params,
8023
        ggml_tensor * dst,
8024
0
        int ir0, int ir1) {
8025
0
    const ggml_tensor * q     = dst->src[0];
8026
0
    const ggml_tensor * k     = dst->src[1];
8027
0
    const ggml_tensor * v     = dst->src[2];
8028
0
    const ggml_tensor * mask  = dst->src[3];
8029
0
    const ggml_tensor * sinks = dst->src[4];
8030
8031
0
    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
8032
0
    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
8033
0
    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
8034
0
    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
8035
0
    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
8036
0
    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
8037
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
8038
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
8039
8040
0
    const int64_t DK = nek0;
8041
0
    const int64_t DV = nev0;
8042
0
    const int64_t N  = neq1;
8043
8044
0
    GGML_ASSERT(ne0 == DV);
8045
0
    GGML_ASSERT(ne2 == N);
8046
8047
    // input tensor rows must be contiguous
8048
0
    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8049
0
    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8050
0
    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8051
8052
0
    GGML_ASSERT(neq0 == DK);
8053
0
    GGML_ASSERT(nek0 == DK);
8054
0
    GGML_ASSERT(nev0 == DV);
8055
8056
0
    GGML_ASSERT(neq1 == N);
8057
8058
    // dst cannot be transposed or permuted
8059
0
    GGML_ASSERT(nb0 == sizeof(float));
8060
0
    GGML_ASSERT(nb0 <= nb1);
8061
0
    GGML_ASSERT(nb1 <= nb2);
8062
0
    GGML_ASSERT(nb2 <= nb3);
8063
8064
    // broadcast factors
8065
0
    const int64_t rk2 = neq2/nek2;
8066
0
    const int64_t rk3 = neq3/nek3;
8067
8068
0
    const int64_t rv2 = neq2/nev2;
8069
0
    const int64_t rv3 = neq3/nev3;
8070
8071
    // parallelize by q rows using ggml_vec_dot_f32
8072
8073
0
    float scale         = 1.0f;
8074
0
    float max_bias      = 0.0f;
8075
0
    float logit_softcap = 0.0f;
8076
8077
0
    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
8078
0
    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
8079
0
    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
8080
8081
0
    if (logit_softcap != 0) {
8082
0
        scale /= logit_softcap;
8083
0
    }
8084
8085
0
    const uint32_t n_head      = neq2;
8086
0
    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
8087
8088
0
    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
8089
0
    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8090
8091
0
    ggml_type         const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
8092
0
    ggml_from_float_t const q_to_vec_dot   = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
8093
0
    ggml_vec_dot_t    const kq_vec_dot     = ggml_get_type_traits_cpu(k->type)->vec_dot;
8094
0
    ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
8095
8096
0
    GGML_ASSERT((                            q_to_vec_dot) && "fattn: unsupported K-type");
8097
0
    GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float  ) && "fattn: unsupported V-type");
8098
8099
0
    int ith = params->ith;
8100
8101
    // loop over n_batch and n_head
8102
0
    for (int ir = ir0; ir < ir1; ++ir) {
8103
        // q indices
8104
0
        const int iq3 = ir/(neq2*neq1);
8105
0
        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
8106
0
        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
8107
8108
0
        const uint32_t h = iq2; // head index
8109
0
        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
8110
8111
0
        float S = 0.0f;      // sum
8112
0
        float M = -INFINITY; // maximum KQ value
8113
8114
0
        float       * VKQ32 = (float       *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
8115
0
        float       * V32   =                 (VKQ32 + 1*DV); // (temporary) FP32 V buffer
8116
0
        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
8117
0
        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
8118
8119
0
        if (v->type == GGML_TYPE_F16) {
8120
0
            memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
8121
0
        } else {
8122
0
            memset(VKQ32, 0, DV*sizeof(float));
8123
0
        }
8124
8125
0
        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
8126
8127
        // k indices
8128
0
        const int ik3 = iq3 / rk3;
8129
0
        const int ik2 = iq2 / rk2;
8130
8131
        // v indices
8132
0
        const int iv3 = iq3 / rv3;
8133
0
        const int iv2 = iq2 / rv2;
8134
8135
0
        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
8136
0
        q_to_vec_dot(pq, Q_q, DK);
8137
8138
        // online softmax / attention
8139
        // loop over n_kv and n_head_kv
8140
        // ref: https://arxiv.org/pdf/2112.05682.pdf
8141
0
        for (int64_t ic = 0; ic < nek1; ++ic) {
8142
0
            const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
8143
0
            if (mv == -INFINITY) {
8144
0
                continue;
8145
0
            }
8146
8147
0
            float s; // KQ value
8148
8149
0
            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
8150
0
            kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
8151
8152
0
            s = s*scale; // scale KQ value
8153
8154
0
            if (logit_softcap != 0.0f) {
8155
0
                s = logit_softcap*tanhf(s);
8156
0
            }
8157
8158
0
            s += mv; // apply mask
8159
8160
0
            const float Mold = M;
8161
8162
0
            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
8163
0
            float vs = 1.0f; // post-softmax KQ value, expf(s - M)
8164
8165
0
            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
8166
8167
0
            if (v->type == GGML_TYPE_F16) {
8168
0
                if (s > M) {
8169
                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
8170
0
                    M = s;
8171
0
                    ms = expf(Mold - M);
8172
8173
                    // V = V*expf(Mold - M)
8174
0
                    ggml_vec_scale_f16(DV, VKQ16, ms);
8175
0
                } else {
8176
                    // no new maximum, ms == 1.0f, vs != 1.0f
8177
0
                    vs = expf(s - M);
8178
0
                }
8179
8180
                // V += v*expf(s - M)
8181
0
                ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
8182
0
            } else {
8183
0
                if (s > M) {
8184
                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
8185
0
                    M = s;
8186
0
                    ms = expf(Mold - M);
8187
8188
                    // V = V*expf(Mold - M)
8189
0
                    ggml_vec_scale_f32(DV, VKQ32, ms);
8190
0
                } else {
8191
                    // no new maximum, ms == 1.0f, vs != 1.0f
8192
0
                    vs = expf(s - M);
8193
0
                }
8194
8195
                // V += v*expf(s - M)
8196
0
                if (v_to_float) {
8197
0
                    v_to_float(v_data, V32, DV);
8198
0
                    ggml_vec_mad_f32(DV, VKQ32, V32, vs);
8199
0
                } else {
8200
                    // V is F32
8201
0
                    ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
8202
0
                }
8203
0
            }
8204
8205
0
            S = S*ms + vs; // scale and increment sum with partial sum
8206
0
        }
8207
8208
0
        if (v->type == GGML_TYPE_F16) {
8209
0
            for (int64_t d = 0; d < DV; ++d) {
8210
0
                VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
8211
0
            }
8212
0
        }
8213
8214
        // sinks
8215
0
        if (sinks) {
8216
0
            const float s = ((float *)((char *) sinks->data))[h];
8217
8218
0
            float ms = 1.0f;
8219
0
            float vs = 1.0f;
8220
8221
0
            if (s > M) {
8222
0
                ms = expf(M - s);
8223
0
                ggml_vec_scale_f32(DV, VKQ32, ms);
8224
0
            } else {
8225
0
                vs = expf(s - M);
8226
0
            }
8227
8228
0
            S = S*ms + vs;
8229
0
        }
8230
8231
        // V /= S
8232
0
        const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8233
0
        ggml_vec_scale_f32(DV, VKQ32, S_inv);
8234
8235
        // dst indices
8236
0
        const int i1 = iq1;
8237
0
        const int i2 = iq2;
8238
0
        const int i3 = iq3;
8239
8240
        // original
8241
        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
8242
8243
        // permute(0, 2, 1, 3)
8244
0
        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8245
0
    }
8246
0
}
8247
8248
static void ggml_compute_forward_flash_attn_ext_f16(
8249
        const ggml_compute_params * params,
8250
0
        ggml_tensor * dst) {
8251
8252
0
    const ggml_tensor * q     = dst->src[0];
8253
0
    const ggml_tensor * k     = dst->src[1];
8254
0
    const ggml_tensor * v     = dst->src[2];
8255
8256
0
    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
8257
0
    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
8258
0
    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
8259
0
    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
8260
0
    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
8261
0
    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
8262
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
8263
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
8264
8265
0
    const int64_t DK = nek0;
8266
0
    const int64_t DV = nev0;
8267
0
    const int64_t N  = neq1;
8268
8269
0
    GGML_ASSERT(ne0 == DV);
8270
0
    GGML_ASSERT(ne2 == N);
8271
8272
    // input tensor rows must be contiguous
8273
0
    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8274
0
    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8275
0
    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8276
8277
0
    GGML_ASSERT(neq0 == DK);
8278
0
    GGML_ASSERT(nek0 == DK);
8279
0
    GGML_ASSERT(nev0 == DV);
8280
8281
0
    GGML_ASSERT(neq1 == N);
8282
8283
    // dst cannot be transposed or permuted
8284
0
    GGML_ASSERT(nb0 == sizeof(float));
8285
0
    GGML_ASSERT(nb0 <= nb1);
8286
0
    GGML_ASSERT(nb1 <= nb2);
8287
0
    GGML_ASSERT(nb2 <= nb3);
8288
8289
    // parallelize by q rows using ggml_vec_dot_f32
8290
8291
    // total rows in q
8292
0
    const int64_t nr = neq1*neq2*neq3;
8293
8294
    // rows per thread
8295
0
    const int ith = params->ith;
8296
0
    const int nth = params->nth;
8297
8298
    // disable for NUMA
8299
0
    const bool disable_chunking = ggml_is_numa();
8300
8301
    // 4x chunks per thread
8302
0
    int nth_scaled = nth * 4;
8303
0
    int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8304
0
    int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
8305
8306
0
    if (nth == 1 || nchunk < nth || disable_chunking) {
8307
0
        nchunk = nth;
8308
0
    }
8309
8310
0
    if (ith == 0) {
8311
        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
8312
0
        ggml_threadpool_chunk_set(params->threadpool, nth);
8313
0
    }
8314
8315
0
    ggml_barrier(params->threadpool);
8316
8317
    // The number of elements in each chunk
8318
0
    const int64_t dr = (nr + nchunk - 1) / nchunk;
8319
8320
    // The first chunk comes from our thread_id, the rest will get auto-assigned.
8321
0
    int current_chunk = ith;
8322
8323
0
    while (current_chunk < nchunk) {
8324
0
        const int64_t ir0 = dr * current_chunk;
8325
0
        const int64_t ir1 = MIN(ir0 + dr, nr);
8326
8327
0
        ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8328
8329
0
        current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8330
0
    }
8331
0
}
8332
8333
void ggml_compute_forward_flash_attn_ext(
8334
        const ggml_compute_params * params,
8335
0
        ggml_tensor * dst) {
8336
0
    switch (dst->op_params[3]) {
8337
0
        case GGML_PREC_DEFAULT:
8338
0
        case GGML_PREC_F32:
8339
0
            {
8340
                // uses F32 accumulators
8341
0
                ggml_compute_forward_flash_attn_ext_f16(params, dst);
8342
0
            } break;
8343
0
        default:
8344
0
            {
8345
0
                GGML_ABORT("fatal error");
8346
0
            }
8347
0
    }
8348
0
}
8349
8350
// ggml_compute_forward_flash_attn_back
8351
8352
static void ggml_compute_forward_flash_attn_back_f32(
8353
        const ggml_compute_params * params,
8354
        const bool masked,
8355
0
              ggml_tensor * dst) {
8356
8357
0
    const ggml_tensor * q = dst->src[0];
8358
0
    const ggml_tensor * k = dst->src[1];
8359
0
    const ggml_tensor * v = dst->src[2];
8360
0
    const ggml_tensor * d = dst->src[3];
8361
8362
0
    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
8363
0
    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
8364
0
    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
8365
0
    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
8366
0
    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
8367
0
    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
8368
0
    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)
8369
0
    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)
8370
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
8371
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
8372
8373
0
    const int ith = params->ith;
8374
0
    const int nth = params->nth;
8375
8376
0
    const int64_t D = neq0;
8377
0
    const int64_t N = neq1;
8378
0
    const int64_t P = nek1 - N;
8379
0
    const int64_t M = P + N;
8380
8381
0
    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
8382
0
    const int mxDM = MAX(D, Mup);
8383
8384
    // GGML_ASSERT(ne0 == D);
8385
    // GGML_ASSERT(ne1 == N);
8386
0
    GGML_ASSERT(P >= 0);
8387
8388
0
    GGML_ASSERT(nbq0 == sizeof(float));
8389
0
    GGML_ASSERT(nbk0 == sizeof(float));
8390
0
    GGML_ASSERT(nbv0 == sizeof(float));
8391
8392
0
    GGML_ASSERT(neq0 == D);
8393
0
    GGML_ASSERT(nek0 == D);
8394
0
    GGML_ASSERT(nev1 == D);
8395
0
    GGML_ASSERT(ned0 == D);
8396
8397
0
    GGML_ASSERT(neq1 == N);
8398
0
    GGML_ASSERT(nek1 == N + P);
8399
0
    GGML_ASSERT(nev1 == D);
8400
0
    GGML_ASSERT(ned1 == N);
8401
8402
    // dst cannot be transposed or permuted
8403
0
    GGML_ASSERT(nb0 == sizeof(float));
8404
0
    GGML_ASSERT(nb0 <= nb1);
8405
0
    GGML_ASSERT(nb1 <= nb2);
8406
0
    GGML_ASSERT(nb2 <= nb3);
8407
8408
0
    if (ith == 0) {
8409
0
        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
8410
0
    }
8411
0
    ggml_barrier(params->threadpool);
8412
8413
0
    const int64_t elem_q = ggml_nelements(q);
8414
0
    const int64_t elem_k = ggml_nelements(k);
8415
8416
0
    ggml_type result_type = dst->type;
8417
0
    GGML_ASSERT(ggml_blck_size(result_type) == 1);
8418
0
    const size_t tsize = ggml_type_size(result_type);
8419
8420
0
    const size_t offs_q = 0;
8421
0
    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
8422
0
    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
8423
8424
0
    void * grad_q = (char *) dst->data;
8425
0
    void * grad_k = (char *) dst->data + offs_k;
8426
0
    void * grad_v = (char *) dst->data + offs_v;
8427
8428
0
    const size_t nbgq1 = nb0*neq0;
8429
0
    const size_t nbgq2 = nb0*neq0*neq1;
8430
0
    const size_t nbgq3 = nb0*neq0*neq1*neq2;
8431
8432
0
    const size_t nbgk1 = nb0*nek0;
8433
0
    const size_t nbgk2 = nb0*nek0*nek1;
8434
0
    const size_t nbgk3 = nb0*nek0*nek1*neq2;
8435
8436
0
    const size_t nbgv1 = nb0*nev0;
8437
0
    const size_t nbgv2 = nb0*nev0*nev1;
8438
0
    const size_t nbgv3 = nb0*nev0*nev1*neq2;
8439
8440
    // parallelize by k rows using ggml_vec_dot_f32
8441
8442
    // total rows in k
8443
0
    const int nr = nek2*nek3;
8444
8445
    // rows per thread
8446
0
    const int dr = (nr + nth - 1)/nth;
8447
8448
    // row range for this thread
8449
0
    const int ir0 = dr*ith;
8450
0
    const int ir1 = MIN(ir0 + dr, nr);
8451
8452
0
    const float scale = 1.0f/sqrtf(D);
8453
8454
    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
8455
8456
    // how often k2 (and v2) is repeated in q2
8457
0
    int nrep = neq2/nek2;
8458
8459
0
    for (int ir = ir0; ir < ir1; ++ir) {
8460
        // q indices
8461
0
        const int ik3 = ir/(nek2);
8462
0
        const int ik2 = ir - ik3*nek2;
8463
8464
0
        const int iq3 = ik3;
8465
0
        const int id3 = ik3;
8466
0
        const int iv3 = ik3;
8467
0
        const int iv2 = ik2;
8468
8469
0
        for (int irep = 0; irep < nrep; ++irep) {
8470
0
            const int iq2 = ik2 + irep*nek2;
8471
0
            const int id2 = iq2;
8472
8473
            // (ik2 + irep*nek2) % nek2 == ik2
8474
0
            for (int iq1 = 0; iq1 < neq1; ++iq1) {
8475
0
                const int id1 = iq1;
8476
8477
                // not sure about CACHE_LINE_SIZE_F32..
8478
                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
8479
0
                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
8480
0
                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
8481
8482
0
                for (int i = M; i < Mup; ++i) {
8483
0
                    S[i] = -INFINITY;
8484
0
                }
8485
8486
0
                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
8487
0
                for (int64_t ic = 0; ic < masked_begin; ++ic) {
8488
                    // k indices
8489
0
                    const int ik1 = ic;
8490
8491
                    // S indices
8492
0
                    const int i1 = ik1;
8493
8494
0
                    ggml_vec_dot_f32(neq0,
8495
0
                            S + i1, 0,
8496
0
                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
8497
0
                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
8498
0
                }
8499
8500
                // scale
8501
0
                ggml_vec_scale_f32(masked_begin, S, scale);
8502
8503
0
                for (int64_t i = masked_begin; i < M; i++) {
8504
0
                    S[i] = -INFINITY;
8505
0
                }
8506
8507
                // softmax
8508
                // exclude known -INF S[..] values from max and loop
8509
                // dont forget to set their SM values to zero
8510
0
                {
8511
0
                    float max = -INFINITY;
8512
0
                    ggml_vec_max_f32(masked_begin, &max, S);
8513
8514
0
                    ggml_float sum = 0.0;
8515
0
                    {
8516
#ifdef GGML_SOFT_MAX_ACCELERATE
8517
                        max = -max;
8518
                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
8519
                        vvexpf(SM, SM, &Mup);
8520
                        ggml_vec_sum_f32(Mup, &sum, SM);
8521
#else
8522
0
                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
8523
0
#endif
8524
0
                    }
8525
8526
0
                    assert(sum > 0.0);
8527
8528
0
                    sum = 1.0/sum;
8529
0
                    ggml_vec_scale_f32(masked_begin, SM, sum);
8530
8531
0
                }
8532
8533
                // step-by-step explanation
8534
0
                {
8535
                    // forward-process                    shape      grads from backward process
8536
                    // parallel_for ik2,ik3:
8537
                    //  for irep:
8538
                    //   iq2 = ik2 + irep*nek2
8539
                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]
8540
                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
8541
                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]
8542
                    //   for iq1:
8543
                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
8544
                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
8545
                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
8546
                    //    S0     = -Inf                   [D,1,1,1]
8547
                    //   ~S1[i]  = dot(kcur[:D,i], qcur)
8548
                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
8549
                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
8550
                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
8551
                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
8552
                    //   ~S5[i]  = dot(vcur[:,i], S4)
8553
                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]
8554
                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
8555
                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
8556
                    // dst                               backward-/ grad[dst]                 = d
8557
                    //
8558
                    // output gradients with their dependencies:
8559
                    //
8560
                    // grad[kcur] = grad[S1].T @ qcur
8561
                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
8562
                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
8563
                    // grad[S4]   = grad[S5] @ vcur
8564
                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
8565
                    // grad[qcur] = grad[S1]   @ kcur
8566
                    // grad[vcur] = grad[S5].T @ S4
8567
                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
8568
                    //
8569
                    // in post-order:
8570
                    //
8571
                    // S1         = qcur @ kcur.T
8572
                    // S2         = S1 * scale
8573
                    // S3         = diag_mask_inf(S2, P)
8574
                    // S4         = softmax(S3)
8575
                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
8576
                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
8577
                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
8578
                    // grad[qcur] = grad[S1]   @ kcur
8579
                    // grad[kcur] = grad[S1].T @ qcur
8580
                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
8581
                    //
8582
                    // using less variables (SM=S4):
8583
                    //
8584
                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
8585
                    // SM            = softmax(S)
8586
                    // S             = d[:D,iq1,iq2,iq3] @ vcur
8587
                    // dot_SM_gradSM = dot(SM, S)
8588
                    // S             = SM * (S - dot(SM, S))
8589
                    // S             = diag_mask_zero(S, P) * scale
8590
                    //
8591
                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
8592
                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur
8593
                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM
8594
0
                }
8595
8596
                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
8597
                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
8598
                // for ic:
8599
                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
8600
                // exclude known future zero S[..] values from operation
8601
0
                ggml_vec_set_f32(masked_begin, S, 0);
8602
0
                for (int64_t ic = 0; ic < D; ++ic) {
8603
0
                    ggml_vec_mad_f32(masked_begin,
8604
0
                            S,
8605
0
                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),
8606
0
                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
8607
0
                }
8608
8609
                // S = SM * (S - dot(SM, S))
8610
0
                float dot_SM_gradSM = 0;
8611
0
                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
8612
0
                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
8613
0
                ggml_vec_mul_f32 (masked_begin, S, S, SM);
8614
8615
                // S = diag_mask_zero(S, P) * scale
8616
                // already done by above ggml_vec_set_f32
8617
8618
                // exclude known zero S[..] values from operation
8619
0
                ggml_vec_scale_f32(masked_begin, S, scale);
8620
8621
                // S    shape [M,1]
8622
                // SM   shape [M,1]
8623
                // kcur shape [D,M]
8624
                // qcur shape [D,1]
8625
                // vcur shape [M,D]
8626
8627
                // grad[q][:D,iq1,iq2,iq3] += S @ kcur
8628
                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
8629
                // for ic:
8630
                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
8631
                // exclude known zero S[..] values from loop
8632
0
                for (int64_t ic = 0; ic < masked_begin; ++ic) {
8633
0
                    ggml_vec_mad_f32(D,
8634
0
                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),
8635
0
                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),
8636
0
                            S[ic]);
8637
0
                }
8638
8639
                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
8640
                // for ic:
8641
                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
8642
                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
8643
                // exclude known zero S[..] values from loop
8644
0
                for (int64_t ic = 0; ic < masked_begin; ++ic) {
8645
0
                    ggml_vec_mad_f32(D,
8646
0
                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),
8647
0
                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),
8648
0
                            S[ic]);
8649
0
                }
8650
8651
                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM
8652
                // for ic:
8653
                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
8654
                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]
8655
                // exclude known zero SM[..] values from mad
8656
0
                for (int64_t ic = 0; ic < D; ++ic) {
8657
0
                    ggml_vec_mad_f32(masked_begin,
8658
0
                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
8659
0
                            SM,
8660
0
                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));
8661
0
                }
8662
0
            }
8663
0
        }
8664
0
    }
8665
0
}
8666
8667
void ggml_compute_forward_flash_attn_back(
8668
        const ggml_compute_params * params,
8669
        const bool masked,
8670
0
        ggml_tensor * dst) {
8671
8672
0
    const ggml_tensor * q = dst->src[0];
8673
8674
0
    switch (q->type) {
8675
0
        case GGML_TYPE_F32:
8676
0
            {
8677
0
                ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
8678
0
            } break;
8679
0
        default:
8680
0
            {
8681
0
                GGML_ABORT("fatal error");
8682
0
            }
8683
0
    }
8684
0
}
8685
8686
// ggml_compute_forward_ssm_conv
8687
8688
static void ggml_compute_forward_ssm_conv_f32(
8689
        const ggml_compute_params * params,
8690
0
        ggml_tensor * dst) {
8691
0
    const ggml_tensor * src0 = dst->src[0]; // conv_x
8692
0
    const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
8693
8694
0
    const int ith = params->ith;
8695
0
    const int nth = params->nth;
8696
8697
0
    const int nc  = src1->ne[0]; // d_conv
8698
0
    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
8699
0
    const int nr  = src0->ne[1]; // d_inner
8700
0
    const int n_t =  dst->ne[1]; // tokens per sequence
8701
0
    const int n_s =  dst->ne[2]; // number of sequences in the batch
8702
8703
0
    GGML_ASSERT( dst->ne[0] == nr);
8704
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
8705
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
8706
0
    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
8707
8708
    // rows per thread
8709
0
    const int dr = (nr + nth - 1)/nth;
8710
8711
    // row range for this thread
8712
0
    const int ir0 = dr*ith;
8713
0
    const int ir1 = MIN(ir0 + dr, nr);
8714
0
    const int ir  = ir1 - ir0;
8715
8716
0
    for (int i3 = 0; i3 < n_s; ++i3) {
8717
0
        for (int i2 = 0; i2 < n_t; ++i2) {
8718
            // {d_conv - 1 + n_t, d_inner, n_seqs}
8719
            // sliding window
8720
0
            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
8721
0
            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
8722
0
            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
8723
8724
            // TODO: transpose the output for smaller strides for big batches?
8725
            // d_inner
8726
0
            for (int i1 = 0; i1 < ir; ++i1) {
8727
                // rowwise dot product
8728
                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
8729
0
                float sumf = 0.0f;
8730
8731
                // d_conv
8732
0
                for (int i0 = 0; i0 < nc; ++i0) {
8733
0
                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
8734
0
                }
8735
0
                x[i1] = sumf;
8736
0
            }
8737
0
        }
8738
0
    }
8739
0
}
8740
8741
void ggml_compute_forward_ssm_conv(
8742
        const ggml_compute_params * params,
8743
0
        ggml_tensor * dst) {
8744
0
    switch (dst->src[0]->type) {
8745
0
        case GGML_TYPE_F32:
8746
0
            {
8747
0
                ggml_compute_forward_ssm_conv_f32(params, dst);
8748
0
            } break;
8749
0
        default:
8750
0
            {
8751
0
                GGML_ABORT("fatal error");
8752
0
            }
8753
0
    }
8754
0
}
8755
8756
// ggml_compute_forward_ssm_scan
8757
8758
static void ggml_compute_forward_ssm_scan_f32(
8759
        const ggml_compute_params * params,
8760
0
        ggml_tensor * dst) {
8761
0
    const ggml_tensor * src0 = dst->src[0]; // s  {d_state, dim, n_head, n_seqs+}
8762
0
    const ggml_tensor * src1 = dst->src[1]; // x  {dim, n_head, n_seq_tokens, n_seqs}
8763
0
    const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8764
0
    const ggml_tensor * src3 = dst->src[3]; // A  {d_state, n_head} or {1, n_head}
8765
0
    const ggml_tensor * src4 = dst->src[4]; // B  {d_state, n_group, n_seq_tokens, n_seqs}
8766
0
    const ggml_tensor * src5 = dst->src[5]; // C  {d_state, n_group, n_seq_tokens, n_seqs}
8767
0
    const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
8768
8769
0
    const int ith = params->ith;
8770
0
    const int nth = params->nth;
8771
8772
0
    const int64_t nc = src0->ne[0]; // d_state
8773
0
    const int64_t nr = src0->ne[1]; // dim
8774
0
    const int64_t nh = src1->ne[1]; // n_head
8775
0
    const int64_t ng = src4->ne[1];
8776
0
    const int64_t nt = src1->ne[2]; // number of tokens per sequence
8777
0
    const int64_t ns = src1->ne[3]; // number of sequences in the batch
8778
8779
    // can't use ggml_nbytes because src1 is not necessarily contiguous
8780
0
    const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
8781
8782
0
    GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
8783
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
8784
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
8785
0
    GGML_ASSERT(src2->nb[0] == sizeof(float));
8786
0
    GGML_ASSERT(src3->nb[0] == sizeof(float));
8787
0
    GGML_ASSERT(src4->nb[0] == sizeof(float));
8788
0
    GGML_ASSERT(src5->nb[0] == sizeof(float));
8789
0
    GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8790
0
    GGML_ASSERT(nh % ng == 0);
8791
8792
    // heads per thread
8793
0
    const int dh = (nh + nth - 1)/nth;
8794
8795
    // head range for this thread
8796
0
    const int ih0 = dh*ith;
8797
0
    const int ih1 = MIN(ih0 + dh, nh);
8798
8799
0
    const int32_t * ids = (const int32_t *) src6->data;
8800
8801
0
    for (int i3 = 0; i3 < ns; ++i3) {
8802
0
        const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8803
0
              float * s  = (      float *) ((      char *) dst->data  + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8804
8805
0
        for (int i2 = 0; i2 < nt; ++i2) {
8806
0
            const float * x  = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8807
0
            const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8808
0
            const float * A  = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8809
0
            const float * B  = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8810
0
            const float * C  = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8811
0
                  float * y  = (      float *) ((      char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8812
8813
0
            if (src3->ne[0] == 1) {
8814
                // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8815
8816
                // n_head
8817
0
                for (int h = ih0; h < ih1; ++h) {
8818
                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8819
0
                    const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
8820
0
                    const float dA = expf(dt_soft_plus * A[h]);
8821
0
                    const int g = h / (nh / ng); // repeat_interleave
8822
8823
                    // dim
8824
0
                    for (int i1 = 0; i1 < nr; ++i1) {
8825
0
                        const int ii = i1 + h*nr;
8826
0
                        const float x_dt = x[ii] * dt_soft_plus;
8827
0
                        float sumf = 0.0f;
8828
0
#if defined(GGML_SIMD)
8829
    #if defined(__ARM_FEATURE_SVE)
8830
                        const int ggml_f32_epr = svcntw();
8831
                        const int ggml_f32_step = 1 * ggml_f32_epr;
8832
8833
                        const int np = (nc & ~(ggml_f32_step - 1));
8834
8835
                        GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8836
8837
                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8838
                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8839
8840
                        for (int i = 0; i < np; i += ggml_f32_step) {
8841
                            // TODO: maybe unroll more?
8842
                            for (int j = 0; j < 1; j++) {
8843
                                GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8844
                                GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
8845
                                GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
8846
8847
                                t0 = GGML_F32_VEC_MUL(t0, adA);
8848
                                t1 = GGML_F32_VEC_MUL(t1, axdt);
8849
8850
                                t0 = GGML_F32_VEC_ADD(t0, t1);
8851
8852
                                sum = GGML_F32_VEC_FMA(sum, t0, t2);
8853
8854
                                GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8855
                            }
8856
                        }
8857
8858
                        sumf = GGML_F32xt_REDUCE_ONE(sum);
8859
    #elif defined(__riscv_v_intrinsic)
8860
                        // todo: RVV implementation
8861
                        const int np = 0;
8862
    #else
8863
0
                        const int np = (nc & ~(GGML_F32_STEP - 1));
8864
8865
0
                        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8866
8867
0
                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8868
0
                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8869
8870
0
                        GGML_F32_VEC ax[GGML_F32_ARR];
8871
0
                        GGML_F32_VEC ay[GGML_F32_ARR];
8872
0
                        GGML_F32_VEC az[GGML_F32_ARR];
8873
8874
0
                        for (int i = 0; i < np; i += GGML_F32_STEP) {
8875
0
                            for (int j = 0; j < GGML_F32_ARR; j++) {
8876
0
                                ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
8877
0
                                ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
8878
0
                                az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
8879
8880
0
                                ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
8881
0
                                ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
8882
8883
0
                                ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
8884
8885
0
                                sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8886
8887
0
                                GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
8888
0
                            }
8889
0
                        }
8890
8891
                        // reduce sum0..sum3 to sum0
8892
0
                        GGML_F32_VEC_REDUCE(sumf, sum);
8893
0
    #endif
8894
#else
8895
                        const int np = 0;
8896
#endif
8897
                        // d_state
8898
0
                        for (int i0 = np; i0 < nc; ++i0) {
8899
0
                            const int i = i0 + ii*nc;
8900
0
                            const int ig = i0 + g*nc;
8901
                            // state = prev_state * dA + dB * x
8902
0
                            const float state = (s0[i] * dA) + (B[ig] * x_dt);
8903
                            // y = rowwise_dotprod(state, C)
8904
0
                            sumf += state * C[ig];
8905
0
                            s[i] = state;
8906
0
                        }
8907
0
                        y[ii] = sumf;
8908
0
                    }
8909
0
                }
8910
0
            } else {
8911
                // Mamba-1 has an element-wise decay factor for the states
8912
8913
                // n_head
8914
0
                for (int h = ih0; h < ih1; ++h) {
8915
                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8916
0
                    const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
8917
0
                    const int g = h / (nh / ng); // repeat_interleave
8918
8919
                    // dim
8920
0
                    for (int i1 = 0; i1 < nr; ++i1) {
8921
0
                        const int ii = i1 + h*nr;
8922
0
                        const float x_dt = x[ii] * dt_soft_plus;
8923
#if defined(__ARM_FEATURE_SVE)
8924
                        svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8925
                        svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8926
                        svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8927
8928
                        // d_state
8929
                        // TODO: what happens when (d_state % svcntw()) != 0?
8930
                        for (int64_t k = 0; k < nc; k += svcntw()) {
8931
                            svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8932
                            svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
8933
                            svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
8934
                            svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8935
8936
                            svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8937
                            t1 = exp_ps_sve(svptrue_b32(), t1);
8938
                            svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8939
8940
                            vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8941
                            r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8942
8943
                            GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8944
                        }
8945
                        y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8946
#else
8947
0
                        float sumf = 0.0f;
8948
                        // NOTE: can't really use GGML_SIMD here because d_state is usually 16
8949
                        //       and also because expf is used within the loop.
8950
                        // d_state
8951
0
                        for (int i0 = 0; i0 < nc; ++i0) {
8952
0
                            const int i = i0 + ii*nc;
8953
0
                            const int ig = i0 + g*nc;
8954
                            // state = prev_state * dA + dB * x
8955
0
                            const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8956
                            // y = rowwise_dotprod(state, C)
8957
0
                            sumf += state * C[ig];
8958
0
                            s[i] = state;
8959
0
                        }
8960
0
                        y[ii] = sumf;
8961
0
#endif
8962
0
                    }
8963
0
                }
8964
0
            }
8965
            // use the output as the source when it's not the first token-wise iteration
8966
0
            s0 = s;
8967
0
        }
8968
0
    }
8969
0
}
8970
8971
void ggml_compute_forward_ssm_scan(
8972
        const ggml_compute_params * params,
8973
0
        ggml_tensor * dst) {
8974
0
    switch (dst->src[0]->type) {
8975
0
        case GGML_TYPE_F32:
8976
0
            {
8977
0
                ggml_compute_forward_ssm_scan_f32(params, dst);
8978
0
            } break;
8979
0
        default:
8980
0
            {
8981
0
                GGML_ABORT("fatal error");
8982
0
            }
8983
0
    }
8984
0
}
8985
8986
// ggml_compute_forward_win_part
8987
8988
static void ggml_compute_forward_win_part_f32(
8989
        const ggml_compute_params * params,
8990
0
        ggml_tensor * dst) {
8991
0
    GGML_UNUSED(params);
8992
8993
0
    const ggml_tensor * src0 = dst->src[0];
8994
8995
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
8996
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
8997
8998
0
    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
8999
0
    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
9000
0
    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
9001
9002
0
    assert(ne00 == ne0);
9003
0
    assert(ne3  == nep0*nep1);
9004
9005
    // TODO: optimize / multi-thread
9006
0
    for (int py = 0; py < nep1; ++py) {
9007
0
        for (int px = 0; px < nep0; ++px) {
9008
0
            const int64_t i3 = py*nep0 + px;
9009
0
            for (int64_t i2 = 0; i2 < ne2; ++i2) {
9010
0
                for (int64_t i1 = 0; i1 < ne1; ++i1) {
9011
0
                    for (int64_t i0 = 0; i0 < ne0; ++i0) {
9012
0
                        const int64_t i02 = py*w + i2;
9013
0
                        const int64_t i01 = px*w + i1;
9014
0
                        const int64_t i00 = i0;
9015
9016
0
                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;
9017
0
                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;
9018
9019
0
                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
9020
0
                            ((float *) dst->data)[i] = 0.0f;
9021
0
                        } else {
9022
0
                            ((float *) dst->data)[i] = ((float *) src0->data)[j];
9023
0
                        }
9024
0
                    }
9025
0
                }
9026
0
            }
9027
0
        }
9028
0
    }
9029
0
}
9030
9031
void ggml_compute_forward_win_part(
9032
        const ggml_compute_params * params,
9033
0
        ggml_tensor * dst) {
9034
9035
0
    const ggml_tensor * src0 = dst->src[0];
9036
9037
0
    switch (src0->type) {
9038
0
        case GGML_TYPE_F32:
9039
0
            {
9040
0
                ggml_compute_forward_win_part_f32(params, dst);
9041
0
            } break;
9042
0
        default:
9043
0
            {
9044
0
                GGML_ABORT("fatal error");
9045
0
            }
9046
0
    }
9047
0
}
9048
9049
// ggml_compute_forward_win_unpart
9050
9051
static void ggml_compute_forward_win_unpart_f32(
9052
        const ggml_compute_params * params,
9053
0
        ggml_tensor * dst) {
9054
0
    GGML_UNUSED(params);
9055
9056
0
    const ggml_tensor * src0 = dst->src[0];
9057
9058
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
9059
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
9060
9061
0
    const int32_t w = ((const int32_t *)(dst->op_params))[0];
9062
9063
    // padding
9064
0
    const int px = (w - ne1%w)%w;
9065
    //const int py = (w - ne2%w)%w;
9066
9067
0
    const int npx = (px + ne1)/w;
9068
    //const int npy = (py + ne2)/w;
9069
9070
0
    assert(ne0 == ne00);
9071
9072
    // TODO: optimize / multi-thread
9073
0
    for (int64_t i2 = 0; i2 < ne2; ++i2) {
9074
0
        for (int64_t i1 = 0; i1 < ne1; ++i1) {
9075
0
            for (int64_t i0 = 0; i0 < ne0; ++i0) {
9076
0
                const int ip2 = i2/w;
9077
0
                const int ip1 = i1/w;
9078
9079
0
                const int64_t i02 = i2%w;
9080
0
                const int64_t i01 = i1%w;
9081
0
                const int64_t i00 = i0;
9082
9083
0
                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
9084
0
                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;
9085
9086
0
                ((float *) dst->data)[j] = ((float *) src0->data)[i];
9087
0
            }
9088
0
        }
9089
0
    }
9090
0
}
9091
9092
void ggml_compute_forward_win_unpart(
9093
        const ggml_compute_params * params,
9094
0
        ggml_tensor * dst) {
9095
9096
0
    const ggml_tensor * src0 = dst->src[0];
9097
9098
0
    switch (src0->type) {
9099
0
        case GGML_TYPE_F32:
9100
0
            {
9101
0
                ggml_compute_forward_win_unpart_f32(params, dst);
9102
0
            } break;
9103
0
        default:
9104
0
            {
9105
0
                GGML_ABORT("fatal error");
9106
0
            }
9107
0
    }
9108
0
}
9109
9110
//gmml_compute_forward_unary
9111
9112
void ggml_compute_forward_unary(
9113
        const ggml_compute_params * params,
9114
0
        ggml_tensor * dst) {
9115
9116
0
    const ggml_unary_op op = ggml_get_unary_op(dst);
9117
9118
0
    switch (op) {
9119
0
        case GGML_UNARY_OP_ABS:
9120
0
            {
9121
0
                ggml_compute_forward_abs(params, dst);
9122
0
            } break;
9123
0
        case GGML_UNARY_OP_SGN:
9124
0
            {
9125
0
                ggml_compute_forward_sgn(params, dst);
9126
0
            } break;
9127
0
        case GGML_UNARY_OP_NEG:
9128
0
            {
9129
0
                ggml_compute_forward_neg(params, dst);
9130
0
            } break;
9131
0
        case GGML_UNARY_OP_STEP:
9132
0
            {
9133
0
                ggml_compute_forward_step(params, dst);
9134
0
            } break;
9135
0
        case GGML_UNARY_OP_TANH:
9136
0
            {
9137
0
                ggml_compute_forward_tanh(params, dst);
9138
0
            } break;
9139
0
        case GGML_UNARY_OP_ELU:
9140
0
            {
9141
0
                ggml_compute_forward_elu(params, dst);
9142
0
            } break;
9143
0
        case GGML_UNARY_OP_RELU:
9144
0
            {
9145
0
                ggml_compute_forward_relu(params, dst);
9146
0
            } break;
9147
0
        case GGML_UNARY_OP_SIGMOID:
9148
0
            {
9149
0
                ggml_compute_forward_sigmoid(params, dst);
9150
0
            } break;
9151
0
        case GGML_UNARY_OP_GELU:
9152
0
            {
9153
0
                ggml_compute_forward_gelu(params, dst);
9154
0
            } break;
9155
0
        case GGML_UNARY_OP_GELU_ERF:
9156
0
            {
9157
0
                ggml_compute_forward_gelu_erf(params, dst);
9158
0
            } break;
9159
0
        case GGML_UNARY_OP_GELU_QUICK:
9160
0
            {
9161
0
                ggml_compute_forward_gelu_quick(params, dst);
9162
0
            } break;
9163
0
        case GGML_UNARY_OP_SILU:
9164
0
            {
9165
0
                ggml_compute_forward_silu(params, dst);
9166
0
            } break;
9167
0
        case GGML_UNARY_OP_HARDSWISH:
9168
0
            {
9169
0
                ggml_compute_forward_hardswish(params, dst);
9170
0
            } break;
9171
0
        case GGML_UNARY_OP_HARDSIGMOID:
9172
0
            {
9173
0
                ggml_compute_forward_hardsigmoid(params, dst);
9174
0
            } break;
9175
0
        case GGML_UNARY_OP_EXP:
9176
0
            {
9177
0
                ggml_compute_forward_exp(params, dst);
9178
0
            } break;
9179
0
        case GGML_UNARY_OP_FLOOR:
9180
0
            {
9181
0
                ggml_compute_forward_floor(params, dst);
9182
0
            } break;
9183
0
        case GGML_UNARY_OP_CEIL:
9184
0
            {
9185
0
                ggml_compute_forward_ceil(params, dst);
9186
0
            } break;
9187
0
        case GGML_UNARY_OP_ROUND:
9188
0
            {
9189
0
                ggml_compute_forward_round(params, dst);
9190
0
            } break;
9191
0
        case GGML_UNARY_OP_TRUNC:
9192
0
            {
9193
0
                ggml_compute_forward_trunc(params, dst);
9194
0
            } break;
9195
0
        case GGML_UNARY_OP_XIELU:
9196
0
            {
9197
0
                ggml_compute_forward_xielu(params, dst);
9198
0
            } break;
9199
0
        case GGML_UNARY_OP_EXPM1:
9200
0
            {
9201
0
                ggml_compute_forward_expm1(params, dst);
9202
0
            } break;
9203
0
        case GGML_UNARY_OP_SOFTPLUS:
9204
0
            {
9205
0
                ggml_compute_forward_softplus(params, dst);
9206
0
            } break;
9207
0
        default:
9208
0
            {
9209
0
                GGML_ABORT("fatal error");
9210
0
            }
9211
0
    }
9212
0
}
9213
9214
//ggml_compute_forward_glu
9215
9216
void ggml_compute_forward_glu(
9217
        const ggml_compute_params * params,
9218
0
        ggml_tensor * dst) {
9219
9220
0
    const ggml_glu_op op = ggml_get_glu_op(dst);
9221
9222
0
    switch (op) {
9223
0
        case GGML_GLU_OP_REGLU:
9224
0
            {
9225
0
                ggml_compute_forward_reglu(params, dst);
9226
0
            } break;
9227
0
        case GGML_GLU_OP_GEGLU:
9228
0
            {
9229
0
                ggml_compute_forward_geglu(params, dst);
9230
0
            } break;
9231
0
        case GGML_GLU_OP_SWIGLU:
9232
0
            {
9233
0
                ggml_compute_forward_swiglu(params, dst);
9234
0
            } break;
9235
0
        case GGML_GLU_OP_SWIGLU_OAI:
9236
0
            {
9237
0
                ggml_compute_forward_swiglu_oai(params, dst);
9238
0
            } break;
9239
0
        case GGML_GLU_OP_GEGLU_ERF:
9240
0
            {
9241
0
                ggml_compute_forward_geglu_erf(params, dst);
9242
0
            } break;
9243
0
        case GGML_GLU_OP_GEGLU_QUICK:
9244
0
            {
9245
0
                ggml_compute_forward_geglu_quick(params, dst);
9246
0
            } break;
9247
0
        default:
9248
0
            {
9249
0
                GGML_ABORT("fatal error");
9250
0
            }
9251
0
    }
9252
0
}
9253
9254
// ggml_compute_forward_get_rel_pos
9255
9256
static void ggml_compute_forward_get_rel_pos_f16(
9257
        const ggml_compute_params * params,
9258
0
        ggml_tensor * dst) {
9259
0
    GGML_UNUSED(params);
9260
9261
0
    const ggml_tensor * src0 = dst->src[0];
9262
9263
    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
9264
9265
0
    GGML_TENSOR_UNARY_OP_LOCALS
9266
9267
0
    const int64_t w = ne1;
9268
9269
0
    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
9270
0
    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;
9271
9272
0
    for (int64_t i2 = 0; i2 < ne2; ++i2) {
9273
0
        for (int64_t i1 = 0; i1 < ne1; ++i1) {
9274
0
            const int64_t pos = (w - i1 - 1) + i2;
9275
0
            for (int64_t i0 = 0; i0 < ne0; ++i0) {
9276
0
                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
9277
0
            }
9278
0
        }
9279
0
    }
9280
0
}
9281
9282
void ggml_compute_forward_get_rel_pos(
9283
        const ggml_compute_params * params,
9284
0
        ggml_tensor * dst) {
9285
9286
0
    const ggml_tensor * src0 = dst->src[0];
9287
9288
0
    switch (src0->type) {
9289
0
        case GGML_TYPE_F16:
9290
0
        case GGML_TYPE_BF16:
9291
0
            {
9292
0
                ggml_compute_forward_get_rel_pos_f16(params, dst);
9293
0
            } break;
9294
0
        default:
9295
0
            {
9296
0
                GGML_ABORT("fatal error");
9297
0
            }
9298
0
    }
9299
0
}
9300
9301
// ggml_compute_forward_add_rel_pos
9302
9303
static void ggml_compute_forward_add_rel_pos_f32(
9304
        const ggml_compute_params * params,
9305
0
        ggml_tensor * dst) {
9306
9307
0
    const ggml_tensor * src0 = dst->src[0];
9308
0
    const ggml_tensor * src1 = dst->src[1];
9309
0
    const ggml_tensor * src2 = dst->src[2];
9310
9311
0
    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
9312
0
    if (!inplace) {
9313
0
        if (params->ith == 0) {
9314
0
            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
9315
0
        }
9316
0
        ggml_barrier(params->threadpool);
9317
0
    }
9318
    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
9319
9320
0
    float * src1_data = (float *) src1->data;
9321
0
    float * src2_data = (float *) src2->data;
9322
0
    float * dst_data  = (float *) dst->data;
9323
9324
0
    const int64_t ne10 = src1->ne[0];
9325
0
    const int64_t ne11 = src1->ne[1];
9326
0
    const int64_t ne12 = src1->ne[2];
9327
0
    const int64_t ne13 = src1->ne[3];
9328
9329
0
    const int ith = params->ith;
9330
0
    const int nth = params->nth;
9331
9332
    // total patches in dst
9333
0
    const int np = ne13;
9334
9335
    // patches per thread
9336
0
    const int dp = (np + nth - 1)/nth;
9337
9338
    // patch range for this thread
9339
0
    const int ip0 = dp*ith;
9340
0
    const int ip1 = MIN(ip0 + dp, np);
9341
9342
0
    for (int64_t i13 = ip0; i13 < ip1; ++i13) {
9343
0
        for (int64_t i12 = 0; i12 < ne12; ++i12) {
9344
0
            for (int64_t i11 = 0; i11 < ne11; ++i11) {
9345
0
                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
9346
0
                for (int64_t i10 = 0; i10 < ne10; ++i10) {
9347
0
                    const int64_t jp0  = jp1 + i10;
9348
0
                    const float src1_e = src1_data[jp0];
9349
0
                    const float src2_e = src2_data[jp0];
9350
9351
0
                    const int64_t jdh = jp0 * ne10;
9352
0
                    const int64_t jdw = jdh - (ne10 - 1) * i10;
9353
9354
0
                    for (int64_t j = 0; j < ne10; ++j) {
9355
0
                        dst_data[jdh + j     ] += src2_e;
9356
0
                        dst_data[jdw + j*ne10] += src1_e;
9357
0
                    }
9358
0
                }
9359
0
            }
9360
0
        }
9361
0
    }
9362
0
}
9363
9364
void ggml_compute_forward_add_rel_pos(
9365
        const ggml_compute_params * params,
9366
0
        ggml_tensor * dst) {
9367
9368
0
    const ggml_tensor * src0 = dst->src[0];
9369
9370
0
    switch (src0->type) {
9371
0
        case GGML_TYPE_F32:
9372
0
            {
9373
0
                ggml_compute_forward_add_rel_pos_f32(params, dst);
9374
0
            } break;
9375
0
        default:
9376
0
            {
9377
0
                GGML_ABORT("fatal error");
9378
0
            }
9379
0
    }
9380
0
}
9381
9382
// ggml_compute_forward_rwkv_wkv6
9383
9384
static void ggml_compute_forward_rwkv_wkv6_f32(
9385
        const ggml_compute_params * params,
9386
0
        ggml_tensor * dst) {
9387
0
    const int64_t T = dst->src[1]->ne[2];
9388
0
    const int64_t C = dst->ne[0];
9389
0
    const int64_t HEADS = dst->src[1]->ne[1];
9390
0
    const int64_t n_seqs = dst->src[5]->ne[1];
9391
0
    const int64_t head_size = C / HEADS;
9392
9393
0
    float * dst_data = (float *) dst->data;
9394
0
    float * state = ((float *) dst->data) + C * T;
9395
9396
0
    const int ith = params->ith;
9397
0
    const int nth = params->nth;
9398
9399
0
    if (ith >= HEADS) {
9400
0
        return;
9401
0
    }
9402
9403
0
    const int h_start = (HEADS * ith) / nth;
9404
0
    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9405
0
                (HEADS * (ith + 1)) / nth : HEADS;
9406
9407
0
    float * k =          (float *) dst->src[0]->data;
9408
0
    float * v =          (float *) dst->src[1]->data;
9409
0
    float * r =          (float *) dst->src[2]->data;
9410
0
    float * time_faaaa = (float *) dst->src[3]->data;
9411
0
    float * time_decay = (float *) dst->src[4]->data;
9412
9413
0
    size_t t_stride = HEADS * head_size; // Same to C
9414
9415
0
    size_t h_stride = C / HEADS;
9416
0
    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
9417
0
    size_t h_stride_2d = head_size * head_size;
9418
9419
0
    if (ith == 0) {
9420
0
        memset(dst_data, 0, T * C * sizeof(float));
9421
0
    }
9422
0
    ggml_barrier(params->threadpool);
9423
9424
9425
0
    #if defined(__AVX__) && !defined(__AVX512F__)
9426
0
        #define GGML_F32X GGML_F32x8
9427
0
        #define GGML_F32X_SET1 GGML_F32x8_SET1
9428
0
        #define GGML_F32X_LOAD GGML_F32x8_LOAD
9429
0
        #define GGML_F32X_STORE GGML_F32x8_STORE
9430
0
        #define GGML_F32X_MUL GGML_F32x8_MUL
9431
0
        #define GGML_F32X_FMA GGML_F32x8_FMA
9432
0
        #define WKV_VECTOR_SIZE 8
9433
    #elif defined(__AVX512F__)
9434
        #define GGML_F32X GGML_F32x16
9435
        #define GGML_F32X_SET1 GGML_F32x16_SET1
9436
        #define GGML_F32X_LOAD GGML_F32x16_LOAD
9437
        #define GGML_F32X_STORE GGML_F32x16_STORE
9438
        #define GGML_F32X_MUL GGML_F32x16_MUL
9439
        #define GGML_F32X_FMA GGML_F32x16_FMA
9440
        #define WKV_VECTOR_SIZE 16
9441
    #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9442
        #define GGML_F32X GGML_F32xt
9443
        #define GGML_F32X_SET1 GGML_F32xt_SET1
9444
        #define GGML_F32X_LOAD GGML_F32xt_LOAD
9445
        #define GGML_F32X_STORE GGML_F32xt_STORE
9446
        #define GGML_F32X_MUL GGML_F32xt_MUL
9447
        #define GGML_F32X_FMA GGML_F32xt_FMA
9448
        #define WKV_VECTOR_SIZE 8
9449
    #elif defined(__ARM_NEON) && defined(__aarch64__)
9450
        #define GGML_F32X GGML_F32x4
9451
        #define GGML_F32X_SET1 GGML_F32x4_SET1
9452
        #define GGML_F32X_LOAD GGML_F32x4_LOAD
9453
        #define GGML_F32X_STORE GGML_F32x4_STORE
9454
        #define GGML_F32X_MUL GGML_F32x4_MUL
9455
        #define GGML_F32X_FMA GGML_F32x4_FMA
9456
        #define WKV_VECTOR_SIZE 4
9457
    #endif
9458
9459
0
    #ifdef WKV_VECTOR_SIZE
9460
0
        int wkv_vector_size;
9461
        #if defined(__ARM_FEATURE_SVE)
9462
            wkv_vector_size = svcntw();
9463
        #else
9464
0
            wkv_vector_size = WKV_VECTOR_SIZE;
9465
0
        #endif
9466
0
        const int64_t vec_count = head_size / wkv_vector_size;
9467
9468
0
        for (int64_t t = 0; t < T; t++) {
9469
0
            size_t t_offset = t * t_stride;
9470
0
            size_t state_offset = head_size * C * (t / (T / n_seqs));
9471
0
            float * state_cur = state + state_offset;
9472
0
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
9473
9474
0
            for (int64_t h = h_start; h < h_end; h++) {
9475
0
                size_t h_offset = h * h_stride;
9476
0
                size_t t_h_offset = t_offset + h_offset;
9477
0
                size_t h_2d_offset = h * h_stride_2d;
9478
9479
0
                for (int64_t i = 0; i < head_size; i++) {
9480
0
                    size_t t_h_i_offset = t_h_offset + i;
9481
0
                    size_t h_i_offset = h_offset + i;
9482
0
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9483
9484
0
                    float k_val = k[t_h_i_offset];
9485
0
                    float r_val = r[t_h_i_offset];
9486
0
                    float time_faaaa_val = time_faaaa[h_i_offset];
9487
0
                    float time_decay_val = time_decay[t_h_i_offset];
9488
9489
                    // Broadcast scalar values to vectors
9490
0
                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
9491
0
                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);
9492
0
                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
9493
0
                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
9494
9495
0
                    for (int64_t j = 0; j < vec_count; j++) {
9496
0
                        size_t base_j = j * wkv_vector_size;
9497
0
                        size_t t_h_j_offset = t_h_offset + base_j;
9498
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
9499
9500
                        // Load x elements at once
9501
0
                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
9502
0
                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
9503
0
                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
9504
9505
                        // Compute kv = v * k
9506
0
                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
9507
9508
                        // Compute temp = kv * time_faaaa + prev_state
9509
0
                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
9510
9511
                        // Update dst: dst += temp * r
9512
0
                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
9513
0
                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
9514
9515
                        // Update state: state = prev_state * time_decay + kv
9516
0
                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
9517
0
                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
9518
0
                    }
9519
9520
                    // Handle remaining elements, this will not be used.
9521
0
                    for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
9522
0
                        size_t t_h_j_offset = t_h_offset + j;
9523
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
9524
0
                        float v_val = v[t_h_j_offset];
9525
0
                        float kv_val = v_val * k_val;
9526
0
                        float prev_state_val = state_prev[h_2d_i_j_offset];
9527
0
                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
9528
0
                        dst_data[t_h_j_offset] += temp_val * r_val;
9529
0
                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
9530
0
                    }
9531
0
                }
9532
0
            }
9533
0
        }
9534
9535
    #else
9536
        // basically fused operations:
9537
        // dst = r @ (time_faaaa * (k @ v) + state),
9538
        // state = time_decay * state + (k @ v),
9539
        // recursive through each token
9540
        for (int64_t t = 0; t < T; t++) {
9541
            size_t t_offset = t * t_stride;
9542
            size_t state_offset = head_size * C * (t / (T / n_seqs));
9543
            float * state_cur = state + state_offset;
9544
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
9545
9546
            for (int64_t h = h_start; h < h_end; h++) {
9547
                size_t h_offset = h * h_stride;
9548
                size_t t_h_offset = t_offset + h_offset;
9549
                size_t h_2d_offset = h * h_stride_2d;
9550
9551
                for (int64_t i = 0; i < head_size; i++) {
9552
                    size_t t_h_i_offset = t_h_offset + i;
9553
                    size_t h_i_offset = h_offset + i;
9554
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9555
9556
                    float k_val = k[t_h_i_offset];
9557
                    float r_val = r[t_h_i_offset];
9558
                    float time_faaaa_val = time_faaaa[h_i_offset];
9559
                    // RWKV v6: different time_decay for each token.
9560
                    float time_decay_val = time_decay[t_h_i_offset];
9561
9562
                    for (int64_t j = 0; j < head_size; j++) {
9563
                        size_t t_h_j_offset = t_h_offset + j;
9564
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
9565
9566
                        float v_val = v[t_h_j_offset];
9567
                        float kv_val = v_val * k_val;
9568
                        float prev_state_val = state_prev[h_2d_i_j_offset];
9569
                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
9570
                        dst_data[t_h_j_offset] += temp_val * r_val;
9571
                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
9572
                    }
9573
                }
9574
            }
9575
        }
9576
    #endif
9577
0
}
9578
9579
9580
void ggml_compute_forward_rwkv_wkv6(
9581
        const ggml_compute_params * params,
9582
0
        ggml_tensor * dst) {
9583
9584
0
    const ggml_tensor * src0 = dst->src[0];
9585
9586
0
    switch (src0->type) {
9587
0
        case GGML_TYPE_F32:
9588
0
            {
9589
0
                ggml_compute_forward_rwkv_wkv6_f32(params, dst);
9590
0
            } break;
9591
0
        default:
9592
0
            {
9593
0
                GGML_ABORT("fatal error");
9594
0
            }
9595
0
    }
9596
0
}
9597
9598
// ggml_compute_forward_gla
9599
9600
static void ggml_compute_forward_gla_f32(
9601
        const ggml_compute_params * params,
9602
0
        ggml_tensor * dst) {
9603
0
    const int64_t T = dst->src[1]->ne[2];
9604
0
    const int64_t C = dst->ne[0];
9605
0
    const int64_t HEADS = dst->src[1]->ne[1];
9606
0
    const int64_t n_seqs = dst->src[4]->ne[1];
9607
0
    const int64_t head_size = C / HEADS;
9608
0
    const float scale = ggml_get_op_params_f32(dst, 0);
9609
9610
0
    float * dst_data = (float *) dst->data;
9611
0
    float * state = ((float *) dst->data) + C * T;
9612
9613
0
    const int ith = params->ith;
9614
0
    const int nth = params->nth;
9615
9616
0
    if (ith >= HEADS) {
9617
0
        return;
9618
0
    }
9619
9620
0
    const int h_start = (HEADS * ith) / nth;
9621
0
    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9622
0
                (HEADS * (ith + 1)) / nth : HEADS;
9623
9624
0
    float * k = (float *) dst->src[0]->data;
9625
0
    float * v = (float *) dst->src[1]->data;
9626
0
    float * q = (float *) dst->src[2]->data;
9627
0
    float * g = (float *) dst->src[3]->data;
9628
9629
0
    size_t t_stride = HEADS * head_size; // Same to C
9630
9631
0
    size_t h_stride = C / HEADS;
9632
0
    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
9633
0
    size_t h_stride_2d = head_size * head_size;
9634
9635
0
    if (ith == 0) {
9636
0
        memset(dst_data, 0, T * C * sizeof(float));
9637
0
    }
9638
0
    ggml_barrier(params->threadpool);
9639
9640
9641
0
    #if defined(__AVX__) && !defined(__AVX512F__)
9642
0
        #define GGML_F32X GGML_F32x8
9643
0
        #define GGML_F32X_SET1 GGML_F32x8_SET1
9644
0
        #define GGML_F32X_LOAD GGML_F32x8_LOAD
9645
0
        #define GGML_F32X_STORE GGML_F32x8_STORE
9646
0
        #define GGML_F32X_MUL GGML_F32x8_MUL
9647
0
        #define GGML_F32X_FMA GGML_F32x8_FMA
9648
0
        #define GLA_VECTOR_SIZE 8
9649
    #elif defined(__AVX512F__)
9650
        #define GGML_F32X GGML_F32x16
9651
        #define GGML_F32X_SET1 GGML_F32x16_SET1
9652
        #define GGML_F32X_LOAD GGML_F32x16_LOAD
9653
        #define GGML_F32X_STORE GGML_F32x16_STORE
9654
        #define GGML_F32X_MUL GGML_F32x16_MUL
9655
        #define GGML_F32X_FMA GGML_F32x16_FMA
9656
        #define GLA_VECTOR_SIZE 16
9657
    #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9658
        #define GGML_F32X GGML_F32xt
9659
        #define GGML_F32X_SET1 GGML_F32xt_SET1
9660
        #define GGML_F32X_LOAD GGML_F32xt_LOAD
9661
        #define GGML_F32X_STORE GGML_F32xt_STORE
9662
        #define GGML_F32X_MUL GGML_F32xt_MUL
9663
        #define GGML_F32X_FMA GGML_F32xt_FMA
9664
        #define GLA_VECTOR_SIZE 8
9665
    #elif defined(__ARM_NEON) && defined(__aarch64__)
9666
        #define GGML_F32X GGML_F32x4
9667
        #define GGML_F32X_SET1 GGML_F32x4_SET1
9668
        #define GGML_F32X_LOAD GGML_F32x4_LOAD
9669
        #define GGML_F32X_STORE GGML_F32x4_STORE
9670
        #define GGML_F32X_MUL GGML_F32x4_MUL
9671
        #define GGML_F32X_FMA GGML_F32x4_FMA
9672
        #define GLA_VECTOR_SIZE 4
9673
    #endif
9674
9675
0
    #ifdef GLA_VECTOR_SIZE
9676
0
        int gla_vector_size;
9677
        #if defined(__ARM_FEATURE_SVE)
9678
            gla_vector_size = svcntw();
9679
        #else
9680
0
            gla_vector_size = GLA_VECTOR_SIZE;
9681
0
        #endif
9682
0
        const int64_t vec_count = head_size / gla_vector_size;
9683
9684
0
        for (int64_t t = 0; t < T; t++) {
9685
0
            size_t t_offset = t * t_stride;
9686
0
            size_t state_offset = head_size * C * (t / (T / n_seqs));
9687
0
            float * state_cur = state + state_offset;
9688
0
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
9689
9690
0
            for (int64_t h = h_start; h < h_end; h++) {
9691
0
                size_t h_offset = h * h_stride;
9692
0
                size_t t_h_offset = t_offset + h_offset;
9693
0
                size_t h_2d_offset = h * h_stride_2d;
9694
9695
0
                for (int64_t i = 0; i < head_size; i++) {
9696
0
                    size_t t_h_i_offset = t_h_offset + i;
9697
0
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9698
9699
0
                    float k_val = k[t_h_i_offset];
9700
0
                    float q_val = q[t_h_i_offset] * scale;
9701
0
                    float g_val = g[t_h_i_offset];
9702
9703
                    // Broadcast scalar values to vectors
9704
0
                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
9705
0
                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
9706
0
                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
9707
9708
0
                    for (int64_t j = 0; j < vec_count; j++) {
9709
0
                        size_t base_j = j * gla_vector_size;
9710
0
                        size_t t_h_j_offset = t_h_offset + base_j;
9711
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
9712
9713
                        // Load x elements at once
9714
0
                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
9715
0
                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
9716
0
                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
9717
9718
                        // Compute kv = v * k
9719
0
                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
9720
9721
                        // Compute temp = prev_state * g + kv
9722
0
                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
9723
9724
                        // Update dst: dst += temp * q
9725
0
                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
9726
0
                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
9727
9728
                        // Update state
9729
0
                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
9730
0
                    }
9731
9732
                    // Handle remaining elements, this will not be used.
9733
0
                    for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
9734
0
                        size_t t_h_j_offset = t_h_offset + j;
9735
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
9736
0
                        float v_val = v[t_h_j_offset];
9737
0
                        float kv_val = v_val * k_val;
9738
0
                        float prev_state_val = state_prev[h_2d_i_j_offset];
9739
0
                        float temp_val = kv_val + prev_state_val * g_val;
9740
0
                        dst_data[t_h_j_offset] += temp_val * q_val;
9741
0
                        state_cur[h_2d_i_j_offset] = temp_val;
9742
0
                    }
9743
0
                }
9744
0
            }
9745
0
        }
9746
9747
    #else
9748
        for (int64_t t = 0; t < T; t++) {
9749
            size_t t_offset = t * t_stride;
9750
            size_t state_offset = head_size * C * (t / (T / n_seqs));
9751
            float * state_cur = state + state_offset;
9752
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
9753
9754
            for (int64_t h = h_start; h < h_end; h++) {
9755
                size_t h_offset = h * h_stride;
9756
                size_t t_h_offset = t_offset + h_offset;
9757
                size_t h_2d_offset = h * h_stride_2d;
9758
9759
                for (int64_t i = 0; i < head_size; i++) {
9760
                    size_t t_h_i_offset = t_h_offset + i;
9761
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9762
9763
                    float k_val = k[t_h_i_offset];
9764
                    float q_val = q[t_h_i_offset] * scale;
9765
                    float g_val = g[t_h_i_offset];
9766
9767
                    for (int64_t j = 0; j < head_size; j++) {
9768
                        size_t t_h_j_offset = t_h_offset + j;
9769
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
9770
9771
                        float v_val = v[t_h_j_offset];
9772
                        float kv_val = v_val * k_val;
9773
                        float prev_state_val = state_prev[h_2d_i_j_offset];
9774
                        float temp_val = prev_state_val * g_val + kv_val;
9775
                        dst_data[t_h_j_offset] += temp_val * q_val;
9776
                        state_cur[h_2d_i_j_offset] = temp_val;
9777
                    }
9778
                }
9779
            }
9780
        }
9781
    #endif
9782
0
}
9783
9784
9785
void ggml_compute_forward_gla(
9786
        const ggml_compute_params * params,
9787
0
        ggml_tensor * dst) {
9788
9789
0
    const ggml_tensor * src0 = dst->src[0];
9790
9791
0
    switch (src0->type) {
9792
0
        case GGML_TYPE_F32:
9793
0
            {
9794
0
                ggml_compute_forward_gla_f32(params, dst);
9795
0
            } break;
9796
0
        default:
9797
0
            {
9798
0
                GGML_ABORT("fatal error");
9799
0
            }
9800
0
    }
9801
0
}
9802
9803
0
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
9804
0
    const struct ggml_tensor * src0 = dst->src[0];  // A (lower triangular)
9805
0
    const struct ggml_tensor * src1 = dst->src[1];  // B (RHS)
9806
9807
0
    GGML_TENSOR_BINARY_OP_LOCALS;
9808
9809
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
9810
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
9811
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
9812
9813
0
    GGML_ASSERT(ne00 == ne01); // A must be square
9814
0
    GGML_ASSERT(ne0  == ne10); // solution cols == B cols
9815
0
    GGML_ASSERT(ne1  == ne11); // solution rows == B rows
9816
9817
0
    GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
9818
0
    GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
9819
9820
0
    const int ith = params->ith;
9821
0
    const int nth = params->nth;
9822
9823
0
    const int64_t k = ne10;   // number of RHS columns
9824
0
    const int64_t n = ne11;   // A is n×n
9825
0
    const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
9826
9827
    // chunks per thread
9828
0
    const int64_t dr = (nr + nth - 1)/nth;
9829
9830
    // chunk range for this thread
9831
0
    const int64_t ir0 = dr*ith;
9832
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
9833
9834
0
    const float * A = (const float *) src0->data;  // [n, n, B1, B2]
9835
0
    const float * B = (const float *) src1->data;  // [n, k, B1, B2]
9836
0
          float * X = (      float *) dst->data;   // [n, k, B1, B2]
9837
9838
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
9839
0
        const int64_t i03 = ir/(ne02*k);
9840
0
        const int64_t i02 = (ir - i03*ne02*k)/k;
9841
0
        const int64_t i01 = (ir - i03*ne02*k - i02*k);
9842
9843
0
        const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
9844
0
        const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
9845
9846
0
        float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
9847
9848
0
        for (int64_t i00 = 0; i00 < n; ++i00) {
9849
0
            float sum = 0.0f;
9850
0
            for (int64_t t = 0; t < i00; ++t) {
9851
0
                sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
9852
0
            }
9853
9854
0
            const float diag = A_batch[i00 * n + i00];
9855
0
            assert(diag != 0.0f && "Zero diagonal in triangular matrix");
9856
9857
0
            X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
9858
0
        }
9859
0
    }
9860
0
}
9861
9862
0
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
9863
0
    const ggml_tensor * src0 = dst->src[0];
9864
0
    const ggml_tensor * src1 = dst->src[1];
9865
9866
0
    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
9867
0
        ggml_compute_forward_solve_tri_f32(params, dst);
9868
0
    } else {
9869
0
        GGML_ABORT("fatal error");
9870
0
    }
9871
0
}
9872
9873
// ggml_compute_forward_rwkv_wkv7
9874
9875
static void ggml_compute_forward_rwkv_wkv7_f32(
9876
        const ggml_compute_params * params,
9877
0
        ggml_tensor * dst) {
9878
0
    const int64_t T = dst->src[1]->ne[2];
9879
0
    const int64_t C = dst->ne[0];
9880
0
    const int64_t HEADS = dst->src[1]->ne[1];
9881
0
    const int64_t n_seqs = dst->src[6]->ne[1];
9882
0
    const int64_t head_size = C / HEADS;
9883
9884
0
    float * dst_data = (float *) dst->data;
9885
0
    float * state = ((float *) dst->data) + C * T;
9886
9887
0
    const int ith = params->ith;
9888
0
    const int nth = params->nth;
9889
9890
0
    if (ith >= HEADS) {
9891
0
        return;
9892
0
    }
9893
9894
0
    const int h_start = (HEADS * ith) / nth;
9895
0
    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9896
0
                (HEADS * (ith + 1)) / nth : HEADS;
9897
9898
0
    float * r = (float *) dst->src[0]->data;
9899
0
    float * w = (float *) dst->src[1]->data;
9900
0
    float * k = (float *) dst->src[2]->data;
9901
0
    float * v = (float *) dst->src[3]->data;
9902
0
    float * a = (float *) dst->src[4]->data;
9903
0
    float * b = (float *) dst->src[5]->data;
9904
9905
0
    int64_t t_stride = HEADS * head_size; // Same to C
9906
9907
0
    int64_t h_stride = C / HEADS;
9908
0
    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
9909
0
    int64_t h_stride_2d = head_size * head_size;
9910
9911
0
    #if defined(GGML_SIMD)
9912
        #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9913
            // scalar Route to scalar implementation       //TODO: Write SVE code and RVV code
9914
            for (int64_t t = 0; t < T; t++) {
9915
                int64_t t_offset = t * t_stride;
9916
                int64_t state_offset = head_size * C * (t / (T / n_seqs));
9917
                float * state_cur = state + state_offset;
9918
                float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9919
9920
                for (int64_t h = h_start; h < h_end; h++) {
9921
                    int64_t h_offset = h * h_stride;
9922
                    int64_t t_h_offset = t_offset + h_offset;
9923
                    int64_t h_2d_offset = h * h_stride_2d;
9924
9925
                    for (int64_t i = 0; i < head_size; i++) {
9926
                        int64_t t_h_i_offset = t_h_offset + i;
9927
                        int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
9928
9929
                        float v_val = v[t_h_i_offset];
9930
9931
                        float sa = 0, result = 0;
9932
                        for (int64_t j = 0; j < head_size; j++) {
9933
                            sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
9934
                        }
9935
9936
                        for (int64_t j = 0; j < head_size; j++) {
9937
                            int64_t t_h_j_offset = t_h_offset + j;
9938
                            int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9939
9940
                            float r_val = r[t_h_j_offset];
9941
                            float w_val = w[t_h_j_offset];
9942
                            float k_val = k[t_h_j_offset];
9943
                            float b_val = b[t_h_j_offset];
9944
                            float kv_val = v_val * k_val;
9945
                            float prev_state_val = state_prev[h_2d_i_j_offset];
9946
                            state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9947
                            result += state_cur[h_2d_i_j_offset] * r_val;
9948
                        }
9949
                        dst_data[t_h_i_offset] = result;
9950
                    }
9951
                }
9952
            }
9953
        #else
9954
0
            for (int64_t t = 0; t < T; t++) {
9955
0
                int64_t t_offset = t * t_stride;
9956
0
                int64_t state_offset = head_size * C * (t / (T / n_seqs));
9957
0
                float * state_cur = state + state_offset;
9958
0
                float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9959
9960
0
                for (int64_t h = h_start; h < h_end; h++) {
9961
0
                    int64_t h_offset = h * h_stride;
9962
0
                    int64_t t_h_offset = t_offset + h_offset;
9963
0
                    int64_t h_2d_offset = h * h_stride_2d;
9964
9965
0
                    for (int64_t ii = 0; ii < head_size; ii++) {
9966
0
                        int64_t t_h_i_offset = t_h_offset + ii;
9967
0
                        int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9968
9969
0
                        GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
9970
9971
0
                        float sa = 0;
9972
0
                        {
9973
0
                            GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9974
0
                            GGML_F32_VEC ax[GGML_F32_ARR];
9975
0
                            GGML_F32_VEC ay[GGML_F32_ARR];
9976
0
                            for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
9977
0
                                for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9978
0
                                    ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
9979
0
                                    ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9980
0
                                    sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
9981
0
                                }
9982
0
                            }
9983
0
                            GGML_F32_VEC_REDUCE(sa, sum);
9984
0
                        }
9985
9986
0
                        GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
9987
9988
0
                        int64_t j = 0;
9989
0
                        GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9990
0
                        for (; j < head_size; j += GGML_F32_STEP) {
9991
0
                            for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9992
0
                                int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9993
0
                                int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
9994
9995
0
                                GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
9996
0
                                GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
9997
0
                                GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
9998
0
                                GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
9999
10000
0
                                k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
10001
10002
0
                                GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
10003
                                // kv + s * decay + sa * b
10004
0
                                state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
10005
0
                                state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
10006
0
                                GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
10007
10008
0
                                result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
10009
0
                            }
10010
0
                        }
10011
0
                        GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
10012
10013
                        // There shouldn't be left-overs though.
10014
0
                        for (; j < head_size; j++) {
10015
0
                            int64_t t_h_j_offset = t_h_offset + j;
10016
0
                            int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10017
10018
0
                            float r_val = r[t_h_j_offset];
10019
0
                            float w_val = w[t_h_j_offset];
10020
0
                            float k_val = k[t_h_j_offset];
10021
0
                            float b_val = b[t_h_j_offset];
10022
0
                            float kv_val = v[t_h_i_offset] * k_val;
10023
10024
0
                            float prev_state_val = state_prev[h_2d_i_j_offset];
10025
0
                            state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10026
0
                            dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
10027
0
                        }
10028
0
                    }
10029
0
                }
10030
0
            }
10031
0
        #endif
10032
    #else
10033
        for (int64_t t = 0; t < T; t++) {
10034
            int64_t t_offset = t * t_stride;
10035
            int64_t state_offset = head_size * C * (t / (T / n_seqs));
10036
            float * state_cur = state + state_offset;
10037
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10038
10039
            for (int64_t h = h_start; h < h_end; h++) {
10040
                int64_t h_offset = h * h_stride;
10041
                int64_t t_h_offset = t_offset + h_offset;
10042
                int64_t h_2d_offset = h * h_stride_2d;
10043
10044
                for (int64_t i = 0; i < head_size; i++) {
10045
                    int64_t t_h_i_offset = t_h_offset + i;
10046
                    int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
10047
10048
                    float v_val = v[t_h_i_offset];
10049
10050
                    float sa = 0, result = 0;
10051
                    for (int64_t j = 0; j < head_size; j++) {
10052
                        sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
10053
                    }
10054
10055
                    for (int64_t j = 0; j < head_size; j++) {
10056
                        int64_t t_h_j_offset = t_h_offset + j;
10057
                        int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10058
10059
                        float r_val = r[t_h_j_offset];
10060
                        float w_val = w[t_h_j_offset];
10061
                        float k_val = k[t_h_j_offset];
10062
                        float b_val = b[t_h_j_offset];
10063
                        float kv_val = v_val * k_val;
10064
                        float prev_state_val = state_prev[h_2d_i_j_offset];
10065
                        state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10066
                        result += state_cur[h_2d_i_j_offset] * r_val;
10067
                    }
10068
                    dst_data[t_h_i_offset] = result;
10069
                }
10070
            }
10071
        }
10072
    #endif
10073
0
}
10074
10075
10076
void ggml_compute_forward_rwkv_wkv7(
10077
        const ggml_compute_params * params,
10078
0
        ggml_tensor * dst) {
10079
10080
0
    const ggml_tensor * src0 = dst->src[0];
10081
10082
0
    switch (src0->type) {
10083
0
        case GGML_TYPE_F32:
10084
0
            {
10085
0
                ggml_compute_forward_rwkv_wkv7_f32(params, dst);
10086
0
            } break;
10087
0
        default:
10088
0
            {
10089
0
                GGML_ABORT("fatal error");
10090
0
            }
10091
0
    }
10092
0
}
10093
10094
// ggml_compute_forward_map_custom1
10095
10096
void ggml_compute_forward_map_custom1(
10097
        const ggml_compute_params * params,
10098
0
              ggml_tensor * dst) {
10099
10100
0
    const ggml_tensor * a = dst->src[0];
10101
10102
0
    struct ggml_map_custom1_op_params p;
10103
0
    memcpy(&p, dst->op_params, sizeof(p));
10104
10105
0
    p.fun(dst, a, params->ith, params->nth, p.userdata);
10106
0
}
10107
10108
// ggml_compute_forward_map_custom2
10109
10110
void ggml_compute_forward_map_custom2(
10111
        const ggml_compute_params * params,
10112
0
              ggml_tensor * dst) {
10113
10114
0
    const ggml_tensor * a = dst->src[0];
10115
0
    const ggml_tensor * b = dst->src[1];
10116
10117
0
    struct ggml_map_custom2_op_params p;
10118
0
    memcpy(&p, dst->op_params, sizeof(p));
10119
10120
0
    p.fun(dst, a, b, params->ith, params->nth, p.userdata);
10121
0
}
10122
10123
// ggml_compute_forward_map_custom3
10124
10125
void ggml_compute_forward_map_custom3(
10126
        const ggml_compute_params * params,
10127
0
              ggml_tensor * dst) {
10128
10129
0
    const ggml_tensor * a = dst->src[0];
10130
0
    const ggml_tensor * b = dst->src[1];
10131
0
    const ggml_tensor * c = dst->src[2];
10132
10133
0
    struct ggml_map_custom3_op_params p;
10134
0
    memcpy(&p, dst->op_params, sizeof(p));
10135
10136
0
    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
10137
0
}
10138
10139
// ggml_compute_forward_custom
10140
10141
void ggml_compute_forward_custom(
10142
    const struct ggml_compute_params * params,
10143
0
          struct ggml_tensor * dst) {
10144
10145
0
    struct ggml_custom_op_params p;
10146
0
    memcpy(&p, dst->op_params, sizeof(p));
10147
10148
0
    p.fun(dst, params->ith, params->nth, p.userdata);
10149
0
}
10150
10151
// ggml_compute_forward_cross_entropy_loss
10152
10153
static void ggml_compute_forward_cross_entropy_loss_f32(
10154
        const ggml_compute_params * params,
10155
0
        ggml_tensor * dst) {
10156
10157
0
    const ggml_tensor * src0 = dst->src[0];
10158
0
    const ggml_tensor * src1 = dst->src[1];
10159
10160
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
10161
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
10162
0
    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
10163
0
    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
10164
0
    GGML_ASSERT(ggml_are_same_shape(src0, src1));
10165
0
    GGML_ASSERT(ggml_is_scalar(dst));
10166
0
    GGML_ASSERT(dst->type == GGML_TYPE_F32);
10167
10168
    // TODO: handle transposed/permuted matrices
10169
0
    const int64_t nc = src0->ne[0];
10170
0
    const int64_t nr = ggml_nrows(src0);
10171
10172
0
    const int ith = params->ith;
10173
0
    const int nth = params->nth;
10174
10175
0
    float * sums =  (float *) params->wdata;
10176
0
    float * st   = ((float *) params->wdata) + nth + ith*nc;
10177
0
    float sum_thread = 0.0f;
10178
10179
0
    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
10180
10181
    // rows per thread
10182
0
    const int64_t dr = (nr + nth - 1)/nth;
10183
10184
    // row range for this thread
10185
0
    const int64_t ir0 = dr*ith;
10186
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
10187
10188
0
    for (int64_t i1 = ir0; i1 < ir1; ++i1) {
10189
0
        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
10190
0
        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
10191
10192
#ifndef NDEBUG
10193
        for (int64_t i = 0; i < nc; ++i) {
10194
            //printf("p[%d] = %f\n", i, p[i]);
10195
            assert(!isnan(s0[i]));
10196
            assert(!isnan(s1[i]));
10197
        }
10198
#endif
10199
10200
0
        float max = -INFINITY;
10201
0
        ggml_vec_max_f32(nc, &max, s0);
10202
0
        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
10203
0
        assert(sum_softmax >= 0.0);
10204
10205
0
        ggml_vec_add1_f32(nc, st, st, -sum_softmax);
10206
0
        ggml_vec_mul_f32(nc, st, st, s1);
10207
10208
0
        float sum_st = 0.0f;
10209
0
        ggml_vec_sum_f32(nc, &sum_st, st);
10210
0
        sum_thread += sum_st;
10211
10212
#ifndef NDEBUG
10213
        for (int64_t i = 0; i < nc; ++i) {
10214
            assert(!isnan(st[i]));
10215
            assert(!isinf(st[i]));
10216
        }
10217
#endif
10218
0
    }
10219
0
    sums[ith] = sum_thread;
10220
0
    ggml_barrier(params->threadpool);
10221
10222
0
    if (ith == 0) {
10223
0
        float * dp = (float *) dst->data;
10224
0
        ggml_vec_sum_f32(nth, dp, sums);
10225
0
        dp[0] *= -1.0f / (float) nr;
10226
0
    }
10227
0
}
10228
10229
void ggml_compute_forward_cross_entropy_loss(
10230
        const ggml_compute_params * params,
10231
0
        ggml_tensor * dst) {
10232
10233
0
    const ggml_tensor * src0 = dst->src[0];
10234
10235
0
    switch (src0->type) {
10236
0
        case GGML_TYPE_F32:
10237
0
            {
10238
0
                ggml_compute_forward_cross_entropy_loss_f32(params, dst);
10239
0
            } break;
10240
0
        default:
10241
0
            {
10242
0
                GGML_ABORT("fatal error");
10243
0
            }
10244
0
    }
10245
0
}
10246
10247
// ggml_compute_forward_cross_entropy_loss_back
10248
10249
static void ggml_compute_forward_cross_entropy_loss_back_f32(
10250
        const ggml_compute_params * params,
10251
0
        ggml_tensor * dst) {
10252
10253
0
    const ggml_tensor * grad  = dst->src[0]; // gradient of forward pass output
10254
0
    const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
10255
0
    const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
10256
10257
0
    GGML_ASSERT(ggml_is_contiguous(dst));
10258
0
    GGML_ASSERT(ggml_is_contiguous(src0f));
10259
0
    GGML_ASSERT(ggml_is_contiguous(src1f));
10260
0
    GGML_ASSERT(ggml_is_contiguous(grad));
10261
0
    GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
10262
10263
0
    const int64_t ith = params->ith;
10264
0
    const int64_t nth = params->nth;
10265
10266
    // TODO: handle transposed/permuted matrices
10267
0
    const int64_t nc = src0f->ne[0];
10268
0
    const int64_t nr = ggml_nrows(src0f);
10269
10270
    // rows per thread
10271
0
    const int64_t dr = (nr + nth - 1)/nth;
10272
10273
    // row range for this thread
10274
0
    const int64_t ir0 = dr*ith;
10275
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
10276
10277
0
    const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
10278
10279
0
    for (int64_t i1 = ir0; i1 < ir1; i1++) {
10280
0
        float       * ds0 = (float       *)((char       *) dst->data   + i1*dst->nb[1]);
10281
0
        const float * s0  = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
10282
0
        const float * s1  = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
10283
10284
#ifndef NDEBUG
10285
        for (int64_t i = 0; i < nc; ++i) {
10286
            //printf("p[%d] = %f\n", i, p[i]);
10287
            assert(!isnan(s0[i]));
10288
            assert(!isnan(s1[i]));
10289
        }
10290
#endif
10291
10292
        // soft_max
10293
0
        float max = -INFINITY;
10294
0
        ggml_vec_max_f32(nc, &max, s0);
10295
0
        const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
10296
0
        assert(sum > 0.0);
10297
0
        ggml_vec_scale_f32(nc, ds0, 1.0/sum);
10298
10299
        // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
10300
0
        ggml_vec_sub_f32(nc, ds0, ds0, s1);
10301
0
        ggml_vec_scale_f32(nc, ds0, d_by_nr);
10302
10303
#ifndef NDEBUG
10304
        for (int64_t i = 0; i < nc; ++i) {
10305
            assert(!isnan(ds0[i]));
10306
            assert(!isinf(ds0[i]));
10307
        }
10308
#endif
10309
0
    }
10310
0
}
10311
10312
void ggml_compute_forward_cross_entropy_loss_back(
10313
        const ggml_compute_params * params,
10314
0
        ggml_tensor * dst) {
10315
10316
0
    const ggml_tensor * src0 = dst->src[0];
10317
10318
0
    switch (src0->type) {
10319
0
        case GGML_TYPE_F32:
10320
0
            {
10321
0
                ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
10322
0
            } break;
10323
0
        default:
10324
0
            {
10325
0
                GGML_ABORT("fatal error");
10326
0
            }
10327
0
    }
10328
0
}
10329
10330
static void ggml_compute_forward_opt_step_adamw_f32(
10331
        const ggml_compute_params * params,
10332
0
        ggml_tensor * dst) {
10333
10334
0
    const ggml_tensor * src0         = dst->src[0];
10335
0
    const ggml_tensor * src0_grad    = dst->src[1];
10336
0
    const ggml_tensor * src0_grad_m  = dst->src[2];
10337
0
    const ggml_tensor * src0_grad_v  = dst->src[3];
10338
0
    const ggml_tensor * adamw_params = dst->src[4];
10339
10340
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10341
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
10342
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
10343
0
    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
10344
10345
0
    const int ith = params->ith;
10346
0
    const int nth = params->nth;
10347
10348
0
    const int nr  = ggml_nrows(src0);
10349
10350
0
    GGML_TENSOR_UNARY_OP_LOCALS
10351
0
    GGML_ASSERT(nb00 == sizeof(float));
10352
10353
    // rows per thread
10354
0
    const int dr = (nr + nth - 1)/nth;
10355
10356
    // row range for this thread
10357
0
    const int ir0 = dr*ith;
10358
0
    const int ir1 = MIN(ir0 + dr, nr);
10359
10360
0
    const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
10361
10362
0
    const float alpha  = adamw_params_ptr[0];
10363
0
    const float beta1  = adamw_params_ptr[1];
10364
0
    const float beta2  = adamw_params_ptr[2];
10365
0
    const float eps    = adamw_params_ptr[3];
10366
0
    const float wd     = adamw_params_ptr[4];
10367
0
    const float beta1h = adamw_params_ptr[5];
10368
0
    const float beta2h = adamw_params_ptr[6];
10369
0
    const float keep   = 1.f - alpha * wd;
10370
0
    for (int ir = ir0; ir < ir1; ++ir) {
10371
0
        const int64_t i03 = ir/(ne02*ne01);
10372
0
        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10373
0
        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10374
10375
0
        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
10376
10377
0
        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
10378
0
        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
10379
0
        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
10380
0
        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
10381
10382
0
        for (int i00 = 0; i00 < ne00; ++i00) {
10383
0
            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
10384
0
            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
10385
10386
0
            const float mh =       m[i00]*beta1h;
10387
0
            const float vh = sqrtf(v[i00]*beta2h) + eps;
10388
10389
            // The weight decay is applied independently of the Adam momenta m and v.
10390
            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
10391
            // See: https://arxiv.org/pdf/1711.05101v3.pdf
10392
0
            w[i00] = w[i00] * keep - alpha * mh / vh;
10393
0
        }
10394
0
    }
10395
0
}
10396
10397
void ggml_compute_forward_opt_step_adamw(
10398
        const ggml_compute_params * params,
10399
0
        ggml_tensor * dst) {
10400
10401
0
    const ggml_tensor * src0 = dst->src[0];
10402
10403
0
    switch (src0->type) {
10404
0
        case GGML_TYPE_F32:
10405
0
            {
10406
0
                ggml_compute_forward_opt_step_adamw_f32(params, dst);
10407
0
            } break;
10408
0
        default:
10409
0
            {
10410
0
                GGML_ABORT("fatal error");
10411
0
            }
10412
0
    }
10413
0
}
10414
10415
0
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10416
0
    const ggml_tensor * src0       = dst->src[0];
10417
0
    const ggml_tensor * src0_grad  = dst->src[1];
10418
0
    const ggml_tensor * sgd_params = dst->src[2];
10419
10420
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10421
0
    GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10422
10423
0
    const int ith = params->ith;
10424
0
    const int nth = params->nth;
10425
10426
0
    const int nr = ggml_nrows(src0);
10427
10428
0
    GGML_TENSOR_UNARY_OP_LOCALS
10429
0
    GGML_ASSERT(nb00 == sizeof(float));
10430
10431
    // rows per thread
10432
0
    const int dr = (nr + nth - 1) / nth;
10433
10434
    // row range for this thread
10435
0
    const int ir0 = dr * ith;
10436
0
    const int ir1 = MIN(ir0 + dr, nr);
10437
10438
    // using adamw param subset we care about - alpha, wd - could have a separate struct
10439
0
    const float * sgd_params_ptr   = ggml_get_data_f32(sgd_params);
10440
0
    const float   alpha            = sgd_params_ptr[0];
10441
0
    const float   keep             = 1.f - alpha * sgd_params_ptr[1];
10442
10443
0
    for (int ir = ir0; ir < ir1; ++ir) {
10444
0
        const int64_t i03 = ir / (ne02 * ne01);
10445
0
        const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10446
0
        const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10447
10448
0
        const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10449
10450
0
        float *       w = (float *) ((char *) src0->data + offset);                   // weight
10451
0
        const float * g = (const float *) ((const char *) src0_grad->data + offset);  // grad
10452
10453
0
        for (int i00 = 0; i00 < ne00; ++i00) {
10454
0
            w[i00] = w[i00] * keep - alpha * g[i00];
10455
0
        }
10456
0
    }
10457
0
}
10458
10459
0
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10460
0
    const ggml_tensor * src0 = dst->src[0];
10461
10462
0
    switch (src0->type) {
10463
0
        case GGML_TYPE_F32:
10464
0
            {
10465
0
                ggml_compute_forward_opt_step_sgd_f32(params, dst);
10466
0
            }
10467
0
            break;
10468
0
        default:
10469
0
            {
10470
0
                GGML_ABORT("fatal error - sgd is F32 only");
10471
0
            }
10472
0
    }
10473
0
}