Coverage Report

Created: 2026-06-22 06:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp
Line
Count
Source
1
// Copyright 2024 Mozilla Foundation
2
//
3
// Permission is hereby granted, free of charge, to any person obtaining
4
// a copy of this software and associated documentation files (the
5
// "Software"), to deal in the Software without restriction, including
6
// without limitation the rights to use, copy, modify, merge, publish,
7
// distribute, sublicense, and/or sell copies of the Software, and to
8
// permit persons to whom the Software is furnished to do so, subject to
9
// the following conditions:
10
//
11
// The above copyright notice and this permission notice shall be
12
// included in all copies or substantial portions of the Software.
13
//
14
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
18
// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
19
// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
// SOFTWARE.
22
23
//
24
//                   _   _          ___ _      _   ___
25
//                  | |_(_)_ _ _  _| _ ) |    /_\ / __|
26
//                  |  _| | ' \ || | _ \ |__ / _ \\__ \.
27
//                   \__|_|_||_\_, |___/____/_/ \_\___/
28
//                             |__/
29
//
30
//                    BASIC LINEAR ALGEBRA SUBPROGRAMS
31
//
32
//
33
// This file implements multithreaded CPU matrix multiplication for the
34
// common contiguous use case C = Aᵀ * B. These kernels are designed to
35
// have excellent performance[1] for matrices that fit in the CPU cache
36
// without imposing any overhead such as cache filling or malloc calls.
37
//
38
// This implementation does not guarantee any upper bound with rounding
39
// errors, which grow along with k. Our goal's to maximally exploit the
40
// hardware for performance, and then use whatever resources remain for
41
// improving numerical accuracy.
42
//
43
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
44
//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
45
46
#if defined(__GNUC__)
47
#pragma GCC diagnostic ignored "-Wpedantic"
48
#pragma GCC diagnostic ignored "-Wignored-attributes"
49
#endif
50
51
#include "sgemm.h"
52
#include "ggml-impl.h"
53
#include "ggml-cpu-impl.h"
54
#include "ggml-quants.h"
55
#include "simd-mappings.h"
56
57
#include <array>
58
#include <type_traits>
59
60
#ifdef _MSC_VER
61
#define NOINLINE __declspec(noinline)
62
#else
63
#define NOINLINE __attribute__((__noinline__))
64
#endif
65
66
#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
67
#define VECTOR_REGISTERS 32
68
#else
69
#define VECTOR_REGISTERS 16
70
#endif
71
72
#if defined(__riscv_v_intrinsic)
73
#define LMUL 4
74
#endif
75
76
0
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
77
78
namespace {
79
80
0
inline float unhalf(ggml_fp16_t d) {
81
0
    return GGML_CPU_FP16_TO_FP32(d);
82
0
}
83
84
////////////////////////////////////////////////////////////////////////////////////////////////////
85
// VECTORIZED ARITHMETIC OPERATIONS
86
87
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
88
0
inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
89
0
inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
90
0
inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
91
#endif  // __SSE__
92
93
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
94
0
inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
95
0
inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
96
0
inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
97
#endif // __AVX__
98
99
#if defined(__AVX512F__)
100
inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
101
inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
102
inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
103
#endif // __AVX512F__
104
105
#if defined(__ARM_NEON)
106
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
107
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
108
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
109
#endif // __ARM_NEON
110
111
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
112
inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
113
inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
114
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
115
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
116
117
#if defined(__VXE__) || defined(__VXE2__)
118
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
119
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
120
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
121
#endif
122
123
#if defined(__MMA__)
124
typedef vector unsigned char vec_t;
125
typedef __vector_quad acc_t;
126
#endif
127
////////////////////////////////////////////////////////////////////////////////////////////////////
128
// VECTORIZED FUSED MULTIPLY ADD
129
130
/**
131
 * Computes a * b + c.
132
 */
133
template <typename T, typename U>
134
inline U madd(T a, T b, U c) {
135
    return add(mul(a, b), c);
136
}
137
138
#if defined(__FMA__)
139
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
140
template <>
141
0
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
142
0
    return _mm256_fmadd_ps(a, b, c);
143
0
}
144
#endif
145
#if defined(__AVX512F__)
146
template <>
147
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
148
    return _mm512_fmadd_ps(a, b, c);
149
}
150
#endif
151
#if defined(__AVX512BF16__)
152
template <>
153
inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
154
    return _mm512_dpbf16_ps(c, a, b);
155
}
156
template <>
157
inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
158
    return _mm256_dpbf16_ps(c, a, b);
159
}
160
#endif
161
#endif
162
163
#if defined(__ARM_FEATURE_FMA)
164
template <>
165
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
166
    return vfmaq_f32(c, b, a);
167
}
168
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
169
template <>
170
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
171
    return vfmaq_f16(c, b, a);
172
}
173
#endif
174
#endif
175
176
#if defined(__VXE__) || defined(__VXE2__)
177
template <>
178
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
179
    return vec_madd(a, b, c);
180
}
181
#endif
182
183
#if defined(__riscv_v_intrinsic)
184
template <> inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
185
    return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
186
}
187
template <> inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
188
    return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
189
}
190
template <> inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
191
    return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
192
}
193
template <> inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
194
    return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
195
}
196
#endif
197
198
#if defined(__riscv_zvfh)
199
template <> inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
200
    return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
201
}
202
template <> inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
203
    return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
204
}
205
template <> inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
206
    return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
207
}
208
template <> inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
209
    return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
210
}
211
#endif
212
213
#if defined(__riscv_zvfbfwma)
214
template <> inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
215
    return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
216
}
217
template <> inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
218
    return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
219
}
220
template <> inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
221
    return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
222
}
223
template <> inline vfloat32m8_t madd(vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) {
224
    return __riscv_vfwmaccbf16_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
225
}
226
#endif
227
228
////////////////////////////////////////////////////////////////////////////////////////////////////
229
// VECTORIZED HORIZONTAL SUM
230
231
#if defined(__ARM_NEON)
232
inline float hsum(float32x4_t x) {
233
    return vaddvq_f32(x);
234
}
235
#endif // __ARM_NEON
236
237
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
238
inline float hsum(float16x8_t x) {
239
    return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
240
                                vcvt_f32_f16(vget_high_f16(x))));
241
}
242
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
243
244
#if defined(__VXE__) || defined(__VXE2__)
245
inline float hsum(float32x4_t x) {
246
    float32x4_t tmp = x + vec_reve(x);
247
    return tmp[0] + tmp[1];
248
}
249
#endif
250
251
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
252
0
inline float hsum(__m128 x) {
253
0
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
254
0
    x = _mm_add_ps(x, _mm_movehl_ps(x, x));
255
0
    x = _mm_add_ss(x, _mm_movehdup_ps(x));
256
#else
257
    __m128 t;
258
    t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
259
    x = _mm_add_ps(x, t);
260
    t = _mm_movehl_ps(t, x);
261
    x = _mm_add_ss(x, t);
262
#endif
263
0
    return _mm_cvtss_f32(x);
264
0
}
265
#endif
266
267
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
268
0
inline float hsum(__m256 x) {
269
0
    return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
270
0
                           _mm256_castps256_ps128(x)));
271
0
}
272
#endif // __AVX__
273
274
#if defined(__AVX512F__)
275
inline float hsum(__m512 x) {
276
    return _mm512_reduce_add_ps(x);
277
}
278
#endif // __AVX512F__
279
280
#if defined(__riscv_v_intrinsic)
281
inline float hsum(vfloat32m1_t x) {
282
    return __riscv_vfmv_f_s_f32m1_f32(
283
        __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
284
}
285
inline float hsum(vfloat32m2_t x) {
286
    return __riscv_vfmv_f_s_f32m1_f32(
287
        __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
288
}
289
inline float hsum(vfloat32m4_t x) {
290
    return __riscv_vfmv_f_s_f32m1_f32(
291
        __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
292
}
293
inline float hsum(vfloat32m8_t x) {
294
    return __riscv_vfmv_f_s_f32m1_f32(
295
        __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
296
}
297
#endif
298
299
////////////////////////////////////////////////////////////////////////////////////////////////////
300
// VECTORIZED MEMORY LOADING
301
302
template <typename T, typename U> T load(const U *);
303
304
#if defined(__ARM_NEON)
305
template <> inline float32x4_t load(const float *p) {
306
    return vld1q_f32(p);
307
}
308
#if !defined(_MSC_VER)
309
// FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
310
template <> inline float16x8_t load(const ggml_fp16_t *p) {
311
    return vld1q_f16((const float16_t *)p);
312
}
313
template <> inline float32x4_t load(const ggml_fp16_t *p) {
314
    return vcvt_f32_f16(vld1_f16((const float16_t *)p));
315
}
316
#endif // _MSC_VER
317
#endif // __ARM_NEON
318
319
#if defined(__VXE__) || defined(__VXE2__)
320
template <> inline float32x4_t load(const ggml_fp16_t * p) {
321
    float tmp[4];
322
323
    for (int i = 0; i < 4; i++) {
324
        tmp[i] = GGML_CPU_FP16_TO_FP32(p[i]);
325
    }
326
327
    return vec_xl(0, (const float *)(tmp));
328
}
329
template <> inline float32x4_t load(const float * p) {
330
    return vec_xl(0, p);
331
}
332
#endif
333
334
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
335
0
template <> inline __m128 load(const float *p) {
336
0
    return _mm_loadu_ps(p);
337
0
}
338
#endif  // __SSE__
339
340
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
341
0
template <> inline __m256 load(const float *p) {
342
0
    return _mm256_loadu_ps(p);
343
0
}
344
#endif // __AVX__
345
346
#if defined(__AVX2__) || defined(__AVX512F__)
347
0
template <> inline __m256 load(const ggml_bf16_t *p) {
348
0
    return _mm256_castsi256_ps(
349
0
        _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
350
0
}
351
#endif // __AVX2__
352
353
#if defined(__F16C__)
354
0
template <> inline __m256 load(const ggml_fp16_t *p) {
355
0
    return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
356
0
}
357
#endif // __F16C__
358
359
#if defined(__AVX512F__)
360
template <> inline __m512 load(const float *p) {
361
    return _mm512_loadu_ps(p);
362
}
363
template <> inline __m512 load(const ggml_fp16_t *p) {
364
    return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
365
}
366
template <> inline __m512 load(const ggml_bf16_t *p) {
367
    return _mm512_castsi512_ps(
368
        _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
369
}
370
#endif // __AVX512F__
371
372
#if defined(__AVX512BF16__)
373
template <> inline __m512bh load(const ggml_bf16_t *p) {
374
    return (__m512bh)_mm512_loadu_ps((const float *)p);
375
}
376
template <> inline __m256bh load(const ggml_bf16_t *p) {
377
    return (__m256bh)_mm256_loadu_ps((const float *)p);
378
}
379
template <> inline __m512bh load(const float *p) {
380
    return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
381
}
382
template <> inline __m256bh load(const float *p) {
383
    return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
384
}
385
#endif
386
387
#if defined(__riscv_v_intrinsic)
388
template <> inline vfloat32m1_t load(const float *p) {
389
    return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
390
}
391
template <> inline vfloat32m2_t load(const float *p) {
392
    return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
393
}
394
template <> inline vfloat32m4_t load(const float *p) {
395
    return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
396
}
397
template <> inline vfloat32m8_t load(const float *p) {
398
    return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
399
}
400
#endif
401
402
#if defined(__riscv_zvfh)
403
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
404
    return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
405
}
406
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
407
    return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
408
}
409
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
410
    return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
411
}
412
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
413
    return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
414
}
415
#endif
416
417
#if defined(__riscv_zvfbfwma)
418
template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
419
    return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
420
}
421
template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
422
    return __riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m1());
423
}
424
template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
425
    return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
426
}
427
template <> inline vbfloat16m4_t load(const ggml_bf16_t *p) {
428
    return __riscv_vle16_v_bf16m4(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m4());
429
}
430
#endif
431
432
#if defined(__riscv_v_intrinsic)
433
template <typename T> T set_zero();
434
435
template <> inline vfloat32m1_t set_zero() {
436
    return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
437
}
438
template <> inline vfloat32m2_t set_zero() {
439
    return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
440
}
441
template <> inline vfloat32m4_t set_zero() {
442
    return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
443
}
444
template <> inline vfloat32m8_t set_zero() {
445
    return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
446
}
447
#endif
448
449
#if defined(__riscv_v_intrinsic)
450
template <typename T> size_t vlmax() {
451
    if constexpr (std::is_same_v<T, vfloat32m1_t>) { return  __riscv_vsetvlmax_e32m1(); }
452
    else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return  __riscv_vsetvlmax_e32m2(); }
453
    else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return  __riscv_vsetvlmax_e32m4(); }
454
    else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return  __riscv_vsetvlmax_e32m8(); }
455
    #if defined (__riscv_zvfh)
456
    else if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return  __riscv_vsetvlmax_e16mf2(); }
457
    else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return  __riscv_vsetvlmax_e16m1(); }
458
    else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return  __riscv_vsetvlmax_e16m2(); }
459
    else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return  __riscv_vsetvlmax_e16m4(); }
460
    #endif
461
    #if defined (__riscv_zvfbfwma)
462
    else if constexpr (std::is_same_v<T, vbfloat16mf2_t>) { return  __riscv_vsetvlmax_e16mf2(); }
463
    else if constexpr (std::is_same_v<T, vbfloat16m1_t>) { return  __riscv_vsetvlmax_e16m1(); }
464
    else if constexpr (std::is_same_v<T, vbfloat16m2_t>) { return  __riscv_vsetvlmax_e16m2(); }
465
    else if constexpr (std::is_same_v<T, vbfloat16m4_t>) { return  __riscv_vsetvlmax_e16m4(); }
466
    #endif
467
    return 0;
468
}
469
#endif
470
471
////////////////////////////////////////////////////////////////////////////////////////////////////
472
// FLOATING POINT MATRIX MULTIPLICATION
473
474
template <int M>
475
0
static inline int64_t BLOCK_SIZE(size_t m) {
476
0
    const int64_t NB_BLOC_M = (m + M - 1) / M;
477
0
    return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
478
0
}
479
480
0
static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
481
0
    return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
