Coverage Report

Created: 2025-12-14 06:24

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp
Line
Count
Source
1
// Copyright 2024 Mozilla Foundation
2
//
3
// Permission is hereby granted, free of charge, to any person obtaining
4
// a copy of this software and associated documentation files (the
5
// "Software"), to deal in the Software without restriction, including
6
// without limitation the rights to use, copy, modify, merge, publish,
7
// distribute, sublicense, and/or sell copies of the Software, and to
8
// permit persons to whom the Software is furnished to do so, subject to
9
// the following conditions:
10
//
11
// The above copyright notice and this permission notice shall be
12
// included in all copies or substantial portions of the Software.
13
//
14
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
18
// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
19
// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
// SOFTWARE.
22
23
//
24
//                   _   _          ___ _      _   ___
25
//                  | |_(_)_ _ _  _| _ ) |    /_\ / __|
26
//                  |  _| | ' \ || | _ \ |__ / _ \\__ \.
27
//                   \__|_|_||_\_, |___/____/_/ \_\___/
28
//                             |__/
29
//
30
//                    BASIC LINEAR ALGEBRA SUBPROGRAMS
31
//
32
//
33
// This file implements multithreaded CPU matrix multiplication for the
34
// common contiguous use case C = Aᵀ * B. These kernels are designed to
35
// have excellent performance[1] for matrices that fit in the CPU cache
36
// without imposing any overhead such as cache filling or malloc calls.
37
//
38
// This implementation does not guarantee any upper bound with rounding
39
// errors, which grow along with k. Our goal's to maximally exploit the
40
// hardware for performance, and then use whatever resources remain for
41
// improving numerical accuracy.
42
//
43
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
44
//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
45
46
#if defined(__GNUC__)
47
#pragma GCC diagnostic ignored "-Wpedantic"
48
#pragma GCC diagnostic ignored "-Wignored-attributes"
49
#endif
50
51
#include "sgemm.h"
52
#include "ggml-impl.h"
53
#include "ggml-cpu-impl.h"
54
#include "ggml-quants.h"
55
#include "simd-mappings.h"
56
57
#include <array>
58
#include <type_traits>
59
60
#ifdef _MSC_VER
61
#define NOINLINE __declspec(noinline)
62
#else
63
#define NOINLINE __attribute__((__noinline__))
64
#endif
65
66
#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
67
#define VECTOR_REGISTERS 32
68
#else
69
#define VECTOR_REGISTERS 16
70
#endif
71
72
0
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
73
74
namespace {
75
76
0
inline float unhalf(ggml_fp16_t d) {
77
0
    return GGML_CPU_FP16_TO_FP32(d);
78
0
}
79
80
////////////////////////////////////////////////////////////////////////////////////////////////////
81
// VECTORIZED ARITHMETIC OPERATIONS
82
83
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
84
0
inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
85
0
inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
86
0
inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
87
#endif  // __SSE__
88
89
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
90
0
inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
91
0
inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
92
0
inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
93
#endif // __AVX__
94
95
#if defined(__AVX512F__)
96
inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
97
inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
98
inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
99
#endif // __AVX512F__
100
101
#if defined(__ARM_NEON)
102
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
103
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
104
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
105
#endif // __ARM_NEON
106
107
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
108
inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
109
inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
110
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
111
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
112
113
#if defined(__VXE__) || defined(__VXE2__)
114
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
115
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
116
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
117
#endif
118
119
#if defined(__MMA__)
120
#include "sgemm-ppc.h"
121
#endif
122
////////////////////////////////////////////////////////////////////////////////////////////////////
123
// VECTORIZED FUSED MULTIPLY ADD
124
125
/**
126
 * Computes a * b + c.
127
 */
128
template <typename T, typename U>
129
inline U madd(T a, T b, U c) {
130
    return add(mul(a, b), c);
131
}
132
133
#if defined(__FMA__)
134
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
135
template <>
136
0
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
137
0
    return _mm256_fmadd_ps(a, b, c);
138
0
}
139
#endif
140
#if defined(__AVX512F__)
141
template <>
142
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
143
    return _mm512_fmadd_ps(a, b, c);
144
}
145
#endif
146
#if defined(__AVX512BF16__)
147
template <>
148
inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
149
    return _mm512_dpbf16_ps(c, a, b);
150
}
151
template <>
152
inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
153
    return _mm256_dpbf16_ps(c, a, b);
154
}
155
#endif
156
#endif
157
158
#if defined(__ARM_FEATURE_FMA)
159
template <>
160
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
161
    return vfmaq_f32(c, b, a);
162
}
163
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
164
template <>
165
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
166
    return vfmaq_f16(c, b, a);
167
}
168
#endif
169
#endif
170
171
#if defined(__VXE__) || defined(__VXE2__)
172
template <>
173
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
174
    return vec_madd(a, b, c);
175
}
176
#endif
177
178
////////////////////////////////////////////////////////////////////////////////////////////////////
179
// VECTORIZED HORIZONTAL SUM
180
181
#if defined(__ARM_NEON)
182
inline float hsum(float32x4_t x) {
183
    return vaddvq_f32(x);
184
}
185
#endif // __ARM_NEON
186
187
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
188
inline float hsum(float16x8_t x) {
189
    return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
190
                                vcvt_f32_f16(vget_high_f16(x))));
191
}
192
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
193
194
#if defined(__VXE__) || defined(__VXE2__)
195
inline float hsum(float32x4_t x) {
196
    float32x4_t tmp = x + vec_reve(x);
197
    return tmp[0] + tmp[1];
198
}
199
#endif
200
201
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
202
0
inline float hsum(__m128 x) {
203
0
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
204
0
    x = _mm_add_ps(x, _mm_movehl_ps(x, x));
205
0
    x = _mm_add_ss(x, _mm_movehdup_ps(x));
206
#else
207
    __m128 t;
208
    t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
209
    x = _mm_add_ps(x, t);
210
    t = _mm_movehl_ps(t, x);
211
    x = _mm_add_ss(x, t);
212
#endif
213
0
    return _mm_cvtss_f32(x);
214
0
}
215
#endif
216
217
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
218
0
inline float hsum(__m256 x) {
219
0
    return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
220
0
                           _mm256_castps256_ps128(x)));
221
0
}
222
#endif // __AVX__
223
224
#if defined(__AVX512F__)
225
inline float hsum(__m512 x) {
226
    return _mm512_reduce_add_ps(x);
227
}
228
#endif // __AVX512F__
229
230
////////////////////////////////////////////////////////////////////////////////////////////////////
231
// VECTORIZED MEMORY LOADING
232
233
template <typename T, typename U> T load(const U *);
234
235
#if defined(__ARM_NEON)
236
template <> inline float32x4_t load(const float *p) {
237
    return vld1q_f32(p);
238
}
239
#if !defined(_MSC_VER)
240
// FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
241
template <> inline float16x8_t load(const ggml_fp16_t *p) {
242
    return vld1q_f16((const float16_t *)p);
243
}
244
template <> inline float32x4_t load(const ggml_fp16_t *p) {
245
    return vcvt_f32_f16(vld1_f16((const float16_t *)p));
246
}
247
#endif // _MSC_VER
248
#endif // __ARM_NEON
249
250
#if defined(__VXE__) || defined(__VXE2__)
251
template <> inline float32x4_t load(const ggml_fp16_t * p) {
252
    float tmp[4];
253
254
    for (int i = 0; i < 4; i++) {
255
        tmp[i] = GGML_CPU_FP16_TO_FP32(p[i]);
256
    }
257
258
    return vec_xl(0, (const float *)(tmp));
259
}
260
template <> inline float32x4_t load(const float * p) {
261
    return vec_xl(0, p);
262
}
263
#endif
264
265
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
266
0
template <> inline __m128 load(const float *p) {
267
0
    return _mm_loadu_ps(p);
268
0
}
269
#endif  // __SSE__
270
271
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
272
0
template <> inline __m256 load(const float *p) {
273
0
    return _mm256_loadu_ps(p);
274
0
}
275
#endif // __AVX__
276
277
#if defined(__AVX2__) || defined(__AVX512F__)
278
0
template <> inline __m256 load(const ggml_bf16_t *p) {
279
0
    return _mm256_castsi256_ps(
280
0
        _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16));
