Coverage Report

Created: 2026-01-18 06:10

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp
Line
Count
Source
1
#include "ops.h"
2
3
#include "ggml-cpu.h"
4
#include "ggml-impl.h"
5
#include "binary-ops.h"
6
#include "ggml.h"
7
#include "unary-ops.h"
8
#include "vec.h"
9
10
#include <algorithm>
11
#include <cfloat>
12
#include <cmath>
13
14
// ggml_compute_forward_dup
15
16
static void ggml_compute_forward_dup_same_cont(
17
        const ggml_compute_params * params,
18
0
        ggml_tensor * dst) {
19
20
0
    const ggml_tensor * src0 = dst->src[0];
21
22
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
23
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
24
0
    GGML_ASSERT(src0->type == dst->type);
25
26
0
    const size_t nb0 = ggml_type_size(src0->type);
27
28
0
    const int ith = params->ith; // thread index
29
0
    const int nth = params->nth; // number of threads
30
31
    // parallelize by blocks
32
0
    const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
33
0
    const int dr = (nk + nth - 1) / nth;
34
0
    const int k0 = dr * ith;
35
0
    const int k1 = MIN(k0 + dr, nk);
36
37
0
    if (k0 < k1) {
38
0
        memcpy(
39
0
            ((char *)  dst->data + k0*nb0),
40
0
            ((char *) src0->data + k0*nb0),
41
0
            (k1 - k0) * nb0);
42
0
    }
43
0
}
44
45
template<typename src_t, typename dst_t>
46
static void ggml_compute_forward_dup_flt(
47
        const ggml_compute_params * params,
48
0
        ggml_tensor * dst) {
49
50
0
    const ggml_tensor * src0 = dst->src[0];
51
52
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
53
0
    GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
54
55
0
    GGML_TENSOR_UNARY_OP_LOCALS
56
57
0
    const int ith = params->ith; // thread index
58
0
    const int nth = params->nth; // number of threads
59
60
    // parallelize by rows
61
0
    const int nr = ne01;
62
    // number of rows per thread
63
0
    const int dr = (nr + nth - 1) / nth;
64
    // row range for this thread
65
0
    const int ir0 = dr * ith;
66
0
    const int ir1 = MIN(ir0 + dr, nr);
67
68
    // case: type & row size equal
69
0
    if (src0->type == dst->type &&
70
0
        ne00 == ne0 &&
71
0
        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
72
        // copy by rows
73
0
        const size_t rs = ne00*nb00;
74
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
75
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
76
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
77
0
                    memcpy(
78
0
                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
79
0
                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
80
0
                        rs);
81
0
                }
82
0
            }
83
0
        }
84
0
        return;
85
0
    }
86
87
    // case: dst tensor is contiguous
88
0
    if (ggml_is_contiguous(dst)) {
89
0
        if (nb00 == sizeof(src_t)) {
90
0
            if constexpr (std::is_same_v<dst_t, src_t>) {
91
                // same type
92
0
                size_t id = 0;
93
0
                const size_t rs = ne00 * nb00;
94
0
                char * dst_ptr = (char *) dst->data;
95
96
0
                for (int i03 = 0; i03 < ne03; i03++) {
97
0
                    for (int i02 = 0; i02 < ne02; i02++) {
98
0
                        id += rs * ir0;
99
0
                        for (int i01 = ir0; i01 < ir1; i01++) {
100
0
                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
101
0
                            memcpy(dst_ptr + id, src0_ptr, rs);
102
0
                            id += rs;
103
0
                        }
104
0
                        id += rs * (ne01 - ir1);
105
0
                    }
106
0
                }
107
0
            } else {
108
                // casting between non-quantized types
109
0
                size_t id = 0;
110
0
                dst_t * dst_ptr = (dst_t *) dst->data;
111
112
0
                for (int i03 = 0; i03 < ne03; i03++) {
113
0
                    for (int i02 = 0; i02 < ne02; i02++) {
114
0
                        id += ne00 * ir0;
115
0
                        for (int i01 = ir0; i01 < ir1; i01++) {
116
0
                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
117
0
                            for (int i00 = 0; i00 < ne00; i00++) {
118
0
                                float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
119
0
                                dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
120
0
                                id++;
121
0
                            }
122
0
                        }
123
0
                        id += ne00 * (ne01 - ir1);
124
0
                    }
125
0
                }
126
0
            }
127
0
        } else {
128
            //printf("%s: this is not optimal - fix me\n", __func__);
129
130
0
            size_t id = 0;
131
0
            dst_t * dst_ptr = (dst_t *) dst->data;
132
133
0
            for (int i03 = 0; i03 < ne03; i03++) {
134
0
                for (int i02 = 0; i02 < ne02; i02++) {
135
0
                    id += ne00 * ir0;
136
0
                    for (int i01 = ir0; i01 < ir1; i01++) {
137
0
                        for (int i00 = 0; i00 < ne00; i00++) {
138
0
                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
139
140
0
                            float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
141
0
                            dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
142
0
                            id++;
143
0
                        }
144
0
                    }
145
0
                    id += ne00 * (ne01 - ir1);
146
0
                }
147
0
            }
148
0
        }
149
0
        return;
150
0
    }
151
152
    // dst counters
153
0
    int64_t i10 = 0;
154
0
    int64_t i11 = 0;
155
0
    int64_t i12 = 0;
156
0
    int64_t i13 = 0;
157
158
0
    if constexpr (std::is_same_v<dst_t, src_t>) {
159
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
160
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
161
0
                i10 += ne00 * ir0;
162
0
                while (i10 >= ne0) {
163
0
                    i10 -= ne0;
164
0
                    if (++i11 == ne1) {
165
0
                        i11 = 0;
166
0
                        if (++i12 == ne2) {
167
0
                            i12 = 0;
168
0
                            if (++i13 == ne3) {
169
0
                                i13 = 0;
170
0
                            }
171
0
                        }
172
0
                    }
173
0
                }
174
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
175
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
176
0
                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
177
0
                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
178
179
0
                        memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
180
181
0
                        if (++i10 == ne00) {
182
0
                            i10 = 0;
183
0
                            if (++i11 == ne01) {
184
0
                                i11 = 0;
185
0
                                if (++i12 == ne02) {
186
0
                                    i12 = 0;
187
0
                                    if (++i13 == ne03) {
188
0
                                        i13 = 0;
189
0
                                    }
190
0
                                }
191
0
                            }
192
0
                        }
193
0
                    }
194
0
                }
195
0
                i10 += ne00 * (ne01 - ir1);
196
0
                while (i10 >= ne0) {
197
0
                    i10 -= ne0;
198
0
                    if (++i11 == ne1) {
199
0
                        i11 = 0;
200
0
                        if (++i12 == ne2) {
201
0
                            i12 = 0;
202
0
                            if (++i13 == ne3) {
203
0
                                i13 = 0;
204
0
                            }
205
0
                        }
206
0
                    }
207
0
                }
208
0
            }
209
0
        }
210
211
0
    } else {
212
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
213
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
214
0
                i10 += ne00 * ir0;
215
0
                while (i10 >= ne0) {
216
0
                    i10 -= ne0;
217
0
                    if (++i11 == ne1) {
218
0
                        i11 = 0;
219
0
                        if (++i12 == ne2) {
220
0
                            i12 = 0;
221
0
                            if (++i13 == ne3) {
222
0
                                i13 = 0;
223
0
                            }
224
0
                        }
225
0
                    }
226
0
                }
227
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
228
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
229
0
                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
230
0
                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
231
232
0
                        float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
233
0
                        *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
234
235
0
                        if (++i10 == ne0) {
236
0
                            i10 = 0;
237
0
                            if (++i11 == ne1) {
238
0
                                i11 = 0;
239
0
                                if (++i12 == ne2) {
240
0
                                    i12 = 0;
241
0
                                    if (++i13 == ne3) {
242
0
                                        i13 = 0;
243
0
                                    }
244
0
                                }
245
0
                            }
246
0
                        }
247
0
                    }
248
0
                }
249
0
                i10 += ne00 * (ne01 - ir1);
250
0
                while (i10 >= ne0) {
251
0
                    i10 -= ne0;
252
0
                    if (++i11 == ne1) {
253
0
                        i11 = 0;
254
0
                        if (++i12 == ne2) {
255
0
                            i12 = 0;
256
0
                            if (++i13 == ne3) {
257
0
                                i13 = 0;
258
0
                            }
259
0
                        }
260
0
                    }
261
0
                }
262
0
            }
263
0
        }
264
0
    }
265
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<unsigned short, unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<unsigned short, ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<unsigned short, float>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<ggml_bf16_t, unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<ggml_bf16_t, float>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, float>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<float, int>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_flt<int, float>(ggml_compute_params const*, ggml_tensor*)
266
267
268
template<typename src_t>
269
static void ggml_compute_forward_dup_to_q(
270
        const ggml_compute_params * params,
271
0
        ggml_tensor * dst) {
272
273
0
    const ggml_tensor * src0 = dst->src[0];
274
275
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
276
0
    GGML_ASSERT(!ggml_is_quantized(src0->type));
277
278
0
    GGML_TENSOR_UNARY_OP_LOCALS
279
280
0
    const int ith = params->ith; // thread index
281
0
    const int nth = params->nth; // number of threads
282
283
    // parallelize by rows
284
0
    const int nr = ne01;
285
    // number of rows per thread
286
0
    const int dr = (nr + nth - 1) / nth;
287
    // row range for this thread
288
0
    const int ir0 = dr * ith;
289
0
    const int ir1 = MIN(ir0 + dr, nr);
290
291
0
    if (ggml_is_contiguous(dst) &&
292
0
            nb00 == sizeof(src_t) &&
293
0
            ggml_get_type_traits_cpu(dst->type)->from_float) {
294
        // casting non-quantized types --> intermediate f32 --> quantized
295
0
        ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
296
0
        float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
297
298
0
        size_t id = 0;
299
0
        size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
300
0
        char * dst_ptr = (char *) dst->data;
301
302
0
        for (int i03 = 0; i03 < ne03; i03++) {
303
0
            for (int i02 = 0; i02 < ne02; i02++) {
304
0
                id += rs * ir0;
305
0
                for (int i01 = ir0; i01 < ir1; i01++) {
306
0
                    const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
307
308
0
                    for (int i00 = 0; i00 < ne00; i00++) {
309
0
                        src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
310
0
                    }
311
312
0
                    quantize_row_q(src0_f32, dst_ptr + id, ne00);
313
0
                    id += rs;
314
0
                }
315
0
                id += rs * (ne01 - ir1);
316
0
            }
317
0
        }
318
0
    } else {
319
        // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
320
0
        GGML_ABORT("not implemented");
321
0
    }
322
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_to_q<unsigned short>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_to_q<ggml_bf16_t>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_dup_to_q<float>(ggml_compute_params const*, ggml_tensor*)
323
324
// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
325
static void ggml_compute_forward_dup_bytes(
326
        const ggml_compute_params * params,
327
0
        ggml_tensor * dst) {
328
0
    const ggml_tensor * src0 = dst->src[0];
329
330
0
    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
331
0
    GGML_ASSERT(src0->type == dst->type);
332
333
0
    GGML_TENSOR_UNARY_OP_LOCALS;
334
335
0
    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
336
0
        ggml_compute_forward_dup_same_cont(params, dst);
337
0
        return;
338
0
    }
339
340
0
    const size_t type_size = ggml_type_size(src0->type);
341
342
0
    const int ith = params->ith; // thread index
343
0
    const int nth = params->nth; // number of threads
344
345
    // parallelize by rows
346
0
    const int nr = ne01;
347
    // number of rows per thread
348
0
    const int dr = (nr + nth - 1) / nth;
349
    // row range for this thread
350
0
    const int ir0 = dr * ith;
351
0
    const int ir1 = MIN(ir0 + dr, nr);
352
353
0
    if (src0->type == dst->type &&
354
0
        ggml_are_same_shape(src0, dst) &&
355
0
        nb00 == type_size && nb0 == type_size) {
356
        // copy by rows
357
0
        const size_t rs = ggml_row_size(src0->type, ne00);
358
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
359
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
360
0
                for (int64_t i01 = ir0; i01 < ir1; i01++) {
361
0
                    memcpy(
362
0
                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
363
0
                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
364
0
                        rs);
365
0
                }
366
0
            }
367
0
        }
368
0
        return;
369
0
    }
370
371
0
    if (ggml_is_contiguous(dst)) {
372
0
        size_t id = 0;
373
0
        char * dst_ptr = (char *) dst->data;
374
0
        const size_t rs = ne00 * type_size;
375
376
0
        if (nb00 == type_size) {
377
            // src0 is contigous on first dimension, copy by rows
378
0
            for (int64_t i03 = 0; i03 < ne03; i03++) {
379
0
                for (int64_t i02 = 0; i02 < ne02; i02++) {
380
0
                    id += rs * ir0;
381
0
                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
382
0
                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
383
0
                        memcpy(dst_ptr + id, src0_ptr, rs);
384
0
                        id += rs;
385
0
                    }
386
0
                    id += rs * (ne01 - ir1);
387
0
                }
388
0
            }
389
0
        } else {
390
            //printf("%s: this is not optimal - fix me\n", __func__);
391
392
0
            for (int64_t i03 = 0; i03 < ne03; i03++) {
393
0
                for (int64_t i02 = 0; i02 < ne02; i02++) {
394
0
                    id += rs * ir0;
395
0
                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
396
0
                        for (int64_t i00 = 0; i00 < ne00; i00++) {
397
0
                            const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
398
0
                            memcpy(dst_ptr + id, src0_ptr, type_size);
399
400
0
                            id += type_size;
401
0
                        }
402
0
                    }
403
0
                    id += rs * (ne01 - ir1);
404
0
                }
405
0
            }
406
0
        }
407
408
0
        return;
409
0
    }
410
411
    // dst counters
412
0
    int64_t k10 = 0;
413
0
    int64_t i11 = 0;
414
0
    int64_t i12 = 0;
415
0
    int64_t i13 = 0;
416
417
    // number of blocks in a row
418
0
    const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
419
0
    const int64_t nk0  = ne0  / ggml_blck_size(dst->type);
420
421
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
422
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
423
0
            k10 += nk00 * ir0;
424
0
            while (k10 >= nk0) {
425
0
                k10 -= nk0;
426
0
                if (++i11 == ne1) {
427
0
                    i11 = 0;
428
0
                    if (++i12 == ne2) {
429
0
                        i12 = 0;
430
0
                        if (++i13 == ne3) {
431
0
                            i13 = 0;
432
0
                        }
433
0
                    }
434
0
                }
435
0
            }
436
0
            for (int64_t i01 = ir0; i01 < ir1; i01++) {
437
0
                for (int64_t k00 = 0; k00 < nk00; k00++) {
438
0
                    const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
439
0
                          char * dst_ptr  = ((char *)  dst->data + k10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
440
441
0
                    memcpy(dst_ptr, src0_ptr, type_size);
442
443
0
                    if (++k10 == nk0) {
444
0
                        k10 = 0;
445
0
                        if (++i11 == ne1) {
446
0
                            i11 = 0;
447
0
                            if (++i12 == ne2) {
448
0
                                i12 = 0;
449
0
                                if (++i13 == ne3) {
450
0
                                    i13 = 0;
451
0
                                }
452
0
                            }
453
0
                        }
454
0
                    }
455
0
                }
456
0
            }
457
0
            k10 += nk00 * (ne01 - ir1);
458
0
            while (k10 >= nk0) {
459
0
                k10 -= nk0;
460
0
                if (++i11 == ne1) {
461
0
                    i11 = 0;
462
0
                    if (++i12 == ne2) {
463
0
                        i12 = 0;
464
0
                        if (++i13 == ne3) {
465
0
                            i13 = 0;
466
0
                        }
467
0
                    }
468
0
                }
469
0
            }
470
0
        }
471
0
    }
472
0
}
473
474
static void ggml_compute_forward_dup_from_q(
475
        const ggml_compute_params * params,
476
0
              ggml_tensor * dst) {
477
478
0
    const ggml_tensor * src0 = dst->src[0];
479
0
    const ggml_tensor * src1 = dst->src[1];
480
481
0
    GGML_TENSOR_BINARY_OP_LOCALS
482
483
0
    const ggml_type type = src0->type;
484
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
485
486
0
    size_t qk = ggml_blck_size(type);
487
0
    const int64_t nr = ggml_nelements(src1) / qk;
488
489
    // destination must be contiguous in the first dimension
490
0
    GGML_ASSERT(nb10 == ggml_type_size(dst->type));
491
    // must either have first dimension large enough to hold a row, or fully contiguous
492
0
    GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
493
494
0
    const int ith = params->ith;
495
0
    const int nth = params->nth;
496
497
0
    const int dr = (nr + nth - 1)/nth;
498
499
    // row range for this thread
500
0
    const int ir0 = dr*ith;
501
0
    const int ir1 = MIN(ir0 + dr, nr);
502
503
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
504
505
0
        uint32_t i = ir * qk;
506
507
0
        const int64_t i03 = i/(ne00 * ne01 * ne02);
508
0
        const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
509
0
        const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
510
0
        const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
511
0
        const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
512
513
0
        const int64_t i13 = i/(ne10 * ne11 * ne12);
514
0
        const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
515
0
        const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
516
0
        const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
517
0
        const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
518
519
0
        dequantize_row_q(
520
0
                (const void *) ((char *) src0->data + x_offset),
521
0
                     (float *) ((char *)  dst->data + dst_offset), qk);
522
0
    }
523
0
}
524
525
void ggml_compute_forward_dup(
526
        const ggml_compute_params * params,
527
0
        ggml_tensor * dst) {
528
529
0
    const ggml_tensor * src0 = dst->src[0];
530
531
0
    if (src0->type == dst->type) {
532
0
        ggml_compute_forward_dup_bytes(params, dst);
533
0
        return;
534
0
    }
535
536
0
    switch (src0->type) {
537
0
        case GGML_TYPE_F16:
538
0
            {
539
0
                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
540
0
                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
541
0
                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_fp16_t, float      >(params, dst);
542
0
                else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
543
0
            } break;
544
0
        case GGML_TYPE_BF16:
545
0
            {
546
0
                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
547
0
                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
548
0
                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_bf16_t, float      >(params, dst);
549
0
                else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
550
0
            } break;
551
0
        case GGML_TYPE_F32:
552
0
            {
553
0
                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
554
0
                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
555
0
                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<float, float      >(params, dst);
556
0
                else if (dst->type == GGML_TYPE_I32)  ggml_compute_forward_dup_flt<float, int32_t    >(params, dst);
557
0
                else ggml_compute_forward_dup_to_q<float>(params, dst);
558
0
            } break;
559
0
        case GGML_TYPE_I32:
560
0
            {
561
0
                if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
562
0
                else GGML_ABORT("not implemented");
563
0
            } break;
564
0
        default:
565
0
            {
566
0
                if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
567
0
                    ggml_compute_forward_dup_from_q(params, dst);
568
0
                    break;
569
0
                }
570
0
                GGML_ABORT("fatal error");
571
0
            }
572
0
    }
573
0
}
574
575
// ggml_compute_forward_add
576
577
static void ggml_compute_forward_add_q_f32(
578
        const ggml_compute_params * params,
579
0
        ggml_tensor * dst) {
580
581
0
    const ggml_tensor * src0 = dst->src[0];
582
0
    const ggml_tensor * src1 = dst->src[1];
583
584
0
    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
585
586
0
    const int nr  = ggml_nrows(src0);
587
588
0
    GGML_TENSOR_BINARY_OP_LOCALS
589
590
0
    const int ith = params->ith;
591
0
    const int nth = params->nth;
592
593
0
    const ggml_type type = src0->type;
594
0
    const ggml_type dtype = dst->type;
595
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
596
0
    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
597
598
    // we don't support permuted src0 or src1
599
0
    GGML_ASSERT(nb00 == ggml_type_size(type));
600
0
    GGML_ASSERT(nb10 == sizeof(float));
601
602
    // dst cannot be transposed or permuted
603
0
    GGML_ASSERT(nb0 <= nb1);
604
0
    GGML_ASSERT(nb1 <= nb2);
605
0
    GGML_ASSERT(nb2 <= nb3);
606
607
0
    GGML_ASSERT(ggml_is_quantized(src0->type));
608
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
609
610
    // rows per thread
611
0
    const int dr = (nr + nth - 1)/nth;
612
613
    // row range for this thread
614
0
    const int ir0 = dr*ith;
615
0
    const int ir1 = MIN(ir0 + dr, nr);
616
617
0
    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
618
619
0
    for (int ir = ir0; ir < ir1; ++ir) {
620
        // src0 indices
621
0
        const int i03 = ir/(ne02*ne01);
622
0
        const int i02 = (ir - i03*ne02*ne01)/ne01;
623
0
        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
624
625
        // src1 and dst are same shape as src0 => same indices
626
0
        const int i13 = i03;
627
0
        const int i12 = i02;
628
0
        const int i11 = i01;
629
630
0
        const int i3 = i03;
631
0
        const int i2 = i02;
632
0
        const int i1 = i01;
633
634
0
        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
635
0
        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
636
0
        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
637
638
0
        assert(ne00 % 32 == 0);
639
640
        // unquantize row from src0 to temp buffer
641
0
        dequantize_row_q(src0_row, wdata, ne00);
642
        // add src1
643
0
        ggml_vec_acc_f32(ne00, wdata, src1_row);
644
        // quantize row to dst
645
0
        if (quantize_row_q != NULL) {
646
0
            quantize_row_q(wdata, dst_row, ne00);
647
0
        } else {
648
0
            memcpy(dst_row, wdata, ne0*nb0);
649
0
        }
650
0
    }
651
0
}
652
653
void ggml_compute_forward_add(
654
        const ggml_compute_params * params,
655
0
        ggml_tensor * dst) {
656
657
0
    const ggml_tensor * src0 = dst->src[0];
658
659
0
    switch (src0->type) {
660
0
        case GGML_TYPE_F32:
661
0
        case GGML_TYPE_F16:
662
0
        case GGML_TYPE_BF16:
663
0
            {
664
0
                ggml_compute_forward_add_non_quantized(params, dst);
665
0
            } break;
666
0
        case GGML_TYPE_Q4_0:
667
0
        case GGML_TYPE_Q4_1:
668
0
        case GGML_TYPE_Q5_0:
669
0
        case GGML_TYPE_Q5_1:
670
0
        case GGML_TYPE_Q8_0:
671
0
        case GGML_TYPE_MXFP4:
672
0
        case GGML_TYPE_Q2_K:
673
0
        case GGML_TYPE_Q3_K:
674
0
        case GGML_TYPE_Q4_K:
675
0
        case GGML_TYPE_Q5_K:
676
0
        case GGML_TYPE_Q6_K:
677
0
        case GGML_TYPE_TQ1_0:
678
0
        case GGML_TYPE_TQ2_0:
679
0
        case GGML_TYPE_IQ2_XXS:
680
0
        case GGML_TYPE_IQ2_XS:
681
0
        case GGML_TYPE_IQ3_XXS:
682
0
        case GGML_TYPE_IQ1_S:
683
0
        case GGML_TYPE_IQ1_M:
684
0
        case GGML_TYPE_IQ4_NL:
685
0
        case GGML_TYPE_IQ4_XS:
686
0
        case GGML_TYPE_IQ3_S:
687
0
        case GGML_TYPE_IQ2_S:
688
0
            {
689
0
                ggml_compute_forward_add_q_f32(params, dst);
690
0
            } break;
691
0
        default:
692
0
            {
693
0
                GGML_ABORT("fatal error");
694
0
            }
695
0
    }
696
0
}
697
698
// ggml_compute_forward_add_id
699
700
static void ggml_compute_forward_add_id_f32(
701
        const ggml_compute_params * params,
702
0
        ggml_tensor * dst) {
703
704
0
    const ggml_tensor * src0 = dst->src[0];
705
0
    const ggml_tensor * src1 = dst->src[1];
706
0
    const ggml_tensor * src2 = dst->src[2];
707
708
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
709
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
710
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
711
0
    GGML_ASSERT(src2->type == GGML_TYPE_I32);
712
713
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
714
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
715
716
0
    const int ith = params->ith;
717
0
    const int nth = params->nth;
718
719
0
    const int nr  = ggml_nrows(src0);
720
721
0
    GGML_TENSOR_TERNARY_OP_LOCALS
722
723
0
    GGML_ASSERT( nb0 == sizeof(float));
724
0
    GGML_ASSERT(nb10 == sizeof(float));
725
726
    // rows per thread
727
0
    const int dr = (nr + nth - 1)/nth;
728
729
    // row range for this thread
730
0
    const int ir0 = dr*ith;
731
0
    const int ir1 = MIN(ir0 + dr, nr);
732
733
0
    for (int ir = ir0; ir < ir1; ++ir) {
734
        // src0 indices
735
0
        const int i3 = ir/(ne2*ne1);
736
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
737
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
738
739
        // src1 indices
740
0
        const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
741
742
0
        GGML_ASSERT(i11 >= 0 && i11 < ne11);
743
744
0
        ggml_vec_add_f32(ne0,
745
0
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
746
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
747
0
                (float *) ((char *) src1->data + i11*nb11));
748
0
    }
749
0
}
750
751
void ggml_compute_forward_add_id(
752
        const ggml_compute_params * params,
753
0
        ggml_tensor * dst) {
754
755
0
    const ggml_tensor * src0 = dst->src[0];
756
757
0
    switch (src0->type) {
758
0
        case GGML_TYPE_F32:
759
0
            {
760
0
                ggml_compute_forward_add_id_f32(params, dst);
761
0
            } break;
762
0
        default:
763
0
            {
764
0
                GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
765
0
            }
766
0
    }
767
0
}
768
769
// ggml_compute_forward_add1
770
771
static void ggml_compute_forward_add1_f32(
772
        const ggml_compute_params * params,
773
0
        ggml_tensor * dst) {
774
775
0
    const ggml_tensor * src0 = dst->src[0];
776
0
    const ggml_tensor * src1 = dst->src[1];
777
778
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
779
0
    GGML_ASSERT(ggml_is_scalar(src1));
780
781
0
    const int ith = params->ith;
782
0
    const int nth = params->nth;
783
784
0
    const int nr  = ggml_nrows(src0);
785
786
0
    GGML_TENSOR_UNARY_OP_LOCALS
787
788
0
    GGML_ASSERT( nb0 == sizeof(float));
789
0
    GGML_ASSERT(nb00 == sizeof(float));
790
791
    // rows per thread
792
0
    const int dr = (nr + nth - 1)/nth;
793
794
    // row range for this thread
795
0
    const int ir0 = dr*ith;
796
0
    const int ir1 = MIN(ir0 + dr, nr);
797
798
0
    for (int ir = ir0; ir < ir1; ++ir) {
799
        // src0 and dst are same shape => same indices
800
0
        const int i3 = ir/(ne2*ne1);
801
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
802
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
803
804
#ifdef GGML_USE_ACCELERATE
805
        GGML_UNUSED(ggml_vec_add1_f32);
806
807
        vDSP_vadd(
808
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
809
                (float *) ((char *) src1->data), 0,
810
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
811
                ne0);
812
#else
813
0
        ggml_vec_add1_f32(ne0,
814
0
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
815
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
816
0
               *(float *) src1->data);
817
0
#endif
818
0
    }
819
0
}
820
821
static void ggml_compute_forward_add1_f16_f32(
822
        const ggml_compute_params * params,
823
0
        ggml_tensor * dst) {
824
825
0
    const ggml_tensor * src0 = dst->src[0];
826
0
    const ggml_tensor * src1 = dst->src[1];
827
828
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
829
0
    GGML_ASSERT(ggml_is_scalar(src1));
830
831
    // scalar to add
832
0
    const float v = *(float *) src1->data;
833
834
0
    const int ith = params->ith;
835
0
    const int nth = params->nth;
836
837
0
    const int nr  = ggml_nrows(src0);
838
839
0
    GGML_TENSOR_UNARY_OP_LOCALS
840
841
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
842
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
843
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
844
845
0
    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
846
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
847
848
    // rows per thread
849
0
    const int dr = (nr + nth - 1)/nth;
850
851
    // row range for this thread
852
0
    const int ir0 = dr*ith;
853
0
    const int ir1 = MIN(ir0 + dr, nr);
854
855
0
    for (int ir = ir0; ir < ir1; ++ir) {
856
        // src0 and dst are same shape => same indices
857
0
        const int i3 = ir/(ne2*ne1);
858
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
859
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
860
861
0
        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
862
0
        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
863
0
        for (int i = 0; i < ne0; i++) {
864
0
            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
865
0
        }
866
0
    }
867
0
}
868
869
static void ggml_compute_forward_add1_f16_f16(
870
        const ggml_compute_params * params,
871
0
        ggml_tensor * dst) {
872
873
0
    const ggml_tensor * src0 = dst->src[0];
874
0
    const ggml_tensor * src1 = dst->src[1];
875
876
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
877
0
    GGML_ASSERT(ggml_is_scalar(src1));
878
879
    // scalar to add
880
0
    const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
881
882
0
    const int ith = params->ith;
883
0
    const int nth = params->nth;
884
885
0
    const int nr  = ggml_nrows(src0);
886
887
0
    GGML_TENSOR_UNARY_OP_LOCALS
888
889
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
890
0
    GGML_ASSERT(src1->type == GGML_TYPE_F16);
891
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
892
893
0
    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
894
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
895
896
    // rows per thread
897
0
    const int dr = (nr + nth - 1)/nth;
898
899
    // row range for this thread
900
0
    const int ir0 = dr*ith;
901
0
    const int ir1 = MIN(ir0 + dr, nr);
902
903
0
    for (int ir = ir0; ir < ir1; ++ir) {
904
        // src0 and dst are same shape => same indices
905
0
        const int i3 = ir/(ne2*ne1);
906
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
907
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
908
909
0
        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
910
0
        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
911
0
        for (int i = 0; i < ne0; i++) {
912
0
            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
913
0
        }
914
0
    }
915
0
}
916
917
static void ggml_compute_forward_add1_q_f32(
918
        const ggml_compute_params * params,
919
0
        ggml_tensor * dst) {
920
921
0
    const ggml_tensor * src0 = dst->src[0];
922
0
    const ggml_tensor * src1 = dst->src[1];
923
924
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
925
0
    GGML_ASSERT(ggml_is_scalar(src1));
926
927
    // scalar to add
928
0
    const float v = *(float *) src1->data;
929
930
0
    const int ith = params->ith;
931
0
    const int nth = params->nth;
932
933
0
    const int nr  = ggml_nrows(src0);
934
935
0
    GGML_TENSOR_UNARY_OP_LOCALS
936
937
0
    const ggml_type type = src0->type;
938
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
939
0
    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
940
941
    // we don't support permuted src0
942
0
    GGML_ASSERT(nb00 == ggml_type_size(type));
943
944
    // dst cannot be transposed or permuted
945
0
    GGML_ASSERT(nb0 <= nb1);
946
0
    GGML_ASSERT(nb1 <= nb2);
947
0
    GGML_ASSERT(nb2 <= nb3);
948
949
0
    GGML_ASSERT(ggml_is_quantized(src0->type));
950
0
    GGML_ASSERT(dst->type == src0->type);
951
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
952
953
    // rows per thread
954
0
    const int dr = (nr + nth - 1)/nth;
955
956
    // row range for this thread
957
0
    const int ir0 = dr*ith;
958
0
    const int ir1 = MIN(ir0 + dr, nr);
959
960
0
    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
961
962
0
    for (int ir = ir0; ir < ir1; ++ir) {
963
        // src0 and dst are same shape => same indices
964
0
        const int i3 = ir/(ne2*ne1);
965
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
966
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
967
968
0
        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
969
0
        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));
970
971
0
        assert(ne0 % 32 == 0);
972
973
        // unquantize row from src0 to temp buffer
974
0
        dequantize_row_q(src0_row, wdata, ne0);
975
        // add src1
976
0
        ggml_vec_acc1_f32(ne0, wdata, v);
977
        // quantize row to dst
978
0
        quantize_row_q(wdata, dst_row, ne0);
979
0
    }
980
0
}
981
982
static void ggml_compute_forward_add1_bf16_f32(
983
        const ggml_compute_params * params,
984
0
        ggml_tensor * dst) {
985
986
0
    const ggml_tensor * src0 = dst->src[0];
987
0
    const ggml_tensor * src1 = dst->src[1];
988
989
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
990
0
    GGML_ASSERT(ggml_is_scalar(src1));
991
992
    // scalar to add
993
0
    const float v = *(float *) src1->data;
994
995
0
    const int ith = params->ith;
996
0
    const int nth = params->nth;
997
998
0
    const int nr  = ggml_nrows(src0);
999
1000
0
    GGML_TENSOR_UNARY_OP_LOCALS
1001
1002
0
    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
1003
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
1004
0
    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
1005
1006
0
    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
1007
0
    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
1008
1009
    // rows per thread
1010
0
    const int dr = (nr + nth - 1)/nth;
1011
1012
    // row range for this thread
1013
0
    const int ir0 = dr*ith;
1014
0
    const int ir1 = MIN(ir0 + dr, nr);
1015
1016
0
    for (int ir = ir0; ir < ir1; ++ir) {
1017
        // src0 and dst are same shape => same indices
1018
0
        const int i3 = ir/(ne2*ne1);
1019
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
1020
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1021
1022
0
        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
1023
0
        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1024
0
        for (int i = 0; i < ne0; i++) {
1025
0
            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
1026
0
        }
1027
0
    }
1028
0
}
1029
1030
static void ggml_compute_forward_add1_bf16_bf16(
1031
        const ggml_compute_params * params,
1032
0
        ggml_tensor * dst) {
1033
1034
0
    const ggml_tensor * src0 = dst->src[0];
1035
0
    const ggml_tensor * src1 = dst->src[1];
1036
1037
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
1038
0
    GGML_ASSERT(ggml_is_scalar(src1));
1039
1040
    // scalar to add
1041
0
    const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
1042
1043
0
    const int ith = params->ith;
1044
0
    const int nth = params->nth;
1045
1046
0
    const int nr  = ggml_nrows(src0);
1047
1048
0
    GGML_TENSOR_UNARY_OP_LOCALS
1049
1050
0
    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
1051
0
    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
1052
0
    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
1053
1054
0
    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
1055
0
    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
1056
1057
    // rows per thread
1058
0
    const int dr = (nr + nth - 1)/nth;
1059
1060
    // row range for this thread
1061
0
    const int ir0 = dr*ith;
1062
0
    const int ir1 = MIN(ir0 + dr, nr);
1063
1064
0
    for (int ir = ir0; ir < ir1; ++ir) {
1065
        // src0 and dst are same shape => same indices
1066
0
        const int i3 = ir/(ne2*ne1);
1067
0
        const int i2 = (ir - i3*ne2*ne1)/ne1;
1068
0
        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1069
1070
0
        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
1071
0
        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1072
0
        for (int i = 0; i < ne0; i++) {
1073
0
            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
1074
0
        }
1075
0
    }
1076
0
}
1077
1078
void ggml_compute_forward_add1(
1079
        const ggml_compute_params * params,
1080
0
        ggml_tensor * dst) {
1081
1082
0
    const ggml_tensor * src0 = dst->src[0];
1083
0
    const ggml_tensor * src1 = dst->src[1];
1084
1085
0
    switch (src0->type) {
1086
0
        case GGML_TYPE_F32:
1087
0
            {
1088
0
                ggml_compute_forward_add1_f32(params, dst);
1089
0
            } break;
1090
0
        case GGML_TYPE_F16:
1091
0
            {
1092
0
                if (src1->type == GGML_TYPE_F16) {
1093
0
                    ggml_compute_forward_add1_f16_f16(params, dst);
1094
0
                }
1095
0
                else if (src1->type == GGML_TYPE_F32) {
1096
0
                    ggml_compute_forward_add1_f16_f32(params, dst);
1097
0
                }
1098
0
                else {
1099
0
                    GGML_ABORT("fatal error");
1100
0
                }
1101
0
            } break;
1102
0
        case GGML_TYPE_BF16:
1103
0
            {
1104
0
                if (src1->type == GGML_TYPE_BF16) {
1105
0
                    ggml_compute_forward_add1_bf16_bf16(params, dst);
1106
0
                }
1107
0
                else if (src1->type == GGML_TYPE_F32) {
1108
0
                    ggml_compute_forward_add1_bf16_f32(params, dst);
1109
0
                }
1110
0
                else {
1111
0
                    GGML_ABORT("fatal error");
1112
0
                }
1113
0
            } break;
1114
0
        case GGML_TYPE_Q4_0:
1115
0
        case GGML_TYPE_Q4_1:
1116
0
        case GGML_TYPE_Q5_0:
1117
0
        case GGML_TYPE_Q5_1:
1118
0
        case GGML_TYPE_Q8_0:
1119
0
        case GGML_TYPE_Q8_1:
1120
0
        case GGML_TYPE_MXFP4:
1121
0
        case GGML_TYPE_Q2_K:
1122
0
        case GGML_TYPE_Q3_K:
1123
0
        case GGML_TYPE_Q4_K:
1124
0
        case GGML_TYPE_Q5_K:
1125
0
        case GGML_TYPE_Q6_K:
1126
0
        case GGML_TYPE_TQ1_0:
1127
0
        case GGML_TYPE_TQ2_0:
1128
0
        case GGML_TYPE_IQ2_XXS:
1129
0
        case GGML_TYPE_IQ2_XS:
1130
0
        case GGML_TYPE_IQ3_XXS:
1131
0
        case GGML_TYPE_IQ1_S:
1132
0
        case GGML_TYPE_IQ1_M:
1133
0
        case GGML_TYPE_IQ4_NL:
1134
0
        case GGML_TYPE_IQ4_XS:
1135
0
        case GGML_TYPE_IQ3_S:
1136
0
        case GGML_TYPE_IQ2_S:
1137
0
            {
1138
0
                ggml_compute_forward_add1_q_f32(params, dst);
1139
0
            } break;
1140
0
        default:
1141
0
            {
1142
0
                GGML_ABORT("fatal error");
1143
0
            }
1144
0
    }
1145
0
}
1146
1147
// ggml_compute_forward_acc
1148
1149
static void ggml_compute_forward_acc_f32(
1150
        const ggml_compute_params * params,
1151
0
        ggml_tensor * dst) {
1152
1153
0
    const ggml_tensor * src0 = dst->src[0];
1154
0
    const ggml_tensor * src1 = dst->src[1];
1155
1156
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
1157
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
1158
1159
    // view src0 and dst with these strides and data offset inbytes during acc
1160
    // nb0 is implicitly element_size because src0 and dst are contiguous
1161
0
    size_t nb1     = ((int32_t *) dst->op_params)[0];
1162
0
    size_t nb2     = ((int32_t *) dst->op_params)[1];
1163
0
    size_t nb3     = ((int32_t *) dst->op_params)[2];
1164
0
    size_t offset  = ((int32_t *) dst->op_params)[3];
1165
0
    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
1166
1167
0
    if (!inplace) {
1168
0
        if (params->ith == 0) {
1169
            // memcpy needs to be synchronized across threads to avoid race conditions.
1170
            // => do it in INIT phase
1171
0
            memcpy(
1172
0
                ((char *)  dst->data),
1173
0
                ((char *) src0->data),
1174
0
                ggml_nbytes(dst));
1175
0
        }
1176
0
        ggml_barrier(params->threadpool);
1177
0
    }
1178
1179
0
    const int ith = params->ith;
1180
0
    const int nth = params->nth;
1181
1182
0
    const int nr = ggml_nrows(src1);
1183
0
    const int nc = src1->ne[0];
1184
1185
0
    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
1186
0
    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
1187
1188
    // src0 and dst as viewed during acc
1189
0
    const size_t nb0 = ggml_element_size(src0);
1190
1191
0
    const size_t nb00 = nb0;
1192
0
    const size_t nb01 = nb1;
1193
0
    const size_t nb02 = nb2;
1194
0
    const size_t nb03 = nb3;
1195
1196
0
    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));
1197
0
    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
1198
1199
0
    GGML_ASSERT(nb10 == sizeof(float));
1200
1201
    // rows per thread
1202
0
    const int dr = (nr + nth - 1)/nth;
1203
1204
    // row range for this thread
1205
0
    const int ir0 = dr*ith;
1206
0
    const int ir1 = MIN(ir0 + dr, nr);
1207
1208
0
    for (int ir = ir0; ir < ir1; ++ir) {
1209
        // src0 and dst are viewed with shape of src1 and offset
1210
        // => same indices
1211
0
        const int i3 = ir/(ne12*ne11);
1212
0
        const int i2 = (ir - i3*ne12*ne11)/ne11;
1213
0
        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
1214
1215
#ifdef GGML_USE_ACCELERATE
1216
        vDSP_vadd(
1217
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
1218
                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
1219
                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);
1220
#else
1221
0
        ggml_vec_add_f32(nc,
1222
0
                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
1223
0
                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
1224
0
                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
1225
0
#endif
1226
0
    }
1227
0
}
1228
1229
void ggml_compute_forward_acc(
1230
        const ggml_compute_params * params,
1231
0
        ggml_tensor * dst) {
1232
1233
0
    const ggml_tensor * src0 = dst->src[0];
1234
1235
0
    switch (src0->type) {
1236
0
        case GGML_TYPE_F32:
1237
0
            {
1238
0
                ggml_compute_forward_acc_f32(params, dst);
1239
0
            } break;
1240
0
        case GGML_TYPE_F16:
1241
0
        case GGML_TYPE_BF16:
1242
0
        case GGML_TYPE_Q4_0:
1243
0
        case GGML_TYPE_Q4_1:
1244
0
        case GGML_TYPE_Q5_0:
1245
0
        case GGML_TYPE_Q5_1:
1246
0
        case GGML_TYPE_Q8_0:
1247
0
        case GGML_TYPE_Q8_1:
1248
0
        case GGML_TYPE_MXFP4:
1249
0
        case GGML_TYPE_Q2_K:
1250
0
        case GGML_TYPE_Q3_K:
1251
0
        case GGML_TYPE_Q4_K:
1252
0
        case GGML_TYPE_Q5_K:
1253
0
        case GGML_TYPE_Q6_K:
1254
0
        case GGML_TYPE_TQ1_0:
1255
0
        case GGML_TYPE_TQ2_0:
1256
0
        case GGML_TYPE_IQ2_XXS:
1257
0
        case GGML_TYPE_IQ2_XS:
1258
0
        case GGML_TYPE_IQ3_XXS:
1259
0
        case GGML_TYPE_IQ1_S:
1260
0
        case GGML_TYPE_IQ1_M:
1261
0
        case GGML_TYPE_IQ4_NL:
1262
0
        case GGML_TYPE_IQ4_XS:
1263
0
        case GGML_TYPE_IQ3_S:
1264
0
        case GGML_TYPE_IQ2_S:
1265
0
        default:
1266
0
            {
1267
0
                GGML_ABORT("fatal error");
1268
0
            }
1269
0
    }
1270
0
}
1271
1272
// ggml_compute_forward_sum
1273
1274
static void ggml_compute_forward_sum_f32(
1275
        const ggml_compute_params * params,
1276
0
        ggml_tensor * dst) {
1277
1278
0
    const ggml_tensor * src0 = dst->src[0];
1279
1280
0
    if (params->ith != 0) {
1281
0
        return;
1282
0
    }
1283
1284
0
    assert(ggml_is_scalar(dst));
1285
0
    assert(src0->nb[0] == sizeof(float));
1286
1287
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1288
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
1289
1290
0
    ggml_float sum     = 0;
1291
0
    ggml_float row_sum = 0;
1292
1293
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1294
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1295
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1296
0
                ggml_vec_sum_f32_ggf(ne00,
1297
0
                        &row_sum,
1298
0
                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
1299
0
                sum += row_sum;
1300
0
            }
1301
0
        }
1302
0
    }
1303
0
    ((float *) dst->data)[0] = sum;
1304
0
}
1305
1306
static void ggml_compute_forward_sum_f16(
1307
    const ggml_compute_params * params,
1308
0
          ggml_tensor * dst) {
1309
1310
0
    const ggml_tensor * src0 = dst->src[0];
1311
1312
0
    if (params->ith != 0) {
1313
0
        return;
1314
0
    }
1315
1316
0
    assert(ggml_is_scalar(dst));
1317
1318
0
    assert(src0->nb[0] == sizeof(ggml_fp16_t));
1319
1320
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1321
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
1322
1323
0
    float sum = 0;
1324
0
    float row_sum = 0;
1325
1326
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1327
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1328
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1329
0
                ggml_vec_sum_f16_ggf(ne00,
1330
0
                    &row_sum,
1331
0
                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
1332
0
                sum += row_sum;
1333
0
            }
1334
0
        }
1335
0
    }
1336
0
    ((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
1337
0
}
1338
1339
static void ggml_compute_forward_sum_bf16(
1340
    const ggml_compute_params * params,
1341
0
          ggml_tensor * dst) {
1342
1343
0
    const ggml_tensor * src0 = dst->src[0];
1344
1345
0
    if (params->ith != 0) {
1346
0
        return;
1347
0
    }
1348
1349
0
    assert(ggml_is_scalar(dst));
1350
1351
0
    assert(src0->nb[0] == sizeof(ggml_bf16_t));
1352
1353
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1354
0
    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
1355
1356
0
    float sum = 0;
1357
0
    float row_sum = 0;
1358
1359
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1360
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1361
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1362
0
                ggml_vec_sum_bf16_ggf(ne00,
1363
0
                    &row_sum,
1364
0
                    (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
1365
0
                sum += row_sum;
1366
0
            }
1367
0
        }
1368
0
    }
1369
0
    ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
1370
0
}
1371
1372
void ggml_compute_forward_sum(
1373
        const ggml_compute_params * params,
1374
0
        ggml_tensor * dst) {
1375
1376
0
    const ggml_tensor * src0 = dst->src[0];
1377
1378
0
    switch (src0->type) {
1379
0
        case GGML_TYPE_F32:
1380
0
            {
1381
0
                ggml_compute_forward_sum_f32(params, dst);
1382
0
            } break;
1383
0
        case GGML_TYPE_F16:
1384
0
            {
1385
0
                ggml_compute_forward_sum_f16(params, dst);
1386
0
            } break;
1387
0
        case GGML_TYPE_BF16:
1388
0
            {
1389
0
                ggml_compute_forward_sum_bf16(params, dst);
1390
0
            } break;
1391
0
        default:
1392
0
            {
1393
0
                GGML_ABORT("fatal error");
1394
0
            }
1395
0
    }
1396
0
}
1397
1398
// ggml_compute_forward_cumsum
1399
1400
static void ggml_compute_forward_cumsum_f32(
1401
        const ggml_compute_params * params,
1402
0
        ggml_tensor * dst) {
1403
1404
0
    const ggml_tensor * src0 = dst->src[0];
1405
1406
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
1407
0
    GGML_ASSERT(dst->nb[0] == sizeof(float));
1408
1409
0
    GGML_TENSOR_UNARY_OP_LOCALS
1410
1411
0
    GGML_ASSERT(ne0 == ne00);
1412
0
    GGML_ASSERT(ne1 == ne01);
1413
0
    GGML_ASSERT(ne2 == ne02);
1414
0
    GGML_ASSERT(ne3 == ne03);
1415
1416
0
    const auto [ir0, ir1] = get_thread_range(params, src0);
1417
1418
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
1419
0
        const int64_t i03 = ir/(ne02*ne01);
1420
0
        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
1421
0
        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1422
1423
0
        float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1424
0
        float * dst_row = (float *) ((char *) dst->data  + i01*nb1  + i02*nb2  + i03*nb3);
1425
1426
0
        ggml_vec_cumsum_f32(ne00, dst_row, src_row);
1427
0
    }
1428
0
}
1429
1430
void ggml_compute_forward_cumsum(
1431
        const ggml_compute_params * params,
1432
0
        ggml_tensor * dst) {
1433
1434
0
    const ggml_tensor * src0 = dst->src[0];
1435
1436
0
    switch (src0->type) {
1437
0
        case GGML_TYPE_F32:
1438
0
            {
1439
0
                ggml_compute_forward_cumsum_f32(params, dst);
1440
0
            } break;
1441
0
        default:
1442
0
            {
1443
0
                GGML_ABORT("fatal error");
1444
0
            }
1445
0
    }
1446
0
}
1447
1448
// ggml_compute_forward_sum_rows
1449
1450
static void ggml_compute_forward_sum_rows_f32(
1451
        const ggml_compute_params * params,
1452
0
        ggml_tensor * dst) {
1453
1454
0
    const ggml_tensor * src0 = dst->src[0];
1455
1456
0
    if (params->ith != 0) {
1457
0
        return;
1458
0
    }
1459
1460
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
1461
0
    GGML_ASSERT(dst->nb[0] == sizeof(float));
1462
1463
0
    GGML_TENSOR_UNARY_OP_LOCALS
1464
1465
0
    GGML_ASSERT(ne0 == 1);
1466
0
    GGML_ASSERT(ne1 == ne01);
1467
0
    GGML_ASSERT(ne2 == ne02);
1468
0
    GGML_ASSERT(ne3 == ne03);
1469
1470
0
    for (int64_t i3 = 0; i3 < ne03; i3++) {
1471
0
        for (int64_t i2 = 0; i2 < ne02; i2++) {
1472
0
            for (int64_t i1 = 0; i1 < ne01; i1++) {
1473
0
                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1474
0
                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
1475
0
                float row_sum = 0;
1476
0
                ggml_vec_sum_f32(ne00, &row_sum, src_row);
1477
0
                dst_row[0] = row_sum;
1478
0
            }
1479
0
        }
1480
0
    }
1481
0
}
1482
1483
void ggml_compute_forward_sum_rows(
1484
        const ggml_compute_params * params,
1485
0
        ggml_tensor * dst) {
1486
1487
0
    const ggml_tensor * src0 = dst->src[0];
1488
1489
0
    switch (src0->type) {
1490
0
        case GGML_TYPE_F32:
1491
0
            {
1492
0
                ggml_compute_forward_sum_rows_f32(params, dst);
1493
0
            } break;
1494
0
        default:
1495
0
            {
1496
0
                GGML_ABORT("fatal error");
1497
0
            }
1498
0
    }
1499
0
}
1500
1501
// ggml_compute_forward_mean
1502
1503
static void ggml_compute_forward_mean_f32(
1504
        const ggml_compute_params * params,
1505
0
        ggml_tensor * dst) {
1506
1507
0
    const ggml_tensor * src0 = dst->src[0];
1508
1509
0
    if (params->ith != 0) {
1510
0
        return;
1511
0
    }
1512
1513
0
    assert(src0->nb[0] == sizeof(float));
1514
1515
0
    GGML_TENSOR_UNARY_OP_LOCALS
1516
1517
0
    assert(ne0 == 1);
1518
0
    assert(ne1 == ne01);
1519
0
    assert(ne2 == ne02);
1520
0
    assert(ne3 == ne03);
1521
1522
0
    GGML_UNUSED(ne0);
1523
0
    GGML_UNUSED(ne1);
1524
0
    GGML_UNUSED(ne2);
1525
0
    GGML_UNUSED(ne3);
1526
1527
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
1528
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
1529
0
            for (int64_t i01 = 0; i01 < ne01; i01++) {
1530
0
                ggml_vec_sum_f32(ne00,
1531
0
                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
1532
0
                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
1533
1534
0
                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
1535
0
            }
1536
0
        }
1537
0
    }
1538
0
}
1539
1540
void ggml_compute_forward_mean(
1541
        const ggml_compute_params * params,
1542
0
        ggml_tensor * dst) {
1543
1544
0
    const ggml_tensor * src0 = dst->src[0];
1545
1546
0
    switch (src0->type) {
1547
0
        case GGML_TYPE_F32:
1548
0
            {
1549
0
                ggml_compute_forward_mean_f32(params, dst);
1550
0
            } break;
1551
0
        default:
1552
0
            {
1553
0
                GGML_ABORT("fatal error");
1554
0
            }
1555
0
    }
1556
0
}
1557
1558
// ggml_compute_forward_argmax
1559
1560
static void ggml_compute_forward_argmax_f32(
1561
        const ggml_compute_params * params,
1562
0
        ggml_tensor * dst) {
1563
1564
0
    const ggml_tensor * src0 = dst->src[0];
1565
1566
0
    if (params->ith != 0) {
1567
0
        return;
1568
0
    }
1569
1570
0
    assert(src0->nb[0] == sizeof(float));
1571
0
    assert(dst->nb[0] == sizeof(float));
1572
1573
0
    const int64_t ne00 = src0->ne[0];
1574
0
    const int64_t ne01 = src0->ne[1];
1575
1576
0
    const size_t nb01 = src0->nb[1];
1577
0
    const size_t nb0 = dst->nb[0];
1578
1579
0
    for (int64_t i1 = 0; i1 < ne01; i1++) {
1580
0
        float * src = (float *) ((char *) src0->data + i1*nb01);
1581
0
        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);
1582
0
        int v = 0;
1583
0
        ggml_vec_argmax_f32(ne00, &v, src);
1584
0
        dst_[0] = v;
1585
0
    }
1586
0
}
1587
1588
void ggml_compute_forward_argmax(
1589
        const ggml_compute_params * params,
1590
0
        ggml_tensor * dst) {
1591
1592
0
    const ggml_tensor * src0 = dst->src[0];
1593
1594
0
    switch (src0->type) {
1595
0
        case GGML_TYPE_F32:
1596
0
            {
1597
0
                ggml_compute_forward_argmax_f32(params, dst);
1598
0
            } break;
1599
0
        default:
1600
0
            {
1601
0
                GGML_ABORT("fatal error");
1602
0
            }
1603
0
    }
1604
0
}
1605
1606
// ggml_compute_forward_count_equal
1607
1608
static void ggml_compute_forward_count_equal_i32(
1609
        const ggml_compute_params * params,
1610
0
        ggml_tensor * dst) {
1611
1612
0
    const ggml_tensor * src0 = dst->src[0];
1613
0
    const ggml_tensor * src1 = dst->src[1];
1614
1615
0
    GGML_TENSOR_BINARY_OP_LOCALS;
1616
1617
0
    GGML_ASSERT(src0->type == GGML_TYPE_I32);
1618
0
    GGML_ASSERT(src1->type == GGML_TYPE_I32);
1619
0
    GGML_ASSERT(ggml_are_same_shape(src0, src1));
1620
0
    GGML_ASSERT(ggml_is_scalar(dst));
1621
0
    GGML_ASSERT(dst->type == GGML_TYPE_I64);
1622
1623
0
    const int64_t nr = ggml_nrows(src0);
1624
1625
0
    const int ith = params->ith;
1626
0
    const int nth = params->nth;
1627
1628
0
    int64_t * sums = (int64_t *) params->wdata;
1629
0
    int64_t sum_thread = 0;
1630
1631
    // rows per thread
1632
0
    const int64_t dr = (nr + nth - 1)/nth;
1633
1634
    // row range for this thread
1635
0
    const int64_t ir0 = dr*ith;
1636
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
1637
1638
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
1639
0
        const int64_t i03 =  ir                        / (ne02*ne01);
1640
0
        const int64_t i02 = (ir - i03*ne03)            /       ne01;
1641
0
        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
1642
1643
0
        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
1644
0
        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
1645
1646
0
        for (int64_t i00 = 0; i00 < ne00; ++i00) {
1647
0
            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
1648
0
            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
1649
1650
0
            sum_thread += val0 == val1;
1651
0
        }
1652
0
    }
1653
0
    if (ith != 0) {
1654
0
        sums[ith] = sum_thread;
1655
0
    }
1656
0
    ggml_barrier(params->threadpool);
1657
1658
0
    if (ith != 0) {
1659
0
        return;
1660
0
    }
1661
1662
0
    for (int ith_other = 1; ith_other < nth; ++ith_other) {
1663
0
        sum_thread += sums[ith_other];
1664
0
    }
1665
0
    *((int64_t *) dst->data) = sum_thread;
1666
0
}
1667
1668
void ggml_compute_forward_count_equal(
1669
        const ggml_compute_params * params,
1670
0
        ggml_tensor * dst) {
1671
1672
0
    const ggml_tensor * src0 = dst->src[0];
1673
1674
0
    switch (src0->type) {
1675
0
        case GGML_TYPE_I32:
1676
0
            {
1677
0
                ggml_compute_forward_count_equal_i32(params, dst);
1678
0
            } break;
1679
0
        default:
1680
0
            {
1681
0
                GGML_ABORT("fatal error");
1682
0
            }
1683
0
    }
1684
0
}
1685
1686
// ggml_compute_forward_repeat
1687
1688
static void ggml_compute_forward_repeat_f32(
1689
        const ggml_compute_params * params,
1690
0
        ggml_tensor * dst) {
1691
1692
0
    const ggml_tensor * src0 = dst->src[0];
1693
1694
0
    if (params->ith != 0) {
1695
0
        return;
1696
0
    }
1697
1698
0
    GGML_ASSERT(ggml_can_repeat(src0, dst));
1699
1700
0
    GGML_TENSOR_UNARY_OP_LOCALS
1701
1702
    // guaranteed to be an integer due to the check in ggml_can_repeat
1703
0
    const int nr0 = (int)(ne0/ne00);
1704
0
    const int nr1 = (int)(ne1/ne01);
1705
0
    const int nr2 = (int)(ne2/ne02);
1706
0
    const int nr3 = (int)(ne3/ne03);
1707
1708
    // TODO: support for transposed / permuted tensors
1709
0
    GGML_ASSERT(nb0  == sizeof(float));
1710
0
    GGML_ASSERT(nb00 == sizeof(float));
1711
1712
    // TODO: maybe this is not optimal?
1713
0
    for                         (int i3 = 0; i3 < nr3;  i3++) {
1714
0
        for                     (int k3 = 0; k3 < ne03; k3++) {
1715
0
            for                 (int i2 = 0; i2 < nr2;  i2++) {
1716
0
                for             (int k2 = 0; k2 < ne02; k2++) {
1717
0
                    for         (int i1 = 0; i1 < nr1;  i1++) {
1718
0
                        for     (int k1 = 0; k1 < ne01; k1++) {
1719
0
                            for (int i0 = 0; i0 < nr0;  i0++) {
1720
0
                                ggml_vec_cpy_f32(ne00,
1721
0
                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),
1722
0
                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));
1723
0
                            }
1724
0
                        }
1725
0
                    }
1726
0
                }
1727
0
            }
1728
0
        }
1729
0
    }
1730
0
}
1731
1732
static void ggml_compute_forward_repeat_f16(
1733
        const ggml_compute_params * params,
1734
0
        ggml_tensor * dst) {
1735
1736
0
    const ggml_tensor * src0 = dst->src[0];
1737
1738
0
    if (params->ith != 0) {
1739
0
        return;
1740
0
    }
1741
1742
0
    GGML_ASSERT(ggml_can_repeat(src0, dst));
1743
1744
0
    GGML_TENSOR_UNARY_OP_LOCALS
1745
1746
    // guaranteed to be an integer due to the check in ggml_can_repeat
1747
0
    const int nr0 = (int)(ne0/ne00);
1748
0
    const int nr1 = (int)(ne1/ne01);
1749
0
    const int nr2 = (int)(ne2/ne02);
1750
0
    const int nr3 = (int)(ne3/ne03);
1751
1752
    // TODO: support for transposed / permuted tensors
1753
0
    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));
1754
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
1755
1756
    // TODO: maybe this is not optimal?
1757
0
    for                         (int i3 = 0; i3 < nr3;  i3++) {
1758
0
        for                     (int k3 = 0; k3 < ne03; k3++) {
1759
0
            for                 (int i2 = 0; i2 < nr2;  i2++) {
1760
0
                for             (int k2 = 0; k2 < ne02; k2++) {
1761
0
                    for         (int i1 = 0; i1 < nr1;  i1++) {
1762
0
                        for     (int k1 = 0; k1 < ne01; k1++) {
1763
0
                            for (int i0 = 0; i0 < nr0;  i0++) {
1764
0
                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
1765
0
                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
1766
                                // ggml_vec_cpy_f16(ne00, y, x)
1767
0
                                for (int i = 0; i < ne00; ++i) {
1768
0
                                    y[i]  = x[i];
1769
0
                                }
1770
0
                            }
1771
0
                        }
1772
0
                    }
1773
0
                }
1774
0
            }
1775
0
        }
1776
0
    }
1777
0
}
1778
1779
void ggml_compute_forward_repeat(
1780
        const ggml_compute_params * params,
1781
0
        ggml_tensor * dst) {
1782
1783
0
    const ggml_tensor * src0 = dst->src[0];
1784
1785
0
    switch (src0->type) {
1786
0
        case GGML_TYPE_F16:
1787
0
        case GGML_TYPE_BF16:
1788
0
        case GGML_TYPE_I16:
1789
0
            {
1790
0
                ggml_compute_forward_repeat_f16(params, dst);
1791
0
            } break;
1792
0
        case GGML_TYPE_F32:
1793
0
        case GGML_TYPE_I32:
1794
0
            {
1795
0
                ggml_compute_forward_repeat_f32(params, dst);
1796
0
            } break;
1797
        // TODO: templateify the implemenation and support for I64
1798
        //       ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
1799
        //case GGML_TYPE_I64:
1800
        //    {
1801
        //        ggml_compute_forward_repeat_i64(params, dst);
1802
        //    } break;
1803
0
        default:
1804
0
            {
1805
0
                GGML_ABORT("fatal error");
1806
0
            }
1807
0
    }
1808
0
}
1809
1810
// ggml_compute_forward_repeat_back
1811
1812
static void ggml_compute_forward_repeat_back_f32(
1813
        const ggml_compute_params * params,
1814
0
        ggml_tensor * dst) {
1815
1816
0
    const ggml_tensor * src0 = dst->src[0];
1817
1818
0
    if (params->ith != 0) {
1819
0
        return;
1820
0
    }
1821
1822
0
    GGML_ASSERT(ggml_can_repeat(dst, src0));
1823
1824
0
    GGML_TENSOR_UNARY_OP_LOCALS
1825
1826
    // guaranteed to be an integer due to the check in ggml_can_repeat
1827
0
    const int nr0 = (int)(ne00/ne0);
1828
0
    const int nr1 = (int)(ne01/ne1);
1829
0
    const int nr2 = (int)(ne02/ne2);
1830
0
    const int nr3 = (int)(ne03/ne3);
1831
1832
    // TODO: support for transposed / permuted tensors
1833
0
    GGML_ASSERT(nb0  == sizeof(float));
1834
0
    GGML_ASSERT(nb00 == sizeof(float));
1835
1836
0
    if (ggml_is_contiguous(dst)) {
1837
0
        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
1838
0
    } else {
1839
0
        for         (int k3 = 0; k3 < ne3; k3++) {
1840
0
            for     (int k2 = 0; k2 < ne2; k2++) {
1841
0
                for (int k1 = 0; k1 < ne1; k1++) {
1842
0
                    ggml_vec_set_f32(ne0,
1843
0
                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
1844
0
                        0);
1845
0
                }
1846
0
            }
1847
0
        }
1848
0
    }
1849
1850
    // TODO: maybe this is not optimal?
1851
0
    for                         (int i3 = 0; i3 < nr3; i3++) {
1852
0
        for                     (int k3 = 0; k3 < ne3; k3++) {
1853
0
            for                 (int i2 = 0; i2 < nr2; i2++) {
1854
0
                for             (int k2 = 0; k2 < ne2; k2++) {
1855
0
                    for         (int i1 = 0; i1 < nr1; i1++) {
1856
0
                        for     (int k1 = 0; k1 < ne1; k1++) {
1857
0
                            for (int i0 = 0; i0 < nr0; i0++) {
1858
0
                                ggml_vec_acc_f32(ne0,
1859
0
                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
1860
0
                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
1861
0
                            }
1862
0
                        }
1863
0
                    }
1864
0
                }
1865
0
            }
1866
0
        }
1867
0
    }
1868
0
}
1869
1870
void ggml_compute_forward_repeat_back(
1871
        const ggml_compute_params * params,
1872
0
        ggml_tensor * dst) {
1873
1874
0
    const ggml_tensor * src0 = dst->src[0];
1875
1876
0
    switch (src0->type) {
1877
0
        case GGML_TYPE_F32:
1878
0
            {
1879
0
                ggml_compute_forward_repeat_back_f32(params, dst);
1880
0
            } break;
1881
0
        default:
1882
0
            {
1883
0
                GGML_ABORT("fatal error");
1884
0
            }
1885
0
    }
1886
0
}
1887
1888
// ggml_compute_forward_concat
1889
1890
static void ggml_compute_forward_concat_any(
1891
    const ggml_compute_params * params,
1892
0
    ggml_tensor * dst) {
1893
1894
0
    const ggml_tensor * src0 = dst->src[0];
1895
0
    const ggml_tensor * src1 = dst->src[1];
1896
1897
0
    const size_t len = ggml_type_size(src0->type);
1898
1899
0
    const int ith = params->ith;
1900
0
    const int nth = params->nth;
1901
1902
0
    GGML_TENSOR_BINARY_OP_LOCALS
1903
1904
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
1905
1906
0
    GGML_ASSERT(dim >= 0 && dim < 4);
1907
1908
0
    int64_t o[4] = {0, 0, 0, 0};
1909
0
    o[dim] = src0->ne[dim];
1910
1911
0
    const char * x;
1912
1913
    // TODO: smarter multi-theading
1914
0
    for (int i3 = 0; i3 < ne3; i3++) {
1915
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
1916
0
            for (int i1 = 0; i1 < ne1; i1++) {
1917
0
                for (int i0 = 0; i0 < ne0; i0++) {
1918
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1919
0
                        x = (const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03;
1920
0
                    } else {
1921
0
                        x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
1922
0
                    }
1923
1924
0
                    char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
1925
1926
0
                    memcpy(y, x, len);
1927
0
                }
1928
0
            }
1929
0
        }
1930
0
    }
1931
0
}
1932
1933
static void ggml_compute_forward_concat_i8(
1934
    const ggml_compute_params * params,
1935
0
    ggml_tensor * dst) {
1936
1937
0
    const ggml_tensor * src0 = dst->src[0];
1938
0
    const ggml_tensor * src1 = dst->src[1];
1939
1940
0
    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
1941
1942
0
    const int ith = params->ith;
1943
0
    const int nth = params->nth;
1944
1945
0
    GGML_TENSOR_BINARY_OP_LOCALS
1946
1947
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
1948
1949
0
    GGML_ASSERT(dim >= 0 && dim < 4);
1950
1951
0
    int64_t o[4] = {0, 0, 0, 0};
1952
0
    o[dim] = src0->ne[dim];
1953
1954
0
    const int8_t * x;
1955
1956
    // TODO: smarter multi-theading
1957
0
    for (int i3 = 0; i3 < ne3; i3++) {
1958
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
1959
0
            for (int i1 = 0; i1 < ne1; i1++) {
1960
0
                for (int i0 = 0; i0 < ne0; i0++) {
1961
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1962
0
                        x = (const int8_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
1963
0
                    } else {
1964
0
                        x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
1965
0
                    }
1966
1967
0
                    int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
1968
1969
0
                    *y = *x;
1970
0
                }
1971
0
            }
1972
0
        }
1973
0
    }
1974
0
}
1975
1976
static void ggml_compute_forward_concat_f16(
1977
    const ggml_compute_params * params,
1978
0
    ggml_tensor * dst) {
1979
1980
0
    const ggml_tensor * src0 = dst->src[0];
1981
0
    const ggml_tensor * src1 = dst->src[1];
1982
1983
0
    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
1984
1985
0
    const int ith = params->ith;
1986
0
    const int nth = params->nth;
1987
1988
0
    GGML_TENSOR_BINARY_OP_LOCALS
1989
1990
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
1991
1992
0
    GGML_ASSERT(dim >= 0 && dim < 4);
1993
1994
0
    int64_t o[4] = {0, 0, 0, 0};
1995
0
    o[dim] = src0->ne[dim];
1996
1997
0
    const ggml_fp16_t * x;
1998
1999
    // TODO: smarter multi-theading
2000
0
    for (int i3 = 0; i3 < ne3; i3++) {
2001
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
2002
0
            for (int i1 = 0; i1 < ne1; i1++) {
2003
0
                for (int i0 = 0; i0 < ne0; i0++) {
2004
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
2005
0
                        x = (const ggml_fp16_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
2006
0
                    } else {
2007
0
                        x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
2008
0
                    }
2009
2010
0
                    ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
2011
2012
0
                    *y = *x;
2013
0
                }
2014
0
            }
2015
0
        }
2016
0
    }
2017
0
}
2018
2019
static void ggml_compute_forward_concat_f32(
2020
    const ggml_compute_params * params,
2021
0
    ggml_tensor * dst) {
2022
2023
0
    const ggml_tensor * src0 = dst->src[0];
2024
0
    const ggml_tensor * src1 = dst->src[1];
2025
2026
0
    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
2027
2028
0
    const int ith = params->ith;
2029
0
    const int nth = params->nth;
2030
2031
0
    GGML_TENSOR_BINARY_OP_LOCALS
2032
2033
0
    const int32_t dim = ggml_get_op_params_i32(dst, 0);
2034
2035
0
    GGML_ASSERT(dim >= 0 && dim < 4);
2036
2037
0
    int64_t o[4] = {0, 0, 0, 0};
2038
0
    o[dim] = src0->ne[dim];
2039
2040
0
    const float * x;
2041
2042
    // TODO: smarter multi-theading
2043
0
    for (int i3 = 0; i3 < ne3; i3++) {
2044
0
        for (int i2 = ith; i2 < ne2; i2 += nth) {
2045
0
            for (int i1 = 0; i1 < ne1; i1++) {
2046
0
                for (int i0 = 0; i0 < ne0; i0++) {
2047
0
                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
2048
0
                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
2049
0
                    } else {
2050
0
                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
2051
0
                    }
2052
2053
0
                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
2054
2055
0
                    *y = *x;
2056
0
                }
2057
0
            }
2058
0
        }
2059
0
    }
2060
0
}
2061
2062
void ggml_compute_forward_concat(
2063
    const ggml_compute_params * params,
2064
0
    ggml_tensor * dst) {
2065
2066
0
    const ggml_tensor * src0 = dst->src[0];
2067
2068
0
    switch (src0->type) {
2069
0
        case GGML_TYPE_F16:
2070
0
        case GGML_TYPE_BF16:
2071
0
        case GGML_TYPE_I16:
2072
0
            {
2073
0
                ggml_compute_forward_concat_f16(params, dst);
2074
0
            } break;
2075
0
        case GGML_TYPE_I8:
2076
0
            {
2077
0
                ggml_compute_forward_concat_i8(params, dst);
2078
0
            } break;
2079
0
        case GGML_TYPE_F32:
2080
0
        case GGML_TYPE_I32:
2081
0
            {
2082
0
                ggml_compute_forward_concat_f32(params, dst);
2083
0
            } break;
2084
0
        default:
2085
0
            {
2086
0
                ggml_compute_forward_concat_any(params, dst);
2087
0
            }
2088
0
    }
2089
0
}
2090
2091
// ggml_compute_forward_gelu
2092
2093
static void ggml_compute_forward_gelu_f32(
2094
        const ggml_compute_params * params,
2095
0
        ggml_tensor * dst) {
2096
2097
0
    const ggml_tensor * src0 = dst->src[0];
2098
2099
0
    assert(ggml_is_contiguous_1(src0));
2100
0
    assert(ggml_is_contiguous_1(dst));
2101
0
    assert(ggml_are_same_shape(src0, dst));
2102
2103
0
    const int ith = params->ith;
2104
0
    const int nth = params->nth;
2105
2106
0
    const int nc = src0->ne[0];
2107
0
    const int nr = ggml_nrows(src0);
2108
2109
    // rows per thread
2110
0
    const int dr = (nr + nth - 1)/nth;
2111
2112
    // row range for this thread
2113
0
    const int ir0 = dr*ith;
2114
0
    const int ir1 = MIN(ir0 + dr, nr);
2115
2116
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2117
0
        ggml_vec_gelu_f32(nc,
2118
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2119
0
                (float *) ((char *) src0->data + i1*(src0->nb[1])));
2120
2121
#ifndef NDEBUG
2122
        for (int k = 0; k < nc; k++) {
2123
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2124
            GGML_UNUSED(x);
2125
            assert(!isnan(x));
2126
            assert(!isinf(x));
2127
        }
2128
#endif
2129
0
    }
2130
0
}
2131
2132
static void ggml_compute_forward_gelu_f16(
2133
    const ggml_compute_params * params,
2134
0
    ggml_tensor * dst) {
2135
2136
0
    const ggml_tensor * src0 = dst->src[0];
2137
2138
0
    assert(ggml_is_contiguous_1(src0));
2139
0
    assert(ggml_is_contiguous_1(dst));
2140
0
    assert(ggml_are_same_shape(src0, dst));
2141
2142
0
    const int ith = params->ith;
2143
0
    const int nth = params->nth;
2144
2145
0
    const int nc = src0->ne[0];
2146
0
    const int nr = ggml_nrows(src0);
2147
2148
    // rows per thread
2149
0
    const int dr = (nr + nth - 1)/nth;
2150
2151
    // row range for this thread
2152
0
    const int ir0 = dr*ith;
2153
0
    const int ir1 = MIN(ir0 + dr, nr);
2154
2155
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2156
0
        ggml_vec_gelu_f16(nc,
2157
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2158
0
                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2159
2160
#ifndef NDEBUG
2161
        for (int k = 0; k < nc; k++) {
2162
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2163
            const float v = GGML_CPU_FP16_TO_FP32(x);
2164
            GGML_UNUSED(v);
2165
            assert(!isnan(v));
2166
            assert(!isinf(v));
2167
        }
2168
#endif
2169
0
    }
2170
0
}
2171
2172
static void ggml_compute_forward_gelu(
2173
        const ggml_compute_params * params,
2174
0
        ggml_tensor * dst) {
2175
2176
0
    const ggml_tensor * src0 = dst->src[0];
2177
2178
0
    switch (src0->type) {
2179
0
        case GGML_TYPE_F32:
2180
0
            {
2181
0
                ggml_compute_forward_gelu_f32(params, dst);
2182
0
            } break;
2183
0
        case GGML_TYPE_F16:
2184
0
            {
2185
0
                ggml_compute_forward_gelu_f16(params, dst);
2186
0
            } break;
2187
0
        default:
2188
0
            {
2189
0
                GGML_ABORT("fatal error");
2190
0
            }
2191
0
    }
2192
0
}
2193
2194
// ggml_compute_fill
2195
2196
0
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2197
0
    const float c = ggml_get_op_params_f32(dst, 0);
2198
2199
0
    GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2200
0
    GGML_TENSOR_LOCALS(size_t,  nb, dst, nb);
2201
2202
0
    const auto [ir0, ir1] = get_thread_range(params, dst);
2203
2204
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
2205
0
        const int64_t i03 = ir/(ne2*ne1);
2206
0
        const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2207
0
        const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2208
2209
0
        float * dst_ptr  = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2210
2211
0
        ggml_vec_set_f32(ne0, dst_ptr, c);
2212
0
    }
2213
0
}
2214
2215
0
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
2216
0
    ggml_compute_forward_fill_f32(params, dst);
2217
0
}
2218
2219
// ggml_compute_tri
2220
2221
0
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2222
0
    const ggml_tensor * src0 = dst->src[0];
2223
2224
0
    const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2225
2226
0
    GGML_ASSERT(ggml_is_contiguous(src0));
2227
2228
0
    GGML_TENSOR_UNARY_OP_LOCALS
2229
2230
0
    const auto [ir0, ir1] = get_thread_range(params, src0);
2231
2232
0
    bool (*bipred)(int, int);
2233
2234
0
    switch (ttype) {
2235
0
        case GGML_TRI_TYPE_LOWER:      bipred = [](int i, int r) { return i <  r; }; break;
2236
0
        case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2237
0
        case GGML_TRI_TYPE_UPPER:      bipred = [](int i, int r) { return i >  r; }; break;
2238
0
        case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
2239
0
        default: GGML_ABORT("invalid tri type");
2240
0
    }
2241
2242
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
2243
0
        const int64_t i03 = ir/(ne02*ne01);
2244
0
        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
2245
0
        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
2246
2247
0
        const float * src_ptr = (const float  *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
2248
0
              float * dst_ptr = (      float  *) ((      char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1);
2249
2250
0
        for (int i0 = 0; i0 < ne0; ++i0) {
2251
0
            dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
2252
0
        }
2253
0
    }
2254
0
}
2255
2256
0
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
2257
0
    const ggml_tensor * src0 = dst->src[0];
2258
2259
0
    switch (src0->type) {
2260
0
        case GGML_TYPE_F32:
2261
0
            {
2262
0
                ggml_compute_forward_tri_f32(params, dst);
2263
0
            } break;
2264
0
        default:
2265
0
            {
2266
0
                GGML_ABORT("fatal error");
2267
0
            }
2268
0
    }
2269
0
}
2270
2271
// ggml_compute_forward_gelu_erf
2272
2273
static void ggml_compute_forward_gelu_erf_f32(
2274
        const ggml_compute_params * params,
2275
0
        ggml_tensor * dst) {
2276
2277
0
    const ggml_tensor * src0 = dst->src[0];
2278
2279
0
    assert(ggml_is_contiguous_1(src0));
2280
0
    assert(ggml_is_contiguous_1(dst));
2281
0
    assert(ggml_are_same_shape(src0, dst));
2282
2283
0
    const int ith = params->ith;
2284
0
    const int nth = params->nth;
2285
2286
0
    const int nc = src0->ne[0];
2287
0
    const int nr = ggml_nrows(src0);
2288
2289
    // rows per thread
2290
0
    const int dr = (nr + nth - 1)/nth;
2291
2292
    // row range for this thread
2293
0
    const int ir0 = dr*ith;
2294
0
    const int ir1 = MIN(ir0 + dr, nr);
2295
2296
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2297
0
        ggml_vec_gelu_erf_f32(nc,
2298
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2299
0
                (float *) ((char *) src0->data + i1*(src0->nb[1])));
2300
2301
#ifndef NDEBUG
2302
        for (int k = 0; k < nc; k++) {
2303
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2304
            GGML_UNUSED(x);
2305
            assert(!isnan(x));
2306
            assert(!isinf(x));
2307
        }
2308
#endif
2309
0
    }
2310
0
}
2311
2312
static void ggml_compute_forward_gelu_erf_f16(
2313
    const ggml_compute_params * params,
2314
0
    ggml_tensor * dst) {
2315
2316
0
    const ggml_tensor * src0 = dst->src[0];
2317
2318
0
    assert(ggml_is_contiguous_1(src0));
2319
0
    assert(ggml_is_contiguous_1(dst));
2320
0
    assert(ggml_are_same_shape(src0, dst));
2321
2322
0
    const int ith = params->ith;
2323
0
    const int nth = params->nth;
2324
2325
0
    const int nc = src0->ne[0];
2326
0
    const int nr = ggml_nrows(src0);
2327
2328
    // rows per thread
2329
0
    const int dr = (nr + nth - 1)/nth;
2330
2331
    // row range for this thread
2332
0
    const int ir0 = dr*ith;
2333
0
    const int ir1 = MIN(ir0 + dr, nr);
2334
2335
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2336
0
        ggml_vec_gelu_erf_f16(nc,
2337
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2338
0
                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2339
2340
#ifndef NDEBUG
2341
        for (int k = 0; k < nc; k++) {
2342
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2343
            const float v = GGML_CPU_FP16_TO_FP32(x);
2344
            GGML_UNUSED(v);
2345
            assert(!isnan(v));
2346
            assert(!isinf(v));
2347
        }
2348
#endif
2349
0
    }
2350
0
}
2351
2352
static void ggml_compute_forward_gelu_erf(
2353
        const ggml_compute_params * params,
2354
0
        ggml_tensor * dst) {
2355
2356
0
    const ggml_tensor * src0 = dst->src[0];
2357
2358
0
    switch (src0->type) {
2359
0
        case GGML_TYPE_F32:
2360
0
            {
2361
0
                ggml_compute_forward_gelu_erf_f32(params, dst);
2362
0
            } break;
2363
0
        case GGML_TYPE_F16:
2364
0
            {
2365
0
                ggml_compute_forward_gelu_erf_f16(params, dst);
2366
0
            } break;
2367
0
        default:
2368
0
            {
2369
0
                GGML_ABORT("fatal error");
2370
0
            }
2371
0
    }
2372
0
}
2373
2374
// ggml_compute_forward_gelu_quick
2375
2376
static void ggml_compute_forward_gelu_quick_f32(
2377
        const ggml_compute_params * params,
2378
0
        ggml_tensor * dst) {
2379
2380
0
    const ggml_tensor * src0 = dst->src[0];
2381
2382
0
    assert(ggml_is_contiguous_1(src0));
2383
0
    assert(ggml_is_contiguous_1(dst));
2384
0
    assert(ggml_are_same_shape(src0, dst));
2385
2386
0
    const int ith = params->ith;
2387
0
    const int nth = params->nth;
2388
2389
0
    const int nc = src0->ne[0];
2390
0
    const int nr = ggml_nrows(src0);
2391
2392
    // rows per thread
2393
0
    const int dr = (nr + nth - 1)/nth;
2394
2395
    // row range for this thread
2396
0
    const int ir0 = dr*ith;
2397
0
    const int ir1 = MIN(ir0 + dr, nr);
2398
2399
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2400
0
        ggml_vec_gelu_quick_f32(nc,
2401
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2402
0
                (float *) ((char *) src0->data + i1*(src0->nb[1])));
2403
2404
#ifndef NDEBUG
2405
        for (int k = 0; k < nc; k++) {
2406
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2407
            GGML_UNUSED(x);
2408
            assert(!isnan(x));
2409
            assert(!isinf(x));
2410
        }
2411
#endif
2412
0
    }
2413
0
}
2414
2415
static void ggml_compute_forward_gelu_quick_f16(
2416
    const ggml_compute_params * params,
2417
0
    ggml_tensor * dst) {
2418
2419
0
    const ggml_tensor * src0 = dst->src[0];
2420
2421
0
    assert(ggml_is_contiguous_1(src0));
2422
0
    assert(ggml_is_contiguous_1(dst));
2423
0
    assert(ggml_are_same_shape(src0, dst));
2424
2425
0
    const int ith = params->ith;
2426
0
    const int nth = params->nth;
2427
2428
0
    const int nc = src0->ne[0];
2429
0
    const int nr = ggml_nrows(src0);
2430
2431
    // rows per thread
2432
0
    const int dr = (nr + nth - 1)/nth;
2433
2434
    // row range for this thread
2435
0
    const int ir0 = dr*ith;
2436
0
    const int ir1 = MIN(ir0 + dr, nr);
2437
2438
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2439
0
        ggml_vec_gelu_quick_f16(nc,
2440
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2441
0
                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2442
2443
#ifndef NDEBUG
2444
        for (int k = 0; k < nc; k++) {
2445
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2446
            const float v = GGML_CPU_FP16_TO_FP32(x);
2447
            GGML_UNUSED(v);
2448
            assert(!isnan(v));
2449
            assert(!isinf(v));
2450
        }
2451
#endif
2452
0
    }
2453
0
}
2454
2455
static void ggml_compute_forward_gelu_quick(
2456
        const ggml_compute_params * params,
2457
0
        ggml_tensor * dst) {
2458
2459
0
    const ggml_tensor * src0 = dst->src[0];
2460
2461
0
    switch (src0->type) {
2462
0
        case GGML_TYPE_F32:
2463
0
            {
2464
0
                ggml_compute_forward_gelu_quick_f32(params, dst);
2465
0
            } break;
2466
0
        case GGML_TYPE_F16:
2467
0
            {
2468
0
                ggml_compute_forward_gelu_quick_f16(params, dst);
2469
0
            } break;
2470
0
        default:
2471
0
            {
2472
0
                GGML_ABORT("fatal error");
2473
0
            }
2474
0
    }
2475
0
}
2476
2477
// ggml_compute_forward_silu
2478
2479
static void ggml_compute_forward_silu_f32(
2480
        const ggml_compute_params * params,
2481
0
        ggml_tensor * dst) {
2482
2483
0
    const ggml_tensor * src0 = dst->src[0];
2484
2485
0
    assert(ggml_is_contiguous_1(src0));
2486
0
    assert(ggml_is_contiguous_1(dst));
2487
0
    assert(ggml_are_same_shape(src0, dst));
2488
2489
0
    const int ith = params->ith;
2490
0
    const int nth = params->nth;
2491
2492
0
    const int nc = src0->ne[0];
2493
0
    const int nr = ggml_nrows(src0);
2494
2495
    // rows per thread
2496
0
    const int dr = (nr + nth - 1)/nth;
2497
2498
    // row range for this thread
2499
0
    const int ir0 = dr*ith;
2500
0
    const int ir1 = MIN(ir0 + dr, nr);
2501
2502
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2503
0
        ggml_vec_silu_f32(nc,
2504
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2505
0
                (float *) ((char *) src0->data + i1*(src0->nb[1])));
2506
2507
#ifndef NDEBUG
2508
        for (int k = 0; k < nc; k++) {
2509
            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2510
            GGML_UNUSED(x);
2511
            assert(!isnan(x));
2512
            assert(!isinf(x));
2513
        }
2514
#endif
2515
0
    }
2516
0
}
2517
2518
static void ggml_compute_forward_silu_f16(
2519
    const ggml_compute_params * params,
2520
0
    ggml_tensor * dst) {
2521
2522
0
    const ggml_tensor * src0 = dst->src[0];
2523
2524
0
    assert(ggml_is_contiguous_1(src0));
2525
0
    assert(ggml_is_contiguous_1(dst));
2526
0
    assert(ggml_are_same_shape(src0, dst));
2527
2528
0
    const int ith = params->ith;
2529
0
    const int nth = params->nth;
2530
2531
0
    const int nc = src0->ne[0];
2532
0
    const int nr = ggml_nrows(src0);
2533
2534
    // rows per thread
2535
0
    const int dr = (nr + nth - 1)/nth;
2536
2537
    // row range for this thread
2538
0
    const int ir0 = dr*ith;
2539
0
    const int ir1 = MIN(ir0 + dr, nr);
2540
2541
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2542
0
        ggml_vec_silu_f16(nc,
2543
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2544
0
                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2545
2546
#ifndef NDEBUG
2547
        for (int k = 0; k < nc; k++) {
2548
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2549
            const float v = GGML_CPU_FP16_TO_FP32(x);
2550
            GGML_UNUSED(v);
2551
            assert(!isnan(v));
2552
            assert(!isinf(v));
2553
        }
2554
#endif
2555
0
    }
2556
0
}
2557
2558
static void ggml_compute_forward_silu(
2559
        const ggml_compute_params * params,
2560
0
        ggml_tensor * dst) {
2561
2562
0
    const ggml_tensor * src0 = dst->src[0];
2563
2564
0
    switch (src0->type) {
2565
0
        case GGML_TYPE_F32:
2566
0
            {
2567
0
                ggml_compute_forward_silu_f32(params, dst);
2568
0
            } break;
2569
0
        case GGML_TYPE_F16:
2570
0
            {
2571
0
                ggml_compute_forward_silu_f16(params, dst);
2572
0
            } break;
2573
0
        default:
2574
0
            {
2575
0
                GGML_ABORT("fatal error");
2576
0
            }
2577
0
    }
2578
0
}
2579
// ggml_compute_forward_leaky_relu
2580
2581
static void ggml_compute_forward_leaky_relu_f32(
2582
        const ggml_compute_params * params,
2583
0
        ggml_tensor * dst) {
2584
2585
0
    const ggml_tensor * src0 = dst->src[0];
2586
2587
0
    if (params->ith != 0) {
2588
0
        return;
2589
0
    }
2590
2591
0
    assert(ggml_is_contiguous_1(src0));
2592
0
    assert(ggml_is_contiguous_1(dst));
2593
0
    assert(ggml_are_same_shape(src0, dst));
2594
2595
0
    const int n  = ggml_nrows(src0);
2596
0
    const int nc = src0->ne[0];
2597
2598
0
    float negative_slope;
2599
0
    memcpy(&negative_slope, dst->op_params, sizeof(float));
2600
2601
0
    assert(dst->nb[0]  == sizeof(float));
2602
0
    assert(src0->nb[0] == sizeof(float));
2603
2604
0
    for (int i = 0; i < n; i++) {
2605
0
        ggml_vec_leaky_relu_f32(nc,
2606
0
                (float *) ((char *) dst->data  + i*( dst->nb[1])),
2607
0
                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2608
0
    }
2609
0
}
2610
2611
static void ggml_compute_forward_leaky_relu_f16(
2612
    const ggml_compute_params * params,
2613
0
    ggml_tensor * dst) {
2614
2615
0
    const ggml_tensor * src0 = dst->src[0];
2616
2617
0
    if (params->ith != 0) {
2618
0
        return;
2619
0
    }
2620
2621
0
    assert(ggml_is_contiguous_1(src0));
2622
0
    assert(ggml_is_contiguous_1(dst));
2623
0
    assert(ggml_are_same_shape(src0, dst));
2624
2625
0
    const int n  = ggml_nrows(src0);
2626
0
    const int nc = src0->ne[0];
2627
2628
0
    float negative_slope;
2629
0
    memcpy(&negative_slope, dst->op_params, sizeof(float));
2630
2631
0
    assert(dst->nb[0]  == sizeof(ggml_fp16_t));
2632
0
    assert(src0->nb[0] == sizeof(ggml_fp16_t));
2633
2634
0
    for (int i = 0; i < n; i++) {
2635
0
        ggml_vec_leaky_relu_f16(nc,
2636
0
                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
2637
0
                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2638
0
    }
2639
0
}
2640
2641
void ggml_compute_forward_leaky_relu(
2642
        const ggml_compute_params * params,
2643
0
        ggml_tensor * dst) {
2644
2645
0
    const ggml_tensor * src0 = dst->src[0];
2646
2647
0
    switch (src0->type) {
2648
0
        case GGML_TYPE_F32:
2649
0
            {
2650
0
                ggml_compute_forward_leaky_relu_f32(params, dst);
2651
0
            } break;
2652
0
        case GGML_TYPE_F16:
2653
0
            {
2654
0
                ggml_compute_forward_leaky_relu_f16(params, dst);
2655
0
            } break;
2656
0
        default:
2657
0
            {
2658
0
                GGML_ABORT("fatal error");
2659
0
            }
2660
0
    }
2661
0
}
2662
2663
// ggml_compute_forward_silu_back
2664
2665
static void ggml_compute_forward_silu_back_f32(
2666
        const ggml_compute_params * params,
2667
0
        ggml_tensor * dst) {
2668
2669
0
    const ggml_tensor * grad = dst->src[0];
2670
0
    const ggml_tensor * src1 = dst->src[1];
2671
2672
0
    assert(ggml_is_contiguous_1(grad));
2673
0
    assert(ggml_is_contiguous_1(src1));
2674
0
    assert(ggml_is_contiguous_1(dst));
2675
0
    assert(ggml_are_same_shape(src1, dst));
2676
0
    assert(ggml_are_same_shape(src1, grad));
2677
2678
0
    const int ith = params->ith;
2679
0
    const int nth = params->nth;
2680
2681
0
    const int nc = src1->ne[0];
2682
0
    const int nr = ggml_nrows(src1);
2683
2684
    // rows per thread
2685
0
    const int dr = (nr + nth - 1)/nth;
2686
2687
    // row range for this thread
2688
0
    const int ir0 = dr*ith;
2689
0
    const int ir1 = MIN(ir0 + dr, nr);
2690
2691
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2692
0
        ggml_vec_silu_backward_f32(nc,
2693
0
                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
2694
0
                (float *) ((char *) src1->data + i1*(src1->nb[1])),
2695
0
                (float *) ((char *) grad->data + i1*(grad->nb[1])));
2696
2697
#ifndef NDEBUG
2698
        for (int k = 0; k < nc; k++) {
2699
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2700
            GGML_UNUSED(x);
2701
            assert(!isnan(x));
2702
            assert(!isinf(x));
2703
        }
2704
#endif
2705
0
    }
2706
0
}
2707
2708
static void ggml_compute_forward_silu_back_f16(
2709
    const ggml_compute_params * params,
2710
0
    ggml_tensor * dst) {
2711
2712
0
    const ggml_tensor * grad = dst->src[0];
2713
0
    const ggml_tensor * src1 = dst->src[1];
2714
2715
0
    assert(ggml_is_contiguous_1(grad));
2716
0
    assert(ggml_is_contiguous_1(src1));
2717
0
    assert(ggml_is_contiguous_1(dst));
2718
0
    assert(ggml_are_same_shape(src1, dst));
2719
0
    assert(ggml_are_same_shape(src1, grad));
2720
2721
0
    const int ith = params->ith;
2722
0
    const int nth = params->nth;
2723
2724
0
    const int nc = src1->ne[0];
2725
0
    const int nr = ggml_nrows(src1);
2726
2727
    // rows per thread
2728
0
    const int dr = (nr + nth - 1)/nth;
2729
2730
    // row range for this thread
2731
0
    const int ir0 = dr*ith;
2732
0
    const int ir1 = MIN(ir0 + dr, nr);
2733
2734
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2735
0
        ggml_vec_silu_backward_f16(nc,
2736
0
                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
2737
0
                (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
2738
0
                (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2739
2740
    #ifndef NDEBUG
2741
        for (int k = 0; k < nc; k++) {
2742
            const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2743
            const float v = GGML_CPU_FP16_TO_FP32(x);
2744
            GGML_UNUSED(v);
2745
            assert(!isnan(v));
2746
            assert(!isinf(v));
2747
        }
2748
    #endif
2749
0
    }
2750
0
}
2751
2752
void ggml_compute_forward_silu_back(
2753
        const ggml_compute_params * params,
2754
0
        ggml_tensor * dst) {
2755
2756
0
    const ggml_tensor * src0 = dst->src[0];
2757
2758
0
    switch (src0->type) {
2759
0
        case GGML_TYPE_F32:
2760
0
            {
2761
0
                ggml_compute_forward_silu_back_f32(params, dst);
2762
0
            } break;
2763
0
        case GGML_TYPE_F16:
2764
0
            {
2765
0
                ggml_compute_forward_silu_back_f16(params, dst);
2766
0
            } break;
2767
0
        default:
2768
0
            {
2769
0
                GGML_ABORT("fatal error");
2770
0
            }
2771
0
    }
2772
0
}
2773
2774
// ggml_compute_forward_reglu
2775
2776
static void ggml_compute_forward_reglu_f32(
2777
        const ggml_compute_params * params,
2778
0
        ggml_tensor * dst) {
2779
2780
0
    const ggml_tensor * src0 = dst->src[0];
2781
0
    const ggml_tensor * src1 = dst->src[1];
2782
0
    char * src0_d = (char *) src0->data;
2783
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
2784
0
    const size_t src0_o = src0->nb[1];
2785
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2786
2787
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
2788
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
2789
2790
0
    if (src1) {
2791
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
2792
0
        GGML_ASSERT(src0->type == src1->type);
2793
0
    }
2794
2795
0
    const int ith = params->ith;
2796
0
    const int nth = params->nth;
2797
2798
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2799
0
    const int nr = ggml_nrows(src0);
2800
2801
0
    GGML_ASSERT(dst->ne[0] == nc);
2802
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
2803
2804
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2805
2806
    // rows per thread
2807
0
    const int dr = (nr + nth - 1)/nth;
2808
2809
    // row range for this thread
2810
0
    const int ir0 = dr*ith;
2811
0
    const int ir1 = MIN(ir0 + dr, nr);
2812
2813
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2814
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
2815
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
2816
2817
0
        if (!src1) {
2818
0
            src0_p += swapped ? nc : 0;
2819
0
            src1_p += swapped ? 0 : nc;
2820
0
        }
2821
2822
0
        ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2823
2824
#ifndef NDEBUG
2825
        for (int k = 0; k < nc; k++) {
2826
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2827
            GGML_UNUSED(x);
2828
            assert(!isnan(x));
2829
            assert(!isinf(x));
2830
        }
2831
#endif
2832
0
    }
2833
0
}
2834
2835
static void ggml_compute_forward_reglu_f16(
2836
    const ggml_compute_params * params,
2837
0
    ggml_tensor * dst) {
2838
2839
0
    const ggml_tensor * src0 = dst->src[0];
2840
0
    const ggml_tensor * src1 = dst->src[1];
2841
0
    char * src0_d = (char *) src0->data;
2842
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
2843
0
    const size_t src0_o = src0->nb[1];
2844
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2845
2846
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
2847
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
2848
2849
0
    if (src1) {
2850
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
2851
0
        GGML_ASSERT(src0->type == src1->type);
2852
0
    }
2853
2854
0
    const int ith = params->ith;
2855
0
    const int nth = params->nth;
2856
2857
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2858
0
    const int nr = ggml_nrows(src0);
2859
2860
0
    GGML_ASSERT(dst->ne[0] == nc);
2861
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
2862
2863
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2864
2865
    // rows per thread
2866
0
    const int dr = (nr + nth - 1)/nth;
2867
2868
    // row range for this thread
2869
0
    const int ir0 = dr*ith;
2870
0
    const int ir1 = MIN(ir0 + dr, nr);
2871
2872
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2873
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
2874
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
2875
2876
0
        if (!src1) {
2877
0
            src0_p += swapped ? nc : 0;
2878
0
            src1_p += swapped ? 0 : nc;
2879
0
        }
2880
2881
0
        ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2882
2883
#ifndef NDEBUG
2884
        for (int k = 0; k < nc; k++) {
2885
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2886
            const float v = GGML_FP16_TO_FP32(x);
2887
            GGML_UNUSED(v);
2888
            assert(!isnan(v));
2889
            assert(!isinf(v));
2890
        }
2891
#endif
2892
0
    }
2893
0
}
2894
2895
static void ggml_compute_forward_reglu(
2896
        const ggml_compute_params * params,
2897
0
        ggml_tensor * dst) {
2898
2899
0
    const ggml_tensor * src0 = dst->src[0];
2900
2901
0
    switch (src0->type) {
2902
0
        case GGML_TYPE_F32:
2903
0
            {
2904
0
                ggml_compute_forward_reglu_f32(params, dst);
2905
0
            } break;
2906
0
        case GGML_TYPE_F16:
2907
0
            {
2908
0
                ggml_compute_forward_reglu_f16(params, dst);
2909
0
            } break;
2910
0
        default:
2911
0
            {
2912
0
                GGML_ABORT("fatal error");
2913
0
            }
2914
0
    }
2915
0
}
2916
2917
// ggml_compute_forward_geglu
2918
2919
static void ggml_compute_forward_geglu_f32(
2920
        const ggml_compute_params * params,
2921
0
        ggml_tensor * dst) {
2922
2923
0
    const ggml_tensor * src0 = dst->src[0];
2924
0
    const ggml_tensor * src1 = dst->src[1];
2925
0
    char * src0_d = (char *) src0->data;
2926
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
2927
0
    const size_t src0_o = src0->nb[1];
2928
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2929
2930
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
2931
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
2932
2933
0
    if (src1) {
2934
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
2935
0
        GGML_ASSERT(src0->type == src1->type);
2936
0
    }
2937
2938
0
    const int ith = params->ith;
2939
0
    const int nth = params->nth;
2940
2941
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2942
0
    const int nr = ggml_nrows(src0);
2943
2944
0
    GGML_ASSERT(dst->ne[0] == nc);
2945
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
2946
2947
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2948
2949
    // rows per thread
2950
0
    const int dr = (nr + nth - 1)/nth;
2951
2952
    // row range for this thread
2953
0
    const int ir0 = dr*ith;
2954
0
    const int ir1 = MIN(ir0 + dr, nr);
2955
2956
0
    for (int i1 = ir0; i1 < ir1; i1++) {
2957
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
2958
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
2959
2960
0
        if (!src1) {
2961
0
            src0_p += swapped ? nc : 0;
2962
0
            src1_p += swapped ? 0 : nc;
2963
0
        }
2964
2965
0
        ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2966
2967
#ifndef NDEBUG
2968
        for (int k = 0; k < nc; k++) {
2969
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2970
            GGML_UNUSED(x);
2971
            assert(!isnan(x));
2972
            assert(!isinf(x));
2973
        }
2974
#endif
2975
0
    }
2976
0
}
2977
2978
static void ggml_compute_forward_geglu_f16(
2979
    const ggml_compute_params * params,
2980
0
    ggml_tensor * dst) {
2981
2982
0
    const ggml_tensor * src0 = dst->src[0];
2983
0
    const ggml_tensor * src1 = dst->src[1];
2984
0
    char * src0_d = (char *) src0->data;
2985
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
2986
0
    const size_t src0_o = src0->nb[1];
2987
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2988
2989
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
2990
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
2991
2992
0
    if (src1) {
2993
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
2994
0
        GGML_ASSERT(src0->type == src1->type);
2995
0
    }
2996
2997
0
    const int ith = params->ith;
2998
0
    const int nth = params->nth;
2999
3000
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3001
0
    const int nr = ggml_nrows(src0);
3002
3003
0
    GGML_ASSERT(dst->ne[0] == nc);
3004
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3005
3006
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3007
3008
    // rows per thread
3009
0
    const int dr = (nr + nth - 1)/nth;
3010
3011
    // row range for this thread
3012
0
    const int ir0 = dr*ith;
3013
0
    const int ir1 = MIN(ir0 + dr, nr);
3014
3015
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3016
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3017
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3018
3019
0
        if (!src1) {
3020
0
            src0_p += swapped ? nc : 0;
3021
0
            src1_p += swapped ? 0 : nc;
3022
0
        }
3023
3024
0
        ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3025
3026
#ifndef NDEBUG
3027
        for (int k = 0; k < nc; k++) {
3028
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3029
            const float v = GGML_FP16_TO_FP32(x);
3030
            GGML_UNUSED(v);
3031
            assert(!isnan(v));
3032
            assert(!isinf(v));
3033
        }
3034
#endif
3035
0
    }
3036
0
}
3037
3038
static void ggml_compute_forward_geglu(
3039
        const ggml_compute_params * params,
3040
0
        ggml_tensor * dst) {
3041
3042
0
    const ggml_tensor * src0 = dst->src[0];
3043
3044
0
    switch (src0->type) {
3045
0
        case GGML_TYPE_F32:
3046
0
            {
3047
0
                ggml_compute_forward_geglu_f32(params, dst);
3048
0
            } break;
3049
0
        case GGML_TYPE_F16:
3050
0
            {
3051
0
                ggml_compute_forward_geglu_f16(params, dst);
3052
0
            } break;
3053
0
        default:
3054
0
            {
3055
0
                GGML_ABORT("fatal error");
3056
0
            }
3057
0
    }
3058
0
}
3059
3060
// ggml_compute_forward_swiglu
3061
3062
static void ggml_compute_forward_swiglu_f32(
3063
        const ggml_compute_params * params,
3064
0
        ggml_tensor * dst) {
3065
3066
0
    const ggml_tensor * src0 = dst->src[0];
3067
0
    const ggml_tensor * src1 = dst->src[1];
3068
0
    char * src0_d = (char *) src0->data;
3069
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3070
0
    const size_t src0_o = src0->nb[1];
3071
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3072
3073
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3074
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3075
3076
0
    if (src1) {
3077
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3078
0
        GGML_ASSERT(src0->type == src1->type);
3079
0
    }
3080
3081
0
    const int ith = params->ith;
3082
0
    const int nth = params->nth;
3083
3084
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3085
0
    const int nr = ggml_nrows(src0);
3086
3087
0
    GGML_ASSERT(dst->ne[0] == nc);
3088
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3089
3090
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3091
3092
    // rows per thread
3093
0
    const int dr = (nr + nth - 1)/nth;
3094
3095
    // row range for this thread
3096
0
    const int ir0 = dr*ith;
3097
0
    const int ir1 = MIN(ir0 + dr, nr);
3098
3099
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3100
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3101
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3102
3103
0
        if (!src1) {
3104
0
            src0_p += swapped ? nc : 0;
3105
0
            src1_p += swapped ? 0 : nc;
3106
0
        }
3107
3108
0
        ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3109
3110
#ifndef NDEBUG
3111
        for (int k = 0; k < nc; k++) {
3112
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3113
            GGML_UNUSED(x);
3114
            assert(!isnan(x));
3115
            assert(!isinf(x));
3116
        }
3117
#endif
3118
0
    }
3119
0
}
3120
3121
static void ggml_compute_forward_swiglu_f16(
3122
    const ggml_compute_params * params,
3123
0
    ggml_tensor * dst) {
3124
3125
0
    const ggml_tensor * src0 = dst->src[0];
3126
0
    const ggml_tensor * src1 = dst->src[1];
3127
0
    char * src0_d = (char *) src0->data;
3128
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3129
0
    const size_t src0_o = src0->nb[1];
3130
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3131
3132
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3133
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3134
3135
0
    if (src1) {
3136
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3137
0
        GGML_ASSERT(src0->type == src1->type);
3138
0
    }
3139
3140
0
    const int ith = params->ith;
3141
0
    const int nth = params->nth;
3142
3143
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3144
0
    const int nr = ggml_nrows(src0);
3145
3146
0
    GGML_ASSERT(dst->ne[0] == nc);
3147
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3148
3149
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3150
3151
    // rows per thread
3152
0
    const int dr = (nr + nth - 1)/nth;
3153
3154
    // row range for this thread
3155
0
    const int ir0 = dr*ith;
3156
0
    const int ir1 = MIN(ir0 + dr, nr);
3157
3158
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3159
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3160
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3161
3162
0
        if (!src1) {
3163
0
            src0_p += swapped ? nc : 0;
3164
0
            src1_p += swapped ? 0 : nc;
3165
0
        }
3166
3167
0
        ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3168
3169
#ifndef NDEBUG
3170
        for (int k = 0; k < nc; k++) {
3171
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3172
            const float v = GGML_FP16_TO_FP32(x);
3173
            GGML_UNUSED(v);
3174
            assert(!isnan(v));
3175
            assert(!isinf(v));
3176
        }
3177
#endif
3178
0
    }
3179
0
}
3180
3181
static void ggml_compute_forward_swiglu(
3182
        const ggml_compute_params * params,
3183
0
        ggml_tensor * dst) {
3184
3185
0
    const ggml_tensor * src0 = dst->src[0];
3186
3187
0
    switch (src0->type) {
3188
0
        case GGML_TYPE_F32:
3189
0
            {
3190
0
                ggml_compute_forward_swiglu_f32(params, dst);
3191
0
            } break;
3192
0
        case GGML_TYPE_F16:
3193
0
            {
3194
0
                ggml_compute_forward_swiglu_f16(params, dst);
3195
0
            } break;
3196
0
        default:
3197
0
            {
3198
0
                GGML_ABORT("fatal error");
3199
0
            }
3200
0
    }
3201
0
}
3202
3203
// ggml_compute_forward_swiglu_oai
3204
3205
static void ggml_compute_forward_swiglu_oai_f32(
3206
        const ggml_compute_params * params,
3207
0
        ggml_tensor * dst) {
3208
3209
0
    const ggml_tensor * src0 = dst->src[0];
3210
0
    const ggml_tensor * src1 = dst->src[1];
3211
0
    char * src0_d = (char *) src0->data;
3212
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3213
0
    const size_t src0_o = src0->nb[1];
3214
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3215
3216
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3217
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3218
3219
0
    if (src1) {
3220
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3221
0
        GGML_ASSERT(src0->type == src1->type);
3222
0
    }
3223
3224
0
    const int ith = params->ith;
3225
0
    const int nth = params->nth;
3226
3227
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3228
0
    const int nr = ggml_nrows(src0);
3229
3230
0
    GGML_ASSERT(dst->ne[0] == nc);
3231
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3232
3233
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3234
0
    const float alpha = ggml_get_op_params_f32(dst, 2);
3235
0
    const float limit = ggml_get_op_params_f32(dst, 3);
3236
3237
    // rows per thread
3238
0
    const int dr = (nr + nth - 1)/nth;
3239
3240
    // row range for this thread
3241
0
    const int ir0 = dr*ith;
3242
0
    const int ir1 = MIN(ir0 + dr, nr);
3243
3244
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3245
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3246
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3247
0
        float * dst_p  = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3248
3249
0
        if (!src1) {
3250
0
            src0_p += swapped ? nc : 0;
3251
0
            src1_p += swapped ? 0 : nc;
3252
0
        }
3253
3254
0
        for (int k = 0; k < nc; k++) {
3255
0
            const float x = std::min(src0_p[k], limit);
3256
0
            const float y = std::clamp(src1_p[k], -limit, limit);
3257
0
            const float out_glu = x / (1.f + expf(alpha * (-x)));
3258
0
            dst_p[k] = out_glu * (y + 1.f);
3259
0
        }
3260
3261
#ifndef NDEBUG
3262
        for (int k = 0; k < nc; k++) {
3263
            const float x = dst_p[k];
3264
            GGML_UNUSED(x);
3265
            assert(!isnan(x));
3266
            assert(!isinf(x));
3267
        }
3268
#endif
3269
0
    }
3270
0
}
3271
3272
static void ggml_compute_forward_swiglu_oai(
3273
        const ggml_compute_params * params,
3274
0
        ggml_tensor * dst) {
3275
3276
0
    const ggml_tensor * src0 = dst->src[0];
3277
3278
0
    switch (src0->type) {
3279
0
        case GGML_TYPE_F32:
3280
0
            {
3281
0
                ggml_compute_forward_swiglu_oai_f32(params, dst);
3282
0
            } break;
3283
0
        default:
3284
0
            {
3285
0
                GGML_ABORT("fatal error");
3286
0
            }
3287
0
    }
3288
0
}
3289
3290
// ggml_compute_forward_geglu_erf
3291
3292
static void ggml_compute_forward_geglu_erf_f32(
3293
        const ggml_compute_params * params,
3294
0
        ggml_tensor * dst) {
3295
3296
0
    const ggml_tensor * src0 = dst->src[0];
3297
0
    const ggml_tensor * src1 = dst->src[1];
3298
0
    char * src0_d = (char *) src0->data;
3299
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3300
0
    const size_t src0_o = src0->nb[1];
3301
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3302
3303
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3304
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3305
3306
0
    if (src1) {
3307
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3308
0
        GGML_ASSERT(src0->type == src1->type);
3309
0
    }
3310
3311
0
    const int ith = params->ith;
3312
0
    const int nth = params->nth;
3313
3314
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3315
0
    const int nr = ggml_nrows(src0);
3316
3317
0
    GGML_ASSERT(dst->ne[0] == nc);
3318
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3319
3320
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3321
3322
    // rows per thread
3323
0
    const int dr = (nr + nth - 1)/nth;
3324
3325
    // row range for this thread
3326
0
    const int ir0 = dr*ith;
3327
0
    const int ir1 = MIN(ir0 + dr, nr);
3328
3329
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3330
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3331
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3332
3333
0
        if (!src1) {
3334
0
            src0_p += swapped ? nc : 0;
3335
0
            src1_p += swapped ? 0 : nc;
3336
0
        }
3337
3338
0
        ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3339
3340
#ifndef NDEBUG
3341
        for (int k = 0; k < nc; k++) {
3342
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3343
            GGML_UNUSED(x);
3344
            assert(!isnan(x));
3345
            assert(!isinf(x));
3346
        }
3347
#endif
3348
0
    }
3349
0
}
3350
3351
static void ggml_compute_forward_geglu_erf_f16(
3352
    const ggml_compute_params * params,
3353
0
    ggml_tensor * dst) {
3354
3355
0
    const ggml_tensor * src0 = dst->src[0];
3356
0
    const ggml_tensor * src1 = dst->src[1];
3357
0
    char * src0_d = (char *) src0->data;
3358
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3359
0
    const size_t src0_o = src0->nb[1];
3360
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3361
3362
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3363
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3364
3365
0
    if (src1) {
3366
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3367
0
        GGML_ASSERT(src0->type == src1->type);
3368
0
    }
3369
3370
0
    const int ith = params->ith;
3371
0
    const int nth = params->nth;
3372
3373
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3374
0
    const int nr = ggml_nrows(src0);
3375
3376
0
    GGML_ASSERT(dst->ne[0] == nc);
3377
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3378
3379
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3380
3381
    // rows per thread
3382
0
    const int dr = (nr + nth - 1)/nth;
3383
3384
    // row range for this thread
3385
0
    const int ir0 = dr*ith;
3386
0
    const int ir1 = MIN(ir0 + dr, nr);
3387
3388
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3389
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3390
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3391
3392
0
        if (!src1) {
3393
0
            src0_p += swapped ? nc : 0;
3394
0
            src1_p += swapped ? 0 : nc;
3395
0
        }
3396
3397
0
        ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3398
3399
#ifndef NDEBUG
3400
        for (int k = 0; k < nc; k++) {
3401
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3402
            const float v = GGML_FP16_TO_FP32(x);
3403
            GGML_UNUSED(v);
3404
            assert(!isnan(v));
3405
            assert(!isinf(v));
3406
        }
3407
#endif
3408
0
    }
3409
0
}
3410
3411
static void ggml_compute_forward_geglu_erf(
3412
        const ggml_compute_params * params,
3413
0
        ggml_tensor * dst) {
3414
3415
0
    const ggml_tensor * src0 = dst->src[0];
3416
3417
0
    switch (src0->type) {
3418
0
        case GGML_TYPE_F32:
3419
0
            {
3420
0
                ggml_compute_forward_geglu_erf_f32(params, dst);
3421
0
            } break;
3422
0
        case GGML_TYPE_F16:
3423
0
            {
3424
0
                ggml_compute_forward_geglu_erf_f16(params, dst);
3425
0
            } break;
3426
0
        default:
3427
0
            {
3428
0
                GGML_ABORT("fatal error");
3429
0
            }
3430
0
    }
3431
0
}
3432
3433
// ggml_compute_forward_geglu_quick
3434
3435
static void ggml_compute_forward_geglu_quick_f32(
3436
        const ggml_compute_params * params,
3437
0
        ggml_tensor * dst) {
3438
3439
0
    const ggml_tensor * src0 = dst->src[0];
3440
0
    const ggml_tensor * src1 = dst->src[1];
3441
0
    char * src0_d = (char *) src0->data;
3442
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3443
0
    const size_t src0_o = src0->nb[1];
3444
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3445
3446
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3447
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3448
3449
0
    if (src1) {
3450
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3451
0
        GGML_ASSERT(src0->type == src1->type);
3452
0
    }
3453
3454
0
    const int ith = params->ith;
3455
0
    const int nth = params->nth;
3456
3457
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3458
0
    const int nr = ggml_nrows(src0);
3459
3460
0
    GGML_ASSERT(dst->ne[0] == nc);
3461
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3462
3463
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3464
3465
    // rows per thread
3466
0
    const int dr = (nr + nth - 1)/nth;
3467
3468
    // row range for this thread
3469
0
    const int ir0 = dr*ith;
3470
0
    const int ir1 = MIN(ir0 + dr, nr);
3471
3472
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3473
0
        float * src0_p = (float *) (src0_d + i1*src0_o);
3474
0
        float * src1_p = (float *) (src1_d + i1*src1_o);
3475
3476
0
        if (!src1) {
3477
0
            src0_p += swapped ? nc : 0;
3478
0
            src1_p += swapped ? 0 : nc;
3479
0
        }
3480
3481
0
        ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3482
3483
#ifndef NDEBUG
3484
        for (int k = 0; k < nc; k++) {
3485
            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3486
            GGML_UNUSED(x);
3487
            assert(!isnan(x));
3488
            assert(!isinf(x));
3489
        }
3490
#endif
3491
0
    }
3492
0
}
3493
3494
static void ggml_compute_forward_geglu_quick_f16(
3495
    const ggml_compute_params * params,
3496
0
    ggml_tensor * dst) {
3497
3498
0
    const ggml_tensor * src0 = dst->src[0];
3499
0
    const ggml_tensor * src1 = dst->src[1];
3500
0
    char * src0_d = (char *) src0->data;
3501
0
    char * src1_d = (char *) (src1 ? src1->data : src0->data);
3502
0
    const size_t src0_o = src0->nb[1];
3503
0
    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3504
3505
0
    GGML_ASSERT(ggml_is_contiguous_1(src0));
3506
0
    GGML_ASSERT(ggml_is_contiguous_1(dst));
3507
3508
0
    if (src1) {
3509
0
        GGML_ASSERT(ggml_is_contiguous_1(src1));
3510
0
        GGML_ASSERT(src0->type == src1->type);
3511
0
    }
3512
3513
0
    const int ith = params->ith;
3514
0
    const int nth = params->nth;
3515
3516
0
    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3517
0
    const int nr = ggml_nrows(src0);
3518
3519
0
    GGML_ASSERT(dst->ne[0] == nc);
3520
0
    GGML_ASSERT(ggml_nrows(dst) == nr);
3521
3522
0
    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3523
3524
    // rows per thread
3525
0
    const int dr = (nr + nth - 1)/nth;
3526
3527
    // row range for this thread
3528
0
    const int ir0 = dr*ith;
3529
0
    const int ir1 = MIN(ir0 + dr, nr);
3530
3531
0
    for (int i1 = ir0; i1 < ir1; i1++) {
3532
0
        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3533
0
        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3534
3535
0
        if (!src1) {
3536
0
            src0_p += swapped ? nc : 0;
3537
0
            src1_p += swapped ? 0 : nc;
3538
0
        }
3539
3540
0
        ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3541
3542
#ifndef NDEBUG
3543
        for (int k = 0; k < nc; k++) {
3544
            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3545
            const float v = GGML_FP16_TO_FP32(x);
3546
            GGML_UNUSED(v);
3547
            assert(!isnan(v));
3548
            assert(!isinf(v));
3549
        }
3550
#endif
3551
0
    }
3552
0
}
3553
3554
static void ggml_compute_forward_geglu_quick(
3555
        const ggml_compute_params * params,
3556
0
        ggml_tensor * dst) {
3557
3558
0
    const ggml_tensor * src0 = dst->src[0];
3559
3560
0
    switch (src0->type) {
3561
0
        case GGML_TYPE_F32:
3562
0
            {
3563
0
                ggml_compute_forward_geglu_quick_f32(params, dst);
3564
0
            } break;
3565
0
        case GGML_TYPE_F16:
3566
0
            {
3567
0
                ggml_compute_forward_geglu_quick_f16(params, dst);
3568
0
            } break;
3569
0
        default:
3570
0
            {
3571
0
                GGML_ABORT("fatal error");
3572
0
            }
3573
0
    }
3574
0
}
3575
3576
// ggml_compute_forward_norm
3577
3578
static void ggml_compute_forward_norm_f32(
3579
        const ggml_compute_params * params,
3580
0
        ggml_tensor * dst) {
3581
3582
0
    const ggml_tensor * src0 = dst->src[0];
3583
3584
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
3585
3586
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3587
3588
0
    const int ith = params->ith;
3589
0
    const int nth = params->nth;
3590
3591
0
    GGML_TENSOR_UNARY_OP_LOCALS
3592
3593
0
    float eps;
3594
0
    memcpy(&eps, dst->op_params, sizeof(float));
3595
3596
0
    GGML_ASSERT(eps >= 0.0f);
3597
3598
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
3599
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
3600
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3601
0
                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3602
3603
0
                float sum = 0.0;
3604
0
                ggml_vec_sum_f32(ne00, &sum, x);
3605
0
                float mean = sum/ne00;
3606
3607
0
                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3608
0
                float variance = 0;
3609
3610
#ifdef GGML_USE_ACCELERATE
3611
                mean = -mean;
3612
                vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3613
                vDSP_measqv(y, 1, &variance, ne00);
3614
#else
3615
0
                variance = ggml_vec_cvar_f32(ne00, y, x, mean);
3616
0
#endif //GGML_USE_ACCELERATE
3617
3618
0
                const float scale = 1.0f/sqrtf(variance + eps);
3619
0
                ggml_vec_scale_f32(ne00, y, scale);
3620
0
            }
3621
0
        }
3622
0
    }
3623
0
}
3624
3625
void ggml_compute_forward_norm(
3626
        const ggml_compute_params * params,
3627
0
        ggml_tensor * dst) {
3628
3629
0
    const ggml_tensor * src0 = dst->src[0];
3630
3631
0
    switch (src0->type) {
3632
0
        case GGML_TYPE_F32:
3633
0
            {
3634
0
                ggml_compute_forward_norm_f32(params, dst);
3635
0
            } break;
3636
0
        default:
3637
0
            {
3638
0
                GGML_ABORT("fatal error");
3639
0
            }
3640
0
    }
3641
0
}
3642
3643
// ggml_compute_forward_group_rms_norm
3644
3645
static void ggml_compute_forward_rms_norm_f32(
3646
        const ggml_compute_params * params,
3647
0
        ggml_tensor * dst) {
3648
3649
0
    const ggml_tensor * src0 = dst->src[0];
3650
3651
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
3652
3653
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3654
3655
0
    const int ith = params->ith;
3656
0
    const int nth = params->nth;
3657
3658
0
    GGML_TENSOR_UNARY_OP_LOCALS
3659
3660
0
    float eps;
3661
0
    memcpy(&eps, dst->op_params, sizeof(float));
3662
3663
0
    GGML_ASSERT(eps >= 0.0f);
3664
3665
    // TODO: optimize
3666
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
3667
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
3668
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3669
0
                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3670
3671
0
                ggml_float sum = 0.0;
3672
0
                for (int64_t i00 = 0; i00 < ne00; i00++) {
3673
0
                    sum += (ggml_float)(x[i00] * x[i00]);
3674
0
                }
3675
3676
0
                const float mean = sum/ne00;
3677
3678
0
                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3679
3680
0
                memcpy(y, x, ne00 * sizeof(float));
3681
                // for (int i00 = 0; i00 < ne00; i00++) {
3682
                //     y[i00] = x[i00];
3683
                // }
3684
3685
0
                const float scale = 1.0f/sqrtf(mean + eps);
3686
3687
                // if you hit this, likely you got an inf somewhere earlier
3688
0
                assert(scale > 0.0f);
3689
3690
0
                ggml_vec_scale_f32(ne00, y, scale);
3691
0
            }
3692
0
        }
3693
0
    }
3694
0
}
3695
3696
void ggml_compute_forward_rms_norm(
3697
        const ggml_compute_params * params,
3698
0
        ggml_tensor * dst) {
3699
3700
0
    const ggml_tensor * src0 = dst->src[0];
3701
3702
0
    switch (src0->type) {
3703
0
        case GGML_TYPE_F32:
3704
0
            {
3705
0
                ggml_compute_forward_rms_norm_f32(params, dst);
3706
0
            } break;
3707
0
        default:
3708
0
            {
3709
0
                GGML_ABORT("fatal error");
3710
0
            }
3711
0
    }
3712
0
}
3713
3714
static void ggml_compute_forward_rms_norm_back_f32(
3715
        const ggml_compute_params * params,
3716
0
        ggml_tensor * dst) {
3717
3718
0
    const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
3719
0
    const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
3720
3721
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
3722
3723
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3724
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
3725
3726
0
    const int ith = params->ith;
3727
0
    const int nth = params->nth;
3728
3729
0
    GGML_TENSOR_BINARY_OP_LOCALS
3730
3731
0
    float eps;
3732
0
    memcpy(&eps, dst->op_params, sizeof(float));
3733
3734
    // TODO: optimize
3735
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
3736
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
3737
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3738
                // src1 is same shape as src0 => same indices
3739
0
                const int64_t i11 = i01;
3740
0
                const int64_t i12 = i02;
3741
0
                const int64_t i13 = i03;
3742
3743
0
                const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3744
0
                const float * x  = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
3745
3746
0
                ggml_float sum_xx  = 0.0;
3747
0
                ggml_float sum_xdz = 0.0;
3748
3749
0
                for (int64_t i00 = 0; i00 < ne00; i00++) {
3750
0
                    sum_xx  += (ggml_float)(x[i00] * x[i00]);
3751
0
                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);
3752
0
                }
3753
3754
                //const float mean     = (float)(sum_xx)/ne00;
3755
0
                const float mean_eps = (float)(sum_xx)/ne00 + eps;
3756
0
                const float sum_eps  = (float)(sum_xx) + eps*ne00;
3757
                //const float mean_xdz = (float)(sum_xdz)/ne00;
3758
                // we could cache rms from forward pass to improve performance.
3759
                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
3760
                //const float rms      = sqrtf(mean_eps);
3761
0
                const float rrms     = 1.0f / sqrtf(mean_eps);
3762
                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
3763
3764
0
                {
3765
                    // z = rms_norm(x)
3766
                    //
3767
                    // rms_norm(src1) =
3768
                    //     scale(
3769
                    //         src1,
3770
                    //         div(
3771
                    //             1,
3772
                    //             sqrt(
3773
                    //                 add(
3774
                    //                     scale(
3775
                    //                         sum(
3776
                    //                             sqr(
3777
                    //                                 src1)),
3778
                    //                         (1.0/N)),
3779
                    //                     eps))));
3780
3781
                    // postorder:
3782
                    // ## op    args         grad
3783
                    // 00 param src1         grad[#00]
3784
                    // 01 const 1
3785
                    // 02 sqr   (#00)        grad[#02]
3786
                    // 03 sum   (#02)        grad[#03]
3787
                    // 04 const 1/N
3788
                    // 05 scale (#03, #04)   grad[#05]
3789
                    // 06 const eps
3790
                    // 07 add   (#05, #06)   grad[#07]
3791
                    // 08 sqrt  (#07)        grad[#08]
3792
                    // 09 div   (#01,#08)    grad[#09]
3793
                    // 10 scale (#00,#09)    grad[#10]
3794
                    //
3795
                    // backward pass, given grad[#10]
3796
                    // #10: scale
3797
                    // grad[#00] += scale(grad[#10],#09)
3798
                    // grad[#09] += sum(mul(grad[#10],#00))
3799
                    // #09: div
3800
                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
3801
                    // #08: sqrt
3802
                    // grad[#07] += mul(grad[#08], div(0.5, #08))
3803
                    // #07: add
3804
                    // grad[#05] += grad[#07]
3805
                    // #05: scale
3806
                    // grad[#03] += scale(grad[#05],#04)
3807
                    // #03: sum
3808
                    // grad[#02] += repeat(grad[#03], #02)
3809
                    // #02:
3810
                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
3811
                    //
3812
                    // substitute and simplify:
3813
                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
3814
                    // grad[#02] = repeat(grad[#03], #02)
3815
                    // grad[#02] = repeat(scale(grad[#05],#04), #02)
3816
                    // grad[#02] = repeat(scale(grad[#07],#04), #02)
3817
                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
3818
                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
3819
                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
3820
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
3821
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
3822
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
3823
                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
3824
                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
3825
                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
3826
                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
3827
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
3828
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
3829
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
3830
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
3831
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
3832
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
3833
                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
3834
                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
3835
                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
3836
                    // a = b*c + d*e
3837
                    // a = b*c*f/f + d*e*f/f
3838
                    // a = (b*c*f + d*e*f)*(1/f)
3839
                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
3840
                    // a = (b + d*e/c)*c
3841
                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
3842
                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
3843
                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
3844
                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
3845
                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
3846
                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
3847
                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
3848
                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
3849
                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3850
                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3851
0
                }
3852
                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3853
                // post-order:
3854
                // dx := x
3855
                // dx := scale(dx,-mean_xdz/mean_eps)
3856
                // dx := add(dx, dz)
3857
                // dx := scale(dx, rrms)
3858
0
                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3859
3860
                // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
3861
0
                ggml_vec_cpy_f32  (ne00, dx, x);
3862
                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
3863
0
                ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
3864
0
                ggml_vec_acc_f32  (ne00, dx, dz);
3865
0
                ggml_vec_scale_f32(ne00, dx, rrms);
3866
0
            }
3867
0
        }
3868
0
    }
3869
0
}
3870
3871
void ggml_compute_forward_rms_norm_back(
3872
        const ggml_compute_params * params,
3873
0
        ggml_tensor * dst) {
3874
3875
0
    const ggml_tensor * src0 = dst->src[0];
3876
3877
0
    switch (src0->type) {
3878
0
        case GGML_TYPE_F32:
3879
0
            {
3880
0
                ggml_compute_forward_rms_norm_back_f32(params, dst);
3881
0
            } break;
3882
0
        default:
3883
0
            {
3884
0
                GGML_ABORT("fatal error");
3885
0
            }
3886
0
    }
3887
0
}
3888
3889
// ggml_compute_forward_group_norm
3890
3891
static void ggml_compute_forward_group_norm_f32(
3892
    const ggml_compute_params * params,
3893
0
    ggml_tensor * dst) {
3894
3895
0
    const ggml_tensor * src0 = dst->src[0];
3896
3897
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
3898
3899
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3900
3901
0
    const int ith = params->ith;
3902
0
    const int nth = params->nth;
3903
3904
0
    GGML_TENSOR_UNARY_OP_LOCALS
3905
3906
    // TODO: optimize
3907
3908
0
    float eps;
3909
0
    memcpy(&eps, dst->op_params + 1, sizeof(float));
3910
3911
0
    int n_channels = src0->ne[2];
3912
0
    int n_groups = dst->op_params[0];
3913
0
    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
3914
0
    for (int i = ith; i < n_groups; i += nth) {
3915
0
        int start = i * n_channels_per_group;
3916
0
        int end = start + n_channels_per_group;
3917
0
        if (end > n_channels) {
3918
0
            end = n_channels;
3919
0
        }
3920
0
        int step = end - start;
3921
3922
0
        for (int64_t i03 = 0; i03 < ne03; i03++) {
3923
0
            ggml_float sum = 0.0;
3924
0
            for (int64_t i02 = start; i02 < end; i02++) {
3925
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
3926
0
                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3927
3928
0
                    ggml_float sumr = 0.0;
3929
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
3930
0
                        sumr += (ggml_float)x[i00];
3931
0
                    }
3932
0
                    sum += sumr;
3933
0
                }
3934
0
            }
3935
0
            const float mean = sum / (ne00 * ne01 * step);
3936
3937
0
            ggml_float sum2 = 0.0;
3938
0
            for (int64_t i02 = start; i02 < end; i02++) {
3939
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
3940
0
                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3941
3942
0
                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
3943
3944
0
                    ggml_float sumr = 0.0;
3945
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
3946
0
                        float v = x[i00] - mean;
3947
0
                        y[i00] = v;
3948
0
                        sumr += (ggml_float)(v * v);
3949
0
                    }
3950
0
                    sum2 += sumr;
3951
0
                }
3952
0
            }
3953
0
            const float variance = sum2 / (ne00 * ne01 * step);
3954
0
            const float scale = 1.0f / sqrtf(variance + eps);
3955
3956
0
            for (int64_t i02 = start; i02 < end; i02++) {
3957
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
3958
0
                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
3959
0
                    ggml_vec_scale_f32(ne00, y, scale);
3960
0
                }
3961
0
            }
3962
0
        }
3963
0
    }
3964
0
}
3965
3966
void ggml_compute_forward_group_norm(
3967
    const ggml_compute_params * params,
3968
0
    ggml_tensor * dst) {
3969
3970
0
    const ggml_tensor * src0 = dst->src[0];
3971
3972
0
    switch (src0->type) {
3973
0
        case GGML_TYPE_F32:
3974
0
            {
3975
0
                ggml_compute_forward_group_norm_f32(params, dst);
3976
0
            } break;
3977
0
        default:
3978
0
            {
3979
0
                GGML_ABORT("fatal error");
3980
0
            }
3981
0
    }
3982
0
}
3983
3984
// ggml_compute_forward_l2_norm
3985
3986
static void ggml_compute_forward_l2_norm_f32(
3987
    const ggml_compute_params * params,
3988
0
    ggml_tensor * dst) {
3989
3990
0
    const ggml_tensor * src0 = dst->src[0];
3991
3992
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
3993
3994
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
3995
3996
0
    const int ith = params->ith;
3997
0
    const int nth = params->nth;
3998
3999
0
    GGML_TENSOR_UNARY_OP_LOCALS
4000
4001
0
    float eps;
4002
0
    memcpy(&eps, dst->op_params, sizeof(float));
4003
4004
0
    GGML_ASSERT(eps >= 0.0f);
4005
4006
    // TODO: optimize
4007
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
4008
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
4009
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
4010
0
                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
4011
4012
0
                ggml_float sum = 0.0;
4013
0
                for (int64_t i00 = 0; i00 < ne00; i00++) {
4014
0
                    sum += (ggml_float)(x[i00] * x[i00]);
4015
0
                }
4016
4017
0
                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
4018
4019
0
                memcpy(y, x, ne00 * sizeof(float));
4020
4021
0
                const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
4022
4023
0
                ggml_vec_scale_f32(ne00, y, scale);
4024
0
            }
4025
0
        }
4026
0
    }
4027
0
}
4028
4029
void ggml_compute_forward_l2_norm(
4030
    const ggml_compute_params * params,
4031
0
    ggml_tensor * dst) {
4032
4033
0
    const ggml_tensor * src0 = dst->src[0];
4034
4035
0
    switch (src0->type) {
4036
0
        case GGML_TYPE_F32:
4037
0
            {
4038
0
                ggml_compute_forward_l2_norm_f32(params, dst);
4039
0
            } break;
4040
0
        default:
4041
0
            {
4042
0
                GGML_ABORT("fatal error");
4043
0
            }
4044
0
    }
4045
0
}
4046
4047
// ggml_compute_forward_out_prod
4048
4049
static void ggml_compute_forward_out_prod_f32(
4050
        const ggml_compute_params * params,
4051
0
              ggml_tensor * dst) {
4052
4053
0
    const ggml_tensor * src0 = dst->src[0];
4054
0
    const ggml_tensor * src1 = dst->src[1];
4055
4056
0
    GGML_TENSOR_BINARY_OP_LOCALS
4057
4058
0
    GGML_ASSERT(dst->type == GGML_TYPE_F32);
4059
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
4060
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
4061
4062
0
    const int ith = params->ith;
4063
0
    const int nth = params->nth;
4064
4065
0
    GGML_ASSERT(ne0 == ne00);
4066
0
    GGML_ASSERT(ne1 == ne10);
4067
0
    GGML_ASSERT(ne2 == ne12);
4068
0
    GGML_ASSERT(ne3 == ne13);
4069
4070
0
    GGML_ASSERT(ne2 % ne02 == 0);
4071
0
    GGML_ASSERT(ne3 % ne03 == 0);
4072
4073
    // we don't support permuted src0 or src1
4074
0
    GGML_ASSERT(nb00 == sizeof(float));
4075
4076
    // dst cannot be transposed or permuted
4077
0
    GGML_ASSERT(nb0 == sizeof(float));
4078
    // GGML_ASSERT(nb0 <= nb1);
4079
    // GGML_ASSERT(nb1 <= nb2);
4080
    // GGML_ASSERT(nb2 <= nb3);
4081
4082
    // nb01 >= nb00 - src0 is not transposed
4083
    //   compute by src0 rows
4084
4085
0
    if (ith == 0) {
4086
0
        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
4087
0
    }
4088
0
    ggml_barrier(params->threadpool);
4089
4090
    // dst[:,:,:,:] = 0
4091
    // for i2,i3:
4092
    //   for i1:
4093
    //     for i01:
4094
    //       for i0:
4095
    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4096
4097
    // parallelize by last three dimensions
4098
4099
    // total rows in dst
4100
0
    const int64_t nr = ne1*ne2*ne3;
4101
4102
    // rows per thread
4103
0
    const int64_t dr = (nr + nth - 1)/nth;
4104
4105
    // row range for this thread
4106
0
    const int64_t ir0 = dr*ith;
4107
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
4108
4109
    // block-tiling attempt
4110
0
    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
4111
0
    const int64_t blck_1 = 16;
4112
4113
    // dps == dst per src0, used for group query attention
4114
0
    const int64_t dps2 = ne2 / ne02;
4115
0
    const int64_t dps3 = ne3 / ne03;
4116
4117
0
    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
4118
0
        const int64_t bir1 = MIN(bir + blck_1, ir1);
4119
0
        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
4120
0
            const int64_t bne01 = MIN(bi01 + blck_0, ne01);
4121
0
            for (int64_t ir = bir; ir < bir1; ++ir) {
4122
                // dst indices
4123
0
                const int64_t i3 = ir/(ne2*ne1);
4124
0
                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4125
0
                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4126
4127
0
                const int64_t i02 = i2 / dps2;
4128
0
                const int64_t i03 = i3 / dps3;
4129
4130
                //const int64_t i10 = i1;
4131
0
                const int64_t i12 = i2;
4132
0
                const int64_t i13 = i3;
4133
4134
0
#if GGML_VEC_MAD_UNROLL > 2
4135
0
                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
4136
0
                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
4137
0
                    const int64_t i11 = i01;
4138
4139
0
                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4140
0
                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4141
0
                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
4142
4143
0
                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
4144
0
                }
4145
0
                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
4146
0
                    const int64_t i11 = i01;
4147
4148
0
                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4149
0
                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4150
0
                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
4151
4152
0
                    ggml_vec_mad_f32(ne0, d, s0, *s1);
4153
0
                }
4154
#else
4155
                for (int64_t i01 = bi01; i01 < bne01; ++i01) {
4156
                    const int64_t i11 = i01;
4157
4158
                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4159
                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4160
                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
4161
4162
                    ggml_vec_mad_f32(ne0, d, s0, *s1);
4163
                }
4164
#endif
4165
0
            }
4166
0
        }
4167
0
    }
4168
0
}
4169
4170
static void ggml_compute_forward_out_prod_q_f32(
4171
        const ggml_compute_params * params,
4172
0
              ggml_tensor * dst) {
4173
4174
0
    const ggml_tensor * src0 = dst->src[0];
4175
0
    const ggml_tensor * src1 = dst->src[1];
4176
4177
0
    GGML_TENSOR_BINARY_OP_LOCALS;
4178
4179
0
    const int ith = params->ith;
4180
0
    const int nth = params->nth;
4181
4182
0
    const ggml_type type = src0->type;
4183
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
4184
4185
0
    GGML_ASSERT(ne02 == ne12);
4186
0
    GGML_ASSERT(ne03 == ne13);
4187
0
    GGML_ASSERT(ne2  == ne12);
4188
0
    GGML_ASSERT(ne3  == ne13);
4189
4190
    // we don't support permuted src0 dim0
4191
0
    GGML_ASSERT(nb00 == ggml_type_size(type));
4192
4193
    // dst dim0 cannot be transposed or permuted
4194
0
    GGML_ASSERT(nb0 == sizeof(float));
4195
    // GGML_ASSERT(nb0 <= nb1);
4196
    // GGML_ASSERT(nb1 <= nb2);
4197
    // GGML_ASSERT(nb2 <= nb3);
4198
4199
0
    GGML_ASSERT(ne0 == ne00);
4200
0
    GGML_ASSERT(ne1 == ne10);
4201
0
    GGML_ASSERT(ne2 == ne02);
4202
0
    GGML_ASSERT(ne3 == ne03);
4203
4204
    // nb01 >= nb00 - src0 is not transposed
4205
    //   compute by src0 rows
4206
4207
0
    if (ith == 0) {
4208
0
        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
4209
0
    }
4210
0
    ggml_barrier(params->threadpool);
4211
4212
    // parallelize by last three dimensions
4213
4214
    // total rows in dst
4215
0
    const int64_t nr = ne1*ne2*ne3;
4216
4217
    // rows per thread
4218
0
    const int64_t dr = (nr + nth - 1)/nth;
4219
4220
    // row range for this thread
4221
0
    const int64_t ir0 = dr*ith;
4222
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
4223
4224
    // dst[:,:,:,:] = 0
4225
    // for i2,i3:
4226
    //   for i1:
4227
    //     for i01:
4228
    //       for i0:
4229
    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4230
4231
0
    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
4232
4233
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
4234
        // dst indices
4235
0
        const int64_t i3 = ir/(ne2*ne1);
4236
0
        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4237
0
        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4238
4239
0
        const int64_t i02 = i2;
4240
0
        const int64_t i03 = i3;
4241
4242
        //const int64_t i10 = i1;
4243
0
        const int64_t i12 = i2;
4244
0
        const int64_t i13 = i3;
4245
4246
0
        for (int64_t i01 = 0; i01 < ne01; ++i01) {
4247
0
            const int64_t i11 = i01;
4248
4249
0
            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
4250
0
            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4251
0
            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
4252
4253
0
            dequantize_row_q(s0, wdata, ne0);
4254
0
            ggml_vec_mad_f32(ne0, d, wdata, *s1);
4255
0
        }
4256
0
    }
4257
0
}
4258
4259
void ggml_compute_forward_out_prod(
4260
        const ggml_compute_params * params,
4261
0
        ggml_tensor * dst) {
4262
4263
0
    const ggml_tensor * src0 = dst->src[0];
4264
4265
0
    switch (src0->type) {
4266
0
        case GGML_TYPE_Q4_0:
4267
0
        case GGML_TYPE_Q4_1:
4268
0
        case GGML_TYPE_Q5_0:
4269
0
        case GGML_TYPE_Q5_1:
4270
0
        case GGML_TYPE_Q8_0:
4271
0
        case GGML_TYPE_MXFP4:
4272
0
        case GGML_TYPE_Q2_K:
4273
0
        case GGML_TYPE_Q3_K:
4274
0
        case GGML_TYPE_Q4_K:
4275
0
        case GGML_TYPE_Q5_K:
4276
0
        case GGML_TYPE_Q6_K:
4277
0
        case GGML_TYPE_TQ1_0:
4278
0
        case GGML_TYPE_TQ2_0:
4279
0
        case GGML_TYPE_IQ2_XXS:
4280
0
        case GGML_TYPE_IQ2_XS:
4281
0
        case GGML_TYPE_IQ3_XXS:
4282
0
        case GGML_TYPE_IQ1_S:
4283
0
        case GGML_TYPE_IQ1_M:
4284
0
        case GGML_TYPE_IQ4_NL:
4285
0
        case GGML_TYPE_IQ4_XS:
4286
0
        case GGML_TYPE_IQ3_S:
4287
0
        case GGML_TYPE_IQ2_S:
4288
0
            {
4289
0
                ggml_compute_forward_out_prod_q_f32(params, dst);
4290
0
            } break;
4291
0
        case GGML_TYPE_F16:
4292
0
            {
4293
0
                GGML_ABORT("fatal error"); // todo
4294
                // ggml_compute_forward_out_prod_f16_f32(params, dst);
4295
0
            }
4296
0
        case GGML_TYPE_F32:
4297
0
            {
4298
0
                ggml_compute_forward_out_prod_f32(params, dst);
4299
0
            } break;
4300
0
        default:
4301
0
            {
4302
0
                GGML_ABORT("fatal error");
4303
0
            }
4304
0
    }
4305
0
}
4306
4307
// ggml_compute_forward_scale
4308
4309
static void ggml_compute_forward_scale_f32(
4310
        const ggml_compute_params * params,
4311
0
        ggml_tensor * dst) {
4312
4313
0
    const ggml_tensor * src0 = dst->src[0];
4314
4315
0
    GGML_ASSERT(ggml_is_contiguous(src0));
4316
0
    GGML_ASSERT(ggml_is_contiguous(dst));
4317
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4318
4319
0
    float s; // scale factor
4320
0
    float b; // bias
4321
4322
0
    memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4323
0
    memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
4324
4325
0
    const int ith = params->ith;
4326
0
    const int nth = params->nth;
4327
4328
0
    const int nc = src0->ne[0];
4329
0
    const int nr = ggml_nrows(src0);
4330
4331
    // rows per thread
4332
0
    const int dr = (nr + nth - 1)/nth;
4333
4334
    // row range for this thread
4335
0
    const int ir0 = dr*ith;
4336
0
    const int ir1 = MIN(ir0 + dr, nr);
4337
4338
0
    const size_t nb01 = src0->nb[1];
4339
4340
0
    const size_t nb1 = dst->nb[1];
4341
4342
0
    if (b == 0.0f) {
4343
0
        for (int i1 = ir0; i1 < ir1; i1++) {
4344
0
            if (dst->data != src0->data) {
4345
                // src0 is same shape as dst => same indices
4346
                // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4347
0
                memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4348
0
            }
4349
0
            ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4350
0
        }
4351
0
    } else {
4352
0
        for (int i1 = ir0; i1 < ir1; i1++) {
4353
0
            ggml_vec_mad1_f32(nc,
4354
0
                (float *) ((char *) dst->data  + i1*nb1),
4355
0
                (float *) ((char *) src0->data + i1*nb1),
4356
0
                s, b);
4357
0
        }
4358
0
    }
4359
0
}
4360
4361
void ggml_compute_forward_scale(
4362
        const ggml_compute_params * params,
4363
0
        ggml_tensor * dst) {
4364
4365
0
    const ggml_tensor * src0 = dst->src[0];
4366
4367
0
    switch (src0->type) {
4368
0
        case GGML_TYPE_F32:
4369
0
            {
4370
0
                ggml_compute_forward_scale_f32(params, dst);
4371
0
            } break;
4372
0
        default:
4373
0
            {
4374
0
                GGML_ABORT("fatal error");
4375
0
            }
4376
0
    }
4377
0
}
4378
4379
// ggml_compute_forward_set
4380
4381
static void ggml_compute_forward_set_f32(
4382
        const ggml_compute_params * params,
4383
0
        ggml_tensor * dst) {
4384
4385
0
    const ggml_tensor * src0 = dst->src[0];
4386
0
    const ggml_tensor * src1 = dst->src[1];
4387
4388
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4389
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
4390
4391
    // view src0 and dst with these strides and data offset inbytes during set
4392
    // nb0 is implicitly element_size because src0 and dst are contiguous
4393
0
    size_t nb1     = ((int32_t *) dst->op_params)[0];
4394
0
    size_t nb2     = ((int32_t *) dst->op_params)[1];
4395
0
    size_t nb3     = ((int32_t *) dst->op_params)[2];
4396
0
    size_t offset  = ((int32_t *) dst->op_params)[3];
4397
0
    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
4398
4399
0
    if (!inplace) {
4400
0
        if (params->ith == 0) {
4401
            // memcpy needs to be synchronized across threads to avoid race conditions.
4402
            // => do it in INIT phase
4403
0
            memcpy(
4404
0
                ((char *)  dst->data),
4405
0
                ((char *) src0->data),
4406
0
                ggml_nbytes(dst));
4407
0
        }
4408
0
        ggml_barrier(params->threadpool);
4409
0
    }
4410
4411
0
    const int ith = params->ith;
4412
0
    const int nth = params->nth;
4413
4414
0
    const int nr = ggml_nrows(src1);
4415
0
    const int nc = src1->ne[0];
4416
4417
0
    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
4418
0
    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
4419
4420
    // src0 and dst as viewed during set
4421
0
    const size_t nb0 = ggml_element_size(src0);
4422
4423
0
    const int im0 = (ne10 == 0 ? 0 : ne10-1);
4424
0
    const int im1 = (ne11 == 0 ? 0 : ne11-1);
4425
0
    const int im2 = (ne12 == 0 ? 0 : ne12-1);
4426
0
    const int im3 = (ne13 == 0 ? 0 : ne13-1);
4427
4428
0
    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
4429
4430
0
    GGML_ASSERT(nb10 == sizeof(float));
4431
4432
    // rows per thread
4433
0
    const int dr = (nr + nth - 1)/nth;
4434
4435
    // row range for this thread
4436
0
    const int ir0 = dr*ith;
4437
0
    const int ir1 = MIN(ir0 + dr, nr);
4438
4439
0
    for (int ir = ir0; ir < ir1; ++ir) {
4440
        // src0 and dst are viewed with shape of src1 and offset
4441
        // => same indices
4442
0
        const int i3 = ir/(ne12*ne11);
4443
0
        const int i2 = (ir - i3*ne12*ne11)/ne11;
4444
0
        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
4445
4446
0
        ggml_vec_cpy_f32(nc,
4447
0
                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
4448
0
                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
4449
0
    }
4450
0
}
4451
4452
static void ggml_compute_forward_set_i32(
4453
        const ggml_compute_params * params,
4454
0
        ggml_tensor * dst) {
4455
4456
0
    const ggml_tensor * src0 = dst->src[0];
4457
0
    const ggml_tensor * src1 = dst->src[1];
4458
4459
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
4460
0
    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
4461
4462
    // view src0 and dst with these strides and data offset inbytes during set
4463
    // nb0 is implicitly element_size because src0 and dst are contiguous
4464
0
    size_t nb1     = ((int32_t *) dst->op_params)[0];
4465
0
    size_t nb2     = ((int32_t *) dst->op_params)[1];
4466
0
    size_t nb3     = ((int32_t *) dst->op_params)[2];
4467
0
    size_t offset  = ((int32_t *) dst->op_params)[3];
4468
0
    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
4469
4470
0
    if (!inplace) {
4471
0
        if (params->ith == 0) {
4472
            // memcpy needs to be synchronized across threads to avoid race conditions.
4473
            // => do it in INIT phase
4474
0
            memcpy(
4475
0
                ((char *)  dst->data),
4476
0
                ((char *) src0->data),
4477
0
                ggml_nbytes(dst));
4478
0
        }
4479
0
        ggml_barrier(params->threadpool);
4480
0
    }
4481
4482
0
    const int ith = params->ith;
4483
0
    const int nth = params->nth;
4484
4485
0
    const int nr = ggml_nrows(src1);
4486
0
    const int nc = src1->ne[0];
4487
4488
0
    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
4489
0
    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
4490
4491
    // src0 and dst as viewed during set
4492
0
    const size_t nb0 = ggml_element_size(src0);
4493
4494
0
    const int im0 = (ne10 == 0 ? 0 : ne10-1);
4495
0
    const int im1 = (ne11 == 0 ? 0 : ne11-1);
4496
0
    const int im2 = (ne12 == 0 ? 0 : ne12-1);
4497
0
    const int im3 = (ne13 == 0 ? 0 : ne13-1);
4498
4499
0
    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
4500
4501
0
    GGML_ASSERT(nb10 == sizeof(int32_t));
4502
4503
    // rows per thread
4504
0
    const int dr = (nr + nth - 1)/nth;
4505
4506
    // row range for this thread
4507
0
    const int ir0 = dr*ith;
4508
0
    const int ir1 = MIN(ir0 + dr, nr);
4509
4510
0
    for (int ir = ir0; ir < ir1; ++ir) {
4511
        // src0 and dst are viewed with shape of src1 and offset
4512
        // => same indices
4513
0
        const int i3 = ir/(ne12*ne11);
4514
0
        const int i2 = (ir - i3*ne12*ne11)/ne11;
4515
0
        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
4516
4517
0
        ggml_vec_cpy_i32(nc,
4518
0
                (int32_t *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
4519
0
                (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
4520
0
    }
4521
0
}
4522
4523
void ggml_compute_forward_set(
4524
        const ggml_compute_params * params,
4525
0
        ggml_tensor * dst) {
4526
4527
0
    const ggml_tensor * src0 = dst->src[0];
4528
4529
0
    switch (src0->type) {
4530
0
        case GGML_TYPE_F32:
4531
0
            {
4532
0
                ggml_compute_forward_set_f32(params, dst);
4533
0
            } break;
4534
0
        case GGML_TYPE_I32:
4535
0
            {
4536
0
                ggml_compute_forward_set_i32(params, dst);
4537
0
            } break;
4538
0
        case GGML_TYPE_F16:
4539
0
        case GGML_TYPE_BF16:
4540
0
        case GGML_TYPE_Q4_0:
4541
0
        case GGML_TYPE_Q4_1:
4542
0
        case GGML_TYPE_Q5_0:
4543
0
        case GGML_TYPE_Q5_1:
4544
0
        case GGML_TYPE_Q8_0:
4545
0
        case GGML_TYPE_Q8_1:
4546
0
        case GGML_TYPE_MXFP4:
4547
0
        case GGML_TYPE_Q2_K:
4548
0
        case GGML_TYPE_Q3_K:
4549
0
        case GGML_TYPE_Q4_K:
4550
0
        case GGML_TYPE_Q5_K:
4551
0
        case GGML_TYPE_Q6_K:
4552
0
        case GGML_TYPE_TQ1_0:
4553
0
        case GGML_TYPE_TQ2_0:
4554
0
        case GGML_TYPE_IQ2_XXS:
4555
0
        case GGML_TYPE_IQ2_XS:
4556
0
        case GGML_TYPE_IQ3_XXS:
4557
0
        case GGML_TYPE_IQ1_S:
4558
0
        case GGML_TYPE_IQ1_M:
4559
0
        case GGML_TYPE_IQ4_NL:
4560
0
        case GGML_TYPE_IQ4_XS:
4561
0
        case GGML_TYPE_IQ3_S:
4562
0
        case GGML_TYPE_IQ2_S:
4563
0
        default:
4564
0
            {
4565
0
                GGML_ABORT("fatal error");
4566
0
            }
4567
0
    }
4568
0
}
4569
4570
// ggml_compute_forward_cpy
4571
4572
void ggml_compute_forward_cpy(
4573
        const ggml_compute_params * params,
4574
0
        ggml_tensor * dst) {
4575
0
    ggml_compute_forward_dup(params, dst);
4576
0
}
4577
4578
// ggml_compute_forward_cont
4579
4580
void ggml_compute_forward_cont(
4581
        const ggml_compute_params * params,
4582
0
        ggml_tensor * dst) {
4583
0
    ggml_compute_forward_dup(params, dst);
4584
0
}
4585
4586
// ggml_compute_forward_get_rows
4587
4588
static void ggml_compute_forward_get_rows_q(
4589
        const ggml_compute_params * params,
4590
0
              ggml_tensor * dst) {
4591
4592
0
    const ggml_tensor * src0 = dst->src[0];
4593
0
    const ggml_tensor * src1 = dst->src[1];
4594
4595
0
    GGML_TENSOR_BINARY_OP_LOCALS
4596
4597
0
    const int64_t nc = ne00;
4598
0
    const int64_t nr = ggml_nelements(src1);
4599
4600
0
    const ggml_type type = src0->type;
4601
0
    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
4602
4603
0
    assert(ne0  == nc);
4604
0
    assert(ne02 == ne11);
4605
0
    assert(nb00 == ggml_type_size(type));
4606
0
    assert(ggml_nrows(dst) == nr);
4607
4608
0
    const int ith = params->ith;
4609
0
    const int nth = params->nth;
4610
4611
    // rows per thread
4612
0
    const int dr = (nr + nth - 1)/nth;
4613
4614
    // row range for this thread
4615
0
    const int ir0 = dr*ith;
4616
0
    const int ir1 = MIN(ir0 + dr, nr);
4617
4618
0
    for (int64_t i = ir0; i < ir1; ++i) {
4619
0
        const int64_t i12 = i/(ne11*ne10);
4620
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4621
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4622
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4623
4624
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4625
4626
0
        dequantize_row_q(
4627
0
                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4628
0
                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
4629
0
    }
4630
0
}
4631
4632
static void ggml_compute_forward_get_rows_f16(
4633
        const ggml_compute_params * params,
4634
0
              ggml_tensor * dst) {
4635
4636
0
    const ggml_tensor * src0 = dst->src[0];
4637
0
    const ggml_tensor * src1 = dst->src[1];
4638
4639
0
    GGML_TENSOR_BINARY_OP_LOCALS
4640
4641
0
    const int64_t nc = ne00;
4642
0
    const int64_t nr = ggml_nelements(src1);
4643
4644
0
    assert(ne0  == nc);
4645
0
    assert(ne02 == ne11);
4646
0
    assert(nb00 == sizeof(ggml_fp16_t));
4647
0
    assert(ggml_nrows(dst) == nr);
4648
4649
0
    const int ith = params->ith;
4650
0
    const int nth = params->nth;
4651
4652
    // rows per thread
4653
0
    const int dr = (nr + nth - 1)/nth;
4654
4655
    // row range for this thread
4656
0
    const int ir0 = dr*ith;
4657
0
    const int ir1 = MIN(ir0 + dr, nr);
4658
4659
0
    for (int64_t i = ir0; i < ir1; ++i) {
4660
0
        const int64_t i12 = i/(ne11*ne10);
4661
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4662
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4663
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4664
4665
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4666
4667
0
        ggml_cpu_fp16_to_fp32(
4668
0
            (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4669
0
                       (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
4670
0
    }
4671
0
}
4672
4673
static void ggml_compute_forward_get_rows_bf16(
4674
        const ggml_compute_params * params,
4675
0
              ggml_tensor * dst) {
4676
4677
0
    const ggml_tensor * src0 = dst->src[0];
4678
0
    const ggml_tensor * src1 = dst->src[1];
4679
4680
0
    GGML_TENSOR_BINARY_OP_LOCALS
4681
4682
0
    const int64_t nc = ne00;
4683
0
    const int64_t nr = ggml_nelements(src1);
4684
4685
0
    assert(ne0  == nc);
4686
0
    assert(ne02 == ne11);
4687
0
    assert(nb00 == sizeof(ggml_bf16_t));
4688
0
    assert(ggml_nrows(dst) == nr);
4689
4690
0
    const int ith = params->ith;
4691
0
    const int nth = params->nth;
4692
4693
    // rows per thread
4694
0
    const int dr = (nr + nth - 1)/nth;
4695
4696
    // row range for this thread
4697
0
    const int ir0 = dr*ith;
4698
0
    const int ir1 = MIN(ir0 + dr, nr);
4699
4700
0
    for (int64_t i = ir0; i < ir1; ++i) {
4701
0
        const int64_t i12 = i/(ne11*ne10);
4702
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4703
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4704
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4705
4706
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4707
4708
0
        ggml_cpu_bf16_to_fp32(
4709
0
            (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4710
0
                        (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
4711
0
    }
4712
0
}
4713
4714
static void ggml_compute_forward_get_rows_f32(
4715
        const ggml_compute_params * params,
4716
0
              ggml_tensor * dst) {
4717
4718
0
    const ggml_tensor * src0 = dst->src[0];
4719
0
    const ggml_tensor * src1 = dst->src[1];
4720
4721
0
    GGML_TENSOR_BINARY_OP_LOCALS
4722
4723
0
    const int64_t nc = ne00;
4724
0
    const int64_t nr = ggml_nelements(src1);
4725
4726
0
    assert(ne0  == nc);
4727
0
    assert(ne02 == ne11);
4728
0
    assert(nb00 == sizeof(float));
4729
0
    assert(ggml_nrows(dst) == nr);
4730
4731
0
    const int ith = params->ith;
4732
0
    const int nth = params->nth;
4733
4734
    // rows per thread
4735
0
    const int dr = (nr + nth - 1)/nth;
4736
4737
    // row range for this thread
4738
0
    const int ir0 = dr*ith;
4739
0
    const int ir1 = MIN(ir0 + dr, nr);
4740
4741
0
    for (int64_t i = ir0; i < ir1; ++i) {
4742
0
        const int64_t i12 = i/(ne11*ne10);
4743
0
        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4744
0
        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4745
0
        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4746
4747
0
        GGML_ASSERT(i01 >= 0 && i01 < ne01);
4748
4749
0
        ggml_vec_cpy_f32(nc,
4750
0
                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
4751
0
                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
4752
0
    }
4753
0
}
4754
4755
void ggml_compute_forward_get_rows(
4756
        const ggml_compute_params * params,
4757
0
        ggml_tensor * dst) {
4758
4759
0
    const ggml_tensor * src0 = dst->src[0];
4760
4761
0
    switch (src0->type) {
4762
0
        case GGML_TYPE_Q4_0:
4763
0
        case GGML_TYPE_Q4_1:
4764
0
        case GGML_TYPE_Q5_0:
4765
0
        case GGML_TYPE_Q5_1:
4766
0
        case GGML_TYPE_Q8_0:
4767
0
        case GGML_TYPE_Q8_1:
4768
0
        case GGML_TYPE_MXFP4:
4769
0
        case GGML_TYPE_Q2_K:
4770
0
        case GGML_TYPE_Q3_K:
4771
0
        case GGML_TYPE_Q4_K:
4772
0
        case GGML_TYPE_Q5_K:
4773
0
        case GGML_TYPE_Q6_K:
4774
0
        case GGML_TYPE_TQ1_0:
4775
0
        case GGML_TYPE_TQ2_0:
4776
0
        case GGML_TYPE_IQ2_XXS:
4777
0
        case GGML_TYPE_IQ2_XS:
4778
0
        case GGML_TYPE_IQ3_XXS:
4779
0
        case GGML_TYPE_IQ1_S:
4780
0
        case GGML_TYPE_IQ1_M:
4781
0
        case GGML_TYPE_IQ4_NL:
4782
0
        case GGML_TYPE_IQ4_XS:
4783
0
        case GGML_TYPE_IQ3_S:
4784
0
        case GGML_TYPE_IQ2_S:
4785
0
            {
4786
0
                ggml_compute_forward_get_rows_q(params, dst);
4787
0
            } break;
4788
0
        case GGML_TYPE_F16:
4789
0
            {
4790
0
                ggml_compute_forward_get_rows_f16(params, dst);
4791
0
            } break;
4792
0
        case GGML_TYPE_BF16:
4793
0
            {
4794
0
                ggml_compute_forward_get_rows_bf16(params, dst);
4795
0
            } break;
4796
0
        case GGML_TYPE_F32:
4797
0
        case GGML_TYPE_I32:
4798
0
            {
4799
0
                ggml_compute_forward_get_rows_f32(params, dst);
4800
0
            } break;
4801
0
        default:
4802
0
            {
4803
0
                GGML_ABORT("fatal error");
4804
0
            }
4805
0
    }
4806
4807
    //static bool first = true;
4808
    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
4809
    //if (first) {
4810
    //    first = false;
4811
    //} else {
4812
    //    for (int k = 0; k < dst->ne[1]; ++k) {
4813
    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
4814
    //            for (int i = 0; i < 16; ++i) {
4815
    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
4816
    //            }
4817
    //            printf("\n");
4818
    //        }
4819
    //        printf("\n");
4820
    //    }
4821
    //    printf("\n");
4822
    //    exit(0);
4823
    //}
4824
0
}
4825
4826
template<typename idx_t>
4827
static void ggml_compute_forward_set_rows_f32(
4828
        const ggml_compute_params * params,
4829
0
              ggml_tensor * dst) {
4830
4831
0
    const ggml_tensor * src0 = dst->src[0];
4832
0
    const ggml_tensor * src1 = dst->src[1];
4833
4834
0
    GGML_TENSOR_BINARY_OP_LOCALS
4835
4836
0
    const int64_t nc = ne00;
4837
0
    const int64_t nr = ne01;
4838
4839
0
    assert(ne0  == nc);
4840
0
    assert(ne2  == ne02);
4841
0
    assert(ne3  == ne03);
4842
0
    assert(src0->type == GGML_TYPE_F32);
4843
0
    assert(ne02 % ne11 == 0);
4844
0
    assert(ne03 % ne12 == 0);
4845
4846
0
    const int ith = params->ith;
4847
0
    const int nth = params->nth;
4848
4849
    // rows per thread
4850
0
    const int64_t dr = (nr + nth - 1)/nth;
4851
4852
    // row range for this thread
4853
0
    const int64_t ir0 = dr*ith;
4854
0
    const int64_t ir1 = std::min(ir0 + dr, nr);
4855
4856
0
    ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
4857
4858
0
    for (int64_t i03 = 0; i03 < ne03; ++i03) {
4859
0
        for (int64_t i02 = 0; i02 < ne02; ++i02) {
4860
0
            for (int64_t i = ir0; i < ir1; ++i) {
4861
0
                const int64_t i12 = i03%ne12;
4862
0
                const int64_t i11 = i02%ne11;
4863
0
                const int64_t i10 = i;
4864
4865
0
                const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4866
4867
0
                GGML_ASSERT(i1 >= 0 && i1 < ne1);
4868
4869
0
                from_float(
4870
0
                        (const float *) ((char *) src0->data +  i*nb01 + i02*nb02 + i03*nb03),
4871
0
                                        ((char *)  dst->data + i1*nb1  + i02*nb2  + i03*nb3), nc);
4872
0
            }
4873
0
        }
4874
0
    }
4875
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_set_rows_f32<long>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_set_rows_f32<int>(ggml_compute_params const*, ggml_tensor*)
4876
4877
void ggml_compute_forward_set_rows(
4878
        const ggml_compute_params * params,
4879
0
        ggml_tensor * dst) {
4880
4881
0
    const ggml_tensor * src0 = dst->src[0];
4882
0
    const ggml_tensor * src1 = dst->src[1];
4883
4884
0
    switch (src0->type) {
4885
0
        case GGML_TYPE_F32:
4886
0
            {
4887
0
                if (src1->type == GGML_TYPE_I64) {
4888
0
                    ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4889
0
                } else if (src1->type == GGML_TYPE_I32) {
4890
0
                    ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4891
0
                } else {
4892
0
                    GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
4893
0
                }
4894
0
            } break;
4895
0
        default:
4896
0
            {
4897
0
                GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
4898
0
            }
4899
0
    }
4900
0
}
4901
4902
// ggml_compute_forward_get_rows_back
4903
4904
static void ggml_compute_forward_get_rows_back_f32_f16(
4905
        const ggml_compute_params * params,
4906
0
              ggml_tensor * dst) {
4907
4908
0
    const ggml_tensor * src0 = dst->src[0];
4909
0
    const ggml_tensor * src1 = dst->src[1];
4910
4911
0
    if (params->ith != 0) {
4912
0
        return;
4913
0
    }
4914
4915
0
    GGML_ASSERT(ggml_is_contiguous(dst));
4916
4917
    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
4918
4919
0
    memset(dst->data, 0, ggml_nbytes(dst));
4920
4921
0
    const int nc = src0->ne[0];
4922
0
    const int nr = ggml_nelements(src1);
4923
4924
0
    GGML_ASSERT( dst->ne[0] == nc);
4925
0
    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
4926
4927
0
    for (int i = 0; i < nr; ++i) {
4928
0
        const int r = ((int32_t *) src1->data)[i];
4929
4930
0
        for (int j = 0; j < nc; ++j) {
4931
0
            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
4932
0
            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
4933
0
        }
4934
0
    }
4935
0
}
4936
4937
static void ggml_compute_forward_get_rows_back_f32(
4938
        const ggml_compute_params * params,
4939
0
              ggml_tensor * dst) {
4940
4941
0
    const ggml_tensor * src0 = dst->src[0];
4942
0
    const ggml_tensor * src1 = dst->src[1];
4943
4944
0
    if (params->ith != 0) {
4945
0
        return;
4946
0
    }
4947
4948
0
    GGML_ASSERT(ggml_is_contiguous(dst));
4949
4950
    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
4951
4952
0
    memset(dst->data, 0, ggml_nbytes(dst));
4953
4954
0
    const int nc = src0->ne[0];
4955
0
    const int nr = ggml_nelements(src1);
4956
4957
0
    GGML_ASSERT( dst->ne[0] == nc);
4958
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
4959
4960
0
    for (int i = 0; i < nr; ++i) {
4961
0
        const int r = ((int32_t *) src1->data)[i];
4962
4963
0
        ggml_vec_add_f32(nc,
4964
0
                (float *) ((char *)  dst->data + r*dst->nb[1]),
4965
0
                (float *) ((char *)  dst->data + r*dst->nb[1]),
4966
0
                (float *) ((char *) src0->data + i*src0->nb[1]));
4967
0
    }
4968
0
}
4969
4970
void ggml_compute_forward_get_rows_back(
4971
        const ggml_compute_params * params,
4972
0
        ggml_tensor * dst) {
4973
4974
0
    const ggml_tensor * src0 = dst->src[0];
4975
4976
0
    switch (src0->type) {
4977
0
        case GGML_TYPE_F16:
4978
0
            {
4979
0
                ggml_compute_forward_get_rows_back_f32_f16(params, dst);
4980
0
            } break;
4981
0
        case GGML_TYPE_F32:
4982
0
            {
4983
0
                ggml_compute_forward_get_rows_back_f32(params, dst);
4984
0
            } break;
4985
0
        default:
4986
0
            {
4987
0
                GGML_ABORT("fatal error");
4988
0
            }
4989
0
    }
4990
4991
    //static bool first = true;
4992
    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
4993
    //if (first) {
4994
    //    first = false;
4995
    //} else {
4996
    //    for (int k = 0; k < dst->ne[1]; ++k) {
4997
    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
4998
    //            for (int i = 0; i < 16; ++i) {
4999
    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
5000
    //            }
5001
    //            printf("\n");
5002
    //        }
5003
    //        printf("\n");
5004
    //    }
5005
    //    printf("\n");
5006
    //    exit(0);
5007
    //}
5008
0
}
5009
5010
// ggml_compute_forward_diag
5011
5012
static void ggml_compute_forward_diag_f32(
5013
        const ggml_compute_params * params,
5014
0
        ggml_tensor * dst) {
5015
5016
0
    const ggml_tensor * src0 = dst->src[0];
5017
5018
0
    if (params->ith != 0) {
5019
0
        return;
5020
0
    }
5021
5022
    // TODO: handle transposed/permuted matrices
5023
5024
0
    GGML_TENSOR_UNARY_OP_LOCALS
5025
5026
0
    GGML_ASSERT(ne00 == ne0);
5027
0
    GGML_ASSERT(ne00 == ne1);
5028
0
    GGML_ASSERT(ne01 == 1);
5029
0
    GGML_ASSERT(ne02 == ne2);
5030
0
    GGML_ASSERT(ne03 == ne3);
5031
5032
0
    GGML_ASSERT(nb00 == sizeof(float));
5033
0
    GGML_ASSERT(nb0  == sizeof(float));
5034
5035
0
    for (int i3 = 0; i3 < ne3; i3++) {
5036
0
        for (int i2 = 0; i2 < ne2; i2++) {
5037
0
            for (int i1 = 0; i1 < ne1; i1++) {
5038
0
                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);
5039
0
                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
5040
0
                for (int i0 = 0; i0 < i1; i0++) {
5041
0
                    d[i0] = 0;
5042
0
                }
5043
0
                d[i1] = s[i1];
5044
0
                for (int i0 = i1+1; i0 < ne0; i0++) {
5045
0
                    d[i0] = 0;
5046
0
                }
5047
0
            }
5048
0
        }
5049
0
    }
5050
0
}
5051
5052
void ggml_compute_forward_diag(
5053
        const ggml_compute_params * params,
5054
0
        ggml_tensor * dst) {
5055
5056
0
    const ggml_tensor * src0 = dst->src[0];
5057
5058
0
    switch (src0->type) {
5059
0
        case GGML_TYPE_F32:
5060
0
            {
5061
0
                ggml_compute_forward_diag_f32(params, dst);
5062
0
            } break;
5063
0
        default:
5064
0
            {
5065
0
                GGML_ABORT("fatal error");
5066
0
            }
5067
0
    }
5068
0
}
5069
5070
// ggml_compute_forward_diag_mask_inf
5071
5072
static void ggml_compute_forward_diag_mask_f32(
5073
        const ggml_compute_params * params,
5074
        ggml_tensor * dst,
5075
0
        const float value) {
5076
5077
0
    const ggml_tensor * src0 = dst->src[0];
5078
5079
0
    const int ith = params->ith;
5080
0
    const int nth = params->nth;
5081
5082
0
    const int  n_past  = ((int32_t *) dst->op_params)[0];
5083
0
    const bool inplace = src0->data == dst->data;
5084
5085
0
    GGML_ASSERT(n_past >= 0);
5086
5087
0
    if (!inplace) {
5088
0
        if (ith == 0) {
5089
            // memcpy needs to be synchronized across threads to avoid race conditions.
5090
            // => do it in INIT phase
5091
0
            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5092
0
            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
5093
0
            memcpy(
5094
0
                ((char *)  dst->data),
5095
0
                ((char *) src0->data),
5096
0
                ggml_nbytes(dst));
5097
0
        }
5098
0
        ggml_barrier(params->threadpool);
5099
0
    }
5100
5101
    // TODO: handle transposed/permuted matrices
5102
5103
0
    const int n  = ggml_nrows(src0);
5104
0
    const int nc = src0->ne[0];
5105
0
    const int nr = src0->ne[1];
5106
0
    const int nz = n/nr;
5107
5108
0
    GGML_ASSERT( dst->nb[0] == sizeof(float));
5109
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
5110
5111
0
    for (int k = 0; k < nz; k++) {
5112
0
        for (int j = ith; j < nr; j += nth) {
5113
0
            for (int i = n_past; i < nc; i++) {
5114
0
                if (i > n_past + j) {
5115
0
                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
5116
0
                }
5117
0
            }
5118
0
        }
5119
0
    }
5120
0
}
5121
5122
void ggml_compute_forward_diag_mask_inf(
5123
        const ggml_compute_params * params,
5124
0
        ggml_tensor * dst) {
5125
5126
0
    const ggml_tensor * src0 = dst->src[0];
5127
5128
0
    switch (src0->type) {
5129
0
        case GGML_TYPE_F32:
5130
0
            {
5131
0
                ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
5132
0
            } break;
5133
0
        default:
5134
0
            {
5135
0
                GGML_ABORT("fatal error");
5136
0
            }
5137
0
    }
5138
0
}
5139
5140
void ggml_compute_forward_diag_mask_zero(
5141
        const ggml_compute_params * params,
5142
0
        ggml_tensor * dst) {
5143
5144
0
    const ggml_tensor * src0 = dst->src[0];
5145
5146
0
    switch (src0->type) {
5147
0
        case GGML_TYPE_F32:
5148
0
            {
5149
0
                ggml_compute_forward_diag_mask_f32(params, dst, 0);
5150
0
            } break;
5151
0
        default:
5152
0
            {
5153
0
                GGML_ABORT("fatal error");
5154
0
            }
5155
0
    }
5156
0
}
5157
5158
// ggml_compute_forward_soft_max
5159
5160
static void ggml_compute_forward_soft_max_f32(
5161
        const ggml_compute_params * params,
5162
0
              ggml_tensor * dst) {
5163
5164
0
    const ggml_tensor * src0 = dst->src[0];
5165
0
    const ggml_tensor * src1 = dst->src[1];
5166
0
    const ggml_tensor * src2 = dst->src[2];
5167
5168
0
    assert(ggml_is_contiguous(dst));
5169
0
    assert(ggml_are_same_shape(src0, dst));
5170
5171
0
    float scale    = 1.0f;
5172
0
    float max_bias = 0.0f;
5173
5174
0
    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
5175
0
    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5176
5177
0
    const int ith = params->ith;
5178
0
    const int nth = params->nth;
5179
5180
0
    GGML_TENSOR_UNARY_OP_LOCALS
5181
5182
0
    const int64_t nb11 = src1 ? src1->nb[1] : 1;
5183
0
    const int64_t nb12 = src1 ? src1->nb[2] : 1;
5184
0
    const int64_t nb13 = src1 ? src1->nb[3] : 1;
5185
5186
0
    const int64_t ne12 = src1 ? src1->ne[2] : 1;
5187
0
    const int64_t ne13 = src1 ? src1->ne[3] : 1;
5188
5189
    // TODO: is this supposed to be ceil instead of floor?
5190
    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
5191
0
    const uint32_t n_head      = ne02;
5192
0
    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
5193
5194
0
    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
5195
0
    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5196
5197
0
    float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5198
5199
0
    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5200
5201
    // sinks
5202
0
    const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5203
5204
0
    for (int64_t i03 = 0; i03 < ne03; i03++) {
5205
0
        for (int64_t i02 = 0; i02 < ne02; i02++) {
5206
0
            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5207
0
                const int64_t i11 = i01;
5208
0
                const int64_t i12 = i02%ne12;
5209
0
                const int64_t i13 = i03%ne13;
5210
5211
                // ALiBi
5212
0
                const uint32_t h = i02; // head
5213
0
                const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5214
5215
0
                float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5216
0
                float * dp = (float *)((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3);
5217
5218
                // broadcast the mask across rows
5219
0
                ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5220
0
                float       * mp_f32 = src1 ? (float       *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5221
5222
0
                ggml_vec_cpy_f32  (ne00, wp, sp);
5223
0
                ggml_vec_scale_f32(ne00, wp, scale);
5224
0
                if (mp_f32) {
5225
0
                    if (use_f16) {
5226
0
                        for (int i = 0; i < ne00; ++i) {
5227
0
                            wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5228
0
                        }
5229
0
                    } else {
5230
0
                        for (int i = 0; i < ne00; ++i) {
5231
0
                            wp[i] += slope*mp_f32[i];
5232
0
                        }
5233
0
                    }
5234
0
                }
5235
5236
#ifndef NDEBUG
5237
                for (int i = 0; i < ne00; ++i) {
5238
                    //printf("p[%d] = %f\n", i, p[i]);
5239
                    assert(!isnan(wp[i]));
5240
                }
5241
#endif
5242
5243
0
                float max = -INFINITY;
5244
0
                ggml_vec_max_f32(ne00, &max, wp);
5245
5246
                // if we have sinks, make a correction as if they were included in the softmax
5247
0
                if (sk) {
5248
0
                    max = MAX(max, sk[i02]);
5249
0
                }
5250
5251
0
                ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5252
0
                assert(sum > 0.0);
5253
5254
0
                if (sk) {
5255
0
                    sum += (ggml_float) expf(sk[i02] - max);
5256
0
                }
5257
5258
0
                sum = 1.0/sum;
5259
0
                ggml_vec_scale_f32(ne00, dp, sum);
5260
5261
#ifndef NDEBUG
5262
                for (int i = 0; i < ne00; ++i) {
5263
                    assert(!isnan(dp[i]));
5264
                    assert(!isinf(dp[i]));
5265
                }
5266
#endif
5267
0
            }
5268
0
        }
5269
0
    }
5270
0
}
5271
5272
void ggml_compute_forward_soft_max(
5273
        const ggml_compute_params * params,
5274
0
              ggml_tensor * dst) {
5275
5276
0
    const ggml_tensor * src0 = dst->src[0];
5277
5278
0
    switch (src0->type) {
5279
0
        case GGML_TYPE_F32:
5280
0
            {
5281
0
                ggml_compute_forward_soft_max_f32(params, dst);
5282
0
            } break;
5283
0
        default:
5284
0
            {
5285
0
                GGML_ABORT("fatal error");
5286
0
            }
5287
0
    }
5288
0
}
5289
5290
5291
// ggml_compute_forward_soft_max_ext_back
5292
5293
static void ggml_compute_forward_soft_max_ext_back_f32(
5294
        const ggml_compute_params * params,
5295
0
        ggml_tensor * dst) {
5296
5297
0
    const ggml_tensor * src0 = dst->src[0];
5298
0
    const ggml_tensor * src1 = dst->src[1];
5299
5300
0
    GGML_ASSERT(ggml_is_contiguous(src0));
5301
0
    GGML_ASSERT(ggml_is_contiguous(src1));
5302
0
    GGML_ASSERT(ggml_is_contiguous(dst));
5303
0
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
5304
0
    GGML_ASSERT(ggml_are_same_shape(src1, dst));
5305
5306
0
    float scale    = 1.0f;
5307
0
    float max_bias = 0.0f;
5308
5309
0
    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
5310
0
    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
5311
5312
0
    GGML_ASSERT(max_bias == 0.0f);
5313
5314
    // TODO: handle transposed/permuted matrices
5315
5316
0
    const int ith = params->ith;
5317
0
    const int nth = params->nth;
5318
5319
0
    const int nc = src0->ne[0];
5320
0
    const int nr = ggml_nrows(src0);
5321
5322
    // rows per thread
5323
0
    const int dr = (nr + nth - 1)/nth;
5324
5325
    // row range for this thread
5326
0
    const int ir0 = dr*ith;
5327
0
    const int ir1 = MIN(ir0 + dr, nr);
5328
5329
0
    for (int i1 = ir0; i1 < ir1; i1++) {
5330
0
        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
5331
0
        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
5332
0
        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
5333
5334
#ifndef NDEBUG
5335
        for (int i = 0; i < nc; ++i) {
5336
            //printf("p[%d] = %f\n", i, p[i]);
5337
            assert(!isnan(dy[i]));
5338
            assert(!isnan(y[i]));
5339
        }
5340
#endif
5341
        // Jii = yi - yi*yi
5342
        // Jij = -yi*yj
5343
        // J = diag(y)-y.T*y
5344
        // dx = J * dy
5345
        // dxk = sum_i(Jki * dyi)
5346
        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
5347
        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
5348
        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
5349
        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
5350
        // dxk = -yk * dot(y, dy) + yk*dyk
5351
        // dxk = yk * (- dot(y, dy) + dyk)
5352
        // dxk = yk * (dyk - dot(y, dy))
5353
        //
5354
        // post-order:
5355
        // dot_y_dy := dot(y, dy)
5356
        // dx := dy
5357
        // dx := dx - dot_y_dy
5358
        // dx := dx * y
5359
5360
        // linear runtime, no additional memory
5361
0
        float dot_y_dy = 0;
5362
0
        ggml_vec_dot_f32  (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
5363
0
        ggml_vec_cpy_f32  (nc, dx, dy);
5364
0
        ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
5365
0
        ggml_vec_mul_f32  (nc, dx, dx, y);
5366
0
        ggml_vec_scale_f32(nc, dx, scale);
5367
5368
#ifndef NDEBUG
5369
        for (int i = 0; i < nc; ++i) {
5370
            assert(!isnan(dx[i]));
5371
            assert(!isinf(dx[i]));
5372
        }
5373
#endif
5374
0
    }
5375
0
}
5376
5377
void ggml_compute_forward_soft_max_ext_back(
5378
        const ggml_compute_params * params,
5379
0
        ggml_tensor * dst) {
5380
5381
0
    const ggml_tensor * src0 = dst->src[0];
5382
5383
0
    switch (src0->type) {
5384
0
        case GGML_TYPE_F32:
5385
0
            {
5386
0
                ggml_compute_forward_soft_max_ext_back_f32(params, dst);
5387
0
            } break;
5388
0
        default:
5389
0
            {
5390
0
                GGML_ABORT("fatal error");
5391
0
            }
5392
0
    }
5393
0
}
5394
5395
// ggml_compute_forward_clamp
5396
5397
static void ggml_compute_forward_clamp_f32(
5398
        const ggml_compute_params * params,
5399
0
        ggml_tensor * dst) {
5400
5401
0
    const ggml_tensor * src0 = dst->src[0];
5402
5403
0
    float min;
5404
0
    float max;
5405
0
    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
5406
0
    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
5407
5408
0
    const int ith = params->ith;
5409
0
    const int nth = params->nth;
5410
5411
0
    const int n  = ggml_nrows(src0);
5412
0
    const int nc = src0->ne[0];
5413
5414
0
    const size_t nb00 = src0->nb[0];
5415
0
    const size_t nb01 = src0->nb[1];
5416
5417
0
    const size_t nb0 = dst->nb[0];
5418
0
    const size_t nb1 = dst->nb[1];
5419
5420
0
    GGML_ASSERT( nb0 == sizeof(float));
5421
0
    GGML_ASSERT(nb00 == sizeof(float));
5422
5423
0
    for (int j = ith; j < n; j += nth) {
5424
0
        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
5425
0
        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
5426
5427
0
        for (int i = 0; i < nc; i++) {
5428
0
            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
5429
0
        }
5430
0
    }
5431
0
}
5432
5433
static void ggml_compute_forward_clamp_f16(
5434
    const ggml_compute_params * params,
5435
0
    ggml_tensor * dst) {
5436
5437
0
    const ggml_tensor * src0 = dst->src[0];
5438
5439
0
    float min;
5440
0
    float max;
5441
0
    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
5442
0
    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
5443
5444
0
    const int ith = params->ith;
5445
0
    const int nth = params->nth;
5446
5447
0
    const int n  = ggml_nrows(src0);
5448
0
    const int nc = src0->ne[0];
5449
5450
0
    const size_t nb00 = src0->nb[0];
5451
0
    const size_t nb01 = src0->nb[1];
5452
5453
0
    const size_t nb0 = dst->nb[0];
5454
0
    const size_t nb1 = dst->nb[1];
5455
5456
0
    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5457
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5458
5459
0
    for (int j = ith; j < n; j += nth) {
5460
0
        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *)  dst->data + j*nb1);
5461
0
        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5462
5463
0
        for (int i = 0; i < nc; i++) {
5464
0
            float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
5465
0
            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
5466
0
        }
5467
0
    }
5468
0
}
5469
5470
void ggml_compute_forward_clamp(
5471
        const ggml_compute_params * params,
5472
0
        ggml_tensor * dst) {
5473
5474
0
    const ggml_tensor * src0 = dst->src[0];
5475
5476
0
    switch (src0->type) {
5477
0
        case GGML_TYPE_F32:
5478
0
            {
5479
0
                ggml_compute_forward_clamp_f32(params, dst);
5480
0
            } break;
5481
0
        case GGML_TYPE_F16:
5482
0
            {
5483
0
                ggml_compute_forward_clamp_f16(params, dst);
5484
0
            } break;
5485
0
        case GGML_TYPE_BF16:
5486
0
        case GGML_TYPE_Q4_0:
5487
0
        case GGML_TYPE_Q4_1:
5488
0
        case GGML_TYPE_Q5_0:
5489
0
        case GGML_TYPE_Q5_1:
5490
0
        case GGML_TYPE_Q8_0:
5491
0
        case GGML_TYPE_Q8_1:
5492
0
        case GGML_TYPE_MXFP4:
5493
0
        case GGML_TYPE_Q2_K:
5494
0
        case GGML_TYPE_Q3_K:
5495
0
        case GGML_TYPE_Q4_K:
5496
0
        case GGML_TYPE_Q5_K:
5497
0
        case GGML_TYPE_Q6_K:
5498
0
        case GGML_TYPE_TQ1_0:
5499
0
        case GGML_TYPE_TQ2_0:
5500
0
        case GGML_TYPE_IQ2_XXS:
5501
0
        case GGML_TYPE_IQ2_XS:
5502
0
        case GGML_TYPE_IQ3_XXS:
5503
0
        case GGML_TYPE_IQ1_S:
5504
0
        case GGML_TYPE_IQ1_M:
5505
0
        case GGML_TYPE_IQ4_NL:
5506
0
        case GGML_TYPE_IQ4_XS:
5507
0
        case GGML_TYPE_IQ3_S:
5508
0
        case GGML_TYPE_IQ2_S:
5509
0
        case GGML_TYPE_Q8_K:
5510
0
        case GGML_TYPE_I8:
5511
0
        case GGML_TYPE_I16:
5512
0
        case GGML_TYPE_I32:
5513
0
        case GGML_TYPE_I64:
5514
0
        case GGML_TYPE_F64:
5515
0
        case GGML_TYPE_COUNT:
5516
0
            {
5517
0
                GGML_ABORT("fatal error");
5518
0
            }
5519
0
    }
5520
0
}
5521
5522
// ggml_compute_forward_rope
5523
5524
0
static float rope_yarn_ramp(const float low, const float high, const int i0) {
5525
0
    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
5526
0
    return 1 - MIN(1, MAX(0, y));
5527
0
}
5528
5529
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
5530
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
5531
static void rope_yarn(
5532
    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
5533
0
    float * cos_theta, float * sin_theta) {
5534
    // Get n-d rotational scaling corrected for extrapolation
5535
0
    float theta_interp = freq_scale * theta_extrap;
5536
0
    float theta = theta_interp;
5537
0
    if (ext_factor != 0.0f) {
5538
0
        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
5539
0
        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
5540
5541
        // Get n-d magnitude scaling corrected for interpolation
5542
0
        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
5543
0
    }
5544
0
    *cos_theta = cosf(theta) * mscale;
5545
0
    *sin_theta = sinf(theta) * mscale;
5546
0
}
5547
5548
static void ggml_rope_cache_init(
5549
     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5550
0
     float * cache, float sin_sign, float theta_scale) {
5551
    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5552
0
    float theta = theta_base;
5553
0
    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5554
0
        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5555
0
        rope_yarn(
5556
0
            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5557
0
        );
5558
0
        cache[i0 + 1] *= sin_sign;
5559
5560
0
        theta *= theta_scale;
5561
0
    }
5562
0
}
5563
5564
static void ggml_mrope_cache_init(
5565
     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5566
     float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5567
0
     float * cache, float sin_sign, float theta_scale) {
5568
    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5569
0
    float theta_t = theta_base_t;
5570
0
    float theta_h = theta_base_h;
5571
0
    float theta_w = theta_base_w;
5572
0
    float theta_e = theta_base_e;  // extra position id for vision encoder
5573
0
    int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
5574
0
    int sec_w = sections[1] + sections[0];
5575
0
    int sec_e = sections[2] + sec_w;
5576
0
    GGML_ASSERT(sect_dims <= ne0);
5577
5578
0
    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5579
0
        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5580
5581
0
        int sector = (i0 / 2) % sect_dims;
5582
0
        if (indep_sects) {
5583
            // compute theta independently for each dim sections
5584
            // (i.e. reset corresponding theta when `i0` go from one section to another)
5585
0
            if (sector == 0) {
5586
0
                theta_t = theta_base_t;
5587
0
            }
5588
0
            else if (sector == sections[0]) {
5589
0
                theta_h = theta_base_h;;
5590
0
            }
5591
0
            else if (sector == sec_w) {
5592
0
                theta_w = theta_base_w;
5593
0
            }
5594
0
            else if (sector == sec_e) {
5595
0
                theta_e = theta_base_e;
5596
0
            }
5597
0
        }
5598
5599
0
        float theta = theta_t;
5600
0
        if (is_imrope) { // qwen3vl apply interleaved mrope
5601
0
            if (sector % 3 == 1 && sector < 3 * sections[1]) {
5602
0
                theta = theta_h;
5603
0
            } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5604
0
                theta = theta_w;
5605
0
            } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5606
0
                theta = theta_t;
5607
0
            } else {
5608
0
                theta = theta_e;
5609
0
            }
5610
0
        } else {
5611
0
            if (sector >= sections[0] && sector < sec_w) {
5612
0
                theta = theta_h;
5613
0
            }
5614
0
            else if (sector >= sec_w && sector < sec_w + sections[2]) {
5615
0
                theta = theta_w;
5616
0
            }
5617
0
            else if (sector >= sec_w + sections[2]) {
5618
0
                theta = theta_e;
5619
0
            }
5620
0
        }
5621
5622
0
        rope_yarn(
5623
0
            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5624
0
        );
5625
0
        cache[i0 + 1] *= sin_sign;
5626
5627
0
        theta_t *= theta_scale;
5628
0
        theta_w *= theta_scale;
5629
0
        theta_h *= theta_scale;
5630
0
        theta_e *= theta_scale;
5631
0
    }
5632
0
}
5633
5634
5635
template<typename T>
5636
0
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
5637
0
  for (int64_t i0 = 0; i0 < n; i0 += 2) {
5638
0
    const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5639
5640
0
    const float cos_theta = cache[i0 + 0];
5641
0
    const float sin_theta = cache[i0 + 1];
5642
5643
0
    const T * const src = src_data + ic;
5644
0
    T * dst             = dst_data + ic;
5645
5646
0
    const float x0 = type_conversion_table<T>::to_f32(src[0]);
5647
0
    const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
5648
5649
0
    dst[0]        = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5650
0
    dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
5651
0
  }
5652
0
}
Unexecuted instantiation: ops.cpp:void rotate_pairs<unsigned short>(long, long, float const*, unsigned short const*, unsigned short*, int)
Unexecuted instantiation: ops.cpp:void rotate_pairs<float>(long, long, float const*, float const*, float*, int)
5653
5654
template<typename T> //float or ggml_fp16_t
5655
static void ggml_compute_forward_rope_flt(
5656
        const ggml_compute_params * params,
5657
        ggml_tensor * dst,
5658
0
        const bool forward) {
5659
5660
0
    const ggml_tensor * src0 = dst->src[0];
5661
0
    const ggml_tensor * src1 = dst->src[1];
5662
0
    const ggml_tensor * src2 = dst->src[2];
5663
5664
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5665
0
    GGML_ASSERT(src1->type == GGML_TYPE_I32);
5666
5667
0
    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5668
0
    int sections[4];
5669
5670
    //const int n_past     = ((int32_t *) dst->op_params)[0];
5671
0
    const int n_dims     = ((int32_t *) dst->op_params)[1];
5672
0
    const int mode       = ((int32_t *) dst->op_params)[2];
5673
    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
5674
0
    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5675
5676
0
    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
5677
0
    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
5678
0
    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
5679
0
    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
5680
0
    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
5681
0
    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
5682
0
    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
5683
5684
0
    GGML_TENSOR_UNARY_OP_LOCALS
5685
5686
    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5687
    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5688
5689
0
    GGML_ASSERT(nb0 == nb00);
5690
0
    GGML_ASSERT(nb0 == sizeof(T));
5691
5692
0
    const int ith = params->ith;
5693
0
    const int nth = params->nth;
5694
5695
0
    const int nr = ggml_nrows(dst);
5696
5697
0
    GGML_ASSERT(n_dims <= ne0);
5698
0
    GGML_ASSERT(n_dims % 2 == 0);
5699
5700
    // rows per thread
5701
0
    const int dr = (nr + nth - 1)/nth;
5702
5703
    // row range for this thread
5704
0
    const int ir0 = dr*ith;
5705
0
    const int ir1 = MIN(ir0 + dr, nr);
5706
5707
    // row index used to determine which thread to use
5708
0
    int ir = 0;
5709
5710
0
    const float theta_scale = powf(freq_base, -2.0f/n_dims);
5711
5712
0
    float corr_dims[2];
5713
0
    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5714
5715
0
    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5716
0
    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
5717
0
    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5718
5719
0
    if (mrope_used) {
5720
0
        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5721
0
    }
5722
5723
0
    if (is_vision) {
5724
0
        GGML_ASSERT(n_dims == ne0/2);
5725
0
    }
5726
5727
0
    const float * freq_factors = NULL;
5728
0
    if (src2 != NULL) {
5729
0
        GGML_ASSERT(src2->type == GGML_TYPE_F32);
5730
0
        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5731
0
        freq_factors = (const float *) src2->data;
5732
0
    }
5733
5734
    // backward process uses inverse rotation by cos and sin.
5735
    // cos and sin build a rotation matrix, where the inverse is the transpose.
5736
    // this essentially just switches the sign of sin.
5737
0
    const float sin_sign = forward ? 1.0f : -1.0f;
5738
5739
0
    const int32_t * pos = (const int32_t *) src1->data;
5740
5741
0
    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5742
0
        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5743
5744
0
            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5745
0
            if (!mrope_used) {
5746
0
                const int64_t p = pos[i2];
5747
0
                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5748
0
            }
5749
0
            else {
5750
0
                const int64_t p_t = pos[i2];
5751
0
                const int64_t p_h = pos[i2 + ne2];
5752
0
                const int64_t p_w = pos[i2 + ne2 * 2];
5753
0
                const int64_t p_e = pos[i2 + ne2 * 3];
5754
0
                ggml_mrope_cache_init(
5755
0
                    p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5756
0
                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5757
0
            }
5758
5759
0
            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5760
0
                if (ir++ < ir0) continue;
5761
0
                if (ir   > ir1) break;
5762
5763
0
                T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5764
0
                T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1);
5765
5766
0
                switch (mode) {
5767
0
                    case GGML_ROPE_TYPE_NORMAL:
5768
0
                        rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
5769
0
                        break;
5770
0
                    case GGML_ROPE_TYPE_NEOX:
5771
0
                    case GGML_ROPE_TYPE_MROPE:
5772
0
                    case GGML_ROPE_TYPE_IMROPE:
5773
0
                        rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
5774
0
                        break;
5775
0
                    case GGML_ROPE_TYPE_VISION:
5776
0
                        rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
5777
0
                        break;
5778
0
                    default:
5779
0
                        GGML_ABORT("rope type not supported");
5780
0
                }
5781
5782
0
                if (!is_vision) {
5783
                    // fill the remain channels with data from src tensor
5784
0
                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5785
0
                        const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5786
0
                        T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
5787
5788
0
                        dst_data[0] = src[0];
5789
0
                        dst_data[1] = src[1];
5790
0
                    }
5791
0
                }
5792
0
            } //attn-heads
5793
0
        }
5794
0
    }
5795
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_rope_flt<unsigned short>(ggml_compute_params const*, ggml_tensor*, bool)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_rope_flt<float>(ggml_compute_params const*, ggml_tensor*, bool)
5796
5797
void ggml_compute_forward_rope(
5798
        const ggml_compute_params * params,
5799
0
        ggml_tensor * dst) {
5800
5801
0
    const ggml_tensor * src0 = dst->src[0];
5802
5803
0
    switch (src0->type) {
5804
0
        case GGML_TYPE_F16:
5805
0
            {
5806
0
                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
5807
0
            } break;
5808
0
        case GGML_TYPE_F32:
5809
0
            {
5810
0
                ggml_compute_forward_rope_flt<float>(params, dst, true);
5811
0
            } break;
5812
0
        default:
5813
0
            {
5814
0
                GGML_ABORT("fatal error");
5815
0
            }
5816
0
    }
5817
0
}
5818
5819
// ggml_compute_forward_rope_back
5820
5821
void ggml_compute_forward_rope_back(
5822
        const ggml_compute_params * params,
5823
0
        ggml_tensor * dst) {
5824
5825
0
    const ggml_tensor * src0 = dst->src[0];
5826
5827
0
    switch (src0->type) {
5828
0
        case GGML_TYPE_F16:
5829
0
            {
5830
0
                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
5831
0
            } break;
5832
0
        case GGML_TYPE_F32:
5833
0
            {
5834
0
                ggml_compute_forward_rope_flt<float>(params, dst, false);
5835
0
            } break;
5836
0
        default:
5837
0
            {
5838
0
                GGML_ABORT("fatal error");
5839
0
            }
5840
0
    }
5841
0
}
5842
5843
// ggml_compute_forward_conv_transpose_1d
5844
5845
static void ggml_compute_forward_conv_transpose_1d_f16_f32(
5846
        const ggml_compute_params * params,
5847
0
              ggml_tensor * dst) {
5848
5849
0
    const ggml_tensor * src0 = dst->src[0];
5850
0
    const ggml_tensor * src1 = dst->src[1];
5851
5852
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
5853
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
5854
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
5855
5856
0
    GGML_TENSOR_BINARY_OP_LOCALS
5857
5858
0
    const int ith = params->ith;
5859
0
    const int nth = params->nth;
5860
5861
0
    const int nk = ne00*ne01*ne02;
5862
5863
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5864
0
    GGML_ASSERT(nb10 == sizeof(float));
5865
5866
0
    if (ith == 0) {
5867
0
        memset(params->wdata, 0, params->wsize);
5868
5869
        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
5870
0
        {
5871
0
            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
5872
5873
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
5874
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
5875
0
                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
5876
0
                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
5877
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
5878
0
                        dst_data[i00*ne02 + i02] = src[i00];
5879
0
                    }
5880
0
                }
5881
0
            }
5882
0
        }
5883
5884
        // permute source data (src1) from (L x Cin) to (Cin x L)
5885
0
        {
5886
0
            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
5887
0
            ggml_fp16_t * dst_data = wdata;
5888
5889
0
            for (int64_t i11 = 0; i11 < ne11; i11++) {
5890
0
                const float * const src = (float *)((char *) src1->data + i11*nb11);
5891
0
                for (int64_t i10 = 0; i10 < ne10; i10++) {
5892
0
                    dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
5893
0
                }
5894
0
            }
5895
0
        }
5896
5897
        // need to zero dst since we are accumulating into it
5898
0
        memset(dst->data, 0, ggml_nbytes(dst));
5899
0
    }
5900
0
    ggml_barrier(params->threadpool);
5901
5902
0
    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
5903
5904
    // total rows in dst
5905
0
    const int nr = ne1;
5906
5907
    // rows per thread
5908
0
    const int dr = (nr + nth - 1)/nth;
5909
5910
    // row range for this thread
5911
0
    const int ir0 = dr*ith;
5912
0
    const int ir1 = MIN(ir0 + dr, nr);
5913
5914
0
    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
5915
0
    ggml_fp16_t * const wdata_src = wdata + nk;
5916
5917
0
    for (int i1 = ir0; i1 < ir1; i1++) {
5918
0
        float * dst_data = (float *)((char *) dst->data + i1*nb1);
5919
0
        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
5920
0
        for (int i10 = 0; i10 < ne10; i10++) {
5921
0
            const int i1n = i10*ne11;
5922
0
            for (int i00 = 0; i00 < ne00; i00++) {
5923
0
                float v = 0;
5924
0
                ggml_vec_dot_f16(ne02, &v, 0,
5925
0
                        (ggml_fp16_t *)    wdata_src + i1n, 0,
5926
0
                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
5927
0
                dst_data[i10*s0 + i00] += v;
5928
0
            }
5929
0
        }
5930
0
    }
5931
0
}
5932
5933
static void ggml_compute_forward_conv_transpose_1d_f32(
5934
        const ggml_compute_params * params,
5935
0
              ggml_tensor * dst) {
5936
5937
0
    const ggml_tensor * src0 = dst->src[0];
5938
0
    const ggml_tensor * src1 = dst->src[1];
5939
5940
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
5941
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
5942
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
5943
5944
0
    GGML_TENSOR_BINARY_OP_LOCALS
5945
5946
0
    const int ith = params->ith;
5947
0
    const int nth = params->nth;
5948
5949
0
    const int nk = ne00*ne01*ne02;
5950
5951
0
    GGML_ASSERT(nb00 == sizeof(float));
5952
0
    GGML_ASSERT(nb10 == sizeof(float));
5953
5954
0
    if (ith == 0) {
5955
0
        memset(params->wdata, 0, params->wsize);
5956
5957
        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
5958
0
        {
5959
0
            float * const wdata = (float *) params->wdata + 0;
5960
5961
0
            for (int64_t i02 = 0; i02 < ne02; i02++) {
5962
0
                for (int64_t i01 = 0; i01 < ne01; i01++) {
5963
0
                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
5964
0
                    float * dst_data = wdata + i01*ne00*ne02;
5965
0
                    for (int64_t i00 = 0; i00 < ne00; i00++) {
5966
0
                        dst_data[i00*ne02 + i02] = src[i00];
5967
0
                    }
5968
0
                }
5969
0
            }
5970
0
        }
5971
5972
        // prepare source data (src1)
5973
0
        {
5974
0
            float * const wdata = (float *) params->wdata + nk;
5975
0
            float * dst_data = wdata;
5976
5977
0
            for (int64_t i11 = 0; i11 < ne11; i11++) {
5978
0
                const float * const src = (float *)((char *) src1->data + i11*nb11);
5979
0
                for (int64_t i10 = 0; i10 < ne10; i10++) {
5980
0
                    dst_data[i10*ne11 + i11] = src[i10];
5981
0
                }
5982
0
            }
5983
0
        }
5984
5985
        // need to zero dst since we are accumulating into it
5986
0
        memset(dst->data, 0, ggml_nbytes(dst));
5987
0
    }
5988
0
    ggml_barrier(params->threadpool);
5989
5990
0
    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
5991
5992
    // total rows in dst
5993
0
    const int nr = ne1;
5994
5995
    // rows per thread
5996
0
    const int dr = (nr + nth - 1)/nth;
5997
5998
    // row range for this thread
5999
0
    const int ir0 = dr*ith;
6000
0
    const int ir1 = MIN(ir0 + dr, nr);
6001
6002
0
    float * const wdata     = (float *) params->wdata + 0;
6003
0
    float * const wdata_src = wdata + nk;
6004
6005
0
    for (int i1 = ir0; i1 < ir1; i1++) {
6006
0
        float * dst_data = (float *)((char *) dst->data + i1*nb1);
6007
0
        float * wdata_kernel = wdata + i1*ne02*ne00;
6008
0
        for (int i10 = 0; i10 < ne10; i10++) {
6009
0
            const int i1n = i10*ne11;
6010
0
            for (int i00 = 0; i00 < ne00; i00++) {
6011
0
                float v = 0;
6012
0
                ggml_vec_dot_f32(ne02, &v, 0,
6013
0
                        wdata_src + i1n, 0,
6014
0
                        wdata_kernel + i00*ne02, 0, 1);
6015
0
                dst_data[i10*s0 + i00] += v;
6016
0
            }
6017
0
        }
6018
0
    }
6019
0
}
6020
6021
void ggml_compute_forward_conv_transpose_1d(
6022
        const ggml_compute_params * params,
6023
0
              ggml_tensor * dst) {
6024
6025
0
    const ggml_tensor * src0 = dst->src[0];
6026
6027
0
    switch (src0->type) {
6028
0
        case GGML_TYPE_F16:
6029
0
            {
6030
0
                ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
6031
0
            } break;
6032
0
        case GGML_TYPE_F32:
6033
0
            {
6034
0
                ggml_compute_forward_conv_transpose_1d_f32(params, dst);
6035
0
            } break;
6036
0
        default:
6037
0
            {
6038
0
                GGML_ABORT("fatal error");
6039
0
            }
6040
0
    }
6041
0
}
6042
6043
// ggml_compute_forward_im2col_f32
6044
// src0: kernel [OC, IC, KH, KW]
6045
// src1: image [N, IC, IH, IW]
6046
// dst:  result [N, OH, OW, IC*KH*KW]
6047
static void ggml_compute_forward_im2col_f32(
6048
        const ggml_compute_params * params,
6049
0
              ggml_tensor * dst) {
6050
6051
0
    const ggml_tensor * src0 = dst->src[0];
6052
0
    const ggml_tensor * src1 = dst->src[1];
6053
6054
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6055
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6056
6057
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6058
6059
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6060
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6061
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6062
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6063
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6064
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6065
0
    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6066
6067
0
    const int ith = params->ith;
6068
0
    const int nth = params->nth;
6069
6070
0
    const int64_t N  = is_2D ? ne13 : ne12;
6071
0
    const int64_t IC = is_2D ? ne12 : ne11;
6072
0
    const int64_t IH = is_2D ? ne11 : 1;
6073
0
    const int64_t IW = ne10;
6074
6075
0
    const int64_t KH = is_2D ? ne01 : 1;
6076
0
    const int64_t KW = ne00;
6077
6078
0
    const int64_t OH = is_2D ? ne2 : 1;
6079
0
    const int64_t OW = ne1;
6080
6081
0
    int ofs0 = is_2D ? nb13 : nb12;
6082
0
    int ofs1 = is_2D ? nb12 : nb11;
6083
6084
0
    GGML_ASSERT(nb10 == sizeof(float));
6085
6086
    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6087
0
    {
6088
0
        float * const wdata = (float *) dst->data;
6089
6090
0
        for (int64_t in = 0; in < N; in++) {
6091
0
            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
6092
0
                for (int64_t iow = 0; iow < OW; iow++) {
6093
0
                    for (int64_t iic = ith; iic < IC; iic += nth) {
6094
6095
                        // micro kernel
6096
0
                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6097
0
                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6098
6099
0
                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
6100
0
                            for (int64_t ikw = 0; ikw < KW; ikw++) {
6101
0
                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
6102
0
                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
6103
6104
0
                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6105
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6106
0
                                } else {
6107
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
6108
0
                                }
6109
0
                            }
6110
0
                        }
6111
0
                    }
6112
0
                }
6113
0
            }
6114
0
        }
6115
0
    }
6116
0
}
6117
6118
6119
// ggml_compute_forward_im2col_f16
6120
// src0: kernel [OC, IC, KH, KW]
6121
// src1: image [N, IC, IH, IW]
6122
// dst:  result [N, OH, OW, IC*KH*KW]
6123
static void ggml_compute_forward_im2col_f16(
6124
        const ggml_compute_params * params,
6125
0
              ggml_tensor * dst) {
6126
6127
0
    const ggml_tensor * src0 = dst->src[0];
6128
0
    const ggml_tensor * src1 = dst->src[1];
6129
6130
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
6131
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6132
0
    GGML_ASSERT( dst->type == GGML_TYPE_F16);
6133
6134
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6135
6136
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6137
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6138
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6139
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6140
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6141
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6142
0
    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6143
6144
0
    const int ith = params->ith;
6145
0
    const int nth = params->nth;
6146
6147
0
    const int64_t N  = is_2D ? ne13 : ne12;
6148
0
    const int64_t IC = is_2D ? ne12 : ne11;
6149
0
    const int64_t IH = is_2D ? ne11 : 1;
6150
0
    const int64_t IW = ne10;
6151
6152
0
    const int64_t KH = is_2D ? ne01 : 1;
6153
0
    const int64_t KW = ne00;
6154
6155
0
    const int64_t OH = is_2D ? ne2 : 1;
6156
0
    const int64_t OW = ne1;
6157
6158
0
    int ofs0 = is_2D ? nb13 : nb12;
6159
0
    int ofs1 = is_2D ? nb12 : nb11;
6160
6161
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6162
0
    GGML_ASSERT(nb10 == sizeof(float));
6163
6164
    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6165
0
    {
6166
0
        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6167
6168
0
        for (int64_t in = 0; in < N; in++) {
6169
0
            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
6170
0
                for (int64_t iow = 0; iow < OW; iow++) {
6171
0
                    for (int64_t iic = ith; iic < IC; iic += nth) {
6172
6173
                        // micro kernel
6174
0
                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6175
0
                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6176
6177
0
                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
6178
0
                            for (int64_t ikw = 0; ikw < KW; ikw++) {
6179
0
                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
6180
0
                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
6181
6182
0
                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6183
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6184
0
                                } else {
6185
0
                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
6186
0
                                }
6187
0
                            }
6188
0
                        }
6189
0
                    }
6190
0
                }
6191
0
            }
6192
0
        }
6193
0
    }
6194
0
}
6195
6196
void ggml_compute_forward_im2col(
6197
        const ggml_compute_params * params,
6198
0
              ggml_tensor * dst) {
6199
0
    switch (dst->type) {
6200
0
        case GGML_TYPE_F16:
6201
0
            {
6202
0
                ggml_compute_forward_im2col_f16(params, dst);
6203
0
            } break;
6204
0
        case GGML_TYPE_F32:
6205
0
            {
6206
0
                ggml_compute_forward_im2col_f32(params, dst);
6207
0
            } break;
6208
0
        default:
6209
0
            {
6210
0
                GGML_ABORT("fatal error");
6211
0
            }
6212
0
    }
6213
0
}
6214
6215
// ggml_compute_forward_im2col_back_f32
6216
6217
void ggml_compute_forward_im2col_back_f32(
6218
        const ggml_compute_params * params,
6219
0
              ggml_tensor * dst) {
6220
6221
0
    const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
6222
0
    const ggml_tensor * src1 = dst->src[1]; // convolution kernel
6223
6224
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
6225
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6226
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6227
6228
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6229
6230
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6231
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6232
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6233
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6234
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6235
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6236
0
    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6237
6238
0
    const int ith = params->ith;
6239
0
    const int nth = params->nth;
6240
6241
0
    const int64_t N  = is_2D ? ne3 : ne2;
6242
0
    const int64_t IC = is_2D ? ne2 : ne1;
6243
0
    const int64_t IH = is_2D ? ne1 : 1;
6244
0
    const int64_t IW = ne0;
6245
6246
0
    const int64_t KH = is_2D ? ne11 : 1;
6247
0
    const int64_t KW = ne10;
6248
6249
0
    const int64_t OH = is_2D ? ne02 : 1;
6250
0
    const int64_t OW = ne01;
6251
6252
0
    int ofs0 = is_2D ? nb3 : nb2;
6253
0
    int ofs1 = is_2D ? nb2 : nb1;
6254
6255
0
    GGML_ASSERT(nb0  == sizeof(float));
6256
6257
    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6258
0
    {
6259
0
        float * const wdata = (float *) dst->data;
6260
6261
0
        for (int64_t in = 0; in < N; in++) {
6262
0
            for (int64_t iic = ith; iic < IC; iic += nth) {
6263
0
                for (int64_t iih = 0; iih < IH; iih++) {
6264
0
                    for (int64_t iiw = 0; iiw < IW; iiw++) {
6265
6266
                        // micro kernel
6267
0
                        float grad = 0.0f;
6268
0
                        for (int64_t ikh = 0; ikh < KH; ikh++) {
6269
0
                            for (int64_t ikw = 0; ikw < KW; ikw++) {
6270
                                // For s0 > 1 some values were skipped over in the forward pass.
6271
                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6272
0
                                const int64_t tmpw = (iiw + p0 - ikw*d0);
6273
0
                                if (tmpw % s0 != 0) {
6274
0
                                    continue;
6275
0
                                }
6276
0
                                const int64_t iow = tmpw / s0;
6277
6278
                                // Equivalent logic as above except for s1.
6279
0
                                int64_t ioh;
6280
0
                                if (is_2D) {
6281
0
                                    const int64_t tmph = iih + p1 - ikh*d1;
6282
6283
0
                                    if (tmph % s1 != 0) {
6284
0
                                        continue;
6285
0
                                    }
6286
6287
0
                                    ioh = tmph / s1;
6288
0
                                } else {
6289
0
                                    ioh = 0;
6290
0
                                }
6291
6292
0
                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
6293
0
                                    continue;
6294
0
                                }
6295
6296
0
                                const float * const grad_in = (const float *) src0->data
6297
0
                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6298
0
                                grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
6299
0
                            }
6300
0
                        }
6301
0
                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
6302
0
                        dst_data[iih*IW + iiw] = grad;
6303
0
                    }
6304
0
                }
6305
0
            }
6306
0
        }
6307
0
    }
6308
0
}
6309
6310
6311
// ggml_compute_forward_im2col_3d_f16
6312
// src0: kernel [OC*IC, KD, KH, KW]
6313
// src1: image [N*IC, ID, IH, IW]
6314
// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]
6315
static void ggml_compute_forward_im2col_3d_f16(
6316
        const ggml_compute_params * params,
6317
0
              ggml_tensor * dst) {
6318
6319
0
    const ggml_tensor * src0 = dst->src[0];
6320
0
    const ggml_tensor * src1 = dst->src[1];
6321
6322
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
6323
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6324
0
    GGML_ASSERT( dst->type == GGML_TYPE_F16);
6325
6326
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6327
6328
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6329
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6330
0
    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6331
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6332
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6333
0
    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6334
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6335
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6336
0
    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6337
0
    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6338
6339
6340
0
    const int ith = params->ith;
6341
0
    const int nth = params->nth;
6342
6343
0
    const int64_t N  = ne13 / IC;
6344
0
    const int64_t ID = ne12;
6345
0
    const int64_t IH = ne11;
6346
0
    const int64_t IW = ne10;
6347
6348
0
    const int64_t OC = ne03 / IC;
6349
0
    GGML_UNUSED(OC);
6350
0
    const int64_t KD = ne02;
6351
0
    const int64_t KH = ne01;
6352
0
    const int64_t KW = ne00;
6353
6354
0
    const int64_t OD = ne3 / N;
6355
0
    const int64_t OH = ne2;
6356
0
    const int64_t OW = ne1;
6357
0
    const int64_t OH_OW = OH*OW;
6358
0
    const int64_t KD_KH_KW = KD*KH*KW;
6359
0
    const int64_t KH_KW = KH*KW;
6360
0
    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6361
6362
0
    GGML_ASSERT(nb10 == sizeof(float));
6363
6364
    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6365
0
    {
6366
0
        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6367
6368
0
        for (int64_t in = 0; in < N; in++) {
6369
0
            for (int64_t iod = 0; iod < OD; iod++) {
6370
0
                for (int64_t ioh = 0; ioh < OH; ioh++) {
6371
0
                    for (int64_t iow = 0; iow < OW; iow++) {
6372
0
                        for (int64_t iic = ith; iic < IC; iic += nth) {
6373
6374
                            // micro kernel
6375
0
                            ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6376
0
                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6377
6378
0
                            for (int64_t ikd = 0; ikd < KD; ikd++) {
6379
0
                                for (int64_t ikh = 0; ikh < KH; ikh++) {
6380
0
                                    for (int64_t ikw = 0; ikw < KW; ikw++) {
6381
0
                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;
6382
0
                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;
6383
0
                                        const int64_t iid = iod*s2 + ikd*d2 - p2;
6384
6385
0
                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6386
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6387
0
                                        } else {
6388
0
                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6389
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
6390
0
                                        }
6391
0
                                    }
6392
0
                                }
6393
0
                            }
6394
0
                        }
6395
0
                    }
6396
0
                }
6397
0
            }
6398
0
        }
6399
0
    }
6400
0
}
6401
6402
// ggml_compute_forward_im2col_3d_f32
6403
// src0: kernel [OC*IC, KD, KH, KW]
6404
// src1: image [N*IC, ID, IH, IW]
6405
// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]
6406
static void ggml_compute_forward_im2col_3d_f32(
6407
        const ggml_compute_params * params,
6408
0
              ggml_tensor * dst) {
6409
6410
0
    const ggml_tensor * src0 = dst->src[0];
6411
0
    const ggml_tensor * src1 = dst->src[1];
6412
6413
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6414
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6415
6416
0
    GGML_TENSOR_BINARY_OP_LOCALS;
6417
6418
0
    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6419
0
    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6420
0
    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6421
0
    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6422
0
    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6423
0
    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6424
0
    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6425
0
    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6426
0
    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6427
0
    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6428
6429
6430
0
    const int ith = params->ith;
6431
0
    const int nth = params->nth;
6432
6433
0
    const int64_t N  = ne13 / IC;
6434
0
    const int64_t ID = ne12;
6435
0
    const int64_t IH = ne11;
6436
0
    const int64_t IW = ne10;
6437
6438
0
    const int64_t OC = ne03 / IC;
6439
0
    GGML_UNUSED(OC);
6440
0
    const int64_t KD = ne02;
6441
0
    const int64_t KH = ne01;
6442
0
    const int64_t KW = ne00;
6443
6444
0
    const int64_t OD = ne3 / N;
6445
0
    const int64_t OH = ne2;
6446
0
    const int64_t OW = ne1;
6447
6448
0
    const int64_t OH_OW = OH*OW;
6449
0
    const int64_t KD_KH_KW = KD*KH*KW;
6450
0
    const int64_t KH_KW = KH*KW;
6451
0
    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6452
6453
0
    GGML_ASSERT(nb10 == sizeof(float));
6454
6455
    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6456
0
    {
6457
0
        float * const wdata = (float *) dst->data;
6458
6459
0
        for (int64_t in = 0; in < N; in++) {
6460
0
            for (int64_t iod = 0; iod < OD; iod++) {
6461
0
                for (int64_t ioh = 0; ioh < OH; ioh++) {
6462
0
                    for (int64_t iow = 0; iow < OW; iow++) {
6463
0
                        for (int64_t iic = ith; iic < IC; iic += nth) {
6464
6465
                            // micro kernel
6466
0
                            float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6467
0
                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6468
6469
0
                            for (int64_t ikd = 0; ikd < KD; ikd++) {
6470
0
                                for (int64_t ikh = 0; ikh < KH; ikh++) {
6471
0
                                    for (int64_t ikw = 0; ikw < KW; ikw++) {
6472
0
                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;
6473
0
                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;
6474
0
                                        const int64_t iid = iod*s2 + ikd*d2 - p2;
6475
6476
0
                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6477
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6478
0
                                        } else {
6479
0
                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6480
0
                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
6481
0
                                        }
6482
0
                                    }
6483
0
                                }
6484
0
                            }
6485
0
                        }
6486
0
                    }
6487
0
                }
6488
0
            }
6489
0
        }
6490
0
    }
6491
0
}
6492
6493
6494
void ggml_compute_forward_im2col_3d(
6495
        const ggml_compute_params * params,
6496
0
              ggml_tensor * dst) {
6497
0
    switch (dst->type) {
6498
0
        case GGML_TYPE_F16:
6499
0
            {
6500
0
                ggml_compute_forward_im2col_3d_f16(params, dst);
6501
0
            } break;
6502
0
        case GGML_TYPE_F32:
6503
0
            {
6504
0
                ggml_compute_forward_im2col_3d_f32(params, dst);
6505
0
            } break;
6506
0
        default:
6507
0
            {
6508
0
                GGML_ABORT("fatal error");
6509
0
            }
6510
0
    }
6511
0
}
6512
6513
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6514
0
                              void * a, void * b, float * c) {
6515
0
    const ggml_type_traits * traits = ggml_get_type_traits(type);
6516
0
    struct ggml_tensor src1 = {};
6517
0
    src1.type  = type;
6518
0
    src1.ne[0] = k;
6519
0
    src1.ne[1] = m;
6520
0
    src1.ne[2] = 1;
6521
0
    src1.ne[3] = 1;
6522
0
    src1.nb[0] = traits->type_size;
6523
0
    src1.nb[1] = k * traits->type_size;
6524
0
    src1.nb[2] = src1.nb[1];
6525
0
    src1.nb[3] = src1.nb[2];
6526
0
    src1.data  = a;
6527
6528
0
    struct ggml_tensor src0 = {};
6529
0
    src0.type  = type;
6530
0
    src0.ne[0] = k;
6531
0
    src0.ne[1] = n;
6532
0
    src0.ne[2] = 1;
6533
0
    src0.ne[3] = 1;
6534
0
    src0.nb[0] = traits->type_size;
6535
0
    src0.nb[1] = k * traits->type_size;
6536
0
    src0.nb[2] = src0.nb[1];
6537
0
    src0.nb[3] = src0.nb[2];
6538
0
    src0.data  = b;
6539
6540
0
    struct ggml_tensor dst = {};
6541
0
    dst.ne[0] = n;
6542
0
    dst.ne[1] = m;
6543
0
    dst.ne[2] = 1;
6544
0
    dst.ne[3] = 1;
6545
0
    dst.nb[0] = sizeof(float);
6546
0
    dst.nb[1] = n * sizeof(float);
6547
0
    dst.nb[2] = dst.nb[1];
6548
0
    dst.nb[3] = dst.nb[2];
6549
0
    dst.data  = c;
6550
0
    dst.src[0] = &src0;
6551
0
    dst.src[1] = &src1;
6552
6553
0
    ggml_compute_forward_mul_mat(params, &dst);
6554
0
}
6555
6556
0
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
6557
0
    return (coord  + size) % size; // adding size avoids negative number weirdness
6558
0
}
6559
6560
// ggml_compute_forward_conv_2d
6561
6562
6563
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6564
                                              const ggml_tensor *         kernel,  // [KW, KH, IC, OC]
6565
                                              const ggml_tensor *         src,     // [W, H, C, N]
6566
                                              ggml_tensor *               dst,     // [OW, OH, OC, N]
6567
0
                                              ggml_type                   kernel_type) {
6568
6569
0
    GGML_ASSERT(ggml_is_contiguous(kernel));
6570
0
    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6571
0
    GGML_ASSERT(kernel->type == kernel_type);
6572
6573
0
    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6574
6575
0
    const int32_t stride_x   = dst->op_params[0];
6576
0
    const int32_t stride_y   = dst->op_params[1];
6577
0
    const int32_t pad_x      = dst->op_params[2];
6578
0
    const int32_t pad_y      = dst->op_params[3];
6579
0
    const int32_t dilation_x = dst->op_params[4];
6580
0
    const int32_t dilation_y = dst->op_params[5];
6581
6582
0
    const int64_t c_in  = src->ne[2];
6583
0
    const int64_t c_out = kernel->ne[3];
6584
0
    GGML_ASSERT(c_in == kernel->ne[2]);
6585
6586
0
    const int64_t src_w = src->ne[0];
6587
0
    const int64_t src_h = src->ne[1];
6588
0
    const int64_t knl_w = kernel->ne[0];
6589
0
    const int64_t knl_h = kernel->ne[1];
6590
0
    const int64_t dst_w = dst->ne[0];
6591
0
    const int64_t dst_h = dst->ne[1];
6592
6593
0
    const float * src_data = (float *) src->data;
6594
0
    void  * knl_data       = kernel->data;
6595
0
    float * dst_data       = (float *) dst->data;
6596
6597
0
    const int64_t knl_n           = knl_w * knl_h * c_in;
6598
0
    const int64_t patch_total     = dst->ne[3] * dst_w * dst_h;
6599
6600
0
    const int64_t space_per_patch   = knl_n * traits->type_size + c_out * sizeof(float);
6601
0
    const int64_t batch_size        = params->wsize / space_per_patch;
6602
0
    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6603
0
    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;
6604
6605
0
    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6606
6607
0
    void * tmp = params->wdata;
6608
6609
0
    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6610
6611
0
        const int64_t patch_start_batch = batch_i * patches_per_batch;
6612
0
        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch,
6613
0
                                              patch_total);
6614
0
        const int64_t patch_n           = patch_end_batch - patch_start_batch;
6615
6616
0
        const int64_t patch_per_thread  = (patch_n + params->nth - 1) / params->nth;
6617
0
        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;
6618
0
        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);
6619
6620
        //im2col for a patch
6621
0
        for (int64_t p = patch_start; p < patch_end; ++p) {
6622
0
            const int64_t  batch_n     =  p / (dst_w * dst_h);
6623
0
            const int64_t  src_x       = (p / dst_w) % dst_h;
6624
0
            const int64_t  src_y       =  p % dst_w;
6625
6626
0
            const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6627
0
            char *        dst_row  = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6628
6629
0
            for (int64_t ic = 0; ic < c_in; ++ic) {
6630
0
                for (int64_t ky = 0; ky < knl_h; ++ky) {
6631
0
                    for (int64_t kx = 0; kx < knl_w; ++kx) {
6632
0
                        const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6633
0
                        const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6634
6635
0
                        int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6636
6637
0
                        float src_val;
6638
0
                        if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6639
0
                            src_val = 0.0f;
6640
0
                        } else {
6641
0
                            const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6642
0
                            src_val               = *src_ptr;
6643
0
                        }
6644
6645
0
                        char * element_ptr = dst_row + dst_idx * traits->type_size;
6646
0
                        if (kernel_type == GGML_TYPE_F32) {
6647
0
                            *(float *) element_ptr = src_val;
6648
0
                        } else if (kernel_type == GGML_TYPE_F16) {
6649
0
                            *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6650
0
                        }
6651
0
                    }
6652
0
                }
6653
0
            }
6654
0
        }   // patches handled by this thread
6655
6656
0
        ggml_barrier(params->threadpool);
6657
6658
0
        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6659
6660
0
        GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6661
6662
        // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6663
0
        ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6664
6665
0
        ggml_barrier(params->threadpool);
6666
6667
6668
        //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6669
0
        const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
6670
0
        const int64_t permute_start = params->ith * permute_per_thread;
6671
0
        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
6672
6673
0
        for (int64_t i = permute_start; i < permute_end; ++i) {
6674
0
            const int64_t p       = patch_start_batch + i;
6675
0
            const int64_t batch_n = p / (dst_w * dst_h);
6676
0
            const int64_t dst_y   = (p / dst_w) % dst_h;
6677
0
            const int64_t dst_x   = p % dst_w;
6678
6679
0
            for (int64_t oc = 0; oc < c_out; ++oc) {
6680
0
                const float value = gemm_output[i * c_out + oc];
6681
0
                float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
6682
0
                *dst_ptr = value;
6683
0
            }
6684
0
        }
6685
0
    }
6686
0
}
6687
6688
void ggml_compute_forward_conv_2d(
6689
        const ggml_compute_params * params,
6690
0
        ggml_tensor * dst) {
6691
6692
0
    const ggml_tensor * src0 = dst->src[0];
6693
0
    const ggml_tensor * src1 = dst->src[1];
6694
6695
0
    ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
6696
0
}
6697
6698
// ggml_compute_forward_conv_3d
6699
6700
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
6701
                                              const ggml_tensor *         kernel,
6702
                                              const ggml_tensor *         src,
6703
                                              ggml_tensor *               dst,
6704
0
                                              ggml_type                   kernel_type) {
6705
6706
0
    GGML_ASSERT(ggml_is_contiguous(kernel));
6707
0
    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6708
0
    GGML_ASSERT(kernel->type == kernel_type);
6709
6710
0
    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6711
6712
0
    const int32_t s0 = dst->op_params[0];
6713
0
    const int32_t s1 = dst->op_params[1];
6714
0
    const int32_t s2 = dst->op_params[2];
6715
0
    const int32_t p0 = dst->op_params[3];
6716
0
    const int32_t p1 = dst->op_params[4];
6717
0
    const int32_t p2 = dst->op_params[5];
6718
0
    const int32_t d0 = dst->op_params[6];
6719
0
    const int32_t d1 = dst->op_params[7];
6720
0
    const int32_t d2 = dst->op_params[8];
6721
0
    const int32_t c  = dst->op_params[9];
6722
0
    const int32_t n  = dst->op_params[10];
6723
0
    const int32_t oc = dst->op_params[11];
6724
6725
0
    const int64_t src_w = src->ne[0];
6726
0
    const int64_t src_h = src->ne[1];
6727
0
    const int64_t src_d = src->ne[2];
6728
0
    const int64_t knl_w = kernel->ne[0];
6729
0
    const int64_t knl_h = kernel->ne[1];
6730
0
    const int64_t knl_d = kernel->ne[2];
6731
0
    const int64_t dst_w = dst->ne[0];
6732
0
    const int64_t dst_h = dst->ne[1];
6733
0
    const int64_t dst_d = dst->ne[2];
6734
6735
0
    const float * src_data = (float *) src->data;
6736
0
    void  * knl_data       = kernel->data;
6737
0
    float * dst_data       = (float *) dst->data;
6738
6739
0
    const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
6740
0
    const int64_t knl_n_total       = knl_n_per_channel * c;
6741
0
    const int64_t patch_total       = n * dst_w * dst_h * dst_d;
6742
6743
0
    const int64_t space_per_patch   = knl_n_total * traits->type_size + oc * sizeof(float);
6744
0
    const int64_t batch_size        = params->wsize / space_per_patch;
6745
0
    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6746
0
    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;
6747
6748
0
    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6749
6750
0
    void * tmp = params->wdata;
6751
6752
0
    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6753
0
        const int64_t patch_start_batch = batch_i * patches_per_batch;
6754
0
        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch, patch_total);
6755
0
        const int64_t patch_n_in_batch  = patch_end_batch - patch_start_batch;
6756
6757
0
        const int64_t patch_per_thread  = (patch_n_in_batch + params->nth - 1) / params->nth;
6758
0
        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;
6759
0
        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);
6760
6761
0
        for (int64_t p = patch_start; p < patch_end; ++p) {
6762
0
            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6763
0
            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6764
0
            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);
6765
0
            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);
6766
0
            const int64_t dst_y      = p_in_depth / dst_w;
6767
0
            const int64_t dst_x      = p_in_depth % dst_w;
6768
6769
0
            char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
6770
6771
0
            for (int64_t ic = 0; ic < c; ++ic) {
6772
0
                for (int64_t kz = 0; kz < knl_d; ++kz) {
6773
0
                    for (int64_t ky = 0; ky < knl_h; ++ky) {
6774
0
                        for (int64_t kx = 0; kx < knl_w; ++kx) {
6775
0
                            const int64_t sz = dst_z * s2 + kz * d2 - p2;
6776
0
                            const int64_t sy = dst_y * s1 + ky * d1 - p1;
6777
0
                            const int64_t sx = dst_x * s0 + kx * d0 - p0;
6778
6779
0
                            int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
6780
6781
0
                            float src_val;
6782
0
                            if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6783
0
                                src_val = 0.0f;
6784
0
                            } else {
6785
0
                                const int64_t cn_idx = batch_idx * c + ic;
6786
0
                                const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
6787
0
                                src_val = *src_ptr;
6788
0
                            }
6789
6790
0
                            char * element_ptr = dst_row + dst_idx * traits->type_size;
6791
0
                            if (kernel_type == GGML_TYPE_F32) {
6792
0
                                *(float *)element_ptr = src_val;
6793
0
                            } else if (kernel_type == GGML_TYPE_F16) {
6794
0
                                *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6795
0
                            }
6796
0
                        }
6797
0
                    }
6798
0
                }
6799
0
            }
6800
0
        }
6801
6802
0
        ggml_barrier(params->threadpool);
6803
6804
0
        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
6805
0
        ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
6806
6807
0
        ggml_barrier(params->threadpool);
6808
6809
0
        const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6810
0
        const int64_t permute_start = params->ith * permute_per_thread;
6811
0
        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
6812
6813
0
        for (int64_t i = permute_start; i < permute_end; ++i) {
6814
0
            const int64_t p = patch_start_batch + i;
6815
0
            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6816
0
            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6817
0
            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);
6818
0
            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);
6819
0
            const int64_t dst_y      = p_in_depth / dst_w;
6820
0
            const int64_t dst_x      = p_in_depth % dst_w;
6821
6822
0
            for (int64_t ioc = 0; ioc < oc; ++ioc) {
6823
0
                const float value = gemm_output[i * oc + ioc];
6824
0
                const int64_t ocn_idx = batch_idx * oc + ioc;
6825
0
                float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
6826
0
                *dst_ptr = value;
6827
0
            }
6828
0
        }
6829
0
    }
6830
0
}
6831
6832
void ggml_compute_forward_conv_3d(
6833
        const ggml_compute_params * params,
6834
0
        ggml_tensor * dst) {
6835
0
    const ggml_tensor * src0 = dst->src[0];
6836
0
    const ggml_tensor * src1 = dst->src[1];
6837
0
    ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
6838
0
}
6839
6840
// ggml_compute_forward_conv_transpose_2d
6841
6842
void ggml_compute_forward_conv_transpose_2d(
6843
        const ggml_compute_params * params,
6844
0
              ggml_tensor * dst) {
6845
6846
0
    const ggml_tensor * src0 = dst->src[0];
6847
0
    const ggml_tensor * src1 = dst->src[1];
6848
6849
0
    GGML_ASSERT(src0->type == GGML_TYPE_F16);
6850
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
6851
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
6852
6853
0
    GGML_TENSOR_BINARY_OP_LOCALS
6854
6855
0
    const int ith = params->ith;
6856
0
    const int nth = params->nth;
6857
6858
0
    const int nk = ne00*ne01*ne02*ne03;
6859
6860
0
    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6861
0
    GGML_ASSERT(nb10 == sizeof(float));
6862
6863
0
    if (ith == 0) {
6864
0
        memset(params->wdata, 0, params->wsize);
6865
6866
        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
6867
0
        {
6868
0
            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
6869
6870
0
            for (int64_t i03 = 0; i03 < ne03; i03++) {
6871
0
                for (int64_t i02 = 0; i02 < ne02; i02++) {
6872
0
                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
6873
0
                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
6874
0
                    for (int64_t i01 = 0; i01 < ne01; i01++) {
6875
0
                        for (int64_t i00 = 0; i00 < ne00; i00++) {
6876
0
                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
6877
0
                        }
6878
0
                    }
6879
0
                }
6880
0
            }
6881
0
        }
6882
6883
        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
6884
0
        {
6885
0
            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
6886
0
            for (int i12 = 0; i12 < ne12; i12++) {
6887
0
                for (int i11 = 0; i11 < ne11; i11++) {
6888
0
                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
6889
0
                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
6890
0
                    for (int i10 = 0; i10 < ne10; i10++) {
6891
0
                        dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
6892
0
                    }
6893
0
                }
6894
0
            }
6895
0
        }
6896
6897
0
        memset(dst->data, 0, ggml_nbytes(dst));
6898
0
    }
6899
0
    ggml_barrier(params->threadpool);
6900
6901
0
    const int32_t stride = ggml_get_op_params_i32(dst, 0);
6902
6903
    // total patches in dst
6904
0
    const int np = ne2;
6905
6906
    // patches per thread
6907
0
    const int dp = (np + nth - 1)/nth;
6908
6909
    // patch range for this thread
6910
0
    const int ip0 = dp*ith;
6911
0
    const int ip1 = MIN(ip0 + dp, np);
6912
6913
0
    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
6914
0
    ggml_fp16_t * const wdata_src = wdata + nk;
6915
6916
0
    for (int i2 = ip0; i2 < ip1; i2++) { // Cout
6917
0
        float * dst_data = (float *)((char *) dst->data + i2*nb2);
6918
0
        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
6919
0
        for (int i11 = 0; i11 < ne11; i11++) {
6920
0
            for (int i10 = 0; i10 < ne10; i10++) {
6921
0
                const int i1n = i11*ne10*ne12 + i10*ne12;
6922
0
                for (int i01 = 0; i01 < ne01; i01++) {
6923
0
                    for (int i00 = 0; i00 < ne00; i00++) {
6924
0
                        float v = 0;
6925
0
                        ggml_vec_dot_f16(ne03, &v, 0,
6926
0
                                wdata_src + i1n, 0,
6927
0
                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
6928
0
                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
6929
0
                    }
6930
0
                }
6931
0
            }
6932
0
        }
6933
0
    }
6934
0
}
6935
6936
// ggml_compute_forward_conv_2d_dw
6937
6938
struct ggml_conv_2d_dw_params {
6939
    int64_t channels;
6940
    int64_t batch;
6941
    int64_t src_w;
6942
    int64_t src_h;
6943
    int64_t dst_w;
6944
    int64_t dst_h;
6945
    int64_t knl_w;
6946
    int64_t knl_h;
6947
    int stride_x;
6948
    int stride_y;
6949
    int pad_x;
6950
    int pad_y;
6951
    int dilation_x;
6952
    int dilation_y;
6953
};
6954
6955
static void ggml_compute_forward_conv_2d_dw_cwhn(
6956
        const ggml_compute_params * params,
6957
        const ggml_tensor * src,
6958
        const ggml_tensor * kernel,
6959
        ggml_tensor * dst,
6960
0
        const ggml_conv_2d_dw_params & p) {
6961
6962
0
    const int64_t c = p.channels;
6963
0
    const float * knl_data = (const float *)kernel->data;
6964
6965
0
    const int64_t rows_total = p.dst_h * p.batch;
6966
0
    const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
6967
0
    const int64_t row_start = params->ith * rows_per_thread;
6968
0
    const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
6969
6970
0
#ifdef GGML_SIMD
6971
    #if defined(__ARM_FEATURE_SVE)
6972
        const int64_t pkg_size = svcntw();
6973
    #else
6974
0
        const int64_t pkg_size = GGML_F32_EPR;
6975
0
    #endif
6976
0
    const int64_t pkg_count = c / pkg_size;
6977
0
    const int64_t c_pkg_end = pkg_count * pkg_size;
6978
#else
6979
    const int64_t c_pkg_end = 0;
6980
#endif
6981
6982
0
    for (int64_t row = row_start; row < row_end; ++row) {
6983
0
        const int64_t dst_y = row % p.dst_h;
6984
0
        const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
6985
0
        for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
6986
0
            float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
6987
0
            const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
6988
0
            const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
6989
6990
0
#ifdef GGML_SIMD
6991
            // Vectorized loop
6992
0
            for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
6993
0
                GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
6994
0
                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
6995
0
                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;
6996
0
                    if (src_y < 0 || src_y >= p.src_h) {
6997
0
                        continue;
6998
0
                    }
6999
0
                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7000
0
                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;
7001
0
                        if (src_x < 0 || src_x >= p.src_w) {
7002
0
                            continue;
7003
0
                        }
7004
0
                        GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
7005
0
                        GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
7006
0
                        sum = GGML_F32_VEC_FMA(sum, k, s);
7007
0
                    }
7008
0
                }
7009
0
                GGML_F32_VEC_STORE(dst_data + c_i, sum);
7010
0
            }
7011
0
#endif
7012
            // Scalar loop
7013
0
            for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
7014
0
                float sum = 0.0f;
7015
0
                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7016
0
                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;
7017
0
                    if (src_y < 0 || src_y >= p.src_h) {
7018
0
                        continue;
7019
0
                    }
7020
0
                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7021
0
                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;
7022
0
                        if (src_x < 0 || src_x >= p.src_w) {
7023
0
                            continue;
7024
0
                        }
7025
0
                        sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
7026
0
                             * src_data[(src_y * p.src_w + src_x) * c + c_i];
7027
0
                    }
7028
0
                }
7029
0
                dst_data[c_i] = sum;
7030
0
            }
7031
0
        }
7032
0
    }
7033
0
}
7034
7035
static void ggml_compute_forward_conv_2d_dw_whcn(
7036
        const ggml_compute_params * params,
7037
        const ggml_tensor * src,
7038
        const ggml_tensor * kernel,
7039
        ggml_tensor * dst,
7040
0
        const ggml_conv_2d_dw_params & p) {
7041
7042
0
    const int64_t n = p.channels * p.batch;
7043
0
    const int64_t per_thread = (n + params->nth - 1) / params->nth;
7044
0
    const int64_t start = params->ith * per_thread;
7045
0
    const int64_t end = MIN(start + per_thread, n);
7046
7047
0
    for (int64_t i = start; i < end; ++i) {
7048
0
        const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
7049
0
        const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
7050
0
        float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
7051
7052
0
        for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
7053
0
            for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
7054
7055
0
                float sum = 0.0f;
7056
0
                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7057
0
                    const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
7058
0
                    if (src_y < 0 || src_y >= p.src_h) {
7059
0
                        continue;
7060
0
                    }
7061
0
                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7062
0
                        const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
7063
0
                        if (src_x < 0 || src_x >= p.src_w) {
7064
0
                            continue;
7065
0
                        }
7066
0
                        sum += knl_data[knl_y * p.knl_w + knl_x]
7067
0
                             * src_data[src_y * p.src_w + src_x];
7068
0
                    }
7069
0
                }
7070
0
                dst_data[dst_y * p.dst_w + dst_x] = sum;
7071
0
            }
7072
0
        }
7073
0
    }
7074
0
}
7075
7076
void ggml_compute_forward_conv_2d_dw(
7077
        const ggml_compute_params * params,
7078
0
        ggml_tensor * dst) {
7079
7080
0
    const ggml_tensor * kernel = dst->src[0];
7081
0
    const ggml_tensor * src = dst->src[1];
7082
0
    ggml_conv_2d_dw_params p;
7083
0
    p.channels = src->ne[2];
7084
0
    p.batch = src->ne[3];
7085
0
    p.src_w = src->ne[0];
7086
0
    p.src_h = src->ne[1];
7087
0
    p.dst_w = dst->ne[0];
7088
0
    p.dst_h = dst->ne[1];
7089
0
    p.knl_w = kernel->ne[0];
7090
0
    p.knl_h = kernel->ne[1];
7091
0
    p.stride_x = dst->op_params[0];
7092
0
    p.stride_y = dst->op_params[1];
7093
0
    p.pad_x = dst->op_params[2];
7094
0
    p.pad_y = dst->op_params[3];
7095
0
    p.dilation_x = dst->op_params[4];
7096
0
    p.dilation_y = dst->op_params[5];
7097
7098
0
    GGML_ASSERT(kernel->ne[3] == p.channels);
7099
0
    GGML_ASSERT(dst->ne[3] == p.batch);
7100
7101
0
    if (ggml_is_contiguous(src)) {
7102
0
        ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
7103
0
    } else if (ggml_is_contiguous_channels(src)) {
7104
        // kernel should also have channels most contiguous in memory
7105
0
        GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
7106
0
        ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
7107
0
    } else {
7108
0
        GGML_ABORT("non-contiguous memory layout not supported");
7109
0
    }
7110
0
}
7111
7112
// ggml_compute_forward_pool_1d_ksp
7113
static void ggml_compute_forward_pool_1d_ksp(
7114
        const ggml_compute_params * params,
7115
        const ggml_op_pool op,
7116
        const int k,
7117
        const int s,
7118
        const int p,
7119
0
        ggml_tensor * dst) {
7120
7121
0
    const ggml_tensor * src = dst->src[0];
7122
7123
0
    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
7124
7125
0
    if (params->ith != 0) {
7126
0
        return;
7127
0
    }
7128
7129
0
    const int64_t IW = src->ne[0];
7130
0
    const int64_t OW = dst->ne[0];
7131
7132
0
    const int64_t nr = ggml_nrows(src);
7133
7134
0
    for (int64_t ir = 0; ir < nr; ++ir) {
7135
0
        const char * srow_bytes =            (const char *) src->data + ir * src->nb[1];
7136
0
        float      * drow       = (float *) ((      char *) dst->data + ir * dst->nb[1]);
7137
7138
0
        for (int64_t ow = 0; ow < OW; ++ow) {
7139
0
            float res = 0;
7140
0
            switch (op) {
7141
0
                case GGML_OP_POOL_AVG: res = 0.0f;     break;
7142
0
                case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7143
0
                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7144
0
            }
7145
7146
0
            int count = 0;
7147
0
            const int base = (int) ow * s - p;
7148
7149
0
            for (int ki = 0; ki < k; ++ki) {
7150
0
                const int j = base + ki;
7151
0
                if (j < 0 || j >= (int) IW) {
7152
0
                    continue;
7153
0
                }
7154
7155
0
                float v;
7156
0
                if (src->type == GGML_TYPE_F32) {
7157
0
                    v = ((const float *) srow_bytes)[j];
7158
0
                } else {
7159
0
                    v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
7160
0
                }
7161
7162
0
                switch (op) {
7163
0
                    case GGML_OP_POOL_AVG: res += v;                break;
7164
0
                    case GGML_OP_POOL_MAX: res =  std::max(v, res); break;
7165
0
                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7166
0
                }
7167
7168
0
                ++count;
7169
0
            }
7170
7171
0
            switch (op) {
7172
0
                case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
7173
0
                case GGML_OP_POOL_MAX:                                           break;
7174
0
                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7175
0
            }
7176
7177
0
            drow[ow] = res;
7178
0
        }
7179
0
    }
7180
0
}
7181
7182
// ggml_compute_forward_pool_1d
7183
7184
void ggml_compute_forward_pool_1d(
7185
        const ggml_compute_params * params,
7186
0
              ggml_tensor * dst) {
7187
7188
0
    const int32_t * opts = (const int32_t *)dst->op_params;
7189
0
    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7190
0
    const int k0 = opts[1];
7191
0
    const int s0 = opts[2];
7192
0
    const int p0 = opts[3];
7193
7194
0
    ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
7195
0
}
7196
7197
// ggml_compute_forward_pool_2d
7198
7199
void ggml_compute_forward_pool_2d(
7200
        const ggml_compute_params * params,
7201
0
        ggml_tensor * dst) {
7202
7203
0
    const ggml_tensor * src = dst->src[0];
7204
7205
0
    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
7206
7207
0
    if (params->ith != 0) {
7208
0
        return;
7209
0
    }
7210
7211
0
    const int32_t * opts = (const int32_t *)dst->op_params;
7212
7213
0
    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7214
0
    const int k0 = opts[1];
7215
0
    const int k1 = opts[2];
7216
0
    const int s0 = opts[3];
7217
0
    const int s1 = opts[4];
7218
0
    const int p0 = opts[5];
7219
0
    const int p1 = opts[6];
7220
0
    const char * cdata = (const char*)src->data;
7221
0
    const char * const data_end = cdata + ggml_nbytes(src);
7222
7223
0
    const int64_t px = dst->ne[0];
7224
0
    const int64_t py = dst->ne[1];
7225
0
    const int64_t pa = px * py;
7226
7227
0
    float * dplane = (float *)dst->data;
7228
7229
0
    const int ka = k0 * k1;
7230
0
    const int offset0 = -p0;
7231
0
    const int offset1 = -p1;
7232
7233
0
    while (cdata < data_end) {
7234
0
        for (int oy = 0; oy < py; ++oy) {
7235
0
            float * const drow = dplane + oy * px;
7236
0
            float * const out  = drow;
7237
7238
0
            for (int ox = 0; ox < px; ++ox) {
7239
0
                float res = 0;
7240
0
                switch (op) {
7241
0
                    case GGML_OP_POOL_AVG: res = 0;        break;
7242
0
                    case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7243
0
                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7244
0
                }
7245
7246
0
                const int ix = offset0 + ox * s0;
7247
0
                const int iy = offset1 + oy * s1;
7248
7249
0
                for (int ky = 0; ky < k1; ++ky) {
7250
0
                    if (iy + ky < 0 || iy + ky >= src->ne[1]) {
7251
0
                        continue;
7252
0
                    }
7253
7254
0
                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
7255
0
                    for (int kx = 0; kx < k0; ++kx) {
7256
0
                        int j = ix + kx;
7257
0
                        if (j < 0 || j >= src->ne[0]) {
7258
0
                            continue;
7259
0
                        }
7260
7261
0
                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7262
0
                        switch (op) {
7263
0
                            case GGML_OP_POOL_AVG: res += srow_j;                break;
7264
0
                            case GGML_OP_POOL_MAX: res =  std::max(srow_j, res); break;
7265
0
                            case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
7266
0
                        }
7267
0
                    }
7268
0
                }
7269
0
                switch (op) {
7270
0
                    case GGML_OP_POOL_AVG:           res /= ka; break;
7271
0
                    case GGML_OP_POOL_MAX:                      break;
7272
0
                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7273
0
                }
7274
7275
0
                out[ox] = res;
7276
0
            }
7277
0
        }
7278
7279
0
        cdata  += src->nb[2];
7280
0
        dplane += pa;
7281
0
    }
7282
0
}
7283
7284
// ggml_compute_forward_pool_2d_back
7285
7286
void ggml_compute_forward_pool_2d_back(
7287
        const ggml_compute_params * params,
7288
0
        ggml_tensor * dst) {
7289
7290
0
    const ggml_tensor * src  = dst->src[0];
7291
0
    const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
7292
7293
0
    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
7294
7295
0
    if (params->ith != 0) {
7296
0
        return;
7297
0
    }
7298
7299
0
    const int32_t * opts = (const int32_t *)dst->op_params;
7300
0
    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7301
0
    const int k0 = opts[1];
7302
0
    const int k1 = opts[2];
7303
0
    const int s0 = opts[3];
7304
0
    const int s1 = opts[4];
7305
0
    const int p0 = opts[5];
7306
0
    const int p1 = opts[6];
7307
7308
0
    char       * cdata  = (char       *) dst->data;
7309
0
    const char * cdataf = (const char *) dstf->data;
7310
0
    const char * const data_end = cdata + ggml_nbytes(dst);
7311
7312
0
    GGML_ASSERT(params->ith == 0);
7313
0
    memset(cdata, 0, ggml_nbytes(dst));
7314
7315
0
    const int64_t px = src->ne[0];
7316
0
    const int64_t py = src->ne[1];
7317
0
    const int64_t pa = px * py;
7318
7319
0
    const float * splane = (const float *) src->data;
7320
7321
0
    const int ka = k0 * k1;
7322
0
    const int offset0 = -p0;
7323
0
    const int offset1 = -p1;
7324
7325
0
    while (cdata < data_end) {
7326
0
        for (int oy = 0; oy < py; ++oy) {
7327
0
            const float * const srow = splane + oy * px;
7328
0
            for (int ox = 0; ox < px; ++ox) {
7329
0
                const float grad0 = srow[ox];
7330
7331
0
                const int ix = offset0 + ox * s0;
7332
0
                const int iy = offset1 + oy * s1;
7333
7334
0
                if (op == GGML_OP_POOL_MAX) {
7335
0
                    float maxval = -FLT_MAX;
7336
0
                    int kxmax = -1;
7337
0
                    int kymax = -1;
7338
7339
0
                    for (int ky = 0; ky < k1; ++ky) {
7340
0
                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
7341
0
                            continue;
7342
0
                        }
7343
0
                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
7344
0
                        for (int kx = 0; kx < k0; ++kx) {
7345
0
                            int j = ix + kx;
7346
0
                            if (j < 0 || j >= dst->ne[0]) {
7347
0
                                continue;
7348
0
                            }
7349
7350
0
                            const float val = dst->type == GGML_TYPE_F32 ?
7351
0
                                ((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
7352
0
                            if (val <= maxval) {
7353
0
                                continue;
7354
0
                            }
7355
7356
0
                            maxval = val;
7357
0
                            kxmax = kx;
7358
0
                            kymax = ky;
7359
0
                        }
7360
0
                    }
7361
7362
0
                    if (kxmax == -1 || kymax == -1) {
7363
0
                        continue;
7364
0
                    }
7365
7366
0
                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
7367
0
                    const int j = ix + kxmax;
7368
0
                    if (dst->type == GGML_TYPE_F32) {
7369
0
                        ((float *) drow)[j] += grad0;
7370
0
                    } else {
7371
0
                        ((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
7372
0
                    }
7373
0
                } else if (op == GGML_OP_POOL_AVG) {
7374
0
                    const float grad = grad0 / ka;
7375
7376
0
                    for (int ky = 0; ky < k1; ++ky) {
7377
0
                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
7378
0
                            continue;
7379
0
                        }
7380
0
                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
7381
0
                        for (int kx = 0; kx < k0; ++kx) {
7382
0
                            int j = ix + kx;
7383
0
                            if (j < 0 || j >= dst->ne[0]) {
7384
0
                                continue;
7385
0
                            }
7386
7387
0
                            if (dst->type == GGML_TYPE_F32) {
7388
0
                                ((float *) drow)[j] += grad;
7389
0
                            } else {
7390
0
                                ((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
7391
0
                            }
7392
0
                        }
7393
0
                    }
7394
0
                } else {
7395
0
                    GGML_ASSERT(false);
7396
0
                }
7397
0
            }
7398
0
        }
7399
7400
0
        cdata  += dst->nb[2];
7401
0
        cdataf += dst->nb[2];
7402
0
        splane += pa;
7403
0
    }
7404
0
}
7405
7406
// ggml_compute_forward_upscale
7407
7408
static void ggml_compute_forward_upscale_f32(
7409
    const ggml_compute_params * params,
7410
0
    ggml_tensor * dst) {
7411
7412
0
    const ggml_tensor * src0 = dst->src[0];
7413
7414
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
7415
7416
0
    const int ith = params->ith;
7417
0
    const int nth = params->nth;
7418
7419
0
    GGML_TENSOR_UNARY_OP_LOCALS
7420
7421
0
    float sf0 = (float)ne0/src0->ne[0];
7422
0
    float sf1 = (float)ne1/src0->ne[1];
7423
0
    float sf2 = (float)ne2/src0->ne[2];
7424
0
    float sf3 = (float)ne3/src0->ne[3];
7425
0
    float pixel_offset = 0.5f;
7426
7427
0
    const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7428
0
    const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7429
7430
0
    if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7431
0
        pixel_offset = 0.0f;
7432
0
        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7433
0
        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7434
0
    }
7435
7436
0
    if (mode == GGML_SCALE_MODE_NEAREST) {
7437
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7438
0
            const int64_t i03 = i3 / sf3;
7439
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7440
0
                const int64_t i02 = i2 / sf2;
7441
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7442
0
                    const int64_t i01 = i1 / sf1;
7443
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7444
0
                        const int64_t i00 = i0 / sf0;
7445
7446
0
                        const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7447
0
                              float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
7448
7449
0
                        *y = *x;
7450
0
                    }
7451
0
                }
7452
0
            }
7453
0
        }
7454
0
    } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
7455
        // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7456
        // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
7457
0
        auto triangle_filter = [](float x) -> float {
7458
0
            return std::max(1.0f - fabsf(x), 0.0f);
7459
0
        };
7460
7461
        // support and invscale, minimum 1 pixel for bilinear
7462
0
        const float support1  = std::max(1.0f, 1.0f / sf1);
7463
0
        const float invscale1 = 1.0f / support1;
7464
0
        const float support0  = std::max(1.0f, 1.0f / sf0);
7465
0
        const float invscale0 = 1.0f / support0;
7466
7467
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7468
0
            const int64_t i03 = i3 / sf3;
7469
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7470
0
                const int64_t i02 = i2 / sf2;
7471
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7472
0
                    const float y = ((float) i1 + pixel_offset) / sf1;
7473
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7474
0
                        const float x = ((float) i0 + pixel_offset) / sf0;
7475
7476
                        // the range of source pixels that contribute
7477
0
                        const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
7478
0
                        const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
7479
0
                        const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
7480
0
                        const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
7481
7482
                        // bilinear filter with antialiasing
7483
0
                        float val = 0.0f;
7484
0
                        float total_weight = 0.0f;
7485
7486
0
                        for (int64_t sy = y_min; sy < y_max; sy++) {
7487
0
                            const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
7488
7489
0
                            for (int64_t sx = x_min; sx < x_max; sx++) {
7490
0
                                const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
7491
0
                                const float weight = weight_x * weight_y;
7492
7493
0
                                if (weight <= 0.0f) {
7494
0
                                    continue;
7495
0
                                }
7496
7497
0
                                const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
7498
0
                                val += pixel * weight;
7499
0
                                total_weight += weight;
7500
0
                            }
7501
0
                        }
7502
7503
0
                        if (total_weight > 0.0f) {
7504
0
                            val /= total_weight;
7505
0
                        }
7506
7507
0
                        float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7508
0
                        *dst_ptr = val;
7509
0
                    }
7510
0
                }
7511
0
            }
7512
0
        }
7513
0
    } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7514
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7515
0
            const int64_t i03 = i3 / sf3;
7516
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7517
0
                const int64_t i02 = i2 / sf2;
7518
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7519
0
                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7520
0
                    int64_t y0 = (int64_t)floorf(y);
7521
0
                    int64_t y1 = y0 + 1;
7522
7523
0
                    y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
7524
0
                    y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
7525
7526
0
                    float dy = y - (float)y0;
7527
0
                    dy = std::max(0.0f, std::min(dy, 1.0f));
7528
7529
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7530
0
                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7531
0
                        int64_t x0 = (int64_t)floorf(x);
7532
0
                        int64_t x1 = x0 + 1;
7533
7534
0
                        x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
7535
0
                        x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
7536
7537
0
                        float dx = x - (float)x0;
7538
0
                        dx = std::max(0.0f, std::min(dx, 1.0f));
7539
7540
                        // fetch the four surrounding pixel values and interpolate
7541
0
                        const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
7542
0
                        const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
7543
0
                        const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
7544
0
                        const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
7545
7546
0
                        const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7547
7548
0
                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7549
0
                        *y_dst = val;
7550
0
                    }
7551
0
                }
7552
0
            }
7553
0
        }
7554
0
    } else if (mode == GGML_SCALE_MODE_BICUBIC) {
7555
        // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7556
0
        const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7557
0
        auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7558
0
        auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7559
0
        auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7560
0
            const float w0 = weight2(x + 1);
7561
0
            const float w1 = weight1(x + 0);
7562
0
            const float w2 = weight1(1 - x);
7563
0
            const float w3 = weight2(2 - x);
7564
0
            return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7565
0
        };
7566
7567
0
        for (int64_t i3 = 0; i3 < ne3; i3++) {
7568
0
            const int64_t i03 = i3 / sf3;
7569
0
            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7570
0
                const int64_t i02 = i2 / sf2;
7571
0
                for (int64_t i1 = 0; i1 < ne1; i1++) {
7572
0
                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7573
0
                    const int64_t y0 = (int64_t)floorf(y);
7574
0
                    const float dy = y - (float)y0;
7575
7576
0
                    for (int64_t i0 = 0; i0 < ne0; i0++) {
7577
0
                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7578
0
                        const int64_t x0 = (int64_t)floorf(x);
7579
0
                        const float dx = x - (float)x0;
7580
7581
0
                        auto p = [=](int64_t x_off, int64_t y_off) -> float {
7582
0
                            int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7583
0
                            int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7584
0
                            return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7585
0
                        };
7586
7587
0
                        const float val = bicubic(
7588
0
                            bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7589
0
                            bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7590
0
                            bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7591
0
                            bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7592
7593
0
                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7594
0
                        *y_dst = val;
7595
0
                    }
7596
0
                }
7597
0
            }
7598
0
        }
7599
0
    } else {
7600
0
        GGML_ABORT("unsupported upscale mode");
7601
0
    }
7602
0
}
7603
7604
void ggml_compute_forward_upscale(
7605
    const ggml_compute_params * params,
7606
0
    ggml_tensor * dst) {
7607
7608
0
    const ggml_tensor * src0 = dst->src[0];
7609
7610
0
    switch (src0->type) {
7611
0
        case GGML_TYPE_F32:
7612
0
            {
7613
0
                ggml_compute_forward_upscale_f32(params, dst);
7614
0
            } break;
7615
0
        default:
7616
0
            {
7617
0
                GGML_ABORT("fatal error");
7618
0
            }
7619
0
    }
7620
0
}
7621
7622
7623
// ggml_compute_forward_pad
7624
7625
template<bool circular_t>
7626
static void ggml_compute_forward_pad_f32(
7627
    const ggml_compute_params * params,
7628
0
          ggml_tensor * dst) {
7629
7630
0
    const ggml_tensor * src0 = dst->src[0];
7631
7632
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
7633
0
    GGML_ASSERT( dst->nb[0] == sizeof(float));
7634
7635
0
    const int ith = params->ith;
7636
0
    const int nth = params->nth;
7637
7638
0
    GGML_TENSOR_UNARY_OP_LOCALS
7639
7640
0
    float * dst_ptr = (float *) dst->data;
7641
0
    const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
7642
0
    const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
7643
0
    const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
7644
0
    const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
7645
0
    const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
7646
0
    const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
7647
0
    const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
7648
0
    const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
7649
7650
    // TODO: optimize
7651
7652
0
    for (int64_t i2 = 0; i2 < ne2; ++i2) {
7653
0
        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7654
0
            for (int64_t i0 = 0; i0 < ne0; ++i0) {
7655
0
                for (int64_t i3 = 0; i3 < ne3; ++i3) {
7656
                    // circular means wrap around on a torus, so x and y loop around
7657
0
                    if constexpr (circular_t) {
7658
0
                        const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7659
0
                        const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
7660
0
                        const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
7661
0
                        const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
7662
0
                        const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
7663
7664
0
                        const int64_t src_idx =
7665
0
                            src_i3*nb03 +
7666
0
                            src_i2*nb02 +
7667
0
                            src_i1*nb01 +
7668
0
                            src_i0*nb00;
7669
7670
0
                        const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7671
0
                        dst_ptr[dst_idx] = *src_ptr;
7672
0
                    } else {
7673
0
                        const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7674
0
                        if ((i0 >= lp0 && i0 < ne0 - rp0) \
7675
0
                            && (i1 >= lp1 && i1 < ne1 - rp1) \
7676
0
                            && (i2 >= lp2 && i2 < ne2 - rp2) \
7677
0
                            && (i3 >= lp3 && i3 < ne3 - rp3)) {
7678
0
                            const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7679
0
                            const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7680
0
                            dst_ptr[dst_idx] = *src_ptr;
7681
0
                        } else {
7682
0
                            dst_ptr[dst_idx] = 0;
7683
0
                        }
7684
0
                    }
7685
0
                }
7686
0
            }
7687
0
        }
7688
0
    }
7689
0
}
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_pad_f32<true>(ggml_compute_params const*, ggml_tensor*)
Unexecuted instantiation: ops.cpp:void ggml_compute_forward_pad_f32<false>(ggml_compute_params const*, ggml_tensor*)
7690
7691
7692
void ggml_compute_forward_pad(
7693
    const ggml_compute_params * params,
7694
0
    ggml_tensor * dst) {
7695
0
    const ggml_tensor * src0 = dst->src[0];
7696
0
    const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
7697
0
    switch (src0->type) {
7698
0
        case GGML_TYPE_F32:
7699
0
            {
7700
0
                if (circular) {
7701
0
                    ggml_compute_forward_pad_f32<true>(params, dst);
7702
0
                } else {
7703
0
                    ggml_compute_forward_pad_f32<false>(params, dst);
7704
0
                }
7705
0
            } break;
7706
0
        default:
7707
0
            {
7708
0
                GGML_ABORT("fatal error");
7709
0
            }
7710
0
    }
7711
0
}
7712
7713
// ggml_compute_forward_pad_reflect_1d
7714
7715
void ggml_compute_forward_pad_reflect_1d(
7716
        const ggml_compute_params * params,
7717
0
              ggml_tensor * dst) {
7718
7719
0
    const ggml_tensor * src0 = dst->src[0];
7720
7721
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
7722
0
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
7723
7724
0
    const int ith = params->ith;
7725
0
    const int nth = params->nth;
7726
7727
0
    const int32_t * opts = (const int32_t *) dst->op_params;
7728
0
    const int p0 = opts[0];
7729
0
    const int p1 = opts[1];
7730
7731
0
    GGML_TENSOR_UNARY_OP_LOCALS
7732
7733
0
    for (int64_t i3 = 0; i3 < ne3; i3++) {
7734
0
        for (int64_t i2 = 0; i2 < ne2; i2++) {
7735
0
            for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7736
0
                float * left  = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 +         p0*nb0);
7737
0
                float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
7738
7739
0
                ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
7740
7741
0
                for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0];   }
7742
0
                for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
7743
0
            }
7744
0
        }
7745
0
    }
7746
0
}
7747
7748
// ggml_compute_forward_roll
7749
7750
0
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
7751
0
    if (i < 0) {
7752
0
        return i + ne;
7753
0
    } else if (i >= ne) {
7754
0
        return i - ne;
7755
0
    }
7756
0
    return i;
7757
0
}
7758
7759
static void ggml_compute_forward_roll_f32(
7760
        const ggml_compute_params * params,
7761
0
        ggml_tensor * dst) {
7762
7763
0
    const ggml_tensor * src0 = dst->src[0];
7764
0
    const float * src_data = (const float *) src0->data;
7765
0
    float * dst_data = (float *) dst->data;
7766
7767
0
    GGML_TENSOR_UNARY_OP_LOCALS
7768
7769
0
    const int s0 = ggml_get_op_params_i32(dst, 0);
7770
0
    const int s1 = ggml_get_op_params_i32(dst, 1);
7771
0
    const int s2 = ggml_get_op_params_i32(dst, 2);
7772
0
    const int s3 = ggml_get_op_params_i32(dst, 3);
7773
7774
0
    const int64_t total = ne1 * ne2 * ne3;
7775
0
    const int64_t per_thread = (total + params->nth) / params->nth;
7776
0
    const int64_t start = params->ith * per_thread;
7777
0
    const int64_t end   = std::min(start + per_thread, total);
7778
7779
0
    for (int64_t i = start; i < end; ++i) {
7780
0
        const int64_t i1 = i % ne1;
7781
0
        const int64_t i2 = (i / ne1) % ne2;
7782
0
        const int64_t i3 = i / (ne2 * ne1);
7783
0
        float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
7784
7785
0
        const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
7786
0
        const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
7787
0
        const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
7788
0
        const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
7789
7790
0
        const int64_t s = ggml_wrap_index(-s0, ne00);
7791
0
        const int64_t n = ne00 - s;
7792
0
        ggml_vec_cpy_f32(n, dst_row,     src_row + s);
7793
0
        ggml_vec_cpy_f32(s, dst_row + n, src_row);
7794
0
    }
7795
0
}
7796
7797
void ggml_compute_forward_roll(
7798
        const ggml_compute_params * params,
7799
0
        ggml_tensor * dst) {
7800
7801
0
    const ggml_tensor * src0 = dst->src[0];
7802
7803
0
    switch (src0->type) {
7804
0
        case GGML_TYPE_F32:
7805
0
            {
7806
0
                ggml_compute_forward_roll_f32(params, dst);
7807
0
            } break;
7808
0
        default:
7809
0
            {
7810
0
                GGML_ABORT("fatal error");
7811
0
            }
7812
0
    }
7813
0
}
7814
7815
// ggml_compute_forward_arange
7816
7817
static void ggml_compute_forward_arange_f32(
7818
    const ggml_compute_params * params,
7819
0
    ggml_tensor * dst) {
7820
7821
0
    GGML_ASSERT(dst->nb[0] == sizeof(float));
7822
7823
0
    const int ith = params->ith;
7824
0
    const int nth = params->nth;
7825
7826
0
    const float start = ggml_get_op_params_f32(dst, 0);
7827
0
    const float stop  = ggml_get_op_params_f32(dst, 1);
7828
0
    const float step  = ggml_get_op_params_f32(dst, 2);
7829
7830
0
    const int64_t steps = (int64_t) ceilf((stop - start) / step);
7831
7832
0
    GGML_ASSERT(ggml_nelements(dst) == steps);
7833
7834
0
    for (int64_t i = ith; i < steps; i+= nth) {
7835
0
        float value = start + step * i;
7836
0
        ((float *)dst->data)[i] = value;
7837
0
    }
7838
0
}
7839
7840
void ggml_compute_forward_arange(
7841
    const ggml_compute_params * params,
7842
0
    ggml_tensor * dst) {
7843
0
    switch (dst->type) {
7844
0
        case GGML_TYPE_F32:
7845
0
            {
7846
0
                ggml_compute_forward_arange_f32(params, dst);
7847
0
            } break;
7848
0
        default:
7849
0
            {
7850
0
                GGML_ABORT("fatal error");
7851
0
            }
7852
0
    }
7853
0
}
7854
7855
static void ggml_compute_forward_timestep_embedding_f32(
7856
    const ggml_compute_params * params,
7857
0
    ggml_tensor * dst) {
7858
7859
0
    const ggml_tensor * src0 = dst->src[0];
7860
7861
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
7862
7863
0
    const int ith = params->ith;
7864
0
    const int nth = params->nth;
7865
7866
0
    GGML_TENSOR_UNARY_OP_LOCALS
7867
7868
0
    const int dim = ggml_get_op_params_i32(dst, 0);
7869
0
    const int max_period = ggml_get_op_params_i32(dst, 1);
7870
7871
0
    int half = dim / 2;
7872
7873
0
    for (int64_t i = 0; i < ne00; i++) {
7874
0
        float * embed_data = (float *)((char *)  dst->data +  i*nb1);
7875
0
        for (int64_t j = ith; j < half; j += nth) {
7876
0
            float timestep = ((float *)src0->data)[i];
7877
0
            float freq = (float)expf(-logf(max_period) * j / half);
7878
0
            float arg = timestep * freq;
7879
0
            embed_data[j] = cosf(arg);
7880
0
            embed_data[j + half] = sinf(arg);
7881
0
        }
7882
0
        if (dim % 2 != 0 && ith == 0) {
7883
0
            embed_data[2 * half] = 0.f;
7884
0
        }
7885
0
    }
7886
0
}
7887
7888
void ggml_compute_forward_timestep_embedding(
7889
    const ggml_compute_params * params,
7890
0
    ggml_tensor * dst) {
7891
7892
0
    const ggml_tensor * src0 = dst->src[0];
7893
7894
0
    switch (src0->type) {
7895
0
        case GGML_TYPE_F32:
7896
0
            {
7897
0
                ggml_compute_forward_timestep_embedding_f32(params, dst);
7898
0
            } break;
7899
0
        default:
7900
0
            {
7901
0
                GGML_ABORT("fatal error");
7902
0
            }
7903
0
    }
7904
0
}
7905
7906
// ggml_compute_forward_argsort
7907
7908
template<enum ggml_sort_order order>
7909
struct cmp_argsort {
7910
    const float * data;
7911
0
    bool operator()(int32_t a, int32_t b) const {
7912
0
        if constexpr (order == GGML_SORT_ORDER_ASC) {
7913
0
            return data[a] < data[b];
7914
0
        } else {
7915
0
            return data[a] > data[b];
7916
0
        }
7917
0
    }
Unexecuted instantiation: cmp_argsort<(ggml_sort_order)0>::operator()(int, int) const
Unexecuted instantiation: cmp_argsort<(ggml_sort_order)1>::operator()(int, int) const
7918
};
7919
7920
static void ggml_compute_forward_argsort_f32(
7921
    const ggml_compute_params * params,
7922
0
    ggml_tensor * dst) {
7923
7924
0
    const ggml_tensor * src0 = dst->src[0];
7925
7926
0
    GGML_TENSOR_UNARY_OP_LOCALS
7927
7928
0
    GGML_ASSERT(nb0 == sizeof(float));
7929
7930
0
    const int ith = params->ith;
7931
0
    const int nth = params->nth;
7932
7933
0
    const int64_t nr = ggml_nrows(src0);
7934
7935
0
    ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
7936
7937
0
    for (int64_t i = ith; i < nr; i += nth) {
7938
0
        const float * src_data = (float *)((char *) src0->data + i*nb01);
7939
7940
0
        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7941
7942
0
        for (int64_t j = 0; j < ne0; j++) {
7943
0
            dst_data[j] = j;
7944
0
        }
7945
7946
0
        switch (order) {
7947
0
            case GGML_SORT_ORDER_ASC:
7948
0
                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
7949
0
                break;
7950
7951
0
            case GGML_SORT_ORDER_DESC:
7952
0
                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
7953
0
                break;
7954
7955
0
            default:
7956
0
                GGML_ABORT("invalid sort order");
7957
0
        }
7958
0
    }
7959
0
}
7960
7961
void ggml_compute_forward_argsort(
7962
    const ggml_compute_params * params,
7963
0
    ggml_tensor * dst) {
7964
7965
0
    const ggml_tensor * src0 = dst->src[0];
7966
7967
0
    switch (src0->type) {
7968
0
        case GGML_TYPE_F32:
7969
0
            {
7970
0
                ggml_compute_forward_argsort_f32(params, dst);
7971
0
            } break;
7972
0
        default:
7973
0
            {
7974
0
                GGML_ABORT("fatal error");
7975
0
            }
7976
0
    }
7977
0
}
7978
7979
// ggml_compute_forward_top_k
7980
7981
struct cmp_top_k {
7982
    const float * data;
7983
0
    bool operator()(int32_t a, int32_t b) const {
7984
0
        return data[a] > data[b];
7985
0
    }
7986
};
7987
7988
static void ggml_compute_forward_top_k_f32(
7989
    const ggml_compute_params * params,
7990
0
    ggml_tensor * dst) {
7991
7992
0
    const ggml_tensor * src0 = dst->src[0];
7993
7994
0
    GGML_TENSOR_UNARY_OP_LOCALS
7995
7996
0
    GGML_ASSERT(nb0 == sizeof(float));
7997
7998
0
    const int ith = params->ith;
7999
0
    const int nth = params->nth;
8000
8001
0
    const int64_t nr = ggml_nrows(src0);
8002
8003
0
    const int top_k = ne0;
8004
8005
0
    int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
8006
8007
0
    for (int64_t i = ith; i < nr; i += nth) {
8008
0
        const float * src_data = (float *)((char *) src0->data + i*nb01);
8009
8010
0
        for (int64_t j = 0; j < ne00; j++) {
8011
0
            tmp[j] = j;
8012
0
        }
8013
8014
0
        std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
8015
8016
0
        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
8017
8018
0
        std::copy(tmp, tmp + top_k, dst_data);
8019
8020
        // emphasize that the order is not important
8021
0
        if (top_k > 1) {
8022
0
            std::swap(dst_data[0], dst_data[1]);
8023
0
        }
8024
0
    }
8025
0
}
8026
8027
void ggml_compute_forward_top_k(
8028
    const ggml_compute_params * params,
8029
0
    ggml_tensor * dst) {
8030
8031
0
    const ggml_tensor * src0 = dst->src[0];
8032
8033
0
    switch (src0->type) {
8034
0
        case GGML_TYPE_F32:
8035
0
            {
8036
0
                ggml_compute_forward_top_k_f32(params, dst);
8037
0
            } break;
8038
0
        default:
8039
0
            {
8040
0
                GGML_ABORT("fatal error");
8041
0
            }
8042
0
    }
8043
0
}
8044
8045
// ggml_compute_forward_flash_attn_ext
8046
8047
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8048
        const ggml_compute_params * params,
8049
        ggml_tensor * dst,
8050
0
        int ir0, int ir1) {
8051
0
    const ggml_tensor * q     = dst->src[0];
8052
0
    const ggml_tensor * k     = dst->src[1];
8053
0
    const ggml_tensor * v     = dst->src[2];
8054
0
    const ggml_tensor * mask  = dst->src[3];
8055
0
    const ggml_tensor * sinks = dst->src[4];
8056
8057
0
    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
8058
0
    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
8059
0
    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
8060
0
    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
8061
0
    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
8062
0
    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
8063
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
8064
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
8065
8066
0
    const int64_t DK = nek0;
8067
0
    const int64_t DV = nev0;
8068
0
    const int64_t N  = neq1;
8069
8070
0
    GGML_ASSERT(ne0 == DV);
8071
0
    GGML_ASSERT(ne2 == N);
8072
8073
    // input tensor rows must be contiguous
8074
0
    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8075
0
    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8076
0
    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8077
8078
0
    GGML_ASSERT(neq0 == DK);
8079
0
    GGML_ASSERT(nek0 == DK);
8080
0
    GGML_ASSERT(nev0 == DV);
8081
8082
0
    GGML_ASSERT(neq1 == N);
8083
8084
    // dst cannot be transposed or permuted
8085
0
    GGML_ASSERT(nb0 == sizeof(float));
8086
0
    GGML_ASSERT(nb0 <= nb1);
8087
0
    GGML_ASSERT(nb1 <= nb2);
8088
0
    GGML_ASSERT(nb2 <= nb3);
8089
8090
    // broadcast factors
8091
0
    const int64_t rk2 = neq2/nek2;
8092
0
    const int64_t rk3 = neq3/nek3;
8093
8094
0
    const int64_t rv2 = neq2/nev2;
8095
0
    const int64_t rv3 = neq3/nev3;
8096
8097
    // parallelize by q rows using ggml_vec_dot_f32
8098
8099
0
    float scale         = 1.0f;
8100
0
    float max_bias      = 0.0f;
8101
0
    float logit_softcap = 0.0f;
8102
8103
0
    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
8104
0
    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
8105
0
    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
8106
8107
0
    if (logit_softcap != 0) {
8108
0
        scale /= logit_softcap;
8109
0
    }
8110
8111
0
    const uint32_t n_head      = neq2;
8112
0
    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
8113
8114
0
    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
8115
0
    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8116
8117
0
    ggml_type         const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
8118
0
    ggml_from_float_t const q_to_vec_dot   = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
8119
0
    ggml_vec_dot_t    const kq_vec_dot     = ggml_get_type_traits_cpu(k->type)->vec_dot;
8120
0
    ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
8121
8122
0
    GGML_ASSERT((                            q_to_vec_dot) && "fattn: unsupported K-type");
8123
0
    GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float  ) && "fattn: unsupported V-type");
8124
8125
0
    int ith = params->ith;
8126
8127
    // loop over n_batch and n_head
8128
0
    for (int ir = ir0; ir < ir1; ++ir) {
8129
        // q indices
8130
0
        const int iq3 = ir/(neq2*neq1);
8131
0
        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
8132
0
        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
8133
8134
0
        const uint32_t h = iq2; // head index
8135
0
        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
8136
8137
0
        float S = 0.0f;      // sum
8138
0
        float M = -INFINITY; // maximum KQ value
8139
8140
0
        float       * VKQ32 = (float       *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
8141
0
        float       * V32   =                 (VKQ32 + 1*DV); // (temporary) FP32 V buffer
8142
0
        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
8143
0
        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
8144
8145
0
        if (v->type == GGML_TYPE_F16) {
8146
0
            memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
8147
0
        } else {
8148
0
            memset(VKQ32, 0, DV*sizeof(float));
8149
0
        }
8150
8151
0
        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
8152
8153
        // k indices
8154
0
        const int ik3 = iq3 / rk3;
8155
0
        const int ik2 = iq2 / rk2;
8156
8157
        // v indices
8158
0
        const int iv3 = iq3 / rv3;
8159
0
        const int iv2 = iq2 / rv2;
8160
8161
0
        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
8162
0
        q_to_vec_dot(pq, Q_q, DK);
8163
8164
        // online softmax / attention
8165
        // loop over n_kv and n_head_kv
8166
        // ref: https://arxiv.org/pdf/2112.05682.pdf
8167
0
        for (int64_t ic = 0; ic < nek1; ++ic) {
8168
0
            const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
8169
0
            if (mv == -INFINITY) {
8170
0
                continue;
8171
0
            }
8172
8173
0
            float s; // KQ value
8174
8175
0
            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
8176
0
            kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
8177
8178
0
            s = s*scale; // scale KQ value
8179
8180
0
            if (logit_softcap != 0.0f) {
8181
0
                s = logit_softcap*tanhf(s);
8182
0
            }
8183
8184
0
            s += mv; // apply mask
8185
8186
0
            const float Mold = M;
8187
8188
0
            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
8189
0
            float vs = 1.0f; // post-softmax KQ value, expf(s - M)
8190
8191
0
            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
8192
8193
0
            if (v->type == GGML_TYPE_F16) {
8194
0
                if (s > M) {
8195
                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
8196
0
                    M = s;
8197
0
                    ms = expf(Mold - M);
8198
8199
                    // V = V*expf(Mold - M)
8200
0
                    ggml_vec_scale_f16(DV, VKQ16, ms);
8201
0
                } else {
8202
                    // no new maximum, ms == 1.0f, vs != 1.0f
8203
0
                    vs = expf(s - M);
8204
0
                }
8205
8206
                // V += v*expf(s - M)
8207
0
                ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
8208
0
            } else {
8209
0
                if (s > M) {
8210
                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
8211
0
                    M = s;
8212
0
                    ms = expf(Mold - M);
8213
8214
                    // V = V*expf(Mold - M)
8215
0
                    ggml_vec_scale_f32(DV, VKQ32, ms);
8216
0
                } else {
8217
                    // no new maximum, ms == 1.0f, vs != 1.0f
8218
0
                    vs = expf(s - M);
8219
0
                }
8220
8221
                // V += v*expf(s - M)
8222
0
                if (v_to_float) {
8223
0
                    v_to_float(v_data, V32, DV);
8224
0
                    ggml_vec_mad_f32(DV, VKQ32, V32, vs);
8225
0
                } else {
8226
                    // V is F32
8227
0
                    ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
8228
0
                }
8229
0
            }
8230
8231
0
            S = S*ms + vs; // scale and increment sum with partial sum
8232
0
        }
8233
8234
0
        if (v->type == GGML_TYPE_F16) {
8235
0
            for (int64_t d = 0; d < DV; ++d) {
8236
0
                VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
8237
0
            }
8238
0
        }
8239
8240
        // sinks
8241
0
        if (sinks) {
8242
0
            const float s = ((float *)((char *) sinks->data))[h];
8243
8244
0
            float ms = 1.0f;
8245
0
            float vs = 1.0f;
8246
8247
0
            if (s > M) {
8248
0
                ms = expf(M - s);
8249
0
                ggml_vec_scale_f32(DV, VKQ32, ms);
8250
0
            } else {
8251
0
                vs = expf(s - M);
8252
0
            }
8253
8254
0
            S = S*ms + vs;
8255
0
        }
8256
8257
        // V /= S
8258
0
        const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8259
0
        ggml_vec_scale_f32(DV, VKQ32, S_inv);
8260
8261
        // dst indices
8262
0
        const int i1 = iq1;
8263
0
        const int i2 = iq2;
8264
0
        const int i3 = iq3;
8265
8266
        // original
8267
        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
8268
8269
        // permute(0, 2, 1, 3)
8270
0
        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8271
0
    }
8272
0
}
8273
8274
static void ggml_compute_forward_flash_attn_ext_f16(
8275
        const ggml_compute_params * params,
8276
0
        ggml_tensor * dst) {
8277
8278
0
    const ggml_tensor * q     = dst->src[0];
8279
0
    const ggml_tensor * k     = dst->src[1];
8280
0
    const ggml_tensor * v     = dst->src[2];
8281
8282
0
    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
8283
0
    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
8284
0
    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
8285
0
    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
8286
0
    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
8287
0
    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
8288
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
8289
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
8290
8291
0
    const int64_t DK = nek0;
8292
0
    const int64_t DV = nev0;
8293
0
    const int64_t N  = neq1;
8294
8295
0
    GGML_ASSERT(ne0 == DV);
8296
0
    GGML_ASSERT(ne2 == N);
8297
8298
    // input tensor rows must be contiguous
8299
0
    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8300
0
    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8301
0
    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8302
8303
0
    GGML_ASSERT(neq0 == DK);
8304
0
    GGML_ASSERT(nek0 == DK);
8305
0
    GGML_ASSERT(nev0 == DV);
8306
8307
0
    GGML_ASSERT(neq1 == N);
8308
8309
    // dst cannot be transposed or permuted
8310
0
    GGML_ASSERT(nb0 == sizeof(float));
8311
0
    GGML_ASSERT(nb0 <= nb1);
8312
0
    GGML_ASSERT(nb1 <= nb2);
8313
0
    GGML_ASSERT(nb2 <= nb3);
8314
8315
    // parallelize by q rows using ggml_vec_dot_f32
8316
8317
    // total rows in q
8318
0
    const int64_t nr = neq1*neq2*neq3;
8319
8320
    // rows per thread
8321
0
    const int ith = params->ith;
8322
0
    const int nth = params->nth;
8323
8324
    // disable for NUMA
8325
0
    const bool disable_chunking = ggml_is_numa();
8326
8327
    // 4x chunks per thread
8328
0
    int nth_scaled = nth * 4;
8329
0
    int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8330
0
    int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
8331
8332
0
    if (nth == 1 || nchunk < nth || disable_chunking) {
8333
0
        nchunk = nth;
8334
0
    }
8335
8336
0
    if (ith == 0) {
8337
        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
8338
0
        ggml_threadpool_chunk_set(params->threadpool, nth);
8339
0
    }
8340
8341
0
    ggml_barrier(params->threadpool);
8342
8343
    // The number of elements in each chunk
8344
0
    const int64_t dr = (nr + nchunk - 1) / nchunk;
8345
8346
    // The first chunk comes from our thread_id, the rest will get auto-assigned.
8347
0
    int current_chunk = ith;
8348
8349
0
    while (current_chunk < nchunk) {
8350
0
        const int64_t ir0 = dr * current_chunk;
8351
0
        const int64_t ir1 = MIN(ir0 + dr, nr);
8352
8353
0
        ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8354
8355
0
        current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8356
0
    }
8357
0
}
8358
8359
void ggml_compute_forward_flash_attn_ext(
8360
        const ggml_compute_params * params,
8361
0
        ggml_tensor * dst) {
8362
0
    switch (dst->op_params[3]) {
8363
0
        case GGML_PREC_DEFAULT:
8364
0
        case GGML_PREC_F32:
8365
0
            {
8366
                // uses F32 accumulators
8367
0
                ggml_compute_forward_flash_attn_ext_f16(params, dst);
8368
0
            } break;
8369
0
        default:
8370
0
            {
8371
0
                GGML_ABORT("fatal error");
8372
0
            }
8373
0
    }
8374
0
}
8375
8376
// ggml_compute_forward_flash_attn_back
8377
8378
static void ggml_compute_forward_flash_attn_back_f32(
8379
        const ggml_compute_params * params,
8380
        const bool masked,
8381
0
              ggml_tensor * dst) {
8382
8383
0
    const ggml_tensor * q = dst->src[0];
8384
0
    const ggml_tensor * k = dst->src[1];
8385
0
    const ggml_tensor * v = dst->src[2];
8386
0
    const ggml_tensor * d = dst->src[3];
8387
8388
0
    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
8389
0
    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
8390
0
    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
8391
0
    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
8392
0
    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
8393
0
    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
8394
0
    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)
8395
0
    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)
8396
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
8397
0
    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
8398
8399
0
    const int ith = params->ith;
8400
0
    const int nth = params->nth;
8401
8402
0
    const int64_t D = neq0;
8403
0
    const int64_t N = neq1;
8404
0
    const int64_t P = nek1 - N;
8405
0
    const int64_t M = P + N;
8406
8407
0
    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
8408
0
    const int mxDM = MAX(D, Mup);
8409
8410
    // GGML_ASSERT(ne0 == D);
8411
    // GGML_ASSERT(ne1 == N);
8412
0
    GGML_ASSERT(P >= 0);
8413
8414
0
    GGML_ASSERT(nbq0 == sizeof(float));
8415
0
    GGML_ASSERT(nbk0 == sizeof(float));
8416
0
    GGML_ASSERT(nbv0 == sizeof(float));
8417
8418
0
    GGML_ASSERT(neq0 == D);
8419
0
    GGML_ASSERT(nek0 == D);
8420
0
    GGML_ASSERT(nev1 == D);
8421
0
    GGML_ASSERT(ned0 == D);
8422
8423
0
    GGML_ASSERT(neq1 == N);
8424
0
    GGML_ASSERT(nek1 == N + P);
8425
0
    GGML_ASSERT(nev1 == D);
8426
0
    GGML_ASSERT(ned1 == N);
8427
8428
    // dst cannot be transposed or permuted
8429
0
    GGML_ASSERT(nb0 == sizeof(float));
8430
0
    GGML_ASSERT(nb0 <= nb1);
8431
0
    GGML_ASSERT(nb1 <= nb2);
8432
0
    GGML_ASSERT(nb2 <= nb3);
8433
8434
0
    if (ith == 0) {
8435
0
        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
8436
0
    }
8437
0
    ggml_barrier(params->threadpool);
8438
8439
0
    const int64_t elem_q = ggml_nelements(q);
8440
0
    const int64_t elem_k = ggml_nelements(k);
8441
8442
0
    ggml_type result_type = dst->type;
8443
0
    GGML_ASSERT(ggml_blck_size(result_type) == 1);
8444
0
    const size_t tsize = ggml_type_size(result_type);
8445
8446
0
    const size_t offs_q = 0;
8447
0
    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
8448
0
    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
8449
8450
0
    void * grad_q = (char *) dst->data;
8451
0
    void * grad_k = (char *) dst->data + offs_k;
8452
0
    void * grad_v = (char *) dst->data + offs_v;
8453
8454
0
    const size_t nbgq1 = nb0*neq0;
8455
0
    const size_t nbgq2 = nb0*neq0*neq1;
8456
0
    const size_t nbgq3 = nb0*neq0*neq1*neq2;
8457
8458
0
    const size_t nbgk1 = nb0*nek0;
8459
0
    const size_t nbgk2 = nb0*nek0*nek1;
8460
0
    const size_t nbgk3 = nb0*nek0*nek1*neq2;
8461
8462
0
    const size_t nbgv1 = nb0*nev0;
8463
0
    const size_t nbgv2 = nb0*nev0*nev1;
8464
0
    const size_t nbgv3 = nb0*nev0*nev1*neq2;
8465
8466
    // parallelize by k rows using ggml_vec_dot_f32
8467
8468
    // total rows in k
8469
0
    const int nr = nek2*nek3;
8470
8471
    // rows per thread
8472
0
    const int dr = (nr + nth - 1)/nth;
8473
8474
    // row range for this thread
8475
0
    const int ir0 = dr*ith;
8476
0
    const int ir1 = MIN(ir0 + dr, nr);
8477
8478
0
    const float scale = 1.0f/sqrtf(D);
8479
8480
    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
8481
8482
    // how often k2 (and v2) is repeated in q2
8483
0
    int nrep = neq2/nek2;
8484
8485
0
    for (int ir = ir0; ir < ir1; ++ir) {
8486
        // q indices
8487
0
        const int ik3 = ir/(nek2);
8488
0
        const int ik2 = ir - ik3*nek2;
8489
8490
0
        const int iq3 = ik3;
8491
0
        const int id3 = ik3;
8492
0
        const int iv3 = ik3;
8493
0
        const int iv2 = ik2;
8494
8495
0
        for (int irep = 0; irep < nrep; ++irep) {
8496
0
            const int iq2 = ik2 + irep*nek2;
8497
0
            const int id2 = iq2;
8498
8499
            // (ik2 + irep*nek2) % nek2 == ik2
8500
0
            for (int iq1 = 0; iq1 < neq1; ++iq1) {
8501
0
                const int id1 = iq1;
8502
8503
                // not sure about CACHE_LINE_SIZE_F32..
8504
                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
8505
0
                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
8506
0
                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
8507
8508
0
                for (int i = M; i < Mup; ++i) {
8509
0
                    S[i] = -INFINITY;
8510
0
                }
8511
8512
0
                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
8513
0
                for (int64_t ic = 0; ic < masked_begin; ++ic) {
8514
                    // k indices
8515
0
                    const int ik1 = ic;
8516
8517
                    // S indices
8518
0
                    const int i1 = ik1;
8519
8520
0
                    ggml_vec_dot_f32(neq0,
8521
0
                            S + i1, 0,
8522
0
                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
8523
0
                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
8524
0
                }
8525
8526
                // scale
8527
0
                ggml_vec_scale_f32(masked_begin, S, scale);
8528
8529
0
                for (int64_t i = masked_begin; i < M; i++) {
8530
0
                    S[i] = -INFINITY;
8531
0
                }
8532
8533
                // softmax
8534
                // exclude known -INF S[..] values from max and loop
8535
                // dont forget to set their SM values to zero
8536
0
                {
8537
0
                    float max = -INFINITY;
8538
0
                    ggml_vec_max_f32(masked_begin, &max, S);
8539
8540
0
                    ggml_float sum = 0.0;
8541
0
                    {
8542
#ifdef GGML_SOFT_MAX_ACCELERATE
8543
                        max = -max;
8544
                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
8545
                        vvexpf(SM, SM, &Mup);
8546
                        ggml_vec_sum_f32(Mup, &sum, SM);
8547
#else
8548
0
                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
8549
0
#endif
8550
0
                    }
8551
8552
0
                    assert(sum > 0.0);
8553
8554
0
                    sum = 1.0/sum;
8555
0
                    ggml_vec_scale_f32(masked_begin, SM, sum);
8556
8557
0
                }
8558
8559
                // step-by-step explanation
8560
0
                {
8561
                    // forward-process                    shape      grads from backward process
8562
                    // parallel_for ik2,ik3:
8563
                    //  for irep:
8564
                    //   iq2 = ik2 + irep*nek2
8565
                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]
8566
                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
8567
                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]
8568
                    //   for iq1:
8569
                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
8570
                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
8571
                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
8572
                    //    S0     = -Inf                   [D,1,1,1]
8573
                    //   ~S1[i]  = dot(kcur[:D,i], qcur)
8574
                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
8575
                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
8576
                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
8577
                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
8578
                    //   ~S5[i]  = dot(vcur[:,i], S4)
8579
                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]
8580
                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
8581
                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
8582
                    // dst                               backward-/ grad[dst]                 = d
8583
                    //
8584
                    // output gradients with their dependencies:
8585
                    //
8586
                    // grad[kcur] = grad[S1].T @ qcur
8587
                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
8588
                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
8589
                    // grad[S4]   = grad[S5] @ vcur
8590
                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
8591
                    // grad[qcur] = grad[S1]   @ kcur
8592
                    // grad[vcur] = grad[S5].T @ S4
8593
                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
8594
                    //
8595
                    // in post-order:
8596
                    //
8597
                    // S1         = qcur @ kcur.T
8598
                    // S2         = S1 * scale
8599
                    // S3         = diag_mask_inf(S2, P)
8600
                    // S4         = softmax(S3)
8601
                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
8602
                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
8603
                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
8604
                    // grad[qcur] = grad[S1]   @ kcur
8605
                    // grad[kcur] = grad[S1].T @ qcur
8606
                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
8607
                    //
8608
                    // using less variables (SM=S4):
8609
                    //
8610
                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
8611
                    // SM            = softmax(S)
8612
                    // S             = d[:D,iq1,iq2,iq3] @ vcur
8613
                    // dot_SM_gradSM = dot(SM, S)
8614
                    // S             = SM * (S - dot(SM, S))
8615
                    // S             = diag_mask_zero(S, P) * scale
8616
                    //
8617
                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
8618
                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur
8619
                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM
8620
0
                }
8621
8622
                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
8623
                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
8624
                // for ic:
8625
                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
8626
                // exclude known future zero S[..] values from operation
8627
0
                ggml_vec_set_f32(masked_begin, S, 0);
8628
0
                for (int64_t ic = 0; ic < D; ++ic) {
8629
0
                    ggml_vec_mad_f32(masked_begin,
8630
0
                            S,
8631
0
                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),
8632
0
                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
8633
0
                }
8634
8635
                // S = SM * (S - dot(SM, S))
8636
0
                float dot_SM_gradSM = 0;
8637
0
                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
8638
0
                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
8639
0
                ggml_vec_mul_f32 (masked_begin, S, S, SM);
8640
8641
                // S = diag_mask_zero(S, P) * scale
8642
                // already done by above ggml_vec_set_f32
8643
8644
                // exclude known zero S[..] values from operation
8645
0
                ggml_vec_scale_f32(masked_begin, S, scale);
8646
8647
                // S    shape [M,1]
8648
                // SM   shape [M,1]
8649
                // kcur shape [D,M]
8650
                // qcur shape [D,1]
8651
                // vcur shape [M,D]
8652
8653
                // grad[q][:D,iq1,iq2,iq3] += S @ kcur
8654
                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
8655
                // for ic:
8656
                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
8657
                // exclude known zero S[..] values from loop
8658
0
                for (int64_t ic = 0; ic < masked_begin; ++ic) {
8659
0
                    ggml_vec_mad_f32(D,
8660
0
                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),
8661
0
                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),
8662
0
                            S[ic]);
8663
0
                }
8664
8665
                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
8666
                // for ic:
8667
                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
8668
                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
8669
                // exclude known zero S[..] values from loop
8670
0
                for (int64_t ic = 0; ic < masked_begin; ++ic) {
8671
0
                    ggml_vec_mad_f32(D,
8672
0
                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),
8673
0
                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),
8674
0
                            S[ic]);
8675
0
                }
8676
8677
                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM
8678
                // for ic:
8679
                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
8680
                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]
8681
                // exclude known zero SM[..] values from mad
8682
0
                for (int64_t ic = 0; ic < D; ++ic) {
8683
0
                    ggml_vec_mad_f32(masked_begin,
8684
0
                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
8685
0
                            SM,
8686
0
                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));
8687
0
                }
8688
0
            }
8689
0
        }
8690
0
    }
8691
0
}
8692
8693
void ggml_compute_forward_flash_attn_back(
8694
        const ggml_compute_params * params,
8695
        const bool masked,
8696
0
        ggml_tensor * dst) {
8697
8698
0
    const ggml_tensor * q = dst->src[0];
8699
8700
0
    switch (q->type) {
8701
0
        case GGML_TYPE_F32:
8702
0
            {
8703
0
                ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
8704
0
            } break;
8705
0
        default:
8706
0
            {
8707
0
                GGML_ABORT("fatal error");
8708
0
            }
8709
0
    }
8710
0
}
8711
8712
// ggml_compute_forward_ssm_conv
8713
8714
static void ggml_compute_forward_ssm_conv_f32(
8715
        const ggml_compute_params * params,
8716
0
        ggml_tensor * dst) {
8717
0
    const ggml_tensor * src0 = dst->src[0]; // conv_x
8718
0
    const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
8719
8720
0
    const int ith = params->ith;
8721
0
    const int nth = params->nth;
8722
8723
0
    const int nc  = src1->ne[0]; // d_conv
8724
0
    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
8725
0
    const int nr  = src0->ne[1]; // d_inner
8726
0
    const int n_t =  dst->ne[1]; // tokens per sequence
8727
0
    const int n_s =  dst->ne[2]; // number of sequences in the batch
8728
8729
0
    GGML_ASSERT( dst->ne[0] == nr);
8730
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
8731
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
8732
0
    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
8733
8734
    // rows per thread
8735
0
    const int dr = (nr + nth - 1)/nth;
8736
8737
    // row range for this thread
8738
0
    const int ir0 = dr*ith;
8739
0
    const int ir1 = MIN(ir0 + dr, nr);
8740
0
    const int ir  = ir1 - ir0;
8741
8742
0
    for (int i3 = 0; i3 < n_s; ++i3) {
8743
0
        for (int i2 = 0; i2 < n_t; ++i2) {
8744
            // {d_conv - 1 + n_t, d_inner, n_seqs}
8745
            // sliding window
8746
0
            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
8747
0
            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
8748
0
            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
8749
8750
            // TODO: transpose the output for smaller strides for big batches?
8751
            // d_inner
8752
0
            for (int i1 = 0; i1 < ir; ++i1) {
8753
                // rowwise dot product
8754
                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
8755
0
                float sumf = 0.0f;
8756
8757
                // d_conv
8758
0
                for (int i0 = 0; i0 < nc; ++i0) {
8759
0
                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
8760
0
                }
8761
0
                x[i1] = sumf;
8762
0
            }
8763
0
        }
8764
0
    }
8765
0
}
8766
8767
void ggml_compute_forward_ssm_conv(
8768
        const ggml_compute_params * params,
8769
0
        ggml_tensor * dst) {
8770
0
    switch (dst->src[0]->type) {
8771
0
        case GGML_TYPE_F32:
8772
0
            {
8773
0
                ggml_compute_forward_ssm_conv_f32(params, dst);
8774
0
            } break;
8775
0
        default:
8776
0
            {
8777
0
                GGML_ABORT("fatal error");
8778
0
            }
8779
0
    }
8780
0
}
8781
8782
// ggml_compute_forward_ssm_scan
8783
8784
static void ggml_compute_forward_ssm_scan_f32(
8785
        const ggml_compute_params * params,
8786
0
        ggml_tensor * dst) {
8787
0
    const ggml_tensor * src0 = dst->src[0]; // s  {d_state, dim, n_head, n_seqs+}
8788
0
    const ggml_tensor * src1 = dst->src[1]; // x  {dim, n_head, n_seq_tokens, n_seqs}
8789
0
    const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8790
0
    const ggml_tensor * src3 = dst->src[3]; // A  {d_state, n_head} or {1, n_head}
8791
0
    const ggml_tensor * src4 = dst->src[4]; // B  {d_state, n_group, n_seq_tokens, n_seqs}
8792
0
    const ggml_tensor * src5 = dst->src[5]; // C  {d_state, n_group, n_seq_tokens, n_seqs}
8793
0
    const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
8794
8795
0
    const int ith = params->ith;
8796
0
    const int nth = params->nth;
8797
8798
0
    const int64_t nc = src0->ne[0]; // d_state
8799
0
    const int64_t nr = src0->ne[1]; // dim
8800
0
    const int64_t nh = src1->ne[1]; // n_head
8801
0
    const int64_t ng = src4->ne[1];
8802
0
    const int64_t nt = src1->ne[2]; // number of tokens per sequence
8803
0
    const int64_t ns = src1->ne[3]; // number of sequences in the batch
8804
8805
    // can't use ggml_nbytes because src1 is not necessarily contiguous
8806
0
    const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
8807
8808
0
    GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
8809
0
    GGML_ASSERT(src0->nb[0] == sizeof(float));
8810
0
    GGML_ASSERT(src1->nb[0] == sizeof(float));
8811
0
    GGML_ASSERT(src2->nb[0] == sizeof(float));
8812
0
    GGML_ASSERT(src3->nb[0] == sizeof(float));
8813
0
    GGML_ASSERT(src4->nb[0] == sizeof(float));
8814
0
    GGML_ASSERT(src5->nb[0] == sizeof(float));
8815
0
    GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8816
0
    GGML_ASSERT(nh % ng == 0);
8817
8818
    // heads per thread
8819
0
    const int dh = (nh + nth - 1)/nth;
8820
8821
    // head range for this thread
8822
0
    const int ih0 = dh*ith;
8823
0
    const int ih1 = MIN(ih0 + dh, nh);
8824
8825
0
    const int32_t * ids = (const int32_t *) src6->data;
8826
8827
0
    for (int i3 = 0; i3 < ns; ++i3) {
8828
0
        const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8829
0
              float * s  = (      float *) ((      char *) dst->data  + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8830
8831
0
        for (int i2 = 0; i2 < nt; ++i2) {
8832
0
            const float * x  = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8833
0
            const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8834
0
            const float * A  = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8835
0
            const float * B  = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8836
0
            const float * C  = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8837
0
                  float * y  = (      float *) ((      char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8838
8839
0
            if (src3->ne[0] == 1) {
8840
                // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8841
8842
                // n_head
8843
0
                for (int h = ih0; h < ih1; ++h) {
8844
                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8845
0
                    const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
8846
0
                    const float dA = expf(dt_soft_plus * A[h]);
8847
0
                    const int g = h / (nh / ng); // repeat_interleave
8848
8849
                    // dim
8850
0
                    for (int i1 = 0; i1 < nr; ++i1) {
8851
0
                        const int ii = i1 + h*nr;
8852
0
                        const float x_dt = x[ii] * dt_soft_plus;
8853
0
                        float sumf = 0.0f;
8854
0
#if defined(GGML_SIMD)
8855
    #if defined(__ARM_FEATURE_SVE)
8856
                        const int ggml_f32_epr = svcntw();
8857
                        const int ggml_f32_step = 1 * ggml_f32_epr;
8858
8859
                        const int np = (nc & ~(ggml_f32_step - 1));
8860
8861
                        GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8862
8863
                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8864
                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8865
8866
                        for (int i = 0; i < np; i += ggml_f32_step) {
8867
                            // TODO: maybe unroll more?
8868
                            for (int j = 0; j < 1; j++) {
8869
                                GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8870
                                GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
8871
                                GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
8872
8873
                                t0 = GGML_F32_VEC_MUL(t0, adA);
8874
                                t1 = GGML_F32_VEC_MUL(t1, axdt);
8875
8876
                                t0 = GGML_F32_VEC_ADD(t0, t1);
8877
8878
                                sum = GGML_F32_VEC_FMA(sum, t0, t2);
8879
8880
                                GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8881
                            }
8882
                        }
8883
8884
                        sumf = GGML_F32xt_REDUCE_ONE(sum);
8885
    #elif defined(__riscv_v_intrinsic)
8886
                        // todo: RVV implementation
8887
                        const int np = 0;
8888
    #else
8889
0
                        const int np = (nc & ~(GGML_F32_STEP - 1));
8890
8891
0
                        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8892
8893
0
                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8894
0
                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8895
8896
0
                        GGML_F32_VEC ax[GGML_F32_ARR];
8897
0
                        GGML_F32_VEC ay[GGML_F32_ARR];
8898
0
                        GGML_F32_VEC az[GGML_F32_ARR];
8899
8900
0
                        for (int i = 0; i < np; i += GGML_F32_STEP) {
8901
0
                            for (int j = 0; j < GGML_F32_ARR; j++) {
8902
0
                                ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
8903
0
                                ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
8904
0
                                az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
8905
8906
0
                                ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
8907
0
                                ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
8908
8909
0
                                ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
8910
8911
0
                                sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8912
8913
0
                                GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
8914
0
                            }
8915
0
                        }
8916
8917
                        // reduce sum0..sum3 to sum0
8918
0
                        GGML_F32_VEC_REDUCE(sumf, sum);
8919
0
    #endif
8920
#else
8921
                        const int np = 0;
8922
#endif
8923
                        // d_state
8924
0
                        for (int i0 = np; i0 < nc; ++i0) {
8925
0
                            const int i = i0 + ii*nc;
8926
0
                            const int ig = i0 + g*nc;
8927
                            // state = prev_state * dA + dB * x
8928
0
                            const float state = (s0[i] * dA) + (B[ig] * x_dt);
8929
                            // y = rowwise_dotprod(state, C)
8930
0
                            sumf += state * C[ig];
8931
0
                            s[i] = state;
8932
0
                        }
8933
0
                        y[ii] = sumf;
8934
0
                    }
8935
0
                }
8936
0
            } else {
8937
                // Mamba-1 has an element-wise decay factor for the states
8938
8939
                // n_head
8940
0
                for (int h = ih0; h < ih1; ++h) {
8941
                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8942
0
                    const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
8943
0
                    const int g = h / (nh / ng); // repeat_interleave
8944
8945
                    // dim
8946
0
                    for (int i1 = 0; i1 < nr; ++i1) {
8947
0
                        const int ii = i1 + h*nr;
8948
0
                        const float x_dt = x[ii] * dt_soft_plus;
8949
#if defined(__ARM_FEATURE_SVE)
8950
                        svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8951
                        svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8952
                        svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8953
8954
                        // d_state
8955
                        // TODO: what happens when (d_state % svcntw()) != 0?
8956
                        for (int64_t k = 0; k < nc; k += svcntw()) {
8957
                            svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8958
                            svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
8959
                            svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
8960
                            svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8961
8962
                            svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8963
                            t1 = exp_ps_sve(svptrue_b32(), t1);
8964
                            svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8965
8966
                            vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8967
                            r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8968
8969
                            GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8970
                        }
8971
                        y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8972
#else
8973
0
                        float sumf = 0.0f;
8974
                        // NOTE: can't really use GGML_SIMD here because d_state is usually 16
8975
                        //       and also because expf is used within the loop.
8976
                        // d_state
8977
0
                        for (int i0 = 0; i0 < nc; ++i0) {
8978
0
                            const int i = i0 + ii*nc;
8979
0
                            const int ig = i0 + g*nc;
8980
                            // state = prev_state * dA + dB * x
8981
0
                            const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8982
                            // y = rowwise_dotprod(state, C)
8983
0
                            sumf += state * C[ig];
8984
0
                            s[i] = state;
8985
0
                        }
8986
0
                        y[ii] = sumf;
8987
0
#endif
8988
0
                    }
8989
0
                }
8990
0
            }
8991
            // use the output as the source when it's not the first token-wise iteration
8992
0
            s0 = s;
8993
0
        }
8994
0
    }
8995
0
}
8996
8997
void ggml_compute_forward_ssm_scan(
8998
        const ggml_compute_params * params,
8999
0
        ggml_tensor * dst) {
9000
0
    switch (dst->src[0]->type) {
9001
0
        case GGML_TYPE_F32:
9002
0
            {
9003
0
                ggml_compute_forward_ssm_scan_f32(params, dst);
9004
0
            } break;
9005
0
        default:
9006
0
            {
9007
0
                GGML_ABORT("fatal error");
9008
0
            }
9009
0
    }
9010
0
}
9011
9012
// ggml_compute_forward_win_part
9013
9014
static void ggml_compute_forward_win_part_f32(
9015
        const ggml_compute_params * params,
9016
0
        ggml_tensor * dst) {
9017
0
    GGML_UNUSED(params);
9018
9019
0
    const ggml_tensor * src0 = dst->src[0];
9020
9021
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
9022
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
9023
9024
0
    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
9025
0
    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
9026
0
    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
9027
9028
0
    assert(ne00 == ne0);
9029
0
    assert(ne3  == nep0*nep1);
9030
9031
    // TODO: optimize / multi-thread
9032
0
    for (int py = 0; py < nep1; ++py) {
9033
0
        for (int px = 0; px < nep0; ++px) {
9034
0
            const int64_t i3 = py*nep0 + px;
9035
0
            for (int64_t i2 = 0; i2 < ne2; ++i2) {
9036
0
                for (int64_t i1 = 0; i1 < ne1; ++i1) {
9037
0
                    for (int64_t i0 = 0; i0 < ne0; ++i0) {
9038
0
                        const int64_t i02 = py*w + i2;
9039
0
                        const int64_t i01 = px*w + i1;
9040
0
                        const int64_t i00 = i0;
9041
9042
0
                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;
9043
0
                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;
9044
9045
0
                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
9046
0
                            ((float *) dst->data)[i] = 0.0f;
9047
0
                        } else {
9048
0
                            ((float *) dst->data)[i] = ((float *) src0->data)[j];
9049
0
                        }
9050
0
                    }
9051
0
                }
9052
0
            }
9053
0
        }
9054
0
    }
9055
0
}
9056
9057
void ggml_compute_forward_win_part(
9058
        const ggml_compute_params * params,
9059
0
        ggml_tensor * dst) {
9060
9061
0
    const ggml_tensor * src0 = dst->src[0];
9062
9063
0
    switch (src0->type) {
9064
0
        case GGML_TYPE_F32:
9065
0
            {
9066
0
                ggml_compute_forward_win_part_f32(params, dst);
9067
0
            } break;
9068
0
        default:
9069
0
            {
9070
0
                GGML_ABORT("fatal error");
9071
0
            }
9072
0
    }
9073
0
}
9074
9075
// ggml_compute_forward_win_unpart
9076
9077
static void ggml_compute_forward_win_unpart_f32(
9078
        const ggml_compute_params * params,
9079
0
        ggml_tensor * dst) {
9080
0
    GGML_UNUSED(params);
9081
9082
0
    const ggml_tensor * src0 = dst->src[0];
9083
9084
0
    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
9085
0
    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
9086
9087
0
    const int32_t w = ((const int32_t *)(dst->op_params))[0];
9088
9089
    // padding
9090
0
    const int px = (w - ne1%w)%w;
9091
    //const int py = (w - ne2%w)%w;
9092
9093
0
    const int npx = (px + ne1)/w;
9094
    //const int npy = (py + ne2)/w;
9095
9096
0
    assert(ne0 == ne00);
9097
9098
    // TODO: optimize / multi-thread
9099
0
    for (int64_t i2 = 0; i2 < ne2; ++i2) {
9100
0
        for (int64_t i1 = 0; i1 < ne1; ++i1) {
9101
0
            for (int64_t i0 = 0; i0 < ne0; ++i0) {
9102
0
                const int ip2 = i2/w;
9103
0
                const int ip1 = i1/w;
9104
9105
0
                const int64_t i02 = i2%w;
9106
0
                const int64_t i01 = i1%w;
9107
0
                const int64_t i00 = i0;
9108
9109
0
                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
9110
0
                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;
9111
9112
0
                ((float *) dst->data)[j] = ((float *) src0->data)[i];
9113
0
            }
9114
0
        }
9115
0
    }
9116
0
}
9117
9118
void ggml_compute_forward_win_unpart(
9119
        const ggml_compute_params * params,
9120
0
        ggml_tensor * dst) {
9121
9122
0
    const ggml_tensor * src0 = dst->src[0];
9123
9124
0
    switch (src0->type) {
9125
0
        case GGML_TYPE_F32:
9126
0
            {
9127
0
                ggml_compute_forward_win_unpart_f32(params, dst);
9128
0
            } break;
9129
0
        default:
9130
0
            {
9131
0
                GGML_ABORT("fatal error");
9132
0
            }
9133
0
    }
9134
0
}
9135
9136
//gmml_compute_forward_unary
9137
9138
void ggml_compute_forward_unary(
9139
        const ggml_compute_params * params,
9140
0
        ggml_tensor * dst) {
9141
9142
0
    const ggml_unary_op op = ggml_get_unary_op(dst);
9143
9144
0
    switch (op) {
9145
0
        case GGML_UNARY_OP_ABS:
9146
0
            {
9147
0
                ggml_compute_forward_abs(params, dst);
9148
0
            } break;
9149
0
        case GGML_UNARY_OP_SGN:
9150
0
            {
9151
0
                ggml_compute_forward_sgn(params, dst);
9152
0
            } break;
9153
0
        case GGML_UNARY_OP_NEG:
9154
0
            {
9155
0
                ggml_compute_forward_neg(params, dst);
9156
0
            } break;
9157
0
        case GGML_UNARY_OP_STEP:
9158
0
            {
9159
0
                ggml_compute_forward_step(params, dst);
9160
0
            } break;
9161
0
        case GGML_UNARY_OP_TANH:
9162
0
            {
9163
0
                ggml_compute_forward_tanh(params, dst);
9164
0
            } break;
9165
0
        case GGML_UNARY_OP_ELU:
9166
0
            {
9167
0
                ggml_compute_forward_elu(params, dst);
9168
0
            } break;
9169
0
        case GGML_UNARY_OP_RELU:
9170
0
            {
9171
0
                ggml_compute_forward_relu(params, dst);
9172
0
            } break;
9173
0
        case GGML_UNARY_OP_SIGMOID:
9174
0
            {
9175
0
                ggml_compute_forward_sigmoid(params, dst);
9176
0
            } break;
9177
0
        case GGML_UNARY_OP_GELU:
9178
0
            {
9179
0
                ggml_compute_forward_gelu(params, dst);
9180
0
            } break;
9181
0
        case GGML_UNARY_OP_GELU_ERF:
9182
0
            {
9183
0
                ggml_compute_forward_gelu_erf(params, dst);
9184
0
            } break;
9185
0
        case GGML_UNARY_OP_GELU_QUICK:
9186
0
            {
9187
0
                ggml_compute_forward_gelu_quick(params, dst);
9188
0
            } break;
9189
0
        case GGML_UNARY_OP_SILU:
9190
0
            {
9191
0
                ggml_compute_forward_silu(params, dst);
9192
0
            } break;
9193
0
        case GGML_UNARY_OP_HARDSWISH:
9194
0
            {
9195
0
                ggml_compute_forward_hardswish(params, dst);
9196
0
            } break;
9197
0
        case GGML_UNARY_OP_HARDSIGMOID:
9198
0
            {
9199
0
                ggml_compute_forward_hardsigmoid(params, dst);
9200
0
            } break;
9201
0
        case GGML_UNARY_OP_EXP:
9202
0
            {
9203
0
                ggml_compute_forward_exp(params, dst);
9204
0
            } break;
9205
0
        case GGML_UNARY_OP_FLOOR:
9206
0
            {
9207
0
                ggml_compute_forward_floor(params, dst);
9208
0
            } break;
9209
0
        case GGML_UNARY_OP_CEIL:
9210
0
            {
9211
0
                ggml_compute_forward_ceil(params, dst);
9212
0
            } break;
9213
0
        case GGML_UNARY_OP_ROUND:
9214
0
            {
9215
0
                ggml_compute_forward_round(params, dst);
9216
0
            } break;
9217
0
        case GGML_UNARY_OP_TRUNC:
9218
0
            {
9219
0
                ggml_compute_forward_trunc(params, dst);
9220
0
            } break;
9221
0
        case GGML_UNARY_OP_XIELU:
9222
0
            {
9223
0
                ggml_compute_forward_xielu(params, dst);
9224
0
            } break;
9225
0
        case GGML_UNARY_OP_EXPM1:
9226
0
            {
9227
0
                ggml_compute_forward_expm1(params, dst);
9228
0
            } break;
9229
0
        case GGML_UNARY_OP_SOFTPLUS:
9230
0
            {
9231
0
                ggml_compute_forward_softplus(params, dst);
9232
0
            } break;
9233
0
        default:
9234
0
            {
9235
0
                GGML_ABORT("fatal error");
9236
0
            }
9237
0
    }
9238
0
}
9239
9240
//ggml_compute_forward_glu
9241
9242
void ggml_compute_forward_glu(
9243
        const ggml_compute_params * params,
9244
0
        ggml_tensor * dst) {
9245
9246
0
    const ggml_glu_op op = ggml_get_glu_op(dst);
9247
9248
0
    switch (op) {
9249
0
        case GGML_GLU_OP_REGLU:
9250
0
            {
9251
0
                ggml_compute_forward_reglu(params, dst);
9252
0
            } break;
9253
0
        case GGML_GLU_OP_GEGLU:
9254
0
            {
9255
0
                ggml_compute_forward_geglu(params, dst);
9256
0
            } break;
9257
0
        case GGML_GLU_OP_SWIGLU:
9258
0
            {
9259
0
                ggml_compute_forward_swiglu(params, dst);
9260
0
            } break;
9261
0
        case GGML_GLU_OP_SWIGLU_OAI:
9262
0
            {
9263
0
                ggml_compute_forward_swiglu_oai(params, dst);
9264
0
            } break;
9265
0
        case GGML_GLU_OP_GEGLU_ERF:
9266
0
            {
9267
0
                ggml_compute_forward_geglu_erf(params, dst);
9268
0
            } break;
9269
0
        case GGML_GLU_OP_GEGLU_QUICK:
9270
0
            {
9271
0
                ggml_compute_forward_geglu_quick(params, dst);
9272
0
            } break;
9273
0
        default:
9274
0
            {
9275
0
                GGML_ABORT("fatal error");
9276
0
            }
9277
0
    }
9278
0
}
9279
9280
// ggml_compute_forward_get_rel_pos
9281
9282
static void ggml_compute_forward_get_rel_pos_f16(
9283
        const ggml_compute_params * params,
9284
0
        ggml_tensor * dst) {
9285
0
    GGML_UNUSED(params);
9286
9287
0
    const ggml_tensor * src0 = dst->src[0];
9288
9289
    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
9290
9291
0
    GGML_TENSOR_UNARY_OP_LOCALS
9292
9293
0
    const int64_t w = ne1;
9294
9295
0
    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
9296
0
    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;
9297
9298
0
    for (int64_t i2 = 0; i2 < ne2; ++i2) {
9299
0
        for (int64_t i1 = 0; i1 < ne1; ++i1) {
9300
0
            const int64_t pos = (w - i1 - 1) + i2;
9301
0
            for (int64_t i0 = 0; i0 < ne0; ++i0) {
9302
0
                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
9303
0
            }
9304
0
        }
9305
0
    }
9306
0
}
9307
9308
void ggml_compute_forward_get_rel_pos(
9309
        const ggml_compute_params * params,
9310
0
        ggml_tensor * dst) {
9311
9312
0
    const ggml_tensor * src0 = dst->src[0];
9313
9314
0
    switch (src0->type) {
9315
0
        case GGML_TYPE_F16:
9316
0
        case GGML_TYPE_BF16:
9317
0
            {
9318
0
                ggml_compute_forward_get_rel_pos_f16(params, dst);
9319
0
            } break;
9320
0
        default:
9321
0
            {
9322
0
                GGML_ABORT("fatal error");
9323
0
            }
9324
0
    }
9325
0
}
9326
9327
// ggml_compute_forward_add_rel_pos
9328
9329
static void ggml_compute_forward_add_rel_pos_f32(
9330
        const ggml_compute_params * params,
9331
0
        ggml_tensor * dst) {
9332
9333
0
    const ggml_tensor * src0 = dst->src[0];
9334
0
    const ggml_tensor * src1 = dst->src[1];
9335
0
    const ggml_tensor * src2 = dst->src[2];
9336
9337
0
    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
9338
0
    if (!inplace) {
9339
0
        if (params->ith == 0) {
9340
0
            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
9341
0
        }
9342
0
        ggml_barrier(params->threadpool);
9343
0
    }
9344
    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
9345
9346
0
    float * src1_data = (float *) src1->data;
9347
0
    float * src2_data = (float *) src2->data;
9348
0
    float * dst_data  = (float *) dst->data;
9349
9350
0
    const int64_t ne10 = src1->ne[0];
9351
0
    const int64_t ne11 = src1->ne[1];
9352
0
    const int64_t ne12 = src1->ne[2];
9353
0
    const int64_t ne13 = src1->ne[3];
9354
9355
0
    const int ith = params->ith;
9356
0
    const int nth = params->nth;
9357
9358
    // total patches in dst
9359
0
    const int np = ne13;
9360
9361
    // patches per thread
9362
0
    const int dp = (np + nth - 1)/nth;
9363
9364
    // patch range for this thread
9365
0
    const int ip0 = dp*ith;
9366
0
    const int ip1 = MIN(ip0 + dp, np);
9367
9368
0
    for (int64_t i13 = ip0; i13 < ip1; ++i13) {
9369
0
        for (int64_t i12 = 0; i12 < ne12; ++i12) {
9370
0
            for (int64_t i11 = 0; i11 < ne11; ++i11) {
9371
0
                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
9372
0
                for (int64_t i10 = 0; i10 < ne10; ++i10) {
9373
0
                    const int64_t jp0  = jp1 + i10;
9374
0
                    const float src1_e = src1_data[jp0];
9375
0
                    const float src2_e = src2_data[jp0];
9376
9377
0
                    const int64_t jdh = jp0 * ne10;
9378
0
                    const int64_t jdw = jdh - (ne10 - 1) * i10;
9379
9380
0
                    for (int64_t j = 0; j < ne10; ++j) {
9381
0
                        dst_data[jdh + j     ] += src2_e;
9382
0
                        dst_data[jdw + j*ne10] += src1_e;
9383
0
                    }
9384
0
                }
9385
0
            }
9386
0
        }
9387
0
    }
9388
0
}
9389
9390
void ggml_compute_forward_add_rel_pos(
9391
        const ggml_compute_params * params,
9392
0
        ggml_tensor * dst) {
9393
9394
0
    const ggml_tensor * src0 = dst->src[0];
9395
9396
0
    switch (src0->type) {
9397
0
        case GGML_TYPE_F32:
9398
0
            {
9399
0
                ggml_compute_forward_add_rel_pos_f32(params, dst);
9400
0
            } break;
9401
0
        default:
9402
0
            {
9403
0
                GGML_ABORT("fatal error");
9404
0
            }
9405
0
    }
9406
0
}
9407
9408
// ggml_compute_forward_rwkv_wkv6
9409
9410
static void ggml_compute_forward_rwkv_wkv6_f32(
9411
        const ggml_compute_params * params,
9412
0
        ggml_tensor * dst) {
9413
0
    const int64_t T = dst->src[1]->ne[2];
9414
0
    const int64_t C = dst->ne[0];
9415
0
    const int64_t HEADS = dst->src[1]->ne[1];
9416
0
    const int64_t n_seqs = dst->src[5]->ne[1];
9417
0
    const int64_t head_size = C / HEADS;
9418
9419
0
    float * dst_data = (float *) dst->data;
9420
0
    float * state = ((float *) dst->data) + C * T;
9421
9422
0
    const int ith = params->ith;
9423
0
    const int nth = params->nth;
9424
9425
0
    if (ith >= HEADS) {
9426
0
        return;
9427
0
    }
9428
9429
0
    const int h_start = (HEADS * ith) / nth;
9430
0
    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9431
0
                (HEADS * (ith + 1)) / nth : HEADS;
9432
9433
0
    float * k =          (float *) dst->src[0]->data;
9434
0
    float * v =          (float *) dst->src[1]->data;
9435
0
    float * r =          (float *) dst->src[2]->data;
9436
0
    float * time_faaaa = (float *) dst->src[3]->data;
9437
0
    float * time_decay = (float *) dst->src[4]->data;
9438
9439
0
    size_t t_stride = HEADS * head_size; // Same to C
9440
9441
0
    size_t h_stride = C / HEADS;
9442
0
    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
9443
0
    size_t h_stride_2d = head_size * head_size;
9444
9445
0
    if (ith == 0) {
9446
0
        memset(dst_data, 0, T * C * sizeof(float));
9447
0
    }
9448
0
    ggml_barrier(params->threadpool);
9449
9450
9451
0
    #if defined(__AVX__) && !defined(__AVX512F__)
9452
0
        #define GGML_F32X GGML_F32x8
9453
0
        #define GGML_F32X_SET1 GGML_F32x8_SET1
9454
0
        #define GGML_F32X_LOAD GGML_F32x8_LOAD
9455
0
        #define GGML_F32X_STORE GGML_F32x8_STORE
9456
0
        #define GGML_F32X_MUL GGML_F32x8_MUL
9457
0
        #define GGML_F32X_FMA GGML_F32x8_FMA
9458
0
        #define WKV_VECTOR_SIZE 8
9459
    #elif defined(__AVX512F__)
9460
        #define GGML_F32X GGML_F32x16
9461
        #define GGML_F32X_SET1 GGML_F32x16_SET1
9462
        #define GGML_F32X_LOAD GGML_F32x16_LOAD
9463
        #define GGML_F32X_STORE GGML_F32x16_STORE
9464
        #define GGML_F32X_MUL GGML_F32x16_MUL
9465
        #define GGML_F32X_FMA GGML_F32x16_FMA
9466
        #define WKV_VECTOR_SIZE 16
9467
    #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9468
        #define GGML_F32X GGML_F32xt
9469
        #define GGML_F32X_SET1 GGML_F32xt_SET1
9470
        #define GGML_F32X_LOAD GGML_F32xt_LOAD
9471
        #define GGML_F32X_STORE GGML_F32xt_STORE
9472
        #define GGML_F32X_MUL GGML_F32xt_MUL
9473
        #define GGML_F32X_FMA GGML_F32xt_FMA
9474
        #define WKV_VECTOR_SIZE 8
9475
    #elif defined(__ARM_NEON) && defined(__aarch64__)
9476
        #define GGML_F32X GGML_F32x4
9477
        #define GGML_F32X_SET1 GGML_F32x4_SET1
9478
        #define GGML_F32X_LOAD GGML_F32x4_LOAD
9479
        #define GGML_F32X_STORE GGML_F32x4_STORE
9480
        #define GGML_F32X_MUL GGML_F32x4_MUL
9481
        #define GGML_F32X_FMA GGML_F32x4_FMA
9482
        #define WKV_VECTOR_SIZE 4
9483
    #endif
9484
9485
0
    #ifdef WKV_VECTOR_SIZE
9486
0
        int wkv_vector_size;
9487
        #if defined(__ARM_FEATURE_SVE)
9488
            wkv_vector_size = svcntw();
9489
        #else
9490
0
            wkv_vector_size = WKV_VECTOR_SIZE;
9491
0
        #endif
9492
0
        const int64_t vec_count = head_size / wkv_vector_size;
9493
9494
0
        for (int64_t t = 0; t < T; t++) {
9495
0
            size_t t_offset = t * t_stride;
9496
0
            size_t state_offset = head_size * C * (t / (T / n_seqs));
9497
0
            float * state_cur = state + state_offset;
9498
0
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
9499
9500
0
            for (int64_t h = h_start; h < h_end; h++) {
9501
0
                size_t h_offset = h * h_stride;
9502
0
                size_t t_h_offset = t_offset + h_offset;
9503
0
                size_t h_2d_offset = h * h_stride_2d;
9504
9505
0
                for (int64_t i = 0; i < head_size; i++) {
9506
0
                    size_t t_h_i_offset = t_h_offset + i;
9507
0
                    size_t h_i_offset = h_offset + i;
9508
0
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9509
9510
0
                    float k_val = k[t_h_i_offset];
9511
0
                    float r_val = r[t_h_i_offset];
9512
0
                    float time_faaaa_val = time_faaaa[h_i_offset];
9513
0
                    float time_decay_val = time_decay[t_h_i_offset];
9514
9515
                    // Broadcast scalar values to vectors
9516
0
                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
9517
0
                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);
9518
0
                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
9519
0
                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
9520
9521
0
                    for (int64_t j = 0; j < vec_count; j++) {
9522
0
                        size_t base_j = j * wkv_vector_size;
9523
0
                        size_t t_h_j_offset = t_h_offset + base_j;
9524
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
9525
9526
                        // Load x elements at once
9527
0
                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
9528
0
                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
9529
0
                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
9530
9531
                        // Compute kv = v * k
9532
0
                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
9533
9534
                        // Compute temp = kv * time_faaaa + prev_state
9535
0
                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
9536
9537
                        // Update dst: dst += temp * r
9538
0
                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
9539
0
                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
9540
9541
                        // Update state: state = prev_state * time_decay + kv
9542
0
                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
9543
0
                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
9544
0
                    }
9545
9546
                    // Handle remaining elements, this will not be used.
9547
0
                    for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
9548
0
                        size_t t_h_j_offset = t_h_offset + j;
9549
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
9550
0
                        float v_val = v[t_h_j_offset];
9551
0
                        float kv_val = v_val * k_val;
9552
0
                        float prev_state_val = state_prev[h_2d_i_j_offset];
9553
0
                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
9554
0
                        dst_data[t_h_j_offset] += temp_val * r_val;
9555
0
                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
9556
0
                    }
9557
0
                }
9558
0
            }
9559
0
        }
9560
9561
    #else
9562
        // basically fused operations:
9563
        // dst = r @ (time_faaaa * (k @ v) + state),
9564
        // state = time_decay * state + (k @ v),
9565
        // recursive through each token
9566
        for (int64_t t = 0; t < T; t++) {
9567
            size_t t_offset = t * t_stride;
9568
            size_t state_offset = head_size * C * (t / (T / n_seqs));
9569
            float * state_cur = state + state_offset;
9570
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
9571
9572
            for (int64_t h = h_start; h < h_end; h++) {
9573
                size_t h_offset = h * h_stride;
9574
                size_t t_h_offset = t_offset + h_offset;
9575
                size_t h_2d_offset = h * h_stride_2d;
9576
9577
                for (int64_t i = 0; i < head_size; i++) {
9578
                    size_t t_h_i_offset = t_h_offset + i;
9579
                    size_t h_i_offset = h_offset + i;
9580
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9581
9582
                    float k_val = k[t_h_i_offset];
9583
                    float r_val = r[t_h_i_offset];
9584
                    float time_faaaa_val = time_faaaa[h_i_offset];
9585
                    // RWKV v6: different time_decay for each token.
9586
                    float time_decay_val = time_decay[t_h_i_offset];
9587
9588
                    for (int64_t j = 0; j < head_size; j++) {
9589
                        size_t t_h_j_offset = t_h_offset + j;
9590
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
9591
9592
                        float v_val = v[t_h_j_offset];
9593
                        float kv_val = v_val * k_val;
9594
                        float prev_state_val = state_prev[h_2d_i_j_offset];
9595
                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
9596
                        dst_data[t_h_j_offset] += temp_val * r_val;
9597
                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
9598
                    }
9599
                }
9600
            }
9601
        }
9602
    #endif
9603
0
}
9604
9605
9606
void ggml_compute_forward_rwkv_wkv6(
9607
        const ggml_compute_params * params,
9608
0
        ggml_tensor * dst) {
9609
9610
0
    const ggml_tensor * src0 = dst->src[0];
9611
9612
0
    switch (src0->type) {
9613
0
        case GGML_TYPE_F32:
9614
0
            {
9615
0
                ggml_compute_forward_rwkv_wkv6_f32(params, dst);
9616
0
            } break;
9617
0
        default:
9618
0
            {
9619
0
                GGML_ABORT("fatal error");
9620
0
            }
9621
0
    }
9622
0
}
9623
9624
// ggml_compute_forward_gla
9625
9626
static void ggml_compute_forward_gla_f32(
9627
        const ggml_compute_params * params,
9628
0
        ggml_tensor * dst) {
9629
0
    const int64_t T = dst->src[1]->ne[2];
9630
0
    const int64_t C = dst->ne[0];
9631
0
    const int64_t HEADS = dst->src[1]->ne[1];
9632
0
    const int64_t n_seqs = dst->src[4]->ne[1];
9633
0
    const int64_t head_size = C / HEADS;
9634
0
    const float scale = ggml_get_op_params_f32(dst, 0);
9635
9636
0
    float * dst_data = (float *) dst->data;
9637
0
    float * state = ((float *) dst->data) + C * T;
9638
9639
0
    const int ith = params->ith;
9640
0
    const int nth = params->nth;
9641
9642
0
    if (ith >= HEADS) {
9643
0
        return;
9644
0
    }
9645
9646
0
    const int h_start = (HEADS * ith) / nth;
9647
0
    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9648
0
                (HEADS * (ith + 1)) / nth : HEADS;
9649
9650
0
    float * k = (float *) dst->src[0]->data;
9651
0
    float * v = (float *) dst->src[1]->data;
9652
0
    float * q = (float *) dst->src[2]->data;
9653
0
    float * g = (float *) dst->src[3]->data;
9654
9655
0
    size_t t_stride = HEADS * head_size; // Same to C
9656
9657
0
    size_t h_stride = C / HEADS;
9658
0
    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
9659
0
    size_t h_stride_2d = head_size * head_size;
9660
9661
0
    if (ith == 0) {
9662
0
        memset(dst_data, 0, T * C * sizeof(float));
9663
0
    }
9664
0
    ggml_barrier(params->threadpool);
9665
9666
9667
0
    #if defined(__AVX__) && !defined(__AVX512F__)
9668
0
        #define GGML_F32X GGML_F32x8
9669
0
        #define GGML_F32X_SET1 GGML_F32x8_SET1
9670
0
        #define GGML_F32X_LOAD GGML_F32x8_LOAD
9671
0
        #define GGML_F32X_STORE GGML_F32x8_STORE
9672
0
        #define GGML_F32X_MUL GGML_F32x8_MUL
9673
0
        #define GGML_F32X_FMA GGML_F32x8_FMA
9674
0
        #define GLA_VECTOR_SIZE 8
9675
    #elif defined(__AVX512F__)
9676
        #define GGML_F32X GGML_F32x16
9677
        #define GGML_F32X_SET1 GGML_F32x16_SET1
9678
        #define GGML_F32X_LOAD GGML_F32x16_LOAD
9679
        #define GGML_F32X_STORE GGML_F32x16_STORE
9680
        #define GGML_F32X_MUL GGML_F32x16_MUL
9681
        #define GGML_F32X_FMA GGML_F32x16_FMA
9682
        #define GLA_VECTOR_SIZE 16
9683
    #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9684
        #define GGML_F32X GGML_F32xt
9685
        #define GGML_F32X_SET1 GGML_F32xt_SET1
9686
        #define GGML_F32X_LOAD GGML_F32xt_LOAD
9687
        #define GGML_F32X_STORE GGML_F32xt_STORE
9688
        #define GGML_F32X_MUL GGML_F32xt_MUL
9689
        #define GGML_F32X_FMA GGML_F32xt_FMA
9690
        #define GLA_VECTOR_SIZE 8
9691
    #elif defined(__ARM_NEON) && defined(__aarch64__)
9692
        #define GGML_F32X GGML_F32x4
9693
        #define GGML_F32X_SET1 GGML_F32x4_SET1
9694
        #define GGML_F32X_LOAD GGML_F32x4_LOAD
9695
        #define GGML_F32X_STORE GGML_F32x4_STORE
9696
        #define GGML_F32X_MUL GGML_F32x4_MUL
9697
        #define GGML_F32X_FMA GGML_F32x4_FMA
9698
        #define GLA_VECTOR_SIZE 4
9699
    #endif
9700
9701
0
    #ifdef GLA_VECTOR_SIZE
9702
0
        int gla_vector_size;
9703
        #if defined(__ARM_FEATURE_SVE)
9704
            gla_vector_size = svcntw();
9705
        #else
9706
0
            gla_vector_size = GLA_VECTOR_SIZE;
9707
0
        #endif
9708
0
        const int64_t vec_count = head_size / gla_vector_size;
9709
9710
0
        for (int64_t t = 0; t < T; t++) {
9711
0
            size_t t_offset = t * t_stride;
9712
0
            size_t state_offset = head_size * C * (t / (T / n_seqs));
9713
0
            float * state_cur = state + state_offset;
9714
0
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
9715
9716
0
            for (int64_t h = h_start; h < h_end; h++) {
9717
0
                size_t h_offset = h * h_stride;
9718
0
                size_t t_h_offset = t_offset + h_offset;
9719
0
                size_t h_2d_offset = h * h_stride_2d;
9720
9721
0
                for (int64_t i = 0; i < head_size; i++) {
9722
0
                    size_t t_h_i_offset = t_h_offset + i;
9723
0
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9724
9725
0
                    float k_val = k[t_h_i_offset];
9726
0
                    float q_val = q[t_h_i_offset] * scale;
9727
0
                    float g_val = g[t_h_i_offset];
9728
9729
                    // Broadcast scalar values to vectors
9730
0
                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
9731
0
                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
9732
0
                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
9733
9734
0
                    for (int64_t j = 0; j < vec_count; j++) {
9735
0
                        size_t base_j = j * gla_vector_size;
9736
0
                        size_t t_h_j_offset = t_h_offset + base_j;
9737
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
9738
9739
                        // Load x elements at once
9740
0
                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
9741
0
                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
9742
0
                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
9743
9744
                        // Compute kv = v * k
9745
0
                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
9746
9747
                        // Compute temp = prev_state * g + kv
9748
0
                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
9749
9750
                        // Update dst: dst += temp * q
9751
0
                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
9752
0
                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
9753
9754
                        // Update state
9755
0
                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
9756
0
                    }
9757
9758
                    // Handle remaining elements, this will not be used.
9759
0
                    for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
9760
0
                        size_t t_h_j_offset = t_h_offset + j;
9761
0
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
9762
0
                        float v_val = v[t_h_j_offset];
9763
0
                        float kv_val = v_val * k_val;
9764
0
                        float prev_state_val = state_prev[h_2d_i_j_offset];
9765
0
                        float temp_val = kv_val + prev_state_val * g_val;
9766
0
                        dst_data[t_h_j_offset] += temp_val * q_val;
9767
0
                        state_cur[h_2d_i_j_offset] = temp_val;
9768
0
                    }
9769
0
                }
9770
0
            }
9771
0
        }
9772
9773
    #else
9774
        for (int64_t t = 0; t < T; t++) {
9775
            size_t t_offset = t * t_stride;
9776
            size_t state_offset = head_size * C * (t / (T / n_seqs));
9777
            float * state_cur = state + state_offset;
9778
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
9779
9780
            for (int64_t h = h_start; h < h_end; h++) {
9781
                size_t h_offset = h * h_stride;
9782
                size_t t_h_offset = t_offset + h_offset;
9783
                size_t h_2d_offset = h * h_stride_2d;
9784
9785
                for (int64_t i = 0; i < head_size; i++) {
9786
                    size_t t_h_i_offset = t_h_offset + i;
9787
                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9788
9789
                    float k_val = k[t_h_i_offset];
9790
                    float q_val = q[t_h_i_offset] * scale;
9791
                    float g_val = g[t_h_i_offset];
9792
9793
                    for (int64_t j = 0; j < head_size; j++) {
9794
                        size_t t_h_j_offset = t_h_offset + j;
9795
                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
9796
9797
                        float v_val = v[t_h_j_offset];
9798
                        float kv_val = v_val * k_val;
9799
                        float prev_state_val = state_prev[h_2d_i_j_offset];
9800
                        float temp_val = prev_state_val * g_val + kv_val;
9801
                        dst_data[t_h_j_offset] += temp_val * q_val;
9802
                        state_cur[h_2d_i_j_offset] = temp_val;
9803
                    }
9804
                }
9805
            }
9806
        }
9807
    #endif
9808
0
}
9809
9810
9811
void ggml_compute_forward_gla(
9812
        const ggml_compute_params * params,
9813
0
        ggml_tensor * dst) {
9814
9815
0
    const ggml_tensor * src0 = dst->src[0];
9816
9817
0
    switch (src0->type) {
9818
0
        case GGML_TYPE_F32:
9819
0
            {
9820
0
                ggml_compute_forward_gla_f32(params, dst);
9821
0
            } break;
9822
0
        default:
9823
0
            {
9824
0
                GGML_ABORT("fatal error");
9825
0
            }
9826
0
    }
9827
0
}
9828
9829
0
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
9830
0
    const struct ggml_tensor * src0 = dst->src[0];  // A (lower triangular)
9831
0
    const struct ggml_tensor * src1 = dst->src[1];  // B (RHS)
9832
9833
0
    GGML_TENSOR_BINARY_OP_LOCALS;
9834
9835
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
9836
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
9837
0
    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
9838
9839
0
    GGML_ASSERT(ne00 == ne01); // A must be square
9840
0
    GGML_ASSERT(ne0  == ne10); // solution cols == B cols
9841
0
    GGML_ASSERT(ne1  == ne11); // solution rows == B rows
9842
9843
0
    GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
9844
0
    GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
9845
9846
0
    const int ith = params->ith;
9847
0
    const int nth = params->nth;
9848
9849
0
    const int64_t k = ne10;   // number of RHS columns
9850
0
    const int64_t n = ne11;   // A is n×n
9851
0
    const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
9852
9853
    // chunks per thread
9854
0
    const int64_t dr = (nr + nth - 1)/nth;
9855
9856
    // chunk range for this thread
9857
0
    const int64_t ir0 = dr*ith;
9858
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
9859
9860
0
    const float * A = (const float *) src0->data;  // [n, n, B1, B2]
9861
0
    const float * B = (const float *) src1->data;  // [n, k, B1, B2]
9862
0
          float * X = (      float *) dst->data;   // [n, k, B1, B2]
9863
9864
0
    for (int64_t ir = ir0; ir < ir1; ++ir) {
9865
0
        const int64_t i03 = ir/(ne02*k);
9866
0
        const int64_t i02 = (ir - i03*ne02*k)/k;
9867
0
        const int64_t i01 = (ir - i03*ne02*k - i02*k);
9868
9869
0
        const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
9870
0
        const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
9871
9872
0
        float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
9873
9874
0
        for (int64_t i00 = 0; i00 < n; ++i00) {
9875
0
            float sum = 0.0f;
9876
0
            for (int64_t t = 0; t < i00; ++t) {
9877
0
                sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
9878
0
            }
9879
9880
0
            const float diag = A_batch[i00 * n + i00];
9881
0
            assert(diag != 0.0f && "Zero diagonal in triangular matrix");
9882
9883
0
            X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
9884
0
        }
9885
0
    }
9886
0
}
9887
9888
0
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
9889
0
    const ggml_tensor * src0 = dst->src[0];
9890
0
    const ggml_tensor * src1 = dst->src[1];
9891
9892
0
    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
9893
0
        ggml_compute_forward_solve_tri_f32(params, dst);
9894
0
    } else {
9895
0
        GGML_ABORT("fatal error");
9896
0
    }
9897
0
}
9898
9899
// ggml_compute_forward_rwkv_wkv7
9900
9901
static void ggml_compute_forward_rwkv_wkv7_f32(
9902
        const ggml_compute_params * params,
9903
0
        ggml_tensor * dst) {
9904
0
    const int64_t T = dst->src[1]->ne[2];
9905
0
    const int64_t C = dst->ne[0];
9906
0
    const int64_t HEADS = dst->src[1]->ne[1];
9907
0
    const int64_t n_seqs = dst->src[6]->ne[1];
9908
0
    const int64_t head_size = C / HEADS;
9909
9910
0
    float * dst_data = (float *) dst->data;
9911
0
    float * state = ((float *) dst->data) + C * T;
9912
9913
0
    const int ith = params->ith;
9914
0
    const int nth = params->nth;
9915
9916
0
    if (ith >= HEADS) {
9917
0
        return;
9918
0
    }
9919
9920
0
    const int h_start = (HEADS * ith) / nth;
9921
0
    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9922
0
                (HEADS * (ith + 1)) / nth : HEADS;
9923
9924
0
    float * r = (float *) dst->src[0]->data;
9925
0
    float * w = (float *) dst->src[1]->data;
9926
0
    float * k = (float *) dst->src[2]->data;
9927
0
    float * v = (float *) dst->src[3]->data;
9928
0
    float * a = (float *) dst->src[4]->data;
9929
0
    float * b = (float *) dst->src[5]->data;
9930
9931
0
    int64_t t_stride = HEADS * head_size; // Same to C
9932
9933
0
    int64_t h_stride = C / HEADS;
9934
0
    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
9935
0
    int64_t h_stride_2d = head_size * head_size;
9936
9937
0
    #if defined(GGML_SIMD)
9938
        #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9939
            // scalar Route to scalar implementation       //TODO: Write SVE code and RVV code
9940
            for (int64_t t = 0; t < T; t++) {
9941
                int64_t t_offset = t * t_stride;
9942
                int64_t state_offset = head_size * C * (t / (T / n_seqs));
9943
                float * state_cur = state + state_offset;
9944
                float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9945
9946
                for (int64_t h = h_start; h < h_end; h++) {
9947
                    int64_t h_offset = h * h_stride;
9948
                    int64_t t_h_offset = t_offset + h_offset;
9949
                    int64_t h_2d_offset = h * h_stride_2d;
9950
9951
                    for (int64_t i = 0; i < head_size; i++) {
9952
                        int64_t t_h_i_offset = t_h_offset + i;
9953
                        int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
9954
9955
                        float v_val = v[t_h_i_offset];
9956
9957
                        float sa = 0, result = 0;
9958
                        for (int64_t j = 0; j < head_size; j++) {
9959
                            sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
9960
                        }
9961
9962
                        for (int64_t j = 0; j < head_size; j++) {
9963
                            int64_t t_h_j_offset = t_h_offset + j;
9964
                            int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9965
9966
                            float r_val = r[t_h_j_offset];
9967
                            float w_val = w[t_h_j_offset];
9968
                            float k_val = k[t_h_j_offset];
9969
                            float b_val = b[t_h_j_offset];
9970
                            float kv_val = v_val * k_val;
9971
                            float prev_state_val = state_prev[h_2d_i_j_offset];
9972
                            state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9973
                            result += state_cur[h_2d_i_j_offset] * r_val;
9974
                        }
9975
                        dst_data[t_h_i_offset] = result;
9976
                    }
9977
                }
9978
            }
9979
        #else
9980
0
            for (int64_t t = 0; t < T; t++) {
9981
0
                int64_t t_offset = t * t_stride;
9982
0
                int64_t state_offset = head_size * C * (t / (T / n_seqs));
9983
0
                float * state_cur = state + state_offset;
9984
0
                float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9985
9986
0
                for (int64_t h = h_start; h < h_end; h++) {
9987
0
                    int64_t h_offset = h * h_stride;
9988
0
                    int64_t t_h_offset = t_offset + h_offset;
9989
0
                    int64_t h_2d_offset = h * h_stride_2d;
9990
9991
0
                    for (int64_t ii = 0; ii < head_size; ii++) {
9992
0
                        int64_t t_h_i_offset = t_h_offset + ii;
9993
0
                        int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9994
9995
0
                        GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
9996
9997
0
                        float sa = 0;
9998
0
                        {
9999
0
                            GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
10000
0
                            GGML_F32_VEC ax[GGML_F32_ARR];
10001
0
                            GGML_F32_VEC ay[GGML_F32_ARR];
10002
0
                            for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
10003
0
                                for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
10004
0
                                    ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
10005
0
                                    ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
10006
0
                                    sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
10007
0
                                }
10008
0
                            }
10009
0
                            GGML_F32_VEC_REDUCE(sa, sum);
10010
0
                        }
10011
10012
0
                        GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
10013
10014
0
                        int64_t j = 0;
10015
0
                        GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
10016
0
                        for (; j < head_size; j += GGML_F32_STEP) {
10017
0
                            for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
10018
0
                                int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
10019
0
                                int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
10020
10021
0
                                GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
10022
0
                                GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
10023
0
                                GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
10024
0
                                GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
10025
10026
0
                                k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
10027
10028
0
                                GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
10029
                                // kv + s * decay + sa * b
10030
0
                                state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
10031
0
                                state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
10032
0
                                GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
10033
10034
0
                                result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
10035
0
                            }
10036
0
                        }
10037
0
                        GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
10038
10039
                        // There shouldn't be left-overs though.
10040
0
                        for (; j < head_size; j++) {
10041
0
                            int64_t t_h_j_offset = t_h_offset + j;
10042
0
                            int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10043
10044
0
                            float r_val = r[t_h_j_offset];
10045
0
                            float w_val = w[t_h_j_offset];
10046
0
                            float k_val = k[t_h_j_offset];
10047
0
                            float b_val = b[t_h_j_offset];
10048
0
                            float kv_val = v[t_h_i_offset] * k_val;
10049
10050
0
                            float prev_state_val = state_prev[h_2d_i_j_offset];
10051
0
                            state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10052
0
                            dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
10053
0
                        }
10054
0
                    }
10055
0
                }
10056
0
            }
10057
0
        #endif
10058
    #else
10059
        for (int64_t t = 0; t < T; t++) {
10060
            int64_t t_offset = t * t_stride;
10061
            int64_t state_offset = head_size * C * (t / (T / n_seqs));
10062
            float * state_cur = state + state_offset;
10063
            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10064
10065
            for (int64_t h = h_start; h < h_end; h++) {
10066
                int64_t h_offset = h * h_stride;
10067
                int64_t t_h_offset = t_offset + h_offset;
10068
                int64_t h_2d_offset = h * h_stride_2d;
10069
10070
                for (int64_t i = 0; i < head_size; i++) {
10071
                    int64_t t_h_i_offset = t_h_offset + i;
10072
                    int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
10073
10074
                    float v_val = v[t_h_i_offset];
10075
10076
                    float sa = 0, result = 0;
10077
                    for (int64_t j = 0; j < head_size; j++) {
10078
                        sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
10079
                    }
10080
10081
                    for (int64_t j = 0; j < head_size; j++) {
10082
                        int64_t t_h_j_offset = t_h_offset + j;
10083
                        int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10084
10085
                        float r_val = r[t_h_j_offset];
10086
                        float w_val = w[t_h_j_offset];
10087
                        float k_val = k[t_h_j_offset];
10088
                        float b_val = b[t_h_j_offset];
10089
                        float kv_val = v_val * k_val;
10090
                        float prev_state_val = state_prev[h_2d_i_j_offset];
10091
                        state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10092
                        result += state_cur[h_2d_i_j_offset] * r_val;
10093
                    }
10094
                    dst_data[t_h_i_offset] = result;
10095
                }
10096
            }
10097
        }
10098
    #endif
10099
0
}
10100
10101
10102
void ggml_compute_forward_rwkv_wkv7(
10103
        const ggml_compute_params * params,
10104
0
        ggml_tensor * dst) {
10105
10106
0
    const ggml_tensor * src0 = dst->src[0];
10107
10108
0
    switch (src0->type) {
10109
0
        case GGML_TYPE_F32:
10110
0
            {
10111
0
                ggml_compute_forward_rwkv_wkv7_f32(params, dst);
10112
0
            } break;
10113
0
        default:
10114
0
            {
10115
0
                GGML_ABORT("fatal error");
10116
0
            }
10117
0
    }
10118
0
}
10119
10120
// ggml_compute_forward_map_custom1
10121
10122
void ggml_compute_forward_map_custom1(
10123
        const ggml_compute_params * params,
10124
0
              ggml_tensor * dst) {
10125
10126
0
    const ggml_tensor * a = dst->src[0];
10127
10128
0
    struct ggml_map_custom1_op_params p;
10129
0
    memcpy(&p, dst->op_params, sizeof(p));
10130
10131
0
    p.fun(dst, a, params->ith, params->nth, p.userdata);
10132
0
}
10133
10134
// ggml_compute_forward_map_custom2
10135
10136
void ggml_compute_forward_map_custom2(
10137
        const ggml_compute_params * params,
10138
0
              ggml_tensor * dst) {
10139
10140
0
    const ggml_tensor * a = dst->src[0];
10141
0
    const ggml_tensor * b = dst->src[1];
10142
10143
0
    struct ggml_map_custom2_op_params p;
10144
0
    memcpy(&p, dst->op_params, sizeof(p));
10145
10146
0
    p.fun(dst, a, b, params->ith, params->nth, p.userdata);
10147
0
}
10148
10149
// ggml_compute_forward_map_custom3
10150
10151
void ggml_compute_forward_map_custom3(
10152
        const ggml_compute_params * params,
10153
0
              ggml_tensor * dst) {
10154
10155
0
    const ggml_tensor * a = dst->src[0];
10156
0
    const ggml_tensor * b = dst->src[1];
10157
0
    const ggml_tensor * c = dst->src[2];
10158
10159
0
    struct ggml_map_custom3_op_params p;
10160
0
    memcpy(&p, dst->op_params, sizeof(p));
10161
10162
0
    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
10163
0
}
10164
10165
// ggml_compute_forward_custom
10166
10167
void ggml_compute_forward_custom(
10168
    const struct ggml_compute_params * params,
10169
0
          struct ggml_tensor * dst) {
10170
10171
0
    struct ggml_custom_op_params p;
10172
0
    memcpy(&p, dst->op_params, sizeof(p));
10173
10174
0
    p.fun(dst, params->ith, params->nth, p.userdata);
10175
0
}
10176
10177
// ggml_compute_forward_cross_entropy_loss
10178
10179
static void ggml_compute_forward_cross_entropy_loss_f32(
10180
        const ggml_compute_params * params,
10181
0
        ggml_tensor * dst) {
10182
10183
0
    const ggml_tensor * src0 = dst->src[0];
10184
0
    const ggml_tensor * src1 = dst->src[1];
10185
10186
0
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
10187
0
    GGML_ASSERT(src1->type == GGML_TYPE_F32);
10188
0
    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
10189
0
    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
10190
0
    GGML_ASSERT(ggml_are_same_shape(src0, src1));
10191
0
    GGML_ASSERT(ggml_is_scalar(dst));
10192
0
    GGML_ASSERT(dst->type == GGML_TYPE_F32);
10193
10194
    // TODO: handle transposed/permuted matrices
10195
0
    const int64_t nc = src0->ne[0];
10196
0
    const int64_t nr = ggml_nrows(src0);
10197
10198
0
    const int ith = params->ith;
10199
0
    const int nth = params->nth;
10200
10201
0
    float * sums =  (float *) params->wdata;
10202
0
    float * st   = ((float *) params->wdata) + nth + ith*nc;
10203
0
    float sum_thread = 0.0f;
10204
10205
0
    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
10206
10207
    // rows per thread
10208
0
    const int64_t dr = (nr + nth - 1)/nth;
10209
10210
    // row range for this thread
10211
0
    const int64_t ir0 = dr*ith;
10212
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
10213
10214
0
    for (int64_t i1 = ir0; i1 < ir1; ++i1) {
10215
0
        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
10216
0
        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
10217
10218
#ifndef NDEBUG
10219
        for (int64_t i = 0; i < nc; ++i) {
10220
            //printf("p[%d] = %f\n", i, p[i]);
10221
            assert(!isnan(s0[i]));
10222
            assert(!isnan(s1[i]));
10223
        }
10224
#endif
10225
10226
0
        float max = -INFINITY;
10227
0
        ggml_vec_max_f32(nc, &max, s0);
10228
0
        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
10229
0
        assert(sum_softmax >= 0.0);
10230
10231
0
        ggml_vec_add1_f32(nc, st, st, -sum_softmax);
10232
0
        ggml_vec_mul_f32(nc, st, st, s1);
10233
10234
0
        float sum_st = 0.0f;
10235
0
        ggml_vec_sum_f32(nc, &sum_st, st);
10236
0
        sum_thread += sum_st;
10237
10238
#ifndef NDEBUG
10239
        for (int64_t i = 0; i < nc; ++i) {
10240
            assert(!isnan(st[i]));
10241
            assert(!isinf(st[i]));
10242
        }
10243
#endif
10244
0
    }
10245
0
    sums[ith] = sum_thread;
10246
0
    ggml_barrier(params->threadpool);
10247
10248
0
    if (ith == 0) {
10249
0
        float * dp = (float *) dst->data;
10250
0
        ggml_vec_sum_f32(nth, dp, sums);
10251
0
        dp[0] *= -1.0f / (float) nr;
10252
0
    }
10253
0
}
10254
10255
void ggml_compute_forward_cross_entropy_loss(
10256
        const ggml_compute_params * params,
10257
0
        ggml_tensor * dst) {
10258
10259
0
    const ggml_tensor * src0 = dst->src[0];
10260
10261
0
    switch (src0->type) {
10262
0
        case GGML_TYPE_F32:
10263
0
            {
10264
0
                ggml_compute_forward_cross_entropy_loss_f32(params, dst);
10265
0
            } break;
10266
0
        default:
10267
0
            {
10268
0
                GGML_ABORT("fatal error");
10269
0
            }
10270
0
    }
10271
0
}
10272
10273
// ggml_compute_forward_cross_entropy_loss_back
10274
10275
static void ggml_compute_forward_cross_entropy_loss_back_f32(
10276
        const ggml_compute_params * params,
10277
0
        ggml_tensor * dst) {
10278
10279
0
    const ggml_tensor * grad  = dst->src[0]; // gradient of forward pass output
10280
0
    const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
10281
0
    const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
10282
10283
0
    GGML_ASSERT(ggml_is_contiguous(dst));
10284
0
    GGML_ASSERT(ggml_is_contiguous(src0f));
10285
0
    GGML_ASSERT(ggml_is_contiguous(src1f));
10286
0
    GGML_ASSERT(ggml_is_contiguous(grad));
10287
0
    GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
10288
10289
0
    const int64_t ith = params->ith;
10290
0
    const int64_t nth = params->nth;
10291
10292
    // TODO: handle transposed/permuted matrices
10293
0
    const int64_t nc = src0f->ne[0];
10294
0
    const int64_t nr = ggml_nrows(src0f);
10295
10296
    // rows per thread
10297
0
    const int64_t dr = (nr + nth - 1)/nth;
10298
10299
    // row range for this thread
10300
0
    const int64_t ir0 = dr*ith;
10301
0
    const int64_t ir1 = MIN(ir0 + dr, nr);
10302
10303
0
    const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
10304
10305
0
    for (int64_t i1 = ir0; i1 < ir1; i1++) {
10306
0
        float       * ds0 = (float       *)((char       *) dst->data   + i1*dst->nb[1]);
10307
0
        const float * s0  = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
10308
0
        const float * s1  = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
10309
10310
#ifndef NDEBUG
10311
        for (int64_t i = 0; i < nc; ++i) {
10312
            //printf("p[%d] = %f\n", i, p[i]);
10313
            assert(!isnan(s0[i]));
10314
            assert(!isnan(s1[i]));
10315
        }
10316
#endif
10317
10318
        // soft_max
10319
0
        float max = -INFINITY;
10320
0
        ggml_vec_max_f32(nc, &max, s0);
10321
0
        const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
10322
0
        assert(sum > 0.0);
10323
0
        ggml_vec_scale_f32(nc, ds0, 1.0/sum);
10324
10325
        // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
10326
0
        ggml_vec_sub_f32(nc, ds0, ds0, s1);
10327
0
        ggml_vec_scale_f32(nc, ds0, d_by_nr);
10328
10329
#ifndef NDEBUG
10330
        for (int64_t i = 0; i < nc; ++i) {
10331
            assert(!isnan(ds0[i]));
10332
            assert(!isinf(ds0[i]));
10333
        }
10334
#endif
10335
0
    }
10336
0
}
10337
10338
void ggml_compute_forward_cross_entropy_loss_back(
10339
        const ggml_compute_params * params,
10340
0
        ggml_tensor * dst) {
10341
10342
0
    const ggml_tensor * src0 = dst->src[0];
10343
10344
0
    switch (src0->type) {
10345
0
        case GGML_TYPE_F32:
10346
0
            {
10347
0
                ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
10348
0
            } break;
10349
0
        default:
10350
0
            {
10351
0
                GGML_ABORT("fatal error");
10352
0
            }
10353
0
    }
10354
0
}
10355
10356
static void ggml_compute_forward_opt_step_adamw_f32(
10357
        const ggml_compute_params * params,
10358
0
        ggml_tensor * dst) {
10359
10360
0
    const ggml_tensor * src0         = dst->src[0];
10361
0
    const ggml_tensor * src0_grad    = dst->src[1];
10362
0
    const ggml_tensor * src0_grad_m  = dst->src[2];
10363
0
    const ggml_tensor * src0_grad_v  = dst->src[3];
10364
0
    const ggml_tensor * adamw_params = dst->src[4];
10365
10366
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10367
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
10368
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
10369
0
    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
10370
10371
0
    const int ith = params->ith;
10372
0
    const int nth = params->nth;
10373
10374
0
    const int nr  = ggml_nrows(src0);
10375
10376
0
    GGML_TENSOR_UNARY_OP_LOCALS
10377
0
    GGML_ASSERT(nb00 == sizeof(float));
10378
10379
    // rows per thread
10380
0
    const int dr = (nr + nth - 1)/nth;
10381
10382
    // row range for this thread
10383
0
    const int ir0 = dr*ith;
10384
0
    const int ir1 = MIN(ir0 + dr, nr);
10385
10386
0
    const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
10387
10388
0
    const float alpha  = adamw_params_ptr[0];
10389
0
    const float beta1  = adamw_params_ptr[1];
10390
0
    const float beta2  = adamw_params_ptr[2];
10391
0
    const float eps    = adamw_params_ptr[3];
10392
0
    const float wd     = adamw_params_ptr[4];
10393
0
    const float beta1h = adamw_params_ptr[5];
10394
0
    const float beta2h = adamw_params_ptr[6];
10395
0
    const float keep   = 1.f - alpha * wd;
10396
0
    for (int ir = ir0; ir < ir1; ++ir) {
10397
0
        const int64_t i03 = ir/(ne02*ne01);
10398
0
        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10399
0
        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10400
10401
0
        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
10402
10403
0
        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
10404
0
        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
10405
0
        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
10406
0
        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
10407
10408
0
        for (int i00 = 0; i00 < ne00; ++i00) {
10409
0
            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
10410
0
            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
10411
10412
0
            const float mh =       m[i00]*beta1h;
10413
0
            const float vh = sqrtf(v[i00]*beta2h) + eps;
10414
10415
            // The weight decay is applied independently of the Adam momenta m and v.
10416
            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
10417
            // See: https://arxiv.org/pdf/1711.05101v3.pdf
10418
0
            w[i00] = w[i00] * keep - alpha * mh / vh;
10419
0
        }
10420
0
    }
10421
0
}
10422
10423
void ggml_compute_forward_opt_step_adamw(
10424
        const ggml_compute_params * params,
10425
0
        ggml_tensor * dst) {
10426
10427
0
    const ggml_tensor * src0 = dst->src[0];
10428
10429
0
    switch (src0->type) {
10430
0
        case GGML_TYPE_F32:
10431
0
            {
10432
0
                ggml_compute_forward_opt_step_adamw_f32(params, dst);
10433
0
            } break;
10434
0
        default:
10435
0
            {
10436
0
                GGML_ABORT("fatal error");
10437
0
            }
10438
0
    }
10439
0
}
10440
10441
0
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10442
0
    const ggml_tensor * src0       = dst->src[0];
10443
0
    const ggml_tensor * src0_grad  = dst->src[1];
10444
0
    const ggml_tensor * sgd_params = dst->src[2];
10445
10446
0
    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10447
0
    GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10448
10449
0
    const int ith = params->ith;
10450
0
    const int nth = params->nth;
10451
10452
0
    const int nr = ggml_nrows(src0);
10453
10454
0
    GGML_TENSOR_UNARY_OP_LOCALS
10455
0
    GGML_ASSERT(nb00 == sizeof(float));
10456
10457
    // rows per thread
10458
0
    const int dr = (nr + nth - 1) / nth;
10459
10460
    // row range for this thread
10461
0
    const int ir0 = dr * ith;
10462
0
    const int ir1 = MIN(ir0 + dr, nr);
10463
10464
    // using adamw param subset we care about - alpha, wd - could have a separate struct
10465
0
    const float * sgd_params_ptr   = ggml_get_data_f32(sgd_params);
10466
0
    const float   alpha            = sgd_params_ptr[0];
10467
0
    const float   keep             = 1.f - alpha * sgd_params_ptr[1];
10468
10469
0
    for (int ir = ir0; ir < ir1; ++ir) {
10470
0
        const int64_t i03 = ir / (ne02 * ne01);
10471
0
        const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10472
0
        const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10473
10474
0
        const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10475
10476
0
        float *       w = (float *) ((char *) src0->data + offset);                   // weight
10477
0
        const float * g = (const float *) ((const char *) src0_grad->data + offset);  // grad
10478
10479
0
        for (int i00 = 0; i00 < ne00; ++i00) {
10480
0
            w[i00] = w[i00] * keep - alpha * g[i00];
10481
0
        }
10482
0
    }
10483
0
}
10484
10485
0
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10486
0
    const ggml_tensor * src0 = dst->src[0];
10487
10488
0
    switch (src0->type) {
10489
0
        case GGML_TYPE_F32:
10490
0
            {
10491
0
                ggml_compute_forward_opt_step_sgd_f32(params, dst);
10492
0
            }
10493
0
            break;
10494
0
        default:
10495
0
            {
10496
0
                GGML_ABORT("fatal error - sgd is F32 only");
10497
0
            }
10498
0
    }
10499
0
}