/rust/registry/src/index.crates.io-6f17d22bba15001f/matrixmultiply-0.3.9/src/cgemm_kernel.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright 2016 - 2021 Ulrik Sverdrup "bluss" |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
4 | | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
5 | | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your |
6 | | // option. This file may not be copied, modified, or distributed |
7 | | // except according to those terms. |
8 | | |
9 | | use crate::kernel::GemmKernel; |
10 | | use crate::kernel::GemmSelect; |
11 | | use crate::kernel::{U2, U4, c32, Element, c32_mul as mul}; |
12 | | use crate::archparam; |
13 | | use crate::cgemm_common::pack_complex; |
14 | | |
15 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
16 | | struct KernelAvx2; |
17 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
18 | | struct KernelFma; |
19 | | |
20 | | #[cfg(target_arch = "aarch64")] |
21 | | #[cfg(has_aarch64_simd)] |
22 | | struct KernelNeon; |
23 | | |
24 | | struct KernelFallback; |
25 | | |
26 | | type T = c32; |
27 | | type TReal = f32; |
28 | | |
29 | | /// Detect which implementation to use and select it using the selector's |
30 | | /// .select(Kernel) method. |
31 | | /// |
32 | | /// This function is called one or more times during a whole program's |
33 | | /// execution, it may be called for each gemm kernel invocation or fewer times. |
34 | | #[inline] |
35 | 0 | pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> { |
36 | 0 | // dispatch to specific compiled versions |
37 | 0 | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
38 | 0 | { |
39 | 0 | if is_x86_feature_detected_!("fma") { |
40 | 0 | if is_x86_feature_detected_!("avx2") { |
41 | 0 | return selector.select(KernelAvx2); |
42 | 0 | } |
43 | 0 | return selector.select(KernelFma); |
44 | 0 | } |
45 | 0 | } |
46 | 0 | #[cfg(target_arch = "aarch64")] |
47 | 0 | #[cfg(has_aarch64_simd)] |
48 | 0 | { |
49 | 0 | if is_aarch64_feature_detected_!("neon") { |
50 | 0 | return selector.select(KernelNeon); |
51 | 0 | } |
52 | 0 | } |
53 | 0 | return selector.select(KernelFallback); |
54 | 0 | } |
55 | | |
56 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
57 | | impl GemmKernel for KernelAvx2 { |
58 | | type Elem = T; |
59 | | |
60 | | type MRTy = U4; |
61 | | type NRTy = U4; |
62 | | |
63 | | #[inline(always)] |
64 | 0 | fn align_to() -> usize { 32 } |
65 | | |
66 | | #[inline(always)] |
67 | 0 | fn always_masked() -> bool { KernelFallback::always_masked() } |
68 | | |
69 | | #[inline(always)] |
70 | 0 | fn nc() -> usize { archparam::C_NC } |
71 | | #[inline(always)] |
72 | 0 | fn kc() -> usize { archparam::C_KC } |
73 | | #[inline(always)] |
74 | 0 | fn mc() -> usize { archparam::C_MC } |
75 | | |
76 | | pack_methods!{} |
77 | | |
78 | | #[inline(always)] |
79 | 0 | unsafe fn kernel( |
80 | 0 | k: usize, |
81 | 0 | alpha: T, |
82 | 0 | a: *const T, |
83 | 0 | b: *const T, |
84 | 0 | beta: T, |
85 | 0 | c: *mut T, rsc: isize, csc: isize) { |
86 | 0 | kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc) |
87 | 0 | } |
88 | | } |
89 | | |
90 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
91 | | impl GemmKernel for KernelFma { |
92 | | type Elem = T; |
93 | | |
94 | | type MRTy = U4; |
95 | | type NRTy = U4; |
96 | | |
97 | | #[inline(always)] |
98 | 0 | fn align_to() -> usize { 16 } |
99 | | |
100 | | #[inline(always)] |
101 | 0 | fn always_masked() -> bool { KernelFallback::always_masked() } |
102 | | |
103 | | #[inline(always)] |
104 | 0 | fn nc() -> usize { archparam::C_NC } |
105 | | #[inline(always)] |
106 | 0 | fn kc() -> usize { archparam::C_KC } |
107 | | #[inline(always)] |
108 | 0 | fn mc() -> usize { archparam::C_MC } |
109 | | |
110 | | pack_methods!{} |
111 | | |
112 | | #[inline(always)] |
113 | 0 | unsafe fn kernel( |
114 | 0 | k: usize, |
115 | 0 | alpha: T, |
116 | 0 | a: *const T, |
117 | 0 | b: *const T, |
118 | 0 | beta: T, |
119 | 0 | c: *mut T, rsc: isize, csc: isize) { |
120 | 0 | kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc) |
121 | 0 | } |
122 | | } |
123 | | |
124 | | #[cfg(target_arch = "aarch64")] |
125 | | #[cfg(has_aarch64_simd)] |
126 | | impl GemmKernel for KernelNeon { |
127 | | type Elem = T; |
128 | | |
129 | | type MRTy = U4; |
130 | | type NRTy = U2; |
131 | | |
132 | | #[inline(always)] |
133 | | fn align_to() -> usize { 16 } |
134 | | |
135 | | #[inline(always)] |
136 | | fn always_masked() -> bool { KernelFallback::always_masked() } |
137 | | |
138 | | #[inline(always)] |
139 | | fn nc() -> usize { archparam::C_NC } |
140 | | #[inline(always)] |
141 | | fn kc() -> usize { archparam::C_KC } |
142 | | #[inline(always)] |
143 | | fn mc() -> usize { archparam::C_MC } |
144 | | |
145 | | pack_methods!{} |
146 | | |
147 | | #[inline(always)] |
148 | | unsafe fn kernel( |
149 | | k: usize, |
150 | | alpha: T, |
151 | | a: *const T, |
152 | | b: *const T, |
153 | | beta: T, |
154 | | c: *mut T, rsc: isize, csc: isize) { |
155 | | kernel_target_neon(k, alpha, a, b, beta, c, rsc, csc) |
156 | | } |
157 | | } |
158 | | |
159 | | impl GemmKernel for KernelFallback { |
160 | | type Elem = T; |
161 | | |
162 | | type MRTy = U4; |
163 | | type NRTy = U2; |
164 | | |
165 | | #[inline(always)] |
166 | 0 | fn align_to() -> usize { 0 } |
167 | | |
168 | | #[inline(always)] |
169 | 0 | fn always_masked() -> bool { true } |
170 | | |
171 | | #[inline(always)] |
172 | 0 | fn nc() -> usize { archparam::C_NC } |
173 | | #[inline(always)] |
174 | 0 | fn kc() -> usize { archparam::C_KC } |
175 | | #[inline(always)] |
176 | 0 | fn mc() -> usize { archparam::C_MC } |
177 | | |
178 | | pack_methods!{} |
179 | | |
180 | | #[inline(always)] |
181 | 0 | unsafe fn kernel( |
182 | 0 | k: usize, |
183 | 0 | alpha: T, |
184 | 0 | a: *const T, |
185 | 0 | b: *const T, |
186 | 0 | beta: T, |
187 | 0 | c: *mut T, rsc: isize, csc: isize) { |
188 | 0 | kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) |
189 | 0 | } |
190 | | } |
191 | | |
192 | | // Kernel AVX2 |
193 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
194 | | macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; } |
195 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
196 | | macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; } |
197 | | |
198 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
199 | | kernel_fallback_impl_complex! { |
200 | | // instantiate separately |
201 | | [inline target_feature(enable="avx2") target_feature(enable="fma")] [fma_yes] |
202 | | kernel_target_avx2, T, TReal, KernelAvx2::MR, KernelAvx2::NR, 4 |
203 | | } |
204 | | |
205 | | |
206 | | // Kernel Fma |
207 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
208 | | macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; } |
209 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
210 | | macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; } |
211 | | |
212 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
213 | | kernel_fallback_impl_complex! { |
214 | | // instantiate separately |
215 | | [inline target_feature(enable="fma")] [fma_no] |
216 | | kernel_target_fma, T, TReal, KernelFma::MR, KernelFma::NR, 2 |
217 | | } |
218 | | |
219 | | // Kernel neon |
220 | | |
221 | | #[cfg(target_arch = "aarch64")] |
222 | | #[cfg(has_aarch64_simd)] |
223 | | macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; } |
224 | | #[cfg(target_arch = "aarch64")] |
225 | | #[cfg(has_aarch64_simd)] |
226 | | macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; } |
227 | | |
228 | | #[cfg(target_arch = "aarch64")] |
229 | | #[cfg(has_aarch64_simd)] |
230 | | kernel_fallback_impl_complex! { |
231 | | [inline target_feature(enable="neon")] [fma_yes] |
232 | | kernel_target_neon, T, TReal, KernelNeon::MR, KernelNeon::NR, 1 |
233 | | } |
234 | | |
235 | | // Kernel fallback |
236 | | |
237 | | macro_rules! loop_m { ($i:ident, $e:expr) => { loop4!($i, $e) }; } |
238 | | macro_rules! loop_n { ($j:ident, $e:expr) => { loop2!($j, $e) }; } |
239 | | |
240 | | kernel_fallback_impl_complex! { |
241 | | [inline(always)] [fma_no] |
242 | | kernel_fallback_impl, T, TReal, KernelFallback::MR, KernelFallback::NR, 1 |
243 | | } |
244 | | |
245 | | #[inline(always)] |
246 | 0 | unsafe fn at(ptr: *const TReal, i: usize) -> TReal { |
247 | 0 | *ptr.add(i) |
248 | 0 | } |
249 | | |
250 | | #[cfg(test)] |
251 | | mod tests { |
252 | | use super::*; |
253 | | use crate::kernel::test::test_complex_packed_kernel; |
254 | | |
255 | | #[test] |
256 | | fn test_kernel_fallback_impl() { |
257 | | test_complex_packed_kernel::<KernelFallback, _, TReal>("kernel"); |
258 | | } |
259 | | |
260 | | #[cfg(target_arch = "aarch64")] |
261 | | #[cfg(has_aarch64_simd)] |
262 | | mod test_kernel_aarch64 { |
263 | | use super::test_complex_packed_kernel; |
264 | | use super::super::*; |
265 | | #[cfg(feature = "std")] |
266 | | use std::println; |
267 | | macro_rules! test_arch_kernels { |
268 | | ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => { |
269 | | $( |
270 | | #[test] |
271 | | fn $name() { |
272 | | if is_aarch64_feature_detected_!($feature_name) { |
273 | | test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name)); |
274 | | } else { |
275 | | #[cfg(feature = "std")] |
276 | | println!("Skipping, host does not have feature: {:?}", $feature_name); |
277 | | } |
278 | | } |
279 | | )* |
280 | | } |
281 | | } |
282 | | |
283 | | test_arch_kernels! { |
284 | | "neon", neon, KernelNeon |
285 | | } |
286 | | } |
287 | | |
288 | | #[cfg(any(target_arch="x86", target_arch="x86_64"))] |
289 | | mod test_arch_kernels { |
290 | | use super::test_complex_packed_kernel; |
291 | | use super::super::*; |
292 | | #[cfg(feature = "std")] |
293 | | use std::println; |
294 | | macro_rules! test_arch_kernels_x86 { |
295 | | ($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => { |
296 | | $( |
297 | | #[test] |
298 | | fn $name() { |
299 | | if is_x86_feature_detected_!($feature_name) { |
300 | | test_complex_packed_kernel::<$kernel_ty, _, TReal>(stringify!($name)); |
301 | | } else { |
302 | | #[cfg(feature = "std")] |
303 | | println!("Skipping, host does not have feature: {:?}", $feature_name); |
304 | | } |
305 | | } |
306 | | )* |
307 | | } |
308 | | } |
309 | | |
310 | | test_arch_kernels_x86! { |
311 | | "fma", fma, KernelFma, |
312 | | "avx2", avx2, KernelAvx2 |
313 | | } |
314 | | } |
315 | | } |