281
0
}
282
#endif // __AVX2__
283
284
#if defined(__F16C__)
285
0
template <> inline __m256 load(const ggml_fp16_t *p) {
286
0
    return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
287
0
}
288
#endif // __F16C__
289
290
#if defined(__AVX512F__)
291
template <> inline __m512 load(const float *p) {
292
    return _mm512_loadu_ps(p);
293
}
294
template <> inline __m512 load(const ggml_fp16_t *p) {
295
    return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
296
}
297
template <> inline __m512 load(const ggml_bf16_t *p) {
298
    return _mm512_castsi512_ps(
299
        _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
300
}
301
#endif // __AVX512F__
302
303
#if defined(__AVX512BF16__)
304
template <> inline __m512bh load(const ggml_bf16_t *p) {
305
    return (__m512bh)_mm512_loadu_ps((const float *)p);
306
}
307
template <> inline __m256bh load(const ggml_bf16_t *p) {
308
    return (__m256bh)_mm256_loadu_ps((const float *)p);
309
}
310
template <> inline __m512bh load(const float *p) {
311
    return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
312
}
313
template <> inline __m256bh load(const float *p) {
314
    return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
315
}
316
#endif
317
318
////////////////////////////////////////////////////////////////////////////////////////////////////
319
// FLOATING POINT MATRIX MULTIPLICATION
320
321
template <int M>
322
0
static inline int64_t BLOCK_SIZE(size_t m) {
323
0
    const int64_t NB_BLOC_M = (m + M - 1) / M;
324
0
    return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
325
0
}
326
327
0
static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
328
0
    return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
329
0
}
330
331
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
332
class tinyBLAS {
333
  public:
334
    tinyBLAS(const ggml_compute_params * params, int64_t k,
335
             const TA *A, int64_t lda,
336
             const TB *B, int64_t ldb,
337
             TC *C, int64_t ldc)
338
0
        : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
339
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::tinyBLAS(ggml_compute_params const*, long, float const*, long, float const*, long, float*, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::tinyBLAS(ggml_compute_params const*, long, ggml_bf16_t const*, long, ggml_bf16_t const*, long, float*, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::tinyBLAS(ggml_compute_params const*, long, unsigned short const*, long, unsigned short const*, long, float*, long)
340
341
0
    bool matmul(int64_t m, int64_t n) {
342
0
        if (k % KN != 0)
343
0
            return false;
344
        // compute RM for only need tile with size RM&RM-1
345
#if VECTOR_REGISTERS == 32
346
        if (m % 16 == 0 && (m/16 >= params->nth)) {
347
            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
348
            mnpack<4, 6, 4>(m, n, SIZE_N, 12);
349
            return true;
350
        }
351
        if (m % 8 == 0 ) {
352
            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
353
            mnpack<4, 6, 2>(m, n, SIZE_N, 12);
354
            return true;
355
        }
356
        if (m % 4 == 0) {
357
            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
358
            mnpack<4, 6, 1>(m, n, SIZE_N, 12);
359
            return true;
360
        }
361
#else  // VECTOR_REGISTERS == 16
362
0
        if (m % 16 == 0 && (m/16 >= params->nth)) {
363
0
            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
364
0
            mnpack<4, 3, 4>(m, n, SIZE_N, 24);
365
0
            return true;
366
0
        }
367
0
        if (m % 8 == 0 ) {
368
0
            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
369
0
            mnpack<4, 3, 2>(m, n, SIZE_N, 24);
370
0
            return true;
371
0
        }
372
0
        if (m % 4 == 0) {
373
0
            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
374
0
            mnpack<4, 3, 1>(m, n, SIZE_N, 24);
375
0
            return true;
376
0
        }
377
0
#endif
378
0
        return false;
379
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::matmul(long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::matmul(long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::matmul(long, long)
380
381
  private:
382
    template <int RM, int RN, int BM>
383
0
    inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
384
0
        if (SIZE_N == RN) {
385
0
            return gemm<RM, RN, BM>(m, n, BN);
386
0
        }
387
0
        if constexpr (RN > 1) {
388
0
            return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
389
0
        } else {
390
0
            GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
391
0
            GGML_ASSERT(false); // we have miss something.
392
0
        }
393
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 3, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 2, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 1, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::mnpack<4, 1, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 3, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 2, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 1, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::mnpack<4, 1, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 3, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 2, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 1, 4>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::mnpack<4, 1, 1>(long, long, long, long)
394
395
    template <int RM, int RN>
396
0
    inline void gemm_bloc(int64_t ii, int64_t jj) {
397
0
        D Cv[RN][RM] = {};
398
0
        for (int64_t l = 0; l < k; l += KN) {
399
            // help compiler for op order.
400
            if constexpr (RM <= RN) {
401
                V Av[RM];
402
                for (int64_t i = 0; i < RM; ++i) {
403
                    Av[i] = load<V>(A + lda * (ii + i) + l);
404
                }
405
                for (int64_t j = 0; j < RN; ++j) {
406
                    V Bv = load<V>(B + ldb * (jj + j) + l);
407
                    for (int64_t i = 0; i < RM; ++i) {
408
                        Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
409
                    }
410
                }
411
0
            } else {
412
0
                V Bv[RN];
413
0
                for (int64_t j = 0; j < RN; ++j) {
414
0
                    Bv[j] = load<V>(B + ldb * (jj + j) + l);
415
0
                }
416
0
                for (int64_t i = 0; i < RM; ++i) {
417
0
                    V Av = load<V>(A + lda * (ii + i) + l);
418
0
                    for (int64_t j = 0; j < RN; ++j) {
419
0
                        Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
420
0
                    }
421
0
                }
422
0
            }
423
0
        }
424
0
        for (int64_t j = 0; j < RN; ++j)
425
0
            for (int64_t i = 0; i < RM; ++i)
426
0
                C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
427
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm_bloc<4, 3>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm_bloc<4, 2>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm_bloc<4, 1>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm_bloc<4, 3>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm_bloc<4, 2>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm_bloc<4, 1>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm_bloc<4, 3>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm_bloc<4, 2>(long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm_bloc<4, 1>(long, long)
428
429
    template <int RM, int RN, int BM>
430
0
    NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
431
0
        GGML_ASSERT(m % (RM * BM) == 0);
432
0
        const int64_t ytiles = m / (RM * BM);
433
0
        const int64_t xtiles = (n + RN -1) / RN;
434
0
        const int64_t jj_RN = (xtiles - (xtiles * RN - n));
435
436
        // "round" bloc_size to "nearest" BN
437
0
        const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
438
0
        const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
439
0
        const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
440
0
        const int64_t nb_job = ytiles * NB_BN;
441
442
0
        if (params->ith == 0) {
443
0
            GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
444
            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
445
0
            ggml_threadpool_chunk_set(params->threadpool, params->nth);
446
0
        }
447
448
0
        ggml_barrier(params->threadpool);
449
450
0
        int64_t job = params->ith;
451
0
        while (job < nb_job) {
452
0
            const int64_t ii = (job % ytiles) * RM * BM;
453
0
            const int64_t jb =  job / ytiles;
454
0
            const int64_t jr0 = BLOC_POS(jb  , jj_BN, SIZE_BN);
455
0
            const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
456
457
0
            const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
458
0
            const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
459
0
            const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
460
461
0
            for (int64_t bi = 0; bi < BM * RM; bi += RM) {
462
0
                int64_t jj = jj0;
463
0
                for (; jj < jj1; jj += RN) {
464
0
                    gemm_bloc<RM, RN>(ii + bi, jj);
465
0
                }
466
0
                if constexpr (RN > 1) {
467
0
                    for (; jj < jj2; jj += RN - 1) {
468
0
                        gemm_bloc<RM, RN-1>(ii + bi, jj);
469
0
                    }
470
0
                }
471
0
                GGML_ASSERT(jj == jj2);
472
0
            }
473
474
0
            job = ggml_threadpool_chunk_add(params->threadpool, 1);
475
0
        }
476
477
0
        ggml_barrier(params->threadpool);
478
0
        return;
479
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 3, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 2, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 1, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 3, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 2, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 1, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 3, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 2, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), float, float, float>::gemm<4, 1, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 3, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 2, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 1, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 3, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 2, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 1, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 3, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 2, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), ggml_bf16_t, ggml_bf16_t, float>::gemm<4, 1, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 3, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 2, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 1, 4>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 3, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 2, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 1, 2>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 3, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 2, 1>(long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS<8, float __vector(8), float __vector(8), unsigned short, unsigned short, float>::gemm<4, 1, 1>(long, long, long)
480
481
    const ggml_compute_params * params;
482
    const TA *const A;
483
    const TB *const B;
484
    TC *const C;
485
    const int64_t k;
486
    const int64_t lda;
487
    const int64_t ldb;
488
    const int64_t ldc;
489
};
490
491
//////////////////////////////////////////////////////////////////////////////////////////
492
// QUANT ZERO MATRIX MULTIPLICATION
493
494
#if defined(__ARM_FEATURE_DOTPROD)
495
template <typename TA>
496
class tinyBLAS_Q0_ARM {
497
  public:
498
    tinyBLAS_Q0_ARM(int64_t k,
499
                    const TA *A, int64_t lda,
500
                    const block_q8_0 *B, int64_t ldb,
501
                    float *C, int64_t ldc,
502
                    int ith, int nth)
503
        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
504
    }
505
506
    void matmul(int64_t m, int64_t n) {
507
        mnpack(0, m, 0, n);
508
    }
509
510
  private:
511
    NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
512
        int64_t mc, nc, mp, np;
513
        switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
514
        case 0x33:
515
            mc = 3;
516
            nc = 3;
517
            gemm<3, 3>(m0, m, n0, n);
518
            break;
519
        case 0x32:
520
            mc = 3;
521
            nc = 2;
522
            gemm<3, 2>(m0, m, n0, n);
523
            break;
524
        case 0x23:
525
            mc = 2;
526
            nc = 3;
527
            gemm<2, 3>(m0, m, n0, n);
528
            break;
529
        case 0x22:
530
            mc = 2;
531
            nc = 2;
532
            gemm<2, 2>(m0, m, n0, n);
533
            break;
534
        case 0x31:
535
            mc = 3;
536
            nc = 1;
537
            gemm<3, 1>(m0, m, n0, n);
538
            break;
539
        case 0x13:
540
            mc = 1;
541
            nc = 3;
542
            gemm<1, 3>(m0, m, n0, n);
543
            break;
544
        case 0x21:
545
            mc = 2;
546
            nc = 1;
547
            gemm<2, 1>(m0, m, n0, n);
548
            break;
549
        case 0x12:
550
            mc = 1;
551
            nc = 2;
552
            gemm<1, 2>(m0, m, n0, n);
553
            break;
554
        case 0x11:
555
            mc = 1;
556
            nc = 1;
557
            gemm<1, 1>(m0, m, n0, n);
558
            break;
559
        default:
560
            return;
561
        }
562
        mp = m0 + (m - m0) / mc * mc;
563
        np = n0 + (n - n0) / nc * nc;
564
        mnpack(mp, m, n0, np);
565
        mnpack(m0, m, np, n);
566
    }
567
568
    template <int RM, int RN>
569
    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
570
        int64_t ytiles = (m - m0) / RM;
571
        int64_t xtiles = (n - n0) / RN;
572
        int64_t tiles = xtiles * ytiles;
573
        int64_t duty = (tiles + nth - 1) / nth;
574
        int64_t start = duty * ith;
575
        int64_t end = start + duty;
576
        if (end > tiles)
577
            end = tiles;
578
        for (int64_t job = start; job < end; ++job) {
579
            int64_t ii = m0 + job / xtiles * RM;
580
            int64_t jj = n0 + job % xtiles * RN;
581
            float32x4_t Cv[RN][RM] = {};
582
            for (int64_t l = 0; l < k; ++l)
583
                for (int64_t j = 0; j < RN; ++j)
584
                    for (int64_t i = 0; i < RM; ++i)
585
                        Cv[j][i] = vmlaq_n_f32(Cv[j][i],
586
                                               vcvtq_f32_s32(vdotq_s32(
587
                                                   vdotq_s32(vdupq_n_s32(0),
588
                                                             load_lo(A + lda * (ii + i) + l),
589
                                                             load_lo(B + ldb * (jj + j) + l)),
590
                                                   load_hi(A + lda * (ii + i) + l),
591
                                                   load_hi(B + ldb * (jj + j) + l))),
592
                                               unhalf(A[lda * (ii + i) + l].d) *
593
                                               unhalf(B[ldb * (jj + j) + l].d));
594
            for (int64_t j = 0; j < RN; ++j)
595
                for (int64_t i = 0; i < RM; ++i)
596
                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
597
        }
598
    }
599
600
    inline int8x16_t load_lo(const block_q8_0 *b) {
601
        return vld1q_s8(b->qs);
602
    }
603
604
    inline int8x16_t load_hi(const block_q8_0 *b) {
605
        return vld1q_s8(b->qs + 16);
606
    }
607
608
    inline int8x16_t load_lo(const block_q4_0 *b) {
609
        return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
610
                                                     vdupq_n_u8(0x0f))),
611
                        vdupq_n_s8(0x8));
612
    }
613
614
    inline int8x16_t load_hi(const block_q4_0 *b) {
615
        return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
616
                        vdupq_n_s8(0x8));
617
    }
618
619
    const TA *const A;
620
    const block_q8_0 *const B;
621
    float *const C;
622
    const int64_t k;
623
    const int64_t lda;
624
    const int64_t ldb;
625
    const int64_t ldc;
626
    const int ith;
627
    const int nth;