482
0
}
483
484
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
485
class tinyBLAS {
486
  public:
487
    tinyBLAS(const ggml_compute_params * params, int64_t k,
488
             const TA *A, int64_t lda,
489
             const TB *B, int64_t ldb,
490
             TC *C, int64_t ldc)
491
0
        : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
492
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::tinyBLAS(ggml_compute_params const*, long, float const*, long, float const*, long, float*, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::tinyBLAS(ggml_compute_params const*, long, ggml_bf16_t const*, long, ggml_bf16_t const*, long, float*, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::tinyBLAS(ggml_compute_params const*, long, unsigned short const*, long, unsigned short const*, long, float*, long)
493
494
0
    bool matmul(int64_t m, int64_t n) {
495
0
        if (k % KN != 0)
496
0
            return false;
497
        // compute RM for only need tile with size RM&RM-1
498
#if VECTOR_REGISTERS == 32
499
        if (m % 16 == 0 && (m/16 >= params->nth)) {
500
            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
501
            mnpack<4, 6, 4>(m, n, SIZE_N, 12);
502
            return true;
503
        }
504
        if (m % 8 == 0 ) {
505
            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
506
            mnpack<4, 6, 2>(m, n, SIZE_N, 12);
507
            return true;
508
        }
509
        if (m % 4 == 0) {
510
            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
511
            mnpack<4, 6, 1>(m, n, SIZE_N, 12);
512
            return true;
513
        }
514
#else  // VECTOR_REGISTERS == 16
515
0
        if (m % 16 == 0 && (m/16 >= params->nth)) {
516
0
            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
517
0
            mnpack<4, 3, 4>(m, n, SIZE_N, 24);
518
0
            return true;
519
0
        }
520
0
        if (m % 8 == 0 ) {
521
0
            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
522
0
            mnpack<4, 3, 2>(m, n, SIZE_N, 24);
523
0
            return true;
524
0
        }
525
0
        if (m % 4 == 0) {
526
0
            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
527
0
            mnpack<4, 3, 1>(m, n, SIZE_N, 24);
528
0
            return true;
529
0
        }
530
0
#endif
531
0
        return false;
532
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::matmul(long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::matmul(long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::matmul(long, long)
533
534
  private:
535
    template <int RM, int RN, int BM>
536
0
    inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
537
0
        if (SIZE_N == RN) {
538
0
            return gemm<RM, RN, BM>(m, n, BN);
539
0
        }
540
0
        if constexpr (RN > 1) {
541
0
            return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
542
0
        } else {
543
0
            GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
544
0
            GGML_ASSERT(false); // we have miss something.
545
0
        }
546
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 3, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 2, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 1, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 1, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 3, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 2, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 1, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 1, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 3, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 2, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 1, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 1, 1>(long, long, long, long)
547
548
    template <int RM, int RN>
549
0
    inline void gemm_bloc(int64_t ii, int64_t jj) {
550
0
        D Cv[RN][RM] = {};
551
0
        for (int64_t l = 0; l < k; l += KN) {
552
            // help compiler for op order.
553
            if constexpr (RM <= RN) {
554
                V Av[RM];
555
                for (int64_t i = 0; i < RM; ++i) {
556
                    Av[i] = load<V>(A + lda * (ii + i) + l);
557
                }
558
                for (int64_t j = 0; j < RN; ++j) {
559
                    V Bv = load<V>(B + ldb * (jj + j) + l);
560
                    for (int64_t i = 0; i < RM; ++i) {
561
                        Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
562
                    }
563
                }
564
0
            } else {
565
0
                V Bv[RN];
566
0
                for (int64_t j = 0; j < RN; ++j) {
567
0
                    Bv[j] = load<V>(B + ldb * (jj + j) + l);
568
0
                }
569
0
                for (int64_t i = 0; i < RM; ++i) {
570
0
                    V Av = load<V>(A + lda * (ii + i) + l);
571
0
                    for (int64_t j = 0; j < RN; ++j) {
572
0
                        Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
573
0
                    }
574
0
                }
575
0
            }
576
0
        }
577
0
        for (int64_t j = 0; j < RN; ++j)
578
0
            for (int64_t i = 0; i < RM; ++i)
579
0
                C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
580
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm_bloc<4, 3>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm_bloc<4, 2>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm_bloc<4, 1>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm_bloc<4, 3>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm_bloc<4, 2>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm_bloc<4, 1>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm_bloc<4, 3>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm_bloc<4, 2>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm_bloc<4, 1>(long, long)
581
582
    template <int RM, int RN, int BM>
583
0
    NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
584
0
        GGML_ASSERT(m % (RM * BM) == 0);
585
0
        const int64_t ytiles = m / (RM * BM);
586
0
        const int64_t xtiles = (n + RN -1) / RN;
587
0
        const int64_t jj_RN = (xtiles - (xtiles * RN - n));
588
589
        // "round" bloc_size to "nearest" BN
590
0
        const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
591
0
        const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
592
0
        const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
593
0
        const int64_t nb_job = ytiles * NB_BN;
594
595
0
        if (params->ith == 0) {
596
0
            GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
597
            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
598
0
            ggml_threadpool_chunk_set(params->threadpool, params->nth);
599
0
        }
600
601
0
        ggml_barrier(params->threadpool);
602
603
0
        int64_t job = params->ith;
604
0
        while (job < nb_job) {
605
0
            const int64_t ii = (job % ytiles) * RM * BM;
606
0
            const int64_t jb =  job / ytiles;
607
0
            const int64_t jr0 = BLOC_POS(jb  , jj_BN, SIZE_BN);
608
0
            const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
609
610
0
            const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
611
0
            const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
612
0
            const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
613
614
0
            for (int64_t bi = 0; bi < BM * RM; bi += RM) {
615
0
                int64_t jj = jj0;
616
0
                for (; jj < jj1; jj += RN) {
617
0
                    gemm_bloc<RM, RN>(ii + bi, jj);
618
0
                }
619
0
                if constexpr (RN > 1) {
620
0
                    for (; jj < jj2; jj += RN - 1) {
621
0
                        gemm_bloc<RM, RN-1>(ii + bi, jj);
622
0
                    }
623
0
                }
624
0
                GGML_ASSERT(jj == jj2);
625
0
            }
626
627
0
            job = ggml_threadpool_chunk_add(params->threadpool, 1);
628
0
        }
629
630
0
        ggml_barrier(params->threadpool);
631
0
        return;
632
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 3, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 2, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 1, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 3, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 2, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 1, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 3, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 2, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 1, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 3, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 2, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 1, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 3, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 2, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 1, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 3, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 2, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 1, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 3, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 2, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 1, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 3, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 2, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 1, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 3, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 2, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 1, 1>(long, long, long)
633
634
    const ggml_compute_params * params;
635
    const TA *const A;
636
    const TB *const B;
637
    TC *const C;
638
    const int64_t k;
639
    const int64_t lda;
640
    const int64_t ldb;
641
    const int64_t ldc;
642
};
643
644
#if defined(__riscv_v_intrinsic)
645
template <typename D, typename V, typename TA, typename TB, typename TC>
646
class tinyBLAS_RVV {
647
  public:
648
    tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
649
             const TA *A, int64_t lda,
650
             const TB *B, int64_t ldb,
651
             TC *C, int64_t ldc)
652
        : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
653
    }
654
655
    bool matmul(int64_t m, int64_t n) {
656
        if (k % vlmax<V>() != 0) {
657
            return false;
658
        }
659
660
#if LMUL == 1
661
        if (m % 16 == 0 && (m/16 >= params->nth)) {
662
            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
663
            mnpack<4, 6, 4>(m, n, SIZE_N, 12);
664
            return true;
665
        }
666
        if (m % 8 == 0 ) {
667
            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
668
            mnpack<4, 6, 2>(m, n, SIZE_N, 12);
669
            return true;
670
        }
671
        if (m % 4 == 0) {
672
            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
673
            mnpack<4, 6, 1>(m, n, SIZE_N, 12);
674
            return true;
675
        }
676
#elif LMUL == 2
677
        if (m % 16 == 0 && (m/16 >= params->nth)) {
678
            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
679
            mnpack<4, 3, 4>(m, n, SIZE_N, 24);
680
            return true;
681
        }
682
        if (m % 8 == 0 ) {
683
            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
684
            mnpack<4, 3, 2>(m, n, SIZE_N, 24);
685
            return true;
686
        }
687
        if (m % 4 == 0) {
688
            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
689
            mnpack<4, 3, 1>(m, n, SIZE_N, 24);
690
            return true;
691
        }
692
#else // LMUL = 4
693
        if (m % 16 == 0 && (m/16 >= params->nth)) {
694
            const int64_t SIZE_N = BLOCK_SIZE<2>(n);
695
            mnpack<2, 2, 8>(m, n, SIZE_N, 36);
696
            return true;
697
        }
698
        if (m % 8 == 0 ) {
699
            const int64_t SIZE_N = BLOCK_SIZE<2>(n);
700
            mnpack<2, 2, 4>(m, n, SIZE_N, 36);
701
            return true;
702
        }
703
        if (m % 4 == 0) {
704
            const int64_t SIZE_N = BLOCK_SIZE<2>(n);
705
            mnpack<2, 2, 2>(m, n, SIZE_N, 36);
706
            return true;
707
        }
708
#endif
709
        return false;
710
    }
711
712
  private:
713
    template<int RM, int RN, int BM>
714
    inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
715
        if (SIZE_N == RN) {
716
            return gemm<RM, RN, BM>(m, n, BN);
717
        }
718
        if constexpr (RN > 1) {
719
            return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
720
        } else {
721
            GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
722
            GGML_ASSERT(false); // we have miss something.
723
        }
724
    }
725
726
    inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
727
        size_t vl = vlmax<V>();
728
        D Cv00 = set_zero<D>();
729
        D Cv01 = set_zero<D>();
730
        D Cv02 = set_zero<D>();
731
        D Cv03 = set_zero<D>();
732
        D Cv10 = set_zero<D>();
733
        D Cv11 = set_zero<D>();
734
        D Cv12 = set_zero<D>();
735
        D Cv13 = set_zero<D>();
736
        D Cv20 = set_zero<D>();
737
        D Cv21 = set_zero<D>();
738
        D Cv22 = set_zero<D>();
739
        D Cv23 = set_zero<D>();
740
        D Cv30 = set_zero<D>();
741
        D Cv31 = set_zero<D>();
742
        D Cv32 = set_zero<D>();
743
        D Cv33 = set_zero<D>();
744
        D Cv40 = set_zero<D>();
745
        D Cv41 = set_zero<D>();
746
        D Cv42 = set_zero<D>();
747
        D Cv43 = set_zero<D>();
748
        D Cv50 = set_zero<D>();
749
        D Cv51 = set_zero<D>();
750
        D Cv52 = set_zero<D>();
751
        D Cv53 = set_zero<D>();
752
753
        for (int64_t l = 0; l < k; l += vl) {
754
            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
755
            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
756
            V Bv2 = load<V>(B + ldb * (jj + 2) + l);
757
            V Bv3 = load<V>(B + ldb * (jj + 3) + l);
758
            V Bv4 = load<V>(B + ldb * (jj + 4) + l);
759
            V Bv5 = load<V>(B + ldb * (jj + 5) + l);
760
761
            V Av0 = load<V>(A + lda * (ii + 0) + l);
762
            Cv00 = madd(Av0, Bv0, Cv00);
763
            Cv10 = madd(Av0, Bv1, Cv10);
764
            Cv20 = madd(Av0, Bv2, Cv20);
765
            Cv30 = madd(Av0, Bv3, Cv30);
766
            Cv40 = madd(Av0, Bv4, Cv40);
767
            Cv50 = madd(Av0, Bv5, Cv50);
768
769
            V Av1 = load<V>(A + lda * (ii + 1) + l);
770
            Cv01 = madd(Av1, Bv0, Cv01);
771
            Cv11 = madd(Av1, Bv1, Cv11);
772
            Cv21 = madd(Av1, Bv2, Cv21);
773
            Cv31 = madd(Av1, Bv3, Cv31);
774
            Cv41 = madd(Av1, Bv4, Cv41);
775
            Cv51 = madd(Av1, Bv5, Cv51);
776
777
            V Av2 = load<V>(A + lda * (ii + 2) + l);
778
            Cv02 = madd(Av2, Bv0, Cv02);
779
            Cv12 = madd(Av2, Bv1, Cv12);
780
            Cv22 = madd(Av2, Bv2, Cv22);
781
            Cv32 = madd(Av2, Bv3, Cv32);
782
            Cv42 = madd(Av2, Bv4, Cv42);
783
            Cv52 = madd(Av2, Bv5, Cv52);
784
785
            V Av3 = load<V>(A + lda * (ii + 3) + l);
786
            Cv03 = madd(Av3, Bv0, Cv03);
787
            Cv13 = madd(Av3, Bv1, Cv13);
788
            Cv23 = madd(Av3, Bv2, Cv23);
789
            Cv33 = madd(Av3, Bv3, Cv33);
790
            Cv43 = madd(Av3, Bv4, Cv43);
791
            Cv53 = madd(Av3, Bv5, Cv53);
792
        }
793
794
        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
795
        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
796
        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
797
        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
798
        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
799
        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
800
        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
801
        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
802
        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
803
        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
804
        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
805
        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
806
        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
807
        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
808
        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
809
        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
810
        C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
811
        C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
812
        C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
813
        C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
814
        C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
815
        C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
816
        C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
817
        C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
818
    }
819
820
    inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
821
        size_t vl = vlmax<V>();
822
        D Cv00 = set_zero<D>();
823
        D Cv01 = set_zero<D>();
824
        D Cv02 = set_zero<D>();
825
        D Cv03 = set_zero<D>();
826
        D Cv10 = set_zero<D>();
827
        D Cv11 = set_zero<D>();
828
        D Cv12 = set_zero<D>();
829
        D Cv13 = set_zero<D>();
830
        D Cv20 = set_zero<D>();
831
        D Cv21 = set_zero<D>();
832
        D Cv22 = set_zero<D>();
833
        D Cv23 = set_zero<D>();
834
        D Cv30 = set_zero<D>();
835
        D Cv31 = set_zero<D>();
836
        D Cv32 = set_zero<D>();
837
        D Cv33 = set_zero<D>();
838
        D Cv40 = set_zero<D>();
839
        D Cv41 = set_zero<D>();
840
        D Cv42 = set_zero<D>();
841
        D Cv43 = set_zero<D>();
842
843
        for (int64_t l = 0; l < k; l += vl) {
844
            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
845
            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
846
            V Bv2 = load<V>(B + ldb * (jj + 2) + l);
847
            V Bv3 = load<V>(B + ldb * (jj + 3) + l);
848
            V Bv4 = load<V>(B + ldb * (jj + 4) + l);
849
850
            V Av0 = load<V>(A + lda * (ii + 0) + l);
851
            Cv00 = madd(Av0, Bv0, Cv00);
852
            Cv10 = madd(Av0, Bv1, Cv10);
853
            Cv20 = madd(Av0, Bv2, Cv20);
854
            Cv30 = madd(Av0, Bv3, Cv30);
855
            Cv40 = madd(Av0, Bv4, Cv40);
856
857
            V Av1 = load<V>(A + lda * (ii + 1) + l);
858
            Cv01 = madd(Av1, Bv0, Cv01);
859
            Cv11 = madd(Av1, Bv1, Cv11);
860
            Cv21 = madd(Av1, Bv2, Cv21);
861
            Cv31 = madd(Av1, Bv3, Cv31);
862
            Cv41 = madd(Av1, Bv4, Cv41);
863
864
            V Av2 = load<V>(A + lda * (ii + 2) + l);
865
            Cv02 = madd(Av2, Bv0, Cv02);
866
            Cv12 = madd(Av2, Bv1, Cv12);
867
            Cv22 = madd(Av2, Bv2, Cv22);
868
            Cv32 = madd(Av2, Bv3, Cv32);
869
            Cv42 = madd(Av2, Bv4, Cv42);
870
871
            V Av3 = load<V>(A + lda * (ii + 3) + l);
872
            Cv03 = madd(Av3, Bv0, Cv03);
873
            Cv13 = madd(Av3, Bv1, Cv13);
874
            Cv23 = madd(Av3, Bv2, Cv23);
875
            Cv33 = madd(Av3, Bv3, Cv33);
876
            Cv43 = madd(Av3, Bv4, Cv43);
877
        }
878
879
        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
880
        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
881
        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
882
        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
883
        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
884
        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
885
        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
886
        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
887
        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
888
        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
889
        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
890
        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
891
        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
892
        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
893
        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
894
        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
895
        C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
896
        C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
897
        C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
898
        C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
899
    }
900
901
    inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
902
        size_t vl = vlmax<V>();
903
        D Cv00 = set_zero<D>();
904
        D Cv01 = set_zero<D>();
905
        D Cv02 = set_zero<D>();
906
        D Cv03 = set_zero<D>();
907
        D Cv10 = set_zero<D>();
908
        D Cv11 = set_zero<D>();
909
        D Cv12 = set_zero<D>();
910
        D Cv13 = set_zero<D>();
911
        D Cv20 = set_zero<D>();
912
        D Cv21 = set_zero<D>();
913
        D Cv22 = set_zero<D>();
914
        D Cv23 = set_zero<D>();
915
        D Cv30 = set_zero<D>();
916
        D Cv31 = set_zero<D>();
917
        D Cv32 = set_zero<D>();
918
        D Cv33 = set_zero<D>();
919
920
        for (int64_t l = 0; l < k; l += vl) {
921
            V Av0 = load<V>(A + lda * (ii + 0) + l);
922
            V Av1 = load<V>(A + lda * (ii + 1) + l);
923
            V Av2 = load<V>(A + lda * (ii + 2) + l);
924
            V Av3 = load<V>(A + lda * (ii + 3) + l);
925
926
            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
927
            Cv00 = madd(Av0, Bv0, Cv00);
928
            Cv01 = madd(Av1, Bv0, Cv01);
929
            Cv02 = madd(Av2, Bv0, Cv02);
930
            Cv03 = madd(Av3, Bv0, Cv03);
931
932
            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
933
            Cv10 = madd(Av0, Bv1, Cv10);
934
            Cv11 = madd(Av1, Bv1, Cv11);
935
            Cv12 = madd(Av2, Bv1, Cv12);
936
            Cv13 = madd(Av3, Bv1, Cv13);
937
938
            V Bv2 = load<V>(B + ldb * (jj + 2) + l);
939
            Cv20 = madd(Av0, Bv2, Cv20);
940
            Cv21 = madd(Av1, Bv2, Cv21);
941
            Cv22 = madd(Av2, Bv2, Cv22);
942
            Cv23 = madd(Av3, Bv2, Cv23);
943
944
            V Bv3 = load<V>(B + ldb * (jj + 3) + l);
945
            Cv30 = madd(Av0, Bv3, Cv30);
946
            Cv31 = madd(Av1, Bv3, Cv31);
947
            Cv32 = madd(Av2, Bv3, Cv32);
948
            Cv33 = madd(Av3, Bv3, Cv33);
949
        }
950
951
        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
952
        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
953
        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
954
        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
955
        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
956
        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
957
        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
958
        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
959
        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
960
        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
961
        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
962
        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
963
        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
964
        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
965
        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
966
        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
967
    }
968
969
    inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
970
        size_t vl = vlmax<V>();
971
        D Cv00 = set_zero<D>();
972
        D Cv01 = set_zero<D>();
973
        D Cv02 = set_zero<D>();
974
        D Cv03 = set_zero<D>();
975
        D Cv10 = set_zero<D>();
976
        D Cv11 = set_zero<D>();
977
        D Cv12 = set_zero<D>();
978
        D Cv13 = set_zero<D>();
979
        D Cv20 = set_zero<D>();
980
        D Cv21 = set_zero<D>();
981
        D Cv22 = set_zero<D>();
982
        D Cv23 = set_zero<D>();
983
984
        for (int64_t l = 0; l < k; l += vl) {
985
            V Av0 = load<V>(A + lda * (ii + 0) + l);
986
            V Av1 = load<V>(A + lda * (ii + 1) + l);
987
            V Av2 = load<V>(A + lda * (ii + 2) + l);
988
            V Av3 = load<V>(A + lda * (ii + 3) + l);
989
990
            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
991
            Cv00 = madd(Av0, Bv0, Cv00);
992
            Cv01 = madd(Av1, Bv0, Cv01);
993
            Cv02 = madd(Av2, Bv0, Cv02);
994
            Cv03 = madd(Av3, Bv0, Cv03);
995
996
            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
997
            Cv10 = madd(Av0, Bv1, Cv10);
998
            Cv11 = madd(Av1, Bv1, Cv11);
999
            Cv12 = madd(Av2, Bv1, Cv12);
1000
            Cv13 = madd(Av3, Bv1, Cv13);
1001
1002
            V Bv2 = load<V>(B + ldb * (jj + 2) + l);
1003
            Cv20 = madd(Av0, Bv2, Cv20);
1004
            Cv21 = madd(Av1, Bv2, Cv21);
1005
            Cv22 = madd(Av2, Bv2, Cv22);
1006
            Cv23 = madd(Av3, Bv2, Cv23);
1007
        }
1008
1009
        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1010
        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1011
        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1012
        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1013
        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1014
        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1015
        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1016
        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1017
        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
1018
        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
1019
        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
1020
        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
1021
    }
1022
1023
    inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
1024
        size_t vl = vlmax<V>();
1025
        D Cv00 = set_zero<D>();
1026
        D Cv01 = set_zero<D>();
1027
        D Cv02 = set_zero<D>();
1028
        D Cv03 = set_zero<D>();
1029
        D Cv10 = set_zero<D>();
1030
        D Cv11 = set_zero<D>();
1031
        D Cv12 = set_zero<D>();
1032
        D Cv13 = set_zero<D>();
1033
1034
        for (int64_t l = 0; l < k; l += vl) {
1035
            V Av0 = load<V>(A + lda * (ii + 0) + l);
1036
            V Av1 = load<V>(A + lda * (ii + 1) + l);
1037
            V Av2 = load<V>(A + lda * (ii + 2) + l);
1038
            V Av3 = load<V>(A + lda * (ii + 3) + l);
1039
1040
            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1041
            Cv00 = madd(Av0, Bv0, Cv00);
1042
            Cv01 = madd(Av1, Bv0, Cv01);
1043
            Cv02 = madd(Av2, Bv0, Cv02);
1044
            Cv03 = madd(Av3, Bv0, Cv03);
1045
1046
            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1047
            Cv10 = madd(Av0, Bv1, Cv10);
1048
            Cv11 = madd(Av1, Bv1, Cv11);
1049
            Cv12 = madd(Av2, Bv1, Cv12);
1050
            Cv13 = madd(Av3, Bv1, Cv13);
1051
        }
1052
1053
        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1054
        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1055
        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1056
        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1057
        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1058
        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1059
        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
1060
        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
1061
    }
1062
1063
    inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
1064
        size_t vl = vlmax<V>();
1065
        D Cv00 = set_zero<D>();
1066
        D Cv01 = set_zero<D>();
1067
        D Cv02 = set_zero<D>();
1068
        D Cv03 = set_zero<D>();
1069
1070
        for (int64_t l = 0; l < k; l += vl) {
1071
            V Av0 = load<V>(A + lda * (ii + 0) + l);
1072
            V Av1 = load<V>(A + lda * (ii + 1) + l);
1073
            V Av2 = load<V>(A + lda * (ii + 2) + l);
1074
            V Av3 = load<V>(A + lda * (ii + 3) + l);
1075
1076
            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1077
            Cv00 = madd(Av0, Bv0, Cv00);
1078
            Cv01 = madd(Av1, Bv0, Cv01);
1079
            Cv02 = madd(Av2, Bv0, Cv02);
1080
            Cv03 = madd(Av3, Bv0, Cv03);
1081
        }
