Coverage Report

Created: 2025-12-20 06:48

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/zune-jpeg-0.5.7/src/idct/avx2.rs
Line
Count
Source
1
/*
2
 * Copyright (c) 2023.
3
 *
4
 * This software is free software;
5
 *
6
 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7
 */
8
9
#![cfg(any(target_arch = "x86", target_arch = "x86_64"))]
10
//! AVX optimised IDCT.
11
//!
12
//! Okay not thaat optimised.
13
//!
14
//!
15
//! # The implementation
16
//! The implementation is neatly broken down into two operations.
17
//!
18
//! 1. Test for zeroes
19
//! > There is a shortcut method for idct  where when all AC values are zero, we can get the answer really quickly.
20
//!  by scaling the 1/8th of the DCT coefficient of the block to the whole block and level shifting.
21
//!
22
//! 2. If above fails, we proceed to carry out IDCT as a two pass one dimensional algorithm.
23
//! IT does two whole scans where it carries out IDCT on all items
24
//! After each successive scan, data is transposed in register(thank you x86 SIMD powers). and the second
25
//! pass is carried out.
26
//!
27
//! The code is not super optimized, it produces bit identical results with scalar code hence it's
28
//! `mm256_add_epi16`
29
//! and it also has the advantage of making this implementation easy to maintain.
30
31
#![cfg(feature = "x86")]
32
#![allow(dead_code)]
33
34
#[cfg(target_arch = "x86")]
35
use core::arch::x86::*;
36
#[cfg(target_arch = "x86_64")]
37
use core::arch::x86_64::*;
38
39
use crate::unsafe_utils::{transpose, YmmRegister};
40
41
const SCALE_BITS: i32 = 512 + 65536 + (128 << 17);
42
43
// Pack i32 to i16's,
44
// clamp them to be between 0-255
45
// Undo shuffling
46
// Store back to array
47
macro_rules! permute_store {
48
    ($x:tt,$y:tt,$index:tt,$out:tt,$stride:tt) => {
49
        let a = _mm256_packs_epi32($x, $y);
50
51
        // Clamp the values after packing, we can clamp more values at once
52
        let b = clamp_avx(a);
53
54
        // /Undo shuffling
55
        let c = _mm256_permute4x64_epi64(b, shuffle(3, 1, 2, 0));
56
57
        // store first vector
58
        _mm_storeu_si128(
59
            ($out)
60
                .get_mut($index..$index + 8)
61
                .unwrap()
62
                .as_mut_ptr()
63
                .cast(),
64
            _mm256_extractf128_si256::<0>(c),
65
        );
66
        $index += $stride;
67
        // second vector
68
        _mm_storeu_si128(
69
            ($out)
70
                .get_mut($index..$index + 8)
71
                .unwrap()
72
                .as_mut_ptr()
73
                .cast(),
74
            _mm256_extractf128_si256::<1>(c),
75
        );
76
        $index += $stride;
77
    };
78
}
79
80
#[target_feature(enable = "avx2")]
81
#[allow(
82
    clippy::too_many_lines,
83
    clippy::cast_possible_truncation,
84
    clippy::similar_names,
85
    clippy::op_ref,
86
    unused_assignments,
87
    clippy::zero_prefixed_literal