628
};
629
#endif // __ARM_FEATURE_DOTPROD
630
631
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
632
template <typename TA, typename TB, typename TC>
633
class tinyBLAS_Q0_AVX {
634
  public:
635
    tinyBLAS_Q0_AVX(int64_t k,
636
                    const TA *A, int64_t lda,
637
                    const TB *B, int64_t ldb,
638
                    TC *C, int64_t ldc,
639
                    int ith, int nth)
640
0
        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
641
0
        const int8_t kvalues_iq4nl[16] = {
642
0
            -127, -104, -83, -65,
643
0
            -49,  -35,  -22, -10,
644
0
              1,   13,   25,  38,
645
0
             53,   69,   89, 113
646
0
        };
647
648
0
        iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
649
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::tinyBLAS_Q0_AVX(long, block_q8_0 const*, long, block_q8_0 const*, long, float*, long, int, int)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::tinyBLAS_Q0_AVX(long, block_q4_0 const*, long, block_q8_0 const*, long, float*, long, int, int)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::tinyBLAS_Q0_AVX(long, block_q5_0 const*, long, block_q8_0 const*, long, float*, long, int, int)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::tinyBLAS_Q0_AVX(long, block_iq4_nl const*, long, block_q8_0 const*, long, float*, long, int, int)
650
651
0
    void matmul(int64_t m, int64_t n) {
652
0
        mnpack(0, m, 0, n);
653
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::matmul(long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::matmul(long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::matmul(long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::matmul(long, long)
654
655
  private:
656
0
    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
657
0
        int64_t mc, nc, mp, np;
658
0
        switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
659
#if VECTOR_REGISTERS == 32
660
        case 0x44:
661
            mc = 4;
662
            nc = 4;
663
#if defined(__AVX2__) && defined(__F16C__)
664
            gemm4xN<4>(m0, m, n0, n);
665
#else
666
            gemm<4, 4>(m0, m, n0, n);
667
#endif
668
            break;
669
        case 0x43:
670
            mc = 4;
671
            nc = 3;
672
#if defined(__AVX2__) && defined(__F16C__)
673
            gemm4xN<3>(m0, m, n0, n);
674
#else
675
            gemm<4, 3>(m0, m, n0, n);
676
#endif
677
            break;
678
        case 0x34:
679
            mc = 3;
680
            nc = 4;
681
#if defined(__AVX2__) && defined(__F16C__)
682
            gemmMx4<3>(m0, m, n0, n);
683
#else
684
            gemm<3, 4>(m0, m, n0, n);
685
#endif
686
            break;
687
        case 0x33:
688
            mc = 3;
689
            nc = 3;
690
            gemm<3, 3>(m0, m, n0, n);
691
            break;
692
        case 0x42:
693
            mc = 4;
694
            nc = 2;
695
#if defined(__AVX2__) && defined(__F16C__)
696
            gemm4xN<2>(m0, m, n0, n);
697
#else
698
            gemm<4, 2>(m0, m, n0, n);
699
#endif
700
            break;
701
        case 0x24:
702
            mc = 2;
703
            nc = 4;
704
#if defined(__AVX2__) && defined(__F16C__)
705
            gemmMx4<2>(m0, m, n0, n);
706
#else
707
            gemm<2, 4>(m0, m, n0, n);
708
#endif
709
            break;
710
#else
711
0
        case 0x44:
712
0
        case 0x43:
713
0
        case 0x42:
714
0
            mc = 4;
715
0
            nc = 2;
716
0
#if defined(__AVX2__) && defined(__F16C__)
717
0
            gemm4xN<2>(m0, m, n0, n);
718
#else
719
            gemm<4, 2>(m0, m, n0, n);
720
#endif
721
0
            break;
722
0
        case 0x34:
723
0
        case 0x24:
724
0
            mc = 2;
725
0
            nc = 4;
726
0
#if defined(__AVX2__) && defined(__F16C__)
727
0
            gemmMx4<2>(m0, m, n0, n);
728
#else
729
            gemm<2, 4>(m0, m, n0, n);
730
#endif
731
0
            break;
732
0
        case 0x33:
733
0
#endif
734
0
        case 0x32:
735
0
            mc = 3;
736
0
            nc = 2;
737
0
            gemm<3, 2>(m0, m, n0, n);
738
0
            break;
739
0
        case 0x23:
740
0
            mc = 2;
741
0
            nc = 3;
742
0
            gemm<2, 3>(m0, m, n0, n);
743
0
            break;
744
0
        case 0x41:
745
0
            mc = 4;
746
0
            nc = 1;
747
0
#if defined(__AVX2__) && defined(__F16C__)
748
0
            gemm4xN<1>(m0, m, n0, n);
749
#else
750
            gemm<4, 1>(m0, m, n0, n);
751
#endif
752
0
            break;
753
0
        case 0x22:
754
0
            mc = 2;
755
0
            nc = 2;
756
0
            gemm<2, 2>(m0, m, n0, n);
757
0
            break;
758
0
        case 0x14:
759
0
            mc = 1;
760
0
            nc = 4;
761
0
#if defined(__AVX2__) && defined(__F16C__)
762
0
            gemmMx4<1>(m0, m, n0, n);
763
#else
764
            gemm<1, 4>(m0, m, n0, n);
765
#endif
766
0
            break;
767
0
        case 0x31:
768
0
            mc = 3;
769
0
            nc = 1;
770
0
            gemm<3, 1>(m0, m, n0, n);
771
0
            break;
772
0
        case 0x13:
773
0
            mc = 1;
774
0
            nc = 3;
775
0
            gemm<1, 3>(m0, m, n0, n);
776
0
            break;
777
0
        case 0x21:
778
0
            mc = 2;
779
0
            nc = 1;
780
0
            gemm<2, 1>(m0, m, n0, n);
781
0
            break;
782
0
        case 0x12:
783
0
            mc = 1;
784
0
            nc = 2;
785
0
            gemm<1, 2>(m0, m, n0, n);
786
0
            break;
787
0
        case 0x11:
788
0
            mc = 1;
789
0
            nc = 1;
790
0
            gemm<1, 1>(m0, m, n0, n);
791
0
            break;
792
0
        default:
793
0
            return;
794
0
        }
795
0
        mp = m0 + (m - m0) / mc * mc;
796
0
        np = n0 + (n - n0) / nc * nc;
797
0
        mnpack(mp, m, n0, np);
798
0
        mnpack(m0, m, np, n);
799
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::mnpack(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::mnpack(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::mnpack(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::mnpack(long, long, long, long)
800
801
#if defined(__AVX2__) && defined(__F16C__)
802
// Templated functions for gemm of dimensions 4xN
803
    template <int RN>
804
0
    NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
805
0
        int64_t ytiles = (m - m0) / 4;
806
0
        int64_t xtiles = (n - n0) / RN;
807
0
        int64_t tiles = xtiles * ytiles;
808
0
        int64_t duty = (tiles + nth - 1) / nth;
809
0
        int64_t start = duty * ith;
810
0
        int64_t end = start + duty;
811
0
        if (end > tiles)
812
0
            end = tiles;
813
0
        for (int64_t job = start; job < end; ++job) {
814
0
            int64_t ii = m0 + job / xtiles * 4;
815
0
            int64_t jj = n0 + job % xtiles * RN;
816
0
            __m256 Cv[RN][4] = {};
817
0
            for (int64_t l = 0; l < k; ++l) {
818
0
                uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
819
                // Convert delta values for four blocks to float values
820
0
                __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
821
0
                __m256i avec0 = load(A + lda * (ii + 0) + l);
822
0
                __m256i avec1 = load(A + lda * (ii + 1) + l);
823
0
                __m256i avec2 = load(A + lda * (ii + 2) + l);
824
0
                __m256i avec3 = load(A + lda * (ii + 3) + l);
825
0
                for (int64_t j = 0; j < RN; ++j) {
826
0
                        __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
827
                        // Computation of product of delta values for four blocks and replicate it across 256 bit lane
828
0
                        __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));
829
0
                        dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
830
                        // Computation of dot product and multiplication with appropriate delta value products
831
0
                        Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
832
0
                                    updot(_mm256_sign_epi8(avec0, avec0),
833
0
                                          _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
834
0
                                    Cv[j][0]);
835
0
                        Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
836
0
                                    updot(_mm256_sign_epi8(avec1, avec1),
837
0
                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
838
0
                                    Cv[j][1]);
839
0
                        Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
840
0
                                    updot(_mm256_sign_epi8(avec2, avec2),
841
0
                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
842
0
                                    Cv[j][2]);
843
0
                        Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
844
0
                                    updot(_mm256_sign_epi8(avec3, avec3),
845
0
                                            _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
846
0
                                    Cv[j][3]);
847
0
                }
848
0
            }
849
850
0
            for (int64_t j = 0; j < RN; ++j)
851
0
                for (int64_t i = 0; i < 4; ++i)
852
0
                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
853
0
        }
854
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm4xN<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm4xN<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm4xN<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm4xN<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm4xN<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm4xN<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm4xN<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm4xN<1>(long, long, long, long)
855
856
    // Templated functions for gemm of dimensions Mx4
857
    template <int RM>
858
0
    NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
859
0
        int64_t ytiles = (m - m0) / RM;
860
0
        int64_t xtiles = (n - n0) / 4;
861
0
        int64_t tiles = xtiles * ytiles;
862
0
        int64_t duty = (tiles + nth - 1) / nth;
863
0
        int64_t start = duty * ith;
864
0
        int64_t end = start + duty;
865
0
        if (end > tiles)
866
0
            end = tiles;
867
0
        for (int64_t job = start; job < end; ++job) {
868
0
            int64_t ii = m0 + job / xtiles * RM;
869
0
            int64_t jj = n0 + job % xtiles * 4;
870
0
            __m256 Cv[4][RM] = {};
871
0
            for (int64_t l = 0; l < k; ++l) {
872
0
                uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
873
                // Convert delta values for four blocks to float values
874
0
                __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
875
0
                __m256i bvec0 = load(B + ldb * (jj + 0) + l);
876
0
                __m256i bvec1 = load(B + ldb * (jj + 1) + l);
877
0
                __m256i bvec2 = load(B + ldb * (jj + 2) + l);
878
0
                __m256i bvec3 = load(B + ldb * (jj + 3) + l);
879
0
                for (int64_t i = 0; i < RM; ++i) {
880
0
                    __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
881
                    // Computation of product of delta values for four blocks and replicate it across 256 bit lane
882
0
                    __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db));
883
0
                    dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
884
                    // Computation of dot product and multiplication with appropriate delta value products
885
0
                    Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
886
0
                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
887
0
                                                            load(A + lda * (ii + i) + l)),
888
0
                                            _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
889
0
                                    Cv[0][i]);
890
0
                    Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
891
0
                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
892
0
                                                            load(A + lda * (ii + i) + l)),
893
0
                                            _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
894
0
                                    Cv[1][i]);
895
0
                    Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
896
0
                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
897
0
                                                            load(A + lda * (ii + i) + l)),
898
0
                                            _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
899
0
                                    Cv[2][i]);
900
0
                    Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
901
0
                                    updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
902
0
                                                            load(A + lda * (ii + i) + l)),
903
0
                                            _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
904
0
                                    Cv[3][i]);
905
0
                }
906
0
            }
907
0
            for (int64_t j = 0; j < 4; ++j)
908
0
                for (int64_t i = 0; i < RM; ++i)
909
0
                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
910
0
        }
911
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemmMx4<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemmMx4<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemmMx4<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemmMx4<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemmMx4<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemmMx4<1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemmMx4<2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemmMx4<1>(long, long, long, long)
912
#endif
913
914
    template <int RM, int RN>
915
0
    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
916
0
        int64_t ytiles = (m - m0) / RM;
917
0
        int64_t xtiles = (n - n0) / RN;
918
0
        int64_t tiles = xtiles * ytiles;
919
0
        int64_t duty = (tiles + nth - 1) / nth;
920
0
        int64_t start = duty * ith;
921
0
        int64_t end = start + duty;
922
0
        if (end > tiles)
923
0
            end = tiles;
924
0
        for (int64_t job = start; job < end; ++job) {
925
0
            int64_t ii = m0 + job / xtiles * RM;
926
0
            int64_t jj = n0 + job % xtiles * RN;
927
0
            __m256 Cv[RN][RM] = {};
928
0
            for (int64_t l = 0; l < k; ++l)
929
0
                for (int64_t j = 0; j < RN; ++j)
930
0
                    for (int64_t i = 0; i < RM; ++i) {
931
0
#if defined(__AVX2__)
932
0
                        __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
933
0
                                                              load(A + lda * (ii + i) + l)),
934
0
                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
935
0
                                                              load(A + lda * (ii + i) + l)));
936
#else
937
                        __m128i ali0 = load0(A + lda * (ii + i) + l);
938
                        __m128i ali1 = load1(A + lda * (ii + i) + l);
939
                        __m128i blj0 = load0(B + ldb * (jj + j) + l);
940
                        __m128i blj1 = load1(B + ldb * (jj + j) + l);
941
942
                        __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
943
                        __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
944
                        __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
945
                        __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
946
947
                        // updot
948
                        const __m128i oneFill = _mm_set1_epi16(1);
949
                        __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
950
                        __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
951
                        __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
952
#endif
953
0
                        Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
954
0
                                                       unhalf(B[ldb * (jj + j) + l].d)),
955
0
                                                       udTmp,
956
0
                                                       Cv[j][i]);
957
0
                    }
958
0
            for (int64_t j = 0; j < RN; ++j)
959
0
                for (int64_t i = 0; i < RM; ++i)
960
0
                    C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
961
0
        }
962
0
    }
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<2, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<1, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::gemm<1, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<2, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<1, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::gemm<1, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<2, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<1, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::gemm<1, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<3, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<2, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<2, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<3, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<1, 3>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<2, 1>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<1, 2>(long, long, long, long)
Unexecuted instantiation: sgemm.cpp:void (anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::gemm<1, 1>(long, long, long, long)
963
964
0
    inline __m256i load(const block_q8_0 *b) {
965
0
        return _mm256_loadu_si256((const __m256i *)b->qs);
966
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::load(block_q8_0 const*)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::load(block_q8_0 const*)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::load(block_q8_0 const*)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::load(block_q8_0 const*)
967
968
    inline __m128i load0(const block_q8_0 *b) {
969
        return _mm_loadu_si128((const __m128i *)b->qs);
970
    }
971
972
    inline __m128i load1(const block_q8_0 *b) {
973
        return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
974
    }
975
976
0
    inline __m256i load(const block_q4_0 *b) {
977
0
        return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
978
0
    }
979
980
    inline __m128i load0(const block_q4_0 *b) {
981
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
982
        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
983
    }
984
985
    inline __m128i load1(const block_q4_0 *b) {
986
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
987
        return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
988
    }
989
990
0
    inline __m256i load(const block_q5_0 *b) {
991
0
        return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
992
0
    }
993
994
    inline __m128i load0(const block_q5_0* b) {
995
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
996
        uint32_t x32;
997
        memcpy(&x32, b->qh, sizeof(uint32_t));
998
        __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
999
        __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
1000
                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
1001
                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),
1002
                                                                      _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
1003
        bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
1004
        return _mm_or_si128(qxl, bytesl);
1005
    }
1006
1007
    inline __m128i load1(const block_q5_0* b) {
1008
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1009
        uint32_t x32;
1010
        memcpy(&x32, b->qh, sizeof(uint32_t));
1011
        __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
1012
        __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
1013
                                        _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
1014
                                                     _mm_shuffle_epi8(_mm_set1_epi32(x32),
1015
                                                                      _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
1016
        bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
1017
        return _mm_or_si128(qxh, bytesh);
1018
    }
1019
1020
0
    inline __m256i load(const block_iq4_nl *b) {
1021
0
        return MM256_SET_M128I(load1(b), load0(b));
1022
0
    }
1023
1024
0
    inline __m128i load0(const block_iq4_nl *b) {
1025
0
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1026
0
        return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
1027
0
    }
1028
1029
0
    inline __m128i load1(const block_iq4_nl *b) {
1030
0
        const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
1031
0
        return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
1032
0
    }
1033
1034
0
    inline __m256 updot(__m256i u, __m256i s) {
1035
0
        __m256i res;
1036
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
1037
        res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
1038
#elif defined(__AVXVNNI__)
1039
        res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
1040
#else
1041
0
        res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
1042
0
#endif
1043
0
        return _mm256_cvtepi32_ps(res);
1044
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float>::updot(long long __vector(4), long long __vector(4))
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::updot(long long __vector(4), long long __vector(4))
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::updot(long long __vector(4), long long __vector(4))
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float>::updot(long long __vector(4), long long __vector(4))
1045
1046
0
    static inline __m256i denibble(const uint8_t *p) {
1047
0
        __m128i x = _mm_loadu_si128((const __m128i *)p);
1048
0
        return _mm256_and_si256(_mm256_set1_epi8(15),
1049
0
                                _mm256_insertf128_si256(_mm256_castsi128_si256(x),
1050
0
                                                        _mm_srli_epi16(x, 4), 1));
1051
0
    }
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float>::denibble(unsigned char const*)
Unexecuted instantiation: sgemm.cpp:(anonymous namespace)::tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float>::denibble(unsigned char const*)
1052
1053
0
    static inline __m256i bittobyte(const uint8_t *p) {
1054
0
        uint32_t x32;
1055
0
        memcpy(&x32, p, sizeof(uint32_t));
1056
0
        __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
1057
0
                                          _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
1058
0
                                                          _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
1059
0
                                                                              _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
1060
0
                                                                                                0x0101010101010101, 0x0000000000000000))));
1061
0
        return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
1062
0
    }
1063
1064
    const TA *const A;
1065
    const TB *const B;
1066
    TC *const C;
1067
    const int64_t k;
1068
    const int64_t lda;
1069
    const int64_t ldb;
1070
    const int64_t ldc;
1071
    const int ith;
1072
    const int nth;
1073
    __m128i iq4nlt;
1074
};
1075
#endif // __AVX__
1076
1077
//PPC Implementation
1078
#if defined(__MMA__)
1079
1080
#define SAVE_ACC(ACC, ii, jj) \
1081
   __builtin_mma_disassemble_acc(vec_C, ACC); \
1082
   for (int I = 0; I < 4; I++) { \
1083
      for (int J = 0; J < 4; J++) { \
1084
         *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
1085
      } \
1086
   } \
1087
1088
template <typename TA, typename TB, typename TC>
1089
class tinyBLAS_BF16_PPC {
1090
  public:
1091
    tinyBLAS_BF16_PPC(int64_t k,
1092
                const TA *A, int64_t lda,
1093
                const TB *B, int64_t ldb,
1094
                TC *C, int64_t ldc,
1095
                int ith, int nth)
1096
        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1097
    }