1082
1083
        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1084
        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1085
        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
1086
        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
1087
    }
1088
1089
    inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
1090
        size_t vl = vlmax<V>();
1091
        D Cv00 = set_zero<D>();
1092
        D Cv01 = set_zero<D>();
1093
        D Cv10 = set_zero<D>();
1094
        D Cv11 = set_zero<D>();
1095
1096
        for (int64_t l = 0; l < k; l += vl) {
1097
            V Av0 = load<V>(A + lda * (ii + 0) + l);
1098
            V Av1 = load<V>(A + lda * (ii + 1) + l);
1099
1100
            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1101
            Cv00 = madd(Av0, Bv0, Cv00);
1102
            Cv01 = madd(Av1, Bv0, Cv01);
1103
1104
            V Bv1 = load<V>(B + ldb * (jj + 1) + l);
1105
            Cv10 = madd(Av0, Bv1, Cv10);
1106
            Cv11 = madd(Av1, Bv1, Cv11);
1107
        }
1108
1109
        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1110
        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1111
        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
1112
        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
1113
    }
1114
1115
    inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
1116
        size_t vl = vlmax<V>();
1117
        D Cv00 = set_zero<D>();
1118
        D Cv01 = set_zero<D>();
1119
1120
        for (int64_t l = 0; l < k; l += vl) {
1121
            V Av0 = load<V>(A + lda * (ii + 0) + l);
1122
            V Av1 = load<V>(A + lda * (ii + 1) + l);
1123
1124
            V Bv0 = load<V>(B + ldb * (jj + 0) + l);
1125
            Cv00 = madd(Av0, Bv0, Cv00);
1126
            Cv01 = madd(Av1, Bv0, Cv01);
1127
        }
1128
1129
        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
1130
        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
1131
    }
1132
1133
    template <int RM, int RN>
1134
    inline void gemm_bloc(int64_t ii, int64_t jj) {
1135
        if constexpr (RM == 4) {
1136
            if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
1137
            if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
1138
            if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
1139
            if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
1140
            if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
1141
            if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
1142
        } else if constexpr (RM == 2) {
1143
            if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
1144
            if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
1145
        }
1146
    }
1147
1148
    template <int RM, int RN, int BM>
1149
    NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
1150
        GGML_ASSERT(m % (RM * BM) == 0);
1151
        const int64_t ytiles = m / (RM * BM);
1152
        const int64_t xtiles = (n + RN -1) / RN;
1153
        const int64_t jj_RN = (xtiles - (xtiles * RN - n));
1154
1155
        // "round" bloc_size to "nearest" BN
1156
        const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
1157
        const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
1158
        const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
1159
        const int64_t nb_job = ytiles * NB_BN;
1160
1161
        if (params->ith == 0) {
1162
            GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
1163
            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
1164
            ggml_threadpool_chunk_set(params->threadpool, params->nth);
1165
        }
1166
1167
        ggml_barrier(params->threadpool);
1168
1169
        int64_t job = params->ith;
1170
        while (job < nb_job) {
1171
            const int64_t ii = (job % ytiles) * RM * BM;
1172
            const int64_t jb =  job / ytiles;
1173
            const int64_t jr0 = BLOC_POS(jb  , jj_BN, SIZE_BN);
1174
            const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
1175
1176
            const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
1177
            const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
1178
            const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
1179
1180
            for (int64_t bi = 0; bi < BM * RM; bi += RM) {
1181
                int64_t jj = jj0;
1182
                for (; jj < jj1; jj += RN) {
1183
                    gemm_bloc<RM, RN>(ii + bi, jj);
1184
                }
1185
                if constexpr (RN > 1) {
1186
                    for (; jj < jj2; jj += RN - 1) {
1187
                        gemm_bloc<RM, RN-1>(ii + bi, jj);
1188
                    }
1189
                }
1190
                GGML_ASSERT(jj == jj2);
1191
            }
1192
1193
            job = ggml_threadpool_chunk_add(params->threadpool, 1);
1194
        }
1195
1196
        ggml_barrier(params->threadpool);
1197
        return;
1198
    }
1199
1200
    const ggml_compute_params * params;
1201
    const TA *const A;
1202
    const TB *const B;
1203
    TC *const C;
1204
    const int64_t k;
1205
    const int64_t lda;
1206
    const int64_t ldb;
1207
    const int64_t ldc;
1208
};
1209
#endif
1210
1211
//////////////////////////////////////////////////////////////////////////////////////////
1212
// QUANT ZERO MATRIX MULTIPLICATION
1213
1214
#if defined(__ARM_FEATURE_DOTPROD)
1215
template <typename TA>
1216
class tinyBLAS_Q0_ARM {
1217
  public:
1218
    tinyBLAS_Q0_ARM(int64_t k,
1219
                    const TA *A, int64_t lda,
1220
                    const block_q8_0 *B, int64_t ldb,
1221
                    float *C, int64_t ldc,
1222
                    int ith, int nth)
1223
        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1224
    }
1225
1226
    void matmul(int64_t m, int64_t n) {
1227
        mnpack(0, m, 0, n);
1228
    }
1229
1230
  private:
1231
    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1232
        int64_t mc, nc, mp, np;
1233
        switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
1234
        case 0x33:
1235
            mc = 3;
1236
            nc = 3;
1237
            gemm<3, 3>(m0, m, n0, n);
1238
            break;
1239
        case 0x32:
1240
            mc = 3;
1241
            nc = 2;
1242
            gemm<3, 2>(m0, m, n0, n);
1243
            break;
1244
        case 0x23:
1245
            mc = 2;
1246
            nc = 3;
1247
            gemm<2, 3>(m0, m, n0, n);
1248
            break;
1249
        case 0x22:
1250
            mc = 2;
1251
            nc = 2;
1252
            gemm<2, 2>(m0, m, n0, n);
1253
            break;
1254
        case 0x31:
1255
            mc = 3;
1256
            nc = 1;
1257
            gemm<3, 1>(m0, m, n0, n);
1258
            break;
1259
        case 0x13:
1260
            mc = 1;
1261
            nc = 3;
1262
            gemm<1, 3>(m0, m, n0, n);
1263
            break;
1264
        case 0x21:
1265
            mc = 2;
1266
            nc = 1;
1267
            gemm<2, 1>(m0, m, n0, n);
1268
            break;
1269
        case 0x12:
1270
            mc = 1;
1271
            nc = 2;
1272
            gemm<1, 2>(m0, m, n0, n);
1273
            break;
1274
        case 0x11:
1275
            mc = 1;
1276
            nc = 1;
1277
            gemm<1, 1>(m0, m, n0, n);
1278
            break;
1279
        default:
1280
            return;
1281
        }
1282
        mp = m0 + (m - m0) / mc * mc;
1283
        np = n0 + (n - n0) / nc * nc;
1284
        mnpack(mp, m, n0, np);
1285
        mnpack(m0, m, np, n);
1286
    }
1287
1288
    template <int RM, int RN>
1289
    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1290
        int64_t ytiles = (m - m0) / RM;
1291
        int64_t xtiles = (n - n0) / RN;
1292
        int64_t tiles = xtiles * ytiles;
1293
        int64_t duty = (tiles + nth - 1) / nth;
1294
        int64_t start = duty * ith;
1295
        int64_t end = start + duty;
1296
        if (end > tiles)
1297
            end = tiles;
1298
        for (int64_t job = start; job < end; ++job) {
1299
            int64_t ii = m0 + job / xtiles * RM;
1300
            int64_t jj = n0 + job % xtiles * RN;
1301
            float32x4_t Cv[RN][RM] = {};
1302
            for (int64_t l = 0; l < k; ++l)
1303
                for (int64_t j = 0; j < RN; ++j)
1304
                    for (int64_t i = 0; i < RM; ++i)
1305
                        Cv[j][i] = vmlaq_n_f32(Cv[j][i],
1306
                                               vcvtq_f32_s32(vdotq_s32(
1307
                                                   vdotq_s32(vdupq_n_s32(0),
1308
                                                             load_lo(A + lda * (ii + i) + l),
1309
                                                             load_lo(B + ldb * (jj + j) + l)),
1310
                                                   load_hi(A + lda * (ii + i) + l),
1311
                                                   load_hi(B + ldb * (jj + j) + l))),
1312
                                               unhalf(A[lda * (ii + i) + l].d) *
1313
                                               unhalf(B[ldb * (jj + j) + l].d));
1314
            for (int64_t j = 0; j < RN; ++j)
1315
                for (int64_t i = 0; i < RM; ++i)
1316
                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1317
        }
1318
    }
1319
1320
    inline int8x16_t load_lo(const block_q8_0 *b) {
1321
        return vld1q_s8(b->qs);
1322
    }
1323
1324
    inline int8x16_t load_hi(const block_q8_0 *b) {
1325
        return vld1q_s8(b->qs + 16);
1326
    }
1327
1328
    inline int8x16_t load_lo(const block_q4_0 *b) {
1329
        return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
1330
                                                     vdupq_n_u8(0x0f))),
1331
                        vdupq_n_s8(0x8));
1332
    }
1333
1334
    inline int8x16_t load_hi(const block_q4_0 *b) {
1335
        return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
1336
                        vdupq_n_s8(0x8));
1337
    }
1338
1339
    const TA *const A;
1340
    const block_q8_0 *const B;
1341
    float *const C;
1342
    const int64_t k;
1343
    const int64_t lda;
1344
    const int64_t ldb;
1345
    const int64_t ldc;
1346
    const int ith;
1347
    const int nth;
1348
};
1349
#endif // __ARM_FEATURE_DOTPROD
1350
1351
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1352
template <typename TA, typename TB, typename TC>
1353
class tinyBLAS_Q0_AVX {
1354
  public:
1355
    tinyBLAS_Q0_AVX(int64_t k,
1356
                    const TA *A, int64_t lda,
1357
                    const TB *B, int64_t ldb,
1358
                    TC *C, int64_t ldc,
1359
                    int ith, int nth)
1360
0
        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1361
0
        const int8_t kvalues_iq4nl[16] = {
1362
0
            -127, -104, -83, -65,
1363
0
            -49,  -35,  -22, -10,
1364
0
              1,   13,   25,  38,
1365
0
             53,   69,   89, 113
1366
0
        };
1367
1368
0
        iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
1369
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::tinyBLAS_Q0_AVX(long, block_q8_0 const*, long, block_q8_0 const*, long, float*, long, int, int)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::tinyBLAS_Q0_AVX(long, block_q4_0 const*, long, block_q8_0 const*, long, float*, long, int, int)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::tinyBLAS_Q0_AVX(long, block_q5_0 const*, long, block_q8_0 const*, long, float*, long, int, int)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::tinyBLAS_Q0_AVX(long, block_iq4_nl const*, long, block_q8_0 const*, long, float*, long, int, int)
1370
1371
0
    void matmul(int64_t m, int64_t n) {
1372
0
        mnpack(0, m, 0, n);
1373
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::matmul(long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::matmul(long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::matmul(long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::matmul(long, long)
1374
1375
  private:
1376
0
    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1377
0
        int64_t mc, nc, mp, np;
1378
0
        switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
1379
#if VECTOR_REGISTERS == 32
1380
        case 0x44:
1381
            mc = 4;
1382
            nc = 4;
1383
#if defined(__AVX2__) && defined(__F16C__)
1384
            gemm4xN<4>(m0, m, n0, n);
1385
#else
1386
            gemm<4, 4>(m0, m, n0, n);
1387
#endif
1388
            break;
1389
        case 0x43:
1390
            mc = 4;
1391
            nc = 3;
1392
#if defined(__AVX2__) && defined(__F16C__)
1393
            gemm4xN<3>(m0, m, n0, n);
1394
#else
1395
            gemm<4, 3>(m0, m, n0, n);
1396
#endif
1397
            break;
1398
        case 0x34:
1399
            mc = 3;
1400
            nc = 4;
1401
#if defined(__AVX2__) && defined(__F16C__)
1402
            gemmMx4<3>(m0, m, n0, n);
1403
#else
1404
            gemm<3, 4>(m0, m, n0, n);
1405
#endif
1406
            break;
1407
        case 0x33:
1408
            mc = 3;
1409
            nc = 3;
1410
            gemm<3, 3>(m0, m, n0, n);
1411
            break;
1412
        case 0x42:
1413
            mc = 4;
1414
            nc = 2;
1415
#if defined(__AVX2__) && defined(__F16C__)
1416
            gemm4xN<2>(m0, m, n0, n);
1417
#else
1418
            gemm<4, 2>(m0, m, n0, n);
1419
#endif
1420
            break;
1421
        case 0x24:
1422
            mc = 2;
1423
            nc = 4;
1424
#if defined(__AVX2__) && defined(__F16C__)
1425
            gemmMx4<2>(m0, m, n0, n);
1426
#else
1427
            gemm<2, 4>(m0, m, n0, n);
1428
#endif
1429
            break;
1430
#else
1431
0
        case 0x44:
1432
0
        case 0x43:
1433
0
        case 0x42:
1434
0
            mc = 4;
1435
0
            nc = 2;
1436
0
#if defined(__AVX2__) && defined(__F16C__)
1437
0
            gemm4xN<2>(m0, m, n0, n);
1438
#else
1439
            gemm<4, 2>(m0, m, n0, n);
1440
#endif
1441
0
            break;
1442
0
        case 0x34:
1443
0
        case 0x24:
1444
0
            mc = 2;
1445
0
            nc = 4;
1446
0
#if defined(__AVX2__) && defined(__F16C__)
1447
0
            gemmMx4<2>(m0, m, n0, n);
1448
#else
1449
            gemm<2, 4>(m0, m, n0, n);
1450
#endif
1451
0
            break;
1452
0
        case 0x33:
1453
0
#endif
1454
0
        case 0x32:
1455
0
            mc = 3;
1456
0
            nc = 2;
1457
0
            gemm<3, 2>(m0, m, n0, n);
1458
0
            break;
1459
0
        case 0x23:
1460
0
            mc = 2;
1461
0
            nc = 3;
1462
0
            gemm<2, 3>(m0, m, n0, n);
1463
0
            break;
1464
0
        case 0x41:
1465
0
            mc = 4;
1466
0
            nc = 1;
1467
0
#if defined(__AVX2__) && defined(__F16C__)
1468
0
            gemm4xN<1>(m0, m, n0, n);
1469
#else
1470
            gemm<4, 1>(m0, m, n0, n);
1471
#endif
1472
0
            break;
1473
0
        case 0x22:
1474
0
            mc = 2;
1475
0
            nc = 2;
1476
0
            gemm<2, 2>(m0, m, n0, n);
1477
0
            break;
1478
0
        case 0x14:
1479
0
            mc = 1;
1480
0
            nc = 4;
1481
0
#if defined(__AVX2__) && defined(__F16C__)
1482
0
            gemmMx4<1>(m0, m, n0, n);
1483
#else
1484
            gemm<1, 4>(m0, m, n0, n);
1485
#endif
1486
0
            break;
1487
0
        case 0x31:
1488
0
            mc = 3;
1489
0
            nc = 1;
1490
0
            gemm<3, 1>(m0, m, n0, n);
1491
0
            break;
1492
0
        case 0x13:
1493
0
            mc = 1;
1494
0
            nc = 3;
1495
0
            gemm<1, 3>(m0, m, n0, n);
1496
0
            break;
1497
0
        case 0x21:
1498
0
            mc = 2;
1499
0
            nc = 1;
1500
0
            gemm<2, 1>(m0, m, n0, n);
1501
0
            break;
1502
0
        case 0x12:
1503
0
            mc = 1;
1504
0
            nc = 2;
1505
0
            gemm<1, 2>(m0, m, n0, n);
1506
0
            break;
1507
0
        case 0x11:
1508
0
            mc = 1;
1509
0
            nc = 1;
1510
0
            gemm<1, 1>(m0, m, n0, n);
1511
0
            break;
1512
0
        default:
1513
0
            return;
1514
0
        }
1515
0
        mp = m0 + (m - m0) / mc * mc;
1516
0
        np = n0 + (n - n0) / nc * nc;
1517
0
        mnpack(mp, m, n0, np);
1518
0
        mnpack(m0, m, np, n);
1519
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::mnpack(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::mnpack(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::mnpack(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::mnpack(long, long, long, long)
1520
1521
#if defined(__AVX2__) && defined(__F16C__)
1522
// Templated functions for gemm of dimensions 4xN
1523
    template <int RN>
1524
0
    NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1525
0
        int64_t ytiles = (m - m0) / 4;
1526
0
        int64_t xtiles = (n - n0) / RN;
1527
0
        int64_t tiles = xtiles * ytiles;
1528
0
        int64_t duty = (tiles + nth - 1) / nth;
1529
0
        int64_t start = duty * ith;
1530
0
        int64_t end = start + duty;
1531
0
        if (end > tiles)
1532
0
            end = tiles;
1533
0
        for (int64_t job = start; job < end; ++job) {
1534
0
            int64_t ii = m0 + job / xtiles * 4;
1535
0
            int64_t jj = n0 + job % xtiles * RN;
1536
0
            __m256 Cv[RN][4] = {};
1537
0
            for (int64_t l = 0; l < k; ++l) {
1538
0
                uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
1539
                // Convert delta values for four blocks to float values
1540
0
                __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
1541
0
                __m256i avec0 = load(A + lda * (ii + 0) + l);
1542
0
                __m256i avec1 = load(A + lda * (ii + 1) + l);
1543
0
                __m256i avec2 = load(A + lda * (ii + 2) + l);
1544
0
                __m256i avec3 = load(A + lda * (ii + 3) + l);
1545
0
                for (int64_t j = 0; j < RN; ++j) {
1546
0
                        __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
1547
                        // Computation of product of delta values for four blocks and replicate it across 256 bit lane
1548
0
                        __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));
1549
0
                        dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
1550
                        // Computation of dot product and multiplication with appropriate delta value products
1551
0
                        Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
1552
0
                                    updot(_mm256_sign_epi8(avec0, avec0),
1553
0
                                          _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
1554
0
                                    Cv[j][0]);
1555
0
                        Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
1556
0
                                    updot(_mm256_sign_epi8(avec1, avec1),
1557
0
                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
1558
0
                                    Cv[j][1]);
1559
0
                        Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
1560
0
                                    updot(_mm256_sign_epi8(avec2, avec2),
1561
0
                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
1562
0
                                    Cv[j][2]);
1563
0
                        Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
1564
0
                                    updot(_mm256_sign_epi8(avec3, avec3),
1565
0
                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
1566
0
                                    Cv[j][3]);
1567
0
                }
1568
0
            }
1569
1570
0
            for (int64_t j = 0; j < RN; ++j)
1571
0
                for (int64_t i = 0; i < 4; ++i)
1572
0
                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1573
0
        }
1574
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm4xN<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm4xN<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm4xN<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm4xN<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm4xN<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm4xN<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm4xN<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm4xN<1>(long, long, long, long)
1575
1576
    // Templated functions for gemm of dimensions Mx4
1577
    template <int RM>
1578
0
    NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1579
0
        int64_t ytiles = (m - m0) / RM;
1580
0
        int64_t xtiles = (n - n0) / 4;
1581
0
        int64_t tiles = xtiles * ytiles;
1582
0
        int64_t duty = (tiles + nth - 1) / nth;
1583
0
        int64_t start = duty * ith;
1584
0
        int64_t end = start + duty;
1585
0
        if (end > tiles)
1586
0
            end = tiles;
1587
0
        for (int64_t job = start; job < end; ++job) {
1588
0
            int64_t ii = m0 + job / xtiles * RM;
1589
0
            int64_t jj = n0 + job % xtiles * 4;
1590
0
            __m256 Cv[4][RM] = {};
1591
0
            for (int64_t l = 0; l < k; ++l) {
1592
0
                uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
1593
                // Convert delta values for four blocks to float values
1594
0
                __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
1595
0
                __m256i bvec0 = load(B + ldb * (jj + 0) + l);
1596
0
                __m256i bvec1 = load(B + ldb * (jj + 1) + l);
1597
0
                __m256i bvec2 = load(B + ldb * (jj + 2) + l);
1598
0
                __m256i bvec3 = load(B + ldb * (jj + 3) + l);
1599
0
                for (int64_t i = 0; i < RM; ++i) {
1600
0
                    __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
1601
                    // Computation of product of delta values for four blocks and replicate it across 256 bit lane
1602
0
                    __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));
1603
0
                    dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
1604
                    // Computation of dot product and multiplication with appropriate delta value products
1605
0
                    Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
1606
0
                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1607
0
                                                            load(A + lda * (ii + i) + l)),
1608
0
                                            _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
1609
0
                                    Cv[0][i]);
1610
0
                    Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
1611
0
                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1612
0
                                                            load(A + lda * (ii + i) + l)),
1613
0
                                            _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
1614
0
                                    Cv[1][i]);
1615
0
                    Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
1616
0
                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1617
0
                                                            load(A + lda * (ii + i) + l)),
1618
0
                                            _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
1619
0
                                    Cv[2][i]);
