Coverage Report

Created: 2025-02-21 07:11

/rust/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/stacking.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2014-2020 bluss and ndarray developers.
2
//
3
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6
// option. This file may not be copied, modified, or distributed
7
// except according to those terms.
8
9
use alloc::vec::Vec;
10
11
use crate::dimension;
12
use crate::error::{from_kind, ErrorKind, ShapeError};
13
use crate::imp_prelude::*;
14
15
/// Stack arrays along the new axis.
16
///
17
/// ***Errors*** if the arrays have mismatching shapes.
18
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
19
/// if the result is larger than is possible to represent.
20
///
21
/// ```
22
/// extern crate ndarray;
23
///
24
/// use ndarray::{arr2, arr3, stack, Axis};
25
///
26
/// # fn main() {
27
///
28
/// let a = arr2(&[[2., 2.],
29
///                [3., 3.]]);
30
/// assert!(
31
///     stack(Axis(0), &[a.view(), a.view()])
32
///     == Ok(arr3(&[[[2., 2.],
33
///                   [3., 3.]],
34
///                  [[2., 2.],
35
///                   [3., 3.]]]))
36
/// );
37
/// # }
38
/// ```
39
0
pub fn stack<A, D>(
40
0
    axis: Axis,
41
0
    arrays: &[ArrayView<A, D>],
42
0
) -> Result<Array<A, D::Larger>, ShapeError>
43
0
where
44
0
    A: Clone,
45
0
    D: Dimension,
46
0
    D::Larger: RemoveAxis,
47
0
{
48
0
    #[allow(deprecated)]
49
0
    stack_new_axis(axis, arrays)
50
0
}
51
52
/// Concatenate arrays along the given axis.
53
///
54
/// ***Errors*** if the arrays have mismatching shapes, apart from along `axis`.
55
/// (may be made more flexible in the future).<br>
56
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
57
/// if the result is larger than is possible to represent.
58
///
59
/// ```
60
/// use ndarray::{arr2, Axis, concatenate};
61
///
62
/// let a = arr2(&[[2., 2.],
63
///                [3., 3.]]);
64
/// assert!(
65
///     concatenate(Axis(0), &[a.view(), a.view()])
66
///     == Ok(arr2(&[[2., 2.],
67
///                  [3., 3.],
68
///                  [2., 2.],
69
///                  [3., 3.]]))
70
/// );
71
/// ```
72
0
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
73
0
where
74
0
    A: Clone,
75
0
    D: RemoveAxis,
76
0
{
77
0
    if arrays.is_empty() {
78
0
        return Err(from_kind(ErrorKind::Unsupported));
79
0
    }
80
0
    let mut res_dim = arrays[0].raw_dim();
81
0
    if axis.index() >= res_dim.ndim() {
82
0
        return Err(from_kind(ErrorKind::OutOfBounds));
83
0
    }
84
0
    let common_dim = res_dim.remove_axis(axis);
85
0
    if arrays
86
0
        .iter()
87
0
        .any(|a| a.raw_dim().remove_axis(axis) != common_dim)
88
    {
89
0
        return Err(from_kind(ErrorKind::IncompatibleShape));
90
0
    }
91
0
92
0
    let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
93
0
    res_dim.set_axis(axis, stacked_dim);
94
0
    let new_len = dimension::size_of_shape_checked(&res_dim)?;
95
96
    // start with empty array with precomputed capacity
97
    // append's handling of empty arrays makes sure `axis` is ok for appending
98
0
    res_dim.set_axis(axis, 0);
99
0
    let mut res = unsafe {
100
0
        // Safety: dimension is size 0 and vec is empty
101
0
        Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
102
    };
103
104
0
    for array in arrays {
105
0
        res.append(axis, array.clone())?;
106
    }
107
0
    debug_assert_eq!(res.len_of(axis), stacked_dim);
108
0
    Ok(res)
109
0
}
110
111
#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
112
/// Stack arrays along the new axis.
113
///
114
/// ***Errors*** if the arrays have mismatching shapes.
115
/// ***Errors*** if `arrays` is empty, if `axis` is out of bounds,
116
/// if the result is larger than is possible to represent.
117
///
118
/// ```
119
/// extern crate ndarray;
120
///
121
/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
122
///
123
/// # fn main() {
124
///
125
/// let a = arr2(&[[2., 2.],
126
///                [3., 3.]]);
127
/// assert!(
128
///     stack_new_axis(Axis(0), &[a.view(), a.view()])
129
///     == Ok(arr3(&[[[2., 2.],
130
///                   [3., 3.]],
131
///                  [[2., 2.],
132
///                   [3., 3.]]]))
133
/// );
134
/// # }
135
/// ```
136
0
pub fn stack_new_axis<A, D>(
137
0
    axis: Axis,
138
0
    arrays: &[ArrayView<A, D>],
139
0
) -> Result<Array<A, D::Larger>, ShapeError>
140
0
where
141
0
    A: Clone,
142
0
    D: Dimension,
143
0
    D::Larger: RemoveAxis,