1098
1099
    void matmul(int64_t m, int64_t n) {
1100
        mnpack(0, m, 0, n);
1101
    }
1102
1103
  private:
1104
    void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
1105
        vec_t t[8], s[8];
1106
        vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
1107
        vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
1108
        vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1109
        vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1110
1111
        if (numVec == 2) {
1112
            t[0] = vec_perm(c[0], c[1], swiz1);
1113
            t[1] = vec_perm(c[2], c[3], swiz1);
1114
            s[0] = vec_perm(t[0], t[1], swiz3);
1115
            s[1] = vec_perm(t[0], t[1], swiz4);
1116
            vec_xst(s[0], 0, (vec_t*)vecOffset);
1117
            vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
1118
        } else if (numVec == 4) {
1119
            t[0] = vec_perm(c[0], c[1], swiz1);
1120
            t[1] = vec_perm(c[0], c[1], swiz2);
1121
            t[2] = vec_perm(c[2], c[3], swiz1);
1122
            t[3] = vec_perm(c[2], c[3], swiz2);
1123
            s[0] = vec_perm(t[0], t[2], swiz3);
1124
            s[1] = vec_perm(t[0], t[2], swiz4);
1125
            s[2] = vec_perm(t[1], t[3], swiz3);
1126
            s[3] = vec_perm(t[1], t[3], swiz4);
1127
            for (int i = 0; i < 4; ++i)
1128
                vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1129
        } else if (numVec == 8) {
1130
            for (int i = 0; i < 4; i += 2) {
1131
                t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1132
                t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1133
            }
1134
            for (int i = 4; i < 8; i += 2) {
1135
                t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1136
                t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1137
            }
1138
            s[0] = vec_perm(t[0], t[2], swiz3);
1139
            s[1] = vec_perm(t[0], t[2], swiz4);
1140
            s[2] = vec_perm(t[1], t[3], swiz3);
1141
            s[3] = vec_perm(t[1], t[3], swiz4);
1142
            s[4] = vec_perm(t[4], t[6], swiz3);
1143
            s[5] = vec_perm(t[4], t[6], swiz4);
1144
            s[6] = vec_perm(t[5], t[7], swiz3);
1145
            s[7] = vec_perm(t[5], t[7], swiz4);
1146
            for (int i = 0; i < 8; ++i)
1147
                vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1148
        }
1149
    }
1150
1151
    void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
1152
        int64_t i, j;
1153
        TA *aoffset = NULL;
1154
        unsigned char *vecOffset = NULL;
1155
        TA * aoffsets[8];
1156
        vector unsigned char c_arr[8];
1157
        aoffset = const_cast<TA*>(a);
1158
        vecOffset = vec;
1159
        j = (rows >> 3);
1160
        if (j > 0) {
1161
            do {
1162
                if (cols == 4) {
1163
                    aoffsets[0] = aoffset;
1164
                    for (int it = 1; it < 4; ++it)
1165
                        aoffsets[it] = aoffsets[it-1] + lda;
1166
                    aoffset += 4 * lda;
1167
                    for (int i = 0; i < 4; ++i)
1168
                        c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
1169
                    vector_permute_store(c_arr, 4, vecOffset);
1170
                    for (int i = 0; i<4; i++)
1171
                        aoffsets[i] = aoffsets[i]+lda;
1172
                    vecOffset +=64;
1173
                }
1174
                i = (cols >> 3);
1175
                if (i > 0) {
1176
                    aoffsets[0] = aoffset;
1177
                    for (int it = 1; it < 8; ++it) {
1178
                        aoffsets[it] = aoffsets[it-1] + lda;
1179
                    }
1180
                    aoffset += 8 * lda;
1181
                    do {
1182
                        for (int it = 0; it < 8; ++it)
1183
                            c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1184
                        vector_permute_store(c_arr, 8, vecOffset);
1185
                        for (int it = 0; it < 8; ++it)
1186
                            aoffsets[it] = aoffsets[it] + 8*lda;
1187
                        vecOffset += 128;
1188
                        i--;
1189
                    } while(i > 0);
1190
                }
1191
                j--;
1192
            } while(j > 0);
1193
        }
1194
        if (rows & 4) {
1195
            aoffsets[0] = aoffset;
1196
            for (int it = 1; it < 4; ++it)
1197
                aoffsets[it] = aoffsets[it-1] + lda;
1198
            aoffset += 4 * lda;
1199
            if (cols == 4) {
1200
                for (int it = 0; it < 4; ++it)
1201
                    c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1202
                vector_permute_store(c_arr, 2, vecOffset);
1203
                for (int it = 0; it< 4; it++)
1204
                    aoffsets[it] = aoffsets[it] + lda;
1205
                vecOffset += 32;
1206
            }
1207
            i = (cols >> 3);
1208
            if (i > 0) {
1209
                do {
1210
                    for (int it = 0; it < 4; ++it)
1211
                        c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1212
                    vector_permute_store(c_arr, 4, vecOffset);
1213
                    for (int it = 0; it< 4; it++)
1214
                        aoffsets[it] = aoffsets[it] + 8*lda;
1215
                    vecOffset += 64;
1216
                    i--;
1217
                } while(i > 0);
1218
            }
1219
        }
1220
        if (rows & 3) {
1221
            aoffsets[0] = aoffset;
1222
            for (int it = 1; it < 4; ++it)
1223
                aoffsets[it] = aoffsets[it-1] + lda;
1224
            if (cols == 4) {
1225
                switch(rows) {
1226
                    case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1227
                    case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1228
                    case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1229
                        break;
1230
                }
1231
                vector_permute_store(c_arr, 2, vecOffset);
1232
                for (int it = 0; it< 4; it++)
1233
                     aoffsets[it] = aoffsets[it] + lda;
1234
                vecOffset += 32;
1235
            }
1236
            i = (cols >> 3);
1237
            if (i > 0) {
1238
                do {
1239
                    switch(rows) {
1240
                        case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1241
                        case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1242
                        case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1243
                            break;
1244
                    }
1245
                    vector_permute_store(c_arr, 4, vecOffset);
1246
                    for (int it = 0; it <4; it++)
1247
                         aoffsets[it] = aoffsets[it] + 8* lda;
1248
                    vecOffset += 64;
1249
                    i--;
1250
                } while(i > 0);
1251
            }
1252
        }
1253
    }
1254
1255
    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1256
        int64_t mc, nc, mp, np;
1257
        int m_rem = MIN(m - m0, 8);
1258
        int n_rem = MIN(n - n0, 8);
1259
1260
        if (m_rem >= 8 && n_rem >= 8) {
1261
            mc = 8;
1262
            nc = 8;
1263
            gemm<8,8>(m0, m, n0, n);
1264
        } else if (m_rem >= 4 && n_rem >= 8) {
1265
            mc = 4;
1266
            nc = 8;
1267
            gemm<4,8>(m0, m, n0, n);
1268
        } else if (m_rem >=8 && n_rem >=4){
1269
                mc = 8;
1270
                nc = 4;
1271
                gemm<8,4>(m0, m, n0, n);
1272
        } else if ((m_rem < 4) && (n_rem >= 8)) {
1273
            nc = 8;
1274
            switch(m_rem) {
1275
                case 1:
1276
                    mc = 1;
1277
                    gemm_Mx8<1>(m0, m, n0, n);
1278
                    break;
1279
                case 2:
1280
                    mc = 2;
1281
                    gemm_Mx8<2>(m0, m, n0, n);
1282
                    break;
1283
                case 3:
1284
                    mc = 3;
1285
                    gemm_Mx8<3>(m0, m, n0, n);
1286
                    break;
1287
                default:
1288
                    return;
1289
            }
1290
        } else if (m_rem >= 4 && n_rem >= 4) {
1291
            mc = 4;
1292
            nc = 4;
1293
            gemm_small<4, 4>(m0, m, n0, n);
1294
        } else if ((m_rem > 4) && (n_rem < 4)) {
1295
            mc = 4;
1296
            switch(n_rem) {
1297
                case 1:
1298
                    nc = 1;
1299
                    gemm_small<4, 1>(m0, m, n0, n);
1300
                    break;
1301
                case 2:
1302
                    nc = 2;
1303
                    gemm_small<4, 2>(m0, m, n0, n);
1304
                    break;
1305
                case 3:
1306
                    nc = 3;
1307
                    gemm_small<4, 3>(m0, m, n0, n);
1308
                    break;
1309
1310
                default:
1311
                    return;
1312
            }
1313
        } else {
1314
            switch((m_rem << 4) | n_rem) {
1315
                case 0x43:
1316
                    mc = 4;
1317
                    nc = 3;
1318
                    gemm_small<4, 3>(m0, m, n0, n);
1319
                    break;
1320
                case 0x42:
1321
                    mc = 4;
1322
                    nc = 2;
1323
                    gemm_small<4, 2>(m0, m, n0, n);
1324
                    break;
1325
                case 0x41:
1326
                    mc = 4;
1327
                    nc = 1;
1328
                    gemm_small<4, 1>(m0, m, n0, n);
1329
                    break;
1330
                case 0x34:
1331
                    mc = 3;
1332
                    nc = 4;
1333
                    gemm_small<3, 4>(m0, m, n0, n);
1334
                    break;
1335
                case 0x33:
1336
                    mc = 3;
1337
                    nc = 3;
1338
                    gemm_small<3, 3>(m0, m, n0, n);
1339
                    break;
1340
                case 0x32:
1341
                    mc = 3;
1342
                    nc = 2;
1343
                    gemm_small<3, 2>(m0, m, n0, n);
1344
                    break;
1345
                case 0x31:
1346
                    mc = 3;
1347
                    nc = 1;
1348
                    gemm_small<3, 1>(m0, m, n0, n);
1349
                    break;
1350
                case 0x24:
1351
                    mc = 2;
1352
                    nc = 4;
1353
                    gemm_small<2,4>(m0, m, n0, n);
1354
                    break;
1355
                case 0x23:
1356
                    mc = 2;
1357
                    nc = 3;
1358
                    gemm_small<2, 3>(m0, m, n0, n);
1359
                    break;
1360
                case 0x22:
1361
                    mc = 2;
1362
                    nc = 2;
1363
                    gemm_small<2, 2>(m0, m, n0, n);
1364
                    break;
1365
                case 0x21:
1366
                    mc = 2;
1367
                    nc = 1;
1368
                    gemm_small<2, 1>(m0, m, n0, n);
1369
                    break;
1370
                case 0x14:
1371
                    mc = 1;
1372
                    nc = 4;
1373
                    gemm_small<1, 4>(m0, m, n0, n);
1374
                    break;
1375
                case 0x13:
1376
                    mc = 1;
1377
                    nc = 3;
1378
                    gemm_small<1, 3>(m0, m, n0, n);
1379
                    break;
1380
                case 0x12:
1381
                    mc = 1;
1382
                    nc = 2;
1383
                    gemm_small<1, 2>(m0, m, n0, n);
1384
                    break;
1385
                case 0x11:
1386
                    mc = 1;
1387
                    nc = 1;
1388
                    gemm_small<1, 1>(m0, m, n0, n);
1389
                    break;
1390
                default:
1391
                    return;
1392
            }
1393
        }
1394
        mp = m0 + (m - m0) / mc * mc;
1395
        np = n0 + (n - n0) / nc * nc;
1396
        mnpack(mp, m, n0, np);
1397
        mnpack(m0, m, np, n);
1398
    }
1399
1400
    void KERNEL_4x8(int64_t ii, int64_t jj) {
1401
        vec_t vec_A[4], vec_B[8] , vec_C[4];
1402
        acc_t acc_0, acc_1;
1403
        __builtin_mma_xxsetaccz(&acc_0);
1404
        __builtin_mma_xxsetaccz(&acc_1);
1405
        for (int l = 0; l < k; l+=8) {
1406
            packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
1407
            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
1408
            for (int x = 0; x < 4; x++) {
1409
                __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1410
                __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1411
            }
1412
        }
1413
        SAVE_ACC(&acc_0, ii, jj);
1414
        SAVE_ACC(&acc_1, ii, jj+4);
1415
    }
1416
1417
    void KERNEL_8x4(int64_t ii, int64_t jj) {
1418
        vec_t vec_A[8], vec_B[4] , vec_C[4];
1419
        acc_t acc_0, acc_1;
1420
        __builtin_mma_xxsetaccz(&acc_0);
1421
        __builtin_mma_xxsetaccz(&acc_1);
1422
        for (int l = 0; l < k; l+=8) {
1423
            packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
1424
            packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
1425
            for (int x = 0; x < 4; x++) {
1426
                __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1427
                __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
1428
            }
1429
        }
1430
        SAVE_ACC(&acc_0, ii, jj);
1431
        SAVE_ACC(&acc_1, ii+4, jj);
1432
    }
1433
1434
1435
    void KERNEL_8x8(int64_t ii, int64_t jj) {
1436
        vec_t vec_A[8], vec_B[8], vec_C[4];
1437
        acc_t acc_0, acc_1, acc_2, acc_3;
1438
        __builtin_mma_xxsetaccz(&acc_0);
1439
        __builtin_mma_xxsetaccz(&acc_1);
1440
        __builtin_mma_xxsetaccz(&acc_2);
1441
        __builtin_mma_xxsetaccz(&acc_3);
1442
        for (int l = 0; l < k; l+=8) {
1443
            packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
1444
            packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
1445
            for (int x = 0; x < 4; x++) {
1446
                __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1447
                __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
1448
                __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
1449
                __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
1450
            }
1451
        }
1452
1453
        SAVE_ACC(&acc_0, ii, jj);
1454
        SAVE_ACC(&acc_1, ii, jj+4);
1455
        SAVE_ACC(&acc_2, ii+4, jj);
1456
        SAVE_ACC(&acc_3, ii+4, jj+4);
1457
    }