1620
0
                    Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
1621
0
                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1622
0
                                                            load(A + lda * (ii + i) + l)),
1623
0
                                            _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
1624
0
                                    Cv[3][i]);
1625
0
                }
1626
0
            }
1627
0
            for (int64_t j = 0; j < 4; ++j)
1628
0
                for (int64_t i = 0; i < RM; ++i)
1629
0
                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1630
0
        }
1631
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemmMx4<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemmMx4<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemmMx4<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemmMx4<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemmMx4<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemmMx4<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemmMx4<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemmMx4<1>(long, long, long, long)
1632
#endif
1633
1634
    template <int RM, int RN>
1635
0
    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1636
0
        int64_t ytiles = (m - m0) / RM;
1637
0
        int64_t xtiles = (n - n0) / RN;
1638
0
        int64_t tiles = xtiles * ytiles;
1639
0
        int64_t duty = (tiles + nth - 1) / nth;
1640
0
        int64_t start = duty * ith;
1641
0
        int64_t end = start + duty;
1642
0
        if (end > tiles)
1643
0
            end = tiles;
1644
0
        for (int64_t job = start; job < end; ++job) {
1645
0
            int64_t ii = m0 + job / xtiles * RM;
1646
0
            int64_t jj = n0 + job % xtiles * RN;
1647
0
            __m256 Cv[RN][RM] = {};
1648
0
            for (int64_t l = 0; l < k; ++l)
1649
0
                for (int64_t j = 0; j < RN; ++j)
1650
0
                    for (int64_t i = 0; i < RM; ++i) {
1651
0
#if defined(__AVX2__)
1652
0
                        __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
1653
0
                                                              load(A + lda * (ii + i) + l)),
1654
0
                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
1655
0
                                                              load(A + lda * (ii + i) + l)));
1656
#else
1657
                        __m128i ali0 = load0(A + lda * (ii + i) + l);
1658
                        __m128i ali1 = load1(A + lda * (ii + i) + l);
1659
                        __m128i blj0 = load0(B + ldb * (jj + j) + l);
1660
                        __m128i blj1 = load1(B + ldb * (jj + j) + l);
1661
1662
                        __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
1663
                        __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
1664
                        __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
1665
                        __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
1666
1667
                        // updot
1668
                        const __m128i oneFill = _mm_set1_epi16(1);
1669
                        __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
1670
                        __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
1671
                        __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
1672
#endif
1673
0
                        Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
1674
0
                                                       unhalf(B[ldb * (jj + j) + l].d)),
1675
0
                                                       udTmp,
1676
0
                                                       Cv[j][i]);
1677
0
                    }
1678
0
            for (int64_t j = 0; j < RN; ++j)
1679
0
                for (int64_t i = 0; i < RM; ++i)
1680
0
                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
1681
0
        }
1682
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<2, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<1, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<1, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<2, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<1, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<1, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<2, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<1, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<1, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<2, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<1, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<1, 1>(long, long, long, long)
1683
1684
0
    inline __m256i load(const block_q8_0 *b) {
1685
0
        return _mm256_loadu_si256((const __m256i *)b->qs);
1686
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::load(block_q8_0 const*)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::load(block_q8_0 const*)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::load(block_q8_0 const*)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::load(block_q8_0 const*)
1687
1688
    inline __m128i load0(const block_q8_0 *b) {
1689
        return _mm_loadu_si128((const __m128i *)b->qs);
1690
    }
1691
1692
    inline __m128i load1(const block_q8_0 *b) {
1693
        return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
1694
    }
1695
1696
0
    inline __m256i load(const block_q4_0 *b) {
1697
0
        return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
1698
0
    }
1699
1700
    inline __m128i load0(const block_q4_0 *b) {
1701
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1702
        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
1703
    }
1704
1705
    inline __m128i load1(const block_q4_0 *b) {
1706
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1707
        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
1708
    }
1709
1710
0
    inline __m256i load(const block_q5_0 *b) {
1711
0
        return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
1712
0
    }
1713
1714
    inline __m128i load0(const block_q5_0* b) {
1715
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1716
        uint32_t x32;
1717
        memcpy(&x32, b->qh, sizeof(uint32_t));
1718
        __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
1719
        __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
1720
                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
1721
                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),
1722
                                                                      _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
1723
        bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
1724
        return _mm_or_si128(qxl, bytesl);
1725
    }
1726
1727
    inline __m128i load1(const block_q5_0* b) {
1728
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1729
        uint32_t x32;
1730
        memcpy(&x32, b->qh, sizeof(uint32_t));
1731
        __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
1732
        __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
1733
                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
1734
                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),
1735
                                                                      _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
1736
        bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
1737
        return _mm_or_si128(qxh, bytesh);
1738
    }
1739
1740
0
    inline __m256i load(const block_iq4_nl *b) {
1741
0
        return MM256_SET_M128I(load1(b), load0(b));
1742
0
    }
1743
1744
0
    inline __m128i load0(const block_iq4_nl *b) {
1745
0
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1746
0
        return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
1747
0
    }
1748
1749
0
    inline __m128i load1(const block_iq4_nl *b) {
1750
0
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1751
0
        return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
1752
0
    }
1753
1754
0
    inline __m256 updot(__m256i u, __m256i s) {
1755
0
        __m256i res;
1756
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
1757
        res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
1758
#elif defined(__AVXVNNI__)
1759
        res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
1760
#else
1761
0
        res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
1762
0
#endif
1763
0
        return _mm256_cvtepi32_ps(res);
1764
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::updot(long long __vector(4), long long __vector(4))
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::updot(long long __vector(4), long long __vector(4))
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::updot(long long __vector(4), long long __vector(4))
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::updot(long long __vector(4), long long __vector(4))
1765
1766
0
    static inline __m256i denibble(const uint8_t *p) {
1767
0
        __m128i x = _mm_loadu_si128((const __m128i *)p);
1768
0
        return _mm256_and_si256(_mm256_set1_epi8(15),
1769
0
                                _mm256_insertf128_si256(_mm256_castsi128_si256(x),
1770
0
                                                        _mm_srli_epi16(x, 4), 1));
1771
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::denibble(unsigned char const*)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::denibble(unsigned char const*)
1772
1773
0
    static inline __m256i bittobyte(const uint8_t *p) {
1774
0
        uint32_t x32;
1775
0
        memcpy(&x32, p, sizeof(uint32_t));
1776
0
        __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
1777
0
                                          _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
1778
0
                                                          _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
1779
0
                                                                              _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
1780
0
                                                                                                0x0101010101010101, 0x0000000000000000))));
1781
0
        return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
1782
0
    }
1783
1784
    const TA *const A;
1785
    const TB *const B;
1786
    TC *const C;
1787
    const int64_t k;
1788
    const int64_t lda;
1789
    const int64_t ldb;
1790
    const int64_t ldc;
1791
    const int ith;
1792
    const int nth;
1793
    __m128i iq4nlt;
1794
};
1795
#endif // __AVX__
1796
1797
//PPC Implementation
1798
#if defined(__MMA__)
1799
1800
#define SAVE_ACC(ACC, ii, jj) \
1801
   __builtin_mma_disassemble_acc(vec_C, ACC); \
1802
   for (int I = 0; I < 4; I++) { \
1803
      for (int J = 0; J < 4; J++) { \
1804
         *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
1805
      } \
1806
   } \
1807
1808
template<typename T>
1809
struct mma_instr;
1810
1811
template<>
1812
struct mma_instr<ggml_bf16_t> {
1813
    static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1814
        __builtin_mma_xvbf16ger2pp(acc, a, b);
1815
    }
1816
};
1817
1818
template<>
1819
struct mma_instr<ggml_fp16_t> {
1820
    static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
1821
        __builtin_mma_xvf16ger2pp(acc, a, b);
1822
    }
1823
};
1824
1825
template <typename TA, typename TB, typename TC>
1826
class tinyBLAS_HP16_PPC {
1827
  public:
1828
    tinyBLAS_HP16_PPC(int64_t k,
1829
                const TA *A, int64_t lda,
1830
                const TB *B, int64_t ldb,
1831
                TC *C, int64_t ldc,
1832
                int ith, int nth)
1833
        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1834
    }
1835
1836
    void matmul(int64_t m, int64_t n) {
1837
        mnpack(0, m, 0, n);
1838
    }
1839
1840
  private:
1841
    void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
1842
        vec_t t[8], s[8];
1843
        vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
1844
        vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
1845
        vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1846
        vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1847
1848
        if (numVec == 2) {
1849
            t[0] = vec_perm(c[0], c[1], swiz1);
1850
            t[1] = vec_perm(c[2], c[3], swiz1);
1851
            s[0] = vec_perm(t[0], t[1], swiz3);
1852
            s[1] = vec_perm(t[0], t[1], swiz4);
1853
            vec_xst(s[0], 0, (vec_t*)vecOffset);
1854
            vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
1855
        } else if (numVec == 4) {
1856
            t[0] = vec_perm(c[0], c[1], swiz1);
1857
            t[1] = vec_perm(c[0], c[1], swiz2);
1858
            t[2] = vec_perm(c[2], c[3], swiz1);
1859
            t[3] = vec_perm(c[2], c[3], swiz2);
1860
            s[0] = vec_perm(t[0], t[2], swiz3);
1861
            s[1] = vec_perm(t[0], t[2], swiz4);
1862
            s[2] = vec_perm(t[1], t[3], swiz3);
1863
            s[3] = vec_perm(t[1], t[3], swiz4);
1864
            for (int i = 0; i < 4; ++i)
1865
                vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1866
        } else if (numVec == 8) {
1867
            for (int i = 0; i < 4; i += 2) {
1868
                t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1869
                t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1870
            }
1871
            for (int i = 4; i < 8; i += 2) {
1872
                t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1873
                t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1874
            }
1875
            s[0] = vec_perm(t[0], t[2], swiz3);
1876
            s[1] = vec_perm(t[0], t[2], swiz4);
1877
            s[2] = vec_perm(t[1], t[3], swiz3);
1878
            s[3] = vec_perm(t[1], t[3], swiz4);
1879
            s[4] = vec_perm(t[4], t[6], swiz3);
1880
            s[5] = vec_perm(t[4], t[6], swiz4);
1881
            s[6] = vec_perm(t[5], t[7], swiz3);
1882
            s[7] = vec_perm(t[5], t[7], swiz4);
1883
            for (int i = 0; i < 8; ++i)
1884
                vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1885
        }
1886
    }
1887
1888
    void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
1889
        int64_t i, j;
1890
        TA *aoffset = NULL;
1891
        unsigned char *vecOffset = NULL;
1892
        TA * aoffsets[8];
1893
        vector unsigned char c_arr[8];
1894
        aoffset = const_cast<TA*>(a);
1895
        vecOffset = vec;
1896
        j = (rows >> 3);
1897
        if (j > 0) {
1898
            do {
1899
                if (cols == 4) {
1900
                    aoffsets[0] = aoffset;
1901
                    for (int it = 1; it < 4; ++it)
1902
                        aoffsets[it] = aoffsets[it-1] + lda;
1903
                    aoffset += 4 * lda;
1904
                    for (int i = 0; i < 4; ++i)
1905
                        c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
1906
                    vector_permute_store(c_arr, 4, vecOffset);
1907
                    for (int i = 0; i<4; i++)
1908
                        aoffsets[i] = aoffsets[i]+lda;
1909
                    vecOffset +=64;
1910
                }
1911
                i = (cols >> 3);
1912
                if (i > 0) {
1913
                    aoffsets[0] = aoffset;
1914
                    for (int it = 1; it < 8; ++it) {
1915
                        aoffsets[it] = aoffsets[it-1] + lda;
1916
                    }
1917
                    aoffset += 8 * lda;
1918
                    do {
1919
                        for (int it = 0; it < 8; ++it)
1920
                            c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1921
                        vector_permute_store(c_arr, 8, vecOffset);
1922
                        for (int it = 0; it < 8; ++it)
1923
                            aoffsets[it] = aoffsets[it] + 8*lda;
1924
                        vecOffset += 128;
1925
                        i--;
1926
                    } while(i > 0);
1927
                }
1928
                j--;
1929
            } while(j > 0);
1930
        }
1931
        if (rows & 4) {
1932
            aoffsets[0] = aoffset;
1933
            for (int it = 1; it < 4; ++it)
1934
                aoffsets[it] = aoffsets[it-1] + lda;
1935
            aoffset += 4 * lda;
1936
            if (cols == 4) {
1937
                for (int it = 0; it < 4; ++it)
1938
                    c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1939
                vector_permute_store(c_arr, 2, vecOffset);
1940
                for (int it = 0; it< 4; it++)
1941
                    aoffsets[it] = aoffsets[it] + lda;
1942
                vecOffset += 32;
1943
            }
1944
            i = (cols >> 3);
1945
            if (i > 0) {
1946
                do {
1947
                    for (int it = 0; it < 4; ++it)
1948
                        c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1949
                    vector_permute_store(c_arr, 4, vecOffset);
1950
                    for (int it = 0; it< 4; it++)
1951
                        aoffsets[it] = aoffsets[it] + 8*lda;
1952
                    vecOffset += 64;
1953
                    i--;
1954
                } while(i > 0);
1955
            }
1956
        }
1957
        if (rows & 3) {
1958
            aoffsets[0] = aoffset;
1959
            for (int it = 1; it < 4; ++it)
1960
                aoffsets[it] = aoffsets[it-1] + lda;
1961
            if (cols == 4) {
1962
                switch(rows) {
1963
                    case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1964
                    case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1965
                    case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1966
                        break;
1967
                }
1968
                vector_permute_store(c_arr, 2, vecOffset);
1969
                for (int it = 0; it< 4; it++)
1970
                     aoffsets[it] = aoffsets[it] + lda;
1971
                vecOffset += 32;
1972
            }
1973
            i = (cols >> 3);
1974
            if (i > 0) {
1975
                do {
1976
                    switch(rows) {
1977
                        case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1978
                        case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1979
                        case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1980
                            break;
1981
                    }
1982
                    vector_permute_store(c_arr, 4, vecOffset);
1983
                    for (int it = 0; it <4; it++)
1984
                         aoffsets[it] = aoffsets[it] + 8* lda;
1985
                    vecOffset += 64;
1986
                    i--;
1987
                } while(i > 0);
1988
            }
1989
        }
1990
    }
1991
1992
    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1993
        int64_t mc, nc, mp, np;
1994
        int m_rem = MIN(m - m0, 8);
1995
        int n_rem = MIN(n - n0, 8);
1996
1997
        if (m_rem >= 8 && n_rem >= 8) {
1998
            mc = 8;
1999
            nc = 8;
2000
            gemm<8,8>(m0, m, n0, n);
2001
        } else if (m_rem >= 4 && n_rem >= 8) {
2002
            mc = 4;
2003
            nc = 8;
2004
            gemm<4,8>(m0, m, n0, n);
2005
        } else if (m_rem >=8 && n_rem >=4){
2006
                mc = 8;
2007
                nc = 4;
2008
                gemm<8,4>(m0, m, n0, n);
2009
        } else if ((m_rem < 4) && (n_rem >= 8)) {
2010
            nc = 8;
2011
            switch(m_rem) {
2012
                case 1:
2013
                    mc = 1;
2014
                    gemm_Mx8<1>(m0, m, n0, n);
2015
                    break;
2016
                case 2:
2017
                    mc = 2;
2018
                    gemm_Mx8<2>(m0, m, n0, n);
2019
                    break;
2020
                case 3:
2021
                    mc = 3;
2022
                    gemm_Mx8<3>(m0, m, n0, n);
2023
                    break;
2024
                default:
2025
                    return;
2026
            }
2027
        } else if (m_rem >= 4 && n_rem >= 4) {
2028
            mc = 4;
2029
            nc = 4;
2030
            gemm_small<4, 4>(m0, m, n0, n);
2031
        } else if ((m_rem > 4) && (n_rem < 4)) {
2032
            mc = 4;
2033
            switch(n_rem) {
2034
                case 1:
2035
                    nc = 1;
2036
                    gemm_small<4, 1>(m0, m, n0, n);
2037
                    break;
2038
                case 2:
2039
                    nc = 2;
2040
                    gemm_small<4, 2>(m0, m, n0, n);
2041
                    break;
2042
                case 3:
2043
                    nc = 3;
2044
                    gemm_small<4, 3>(m0, m, n0, n);
2045
                    break;
2046
2047
                default:
2048
                    return;
2049
            }
2050
        } else {
2051
            switch((m_rem << 4) | n_rem) {
2052
                case 0x43:
2053
                    mc = 4;
2054
                    nc = 3;
2055
                    gemm_small<4, 3>(m0, m, n0, n);
2056
                    break;
2057
                case 0x42:
2058
                    mc = 4;
2059
                    nc = 2;
2060
                    gemm_small<4, 2>(m0, m, n0, n);
2061
                    break;
2062
                case 0x41:
2063
                    mc = 4;
2064
                    nc = 1;
2065
                    gemm_small<4, 1>(m0, m, n0, n);
2066
                    break;
2067
                case 0x34:
2068
                    mc = 3;
2069
                    nc = 4;
2070
                    gemm_small<3, 4>(m0, m, n0, n);
2071
                    break;
2072
                case 0x33:
2073
                    mc = 3;
2074
                    nc = 3;
2075
                    gemm_small<3, 3>(m0, m, n0, n);
2076
                    break;
2077
                case 0x32:
2078
                    mc = 3;
2079
                    nc = 2;
2080
                    gemm_small<3, 2>(m0, m, n0, n);
2081
                    break;
2082
                case 0x31:
2083
                    mc = 3;
2084
                    nc = 1;
2085
                    gemm_small<3, 1>(m0, m, n0, n);
2086
                    break;
2087
                case 0x24:
2088
                    mc = 2;
2089
                    nc = 4;
2090
                    gemm_small<2,4>(m0, m, n0, n);
2091
                    break;
2092
                case 0x23:
2093
                    mc = 2;
2094
                    nc = 3;
2095
                    gemm_small<2, 3>(m0, m, n0, n);
2096
                    break;
2097
                case 0x22:
2098
                    mc = 2;
2099
                    nc = 2;
2100
                    gemm_small<2, 2>(m0, m, n0, n);
2101
                    break;
2102
                case 0x21:
2103
                    mc = 2;
2104
                    nc = 1;
2105
                    gemm_small<2, 1>(m0, m, n0, n);
2106
                    break;
2107
                case 0x14:
2108
                    mc = 1;
2109
                    nc = 4;
2110
                    gemm_small<1, 4>(m0, m, n0, n);
2111
                    break;
2112
                case 0x13:
2113
                    mc = 1;
2114
                    nc = 3;
2115
                    gemm_small<1, 3>(m0, m, n0, n);
2116
                    break;
2117
                case 0x12:
2118
                    mc = 1;
2119
                    nc = 2;
2120
                    gemm_small<1, 2>(m0, m, n0, n);
2121
                    break;
2122
                case 0x11:
2123
                    mc = 1;
2124
                    nc = 1;
2125
                    gemm_small<1, 1>(m0, m, n0, n);
2126
                    break;
2127
                default:
2128
                    return;
2129
            }
2130
        }
