Coverage Report

Created: 2025-10-14 06:57

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/zune-jpeg-0.4.21/src/idct/avx2.rs
Line
Count
Source
1
/*
2
 * Copyright (c) 2023.
3
 *
4
 * This software is free software;
5
 *
6
 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7
 */
8
9
#![cfg(any(target_arch = "x86", target_arch = "x86_64"))]
10
//! AVX optimised IDCT.
11
//!
12
//! Okay not thaat optimised.
13
//!
14
//!
15
//! # The implementation
16
//! The implementation is neatly broken down into two operations.
17
//!
18
//! 1. Test for zeroes
19
//! > There is a shortcut method for idct  where when all AC values are zero, we can get the answer really quickly.
20
//!  by scaling the 1/8th of the DCT coefficient of the block to the whole block and level shifting.
21
//!
22
//! 2. If above fails, we proceed to carry out IDCT as a two pass one dimensional algorithm.
23
//! IT does two whole scans where it carries out IDCT on all items
24
//! After each successive scan, data is transposed in register(thank you x86 SIMD powers). and the second
25
//! pass is carried out.
26
//!
27
//! The code is not super optimized, it produces bit identical results with scalar code hence it's
28
//! `mm256_add_epi16`
29
//! and it also has the advantage of making this implementation easy to maintain.
30
31
#![cfg(feature = "x86")]
32
#![allow(dead_code)]
33
34
#[cfg(target_arch = "x86")]
35
use core::arch::x86::*;
36
#[cfg(target_arch = "x86_64")]
37
use core::arch::x86_64::*;
38
39
use crate::unsafe_utils::{transpose, YmmRegister};
40
41
const SCALE_BITS: i32 = 512 + 65536 + (128 << 17);
42
43
/// SAFETY
44
/// ------
45
///
46
/// It is the responsibility of the CALLER to ensure that  this function is
47
/// called in contexts where the CPU supports it
48
///
49
///
50
/// For documentation see module docs.
51
52
231M
pub fn idct_avx2(in_vector: &mut [i32; 64], out_vector: &mut [i16], stride: usize) {
53
231M
    unsafe {
54
231M
        // We don't call this method directly because we need to flag the code function
55
231M
        // with #[target_feature] so that the compiler does do weird stuff with
56
231M
        // it
57
231M
        idct_int_avx2_inner(in_vector, out_vector, stride);
58
231M
    }
59
231M
}
60
61
#[target_feature(enable = "avx2")]
62
#[allow(
63
    clippy::too_many_lines,
64
    clippy::cast_possible_truncation,
65
    clippy::similar_names,
66
    clippy::op_ref,
67
    unused_assignments,
68
    clippy::zero_prefixed_literal