1458
1459
    template<int RM, int RN>
1460
    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1461
        int64_t ytiles = (m - m0) / RM;
1462
        int64_t xtiles = (n - n0) / RN;
1463
        int64_t tiles = xtiles * ytiles;
1464
        int64_t duty = (tiles + nth - 1) / nth;
1465
        int64_t start = duty * ith;
1466
        int64_t end = start + duty;
1467
        if (end > tiles)
1468
            end = tiles;
1469
        for (int64_t job = start; job < end; ++job) {
1470
            int64_t ii = m0 + job / xtiles * RM;
1471
            int64_t jj = n0 + job % xtiles * RN;
1472
            vec_t vec_C[4];
1473
            acc_t acc_0;
1474
            __builtin_mma_xxsetaccz(&acc_0);
1475
            vec_t vec_A[2], vec_B[2];
1476
            for (int l=0; l<k; l+=4) {
1477
                packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
1478
                packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
1479
                for (int x = 0; x<2; x++) {
1480
                    __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1481
                }
1482
            }
1483
            __builtin_mma_disassemble_acc(vec_C, &acc_0);
1484
            for (int I = 0; I < RM; I++) {
1485
                for (int J = 0; J < RN; J++) {
1486
                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1487
                }
1488
            }
1489
        }
1490
    }
1491
1492
    template<int RM>
1493
    void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1494
        int RN = 8;
1495
        int64_t ytiles = (m - m0) / RM;
1496
        int64_t xtiles = (n - n0) / RN;
1497
        int64_t tiles = xtiles * ytiles;
1498
        int64_t duty = (tiles + nth - 1) / nth;
1499
        int64_t start = duty * ith;
1500
        int64_t end = start + duty;
1501
        if (end > tiles)
1502
            end = tiles;
1503
        for (int64_t job = start; job < end; ++job) {
1504
            int64_t ii = m0 + job / xtiles * RM;
1505
            int64_t jj = n0 + job % xtiles * RN;
1506
            vec_t vec_C[4];
1507
            acc_t acc_0, acc_1;
1508
            __builtin_mma_xxsetaccz(&acc_0);
1509
            __builtin_mma_xxsetaccz(&acc_1);
1510
            vec_t vec_A[4], vec_B[8];
1511
            for (int l=0; l<k; l+=8) {
1512
                packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
1513
                packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
1514
                for (int x = 0; x<4; x++) {
1515
                    __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1516
                    __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1517
                }
1518
            }
1519
            __builtin_mma_disassemble_acc(vec_C, &acc_0);
1520
            for (int I = 0; I < RM; I++) {
1521
                for (int J = 0; J < 4; J++) {
1522
                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1523
                }
1524
            }
1525
            __builtin_mma_disassemble_acc(vec_C, &acc_1);
1526
            for (int I = 0; I < RM; I++) {
1527
                for (int J = 0; J < 4; J++) {
1528
                    *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1529
                }
1530
            }
1531
        }
1532
    }
1533
1534
    template<int RM, int RN>
1535
    inline void kernel(int64_t ii, int64_t jj) {
1536
       if constexpr(RM == 4 && RN == 8) {
1537
          KERNEL_4x8(ii,jj);
1538
       } else if constexpr(RM == 8 && RN == 8) {
1539
          KERNEL_8x8(ii,jj);
1540
       } else if constexpr(RM == 8 && RN == 4) {
1541
          KERNEL_8x4(ii,jj);
1542
       } else {
1543
          assert(false && "RN/RM values not supported");
1544
       }
1545
    }
1546
1547
    template <int RM, int RN>
1548
    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1549
        int64_t ytiles = (m - m0) / RM;
1550
        int64_t xtiles = (n - n0) / RN;
1551
        int64_t tiles = xtiles * ytiles;
1552
        int64_t duty = (tiles + nth - 1) / nth;
1553
        int64_t start = duty * ith;
1554
        int64_t end = start + duty;
1555
        if (end > tiles)
1556
            end = tiles;
1557
        for (int64_t job = start; job < end; ++job) {
1558
            int64_t ii = m0 + job / xtiles * RM;
1559
            int64_t jj = n0 + job % xtiles * RN;
1560
            kernel<RM, RN>(ii, jj);
1561
        }
1562
    }
1563
1564
    const TA *const A;
1565
    const TB *const B;
1566
    TC *C;
1567
    const int64_t k;
1568
    const int64_t lda;
1569
    const int64_t ldb;
1570
    const int64_t ldc;
1571
    const int ith;
1572
    const int nth;
1573
};
1574
1575
    template <typename TA>
1576
    tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
1577
        const TA *A, int64_t lda,
1578
        const block_q8_0 *B, int64_t ldb,
1579
        float *C, int64_t ldc,
1580
        int ith, int nth)
1581
        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1582
                kc = 64;
1583
    }
1584
1585
    template<typename TA>
1586
    void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
1587
        int mc = 64; int nc = 64;
1588
        if (n % 8 == 0 && n < nc) {
1589
                nc = n;
1590
                mc = 32 ;
1591
                kc = 32;
1592
        }
1593
        const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
1594
        if (is_aligned) {
1595
            this->matmul_tiled_q0(m, n, mc, nc, kc);
1596
        } else {
1597
            mnpack(0, m, 0, n);
1598
        }
1599
    }
1600
1601
   template<typename TA>
1602
   template<int size>
1603
   void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
1604
        int64_t i, j;
1605
        TA *aoffset = NULL;
1606
        int8_t *vecOffset = NULL;
1607
        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1608
        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1609
        vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1610
        vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1611
        aoffset = const_cast<TA*>(a);
1612
        vecOffset = vec;
1613
        j = (rows >> 3);
1614
        if (j > 0) {
1615
            do {
1616
                aoffset1 = aoffset;
1617
                aoffset2 = aoffset1 + lda;
1618
                aoffset3 = aoffset2 + lda;
1619
                aoffset4 = aoffset3 + lda;
1620
                aoffset5 = aoffset4 + lda;
1621
                aoffset6 = aoffset5 + lda;
1622
                aoffset7 = aoffset6 + lda;
1623
                aoffset8 = aoffset7 + lda;
1624
                aoffset += 8 * lda;
1625
                i = (cols >> 2);
1626
                if (i > 0) {
1627
                    do {
1628
                        c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1629
                        c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1630
                        c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1631
                        c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1632
                        c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
1633
                        c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
1634
                        c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
1635
                        c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
1636
1637
                        process_q4_elements(c1, &comparray[0]);
1638
                        process_q4_elements(c2, &comparray[1]);
1639
                        process_q4_elements(c3, &comparray[2]);
1640
                        process_q4_elements(c4, &comparray[3]);
1641
                        process_q4_elements(c5, &comparray[4]);
1642
                        process_q4_elements(c6, &comparray[5]);
1643
                        process_q4_elements(c7, &comparray[6]);
1644
                        process_q4_elements(c8, &comparray[7]);
1645
                        vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1646
                        vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1647
                        vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
1648
                        vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
1649
                        aoffset1 += lda;
1650
                        aoffset2 += lda;
1651
                        aoffset3 += lda;
1652
                        aoffset4 += lda;
1653
                        aoffset5 += lda;
1654
                        aoffset6 += lda;
1655
                        aoffset7 += lda;
1656
                        aoffset8 += lda;
1657
                        vecOffset += 256;
1658
                        i--;
1659
                    } while (i > 0);
1660
                }
1661
                j--;
1662
            } while (j > 0);
1663
        }
1664
1665
        if (rows & 4) {
1666
            aoffset1 = aoffset;
1667
            aoffset2 = aoffset1 + lda;
1668
            aoffset3 = aoffset2 + lda;
1669
            aoffset4 = aoffset3 + lda;
1670
            aoffset += 4 * lda;
1671
            i = (cols >> 2);
1672
            if (i > 0) {
1673
                do {
1674
                    c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1675
                    c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1676
                    c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1677
                    c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1678
1679
                    process_q4_elements(c1, &comparray[0]);
1680
                    process_q4_elements(c2, &comparray[1]);
1681
                    process_q4_elements(c3, &comparray[2]);
1682
                    process_q4_elements(c4, &comparray[3]);
1683
                    vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1684
                    vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1685
                    aoffset1 += lda;
1686
                    aoffset2 += lda;
1687
                    aoffset3 += lda;
1688
                    aoffset4 += lda;
1689
                    vecOffset += 128;
1690
                    i--;
1691
                } while (i > 0);
1692
            }
1693
        }
1694
1695
        if (rows & 3) {
1696
            aoffset1 = aoffset;
1697
            aoffset2 = aoffset1 + lda;
1698
            aoffset3 = aoffset2 + lda;
1699
            i = (cols >> 2);
1700
            if (i > 0) {
1701
                do {
1702
                    switch(rows) {
1703
                        case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1704
                        case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1705
                        case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1706
                            break;
1707
                    }
1708
                    process_q4_elements(c1, &comparray[0]);
1709
                    process_q4_elements(c2, &comparray[1]);
1710
                    process_q4_elements(c3, &comparray[2]);
1711
                    process_q4_elements(c4, &comparray[3]);
1712
                    vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1713
                    vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1714
                    aoffset1 += lda;
1715
                    aoffset2 += lda;
1716
                    aoffset3 += lda;
1717
                    vecOffset += 128;
1718
                    i--;
1719
                } while(i > 0);
1720
            }
1721
        }
1722
    }
1723
1724
    template<typename TA>
1725
    template<typename VA, typename VB>
1726
    void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1727
        int64_t i, j;
1728
        block_q8_0 *aoffset = NULL;
1729
        VA *vecOffset = NULL;
1730
        block_q8_0* aoffsets[8];
1731
        __vector_pair arr[8];
1732
        VB c[8][2] = {0};
1733
        VB c1[8] = {0}; VB c2[8] = {0};
1734
        aoffset = const_cast<block_q8_0*>(a);
1735
        vecOffset = vec;
1736
        j = (rows >> 3);
1737
        if (j > 0) {
1738
            do {
1739
                aoffsets[0] = aoffset;
1740
                for (int it = 1; it < 8; it++)
1741
                    aoffsets[it] = aoffsets[it-1] + lda;
1742
                aoffset += 8 * lda;
1743
1744
                i = (cols >> 3);
1745
                if (i > 0) {
1746
                do {
1747
                    for (int it = 0; it < 8; it++) {
1748
                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1749
                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1750
                        c1[it] = c[it][0];
1751
                        c2[it] = c[it][1];
1752
                    }
1753
                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1754
                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1755
                    vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
1756
                    vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
1757
                    for (int it = 0; it < 8; it++)
1758
                        aoffsets[it] += lda;
1759
                    vecOffset += 256;
1760
                    i--;
1761
               } while(i > 0);
1762
            }
1763
            j--;
1764
        } while(j > 0);
1765
    }
1766
    if (rows & 4) {
1767
            aoffsets[0]  = aoffset;
1768
            for (int it = 1; it < 4; it++ )
1769
                aoffsets[it] = aoffsets[it-1] + lda;
1770
            aoffset += 4 * lda;
1771
        i = (cols >> 3);
1772
            if (i > 0) {
1773
               do {
1774
                    for (int it = 0; it < 4; it++) {
1775
                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1776
                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1777
                        c1[it] = c[it][0];
1778
                        c2[it] = c[it][1];
1779
                    }
1780
                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1781
                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1782
                    for (int it = 0; it < 4; it++) {
1783
                        aoffsets[it] += lda;
1784
                    }
1785
                    vecOffset += 128;
1786
                    i--;
1787
               } while(i > 0);
1788
            }
1789
        }
1790
1791
        if (rows & 3) {
1792
            aoffsets[0]  = aoffset;
1793
            for (int it = 1; it < 3; it++ )
1794
                aoffsets[it] = aoffsets[it-1] + lda;
1795
            i = (cols >> 3);
1796
            if (i > 0) {
1797
                do {
1798
                    switch(rows) {
1799
                        case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
1800
                                __builtin_vsx_disassemble_pair(c[2], &arr[2]);
1801
                                c1[2] = c[2][0]; c2[2] = c[2][1];
1802
                        case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
1803
                                __builtin_vsx_disassemble_pair(c[1], &arr[1]);
1804
                                c1[1] = c[1][0]; c2[1] = c[1][1];
1805
                        case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
1806
                                __builtin_vsx_disassemble_pair(c[0], &arr[0]);
1807
                                c1[0] = c[0][0]; c2[0] = c[0][1];
1808
                                break;
1809
                    }
1810
                    vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1811
                    vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1812
                    for (int it = 0; it < 3; it++)
1813
                         aoffsets[it] += lda;
1814
                    vecOffset += 128;
1815
                    i--;
1816
               } while(i > 0);
1817
            }
1818
        }
1819
    }
1820
1821
    template<typename TA>
1822
    void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1823
        int m_rem = MIN(m - m0, 16);
1824
        int n_rem = MIN(n - n0, 16);
1825
1826
        int mc = 0, nc = 0;
