Coverage Report

Created: 2025-02-21 07:11

/rust/registry/src/index.crates.io-6f17d22bba15001f/matrixmultiply-0.3.9/src/cgemm_common.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2021-2023 Ulrik Sverdrup "bluss"
2
//
3
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6
// option. This file may not be copied, modified, or distributed
7
// except according to those terms.
8
9
use core::mem;
10
use core::ptr::copy_nonoverlapping;
11
12
use rawpointer::PointerExt;
13
14
use crate::kernel::Element;
15
use crate::kernel::ConstNum;
16
17
#[cfg(feature = "std")]
18
macro_rules! fmuladd {
19
    // conceptually $dst += $a * $b, optionally use fused multiply-add
20
    (fma_yes, $dst:expr, $a:expr, $b:expr) => {
21
        {
22
            $dst = $a.mul_add($b, $dst);
23
        }
24
    };
25
    (fma_no, $dst:expr, $a:expr, $b:expr) => {
26
        {
27
            $dst += $a * $b;
28
        }
29
    };
30
}
31
32
#[cfg(not(feature = "std"))]
33
macro_rules! fmuladd {
34
    ($any:tt, $dst:expr, $a:expr, $b:expr) => {
35
        {
36
            $dst += $a * $b;
37
        }
38
    };
39
}
40
41
42
// kernel fallback impl macro
43
// Depends on a couple of macro and function defitions to be in scope - loop_m/_n, at, etc.
44
// $fma_opt: fma_yes or fma_no to use f32::mul_add etc or not
45
macro_rules! kernel_fallback_impl_complex {
46
    ([$($attr:meta)*] [$fma_opt:tt] $name:ident, $elem_ty:ty, $real_ty:ty, $mr:expr, $nr:expr, $unroll:tt) => {
47
    $(#[$attr])*
48
0
    unsafe fn $name(k: usize, alpha: $elem_ty, a: *const $elem_ty, b: *const $elem_ty,
49
0
                    beta: $elem_ty, c: *mut $elem_ty, rsc: isize, csc: isize)
50
0
    {
51
        const MR: usize = $mr;
52
        const NR: usize = $nr;
53
54
0
        debug_assert_eq!(beta, <$elem_ty>::zero(), "Beta must be 0 or is not masked");
55
56
0
        let mut pp  = [<$real_ty>::zero(); MR];
57
0
        let mut qq  = [<$real_ty>::zero(); MR];
58
0
        let mut rr  = [<$real_ty>::zero(); NR];
59
0
        let mut ss  = [<$real_ty>::zero(); NR];
60
0
61
0
        let mut ab: [[$elem_ty; NR]; MR] = [[<$elem_ty>::zero(); NR]; MR];
62
0
        let mut areal = a as *const $real_ty;
63
0
        let mut breal = b as *const $real_ty;
64
0
65
0
        unroll_by!($unroll => k, {
66
            // We set:
67
            // P + Q i = A
68
            // R + S i = B
69
            //
70
            // see pack_complex for how data is packed
71
0
            let aimag = areal.add(MR);
72
0
            let bimag = breal.add(NR);
73
74
            // AB = PR - QS + i (QR + PS)
75
0
            loop_m!(i, {
76
0
                pp[i] = at(areal, i);
77
0
                qq[i] = at(aimag, i);
78
0
            });
79
0
            loop_n!(j, {
80
0
                rr[j] = at(breal, j);
81
0
                ss[j] = at(bimag, j);
82
0
            });
83
0
            loop_m!(i, {
84
0
                loop_n!(j, {
85
0
                    // optionally use fma
86
0
                    fmuladd!($fma_opt, ab[i][j][0], pp[i], rr[j]);
87
0
                    fmuladd!($fma_opt, ab[i][j][1], pp[i], ss[j]);
88
0
                    fmuladd!($fma_opt, ab[i][j][0], -qq[i], ss[j]);
89
0
                    fmuladd!($fma_opt, ab[i][j][1], qq[i], rr[j]);
90
0
                })
91
0
            });
92
0
93
0
            areal = aimag.add(MR);
94
0
            breal = bimag.add(NR);
95
        });
96
97
        macro_rules! c {
98
            ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
99
        }
100
101
        // set C = α A B
102
0
        loop_n!(j, loop_m!(i, *c![i, j] = mul(alpha, ab[i][j])));
103
0
    }
Unexecuted instantiation: matrixmultiply::cgemm_kernel::kernel_fallback_impl
Unexecuted instantiation: matrixmultiply::cgemm_kernel::kernel_target_avx2
Unexecuted instantiation: matrixmultiply::cgemm_kernel::kernel_target_fma
Unexecuted instantiation: matrixmultiply::zgemm_kernel::kernel_target_avx2
Unexecuted instantiation: matrixmultiply::zgemm_kernel::kernel_target_fma
Unexecuted instantiation: matrixmultiply::zgemm_kernel::kernel_fallback_impl
104
    };
105
}
106
107
/// GemmKernel packing trait methods
108
macro_rules! pack_methods {
109
    () => {
110
        #[inline]
111
0
        unsafe fn pack_mr(kc: usize, mc: usize, pack: &mut [Self::Elem],
112
0
                          a: *const Self::Elem, rsa: isize, csa: isize)
113
0
        {
114
0
            pack_complex::<Self::MRTy, T, TReal>(kc, mc, pack, a, rsa, csa)
115
0
        }
Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelAvx2 as matrixmultiply::kernel::GemmKernel>::pack_mr
Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelFma as matrixmultiply::kernel::GemmKernel>::pack_mr
Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelFallback as matrixmultiply::kernel::GemmKernel>::pack_mr
Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelAvx2 as matrixmultiply::kernel::GemmKernel>::pack_mr
Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelFma as matrixmultiply::kernel::GemmKernel>::pack_mr
Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelFallback as matrixmultiply::kernel::GemmKernel>::pack_mr
116
117
        #[inline]
118
0
        unsafe fn pack_nr(kc: usize, mc: usize, pack: &mut [Self::Elem],
119
0
                        a: *const Self::Elem, rsa: isize, csa: isize)
120
0
        {
121
0
            pack_complex::<Self::NRTy, T, TReal>(kc, mc, pack, a, rsa, csa)
122
0
        }
Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelAvx2 as matrixmultiply::kernel::GemmKernel>::pack_nr
Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelFma as matrixmultiply::kernel::GemmKernel>::pack_nr
Unexecuted instantiation: <matrixmultiply::cgemm_kernel::KernelFallback as matrixmultiply::kernel::GemmKernel>::pack_nr
Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelAvx2 as matrixmultiply::kernel::GemmKernel>::pack_nr
Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelFma as matrixmultiply::kernel::GemmKernel>::pack_nr
Unexecuted instantiation: <matrixmultiply::zgemm_kernel::KernelFallback as matrixmultiply::kernel::GemmKernel>::pack_nr
123
    }
124
}
125
126
127
/// Pack complex: similar to general packing but separate rows for real and imag parts.
128
///
129
/// Source matrix contains [p0 + q0i, p1 + q1i, p2 + q2i, ..] and it's packed into
130
/// alternate rows of real and imaginary parts.
131
///
132
/// [ p0 p1 p2 p3 .. (MR repeats)
133
///   q0 q1 q2 q3 .. (MR repeats)
134
///   px p_ p_ p_ .. (x = MR)
135
///   qx q_ q_ q_ .. (x = MR)
136
///   py p_ p_ p_ .. (y = 2 * MR)
137
///   qy q_ q_ q_ .. (y = 2 * MR)
138
///   ...
139
/// ]
140
0
pub(crate) unsafe fn pack_complex<MR, T, TReal>(kc: usize, mc: usize, pack: &mut [T],
141
0
                                                a: *const T, rsa: isize, csa: isize)
142
0
    where MR: ConstNum,
143
0
          T: Element,
144
0
          TReal: Element,
145
0
{
146
0
    // use pointers as pointer to TReal
147
0
    let pack = pack.as_mut_ptr() as *mut TReal;
148
0
    let areal = a as *const TReal;
149
0
    let aimag = areal.add(1);
150
0
151
0
    assert_eq!(mem::size_of::<T>(), 2 * mem::size_of::<TReal>());
152
153
0
    let mr = MR::VALUE;
154
0
    let mut p = 0; // offset into pack
155
156
    // general layout case (no contig case when stride != 1)
157
0
    for ir in 0..mc/mr {
158
0
        let row_offset = ir * mr;
159
0
        for j in 0..kc {
160
            // real row
161
0
            for i in 0..mr {
162
0
                let a_elt = areal.stride_offset(2 * rsa, i + row_offset)
163
0
                                 .stride_offset(2 * csa, j);
164
0
                copy_nonoverlapping(a_elt, pack.add(p), 1);
165
0
                p += 1;
166
0
            }
167
            // imag row
168
0
            for i in 0..mr {
169
0
                let a_elt = aimag.stride_offset(2 * rsa, i + row_offset)
170
0
                                 .stride_offset(2 * csa, j);
171
0
                copy_nonoverlapping(a_elt, pack.add(p), 1);
172
0
                p += 1;
173
0
            }
174
        }
175
    }
176
177
0
    let zero = TReal::zero();
178
0
179
0
    // Pad with zeros to multiple of kernel size (uneven mc)
180
0
    let rest = mc % mr;
181
0
    if rest > 0 {
182
0
        let row_offset = (mc/mr) * mr;
183
0
        for j in 0..kc {
184
            // real row
185
0
            for i in 0..mr {
186
0
                if i < rest {
187
0
                    let a_elt = areal.stride_offset(2 * rsa, i + row_offset)
188
0
                                     .stride_offset(2 * csa, j);
189
0
                    copy_nonoverlapping(a_elt, pack.add(p), 1);
190
0
                } else {
191
0
                    *pack.add(p) = zero;
192
0
                }
193
0
                p += 1;
194
            }
195
            // imag row
196
0
            for i in 0..mr {
197
0
                if i < rest {
198
0
                    let a_elt = aimag.stride_offset(2 * rsa, i + row_offset)
199
0
                                     .stride_offset(2 * csa, j);
200
0
                    copy_nonoverlapping(a_elt, pack.add(p), 1);
201
0
                } else {
202
0
                    *pack.add(p) = zero;
203
0
                }
204
0
                p += 1;
205
            }
206
        }
207
0
    }
208
0
}
Unexecuted instantiation: matrixmultiply::cgemm_common::pack_complex::<matrixmultiply::kernel::U2, [f64; 2], f64>
Unexecuted instantiation: matrixmultiply::cgemm_common::pack_complex::<matrixmultiply::kernel::U2, [f32; 2], f32>
Unexecuted instantiation: matrixmultiply::cgemm_common::pack_complex::<matrixmultiply::kernel::U4, [f64; 2], f64>
Unexecuted instantiation: matrixmultiply::cgemm_common::pack_complex::<matrixmultiply::kernel::U4, [f32; 2], f32>