88
)]
89
424M
pub unsafe fn idct_avx2(
90
424M
    in_vector: &mut [i32; 64], out_vector: &mut [i16], stride: usize,
91
424M
) {
92
424M
    let mut pos = 0;
93
94
    // load into registers
95
    //
96
    // We sign extend i16's to i32's and calculate them with extended precision and
97
    // later reduce them to i16's when we are done carrying out IDCT
98
99
424M
    let rw0 = _mm256_loadu_si256(in_vector[00..].as_ptr().cast());
100
424M
    let rw1 = _mm256_loadu_si256(in_vector[08..].as_ptr().cast());
101
424M
    let rw2 = _mm256_loadu_si256(in_vector[16..].as_ptr().cast());
102
424M
    let rw3 = _mm256_loadu_si256(in_vector[24..].as_ptr().cast());
103
424M
    let rw4 = _mm256_loadu_si256(in_vector[32..].as_ptr().cast());
104
424M
    let rw5 = _mm256_loadu_si256(in_vector[40..].as_ptr().cast());
105
424M
    let rw6 = _mm256_loadu_si256(in_vector[48..].as_ptr().cast());
106
424M
    let rw7 = _mm256_loadu_si256(in_vector[56..].as_ptr().cast());
107
108
    // Forward DCT and quantization may cause all the AC terms to be zero, for such
109
    // cases we can try to accelerate it
110
111
    // Basically the poop is that whenever the array has 63 zeroes, its idct is
112
    // (arr[0]>>3)or (arr[0]/8) propagated to all the elements.
113
    // We first test to see if the array contains zero elements and if it does, we go the
114
    // short way.
115
    //
116
    // This reduces IDCT overhead from about 39% to 18 %, almost half
117
118
    // Do another load for the first row, we don't want to check DC value, because
119
    // we only care about AC terms
120
424M
    let rw8 = _mm256_loadu_si256(in_vector[1..].as_ptr().cast());
121
122
424M
    let mut bitmap = _mm256_or_si256(rw1, rw2);
123
424M
    bitmap = _mm256_or_si256(bitmap, rw3);
124
424M
    bitmap = _mm256_or_si256(bitmap, rw4);
125
424M
    bitmap = _mm256_or_si256(bitmap, rw5);
126
424M
    bitmap = _mm256_or_si256(bitmap, rw6);
127
424M
    bitmap = _mm256_or_si256(bitmap, rw7);
128
424M
    bitmap = _mm256_or_si256(bitmap, rw8);
129
130
424M
    if _mm256_testz_si256(bitmap, bitmap) == 1 {
131
        // AC terms all zero, idct of the block is ( coeff[0] * qt[0] )/8 + 128 (bias)
132
        // (and clamped to 255)
133
        // Round by adding 0.5 * (1 << 3) and offset by adding (128 << 3) before scaling
134
378M
        let coeff = ((in_vector[0] + 4 + 1024) >> 3).clamp(0, 255) as i16;
135
378M
        let idct_value = _mm_set1_epi16(coeff);
136
137
        macro_rules! store {
138
            ($pos:tt,$value:tt) => {
139
                // store
140
                _mm_storeu_si128(
141
                    out_vector
142
378M
                        .get_mut($pos..$pos + 8)
143
                        .unwrap()
144
                        .as_mut_ptr()
145
                        .cast(),
146
                    $value,
147
                );
148
                $pos += stride;
149
            };
150
        }
151
378M
        store!(pos, idct_value);
152
378M
        store!(pos, idct_value);
153
378M
        store!(pos, idct_value);
154
378M
        store!(pos, idct_value);
155
156
378M
        store!(pos, idct_value);
157
378M
        store!(pos, idct_value);
158
378M
        store!(pos, idct_value);
159
378M
        store!(pos, idct_value);
160
161
378M
        return;
162
45.4M
    }
163
164
45.4M
    let mut row0 = YmmRegister { mm256: rw0 };
165
45.4M
    let mut row1 = YmmRegister { mm256: rw1 };
166
45.4M
    let mut row2 = YmmRegister { mm256: rw2 };
167
45.4M
    let mut row3 = YmmRegister { mm256: rw3 };
168
169
45.4M
    let mut row4 = YmmRegister { mm256: rw4 };
170
45.4M
    let mut row5 = YmmRegister { mm256: rw5 };
171
45.4M
    let mut row6 = YmmRegister { mm256: rw6 };
172
45.4M
    let mut row7 = YmmRegister { mm256: rw7 };
173
174
    macro_rules! dct_pass {
175
        ($SCALE_BITS:tt,$scale:tt) => {
176
            // There are a lot of ways to do this
177
            // but to keep it simple(and beautiful), ill make a direct translation of the
178
            // scalar code to also make this code fully transparent(this version and the non
179
            // avx one should produce identical code.)
180
181
            // even part
182
            let p1 = (row2 + row6) * 2217;
183
184
            let mut t2 = p1 + row6 * -7567;
185
            let mut t3 = p1 + row2 * 3135;
186
187
            let mut t0 = YmmRegister {
188
                mm256: _mm256_slli_epi32((row0 + row4).mm256, 12),
189
            };
190
            let mut t1 = YmmRegister {
191
                mm256: _mm256_slli_epi32((row0 - row4).mm256, 12),
192
            };
193
194
            let x0 = t0 + t3 + $SCALE_BITS;
195
            let x3 = t0 - t3 + $SCALE_BITS;
196
            let x1 = t1 + t2 + $SCALE_BITS;
197
            let x2 = t1 - t2 + $SCALE_BITS;
198
199
            let p3 = row7 + row3;
200
            let p4 = row5 + row1;
201
            let p1 = row7 + row1;
202
            let p2 = row5 + row3;
203
            let p5 = (p3 + p4) * 4816;
204
205
            t0 = row7 * 1223;
206
            t1 = row5 * 8410;
207
            t2 = row3 * 12586;
208
            t3 = row1 * 6149;
209
210
            let p1 = p5 + p1 * -3685;
211
            let p2 = p5 + (p2 * -10497);
212
            let p3 = p3 * -8034;
213
            let p4 = p4 * -1597;
214
215
            t3 += p1 + p4;
216
            t2 += p2 + p3;
217
            t1 += p2 + p4;
218
            t0 += p1 + p3;
219
220
            row0.mm256 = _mm256_srai_epi32((x0 + t3).mm256, $scale);
221
            row1.mm256 = _mm256_srai_epi32((x1 + t2).mm256, $scale);
222
            row2.mm256 = _mm256_srai_epi32((x2 + t1).mm256, $scale);
223
            row3.mm256 = _mm256_srai_epi32((x3 + t0).mm256, $scale);
224
225
            row4.mm256 = _mm256_srai_epi32((x3 - t0).mm256, $scale);
226
            row5.mm256 = _mm256_srai_epi32((x2 - t1).mm256, $scale);
227
            row6.mm256 = _mm256_srai_epi32((x1 - t2).mm256, $scale);
228
            row7.mm256 = _mm256_srai_epi32((x0 - t3).mm256, $scale);
229
        };
230
    }
231
232
    // Process rows
233
45.4M
    dct_pass!(512, 10);
234
45.4M
    transpose(
235
45.4M
        &mut row0, &mut row1, &mut row2, &mut row3, &mut row4, &mut row5, &mut row6, &mut row7,
236
    );
237
238
    // process columns
239
45.4M
    dct_pass!(SCALE_BITS, 17);
240
45.4M
    transpose(
241
45.4M
        &mut row0, &mut row1, &mut row2, &mut row3, &mut row4, &mut row5, &mut row6, &mut row7,
242
    );
243
    // Pack and write the values back to the array
244
45.4M
    permute_store!((row0.mm256), (row1.mm256), pos, out_vector, stride);
245
45.4M
    permute_store!((row2.mm256), (row3.mm256), pos, out_vector, stride);
246
45.4M
    permute_store!((row4.mm256), (row5.mm256), pos, out_vector, stride);
247
45.4M
    permute_store!((row6.mm256), (row7.mm256), pos, out_vector, stride);
248
424M
}
249
250
251
#[target_feature(enable = "avx2")]
252
#[allow(
253
    clippy::too_many_lines,