69
)]
70
231M
pub unsafe fn idct_int_avx2_inner(
71
231M
    in_vector: &mut [i32; 64], out_vector: &mut [i16], stride: usize
72
231M
) {
73
231M
    let mut pos = 0;
74
75
    // load into registers
76
    //
77
    // We sign extend i16's to i32's and calculate them with extended precision and
78
    // later reduce them to i16's when we are done carrying out IDCT
79
80
231M
    let rw0 = _mm256_loadu_si256(in_vector[00..].as_ptr().cast());
81
231M
    let rw1 = _mm256_loadu_si256(in_vector[08..].as_ptr().cast());
82
231M
    let rw2 = _mm256_loadu_si256(in_vector[16..].as_ptr().cast());
83
231M
    let rw3 = _mm256_loadu_si256(in_vector[24..].as_ptr().cast());
84
231M
    let rw4 = _mm256_loadu_si256(in_vector[32..].as_ptr().cast());
85
231M
    let rw5 = _mm256_loadu_si256(in_vector[40..].as_ptr().cast());
86
231M
    let rw6 = _mm256_loadu_si256(in_vector[48..].as_ptr().cast());
87
231M
    let rw7 = _mm256_loadu_si256(in_vector[56..].as_ptr().cast());
88
89
    // Forward DCT and quantization may cause all the AC terms to be zero, for such
90
    // cases we can try to accelerate it
91
92
    // Basically the poop is that whenever the array has 63 zeroes, its idct is
93
    // (arr[0]>>3)or (arr[0]/8) propagated to all the elements.
94
    // We first test to see if the array contains zero elements and if it does, we go the
95
    // short way.
96
    //
97
    // This reduces IDCT overhead from about 39% to 18 %, almost half
98
99
    // Do another load for the first row, we don't want to check DC value, because
100
    // we only care about AC terms
101
231M
    let rw8 = _mm256_loadu_si256(in_vector[1..].as_ptr().cast());
102
103
231M
    let zero = _mm256_setzero_si256();
104
105
231M
    let mut non_zero = 0;
106
107
231M
    non_zero += _mm256_movemask_epi8(_mm256_cmpeq_epi32(rw8, zero));
108
231M
    non_zero += _mm256_movemask_epi8(_mm256_cmpeq_epi32(rw1, zero));
109
231M
    non_zero += _mm256_movemask_epi8(_mm256_cmpeq_epi32(rw2, zero));
110
231M
    non_zero += _mm256_movemask_epi8(_mm256_cmpeq_epi64(rw3, zero));
111
112
231M
    non_zero += _mm256_movemask_epi8(_mm256_cmpeq_epi64(rw4, zero));
113
231M
    non_zero += _mm256_movemask_epi8(_mm256_cmpeq_epi64(rw5, zero));
114
231M
    non_zero += _mm256_movemask_epi8(_mm256_cmpeq_epi64(rw6, zero));
115
231M
    non_zero += _mm256_movemask_epi8(_mm256_cmpeq_epi64(rw7, zero));
116
117
231M
    if non_zero == -8 {
118
        // AC terms all zero, idct of the block is  is ( coeff[0] * qt[0] )/8 + 128 (bias)
119
        // (and clamped to 255)
120
156M
        let coeff = ((in_vector[0] + 4 + 1024) >> 3).clamp(0, 255) as i16;
121
156M
        let idct_value = _mm_set1_epi16(coeff);
122
123
        macro_rules! store {
124
            ($pos:tt,$value:tt) => {
125
                // store
126
                _mm_storeu_si128(
127
                    out_vector
128
156M
                        .get_mut($pos..$pos + 8)
129
                        .unwrap()
130
                        .as_mut_ptr()
131
                        .cast(),
132
                    $value
133
                );
134
                $pos += stride;
135
            };
136
        }
137
156M
        store!(pos, idct_value);
138
156M
        store!(pos, idct_value);
139
156M
        store!(pos, idct_value);
140
156M
        store!(pos, idct_value);
141
142
156M
        store!(pos, idct_value);
143
156M
        store!(pos, idct_value);
144
156M
        store!(pos, idct_value);
145
156M
        store!(pos, idct_value);
146
147
156M
        return;
148
74.2M
    }
149
150
74.2M
    let mut row0 = YmmRegister { mm256: rw0 };
151
74.2M
    let mut row1 = YmmRegister { mm256: rw1 };
152
74.2M
    let mut row2 = YmmRegister { mm256: rw2 };
153
74.2M
    let mut row3 = YmmRegister { mm256: rw3 };
154
155
74.2M
    let mut row4 = YmmRegister { mm256: rw4 };
156
74.2M
    let mut row5 = YmmRegister { mm256: rw5 };
157
74.2M
    let mut row6 = YmmRegister { mm256: rw6 };
158
74.2M
    let mut row7 = YmmRegister { mm256: rw7 };
159
160
    macro_rules! dct_pass {
161
        ($SCALE_BITS:tt,$scale:tt) => {
162
            // There are a lot of ways to do this
163
            // but to keep it simple(and beautiful), ill make a direct translation of the
164
            // scalar code to also make this code fully transparent(this version and the non
165
            // avx one should produce identical code.)
166
167
            // even part
168
            let p1 = (row2 + row6) * 2217;
169
170
            let mut t2 = p1 + row6 * -7567;
171
            let mut t3 = p1 + row2 * 3135;
172
173
            let mut t0 = YmmRegister {
174
                mm256: _mm256_slli_epi32((row0 + row4).mm256, 12)
175
            };
176
            let mut t1 = YmmRegister {
177
                mm256: _mm256_slli_epi32((row0 - row4).mm256, 12)
178
            };
179
180
            let x0 = t0 + t3 + $SCALE_BITS;
181
            let x3 = t0 - t3 + $SCALE_BITS;
182
            let x1 = t1 + t2 + $SCALE_BITS;
183
            let x2 = t1 - t2 + $SCALE_BITS;
184
185
            let p3 = row7 + row3;
186
            let p4 = row5 + row1;
187
            let p1 = row7 + row1;
188
            let p2 = row5 + row3;
189
            let p5 = (p3 + p4) * 4816;
190
191
            t0 = row7 * 1223;
192
            t1 = row5 * 8410;
193
            t2 = row3 * 12586;
194
            t3 = row1 * 6149;
195
196
            let p1 = p5 + p1 * -3685;
197
            let p2 = p5 + (p2 * -10497);
198
            let p3 = p3 * -8034;
199
            let p4 = p4 * -1597;
200
201
            t3 += p1 + p4;
202
            t2 += p2 + p3;
203
            t1 += p2 + p4;
204
            t0 += p1 + p3;
205
206
            row0.mm256 = _mm256_srai_epi32((x0 + t3).mm256, $scale);
207
            row1.mm256 = _mm256_srai_epi32((x1 + t2).mm256, $scale);
208
            row2.mm256 = _mm256_srai_epi32((x2 + t1).mm256, $scale);
209
            row3.mm256 = _mm256_srai_epi32((x3 + t0).mm256, $scale);
210
211
            row4.mm256 = _mm256_srai_epi32((x3 - t0).mm256, $scale);
212
            row5.mm256 = _mm256_srai_epi32((x2 - t1).mm256, $scale);
213
            row6.mm256 = _mm256_srai_epi32((x1 - t2).mm256, $scale);
214
            row7.mm256 = _mm256_srai_epi32((x0 - t3).mm256, $scale);
215
        };
216
    }
217
218
    // Process rows
219
74.2M
    dct_pass!(512, 10);
220
74.2M
    transpose(
221
74.2M
        &mut row0, &mut row1, &mut row2, &mut row3, &mut row4, &mut row5, &mut row6, &mut row7
222
    );
223
224
    // process columns
225
74.2M
    dct_pass!(SCALE_BITS, 17);
226
74.2M
    transpose(
227
74.2M
        &mut row0, &mut row1, &mut row2, &mut row3, &mut row4, &mut row5, &mut row6, &mut row7
228
    );
229
230
    // Pack i32 to i16's,
231
    // clamp them to be between 0-255
232
    // Undo shuffling
233
    // Store back to array
234
    macro_rules! permute_store {
235
        ($x:tt,$y:tt,$index:tt,$out:tt) => {
236
            let a = _mm256_packs_epi32($x, $y);
237
238
            // Clamp the values after packing, we can clamp more values at once
239
            let b = clamp_avx(a);
240
241
            // /Undo shuffling
242
            let c = _mm256_permute4x64_epi64(b, shuffle(3, 1, 2, 0));
243
244
            // store first vector
245
            _mm_storeu_si128(
246
                ($out)
247
74.2M
                    .get_mut($index..$index + 8)
248
                    .unwrap()
249
                    .as_mut_ptr()
250
                    .cast(),
251
                _mm256_extractf128_si256::<0>(c)
252
            );
253
            $index += stride;
254
            // second vector
255
            _mm_storeu_si128(
256
                ($out)
257
74.2M
                    .get_mut($index..$index + 8)
258
                    .unwrap()
259
                    .as_mut_ptr()
260
                    .cast(),
261
                _mm256_extractf128_si256::<1>(c)
262
            );
263
            $index += stride;
264
        };
265
    }
266
    // Pack and write the values back to the array
267
74.2M
    permute_store!((row0.mm256), (row1.mm256), pos, out_vector);
268
74.2M
    permute_store!((row2.mm256), (row3.mm256), pos, out_vector);
269
74.2M
    permute_store!((row4.mm256), (row5.mm256), pos, out_vector);
270
74.2M
    permute_store!((row6.mm256), (row7.mm256), pos, out_vector);
271
231M
}
272
273
#[inline]
274
#[target_feature(enable = "avx2")]
275
297M
unsafe fn clamp_avx(reg: __m256i) -> __m256i {
276
297M
    let min_s = _mm256_set1_epi16(0);
277
297M
    let max_s = _mm256_set1_epi16(255);
278
279
297M
    let max_v = _mm256_max_epi16(reg, min_s); //max(a,0)
280
297M
    let min_v = _mm256_min_epi16(max_v, max_s); //min(max(a,0),255)
281
297M
    return min_v;
282
297M
}
283
284
/// A copy of `_MM_SHUFFLE()` that doesn't require
285
/// a nightly compiler
286
#[inline]
287
0
const fn shuffle(z: i32, y: i32, x: i32, w: i32) -> i32 {
288
0
    ((z << 6) | (y << 4) | (x << 2) | w)
289
0
}