1827
1828
        if (m_rem >= 8 && n_rem >= 8) {
1829
           mc = 8;
1830
           nc = 8;
1831
           gemm<8, 8>(m0, m, n0, n);
1832
        } else if (m_rem >= 4 && n_rem >= 8) {
1833
            mc = 4;
1834
            nc = 8;
1835
            gemm<4, 8>(m0, m, n0, n);
1836
        } else if (m_rem >= 8 && n_rem >= 4) {
1837
            mc = 8;
1838
            nc = 4;
1839
            gemm<8, 4>(m0, m, n0, n);
1840
        } else if (m_rem >= 4 && n_rem >= 4) {
1841
            mc = 4;
1842
            nc = 4;
1843
            gemm_small(m0, m, n0, n, mc, nc);
1844
        } else {
1845
            mc = (m_rem >= 4) ? 4 : m_rem;
1846
            nc = (n_rem >= 4) ? 4 : n_rem;
1847
            if (mc == 0 || nc == 0)
1848
               return;
1849
            gemm_small(m0, m, n0, n, mc, nc);
1850
        }
1851
1852
        int64_t mp = m0 + ((m - m0) / mc) * mc;
1853
        int64_t np = n0 + ((n - n0) / nc) * nc;
1854
        mnpack(mp, m, n0, np);
1855
        mnpack(m0, m, np, n);
1856
    }
1857
1858
1859
    template<typename TA>
1860
    void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
1861
        vec_t vec_A[8], vec_B[16] = {0};
1862
        acc_t acc_0, acc_1;
1863
        std::array<int, 4> comparray {};
1864
        vector float fin_res[8] = {0};
1865
        vector float vs[8] = {0};
1866
        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1867
        for (int l = 0; l < k; l++) {
1868
            __builtin_mma_xxsetaccz(&acc_0);
1869
            __builtin_mma_xxsetaccz(&acc_1);
1870
            if (std::is_same_v<TA, block_q4_0>) {
1871
               packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
1872
            } else {
1873
               packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1874
            }
1875
            packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1876
            for(int x = 0; x < 8; x++) {
1877
                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1878
                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
1879
            }
1880
            for (int I = 0; I<4; I++) {
1881
                for (int J = 0; J<4; J++) {
1882
                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1883
                    *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1884
                }
1885
            }
1886
            if (!isAblock_q4) {
1887
                auto aoffset = A+(ii*lda)+l;
1888
                for (int i = 0; i < 4; i++) {
1889
                    comparray[i] = 0;
1890
                    int ca = 0;
1891
                    auto *at = aoffset->qs;
1892
                    for (int j = 0; j < 32; j++)
1893
                        ca += (int)*at++;
1894
                    comparray[i] = ca;
1895
                    aoffset += lda;
1896
                }
1897
            }
1898
            compute(&acc_0, 0, 0, comparray, vs, fin_res);
1899
            compute(&acc_1, 0, 4, comparray, vs, fin_res);
1900
        }
1901
        save_res(ii, jj, 0, fin_res);
1902
        save_res(ii, jj+4, 4, fin_res);
1903
    }
1904
1905
    template<typename TA>
1906
    void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
1907
        vec_t vec_A[16], vec_B[8] = {0};
1908
        acc_t acc_0, acc_1;
1909
        std::array<int, 8> comparray {};
1910
        vector float fin_res[8] = {0};
1911
        vector float vs[8] = {0};
1912
        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1913
        for (int l = 0; l < k; l++) {
1914
            __builtin_mma_xxsetaccz(&acc_0);
1915
            __builtin_mma_xxsetaccz(&acc_1);
1916
            if (std::is_same_v<TA, block_q4_0>) {
1917
               packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
1918
            } else {
1919
               packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1920
            }
1921
            packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
1922
            for(int x = 0; x < 8; x++) {
1923
                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1924
                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
1925
            }
1926
            for (int I = 0; I<8; I++) {
1927
                for (int J = 0; J<4; J++) {
1928
                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1929
                }
1930
            }
1931
            if (!isAblock_q4) {
1932
                auto aoffset = A+(ii*lda)+l;
1933
                for (int i = 0; i < 8; i++) {
1934
                    comparray[i] = 0;
1935
                    int ca = 0;
1936
                    auto *at = aoffset->qs;
1937
                    for (int j = 0; j < 32; j++)
1938
                        ca += (int)*at++;
1939
                    comparray[i] = ca;
1940
                    aoffset += lda;
1941
                }
1942
            }
1943
            compute(&acc_0, 0, 0, comparray, vs, fin_res);
1944
            compute(&acc_1, 4, 4, comparray, vs, fin_res);
1945
        }
1946
        save_res(ii, jj, 0, fin_res);
1947
        save_res(ii+4, jj, 4, fin_res);
1948
    }
1949
1950
    template<typename TA>
1951
    void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
1952
        vec_t vec_A[16], vec_B[16] = {0};
1953
        acc_t acc_0, acc_1, acc_2, acc_3;
1954
        acc_t acc_4, acc_5, acc_6, acc_7;
1955
        std::array<int, 8> comparray {};
1956
        vector float fin_res[16] = {0};
1957
        vector float vs[16] = {0};
1958
        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1959
        for (int l = 0; l < k; l++) {
1960
            __builtin_mma_xxsetaccz(&acc_0);
1961
            __builtin_mma_xxsetaccz(&acc_1);
1962
            __builtin_mma_xxsetaccz(&acc_2);
1963
            __builtin_mma_xxsetaccz(&acc_3);
1964
            if (std::is_same_v<TA, block_q4_0>) {
1965
               packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
1966
            } else {
1967
               packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1968
            }
1969
            packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1970
            for(int x = 0; x < 8; x++) {
1971
                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1972
                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
1973
                __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
1974
                __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
1975
            }
1976
            for (int I = 0; I<8; I++) {
1977
                for (int J = 0; J<4; J++) {
1978
                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1979
                    *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1980
                }
1981
            }
1982
            if (!isAblock_q4) {
1983
                auto aoffset = A+(ii*lda)+l;
1984
                for (int i = 0; i < 8; i++) {
1985
                    comparray[i] = 0;
1986
                    int ca = 0;
1987
                    auto *at = aoffset->qs;
1988
                    for (int j = 0; j < 32; j++)
1989
                        ca += (int)*at++;
1990
                    comparray[i] = ca;
1991
                    aoffset += lda;
1992
                }
1993
            }
1994
            compute(&acc_0, 0, 0, comparray, vs, fin_res);
1995
            compute(&acc_1, 4, 4, comparray, vs, fin_res);
1996
            compute(&acc_2, 0, 8, comparray, vs, fin_res);
1997
            compute(&acc_3, 4, 12, comparray, vs, fin_res);
1998
        }
1999
        save_res(ii, jj, 0, fin_res);
2000
        save_res(ii+4, jj, 4, fin_res);
2001
        save_res(ii, jj+4, 8, fin_res);
2002
        save_res(ii+4, jj+4, 12, fin_res);
2003
    }
2004
2005
    template<typename TA>
2006
    void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2007
        int64_t ytiles = (m - m0) / RM;
2008
        int64_t xtiles = (n - n0) / RN;
2009
        int64_t tiles = xtiles * ytiles;
2010
        int64_t duty = (tiles + nth - 1) / nth;
2011
        int64_t start = duty * ith;
2012
        int64_t end = start + duty;
2013
        vec_t vec_A[8] = {0}, vec_B[8] = {0};
2014
        vector signed int vec_C[4];
2015
        acc_t acc_0;
2016
        bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2017
2018
        if (end > tiles)
2019
            end = tiles;
2020
        for (int64_t job = start; job < end; ++job) {
2021
            int64_t ii = m0 + job / xtiles * RM;
2022
            int64_t jj = n0 + job % xtiles * RN;
2023
            std::array<int, 4> comparray{};
2024
            vector float res[4] = {0};
2025
            vector float fin_res[4] = {0};
2026
            vector float vs[4] = {0};
2027
            vector float CA[4] = {0};
2028
            __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
2029
            __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
2030
            for (int l = 0; l < k; l++) {
2031
                __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2032
                __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2033
                __builtin_mma_xxsetaccz(&acc_0);
2034
                if (isAblock_q4) {
2035
                   packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2036
                } else {
2037
                   packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2038
                }
2039
                packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2040
                for(int x = 0; x < 8; x+=4) {
2041
                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2042
                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
2043
                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
2044
                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
2045
                }
2046
                for (int I = 0; I<RM; I++) {
2047
                    for (int J = 0; J<RN; J++) {
2048
                        *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2049
                    }
2050
                }
2051
                __builtin_mma_disassemble_acc(vec_C, &acc_0);
2052
                if (!isAblock_q4) {
2053
                    auto aoffset = A+(ii*lda)+l;
2054
                    for (int i = 0; i < RM; i++) {
2055
                        comparray[i] = 0;
2056
                        int ca = 0;
2057
                        auto *at = aoffset->qs;
2058
                        for (int j = 0; j < 32; j++)
2059
                            ca += (int)*at++;
2060
                        comparray[i] = ca;
2061
                        aoffset += lda;
2062
                    }
2063
                }
2064
                for (int i = 0; i < RM; i++) {
2065
                    CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
2066
                    res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
2067
                    fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
2068
                }
2069
            }
2070
            save_res(ii, jj, 0, fin_res, RM, RN);
2071
        }
2072
    }
2073
2074
    template<typename TA>
2075
    template <int RM, int RN>
2076
    NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2077
        int64_t ytiles = (m - m0) / RM;
2078
        int64_t xtiles = (n - n0) / RN;
2079
        int64_t tiles = xtiles * ytiles;
2080
        int64_t duty = (tiles + nth - 1) / nth;
2081
        int64_t start = duty * ith;
2082
        int64_t end = start + duty;
2083
        if (end > tiles)
2084
            end = tiles;
2085
        for (int64_t job = start; job < end; ++job) {
2086
            int64_t ii = m0 + job / xtiles * RM;
2087
            int64_t jj = n0 + job % xtiles * RN;
2088
            this->kernel<RM, RN>(ii, jj);
2089
        }
2090
    }
2091
2092
template class tinyBLAS_Q0_PPC<block_q4_0>;
2093
template class tinyBLAS_Q0_PPC<block_q8_0>;
2094
2095
class tinyBLAS_PPC {
2096
  public:
2097
    tinyBLAS_PPC(int64_t k,
2098
                const float * A, int64_t lda,
2099
                const float * B, int64_t ldb,
2100
                float * C, int64_t ldc,
2101
                int ith, int nth)
2102
        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2103
    }
2104
2105
    void matmul(int64_t m, int64_t n) {
2106
        int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
2107
        if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
2108
            matmul_tiled(m, n, mc, nc, kc);
2109
        } else {
2110
            mnpack(0, m, 0, n);
2111
        }
2112
    }
2113
2114
  private:
2115
2116
    inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2117
        vec_t vec_C[4];
2118
        __builtin_mma_disassemble_acc(vec_C, ACC);
2119
        for (int I = 0; I < 4; I++) {
2120
            for (int J = 0; J < 4; J++) {
2121
                *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2122
            }
2123
        }
2124
    }
2125
2126
    inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2127
        vec_t vec_C[4];
2128
        __builtin_mma_disassemble_acc(vec_C, ACC);
2129
        for (int I = 0; I < 4; I++) {
2130
            for (int J = 0; J < 4; J++) {
2131
                float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
2132
                *c_ptr += *((float *)&vec_C[I]+J);
2133
            }
2134
        }
2135
    }
2136
2137
    inline void vector_permute_store_4(vector float * src, float * vecOffset) {
2138
        vector float t1, t2, t3, t4, t5, t6, t7, t8;
2139
        t1 = vec_mergeh(src[0], src[1]);
2140
        t2 = vec_mergeh(src[2], src[3]);
2141
        t3 = vec_mergel(src[0], src[1]);
2142
        t4 = vec_mergel(src[2], src[3]);
2143
2144
        t5 = vec_xxpermdi(t1, t2, 0);
2145
        t6 = vec_xxpermdi(t1, t2, 3);
2146
        t7 = vec_xxpermdi(t3, t4, 0);
2147
        t8 = vec_xxpermdi(t3, t4, 3);
2148
2149
        vec_xst(t5, 0, vecOffset);
2150
        vec_xst(t6, 0, vecOffset + 4);
2151
        vec_xst(t7, 0, vecOffset + 8);
2152
        vec_xst(t8, 0, vecOffset + 12);
2153
    }
2154
2155
    inline void vector_permute_store_8(vector float * src, float * vecOffset) {
2156
        vector float t1, t2, t3, t4, t5, t6, t7, t8;
2157
        t1 = vec_mergeh(src[0], src[1]);
2158
        t2 = vec_mergeh(src[2], src[3]);
2159
        t3 = vec_mergeh(src[4], src[5]);
2160
        t4 = vec_mergeh(src[6], src[7]);
2161
2162
        t5 = vec_xxpermdi(t1, t2, 0);
2163
        t6 = vec_xxpermdi(t3, t4, 0);
2164
        t7 = vec_xxpermdi(t1, t2, 3);
2165
        t8 = vec_xxpermdi(t3, t4, 3);
2166
2167
        vec_xst(t5, 0, vecOffset);
2168
        vec_xst(t6, 0, vecOffset + 4);
2169
        vec_xst(t7, 0, vecOffset + 8);
2170
        vec_xst(t8, 0, vecOffset + 12);
2171
2172
        t1 = vec_mergel(src[0], src[1]);
2173
        t2 = vec_mergel(src[2], src[3]);
2174
        t3 = vec_mergel(src[4], src[5]);
2175
        t4 = vec_mergel(src[6], src[7]);
2176
2177
        t5 = vec_xxpermdi(t1, t2, 0);
2178
        t6 = vec_xxpermdi(t3, t4, 0);
2179
        t7 = vec_xxpermdi(t1, t2, 3);
2180
        t8 = vec_xxpermdi(t3, t4, 3);
2181
2182
        vec_xst(t5, 0, vecOffset + 16);
2183
        vec_xst(t6, 0, vecOffset + 20);
2184
        vec_xst(t7, 0, vecOffset + 24);
2185
        vec_xst(t8, 0, vecOffset + 28);
2186
    }