254
    clippy::cast_possible_truncation,
255
    clippy::similar_names,
256
    clippy::op_ref,
257
    unused_assignments,
258
    clippy::zero_prefixed_literal
259
)]
260
1.88M
pub unsafe fn idct_avx2_4x4(
261
1.88M
    in_vector: &mut [i32; 64], out_vector: &mut [i16], stride: usize,
262
1.88M
) {
263
1.88M
    let rw0 = _mm256_loadu_si256(in_vector[00..].as_ptr().cast());
264
1.88M
    let rw1 = _mm256_loadu_si256(in_vector[08..].as_ptr().cast());
265
1.88M
    let rw2 = _mm256_loadu_si256(in_vector[16..].as_ptr().cast());
266
1.88M
    let rw3 = _mm256_loadu_si256(in_vector[24..].as_ptr().cast());
267
268
1.88M
    let mut row0 = YmmRegister { mm256: rw0 };
269
1.88M
    let mut row1 = YmmRegister { mm256: rw1 };
270
1.88M
    let mut row2 = YmmRegister { mm256: rw2 };
271
1.88M
    let mut row3 = YmmRegister { mm256: rw3 };
272
273
1.88M
    let mut row4 = YmmRegister { mm256: rw0 };
274
1.88M
    let mut row5 = YmmRegister { mm256: rw0 };
275
1.88M
    let mut row6 = YmmRegister { mm256: rw0 };
276
1.88M
    let mut row7 = YmmRegister { mm256: rw0 };
277
278
1.88M
    {
279
1.88M
        row0.mm256 = _mm256_slli_epi32(row0.mm256, 12);
280
1.88M
        row0 += 512;
281
1.88M
282
1.88M
        let i2 = row2;
283
1.88M
284
1.88M
        let p1 = i2 * 2217;
285
1.88M
        let p3 = i2 * 5352;
286
1.88M
287
1.88M
        let x0 = row0 + p3;
288
1.88M
        let x1 = row0 + p1;
289
1.88M
        let x2 = row0 - p1;
290
1.88M
        let x3 = row0 - p3;
291
1.88M
292
1.88M
        // odd part
293
1.88M
        let i4 = row3;
294
1.88M
        let i3 = row1;
295
1.88M
296
1.88M
        let p5 = (i4 + i3) * 4816;
297
1.88M
298
1.88M
        let p1 = p5 + i3 * -3685;
299
1.88M
        let p2 = p5 + i4 * -10497;
300
1.88M
301
1.88M
        let t3 = p5 + i3 * 867;
302
1.88M
        let t2 = p5 + i4 * -5945;
303
1.88M
304
1.88M
        let t1 = p2 + i3 * -1597;
305
1.88M
        let t0 = p1 + i4 * -8034;
306
1.88M
307
1.88M
        row0.mm256 = _mm256_srai_epi32((x0 + t3).mm256, 10);
308
1.88M
        row1.mm256 = _mm256_srai_epi32((x1 + t2).mm256, 10);
309
1.88M
        row2.mm256 = _mm256_srai_epi32((x2 + t1).mm256, 10);
310
1.88M
        row3.mm256 = _mm256_srai_epi32((x3 + t0).mm256, 10);
311
1.88M
312
1.88M
        row4.mm256 = _mm256_srai_epi32((x3 - t0).mm256, 10);
313
1.88M
        row5.mm256 = _mm256_srai_epi32((x2 - t1).mm256, 10);
314
1.88M
        row6.mm256 = _mm256_srai_epi32((x1 - t2).mm256, 10);
315
1.88M
        row7.mm256 = _mm256_srai_epi32((x0 - t3).mm256, 10);
316
1.88M
    }
317
318
1.88M
    transpose(
319
1.88M
        &mut row0, &mut row1, &mut row2, &mut row3, &mut row4, &mut row5, &mut row6, &mut row7,
320
    );
321
322
1.88M
    {
323
1.88M
        let i2 = row2;
324
1.88M
        let i0 = row0;
325
1.88M
326
1.88M
        row0.mm256 = _mm256_slli_epi32(i0.mm256, 12);
327
1.88M
        let t0 = row0 + SCALE_BITS;
328
1.88M
329
1.88M
        let t2 = i2 * 2217;
330
1.88M
        let t3 = i2 * 5352;
331
1.88M
332
1.88M
        // constants scaled things up by 1<<12, plus we had 1<<2 from first
333
1.88M
        // loop, plus horizontal and vertical each scale by sqrt(8) so together
334
1.88M
        // we've got an extra 1<<3, so 1<<17 total we need to remove.
335
1.88M
        // so we want to round that, which means adding 0.5 * 1<<17,
336
1.88M
        // aka 65536. Also, we'll end up with -128 to 127 that we want
337
1.88M
        // to encode as 0..255 by adding 128, so we'll add that before the shift
338
1.88M
        // Rounding constant is already added into `t0`
339
1.88M
        let x0 = t0 + t3;
340
1.88M
        let x3 = t0 - t3;
341
1.88M
        let x1 = t0 + t2;
342
1.88M
        let x2 = t0 - t2;
343
1.88M
344
1.88M
        // odd part
345
1.88M
        let i3 = row3;
346
1.88M
        let i1 = row1;
347
1.88M
348
1.88M
        let p5 = (i3 + i1) * 4816;
349
1.88M
350
1.88M
        let p1 = p5 + i1 * -3685;
351
1.88M
        let p2 = p5 + i3 * -10497;
352
1.88M
353
1.88M
        let t3 = p5 + i1 * 867;
354
1.88M
        let t2 = p5 + i3 * -5945;
355
1.88M
356
1.88M
        let t1 = p2 + i1 * -1597;
357
1.88M
        let t0 = p1 + i3 * -8034;
358
1.88M
359
1.88M
        row0.mm256 = _mm256_srai_epi32((x0 + t3).mm256, 17);
360
1.88M
        row1.mm256 = _mm256_srai_epi32((x1 + t2).mm256, 17);
361
1.88M
        row2.mm256 = _mm256_srai_epi32((x2 + t1).mm256, 17);
362
1.88M
        row3.mm256 = _mm256_srai_epi32((x3 + t0).mm256, 17);
363
1.88M
        row4.mm256 = _mm256_srai_epi32((x3 - t0).mm256, 17);
364
1.88M
        row5.mm256 = _mm256_srai_epi32((x2 - t1).mm256, 17);
365
1.88M
        row6.mm256 = _mm256_srai_epi32((x1 - t2).mm256, 17);
366
1.88M
        row7.mm256 = _mm256_srai_epi32((x0 - t3).mm256, 17);
367
1.88M
    }
368
369
1.88M
    transpose(
370
1.88M
        &mut row0, &mut row1, &mut row2, &mut row3, &mut row4, &mut row5, &mut row6, &mut row7,
371
    );
372
373
1.88M
    let mut pos = 0;
374
375
    // Pack and write the values back to the array
376
1.88M
    permute_store!((row0.mm256), (row1.mm256), pos, out_vector, stride);
377
1.88M
    permute_store!((row2.mm256), (row3.mm256), pos, out_vector, stride);
378
1.88M
    permute_store!((row4.mm256), (row5.mm256), pos, out_vector, stride);
379
1.88M
    permute_store!((row6.mm256), (row7.mm256), pos, out_vector, stride);
380
1.88M
}
381
382
#[inline]
383
#[target_feature(enable = "avx2")]
384
189M
unsafe fn clamp_avx(reg: __m256i) -> __m256i {
385
189M
    let min_s = _mm256_set1_epi16(0);
386
189M
    let max_s = _mm256_set1_epi16(255);
387
388
189M
    let max_v = _mm256_max_epi16(reg, min_s); //max(a,0)
389
189M
    let min_v = _mm256_min_epi16(max_v, max_s); //min(max(a,0),255)
390
189M
    return min_v;
391
189M
}
392
393
/// A copy of `_MM_SHUFFLE()` that doesn't require
394
/// a nightly compiler
395
#[inline]
396
0
const fn shuffle(z: i32, y: i32, x: i32, w: i32) -> i32 {
397
0
    ((z << 6) | (y << 4) | (x << 2) | w)
398
0
}