2131
        mp = m0 + (m - m0) / mc * mc;
2132
        np = n0 + (n - n0) / nc * nc;
2133
        mnpack(mp, m, n0, np);
2134
        mnpack(m0, m, np, n);
2135
    }
2136
2137
    void KERNEL_4x8(int64_t ii, int64_t jj) {
2138
        vec_t vec_A[4], vec_B[8] , vec_C[4];
2139
        acc_t acc_0, acc_1;
2140
        __builtin_mma_xxsetaccz(&acc_0);
2141
        __builtin_mma_xxsetaccz(&acc_1);
2142
        for (int l = 0; l < k; l+=8) {
2143
            packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
2144
            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
2145
            for (int x = 0; x < 4; x++) {
2146
                mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2147
                mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2148
            }
2149
        }
2150
        SAVE_ACC(&acc_0, ii, jj);
2151
        SAVE_ACC(&acc_1, ii, jj+4);
2152
    }
2153
2154
    void KERNEL_8x4(int64_t ii, int64_t jj) {
2155
        vec_t vec_A[8], vec_B[4] , vec_C[4];
2156
        acc_t acc_0, acc_1;
2157
        __builtin_mma_xxsetaccz(&acc_0);
2158
        __builtin_mma_xxsetaccz(&acc_1);
2159
        for (int l = 0; l < k; l+=8) {
2160
            packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
2161
            packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
2162
            for (int x = 0; x < 4; x++) {
2163
                mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2164
                mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
2165
            }
2166
        }
2167
        SAVE_ACC(&acc_0, ii, jj);
2168
        SAVE_ACC(&acc_1, ii+4, jj);
2169
    }
2170
2171
2172
    void KERNEL_8x8(int64_t ii, int64_t jj) {
2173
        vec_t vec_A[8], vec_B[8], vec_C[4];
2174
        acc_t acc_0, acc_1, acc_2, acc_3;
2175
        __builtin_mma_xxsetaccz(&acc_0);
2176
        __builtin_mma_xxsetaccz(&acc_1);
2177
        __builtin_mma_xxsetaccz(&acc_2);
2178
        __builtin_mma_xxsetaccz(&acc_3);
2179
        for (int l = 0; l < k; l+=8) {
2180
            packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
2181
            packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
2182
            for (int x = 0; x < 4; x++) {
2183
                mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2184
                mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2185
                mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
2186
                mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
2187
            }
2188
        }
2189
2190
        SAVE_ACC(&acc_0, ii, jj);
2191
        SAVE_ACC(&acc_1, ii, jj+4);
2192
        SAVE_ACC(&acc_2, ii+4, jj);
2193
        SAVE_ACC(&acc_3, ii+4, jj+4);
2194
    }
2195
2196
    template<int RM, int RN>
2197
    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2198
        int64_t ytiles = (m - m0) / RM;
2199
        int64_t xtiles = (n - n0) / RN;
2200
        int64_t tiles = xtiles * ytiles;
2201
        int64_t duty = (tiles + nth - 1) / nth;
2202
        int64_t start = duty * ith;
2203
        int64_t end = start + duty;
2204
        if (end > tiles)
2205
            end = tiles;
2206
        for (int64_t job = start; job < end; ++job) {
2207
            int64_t ii = m0 + job / xtiles * RM;
2208
            int64_t jj = n0 + job % xtiles * RN;
2209
            vec_t vec_C[4];
2210
            acc_t acc_0;
2211
            __builtin_mma_xxsetaccz(&acc_0);
2212
            vec_t vec_A[2], vec_B[2];
2213
            for (int l=0; l<k; l+=4) {
2214
                packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
2215
                packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
2216
                for (int x = 0; x<2; x++) {
2217
                    mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2218
                }
2219
            }
2220
            __builtin_mma_disassemble_acc(vec_C, &acc_0);
2221
            for (int I = 0; I < RM; I++) {
2222
                for (int J = 0; J < RN; J++) {
2223
                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2224
                }
2225
            }
2226
        }
2227
    }
2228
2229
    template<int RM>
2230
    void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2231
        int RN = 8;
2232
        int64_t ytiles = (m - m0) / RM;
2233
        int64_t xtiles = (n - n0) / RN;
2234
        int64_t tiles = xtiles * ytiles;
2235
        int64_t duty = (tiles + nth - 1) / nth;
2236
        int64_t start = duty * ith;
2237
        int64_t end = start + duty;
2238
        if (end > tiles)
2239
            end = tiles;
2240
        for (int64_t job = start; job < end; ++job) {
2241
            int64_t ii = m0 + job / xtiles * RM;
2242
            int64_t jj = n0 + job % xtiles * RN;
2243
            vec_t vec_C[4];
2244
            acc_t acc_0, acc_1;
2245
            __builtin_mma_xxsetaccz(&acc_0);
2246
            __builtin_mma_xxsetaccz(&acc_1);
2247
            vec_t vec_A[4], vec_B[8];
2248
            for (int l=0; l<k; l+=8) {
2249
                packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
2250
                packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
2251
                for (int x = 0; x<4; x++) {
2252
                    mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
2253
                    mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
2254
                }
2255
            }
2256
            __builtin_mma_disassemble_acc(vec_C, &acc_0);
2257
            for (int I = 0; I < RM; I++) {
2258
                for (int J = 0; J < 4; J++) {
2259
                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2260
                }
2261
            }
2262
            __builtin_mma_disassemble_acc(vec_C, &acc_1);
2263
            for (int I = 0; I < RM; I++) {
2264
                for (int J = 0; J < 4; J++) {
2265
                    *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
2266
                }
2267
            }
2268
        }
2269
    }
2270
2271
    template<int RM, int RN>
2272
    inline void kernel(int64_t ii, int64_t jj) {
2273
       if constexpr(RM == 4 && RN == 8) {
2274
          KERNEL_4x8(ii,jj);
2275
       } else if constexpr(RM == 8 && RN == 8) {
2276
          KERNEL_8x8(ii,jj);
2277
       } else if constexpr(RM == 8 && RN == 4) {
2278
          KERNEL_8x4(ii,jj);
2279
       } else {
2280
          assert(false && "RN/RM values not supported");
2281
       }
2282
    }
2283
2284
    template <int RM, int RN>
2285
    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2286
        int64_t ytiles = (m - m0) / RM;
2287
        int64_t xtiles = (n - n0) / RN;
2288
        int64_t tiles = xtiles * ytiles;
2289
        int64_t duty = (tiles + nth - 1) / nth;
2290
        int64_t start = duty * ith;
2291
        int64_t end = start + duty;
2292
        if (end > tiles)
2293
            end = tiles;
2294
        for (int64_t job = start; job < end; ++job) {
2295
            int64_t ii = m0 + job / xtiles * RM;
2296
            int64_t jj = n0 + job % xtiles * RN;
2297
            kernel<RM, RN>(ii, jj);
2298
        }
2299
    }
2300
2301
    const TA *const A;
2302
    const TB *const B;
2303
    TC *C;
2304
    const int64_t k;
2305
    const int64_t lda;
2306
    const int64_t ldb;
2307
    const int64_t ldc;
2308
    const int ith;
2309
    const int nth;
2310
};
2311
2312
template <typename TA>
2313
class tinyBLAS_Q0_PPC {
2314
  public:
2315
    tinyBLAS_Q0_PPC(int64_t k,
2316
             const TA * A, int64_t lda,
2317
             const block_q8_0 * B, int64_t ldb,
2318
             float * C, int64_t ldc,
2319
             int ith, int nth)
2320
        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2321
    }
2322
2323
    void matmul(int64_t m, int64_t n) {
2324
    #if defined(_AIX) || defined(__BIG_ENDIAN__)
2325
        mnpack(0, m, 0, n);
2326
    #else
2327
        const int64_t mc = 64;
2328
        const int64_t kc = 64;
2329
        int64_t nc = 64;
2330
        int64_t n_aligned = 0;
2331
        if (n % 64 == 0) {
2332
            n_aligned = n;
2333
        } else if (n == 4) {
2334
            n_aligned = 4;
2335
        } else if (n < 64) {
2336
            n_aligned = (n / 8) * 8;
2337
        } else {
2338
            n_aligned = (n / 64) * 64;
2339
        }
2340
        if (n_aligned > 0) {
2341
            if (n_aligned % 64 == 0)      nc = 64;
2342
            else if (n_aligned == n)      nc = n;
2343
            else if (n_aligned % 32 == 0) nc = 32;
2344
            else if (n_aligned % 24 == 0) nc = 24;
2345
            else if (n_aligned % 16 == 0) nc = 16;
2346
            else                          nc = 8;
2347
        }
2348
        bool can_use_tiled = n_aligned > 0 && (m % mc == 0);
2349
        if (can_use_tiled) {
2350
            matmul_tiled(m, n_aligned, mc, nc, kc);
2351
            if (n > n_aligned) {
2352
                mnpack(0, m, n_aligned, n);
2353
            }
2354
        } else {
2355
            mnpack(0, m, 0, n);
2356
        }
2357
    #endif
2358
    }
2359
2360
  private:
2361
    inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
2362
        for (int I = 0; I < RM; I++) {
2363
            for (int J = 0; J < RN; J++) {
2364
                *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
2365
            }
2366
        }
2367
    }
2368
2369
    inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2370
        vec_t vec_C[4];
2371
        __builtin_mma_disassemble_acc(vec_C, ACC);
2372
        for (int I = 0; I < 4; I++) {
2373
            for (int J = 0; J < 4; J++) {
2374
                *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
2375
            }
2376
        }
2377
    }
2378
2379
    inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2380
        vec_t vec_C[4];
2381
        __builtin_mma_disassemble_acc(vec_C, ACC);
2382
        for (int I = 0; I < 4; I++) {
2383
            for (int J = 0; J < 4; J++) {
2384
                float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
2385
                *c_ptr += *((float *)&vec_C[I] + J);
2386
            }
2387
        }
2388
    }
2389
2390
    template<typename ArrayType>
2391
    inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
2392
        vector signed int vec_C[4];
2393
        vector float CA[4] = {0};
2394
        vector float res[4] = {0};
2395
        __builtin_mma_disassemble_acc(vec_C, ACC);
2396
        for (int i = 0; i < 4; i++) {
2397
            CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
2398
            res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
2399
            fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
2400
        }
2401
    }
2402
2403
    inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
2404
        const vector signed char lowMask = vec_splats((signed char)0xF);
2405
        const vector unsigned char v4 = vec_splats((unsigned char)0x4);
2406
        const vector signed char v8 = vec_splats((signed char)0x8);
2407
        vector signed int vsum = {0};
2408
        vector signed int vsum2 = {0};
2409
        c[0] = vec_and(c[1], lowMask);
2410
        c[1] = vec_sr(c[1], v4);
2411
        c[0] = vec_sub(c[0], v8);
2412
        c[1] = vec_sub(c[1], v8);
2413
        vsum = vec_sum4s(c[0], vsum);
2414
        vsum2 = vec_sum4s(c[1], vsum2);
2415
        vsum = vec_add(vsum, vsum2);
2416
        *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
2417
    }
2418
2419
    template <typename V1, typename V2>
2420
    inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
2421
        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2422
        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2423
        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
2424
        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
2425
        V2 t1, t2, t3, t4, t5, t6, t7, t8;
2426
        vector unsigned char xor_vector;
2427
        uint8_t flip_vec = 0x80;
2428
        xor_vector = vec_splats(flip_vec);
2429
        t1 = vec_perm(s1, s2, swiz1);
2430
        t2 = vec_perm(s1, s2, swiz2);
2431
        t3 = vec_perm(s3, s4, swiz1);
2432
        t4 = vec_perm(s3, s4, swiz2);
2433
        t5 = vec_perm(t1, t3, swiz3);
2434
        t6 = vec_perm(t1, t3, swiz4);
2435
        t7 = vec_perm(t2, t4, swiz3);
2436
        t8 = vec_perm(t2, t4, swiz4);
2437
        if (flip == true) {
2438
            t5 = vec_xor(t5, xor_vector);
2439
            t6 = vec_xor(t6, xor_vector);
2440
            t7 = vec_xor(t7, xor_vector);
2441
            t8 = vec_xor(t8, xor_vector);
2442
        }
2443
        vec_xst(t5, 0, vecOffset);
2444
        vec_xst(t6, 0, vecOffset + 16);
2445
        vec_xst(t7, 0, vecOffset + 32);
2446
        vec_xst(t8, 0, vecOffset + 48);
2447
    }
2448
2449
    inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
2450
        const vector signed char lowMask = vec_splats((signed char)0x0F);
2451
        const vector signed char v8      = vec_splats((signed char)0x08);
2452
        const vector unsigned char v4    = vec_splats((unsigned char)4);
2453
        lo = vec_and(packed, lowMask);
2454
        hi = vec_sr(packed, v4);
2455
        lo = vec_sub(lo, v8);
2456
        hi = vec_sub(hi, v8);
2457
    }
2458
2459
    inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
2460
        vec_t t[8], s[8];
2461
        vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
2462
        vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
2463
        vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2464
        vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2465
        for (int i = 0; i < 4; i += 2) {
2466
            t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
2467
            t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
2468
        }
2469
        for (int i = 4; i < 8; i += 2) {
2470
            t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
2471
            t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
2472
        }
2473
        s[0] = vec_perm(t[0], t[2], swiz3);
2474
        s[1] = vec_perm(t[0], t[2], swiz4);
2475
        s[2] = vec_perm(t[1], t[3], swiz3);
2476
        s[3] = vec_perm(t[1], t[3], swiz4);
2477
        s[4] = vec_perm(t[4], t[6], swiz3);
2478
        s[5] = vec_perm(t[4], t[6], swiz4);
2479
        s[6] = vec_perm(t[5], t[7], swiz3);
2480
        s[7] = vec_perm(t[5], t[7], swiz4);
2481
        for (int i = 0; i < 8; ++i) {
2482
            vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
2483
        }
2484
    }
2485
2486
    static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
2487
        vector signed short i16_hi = vec_unpackh(raw);
2488
        vector signed short i16_lo = vec_unpackl(raw);
2489
2490
        vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
2491
        vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
2492
        vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
2493
        vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
2494
        out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
2495
        out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
2496
    }
2497
2498
    void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
2499
        unsigned char * vecOffset = vec;
2500
        for (int i = 0; i < rows; i += 8) {
2501
            const block_q4_0 * rows_base[8];
2502
            for (int r = 0; r < 8; r++) {
2503
                rows_base[r] = a + (i + r) * lda;
2504
            }
2505
            for (int blk = 0; blk < blocks; blk++) {
2506
                vector unsigned short hp_res[8][4];
2507
                for (int r = 0; r < 8; r++) {
2508
                    const block_q4_0 * current_blk = rows_base[r] + blk;
2509
                    vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
2510
                    vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs);
2511
                    vector signed char c1, c2;
2512
                    unpack_q4_to_q8(v_qs, c1, c2);
2513
                    convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
2514
                    convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
2515
                }
2516
                for (int c = 0; c < 4; c++) {
2517
                    vector unsigned char c_arr[8];
2518
                    for (int r = 0; r < 8; r++) {
2519
                        c_arr[r] = (vector unsigned char)hp_res[r][c];
2520
                    }
2521
                    vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
2522
                    vecOffset += 128;
2523
                }
2524
            }
2525
        }
2526
    }
2527
2528
    template <int chunk_size>
2529
    static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
2530
        unsigned char * vecOffset = vec;
2531
        const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
2532
        const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
2533
        const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
2534
        const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
2535
2536
        for (int i = 0; i < rows; i += chunk_size) {
2537
            const block_q8_0 * rows_base[chunk_size];
2538
            for (int r = 0; r < chunk_size; r++) {
2539
                rows_base[r] = a + (i + r) * lda;
2540
            }
2541
            for (int blk = 0; blk < blocks; blk++) {
2542
                vector unsigned short hp_res[chunk_size][4];
2543
                for (int r = 0; r < chunk_size; r++) {
2544
                    const block_q8_0 * b = rows_base[r] + blk;
2545
                    vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
2546
                    vector signed char c[2];
2547
                    __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
2548
                    __builtin_vsx_disassemble_pair(c, & pair);
2549
                    convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
2550
                    convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
2551
                }
2552
                for (int col = 0; col < 4; col++) {
2553
                    if constexpr (chunk_size == 8) {
2554
                        vec_t t[8];
2555
                        t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
2556
                        t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
2557
                        t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
2558
                        t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
2559
                        t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
2560
                        t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
2561
                        t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
2562
                        t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
2563
2564
                        vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
2565
                        vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
2566
                        vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
2567
                        vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
2568
                        vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
2569
                        vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
2570
                        vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
2571
                        vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
2572
                        vecOffset += 128;
2573
                    } else {
2574
                        vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
2575
                        vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
2576
                        vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
2577
                        vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
2578
2579
                        vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
2580
                        vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
2581
                        vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
2582
                        vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
2583
                        vecOffset += 64;
2584
                    }
2585
                }
2586
            }
2587
        }
