/rust/registry/src/index.crates.io-1949cf8c6b5b557f/linfa-linalg-0.2.1/src/reflection.rs
Line | Count | Source |
1 | | use ndarray::{ArrayBase, Data, DataMut, Ix1, Ix2, NdFloat}; |
2 | | |
3 | | /// Reflection with respect to a plane |
4 | | pub struct Reflection<A, D: Data<Elem = A>> { |
5 | | axis: ArrayBase<D, Ix1>, |
6 | | bias: A, |
7 | | } |
8 | | |
9 | | impl<A, D: Data<Elem = A>> Reflection<A, D> { |
10 | | /// Create a new reflection with respect to the plane orthogonal to the given axis and bias |
11 | | /// |
12 | | /// `axis` must be a unit vector |
13 | | /// `bias` is the position of the plane on the axis from the origin |
14 | 0 | pub fn new(axis: ArrayBase<D, Ix1>, bias: A) -> Self { |
15 | 0 | Self { axis, bias } |
16 | 0 | } |
17 | | |
18 | 0 | pub fn axis(&self) -> &ArrayBase<D, Ix1> { |
19 | 0 | &self.axis |
20 | 0 | } |
21 | | } |
22 | | |
23 | | // XXX Can use matrix multiplication algorithm instead of iterative algorithm for both reflections |
24 | | impl<A: NdFloat, D: Data<Elem = A>> Reflection<A, D> { |
25 | | /// Apply reflection to the columns of `rhs` |
26 | 0 | pub fn reflect_cols<M: DataMut<Elem = A>>(&self, rhs: &mut ArrayBase<M, Ix2>) { |
27 | 0 | for i in 0..rhs.ncols() { |
28 | 0 | let m_two = A::from(-2.0f64).unwrap(); |
29 | 0 | let factor = (self.axis.dot(&rhs.column(i)) - self.bias) * m_two; |
30 | 0 | rhs.column_mut(i).scaled_add(factor, &self.axis); |
31 | 0 | } |
32 | 0 | } |
33 | | |
34 | | /// Apply reflection to the rows of `lhs` |
35 | 0 | pub fn reflect_rows<M: DataMut<Elem = A>>(&self, lhs: &mut ArrayBase<M, Ix2>) { |
36 | 0 | self.reflect_cols(&mut lhs.view_mut().reversed_axes()); |
37 | 0 | } |
38 | | } |
39 | | |
40 | | #[cfg(test)] |
41 | | mod tests { |
42 | | use approx::assert_abs_diff_eq; |
43 | | use ndarray::array; |
44 | | |
45 | | use super::*; |
46 | | |
47 | | #[test] |
48 | | fn reflect_plane_col() { |
49 | | let y_axis = array![0., 1., 0.]; |
50 | | let refl = Reflection::new(y_axis.view(), 0.0); |
51 | | |
52 | | let mut v = array![[1., 2., 3.], [3., 4., 5.]].reversed_axes(); |
53 | | refl.reflect_cols(&mut v); |
54 | | assert_abs_diff_eq!(v, array![[1., -2., 3.], [3., -4., 5.]].reversed_axes()); |
55 | | refl.reflect_cols(&mut v); |
56 | | assert_abs_diff_eq!(v, array![[1., 2., 3.], [3., 4., 5.]].reversed_axes()); |
57 | | |
58 | | let refl = Reflection::new(y_axis.view(), 3.0); |
59 | | let mut v = array![[1., 2., 3.], [3., 4., 5.]].reversed_axes(); |
60 | | refl.reflect_cols(&mut v); |
61 | | assert_abs_diff_eq!(v, array![[1., 4., 3.], [3., 2., 5.]].reversed_axes()); |
62 | | } |
63 | | |
64 | | #[test] |
65 | | fn reflect_plane_row() { |
66 | | let y_axis = array![0., 1., 0.]; |
67 | | let refl = Reflection::new(y_axis.view(), 0.0); |
68 | | |
69 | | let mut v = array![[1., 2., 3.], [3., 4., 5.]]; |
70 | | refl.reflect_rows(&mut v); |
71 | | assert_abs_diff_eq!(v, array![[1., -2., 3.], [3., -4., 5.]]); |
72 | | refl.reflect_rows(&mut v); |
73 | | assert_abs_diff_eq!(v, array![[1., 2., 3.], [3., 4., 5.]]); |
74 | | |
75 | | let refl = Reflection::new(y_axis.view(), 3.0); |
76 | | let mut v = array![[1., 2., 3.], [3., 4., 5.]]; |
77 | | refl.reflect_rows(&mut v); |
78 | | assert_abs_diff_eq!(v, array![[1., 4., 3.], [3., 2., 5.]]); |
79 | | } |
80 | | } |