/rust/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/stacking.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright 2014-2020 bluss and ndarray developers. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
4 | | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
5 | | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your |
6 | | // option. This file may not be copied, modified, or distributed |
7 | | // except according to those terms. |
8 | | |
9 | | use alloc::vec::Vec; |
10 | | |
11 | | use crate::dimension; |
12 | | use crate::error::{from_kind, ErrorKind, ShapeError}; |
13 | | use crate::imp_prelude::*; |
14 | | |
15 | | /// Stack arrays along the new axis. |
16 | | /// |
17 | | /// ***Errors*** if the arrays have mismatching shapes. |
18 | | /// ***Errors*** if `arrays` is empty, if `axis` is out of bounds, |
19 | | /// if the result is larger than is possible to represent. |
20 | | /// |
21 | | /// ``` |
22 | | /// extern crate ndarray; |
23 | | /// |
24 | | /// use ndarray::{arr2, arr3, stack, Axis}; |
25 | | /// |
26 | | /// # fn main() { |
27 | | /// |
28 | | /// let a = arr2(&[[2., 2.], |
29 | | /// [3., 3.]]); |
30 | | /// assert!( |
31 | | /// stack(Axis(0), &[a.view(), a.view()]) |
32 | | /// == Ok(arr3(&[[[2., 2.], |
33 | | /// [3., 3.]], |
34 | | /// [[2., 2.], |
35 | | /// [3., 3.]]])) |
36 | | /// ); |
37 | | /// # } |
38 | | /// ``` |
39 | 0 | pub fn stack<A, D>( |
40 | 0 | axis: Axis, |
41 | 0 | arrays: &[ArrayView<A, D>], |
42 | 0 | ) -> Result<Array<A, D::Larger>, ShapeError> |
43 | 0 | where |
44 | 0 | A: Clone, |
45 | 0 | D: Dimension, |
46 | 0 | D::Larger: RemoveAxis, |
47 | 0 | { |
48 | 0 | #[allow(deprecated)] |
49 | 0 | stack_new_axis(axis, arrays) |
50 | 0 | } |
51 | | |
52 | | /// Concatenate arrays along the given axis. |
53 | | /// |
54 | | /// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`. |
55 | | /// (may be made more flexible in the future).<br> |
56 | | /// ***Errors*** if `arrays` is empty, if `axis` is out of bounds, |
57 | | /// if the result is larger than is possible to represent. |
58 | | /// |
59 | | /// ``` |
60 | | /// use ndarray::{arr2, Axis, concatenate}; |
61 | | /// |
62 | | /// let a = arr2(&[[2., 2.], |
63 | | /// [3., 3.]]); |
64 | | /// assert!( |
65 | | /// concatenate(Axis(0), &[a.view(), a.view()]) |
66 | | /// == Ok(arr2(&[[2., 2.], |
67 | | /// [3., 3.], |
68 | | /// [2., 2.], |
69 | | /// [3., 3.]])) |
70 | | /// ); |
71 | | /// ``` |
72 | 0 | pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError> |
73 | 0 | where |
74 | 0 | A: Clone, |
75 | 0 | D: RemoveAxis, |
76 | 0 | { |
77 | 0 | if arrays.is_empty() { |
78 | 0 | return Err(from_kind(ErrorKind::Unsupported)); |
79 | 0 | } |
80 | 0 | let mut res_dim = arrays[0].raw_dim(); |
81 | 0 | if axis.index() >= res_dim.ndim() { |
82 | 0 | return Err(from_kind(ErrorKind::OutOfBounds)); |
83 | 0 | } |
84 | 0 | let common_dim = res_dim.remove_axis(axis); |
85 | 0 | if arrays |
86 | 0 | .iter() |
87 | 0 | .any(|a| a.raw_dim().remove_axis(axis) != common_dim) |
88 | | { |
89 | 0 | return Err(from_kind(ErrorKind::IncompatibleShape)); |
90 | 0 | } |
91 | 0 |
|
92 | 0 | let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis)); |
93 | 0 | res_dim.set_axis(axis, stacked_dim); |
94 | 0 | let new_len = dimension::size_of_shape_checked(&res_dim)?; |
95 | | |
96 | | // start with empty array with precomputed capacity |
97 | | // append's handling of empty arrays makes sure `axis` is ok for appending |
98 | 0 | res_dim.set_axis(axis, 0); |
99 | 0 | let mut res = unsafe { |
100 | 0 | // Safety: dimension is size 0 and vec is empty |
101 | 0 | Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len)) |
102 | | }; |
103 | | |
104 | 0 | for array in arrays { |
105 | 0 | res.append(axis, array.clone())?; |
106 | | } |
107 | 0 | debug_assert_eq!(res.len_of(axis), stacked_dim); |
108 | 0 | Ok(res) |
109 | 0 | } |
110 | | |
111 | | #[deprecated(note="Use under the name stack instead.", since="0.15.0")] |
112 | | /// Stack arrays along the new axis. |
113 | | /// |
114 | | /// ***Errors*** if the arrays have mismatching shapes. |
115 | | /// ***Errors*** if `arrays` is empty, if `axis` is out of bounds, |
116 | | /// if the result is larger than is possible to represent. |
117 | | /// |
118 | | /// ``` |
119 | | /// extern crate ndarray; |
120 | | /// |
121 | | /// use ndarray::{arr2, arr3, stack_new_axis, Axis}; |
122 | | /// |
123 | | /// # fn main() { |
124 | | /// |
125 | | /// let a = arr2(&[[2., 2.], |
126 | | /// [3., 3.]]); |
127 | | /// assert!( |
128 | | /// stack_new_axis(Axis(0), &[a.view(), a.view()]) |
129 | | /// == Ok(arr3(&[[[2., 2.], |
130 | | /// [3., 3.]], |
131 | | /// [[2., 2.], |
132 | | /// [3., 3.]]])) |
133 | | /// ); |
134 | | /// # } |
135 | | /// ``` |
136 | 0 | pub fn stack_new_axis<A, D>( |
137 | 0 | axis: Axis, |
138 | 0 | arrays: &[ArrayView<A, D>], |
139 | 0 | ) -> Result<Array<A, D::Larger>, ShapeError> |
140 | 0 | where |
141 | 0 | A: Clone, |
142 | 0 | D: Dimension, |
143 | 0 | D::Larger: RemoveAxis, |
144 | 0 | { |
145 | 0 | if arrays.is_empty() { |
146 | 0 | return Err(from_kind(ErrorKind::Unsupported)); |
147 | 0 | } |
148 | 0 | let common_dim = arrays[0].raw_dim(); |
149 | 0 | // Avoid panic on `insert_axis` call, return an Err instead of it. |
150 | 0 | if axis.index() > common_dim.ndim() { |
151 | 0 | return Err(from_kind(ErrorKind::OutOfBounds)); |
152 | 0 | } |
153 | 0 | let mut res_dim = common_dim.insert_axis(axis); |
154 | 0 |
|
155 | 0 | if arrays.iter().any(|a| a.raw_dim() != common_dim) { |
156 | 0 | return Err(from_kind(ErrorKind::IncompatibleShape)); |
157 | 0 | } |
158 | 0 |
|
159 | 0 | res_dim.set_axis(axis, arrays.len()); |
160 | | |
161 | 0 | let new_len = dimension::size_of_shape_checked(&res_dim)?; |
162 | | |
163 | | // start with empty array with precomputed capacity |
164 | | // append's handling of empty arrays makes sure `axis` is ok for appending |
165 | 0 | res_dim.set_axis(axis, 0); |
166 | 0 | let mut res = unsafe { |
167 | 0 | // Safety: dimension is size 0 and vec is empty |
168 | 0 | Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len)) |
169 | | }; |
170 | | |
171 | 0 | for array in arrays { |
172 | 0 | res.append(axis, array.clone().insert_axis(axis))?; |
173 | | } |
174 | | |
175 | 0 | debug_assert_eq!(res.len_of(axis), arrays.len()); |
176 | 0 | Ok(res) |
177 | 0 | } |
178 | | |
179 | | /// Stack arrays along the new axis. |
180 | | /// |
181 | | /// Uses the [`stack()`] function, calling `ArrayView::from(&a)` on each |
182 | | /// argument `a`. |
183 | | /// |
184 | | /// ***Panics*** if the `stack` function would return an error. |
185 | | /// |
186 | | /// ``` |
187 | | /// extern crate ndarray; |
188 | | /// |
189 | | /// use ndarray::{arr2, arr3, stack, Axis}; |
190 | | /// |
191 | | /// # fn main() { |
192 | | /// |
193 | | /// let a = arr2(&[[1., 2.], |
194 | | /// [3., 4.]]); |
195 | | /// assert_eq!( |
196 | | /// stack![Axis(0), a, a], |
197 | | /// arr3(&[[[1., 2.], |
198 | | /// [3., 4.]], |
199 | | /// [[1., 2.], |
200 | | /// [3., 4.]]]), |
201 | | /// ); |
202 | | /// assert_eq!( |
203 | | /// stack![Axis(1), a, a,], |
204 | | /// arr3(&[[[1., 2.], |
205 | | /// [1., 2.]], |
206 | | /// [[3., 4.], |
207 | | /// [3., 4.]]]), |
208 | | /// ); |
209 | | /// assert_eq!( |
210 | | /// stack![Axis(2), a, a], |
211 | | /// arr3(&[[[1., 1.], |
212 | | /// [2., 2.]], |
213 | | /// [[3., 3.], |
214 | | /// [4., 4.]]]), |
215 | | /// ); |
216 | | /// # } |
217 | | /// ``` |
218 | | #[macro_export] |
219 | | macro_rules! stack { |
220 | | ($axis:expr, $( $array:expr ),+ ,) => { |
221 | | $crate::stack!($axis, $($array),+) |
222 | | }; |
223 | | ($axis:expr, $( $array:expr ),+ ) => { |
224 | | $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() |
225 | | }; |
226 | | } |
227 | | |
228 | | /// Concatenate arrays along the given axis. |
229 | | /// |
230 | | /// Uses the [`concatenate()`] function, calling `ArrayView::from(&a)` on each |
231 | | /// argument `a`. |
232 | | /// |
233 | | /// ***Panics*** if the `concatenate` function would return an error. |
234 | | /// |
235 | | /// ``` |
236 | | /// extern crate ndarray; |
237 | | /// |
238 | | /// use ndarray::{arr2, concatenate, Axis}; |
239 | | /// |
240 | | /// # fn main() { |
241 | | /// |
242 | | /// let a = arr2(&[[1., 2.], |
243 | | /// [3., 4.]]); |
244 | | /// assert_eq!( |
245 | | /// concatenate![Axis(0), a, a], |
246 | | /// arr2(&[[1., 2.], |
247 | | /// [3., 4.], |
248 | | /// [1., 2.], |
249 | | /// [3., 4.]]), |
250 | | /// ); |
251 | | /// assert_eq!( |
252 | | /// concatenate![Axis(1), a, a,], |
253 | | /// arr2(&[[1., 2., 1., 2.], |
254 | | /// [3., 4., 3., 4.]]), |
255 | | /// ); |
256 | | /// # } |
257 | | /// ``` |
258 | | #[macro_export] |
259 | | macro_rules! concatenate { |
260 | | ($axis:expr, $( $array:expr ),+ ,) => { |
261 | | $crate::concatenate!($axis, $($array),+) |
262 | | }; |
263 | | ($axis:expr, $( $array:expr ),+ ) => { |
264 | | $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() |
265 | | }; |
266 | | } |
267 | | |
268 | | /// Stack arrays along the new axis. |
269 | | /// |
270 | | /// Uses the [`stack_new_axis()`] function, calling `ArrayView::from(&a)` on each |
271 | | /// argument `a`. |
272 | | /// |
273 | | /// ***Panics*** if the `stack` function would return an error. |
274 | | /// |
275 | | /// ``` |
276 | | /// extern crate ndarray; |
277 | | /// |
278 | | /// use ndarray::{arr2, arr3, stack_new_axis, Axis}; |
279 | | /// |
280 | | /// # fn main() { |
281 | | /// |
282 | | /// let a = arr2(&[[2., 2.], |
283 | | /// [3., 3.]]); |
284 | | /// assert!( |
285 | | /// stack_new_axis![Axis(0), a, a] |
286 | | /// == arr3(&[[[2., 2.], |
287 | | /// [3., 3.]], |
288 | | /// [[2., 2.], |
289 | | /// [3., 3.]]]) |
290 | | /// ); |
291 | | /// # } |
292 | | /// ``` |
293 | | #[macro_export] |
294 | | #[deprecated(note="Use under the name stack instead.", since="0.15.0")] |
295 | | macro_rules! stack_new_axis { |
296 | | ($axis:expr, $( $array:expr ),+ ) => { |
297 | | $crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap() |
298 | | } |
299 | | } |