2187
2188
    void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
2189
        int64_t i, j;
2190
        float * aoffsets[8];
2191
        float * aoffset = NULL, * boffset = NULL;
2192
        __vector_pair arr[8];
2193
        vector float c[8][2] = {0};
2194
        vector float c1[8] = {0};
2195
        vector float c2[8] = {0};
2196
        aoffset = const_cast<float *>(a);
2197
        boffset = vec;
2198
        j = (rows >> 3);
2199
        if (j > 0) {
2200
            do {
2201
                aoffsets[0] = aoffset;
2202
                for (int it = 1; it < 8; it++)
2203
                    aoffsets[it] = aoffsets[it-1] + lda;
2204
                aoffset += 8 * lda;
2205
                i = (cols >> 3);
2206
                if (i > 0) {
2207
                    do {
2208
                        for (int it = 0; it < 8; it++) {
2209
                            arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2210
                            __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2211
                            c1[it] = c[it][0];
2212
                            c2[it] = c[it][1];
2213
                        }
2214
2215
                        vector_permute_store_8(c1, boffset);
2216
                        vector_permute_store_8(c2, boffset + 32);
2217
                        boffset += 64;
2218
                        i--;
2219
                        if (i > 0) {
2220
                           for (int it = 0; it < 8; it++) {
2221
                               aoffsets[it] = aoffsets[it] + 8;
2222
                           }
2223
                        }
2224
                    } while(i > 0);
2225
                }
2226
                if (cols & 4) {
2227
                    for (int it = 0; it < 8 ; it++)
2228
                        c1[it] = vec_xl(0, aoffsets[it]);
2229
                    vector_permute_store_8(c1, boffset);
2230
                }
2231
            j--;
2232
            } while(j > 0);
2233
        }
2234
2235
        if (rows & 4) {
2236
            aoffsets[0] = aoffset;
2237
            for (int it = 1; it < 4; it++)
2238
                aoffsets[it] = aoffsets[it-1] + lda;
2239
            aoffset += 4 * lda;
2240
            i = (cols >> 3);
2241
            if (i > 0) {
2242
                do {
2243
                    for (int it = 0; it < 4; it++) {
2244
                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2245
                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2246
                        c1[it] = c[it][0];
2247
                        c2[it] = c[it][1];
2248
                    }
2249
                    vector_permute_store_4(c1, boffset);
2250
                    vector_permute_store_4(c2, boffset + 16);
2251
                    for (int it = 0; it < 4; it++)
2252
                        aoffsets[it] += 8 * lda;
2253
                    boffset += 32;
2254
                    i--;
2255
                } while(i > 0);
2256
            }
2257
2258
            if (cols & 4) {
2259
               for (int it = 0; it < 4; it++)
2260
                   c1[it] = vec_xl(0, aoffsets[it]);
2261
                vector_permute_store_4(c1, boffset);
2262
            }
2263
        }
2264
        if (rows & 3) {
2265
            aoffsets[0] = aoffset;
2266
            for (int it = 1; it < 3; it++)
2267
                aoffsets[it] = aoffsets[it-1] + lda;
2268
            if (cols & 4) {
2269
                for (int it = 0; it < 3; it++)
2270
                    c1[it] = vec_xl(0, aoffsets[it]);
2271
                vector_permute_store_4(c1, boffset);
2272
            }
2273
        }
2274
    }
2275
2276
    void KERNEL_4x4(int64_t ii, int64_t jj) {
2277
        vec_t vec_A[4], vec_B[4], vec_C[4];
2278
        acc_t acc_0;
2279
        __builtin_mma_xxsetaccz(&acc_0);
2280
        for (int l = 0; l < k; l += 4) {
2281
            packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2282
            packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2283
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2284
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2285
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
2286
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
2287
        }
2288
        save_acc(&acc_0, ii, jj);
2289
    }
2290
2291
    void KERNEL_4x8(int64_t ii, int64_t jj) {
2292
        vec_t vec_A[4], vec_B[8], vec_C[4];
2293
        acc_t acc_0, acc_1;
2294
        __builtin_mma_xxsetaccz(&acc_0);
2295
        __builtin_mma_xxsetaccz(&acc_1);
2296
        for (int64_t l = 0; l < k; l += 4) {
2297
            packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2298
            packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
2299
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2300
            __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2301
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
2302
            __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
2303
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
2304
            __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
2305
            __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
2306
            __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
2307
        }
2308
        save_acc(&acc_0, ii, jj);
2309
        save_acc(&acc_1, ii, jj + 4);
2310
    }
2311
2312
    void KERNEL_8x4(int64_t ii, int64_t jj) {
2313
        vec_t vec_A[8], vec_B[4], vec_C[4];
2314
        acc_t acc_0, acc_1;
2315
        __builtin_mma_xxsetaccz(&acc_0);
2316
        __builtin_mma_xxsetaccz(&acc_1);
2317
        for (int64_t l = 0; l < k; l += 4) {
2318
            packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
2319
            packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2320
            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
2321
            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
2322
            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
2323
            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
2324
            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
2325
            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
2326
            __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
2327
            __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
2328
        }
2329
        save_acc(&acc_0, ii, jj);
2330
        save_acc(&acc_1, ii + 4, jj);
2331
    }
2332
2333
    void KERNEL_8x8(int64_t ii, int64_t jj) {
2334
        vec_t vec_A[16], vec_B[16], vec_C[4];
2335
        acc_t acc_0, acc_1, acc_2, acc_3;
2336
        __builtin_mma_xxsetaccz(&acc_0);
2337
        __builtin_mma_xxsetaccz(&acc_1);
2338
        __builtin_mma_xxsetaccz(&acc_2);
2339
        __builtin_mma_xxsetaccz(&acc_3);
2340
        for (int l = 0; l < k; l+=8) {
2341
            packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
2342
            packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
2343
            for(int x = 0; x < 16; x+=2) {
2344
                __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
2345
                __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
2346
                __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
2347
                __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
2348
            }
2349
        }
2350
        save_acc(&acc_0, ii, jj);
2351
        save_acc(&acc_1, ii, jj + 4);
2352
        save_acc(&acc_2, ii + 4, jj);
2353
        save_acc(&acc_3, ii + 4, jj + 4);
2354
    }
2355
2356
    inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
2357
        for (int x = 0; x < 16; x += 2) {
2358
            __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
2359
            __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
2360
            __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
2361
            __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
2362
            __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
2363
            __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
2364
            __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
2365
            __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
2366
        }
2367
    }
2368
2369
    void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
2370
        for (int64_t i = 0; i < mc; i += 16) {
2371
            int A_base_addr = (mc / 8) * (i / 8) * 16;
2372
            for (int64_t j = 0; j < nc; j += 8) {
2373
                 int B_base_addr = (nc / 8) * (j / 8) * 16;
2374
                 acc_t acc[8];
2375
                 vec_t A0_block[16]; vec_t A1_block[16];
2376
                 for (int x = 0; x < 8; x++)
2377
                     __builtin_mma_xxsetaccz(&acc[x]);
2378
                 for (int64_t l = 0; l < kc; l += 8) {
2379
                     int A0_block_idx = A_base_addr + (l / 8) * 16;
2380
                     int A1_block_idx = A0_block_idx + (mc / 8) * 16;
2381
                     int B_block_idx = B_base_addr + (l / 8) * 16;
2382
                     vec_t* A0_block = &vec_A[A0_block_idx];
2383
                     vec_t* A1_block = &vec_A[A1_block_idx];
2384
                     vec_t* B_block = &vec_B[B_block_idx];
2385
                     MMA_16x8(A0_block, A1_block, B_block, acc);
2386
                 }
2387
                 if (kk == 0) {
2388
                     save_acc(&acc[0], ii + i, jj + j);
2389
                     save_acc(&acc[1], ii + i, jj + j + 4);
2390
                     save_acc(&acc[2], ii + i + 4, jj + j);
2391
                     save_acc(&acc[3], ii + i + 4, jj + j + 4);
2392
                     save_acc(&acc[4], ii + i + 8, jj + j);
2393
                     save_acc(&acc[5], ii + i + 8, jj + j + 4);
2394
                     save_acc(&acc[6], ii + i + 12, jj + j);
2395
                     save_acc(&acc[7], ii + i + 12, jj + j + 4);
2396
                 } else {
2397
                     add_save_acc(&acc[0], ii + i, jj + j);
2398
                     add_save_acc(&acc[1], ii + i, jj + j + 4);
2399
                     add_save_acc(&acc[2], ii + i + 4, jj + j);
2400
                     add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
2401
                     add_save_acc(&acc[4], ii + i + 8, jj + j);
2402
                     add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
2403
                     add_save_acc(&acc[6], ii + i + 12, jj + j);
2404
                     add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
2405
                 }
2406
            }
2407
        }
2408
    }
2409
2410
    void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2411
        int64_t ytiles = m / mc;
2412
        int64_t xtiles = n / nc;
2413
        int64_t tiles = xtiles * ytiles;
2414
        int64_t duty = (tiles + nth - 1) / nth;
2415
        int64_t start = duty * ith;
2416
        int64_t end = start + duty;
2417
        if (end > tiles) {
2418
            end = tiles;
2419
        }
2420
        for (int64_t job = start; job < end; ++job) {
2421
            int64_t ii = (job / xtiles) * mc;
2422
            int64_t jj = (job % xtiles) * nc;
2423
            for (int64_t kk = 0; kk < k; kk += kc) {
2424
                 vec_t A_pack[kc * mc / 4];
2425
                 vec_t B_pack[kc * nc / 4];
2426
                 packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
2427
                 packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
2428
                 KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
2429
            }
2430
        }
2431
    }
2432
2433
    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2434
        int m_rem = MIN(m - m0, 8);
2435
        int n_rem = MIN(n - n0, 8);
2436
        int mc = 0, nc = 0;
2437
        if (m_rem >= 8 && n_rem >= 8) {
2438
            mc = 8;
2439
            nc = 8;
2440
            gemm<8, 8>(m0, m, n0, n);
2441
        } else if (m_rem >= 4 && n_rem >= 8) {
2442
            mc = 4;
2443
            nc = 8;
2444
            gemm<4, 8>(m0, m, n0, n);
2445
        } else if (m_rem >= 8 && n_rem >= 4) {
2446
            mc = 8;
2447
            nc = 4;
2448
            gemm<8, 4>(m0, m, n0, n);
2449
        } else if (m_rem >= 4 && n_rem >= 4) {
2450
            mc = 4;
2451
            nc = 4;
2452
            gemm<4, 4>(m0, m, n0, n);
2453
        } else {
2454
            mc = (m_rem >= 4) ? 4 : m_rem;
2455
            nc = (n_rem >= 4) ? 4 : n_rem;
2456
            if (mc == 0 || nc == 0)
2457
                return;
2458
            gemm_small(m0, m, n0, n, mc, nc);
2459
        }
2460
        int64_t mp = m0 + ((m - m0) / mc) * mc;
2461
        int64_t np = n0 + ((n - n0) / nc) * nc;
2462
        mnpack(mp, m, n0, np);
2463
        mnpack(m0, m, np, n);
2464
    }
2465
2466
    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2467
        int64_t ytiles = (m - m0) / RM;
2468
        int64_t xtiles = (n - n0) / RN;
2469
        int64_t tiles = xtiles * ytiles;
2470
        int64_t duty = (tiles + nth - 1) / nth;
2471
        int64_t start = duty * ith;
2472
        int64_t end = start + duty;
2473
        if (end > tiles)
2474
            end = tiles;
2475
        for (int64_t job = start; job < end; ++job) {
2476
            int64_t ii = m0 + job / xtiles * RM;
2477
            int64_t jj = n0 + job % xtiles * RN;
2478
            vec_t vec_C[4];
2479
            acc_t acc_0;
2480
            __builtin_mma_xxsetaccz(&acc_0);
2481
            vec_t vec_A[4] = {0}, vec_B[4] = {0};
2482
            for (int l = 0; l < k; l += 4) {
2483
                /* 'GEMV Forwarding' concept is used in first two conditional loops.
2484
                 * when one of the matrix has a single row/column, the elements are
2485
                 * broadcasted, instead of using packing routine to prepack the
2486
                 * matrix elements.
2487
                 */
2488
                if (RM == 1) {
2489
                    float * a = const_cast<float *>(A + (ii) * lda + l);
2490
                    packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
2491
                    vec_A[0] = (vec_t)vec_xl(0,a);
2492
                    vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
2493
                    vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
2494
                    vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
2495
                } else if (RN == 1) {
2496
                    packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2497
                    float * b = const_cast<float *>(B + (jj) * ldb + l);
2498
                    vec_B[0] = (vec_t)vec_xl(0,b);
2499
                    vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
2500
                    vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
2501
                    vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
2502
                } else {
2503
                    packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2504
                    packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
2505
                }
2506
                __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2507
                __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2508
                __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
2509
                __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
2510
            }
2511
            __builtin_mma_disassemble_acc(vec_C, &acc_0);
2512
            for (int I = 0; I < RM; I++) {
2513
                for (int J = 0; J < RN; J++) {
2514
                    *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2515
                }
2516
            }
2517
       }