2588
    }
2589
2590
    void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
2591
        if (rows == 4) {
2592
            pack_q8_block<4>(a, lda, rows, blocks, vec);
2593
        } else {
2594
            pack_q8_block<8>(a, lda, rows, blocks, vec);
2595
        }
2596
    }
2597
2598
    template<int size>
2599
    void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
2600
        int64_t i, j;
2601
        TA * aoffset = NULL;
2602
        int8_t * vecOffset = NULL;
2603
        TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
2604
        TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
2605
        vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2606
        vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
2607
        aoffset = const_cast<TA *>(a);
2608
        vecOffset = vec;
2609
        j = (rows >> 3);
2610
        if (j > 0) {
2611
            do {
2612
                aoffset1 = aoffset;
2613
                aoffset2 = aoffset1 + lda;
2614
                aoffset3 = aoffset2 + lda;
2615
                aoffset4 = aoffset3 + lda;
2616
                aoffset5 = aoffset4 + lda;
2617
                aoffset6 = aoffset5 + lda;
2618
                aoffset7 = aoffset6 + lda;
2619
                aoffset8 = aoffset7 + lda;
2620
                aoffset += 8 * lda;
2621
                i = (cols >> 2);
2622
                if (i > 0) {
2623
                    do {
2624
                        c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
2625
                        c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
2626
                        c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
2627
                        c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
2628
                        c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs);
2629
                        c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs);
2630
                        c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs);
2631
                        c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs);
2632
2633
                        process_q4_elements(c1, & comparray[0]);
2634
                        process_q4_elements(c2, & comparray[1]);
2635
                        process_q4_elements(c3, & comparray[2]);
2636
                        process_q4_elements(c4, & comparray[3]);
2637
                        process_q4_elements(c5, & comparray[4]);
2638
                        process_q4_elements(c6, & comparray[5]);
2639
                        process_q4_elements(c7, & comparray[6]);
2640
                        process_q4_elements(c8, & comparray[7]);
2641
                        vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2642
                        vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
2643
                        vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
2644
                        vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
2645
                        aoffset1 += lda;
2646
                        aoffset2 += lda;
2647
                        aoffset3 += lda;
2648
                        aoffset4 += lda;
2649
                        aoffset5 += lda;
2650
                        aoffset6 += lda;
2651
                        aoffset7 += lda;
2652
                        aoffset8 += lda;
2653
                        vecOffset += 256;
2654
                        i--;
2655
                    } while (i > 0);
2656
                }
2657
                j--;
2658
            } while (j > 0);
2659
        }
2660
2661
        if (rows & 4) {
2662
            aoffset1 = aoffset;
2663
            aoffset2 = aoffset1 + lda;
2664
            aoffset3 = aoffset2 + lda;
2665
            aoffset4 = aoffset3 + lda;
2666
            aoffset += 4 * lda;
2667
            i = (cols >> 2);
2668
            if (i > 0) {
2669
                do {
2670
                    c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
2671
                    c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
2672
                    c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
2673
                    c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
2674
2675
                    process_q4_elements(c1, & comparray[0]);
2676
                    process_q4_elements(c2, & comparray[1]);
2677
                    process_q4_elements(c3, & comparray[2]);
2678
                    process_q4_elements(c4, & comparray[3]);
2679
                    vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2680
                    vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
2681
                    aoffset1 += lda;
2682
                    aoffset2 += lda;
2683
                    aoffset3 += lda;
2684
                    aoffset4 += lda;
2685
                    vecOffset += 128;
2686
                    i--;
2687
                } while (i > 0);
2688
            }
2689
        }
2690
2691
        if (rows & 3) {
2692
            aoffset1 = aoffset;
2693
            aoffset2 = aoffset1 + lda;
2694
            aoffset3 = aoffset2 + lda;
2695
            i = (cols >> 2);
2696
            if (i > 0) {
2697
                do {
2698
                    switch(rows) {
2699
                        case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
2700
                        case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
2701
                        case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
2702
                            break;
2703
                    }
2704
                    process_q4_elements(c1, & comparray[0]);
2705
                    process_q4_elements(c2, & comparray[1]);
2706
                    process_q4_elements(c3, & comparray[2]);
2707
                    process_q4_elements(c4, & comparray[3]);
2708
                    vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2709
                    vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
2710
                    aoffset1 += lda;
2711
                    aoffset2 += lda;
2712
                    aoffset3 += lda;
2713
                    vecOffset += 128;
2714
                    i--;
2715
                } while(i > 0);
2716
            }
2717
        }
2718
    }
2719
2720
    template<typename VA, typename VB>
2721
    void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
2722
        int64_t i, j;
2723
        block_q8_0 * aoffset = NULL;
2724
        VA * vecOffset = NULL;
2725
        block_q8_0 * aoffsets[8];
2726
        __vector_pair arr[8];
2727
        VB c[8][2] = {0};
2728
        VB c1[8] = {0}; VB c2[8] = {0};
2729
        aoffset = const_cast<block_q8_0 *>(a);
2730
        vecOffset = vec;
2731
        j = (rows >> 3);
2732
        if (j > 0) {
2733
            do {
2734
                aoffsets[0] = aoffset;
2735
                for (int it = 1; it < 8; it++)
2736
                    aoffsets[it] = aoffsets[it - 1] + lda;
2737
                aoffset += 8 * lda;
2738
2739
                i = (cols >> 3);
2740
                if (i > 0) {
2741
                do {
2742
                    for (int it = 0; it < 8; it++) {
2743
                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
2744
                        __builtin_vsx_disassemble_pair(c[it], & arr[it]);
2745
                        c1[it] = c[it][0];
2746
                        c2[it] = c[it][1];
2747
                    }
2748
                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2749
                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
2750
                    vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
2751
                    vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
2752
                    for (int it = 0; it < 8; it++)
2753
                        aoffsets[it] += lda;
2754
                    vecOffset += 256;
2755
                    i--;
2756
               } while(i > 0);
2757
            }
2758
            j--;
2759
        } while(j > 0);
2760
    }
2761
    if (rows & 4) {
2762
            aoffsets[0]  = aoffset;
2763
            for (int it = 1; it < 4; it++ )
2764
                aoffsets[it] = aoffsets[it-1] + lda;
2765
            aoffset += 4 * lda;
2766
        i = (cols >> 3);
2767
            if (i > 0) {
2768
               do {
2769
                    for (int it = 0; it < 4; it++) {
2770
                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
2771
                        __builtin_vsx_disassemble_pair(c[it], & arr[it]);
2772
                        c1[it] = c[it][0];
2773
                        c2[it] = c[it][1];
2774
                    }
2775
                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2776
                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
2777
                    for (int it = 0; it < 4; it++) {
2778
                        aoffsets[it] += lda;
2779
                    }
2780
                    vecOffset += 128;
2781
                    i--;
2782
               } while(i > 0);
2783
            }
2784
        }
2785
2786
        if (rows & 3) {
2787
            aoffsets[0]  = aoffset;
2788
            for (int it = 1; it < 3; it++ )
2789
                aoffsets[it] = aoffsets[it - 1] + lda;
2790
            i = (cols >> 3);
2791
            if (i > 0) {
2792
                do {
2793
                    switch(rows) {
2794
                        case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
2795
                                __builtin_vsx_disassemble_pair(c[2], & arr[2]);
2796
                                c1[2] = c[2][0]; c2[2] = c[2][1];
2797
                        case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
2798
                                __builtin_vsx_disassemble_pair(c[1], & arr[1]);
2799
                                c1[1] = c[1][0]; c2[1] = c[1][1];
2800
                        case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
2801
                                __builtin_vsx_disassemble_pair(c[0], & arr[0]);
2802
                                c1[0] = c[0][0]; c2[0] = c[0][1];
2803
                                break;
2804
                    }
2805
                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
2806
                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
2807
                    for (int it = 0; it < 3; it++)
2808
                         aoffsets[it] += lda;
2809
                    vecOffset += 128;
2810
                    i--;
2811
               } while(i > 0);
2812
            }
2813
        }
2814
    }
2815
2816
    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2817
        int m_rem = MIN(m - m0, 16);
2818
        int n_rem = MIN(n - n0, 16);
2819
2820
        int mc = 0, nc = 0;
2821
2822
        if (m_rem >= 8 && n_rem >= 8) {
2823
           mc = 8;
2824
           nc = 8;
2825
           gemm<8, 8>(m0, m, n0, n);
2826
        } else if (m_rem >= 4 && n_rem >= 8) {
2827
            mc = 4;
2828
            nc = 8;
2829
            gemm<4, 8>(m0, m, n0, n);
2830
        } else if (m_rem >= 8 && n_rem >= 4) {
2831
            mc = 8;
2832
            nc = 4;
2833
            gemm<8, 4>(m0, m, n0, n);
2834
        } else if (m_rem >= 4 && n_rem >= 4) {
2835
            mc = 4;
2836
            nc = 4;
2837
            gemm_small(m0, m, n0, n, mc, nc);
2838
        } else {
2839
            mc = (m_rem >= 4) ? 4 : m_rem;
2840
            nc = (n_rem >= 4) ? 4 : n_rem;
2841
            if (mc == 0 || nc == 0)
2842
               return;
2843
            gemm_small(m0, m, n0, n, mc, nc);
2844
        }
2845
2846
        int64_t mp = m0 + ((m - m0) / mc) * mc;
2847
        int64_t np = n0 + ((n - n0) / nc) * nc;
2848
        mnpack(mp, m, n0, np);
2849
        mnpack(m0, m, np, n);
2850
    }
2851
2852
2853
    void KERNEL_4x8(int64_t ii, int64_t jj) {
2854
        vec_t vec_A[8], vec_B[16] = {0};
2855
        acc_t acc_0, acc_1;
2856
        std::array<int, 4> comparray {};
2857
        vector float fin_res[8] = {0};
2858
        vector float vs[8] = {0};
2859
        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2860
        for (int l = 0; l < k; l++) {
2861
            __builtin_mma_xxsetaccz(& acc_0);
2862
            __builtin_mma_xxsetaccz(& acc_1);
2863
            if (std::is_same_v<TA, block_q4_0>) {
2864
               packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
2865
            } else {
2866
               packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
2867
            }
2868
            packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
2869
            for(int x = 0; x < 8; x++) {
2870
                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
2871
                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
2872
            }
2873
            for (int I = 0; I<4; I++) {
2874
                for (int J = 0; J<4; J++) {
2875
                    *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
2876
                    *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
2877
                }
2878
            }
2879
            if (!isAblock_q4) {
2880
                auto aoffset = A + (ii * lda) + l;
2881
                for (int i = 0; i < 4; i++) {
2882
                    comparray[i] = 0;
2883
                    int ca = 0;
2884
                    auto *at = aoffset->qs;
2885
                    for (int j = 0; j < 32; j++)
2886
                        ca += (int)*at++;
2887
                    comparray[i] = ca;
2888
                    aoffset += lda;
2889
                }
2890
            }
2891
            compute(& acc_0, 0, 0, comparray, vs, fin_res);
2892
            compute(& acc_1, 0, 4, comparray, vs, fin_res);
2893
        }
2894
        save_res(ii, jj, 0, fin_res);
2895
        save_res(ii, jj + 4, 4, fin_res);
2896
    }
2897
2898
    void KERNEL_8x4(int64_t ii, int64_t jj) {
2899
        vec_t vec_A[16], vec_B[8] = {0};
2900
        acc_t acc_0, acc_1;
2901
        std::array<int, 8> comparray {};
2902
        vector float fin_res[8] = {0};
2903
        vector float vs[8] = {0};
2904
        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2905
        for (int l = 0; l < k; l++) {
2906
            __builtin_mma_xxsetaccz(& acc_0);
2907
            __builtin_mma_xxsetaccz(& acc_1);
2908
            if (std::is_same_v<TA, block_q4_0>) {
2909
               packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
2910
            } else {
2911
               packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
2912
            }
2913
            packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
2914
            for(int x = 0; x < 8; x++) {
2915
                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
2916
                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
2917
            }
2918
            for (int I = 0; I < 8; I++) {
2919
                for (int J = 0; J < 4; J++) {
2920
                    *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
2921
                }
2922
            }
2923
            if (!isAblock_q4) {
2924
                auto aoffset = A + (ii * lda) + l;
2925
                for (int i = 0; i < 8; i++) {
2926
                    comparray[i] = 0;
2927
                    int ca = 0;
2928
                    auto *at = aoffset->qs;
2929
                    for (int j = 0; j < 32; j++)
2930
                        ca += (int)*at++;
2931
                    comparray[i] = ca;
2932
                    aoffset += lda;
2933
                }
2934
            }
2935
            compute(& acc_0, 0, 0, comparray, vs, fin_res);
2936
            compute(& acc_1, 4, 4, comparray, vs, fin_res);
2937
        }
2938
        save_res(ii, jj, 0, fin_res);
2939
        save_res(ii + 4, jj, 4, fin_res);
2940
    }
2941
2942
    void KERNEL_8x8(int64_t ii, int64_t jj) {
2943
        vec_t vec_A[16], vec_B[16] = {0};
2944
        acc_t acc_0, acc_1, acc_2, acc_3;
2945
        acc_t acc_4, acc_5, acc_6, acc_7;
2946
        std::array<int, 8> comparray {};
2947
        vector float fin_res[16] = {0};
2948
        vector float vs[16] = {0};
2949
        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2950
        for (int l = 0; l < k; l++) {
2951
            __builtin_mma_xxsetaccz(& acc_0);
2952
            __builtin_mma_xxsetaccz(& acc_1);
2953
            __builtin_mma_xxsetaccz(& acc_2);
2954
            __builtin_mma_xxsetaccz(& acc_3);
2955
            if (std::is_same_v<TA, block_q4_0>) {
2956
               packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
2957
            } else {
2958
               packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
2959
            }
2960
            packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
2961
            for(int x = 0; x < 8; x++) {
2962
                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
2963
                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
2964
                __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
2965
                __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
2966
            }
2967
            for (int I = 0; I < 8 ; I++) {
2968
                for (int J = 0; J < 4; J++) {
2969
                    *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
2970
                    *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
2971
                }
2972
            }
2973
            if (!isAblock_q4) {
2974
                auto aoffset = A + (ii * lda) + l;
2975
                for (int i = 0; i < 8; i++) {
2976
                    comparray[i] = 0;
2977
                    int ca = 0;
2978
                    auto *at = aoffset->qs;
2979
                    for (int j = 0; j < 32; j++)
2980
                        ca += (int)*at++;
2981
                    comparray[i] = ca;
2982
                    aoffset += lda;
2983
                }
2984
            }
2985
            compute(& acc_0, 0, 0, comparray, vs, fin_res);
2986
            compute(& acc_1, 4, 4, comparray, vs, fin_res);
2987
            compute(& acc_2, 0, 8, comparray, vs, fin_res);
2988
            compute(& acc_3, 4, 12, comparray, vs, fin_res);
2989
        }
2990
        save_res(ii, jj, 0, fin_res);
2991
        save_res(ii + 4, jj, 4, fin_res);
2992
        save_res(ii, jj + 4, 8, fin_res);
2993
        save_res(ii + 4, jj + 4, 12, fin_res);
2994
    }
2995
2996
    void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
2997
        acc_t acc[8];
2998
        for (int i = 0; i < mc ; i += 16) {
2999
            for (int j = 0; j < nc; j += 8) {
3000
                int A0_base = (i / 16) * (2 * 32 * kc);
3001
                int B0_base = (j / 8) * (32 * kc);
3002
                for (int x = 0; x < 8; x++) {
3003
                     __builtin_mma_xxsetaccz(&acc[x]);
3004
                }
3005
                for (int64_t kk = 0; kk < kc; kk++) {
3006
                    int A0_block_idx = A0_base + kk * 32;
3007
                    int B0_block_idx = B0_base + kk * 32;
3008
                    int A1_block_idx = A0_block_idx + 32 * kc;
3009
                    int B1_block_idx = B0_block_idx + 32 * kc;
3010
                    vec_t * A0_block = & vec_A[A0_block_idx];
3011
                    vec_t * B0_block = & vec_B[B0_block_idx];
3012
                    vec_t * A1_block = & vec_A[A1_block_idx];
3013
                    for (int it = 0; it < 4; it++) {
3014
                        for (int x = 0; x < 4; x++) {
3015
                            __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
3016
                            __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
3017
                            __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
3018
                            __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
3019
                            __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
3020
                            __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
3021
                            __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
3022
                            __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
3023
                        }
3024
                    }
3025
                }
3026
                if (l == 0) {
3027
                    save_acc(& acc[0], ii + i, jj + j);
3028
                    save_acc(& acc[1], ii + i, jj + j + 4);
3029
                    save_acc(& acc[2], ii + i + 4, jj + j);
3030
                    save_acc(& acc[3], ii + i + 4, jj + j + 4);
3031
                    save_acc(& acc[4], ii + i + 8, jj + j);
3032
                    save_acc(& acc[5], ii + i + 8, jj + j + 4);
3033
                    save_acc(& acc[6], ii + i + 12, jj + j);
3034
                    save_acc(& acc[7], ii + i + 12, jj + j + 4);
3035
                } else {
3036
                    add_save_acc(& acc[0], ii + i, jj + j);
3037
                    add_save_acc(& acc[1], ii + i, jj + j + 4);
3038
                    add_save_acc(& acc[2], ii + i + 4, jj + j);
3039
                    add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
3040
                    add_save_acc(& acc[4], ii + i + 8, jj + j);
3041
                    add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
3042
                    add_save_acc(& acc[6], ii + i + 12, jj + j);
3043
                    add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
3044
                }
3045
            }
3046
        }
3047
    }
3048
3049
    void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
3050
        vec_t A_pack[mc * kc * 4];
3051
        vec_t B_pack[nc * kc * 4];
3052
        constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
3053
        int64_t ytiles = m / mc;
3054
        int64_t xtiles = n / nc;
3055
        int64_t tiles  = xtiles * ytiles;
3056
        int64_t duty = (tiles + nth - 1) / nth;
3057
        int64_t start = duty * ith;
3058
        int64_t end = start + duty;
3059
        if (end > tiles) {
3060
            end = tiles;
3061
        }
3062
        for (int64_t job = start; job < end; ++job) {
3063
            int64_t ii = (job / xtiles) * mc;
3064
            int64_t jj = (job % xtiles) * nc;
3065
            for (int64_t kk = 0; kk < k; kk += kc) {
3066
                int64_t k_cur = MIN(kc, k - kk);
3067
                if constexpr(is_Ablock_q4) {
3068
                    packNormal_q4_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack);
3069
                } else {
3070
                    packNormal_q8_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack);
3071
                }
3072
                packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, k_cur, (uint8_t *)B_pack);
3073
                KERNEL_Q0(ii, jj, mc, nc, k_cur, kk, A_pack, B_pack);
