Coverage Report

Created: 2025-02-21 07:11

/rust/registry/src/index.crates.io-6f17d22bba15001f/matrixmultiply-0.3.9/src/cgemm_kernel.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2016 - 2021 Ulrik Sverdrup "bluss"
2
//
3
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6
// option. This file may not be copied, modified, or distributed
7
// except according to those terms.
8
9
use crate::kernel::GemmKernel;
10
use crate::kernel::GemmSelect;
11
use crate::kernel::{U2, U4, c32, Element, c32_mul as mul};
12
use crate::archparam;
13
use crate::cgemm_common::pack_complex;
14
15
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
16
struct KernelAvx2;
17
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
18
struct KernelFma;
19
20
#[cfg(target_arch = "aarch64")]
21
#[cfg(has_aarch64_simd)]
22
struct KernelNeon;
23
24
struct KernelFallback;
25
26
type T = c32;
27
type TReal = f32;
28
29
/// Detect which implementation to use and select it using the selector's
30
/// .select(Kernel) method.
31
///
32
/// This function is called one or more times during a whole program's
33
/// execution, it may be called for each gemm kernel invocation or fewer times.
34
#[inline]
35
0
pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
36
0
    // dispatch to specific compiled versions
37
0
    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
38
0
    {
39
0
        if is_x86_feature_detected_!("fma") {
40
0
            if is_x86_feature_detected_!("avx2") {
41
0
                return selector.select(KernelAvx2);
42
0
            }
43
0
            return selector.select(KernelFma);
44
0
        }
45
0
    }
46
0
    #[cfg(target_arch = "aarch64")]
47
0
    #[cfg(has_aarch64_simd)]
48
0
    {
49
0
        if is_aarch64_feature_detected_!("neon") {
50
0
            return selector.select(KernelNeon);
51
0
        }
52
0
    }
53
0
    return selector.select(KernelFallback);
54
0
}
55
56
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
57
impl GemmKernel for KernelAvx2 {
58
    type Elem = T;
59
60
    type MRTy = U4;
61
    type NRTy = U4;
62
63
    #[inline(always)]
64
0
    fn align_to() -> usize { 32 }
65
66
    #[inline(always)]
67
0
    fn always_masked() -> bool { KernelFallback::always_masked() }
68
69
    #[inline(always)]
70
0
    fn nc() -> usize { archparam::C_NC }
71
    #[inline(always)]
72
0
    fn kc() -> usize { archparam::C_KC }
73
    #[inline(always)]
74
0
    fn mc() -> usize { archparam::C_MC }
75
76
    pack_methods!{}
77
78
    #[inline(always)]
79
0
    unsafe fn kernel(
80
0
        k: usize,
81
0
        alpha: T,
82
0
        a: *const T,
83
0
        b: *const T,
84
0
        beta: T,
85
0
        c: *mut T, rsc: isize, csc: isize) {
86
0
        kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc)
87
0
    }
88
}
89
90
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
91
impl GemmKernel for KernelFma {
92
    type Elem = T;
93
94
    type MRTy = U4;
95
    type NRTy = U4;
96
97
    #[inline(always)]
98
0
    fn align_to() -> usize { 16 }
99
100
    #[inline(always)]
101
0
    fn always_masked() -> bool { KernelFallback::always_masked() }
102
103
    #[inline(always)]
104
0
    fn nc() -> usize { archparam::C_NC }
105
    #[inline(always)]
106
0
    fn kc() -> usize { archparam::C_KC }
107
    #[inline(always)]
108
0
    fn mc() -> usize { archparam::C_MC }
109
110
    pack_methods!{}
111
112
    #[inline(always)]
113
0
    unsafe fn kernel(
114
0
        k: usize,
115
0
        alpha: T,
116
0
        a: *const T,
117
0
        b: *const T,
118
0
        beta: T,
119
0
        c: *mut T, rsc: isize, csc: isize) {
120
0
        kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
121
0
    }
122
}
123
124
#[cfg(target_arch = "aarch64")]
125
#[cfg(has_aarch64_simd)]
126
impl GemmKernel for KernelNeon {
127
    type Elem = T;
128
129
    type MRTy = U4;
130
    type NRTy = U2;
131
132
    #[inline(always)]
133
    fn align_to() -> usize { 16 }
134
135
    #[inline(always)]
136
    fn always_masked() -> bool { KernelFallback::always_masked() }
137
138
    #[inline(always)]
139
    fn nc() -> usize { archparam::C_NC }
140
    #[inline(always)]
141
    fn kc() -> usize { archparam::C_KC }
142
    #[inline(always)]
143
    fn mc() -> usize { archparam::C_MC }
144
145
    pack_methods!{}
146
147
    #[inline(always)]
148
    unsafe fn kernel(
149
        k: usize,
150
        alpha: T,
151
        a: *const T,
152
        b: *const T,
153
        beta: T,
154
        c: *mut T, rsc: isize, csc: isize) {
155
        kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc)
156
    }