2518
    }
2519
2520
    template<int RM, int RN>
2521
    inline void kernel(int64_t ii, int64_t jj) {
2522
        if constexpr(RM == 4 && RN == 4) {
2523
            KERNEL_4x4(ii, jj);
2524
        } else if constexpr(RM == 4 && RN == 8) {
2525
            KERNEL_4x8(ii, jj);
2526
        } else if constexpr(RM == 8 && RN == 4) {
2527
            KERNEL_8x4(ii, jj);
2528
        } else if constexpr(RM == 8 && RN == 8) {
2529
            KERNEL_8x8(ii, jj);
2530
        } else {
2531
            static_assert(false, "RN/RM values not supported");
2532
        }
2533
    }
2534
2535
    template <int RM, int RN>
2536
    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2537
        int64_t ytiles = (m - m0) / RM;
2538
        int64_t xtiles = (n - n0) / RN;
2539
        int64_t tiles = xtiles * ytiles;
2540
        int64_t duty = (tiles + nth - 1) / nth;
2541
        int64_t start = duty * ith;
2542
        int64_t end = start + duty;
2543
        if (end > tiles)
2544
            end = tiles;
2545
        for (int64_t job = start; job < end; ++job) {
2546
            int64_t ii = m0 + job / xtiles * RM;
2547
            int64_t jj = n0 + job % xtiles * RN;
2548
            kernel<RM, RN>(ii, jj);
2549
        }
2550
    }
2551
2552
    const float * const A;
2553
    const float * const B;
2554
    float * C;
2555
    const int64_t k;
2556
    const int64_t lda;
2557
    const int64_t ldb;
2558
    const int64_t ldc;
2559
    const int ith;
2560
    const int nth;
2561
};
2562
#endif
2563
} // namespace
2564
2565
/**
2566
 * Performs optimized matrix multiplication on CPU.
2567
 *
2568
 * This subroutine may compute C = Aᵀ * B with column major ordering.
2569
 * Despite its name, this isn't a generalized implementation. Work is
2570
 * only performed when a handwritten kernel is written and available.
2571
 * Otherwise the caller should fall back to a general matmul routine.
2572
 *
2573
 * For example, for single-threaded single-precision GEMM you can say
2574
 *
2575
 *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
2576
 *                     0, 1,
2577
 *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
2578
 *
2579
 * @param m is rows in `A` and `C`
2580
 * @param n is cols in `B` and `C`
2581
 * @param k is cols in `A` and rows in `B`
2582
 * @param A is first input matrix (always transposed)
2583
 * @param lda is row stride of `A`
2584
 * @param B is second input matrix (never transposed)
2585
 * @param ldb is row stride of `B`
2586
 * @param C is input/output array of output matrices
2587
 * @param ldc is row stride of `C`
2588
 * @param ith is thread id (must be less than `nth`)
2589
 * @param nth is number of threads (must be greater than zero)
2590
 * @param Atype is GGML data type of `A`
2591
 * @param Btype is GGML data type of `B`
2592
 * @param Ctype is GGML data type of `C`
2593
 * @return true if this function was able to service the matmul request
2594
 */
2595
bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
2596
                     const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
2597
0
                     int64_t ldc, int Atype, int Btype, int Ctype) {
2598
2599
0
    assert(m >= 0);
2600
0
    assert(n >= 0);
2601
0
    assert(k >= 0);
2602
0
    assert(lda >= k);
2603
0
    assert(ldb >= k);
2604
0
    assert(ldc >= m);
2605
0
    assert(params->nth > 0);
2606
0
    assert(params->ith < params->nth);
2607
2608
    // only enable sgemm for prompt processing
2609
0
#if !defined(__MMA__)
2610
0
    if (n < 2)
2611
0
        return false;
2612
0
#endif
2613
2614
0
    if (Ctype != GGML_TYPE_F32)
2615
0
        return false;
2616
2617
0
    switch (Atype) {
2618
2619
0
    case GGML_TYPE_F32: {
2620
0
        if (Btype != GGML_TYPE_F32)
2621
0
            return false;
2622
#if defined(__AVX512F__)
2623
        tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
2624
            k, (const float *)A, lda,
2625
            (const float *)B, ldb,
2626
            (float *)C, ldc};
2627
        return tb.matmul(m, n);
2628
#elif defined(__AVX__) || defined(__AVX2__)
2629
0
        tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
2630
0
            k, (const float *)A, lda,
2631
0
            (const float *)B, ldb,
2632
0
            (float *)C, ldc};
2633
0
        return tb.matmul(m, n);
2634
#elif defined(__ARM_NEON)
2635
        if (n < 4)
2636
            return false;
2637
        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
2638
            k, (const float *)A, lda,
2639
            (const float *)B, ldb,
2640
            (float *)C, ldc};
2641
        return tb.matmul(m, n);
2642
#elif defined(__VXE__) || defined(__VXE2__)
2643
        if (n < 4)
2644
            return false;
2645
        tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
2646
            k, (const float *)A, lda,
2647
            (const float *)B, ldb,
2648
            (float *)C, ldc};
2649
        return tb.matmul(m, n);
2650
#elif defined(__MMA__)
2651
        if (k % 8)
2652
            return false;
2653
        tinyBLAS_PPC tb{
2654
            k, (const float *)A, lda,
2655
            (const float *)B, ldb,
2656
            (float *)C, ldc,
2657
            params->ith, params->nth};
2658
        tb.matmul(m, n);
2659
        return true;
2660
#else
2661
        return false;
2662
#endif
2663
0
    }
2664
2665
0
    case GGML_TYPE_BF16: {
2666
#if defined(__AVX512BF16__)
2667
        if (Btype == GGML_TYPE_BF16) {
2668
            tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
2669
                (const ggml_bf16_t *)A, lda,
2670
                (const ggml_bf16_t *)B, ldb,
2671
                (float *)C, ldc};
2672
            return tb.matmul(m, n);
2673
        }
2674
#elif defined(__AVX512F__)
2675
        if (Btype == GGML_TYPE_BF16) {
2676
            tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
2677
                (const ggml_bf16_t *)A, lda,
2678
                (const ggml_bf16_t *)B, ldb,
2679
                (float *)C, ldc};
2680
            return tb.matmul(m, n);
2681
        }
2682
#elif defined(__AVX2__)
2683
0
        if (Btype == GGML_TYPE_BF16) {
2684
0
            tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
2685
0
                (const ggml_bf16_t *)A, lda,
2686
0
                (const ggml_bf16_t *)B, ldb,
2687
0
                (float *)C, ldc};
2688
0
            return tb.matmul(m, n);
2689
0
        }
2690
#elif defined(__MMA__)
2691
        if ((k % 8))
2692
                return false;
2693
        if(Btype == GGML_TYPE_BF16) {
2694
           tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
2695
            (const ggml_bf16_t *)A, lda,
2696
            (const ggml_bf16_t *)B, ldb,
2697
            (float *)C, ldc,
2698
            params->ith, params->nth};
2699
        tb.matmul(m, n);
2700
        return true;
2701
        }
2702
#endif
2703
0
        return false;
2704
0
    }
2705
2706
0
    case GGML_TYPE_F16: {
2707
#if defined(__AVX512F__)
2708
        if (Btype == GGML_TYPE_F16) {
2709
            tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
2710
                (const ggml_fp16_t *)A, lda,
2711
                (const ggml_fp16_t *)B, ldb,
2712
                (float *)C, ldc};
2713
            return tb.matmul(m, n);
2714
        }
2715
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
2716
0
        if (Btype == GGML_TYPE_F16) {
2717
0
            tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
2718
0
                (const ggml_fp16_t *)A, lda,
2719
0
                (const ggml_fp16_t *)B, ldb,
2720
0
                (float *)C, ldc};
2721
0
            return tb.matmul(m, n);
2722
0
        }
2723
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
2724
        if (n < 8)
2725
            return false;
2726
        if (Btype == GGML_TYPE_F16) {
2727
            tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
2728
                k, (const ggml_fp16_t *)A, lda,
2729
                (const ggml_fp16_t *)B, ldb,
2730
                (float *)C, ldc};
2731
            return tb.matmul(m, n);
2732
        }
2733
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
2734
        if (Btype == GGML_TYPE_F32) {
2735
            tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
2736
                k, (const ggml_fp16_t *)A, lda,
2737
                (const float *)B, ldb,
2738
                (float *)C, ldc};
2739
            return tb.matmul(m, n);
2740
        }
2741
#elif defined(__VXE__) || defined(__VXE2__)
2742
        if (n < 4)
2743
            return false;
2744
        if (Btype == GGML_TYPE_F16) {
2745
            tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
2746
                k, (const ggml_fp16_t *)A, lda,
2747
                (const ggml_fp16_t *)B, ldb,
2748
                (float *)C, ldc};
2749
            return tb.matmul(m, n);
2750
        }
2751
#endif
2752
0
        return false;
2753
0
    }
2754
2755
0
    case GGML_TYPE_Q8_0: {
2756
0
        if (Btype != GGML_TYPE_Q8_0)
2757
0
           return false;
2758
0
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
2759
0
        tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
2760
0
            k, (const block_q8_0 *)A, lda,
2761
0
            (const block_q8_0 *)B, ldb,
2762
0
            (float *)C, ldc,
2763
0
            params->ith, params->nth};
2764
0
        tb.matmul(m, n);
2765
0
        return true;
2766
#elif defined(__ARM_FEATURE_DOTPROD)
2767
        tinyBLAS_Q0_ARM<block_q8_0> tb{
2768
            k, (const block_q8_0 *)A, lda,
2769
            (const block_q8_0 *)B, ldb,
2770
            (float *)C, ldc,
2771
            params->ith, params->nth};
2772
        tb.matmul(m, n);
2773
        return true;
2774
#elif defined(__MMA__)
2775
    //TO-DO: Remove this condition once gemv forwarding is enabled.
2776
        if (n < 8 && n != 4)
2777
           return false;
2778
        if (m < 8 && m != 4)
2779
           return false;
2780
        tinyBLAS_Q0_PPC<block_q8_0> tb{
2781
            k, (const block_q8_0 *)A, lda,
2782
            (const block_q8_0 *)B, ldb,
2783
            (float *)C, ldc,
2784
            params->ith, params->nth};
2785
        tb.matmul(m, n);
2786
        return true;
2787
#else
2788
        return false;
2789
#endif
2790
0
    }
2791
2792
0
    case GGML_TYPE_Q4_0: {
2793
0
        if (Btype != GGML_TYPE_Q8_0)
2794
0
            return false;
2795
0
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
2796
0
        tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
2797
0
            k, (const block_q4_0 *)A, lda,
2798
0
            (const block_q8_0 *)B, ldb,
2799
0
            (float *)C, ldc,
2800
0
            params->ith, params->nth};
2801
0
        tb.matmul(m, n);
2802
0
        return true;
2803
#elif defined(__ARM_FEATURE_DOTPROD)
2804
        tinyBLAS_Q0_ARM<block_q4_0> tb{
2805
            k, (const block_q4_0 *)A, lda,
2806
            (const block_q8_0 *)B, ldb,
2807
            (float *)C, ldc,
2808
            params->ith, params->nth};
2809
        tb.matmul(m, n);
2810
        return true;
2811
#elif defined(__MMA__)
2812
    //TO-DO: Remove this condition once gemv forwarding is enabled.
2813
        if (n < 8 && n != 4)
2814
           return false;
2815
        if (m < 8 && m != 4)
2816
           return false;
2817
        tinyBLAS_Q0_PPC<block_q4_0> tb{
2818
            k, (const block_q4_0 *)A, lda,
2819
            (const block_q8_0 *)B, ldb,
2820
            (float *)C, ldc,
2821
            params->ith, params->nth};
2822
        tb.matmul(m, n);
2823
        return true;
2824
#else
2825
        return false;
2826
#endif
2827
0
    }
2828
2829
0
    case GGML_TYPE_Q5_0: {
2830
0
        if (Btype != GGML_TYPE_Q8_0)
2831
0
            return false;
2832
0
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
2833
0
        tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
2834
0
            k, (const block_q5_0 *)A, lda,
2835
0
            (const block_q8_0 *)B, ldb,
2836
0
            (float *)C, ldc,
2837
0
            params->ith, params->nth};
2838
0
        tb.matmul(m, n);
2839
0
        return true;
2840
#else
2841
        return false;
2842
#endif
2843
0
    }
2844
2845
0
    case GGML_TYPE_IQ4_NL: {
2846
0
        if (Btype != GGML_TYPE_Q8_0)
2847
0
            return false;
2848
0
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
2849
0
        tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
2850
0
            k, (const block_iq4_nl *)A, lda,
2851
0
            (const block_q8_0 *)B, ldb,
2852
0
            (float *)C, ldc,
2853
0
            params->ith, params->nth};
2854
0
        tb.matmul(m, n);
2855
0
        return true;
2856
#else
2857
        return false;
2858
#endif
2859
0
    }
2860
2861
0
    default:
2862
0
        return false;
2863
0
    }
2864
2865
0
    (void)params;
2866
0
    (void)m;
2867
0
    (void)n;
2868
0
    (void)k;
2869
0
    (void)A;
2870
0
    (void)lda;
2871
0
    (void)B;
2872
0
    (void)ldb;
2873
0
    (void)C;
2874
0
    (void)ldc;
2875
0
    (void)Atype;
2876
0
    (void)Btype;
2877
0
    (void)Ctype;
2878
0
}