3074
            }
3075
        }
3076
    }
3077
3078
    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3079
        int64_t ytiles = (m - m0) / RM;
3080
        int64_t xtiles = (n - n0) / RN;
3081
        int64_t tiles = xtiles * ytiles;
3082
        int64_t duty = (tiles + nth - 1) / nth;
3083
        int64_t start = duty * ith;
3084
        int64_t end = start + duty;
3085
        vec_t vec_A[8] = {0}, vec_B[8] = {0};
3086
        vector signed int vec_C[4];
3087
        acc_t acc_0;
3088
        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
3089
3090
        if (end > tiles)
3091
            end = tiles;
3092
        for (int64_t job = start; job < end; ++job) {
3093
            int64_t ii = m0 + job / xtiles * RM;
3094
            int64_t jj = n0 + job % xtiles * RN;
3095
            std::array<int, 4> comparray{};
3096
            vector float res[4] = {0};
3097
            vector float fin_res[4] = {0};
3098
            vector float vs[4] = {0};
3099
            vector float CA[4] = {0};
3100
            __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
3101
            __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
3102
            for (int l = 0; l < k; l++) {
3103
                __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
3104
                __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
3105
                __builtin_mma_xxsetaccz(& acc_0);
3106
                if (isAblock_q4) {
3107
                    packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
3108
                } else {
3109
                    packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
3110
                }
3111
                packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
3112
                for (int x = 0; x < 8; x += 4) {
3113
                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
3114
                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
3115
                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
3116
                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
3117
                }
3118
                for (int I = 0; I < RM; I++) {
3119
                    for (int J = 0; J < RN; J++) {
3120
                        *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
3121
                    }
3122
                }
3123
                __builtin_mma_disassemble_acc(vec_C, & acc_0);
3124
                if (!isAblock_q4) {
3125
                    auto aoffset = A + (ii * lda) + l;
3126
                    for (int i = 0; i < RM; i++) {
3127
                        comparray[i] = 0;
3128
                        int ca = 0;
3129
                        auto *at = aoffset->qs;
3130
                        for (int j = 0; j < 32; j++)
3131
                            ca += (int)*at++;
3132
                        comparray[i] = ca;
3133
                        aoffset += lda;
3134
                    }
3135
                }
3136
                for (int i = 0; i < RM; i++) {
3137
                    CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
3138
                    res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
3139
                    fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
3140
                }
3141
            }
3142
            save_res(ii, jj, 0, fin_res, RM, RN);
3143
        }
3144
    }
3145
3146
    template<int RM, int RN>
3147
    inline void kernel(int64_t ii, int64_t jj) {
3148
        if constexpr(RM == 4 && RN == 8) {
3149
            KERNEL_4x8(ii,jj);
3150
        } else if constexpr(RM == 8 && RN == 4) {
3151
            KERNEL_8x4(ii,jj);
3152
        } else if constexpr(RM == 8 && RN == 8) {
3153
            KERNEL_8x8(ii,jj);
3154
        } else {
3155
            assert(false && "RN/RM values not supported");
3156
        }
3157
    }
3158
3159
    template <int RM, int RN>
3160
    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3161
        int64_t ytiles = (m - m0) / RM;
3162
        int64_t xtiles = (n - n0) / RN;
3163
        int64_t tiles = xtiles * ytiles;
3164
        int64_t duty = (tiles + nth - 1) / nth;
3165
        int64_t start = duty * ith;
3166
        int64_t end = start + duty;
3167
        if (end > tiles)
3168
            end = tiles;
3169
        for (int64_t job = start; job < end; ++job) {
3170
            int64_t ii = m0 + job / xtiles * RM;
3171
            int64_t jj = n0 + job % xtiles * RN;
3172
            kernel<RM, RN>(ii, jj);
3173
        }
3174
    }
3175
    const TA * const A;
3176
    const block_q8_0 * const B;
3177
    float * C;
3178
    const int64_t k;
3179
    int64_t kc;
3180
    const int64_t lda;
3181
    const int64_t ldb;
3182
    const int64_t ldc;
3183
    const int ith;
3184
    const int nth;
3185
};
3186
3187
class tinyBLAS_PPC {
3188
  public:
3189
    tinyBLAS_PPC(int64_t k,
3190
                const float * A, int64_t lda,
3191
                const float * B, int64_t ldb,
3192
                float * C, int64_t ldc,
3193
                int ith, int nth)
3194
        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
3195
    }
3196
3197
    void matmul(int64_t m, int64_t n) {
3198
    #if defined(_AIX) || defined(__BIG_ENDIAN__)
3199
        mnpack(0, m, 0, n);
3200
    #else
3201
        int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
3202
        if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
3203
            matmul_tiled(m, n, mc, nc, kc);
3204
        } else {
3205
            mnpack(0, m, 0, n);
3206
        }
3207
    #endif
3208
    }
3209
3210
  private:
3211
3212
    __attribute__((always_inline))
3213
    inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
3214
        vec_t vec_C[4];
3215
        __builtin_mma_disassemble_acc(vec_C, ACC);
3216
        for (int I = 0; I < 4; I++) {
3217
            for (int J = 0; J < 4; J++) {
3218
                *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
3219
            }
3220
        }
3221
    }
3222
3223
    __attribute__((always_inline))
3224
    inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
3225
        vec_t vec_C[4];
3226
        __builtin_mma_disassemble_acc(vec_C, ACC);
3227
        for (int I = 0; I < 4; I++) {
3228
            for (int J = 0; J < 4; J++) {
3229
                float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
3230
                *c_ptr += *((float *)&vec_C[I]+J);
3231
            }
3232
        }
3233
    }
3234
3235
    inline void vector_permute_store_4(vector float * src, float * vecOffset) {
3236
        vector float t1, t2, t3, t4, t5, t6, t7, t8;
3237
        t1 = vec_mergeh(src[0], src[1]);
3238
        t2 = vec_mergeh(src[2], src[3]);
3239
        t3 = vec_mergel(src[0], src[1]);
3240
        t4 = vec_mergel(src[2], src[3]);
3241
3242
        t5 = vec_xxpermdi(t1, t2, 0);
3243
        t6 = vec_xxpermdi(t1, t2, 3);
3244
        t7 = vec_xxpermdi(t3, t4, 0);
3245
        t8 = vec_xxpermdi(t3, t4, 3);
3246
3247
        vec_xst(t5, 0, vecOffset);
3248
        vec_xst(t6, 0, vecOffset + 4);
3249
        vec_xst(t7, 0, vecOffset + 8);
3250
        vec_xst(t8, 0, vecOffset + 12);
3251
    }
3252
3253
    inline void vector_permute_store_8(vector float * src, float * vecOffset) {
3254
        vector float t1, t2, t3, t4, t5, t6, t7, t8;
3255
        t1 = vec_mergeh(src[0], src[1]);
3256
        t2 = vec_mergeh(src[2], src[3]);
3257
        t3 = vec_mergeh(src[4], src[5]);
3258
        t4 = vec_mergeh(src[6], src[7]);
3259
3260
        t5 = vec_xxpermdi(t1, t2, 0);
3261
        t6 = vec_xxpermdi(t3, t4, 0);
3262
        t7 = vec_xxpermdi(t1, t2, 3);
3263
        t8 = vec_xxpermdi(t3, t4, 3);
3264
3265
        vec_xst(t5, 0, vecOffset);
3266
        vec_xst(t6, 0, vecOffset + 4);
3267
        vec_xst(t7, 0, vecOffset + 8);
3268
        vec_xst(t8, 0, vecOffset + 12);
3269
3270
        t1 = vec_mergel(src[0], src[1]);
3271
        t2 = vec_mergel(src[2], src[3]);
3272
        t3 = vec_mergel(src[4], src[5]);
3273
        t4 = vec_mergel(src[6], src[7]);
3274
3275
        t5 = vec_xxpermdi(t1, t2, 0);
3276
        t6 = vec_xxpermdi(t3, t4, 0);
3277
        t7 = vec_xxpermdi(t1, t2, 3);
3278
        t8 = vec_xxpermdi(t3, t4, 3);
3279
3280
        vec_xst(t5, 0, vecOffset + 16);
3281
        vec_xst(t6, 0, vecOffset + 20);
3282
        vec_xst(t7, 0, vecOffset + 24);
3283
        vec_xst(t8, 0, vecOffset + 28);
3284
    }
3285
3286
    void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
3287
        int64_t i, j;
3288
        float * aoffsets[8];
3289
        float * aoffset = NULL, * boffset = NULL;
3290
        __vector_pair arr[8];
3291
        vector float c[8][2] = {0};
3292
        vector float c1[8] = {0};
3293
        vector float c2[8] = {0};
3294
        aoffset = const_cast<float *>(a);
3295
        boffset = vec;
3296
        j = (rows >> 3);
3297
        if (j > 0) {
3298
            do {
3299
                aoffsets[0] = aoffset;
3300
                for (int it = 1; it < 8; it++)
3301
                    aoffsets[it] = aoffsets[it-1] + lda;
3302
                aoffset += 8 * lda;
3303
                i = (cols >> 3);
3304
                if (i > 0) {
3305
                    do {
3306
                        for (int it = 0; it < 8; it++) {
3307
                            arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
3308
                            __builtin_vsx_disassemble_pair(c[it], &arr[it]);
3309
                            c1[it] = c[it][0];
3310
                            c2[it] = c[it][1];
3311
                        }
3312
3313
                        vector_permute_store_8(c1, boffset);
3314
                        vector_permute_store_8(c2, boffset + 32);
3315
                        boffset += 64;
3316
                        i--;
3317
                        if (i > 0) {
3318
                           for (int it = 0; it < 8; it++) {
3319
                               aoffsets[it] = aoffsets[it] + 8;
3320
                           }
3321
                        }
3322
                    } while(i > 0);
3323
                }
3324
                if (cols & 4) {
3325
                    for (int it = 0; it < 8 ; it++)
3326
                        c1[it] = vec_xl(0, aoffsets[it]);
3327
                    vector_permute_store_8(c1, boffset);
3328
                }
3329
            j--;
3330
            } while(j > 0);
3331
        }
3332
3333
        if (rows & 4) {
3334
            aoffsets[0] = aoffset;
3335
            for (int it = 1; it < 4; it++)
3336
                aoffsets[it] = aoffsets[it-1] + lda;
3337
            aoffset += 4 * lda;
3338
            i = (cols >> 3);
3339
            if (i > 0) {
3340
                do {
3341
                    for (int it = 0; it < 4; it++) {
3342
                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
3343
                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
3344
                        c1[it] = c[it][0];
3345
                        c2[it] = c[it][1];
3346
                    }
3347
                    vector_permute_store_4(c1, boffset);
3348
                    vector_permute_store_4(c2, boffset + 16);
3349
                    for (int it = 0; it < 4; it++)
3350
                        aoffsets[it] += 8 * lda;
3351
                    boffset += 32;
3352
                    i--;
3353
                } while(i > 0);
3354
            }
3355
3356
            if (cols & 4) {
3357
               for (int it = 0; it < 4; it++)
3358
                   c1[it] = vec_xl(0, aoffsets[it]);
3359
                vector_permute_store_4(c1, boffset);
3360
            }
3361
        }
3362
        if (rows & 3) {
3363
            aoffsets[0] = aoffset;
3364
            for (int it = 1; it < 3; it++)
3365
                aoffsets[it] = aoffsets[it-1] + lda;
3366
            if (cols & 4) {
3367
                for (int it = 0; it < 3; it++)
3368
                    c1[it] = vec_xl(0, aoffsets[it]);
3369
                vector_permute_store_4(c1, boffset);
3370
            }
3371
        }
3372
    }
3373
3374
    void KERNEL_4x4(int64_t ii, int64_t jj) {
3375
        vec_t vec_A[4], vec_B[4], vec_C[4];
3376
        acc_t acc_0;
3377
        __builtin_mma_xxsetaccz(&acc_0);
3378
        for (int l = 0; l < k; l += 4) {
3379
            packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
3380
            packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
3381
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3382
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
3383
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
3384
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
3385
        }
3386
        save_acc(&acc_0, ii, jj);
3387
    }
3388
3389
    void KERNEL_4x8(int64_t ii, int64_t jj) {
3390
        vec_t vec_A[4], vec_B[8], vec_C[4];
3391
        acc_t acc_0, acc_1;
3392
        __builtin_mma_xxsetaccz(&acc_0);
3393
        __builtin_mma_xxsetaccz(&acc_1);
3394
        for (int64_t l = 0; l < k; l += 4) {
3395
            packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
3396
            packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
3397
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
3398
            __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
3399
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
3400
            __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
3401
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
3402
            __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
3403
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
3404
            __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
3405
        }
3406
        save_acc(&acc_0, ii, jj);
3407
        save_acc(&acc_1, ii, jj + 4);
3408
    }
3409
3410
    void KERNEL_8x4(int64_t ii, int64_t jj) {
3411
        vec_t vec_A[8], vec_B[4], vec_C[4];
3412
        acc_t acc_0, acc_1;
3413
        __builtin_mma_xxsetaccz(&acc_0);
3414
        __builtin_mma_xxsetaccz(&acc_1);
3415
        for (int64_t l = 0; l < k; l += 4) {
3416
            packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
3417
            packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
3418
            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
3419
            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
3420
            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
3421
            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
3422
            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
3423
            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
3424
            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
3425
            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
3426
        }
3427
        save_acc(&acc_0, ii, jj);
3428
        save_acc(&acc_1, ii + 4, jj);
3429
    }
3430
3431
    void KERNEL_8x8(int64_t ii, int64_t jj) {
3432
        vec_t vec_A[16], vec_B[16], vec_C[4];
3433
        acc_t acc_0, acc_1, acc_2, acc_3;
3434
        __builtin_mma_xxsetaccz(&acc_0);
3435
        __builtin_mma_xxsetaccz(&acc_1);
3436
        __builtin_mma_xxsetaccz(&acc_2);
3437
        __builtin_mma_xxsetaccz(&acc_3);
3438
        for (int l = 0; l < k; l+=8) {
3439
            packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
3440
            packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
3441
            for(int x = 0; x < 16; x+=2) {
3442
                __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
3443
                __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
3444
                __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
3445
                __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
3446
            }
3447
        }
3448
        save_acc(&acc_0, ii, jj);
3449
        save_acc(&acc_1, ii, jj + 4);
3450
        save_acc(&acc_2, ii + 4, jj);
3451
        save_acc(&acc_3, ii + 4, jj + 4);
3452
    }
3453
3454
    inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
3455
        for (int x = 0; x < 16; x += 2) {
3456
            __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
3457
            __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
3458
            __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
3459
            __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
3460
            __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
3461
            __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
3462
            __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
3463
            __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
3464
        }
3465
    }
3466
3467
    void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
3468
        for (int64_t i = 0; i < mc; i += 16) {
3469
            int A_base_addr = (mc / 8) * (i / 8) * 16;
3470
            for (int64_t j = 0; j < nc; j += 8) {
3471
                 int B_base_addr = (nc / 8) * (j / 8) * 16;
3472
                 acc_t acc[8];
3473
                 vec_t A0_block[16]; vec_t A1_block[16];
3474
                 for (int x = 0; x < 8; x++)
3475
                     __builtin_mma_xxsetaccz(&acc[x]);
3476
                 for (int64_t l = 0; l < kc; l += 8) {
3477
                     int A0_block_idx = A_base_addr + (l / 8) * 16;
3478
                     int A1_block_idx = A0_block_idx + (mc / 8) * 16;
3479
                     int B_block_idx = B_base_addr + (l / 8) * 16;
3480
                     vec_t* A0_block = &vec_A[A0_block_idx];
3481
                     vec_t* A1_block = &vec_A[A1_block_idx];
3482
                     vec_t* B_block = &vec_B[B_block_idx];
3483
                     MMA_16x8(A0_block, A1_block, B_block, acc);
3484
                 }
3485
                 if (kk == 0) {
3486
                     save_acc(&acc[0], ii + i, jj + j);
3487
                     save_acc(&acc[1], ii + i, jj + j + 4);
3488
                     save_acc(&acc[2], ii + i + 4, jj + j);
3489
                     save_acc(&acc[3], ii + i + 4, jj + j + 4);
3490
                     save_acc(&acc[4], ii + i + 8, jj + j);
3491
                     save_acc(&acc[5], ii + i + 8, jj + j + 4);
3492
                     save_acc(&acc[6], ii + i + 12, jj + j);
3493
                     save_acc(&acc[7], ii + i + 12, jj + j + 4);
3494
                 } else {
3495
                     add_save_acc(&acc[0], ii + i, jj + j);
3496
                     add_save_acc(&acc[1], ii + i, jj + j + 4);
3497
                     add_save_acc(&acc[2], ii + i + 4, jj + j);
3498
                     add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
3499
                     add_save_acc(&acc[4], ii + i + 8, jj + j);
3500
                     add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
3501
                     add_save_acc(&acc[6], ii + i + 12, jj + j);
3502
                     add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
3503
                 }
3504
            }
3505
        }
3506
    }
3507
3508
    void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
3509
        int64_t ytiles = m / mc;
3510
        int64_t xtiles = n / nc;
3511
        int64_t tiles = xtiles * ytiles;
3512
        int64_t duty = (tiles + nth - 1) / nth;
3513
        int64_t start = duty * ith;
3514
        int64_t end = start + duty;
3515
        if (end > tiles) {
3516
            end = tiles;
3517
        }
3518
        for (int64_t job = start; job < end; ++job) {
3519
            int64_t ii = (job / xtiles) * mc;
3520
            int64_t jj = (job % xtiles) * nc;
3521
            for (int64_t kk = 0; kk < k; kk += kc) {
3522
                 vec_t A_pack[kc * mc / 4];
3523
                 vec_t B_pack[kc * nc / 4];
3524
                 packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
3525
                 packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
3526
                 KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
3527
            }
3528
        }
3529
    }
3530
3531
    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3532
        int m_rem = MIN(m - m0, 8);
3533
        int n_rem = MIN(n - n0, 8);
3534
        int mc = 0, nc = 0;
3535
        if (m_rem >= 8 && n_rem >= 8) {
3536
            mc = 8;
3537
            nc = 8;
3538
            gemm<8, 8>(m0, m, n0, n);
3539
        } else if (m_rem >= 4 && n_rem >= 8) {
3540
            mc = 4;
3541
            nc = 8;
3542
            gemm<4, 8>(m0, m, n0, n);
3543
        } else if (m_rem >= 8 && n_rem >= 4) {
3544
            mc = 8;
3545
            nc = 4;
3546
            gemm<8, 4>(m0, m, n0, n);
3547
        } else if (m_rem >= 4 && n_rem >= 4) {
3548
            mc = 4;
3549
            nc = 4;
3550
            gemm<4, 4>(m0, m, n0, n);
3551
        } else {
3552
            mc = (m_rem >= 4) ? 4 : m_rem;
3553
            nc = (n_rem >= 4) ? 4 : n_rem;
3554
            if (mc == 0 || nc == 0)
3555
                return;
3556
            gemm_small(m0, m, n0, n, mc, nc);
3557
        }