157
}
158
159
impl GemmKernel for KernelFallback {
160
    type Elem = T;
161
162
    type MRTy = U4;
163
    type NRTy = U2;
164
165
    #[inline(always)]
166
0
    fn align_to() -> usize { 0 }
167
168
    #[inline(always)]
169
0
    fn always_masked() -> bool { true }
170
171
    #[inline(always)]
172
0
    fn nc() -> usize { archparam::C_NC }
173
    #[inline(always)]
174
0
    fn kc() -> usize { archparam::C_KC }
175
    #[inline(always)]
176
0
    fn mc() -> usize { archparam::C_MC }
177
178
    pack_methods!{}
179
180
    #[inline(always)]
181
0
    unsafe fn kernel(
182
0
        k: usize,
183
0
        alpha: T,
184
0
        a: *const T,
185
0
        b: *const T,
186
0
        beta: T,
187
0
        c: *mut T, rsc: isize, csc: isize) {
188
0
        kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
189
0
    }
190
}
191
192
// Kernel AVX2
193
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
194
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
195
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
196
macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
197
198
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
199
kernel_fallback_impl_complex! {
200
    // instantiate separately
201
    [inline target_feature(enable="avx2") target_feature(enable="fma")] [fma_yes]
202
    kernel_target_avx2, T, TReal, KernelAvx2::MR, KernelAvx2::NR, 4
203
}
204
205
206
// Kernel Fma
207
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
208
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
209
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
210
macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; }
211
212
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
213
kernel_fallback_impl_complex! {
214
    // instantiate separately
215
    [inline target_feature(enable="fma")] [fma_no]
216
    kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2
217
}
218
219
// Kernel neon
220
221
#[cfg(target_arch = "aarch64")]
222
#[cfg(has_aarch64_simd)]
223
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
224
#[cfg(target_arch = "aarch64")]
225
#[cfg(has_aarch64_simd)]
226
macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
227
228
#[cfg(target_arch = "aarch64")]
229
#[cfg(has_aarch64_simd)]
230
kernel_fallback_impl_complex! {
231
    [inline target_feature(enable="neon")] [fma_yes]
232
    kernel_target_neon, T, TReal, KernelNeon::MR, KernelNeon::NR, 1
233
}
234
235
// Kernel fallback
236
237
macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; }
238
macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; }
239
240
kernel_fallback_impl_complex! {
241
    [inline(always)] [fma_no]
242
    kernel_fallback_impl, T, TReal, KernelFallback::MR, KernelFallback::NR, 1
243
}
244
245
#[inline(always)]
246
0
unsafe fn at(ptr: *const TReal, i: usize) -> TReal {
247
0
    *ptr.add(i)
248
0
}
249
250
#[cfg(test)]
251
mod tests {
252
    use super::*;
253
    use crate::kernel::test::test_complex_packed_kernel;
254
255
    #[test]
256
    fn test_kernel_fallback_impl() {
257
        test_complex_packed_kernel::<KernelFallback, _, TReal>("kernel");
258
    }
259
260
    #[cfg(target_arch = "aarch64")]
261
    #[cfg(has_aarch64_simd)]
262
    mod test_kernel_aarch64 {
263
        use super::test_complex_packed_kernel;
264
        use super::super::*;
265
        #[cfg(feature = "std")]
266
        use std::println;
267
        macro_rules! test_arch_kernels {
268
            ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
269
                $(
270
                #[test]
271
                fn $name() {
272
                    if is_aarch64_feature_detected_!($feature_name) {
273
                        test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
274
                    } else {
275
                        #[cfg(feature = "std")]
276
                        println!("Skipping, host does not have feature: {:?}", $feature_name);
277
                    }
278
                }
279
                )*
280
            }
281
        }
282
283
        test_arch_kernels! {
284
            "neon", neon, KernelNeon
285
        }
286
    }
287
288
    #[cfg(any(target_arch="x86", target_arch="x86_64"))]
289
    mod test_arch_kernels {
290
        use super::test_complex_packed_kernel;
291
        use super::super::*;
292
        #[cfg(feature = "std")]
293
        use std::println;
294
        macro_rules! test_arch_kernels_x86 {
295
            ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
296
                $(
297
                #[test]
298
                fn $name() {
299
                    if is_x86_feature_detected_!($feature_name) {
300
                        test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name));
301
                    } else {
302
                        #[cfg(feature = "std")]
303
                        println!("Skipping, host does not have feature: {:?}", $feature_name);
304
                    }
305
                }
306
                )*
307
            }
308
        }
309
310
        test_arch_kernels_x86! {
311
            "fma", fma, KernelFma,
312
            "avx2", avx2, KernelAvx2
313
        }
314
    }
315
}