144
0
{
145
0
    if arrays.is_empty() {
146
0
        return Err(from_kind(ErrorKind::Unsupported));
147
0
    }
148
0
    let common_dim = arrays[0].raw_dim();
149
0
    // Avoid panic on `insert_axis` call, return an Err instead of it.
150
0
    if axis.index() > common_dim.ndim() {
151
0
        return Err(from_kind(ErrorKind::OutOfBounds));
152
0
    }
153
0
    let mut res_dim = common_dim.insert_axis(axis);
154
0
155
0
    if arrays.iter().any(|a| a.raw_dim() != common_dim) {
156
0
        return Err(from_kind(ErrorKind::IncompatibleShape));
157
0
    }
158
0
159
0
    res_dim.set_axis(axis, arrays.len());
160
161
0
    let new_len = dimension::size_of_shape_checked(&res_dim)?;
162
163
    // start with empty array with precomputed capacity
164
    // append's handling of empty arrays makes sure `axis` is ok for appending
165
0
    res_dim.set_axis(axis, 0);
166
0
    let mut res = unsafe {
167
0
        // Safety: dimension is size 0 and vec is empty
168
0
        Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
169
    };
170
171
0
    for array in arrays {
172
0
        res.append(axis, array.clone().insert_axis(axis))?;
173
    }
174
175
0
    debug_assert_eq!(res.len_of(axis), arrays.len());
176
0
    Ok(res)
177
0
}
178
179
/// Stack arrays along the new axis.
180
///
181
/// Uses the [`stack()`] function, calling `ArrayView::from(&a)` on each
182
/// argument `a`.
183
///
184
/// ***Panics*** if the `stack` function would return an error.
185
///
186
/// ```
187
/// extern crate ndarray;
188
///
189
/// use ndarray::{arr2, arr3, stack, Axis};
190
///
191
/// # fn main() {
192
///
193
/// let a = arr2(&[[1., 2.],
194
///                [3., 4.]]);
195
/// assert_eq!(
196
///     stack![Axis(0), a, a],
197
///     arr3(&[[[1., 2.],
198
///             [3., 4.]],
199
///            [[1., 2.],
200
///             [3., 4.]]]),
201
/// );
202
/// assert_eq!(
203
///     stack![Axis(1), a, a,],
204
///     arr3(&[[[1., 2.],
205
///             [1., 2.]],
206
///            [[3., 4.],
207
///             [3., 4.]]]),
208
/// );
209
/// assert_eq!(
210
///     stack![Axis(2), a, a],
211
///     arr3(&[[[1., 1.],
212
///             [2., 2.]],
213
///            [[3., 3.],
214
///             [4., 4.]]]),
215
/// );
216
/// # }
217
/// ```
218
#[macro_export]
219
macro_rules! stack {
220
    ($axis:expr, $( $array:expr ),+ ,) => {
221
        $crate::stack!($axis, $($array),+)
222
    };
223
    ($axis:expr, $( $array:expr ),+ ) => {
224
        $crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
225
    };
226
}
227
228
/// Concatenate arrays along the given axis.
229
///
230
/// Uses the [`concatenate()`] function, calling `ArrayView::from(&a)` on each
231
/// argument `a`.
232
///
233
/// ***Panics*** if the `concatenate` function would return an error.
234
///
235
/// ```
236
/// extern crate ndarray;
237
///
238
/// use ndarray::{arr2, concatenate, Axis};
239
///
240
/// # fn main() {
241
///
242
/// let a = arr2(&[[1., 2.],
243
///                [3., 4.]]);
244
/// assert_eq!(
245
///     concatenate![Axis(0), a, a],
246
///     arr2(&[[1., 2.],
247
///            [3., 4.],
248
///            [1., 2.],
249
///            [3., 4.]]),
250
/// );
251
/// assert_eq!(
252
///     concatenate![Axis(1), a, a,],
253
///     arr2(&[[1., 2., 1., 2.],
254
///            [3., 4., 3., 4.]]),
255
/// );
256
/// # }
257
/// ```
258
#[macro_export]
259
macro_rules! concatenate {
260
    ($axis:expr, $( $array:expr ),+ ,) => {
261
        $crate::concatenate!($axis, $($array),+)
262
    };
263
    ($axis:expr, $( $array:expr ),+ ) => {
264
        $crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
265
    };
266
}
267
268
/// Stack arrays along the new axis.
269
///
270
/// Uses the [`stack_new_axis()`] function, calling `ArrayView::from(&a)` on each
271
/// argument `a`.
272
///
273
/// ***Panics*** if the `stack` function would return an error.
274
///
275
/// ```
276
/// extern crate ndarray;
277
///
278
/// use ndarray::{arr2, arr3, stack_new_axis, Axis};
279
///
280
/// # fn main() {
281
///
282
/// let a = arr2(&[[2., 2.],
283
///                [3., 3.]]);
284
/// assert!(
285
///     stack_new_axis![Axis(0), a, a]
286
///     == arr3(&[[[2., 2.],
287
///                [3., 3.]],
288
///               [[2., 2.],
289
///                [3., 3.]]])
290
/// );
291
/// # }
292
/// ```
293
#[macro_export]
294
#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
295
macro_rules! stack_new_axis {
296
    ($axis:expr, $( $array:expr ),+ ) => {
297
        $crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
298
    }
299
}