3558
        int64_t mp = m0 + ((m - m0) / mc) * mc;
3559
        int64_t np = n0 + ((n - n0) / nc) * nc;
3560
        mnpack(mp, m, n0, np);
3561
        mnpack(m0, m, np, n);
3562
    }
3563
3564
    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
3565
        int64_t ytiles = (m - m0) / RM;
3566
        int64_t xtiles = (n - n0) / RN;
3567
        int64_t tiles = xtiles * ytiles;
3568
        int64_t duty = (tiles + nth - 1) / nth;
3569
        int64_t start = duty * ith;
3570
        int64_t end = start + duty;
3571
        if (end > tiles)
3572
            end = tiles;
3573
        for (int64_t job = start; job < end; ++job) {
3574
            int64_t ii = m0 + job / xtiles * RM;
3575
            int64_t jj = n0 + job % xtiles * RN;
3576
            vec_t vec_C[4];
3577
            acc_t acc_0;
3578
            __builtin_mma_xxsetaccz(&acc_0);
3579
            vec_t vec_A[4] = {0}, vec_B[4] = {0};
3580
            for (int l = 0; l < k; l += 4) {
3581
                /* 'GEMV Forwarding' concept is used in first two conditional loops.
3582
                 * when one of the matrix has a single row/column, the elements are
3583
                 * broadcasted, instead of using packing routine to prepack the
3584
                 * matrix elements.
3585
                 */
3586
                if (RM == 1) {
3587
                    float * a = const_cast<float *>(A + (ii) * lda + l);
3588
                    packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3589
                    vec_A[0] = (vec_t)vec_xl(0,a);
3590
                    vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
3591
                    vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
3592
                    vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
3593
                } else if (RN == 1) {
3594
                    packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
3595
                    float * b = const_cast<float *>(B + (jj) * ldb + l);
3596
                    vec_B[0] = (vec_t)vec_xl(0,b);
3597
                    vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
3598
                    vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
3599
                    vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
3600
                } else {
3601
                    packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
3602
                    packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
3603
                }
3604
                __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
3605
                __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
3606
                __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
3607
                __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
3608
            }
3609
            __builtin_mma_disassemble_acc(vec_C, &acc_0);
3610
            for (int I = 0; I < RM; I++) {
3611
                for (int J = 0; J < RN; J++) {
3612
                    *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
3613
                }
3614
            }
3615
       }
3616
    }
3617
3618
    template<int RM, int RN>
3619
    inline void kernel(int64_t ii, int64_t jj) {
3620
        if constexpr(RM == 4 && RN == 4) {
3621
            KERNEL_4x4(ii, jj);
3622
        } else if constexpr(RM == 4 && RN == 8) {
3623
            KERNEL_4x8(ii, jj);
3624
        } else if constexpr(RM == 8 && RN == 4) {
3625
            KERNEL_8x4(ii, jj);
3626
        } else if constexpr(RM == 8 && RN == 8) {
3627
            KERNEL_8x8(ii, jj);
3628
        } else {
3629
            static_assert(false, "RN/RM values not supported");
3630
        }
3631
    }
3632
3633
    template <int RM, int RN>
3634
    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
3635
        int64_t ytiles = (m - m0) / RM;
3636
        int64_t xtiles = (n - n0) / RN;
3637
        int64_t tiles = xtiles * ytiles;
3638
        int64_t duty = (tiles + nth - 1) / nth;
3639
        int64_t start = duty * ith;
3640
        int64_t end = start + duty;
3641
        if (end > tiles)
3642
            end = tiles;
3643
        for (int64_t job = start; job < end; ++job) {
3644
            int64_t ii = m0 + job / xtiles * RM;
3645
            int64_t jj = n0 + job % xtiles * RN;
3646
            kernel<RM, RN>(ii, jj);
3647
        }
3648
    }
3649
3650
    const float * const A;
3651
    const float * const B;
3652
    float * C;
3653
    const int64_t k;
3654
    const int64_t lda;
3655
    const int64_t ldb;
3656
    const int64_t ldc;
3657
    const int ith;
3658
    const int nth;
3659
};
3660
#endif
3661
} // namespace
3662
3663
/**
3664
 * Performs optimized matrix multiplication on CPU.
3665
 *
3666
 * This subroutine may compute C = Aᵀ * B with column major ordering.
3667
 * Despite its name, this isn't a generalized implementation. Work is
3668
 * only performed when a handwritten kernel is written and available.
3669
 * Otherwise the caller should fall back to a general matmul routine.
3670
 *
3671
 * For example, for single-threaded single-precision GEMM you can say
3672
 *
3673
 *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
3674
 *                     0, 1,
3675
 *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
3676
 *
3677
 * @param m is rows in `A` and `C`
3678
 * @param n is cols in `B` and `C`
3679
 * @param k is cols in `A` and rows in `B`
3680
 * @param A is first input matrix (always transposed)
3681
 * @param lda is row stride of `A`
3682
 * @param B is second input matrix (never transposed)
3683
 * @param ldb is row stride of `B`
3684
 * @param C is input/output array of output matrices
3685
 * @param ldc is row stride of `C`
3686
 * @param ith is thread id (must be less than `nth`)
3687
 * @param nth is number of threads (must be greater than zero)
3688
 * @param Atype is GGML data type of `A`
3689
 * @param Btype is GGML data type of `B`
3690
 * @param Ctype is GGML data type of `C`
3691
 * @return true if this function was able to service the matmul request
3692
 */
3693
bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
3694
                     const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
3695
0
                     int64_t ldc, int Atype, int Btype, int Ctype) {
3696
3697
0
    assert(m >= 0);
3698
0
    assert(n >= 0);
3699
0
    assert(k >= 0);
3700
0
    assert(lda >= k);
3701
0
    assert(ldb >= k);
3702
0
    assert(ldc >= m);
3703
0
    assert(params->nth > 0);
3704
0
    assert(params->ith < params->nth);
3705
3706
    // only enable sgemm for prompt processing
3707
0
#if !defined(__MMA__)
3708
0
    if (n < 2)
3709
0
        return false;
3710
0
#endif
3711
3712
0
    if (Ctype != GGML_TYPE_F32)
3713
0
        return false;
3714
3715
0
    switch (Atype) {
3716
3717
0
    case GGML_TYPE_F32: {
3718
0
        if (Btype != GGML_TYPE_F32)
3719
0
            return false;
3720
#if defined(__AVX512F__)
3721
        tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
3722
            k, (const float *)A, lda,
3723
            (const float *)B, ldb,
3724
            (float *)C, ldc};
3725
        return tb.matmul(m, n);
3726
#elif defined(__AVX__) || defined(__AVX2__)
3727
0
        tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
3728
0
            k, (const float *)A, lda,
3729
0
            (const float *)B, ldb,
3730
0
            (float *)C, ldc};
3731
0
        return tb.matmul(m, n);
3732
#elif defined(__ARM_NEON)
3733
        if (n < 4)
3734
            return false;
3735
        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
3736
            k, (const float *)A, lda,
3737
            (const float *)B, ldb,
3738
            (float *)C, ldc};
3739
        return tb.matmul(m, n);
3740
#elif defined(__VXE__) || defined(__VXE2__)
3741
        if (n < 4)
3742
            return false;
3743
        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
3744
            k, (const float *)A, lda,
3745
            (const float *)B, ldb,
3746
            (float *)C, ldc};
3747
        return tb.matmul(m, n);
3748
#elif defined(__MMA__)
3749
        if (k % 8)
3750
            return false;
3751
        tinyBLAS_PPC tb{
3752
            k, (const float *)A, lda,
3753
            (const float *)B, ldb,
3754
            (float *)C, ldc,
3755
            params->ith, params->nth};
3756
        tb.matmul(m, n);
3757
        return true;
3758
#elif defined(__riscv_v_intrinsic)
3759
    #if LMUL == 1
3760
        tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
3761
            k, (const float *)A, lda,
3762
            (const float *)B, ldb,
3763
            (float *)C, ldc};
3764
    #elif LMUL == 2
3765
        tinyBLAS_RVV<vfloat32m2_t, vfloat32m2_t, float, float, float> tb{ params,
3766
            k, (const float *)A, lda,
3767
            (const float *)B, ldb,
3768
            (float *)C, ldc};
3769
    #else // LMUL = 4
3770
        tinyBLAS_RVV<vfloat32m4_t, vfloat32m4_t, float, float, float> tb{ params,
3771
            k, (const float *)A, lda,
3772
            (const float *)B, ldb,
3773
            (float *)C, ldc};
3774
    #endif
3775
        return tb.matmul(m, n);
3776
#else
3777
        return false;
3778
#endif
3779
0
    }
3780
3781
0
    case GGML_TYPE_BF16: {
3782
#if defined(__AVX512BF16__)
3783
        if (Btype == GGML_TYPE_BF16) {
3784
            tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
3785
                (const ggml_bf16_t *)A, lda,
3786
                (const ggml_bf16_t *)B, ldb,
3787
                (float *)C, ldc};
3788
            return tb.matmul(m, n);
3789
        }
3790
#elif defined(__AVX512F__)
3791
        if (Btype == GGML_TYPE_BF16) {
3792
            tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
3793
                (const ggml_bf16_t *)A, lda,
3794
                (const ggml_bf16_t *)B, ldb,
3795
                (float *)C, ldc};
3796
            return tb.matmul(m, n);
3797
        }
3798
#elif defined(__AVX2__)
3799
0
        if (Btype == GGML_TYPE_BF16) {
3800
0
            tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
3801
0
                (const ggml_bf16_t *)A, lda,
3802
0
                (const ggml_bf16_t *)B, ldb,
3803
0
                (float *)C, ldc};
3804
0
            return tb.matmul(m, n);
3805
0
        }
3806
#elif defined(__MMA__)
3807
        if (k % 8) {
3808
            return false;
3809
        }
3810
3811
        if (Btype == GGML_TYPE_BF16) {
3812
            tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
3813
                (const ggml_bf16_t *)A, lda,
3814
                (const ggml_bf16_t *)B, ldb,
3815
                (float *)C, ldc,
3816
                params->ith, params->nth };
3817
3818
            tb.matmul(m, n);
3819
            return true;
3820
        }
3821
#elif defined(__riscv_zvfbfwma)
3822
        if (Btype == GGML_TYPE_BF16) {
3823
            #if LMUL == 1
3824
                tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3825
                    k, (const ggml_bf16_t *)A, lda,
3826
                    (const ggml_bf16_t *)B, ldb,
3827
                    (float *)C, ldc};
3828
            #elif LMUL == 2
3829
                tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3830
                    k, (const ggml_bf16_t *)A, lda,
3831
                    (const ggml_bf16_t *)B, ldb,
3832
                    (float *)C, ldc};
3833
            #else // LMUL = 4
3834
                tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
3835
                    k, (const ggml_bf16_t *)A, lda,
3836
                    (const ggml_bf16_t *)B, ldb,
3837
                    (float *)C, ldc};
3838
            #endif
3839
                return tb.matmul(m, n);
3840
        }
3841
#endif
3842
0
        return false;
3843
0
    }
3844
3845
0
    case GGML_TYPE_F16: {
3846
#if defined(__AVX512F__)
3847
        if (Btype == GGML_TYPE_F16) {
3848
            tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
3849
                (const ggml_fp16_t *)A, lda,
3850
                (const ggml_fp16_t *)B, ldb,
3851
                (float *)C, ldc};
3852
            return tb.matmul(m, n);
3853
        }
3854
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
3855
0
        if (Btype == GGML_TYPE_F16) {
3856
0
            tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
3857
0
                (const ggml_fp16_t *)A, lda,
3858
0
                (const ggml_fp16_t *)B, ldb,
3859
0
                (float *)C, ldc};
3860
0
            return tb.matmul(m, n);
3861
0
        }
3862
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
3863
        if (n < 8)
3864
            return false;
3865
        if (Btype == GGML_TYPE_F16) {
3866
            tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3867
                k, (const ggml_fp16_t *)A, lda,
3868
                (const ggml_fp16_t *)B, ldb,
3869
                (float *)C, ldc};
3870
            return tb.matmul(m, n);
3871
        }
3872
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
3873
        if (Btype == GGML_TYPE_F32) {
3874
            tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
3875
                k, (const ggml_fp16_t *)A, lda,
3876
                (const float *)B, ldb,
3877
                (float *)C, ldc};
3878
            return tb.matmul(m, n);
3879
        }
3880
#elif defined(__VXE__) || defined(__VXE2__)
3881
        if (n < 4)
3882
            return false;
3883
        if (Btype == GGML_TYPE_F16) {
3884
            tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3885
                k, (const ggml_fp16_t *)A, lda,
3886
                (const ggml_fp16_t *)B, ldb,
3887
                (float *)C, ldc};
3888
            return tb.matmul(m, n);
3889
        }
3890
#elif defined(__riscv_zvfh)
3891
        if (Btype == GGML_TYPE_F16) {
3892
        #if LMUL == 1
3893
            tinyBLAS_RVV<vfloat32m1_t, vfloat16mf2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3894
                k, (const ggml_fp16_t *)A, lda,
3895
                (const ggml_fp16_t *)B, ldb,
3896
                (float *)C, ldc};
3897
        #elif LMUL == 2
3898
            tinyBLAS_RVV<vfloat32m2_t, vfloat16m1_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3899
                k, (const ggml_fp16_t *)A, lda,
3900
                (const ggml_fp16_t *)B, ldb,
3901
                (float *)C, ldc};
3902
        #else // LMUL = 4
3903
            tinyBLAS_RVV<vfloat32m4_t, vfloat16m2_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
3904
                k, (const ggml_fp16_t *)A, lda,
3905
                (const ggml_fp16_t *)B, ldb,
3906
                (float *)C, ldc};
3907
        #endif
3908
            return tb.matmul(m, n);
3909
        }
3910
#elif defined(__MMA__)
3911
        if (k % 8) {
3912
            return false;
3913
        }
3914
3915
        if (Btype == GGML_TYPE_F16) {
3916
            tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
3917
                (const ggml_fp16_t *)A, lda,
3918
                (const ggml_fp16_t *)B, ldb,
3919
                (float *)C, ldc,
3920
                params->ith, params->nth };
3921
3922
            tb.matmul(m, n);
3923
            return true;
3924
        }
3925
#endif
3926
0
        return false;
3927
0
    }
3928
3929
0
    case GGML_TYPE_Q8_0: {
3930
0
        if (Btype != GGML_TYPE_Q8_0)
3931
0
           return false;
3932
0
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
3933
0
        tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
3934
0
            k, (const block_q8_0 *)A, lda,
3935
0
            (const block_q8_0 *)B, ldb,
3936
0
            (float *)C, ldc,
3937
0
            params->ith, params->nth};
3938
0
        tb.matmul(m, n);
3939
0
        return true;
3940
#elif defined(__ARM_FEATURE_DOTPROD)
3941
        tinyBLAS_Q0_ARM<block_q8_0> tb{
3942
            k, (const block_q8_0 *)A, lda,
3943
            (const block_q8_0 *)B, ldb,
3944
            (float *)C, ldc,
3945
            params->ith, params->nth};
3946
        tb.matmul(m, n);
3947
        return true;
3948
#elif defined(__MMA__)
3949
    //TO-DO: Remove this condition once gemv forwarding is enabled.
3950
        if (n < 8 && n != 4)
3951
           return false;
3952
        if (m < 8 && m != 4)
3953
           return false;
3954
        tinyBLAS_Q0_PPC<block_q8_0> tb{
3955
            k, (const block_q8_0 *)A, lda,
3956
            (const block_q8_0 *)B, ldb,
3957
            (float *)C, ldc,
3958
            params->ith, params->nth};
3959
        tb.matmul(m, n);
3960
        return true;
3961
#else
3962
        return false;
3963
#endif
3964
0
    }
3965
3966
0
    case GGML_TYPE_Q4_0: {
3967
0
        if (Btype != GGML_TYPE_Q8_0)
3968
0
            return false;
3969
0
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
3970
0
        tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
3971
0
            k, (const block_q4_0 *)A, lda,
3972
0
            (const block_q8_0 *)B, ldb,
3973
0
            (float *)C, ldc,
3974
0
            params->ith, params->nth};
3975
0
        tb.matmul(m, n);
3976
0
        return true;
3977
#elif defined(__ARM_FEATURE_DOTPROD)
3978
        tinyBLAS_Q0_ARM<block_q4_0> tb{
3979
            k, (const block_q4_0 *)A, lda,
3980
            (const block_q8_0 *)B, ldb,
3981
            (float *)C, ldc,
3982
            params->ith, params->nth};
3983
        tb.matmul(m, n);
3984
        return true;
3985
#elif defined(__MMA__)
3986
    //TO-DO: Remove this condition once gemv forwarding is enabled.
3987
        if (n < 8 && n != 4)
3988
           return false;
3989
        if (m < 8 && m != 4)
3990
           return false;
3991
        tinyBLAS_Q0_PPC<block_q4_0> tb{
3992
            k, (const block_q4_0 *)A, lda,
3993
            (const block_q8_0 *)B, ldb,
3994
            (float *)C, ldc,
3995
            params->ith, params->nth};
3996
        tb.matmul(m, n);
3997
        return true;
3998
#else
3999
        return false;
4000
#endif
4001
0
    }
4002
4003
0
    case GGML_TYPE_Q5_0: {
4004
0
        if (Btype != GGML_TYPE_Q8_0)
4005
0
            return false;
4006
0
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
4007
0
        tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
4008
0
            k, (const block_q5_0 *)A, lda,
4009
0
            (const block_q8_0 *)B, ldb,
4010
0
            (float *)C, ldc,
4011
0
            params->ith, params->nth};
4012
0
        tb.matmul(m, n);
4013
0
        return true;
4014
#else
4015
        return false;
4016
#endif
4017
0
    }
4018
4019
0
    case GGML_TYPE_IQ4_NL: {
4020
0
        if (Btype != GGML_TYPE_Q8_0)
4021
0
            return false;
4022
0
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
4023
0
        tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
4024
0
            k, (const block_iq4_nl *)A, lda,
4025
0
            (const block_q8_0 *)B, ldb,
4026
0
            (float *)C, ldc,
4027
0
            params->ith, params->nth};
4028
0
        tb.matmul(m, n);
4029
0
        return true;
4030
#else
4031
        return false;
4032
#endif
4033
0
    }
4034
4035
0
    default:
4036
0
        return false;
4037
0
    }
4038
4039
0
    (void)params;
4040
0
    (void)m;
4041
0
    (void)n;
4042
0
    (void)k;
4043
0
    (void)A;
4044
0
    (void)lda;
4045
0
    (void)B;
4046
0
    (void)ldb;
4047
0
    (void)C;
4048
0
    (void)ldc;
4049
0
    (void)Atype;
4050
0
    (void)Btype;
4051
0
    (void)Ctype;
4052
0
}