/rust/registry/src/index.crates.io-6f17d22bba15001f/matrixmultiply-0.3.9/src/cgemm_common.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright 2021-2023 Ulrik Sverdrup "bluss" |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
4 | | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
5 | | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your |
6 | | // option. This file may not be copied, modified, or distributed |
7 | | // except according to those terms. |
8 | | |
9 | | use core::mem; |
10 | | use core::ptr::copy_nonoverlapping; |
11 | | |
12 | | use rawpointer::PointerExt; |
13 | | |
14 | | use crate::kernel::Element; |
15 | | use crate::kernel::ConstNum; |
16 | | |
17 | | #[cfg(feature = "std")] |
18 | | macro_rules! fmuladd { |
19 | | // conceptually $dst += $a * $b, optionally use fused multiply-add |
20 | | (fma_yes, $dst:expr, $a:expr, $b:expr) => { |
21 | | { |
22 | | $dst = $a.mul_add($b, $dst); |
23 | | } |
24 | | }; |
25 | | (fma_no, $dst:expr, $a:expr, $b:expr) => { |
26 | | { |
27 | | $dst += $a * $b; |
28 | | } |
29 | | }; |
30 | | } |
31 | | |
32 | | #[cfg(not(feature = "std"))] |
33 | | macro_rules! fmuladd { |
34 | | ($any:tt, $dst:expr, $a:expr, $b:expr) => { |
35 | | { |
36 | | $dst += $a * $b; |
37 | | } |
38 | | }; |
39 | | } |
40 | | |
41 | | |
42 | | // kernel fallback impl macro |
43 | | // Depends on a couple of macro and function defitions to be in scope - loop_m/_n, at, etc. |
44 | | // $fma_opt: fma_yes or fma_no to use f32::mul_add etc or not |
45 | | macro_rules! kernel_fallback_impl_complex { |
46 | | ([$($attr:meta)*] [$fma_opt:tt] $name:ident, $elem_ty:ty, $real_ty:ty, $mr:expr, $nr:expr, $unroll:tt) => { |
47 | | $(#[$attr])* |
48 | 0 | unsafe fn $name(k: usize, alpha: $elem_ty, a: *const $elem_ty, b: *const $elem_ty, |
49 | 0 | beta: $elem_ty, c: *mut $elem_ty, rsc: isize, csc: isize) |
50 | 0 | { |
51 | | const MR: usize = $mr; |
52 | | const NR: usize = $nr; |
53 | | |
54 | 0 | debug_assert_eq!(beta, <$elem_ty>::zero(), "Beta must be 0 or is not masked"); |
55 | | |
56 | 0 | let mut pp = [<$real_ty>::zero(); MR]; |
57 | 0 | let mut qq = [<$real_ty>::zero(); MR]; |
58 | 0 | let mut rr = [<$real_ty>::zero(); NR]; |
59 | 0 | let mut ss = [<$real_ty>::zero(); NR]; |
60 | 0 |
|
61 | 0 | let mut ab: [[$elem_ty; NR]; MR] = [[<$elem_ty>::zero(); NR]; MR]; |
62 | 0 | let mut areal = a as *const $real_ty; |
63 | 0 | let mut breal = b as *const $real_ty; |
64 | 0 |
|
65 | 0 | unroll_by!($unroll => k, { |
66 | | // We set: |
67 | | // P + Q i = A |
68 | | // R + S i = B |
69 | | // |
70 | | // see pack_complex for how data is packed |
71 | 0 | let aimag = areal.add(MR); |
72 | 0 | let bimag = breal.add(NR); |
73 | | |
74 | | // AB = PR - QS + i (QR + PS) |
75 | 0 | loop_m!(i, { |
76 | 0 | pp[i] = at(areal, i); |
77 | 0 | qq[i] = at(aimag, i); |
78 | 0 | }); |
79 | 0 | loop_n!(j, { |
80 | 0 | rr[j] = at(breal, j); |
81 | 0 | ss[j] = at(bimag, j); |
82 | 0 | }); |
83 | 0 | loop_m!(i, { |
84 | 0 | loop_n!(j, { |
85 | 0 | // optionally use fma |
86 | 0 | fmuladd!($fma_opt, ab[i][j][0], pp[i], rr[j]); |
87 | 0 | fmuladd!($fma_opt, ab[i][j][1], pp[i], ss[j]); |
88 | 0 | fmuladd!($fma_opt, ab[i][j][0], -qq[i], ss[j]); |
89 | 0 | fmuladd!($fma_opt, ab[i][j][1], qq[i], rr[j]); |
90 | 0 | }) |
91 | 0 | }); |
92 | 0 |
|
93 | 0 | areal = aimag.add(MR); |
94 | 0 | breal = bimag.add(NR); |
95 | | }); |
96 | | |
97 | | macro_rules! c { |
98 | | ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize)); |
99 | | } |
100 | | |
101 | | // set C = α A B |
102 | 0 | loop_n!(j, loop_m!(i, *c![i, j] = mul(alpha, ab[i][j]))); |
103 | 0 | } Unexecuted instantiation: matrixmultiply::cgemm_kernel::kernel_fallback_impl Unexecuted instantiation: matrixmultiply::cgemm_kernel::kernel_target_avx2 Unexecuted instantiation: matrixmultiply::cgemm_kernel::kernel_target_fma Unexecuted instantiation: matrixmultiply::zgemm_kernel::kernel_target_avx2 Unexecuted instantiation: matrixmultiply::zgemm_kernel::kernel_target_fma Unexecuted instantiation: matrixmultiply::zgemm_kernel::kernel_fallback_impl |
104 | | }; |
105 | | } |
106 | | |
107 | | /// GemmKernel packing trait methods |
108 | | macro_rules! pack_methods { |
109 | | () => { |
110 | | #[inline] |
111 | 0 | unsafe fn pack_mr(kc: usize, mc: usize, pack: &mut [Self::Elem], |
112 | 0 | a: *const Self::Elem, rsa: isize, csa: isize) |
113 | 0 | { |
114 | 0 | pack_complex::<Self::MRTy, T, TReal>(kc, mc, pack, a, rsa, csa) |
115 | 0 | } Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelAvx2 as matrixmultiply::kernel::GemmKernel>::pack_mr Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelFma as matrixmultiply::kernel::GemmKernel>::pack_mr Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelFallback as matrixmultiply::kernel::GemmKernel>::pack_mr Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelAvx2 as matrixmultiply::kernel::GemmKernel>::pack_mr Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelFma as matrixmultiply::kernel::GemmKernel>::pack_mr Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelFallback as matrixmultiply::kernel::GemmKernel>::pack_mr |
116 | | |
117 | | #[inline] |
118 | 0 | unsafe fn pack_nr(kc: usize, mc: usize, pack: &mut [Self::Elem], |
119 | 0 | a: *const Self::Elem, rsa: isize, csa: isize) |
120 | 0 | { |
121 | 0 | pack_complex::<Self::NRTy, T, TReal>(kc, mc, pack, a, rsa, csa) |
122 | 0 | } Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelAvx2 as matrixmultiply::kernel::GemmKernel>::pack_nr Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelFma as matrixmultiply::kernel::GemmKernel>::pack_nr Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelFallback as matrixmultiply::kernel::GemmKernel>::pack_nr Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelAvx2 as matrixmultiply::kernel::GemmKernel>::pack_nr Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelFma as matrixmultiply::kernel::GemmKernel>::pack_nr Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelFallback as matrixmultiply::kernel::GemmKernel>::pack_nr |
123 | | } |
124 | | } |
125 | | |
126 | | |
127 | | /// Pack complex: similar to general packing but separate rows for real and imag parts. |
128 | | /// |
129 | | /// Source matrix contains [p0 + q0i, p1 + q1i, p2 + q2i, ..] and it's packed into |
130 | | /// alternate rows of real and imaginary parts. |
131 | | /// |
132 | | /// [ p0 p1 p2 p3 .. (MR repeats) |
133 | | /// q0 q1 q2 q3 .. (MR repeats) |
134 | | /// px p_ p_ p_ .. (x = MR) |
135 | | /// qx q_ q_ q_ .. (x = MR) |
136 | | /// py p_ p_ p_ .. (y = 2 * MR) |
137 | | /// qy q_ q_ q_ .. (y = 2 * MR) |
138 | | /// ... |
139 | | /// ] |
140 | 0 | pub(crate) unsafe fn pack_complex<MR, T, TReal>(kc: usize, mc: usize, pack: &mut [T], |
141 | 0 | a: *const T, rsa: isize, csa: isize) |
142 | 0 | where MR: ConstNum, |
143 | 0 | T: Element, |
144 | 0 | TReal: Element, |
145 | 0 | { |
146 | 0 | // use pointers as pointer to TReal |
147 | 0 | let pack = pack.as_mut_ptr() as *mut TReal; |
148 | 0 | let areal = a as *const TReal; |
149 | 0 | let aimag = areal.add(1); |
150 | 0 |
|
151 | 0 | assert_eq!(mem::size_of::<T>(), 2 * mem::size_of::<TReal>()); |
152 | | |
153 | 0 | let mr = MR::VALUE; |
154 | 0 | let mut p = 0; // offset into pack |
155 | | |
156 | | // general layout case (no contig case when stride != 1) |
157 | 0 | for ir in 0..mc/mr { |
158 | 0 | let row_offset = ir * mr; |
159 | 0 | for j in 0..kc { |
160 | | // real row |
161 | 0 | for i in 0..mr { |
162 | 0 | let a_elt = areal.stride_offset(2 * rsa, i + row_offset) |
163 | 0 | .stride_offset(2 * csa, j); |
164 | 0 | copy_nonoverlapping(a_elt, pack.add(p), 1); |
165 | 0 | p += 1; |
166 | 0 | } |
167 | | // imag row |
168 | 0 | for i in 0..mr { |
169 | 0 | let a_elt = aimag.stride_offset(2 * rsa, i + row_offset) |
170 | 0 | .stride_offset(2 * csa, j); |
171 | 0 | copy_nonoverlapping(a_elt, pack.add(p), 1); |
172 | 0 | p += 1; |
173 | 0 | } |
174 | | } |
175 | | } |
176 | | |
177 | 0 | let zero = TReal::zero(); |
178 | 0 |
|
179 | 0 | // Pad with zeros to multiple of kernel size (uneven mc) |
180 | 0 | let rest = mc % mr; |
181 | 0 | if rest > 0 { |
182 | 0 | let row_offset = (mc/mr) * mr; |
183 | 0 | for j in 0..kc { |
184 | | // real row |
185 | 0 | for i in 0..mr { |
186 | 0 | if i < rest { |
187 | 0 | let a_elt = areal.stride_offset(2 * rsa, i + row_offset) |
188 | 0 | .stride_offset(2 * csa, j); |
189 | 0 | copy_nonoverlapping(a_elt, pack.add(p), 1); |
190 | 0 | } else { |
191 | 0 | *pack.add(p) = zero; |
192 | 0 | } |
193 | 0 | p += 1; |
194 | | } |
195 | | // imag row |
196 | 0 | for i in 0..mr { |
197 | 0 | if i < rest { |
198 | 0 | let a_elt = aimag.stride_offset(2 * rsa, i + row_offset) |
199 | 0 | .stride_offset(2 * csa, j); |
200 | 0 | copy_nonoverlapping(a_elt, pack.add(p), 1); |
201 | 0 | } else { |
202 | 0 | *pack.add(p) = zero; |
203 | 0 | } |
204 | 0 | p += 1; |
205 | | } |
206 | | } |
207 | 0 | } |
208 | 0 | } Unexecuted instantiation: matrixmultiply::cgemm_common::pack_complex::<matrixmultiply::kernel::U2, [f64; 2], f64> Unexecuted instantiation: matrixmultiply::cgemm_common::pack_complex::<matrixmultiply::kernel::U2, [f32; 2], f32> Unexecuted instantiation: matrixmultiply::cgemm_common::pack_complex::<matrixmultiply::kernel::U4, [f64; 2], f64> Unexecuted instantiation: matrixmultiply::cgemm_common::pack_complex::<matrixmultiply::kernel::U4, [f32; 2], f32> |