/src/libjxl/lib/jxl/enc_ans_simd.cc
Line | Count | Source |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/enc_ans_simd.h" |
7 | | |
8 | | #include <cstdint> |
9 | | |
10 | | #include "lib/jxl/base/status.h" |
11 | | #include "lib/jxl/dec_ans.h" |
12 | | #include "lib/jxl/memory_manager_internal.h" |
13 | | |
14 | | #undef HWY_TARGET_INCLUDE |
15 | | #define HWY_TARGET_INCLUDE "lib/jxl/enc_ans_simd.cc" |
16 | | #include <hwy/foreach_target.h> |
17 | | #include <hwy/highway.h> |
18 | | |
19 | | HWY_BEFORE_NAMESPACE(); |
20 | | namespace jxl { |
21 | | namespace HWY_NAMESPACE { |
22 | | |
23 | | // These templates are not found via ADL. |
24 | | using hwy::HWY_NAMESPACE::Add; |
25 | | using hwy::HWY_NAMESPACE::And; |
26 | | using hwy::HWY_NAMESPACE::Ge; |
27 | | using hwy::HWY_NAMESPACE::GetLane; |
28 | | using hwy::HWY_NAMESPACE::Gt; |
29 | | using hwy::HWY_NAMESPACE::IfThenElse; |
30 | | using hwy::HWY_NAMESPACE::IfThenElseZero; |
31 | | using hwy::HWY_NAMESPACE::Iota; |
32 | | using hwy::HWY_NAMESPACE::LoadU; |
33 | | using hwy::HWY_NAMESPACE::Lt; |
34 | | using hwy::HWY_NAMESPACE::Mul; |
35 | | using hwy::HWY_NAMESPACE::Or; |
36 | | using hwy::HWY_NAMESPACE::Set; |
37 | | using hwy::HWY_NAMESPACE::ShiftRight; |
38 | | using hwy::HWY_NAMESPACE::Store; |
39 | | using hwy::HWY_NAMESPACE::Sub; |
40 | | using hwy::HWY_NAMESPACE::Zero; |
41 | | |
42 | | template <size_t E, size_t M, size_t L> |
43 | | uint32_t EstimateTokenCostImpl(uint32_t* JXL_RESTRICT values, size_t len, |
44 | | uint32_t* JXL_RESTRICT out) { |
45 | | const HWY_FULL(uint32_t) du; |
46 | | const HWY_FULL(float) df; |
47 | | const auto kZero = Zero(du); |
48 | | const auto kSplit = Set(du, 1 << E); |
49 | | const auto kExpOffset = Set(du, 127); |
50 | | const auto kEBOffset = Set(du, 127 + M + L); |
51 | | const auto kBase = Set(du, static_cast<uint32_t>((1 << E) - (E << (M + L)))); |
52 | | const auto kMulN = Set(du, 1 << (M + L)); |
53 | | const auto kMaskL = Set(du, (1 << L) - 1); |
54 | | const auto kMaskM = Set(du, ((1 << M) - 1) << L); |
55 | | const auto kLargeThreshold = Set(du, (1 << 22) - 1); |
56 | | constexpr size_t kLargeShiftVal = 10; |
57 | | const auto kLargeShift = Set(du, kLargeShiftVal); |
58 | | |
59 | | auto extra_bits = kZero; |
60 | | size_t last_full = Lanes(du) * (len / Lanes(du)); |
61 | | for (size_t i = 0; i < last_full; i += Lanes(du)) { |
62 | | const auto val = LoadU(du, values + i); |
63 | | const auto is_large = Gt(val, kLargeThreshold); |
64 | | const auto val_shifted = ShiftRight<kLargeShiftVal>(val); |
65 | | const auto not_literal = Ge(val, kSplit); |
66 | | const auto val_fixed = IfThenElse(is_large, val_shifted, val); |
67 | | const auto b = BitCast(du, ConvertTo(df, val_fixed)); |
68 | | const auto l = And(val, kMaskL); |
69 | | const auto exp = ShiftRight<23>(b); |
70 | | const auto exp_fixed = IfThenElse(is_large, Add(exp, kLargeShift), exp); |
71 | | const auto n = Sub(exp_fixed, kExpOffset); |
72 | | const auto eb = Sub(exp_fixed, kEBOffset); |
73 | | const auto m = ShiftRight<23 - M - L>(b); |
74 | | const auto a = Add(kBase, Mul(n, kMulN)); |
75 | | const auto d = And(m, kMaskM); |
76 | | const auto eb_fixed = IfThenElseZero(not_literal, eb); |
77 | | const auto c = Or(a, l); |
78 | | extra_bits = Add(extra_bits, eb_fixed); |
79 | | const auto t = Or(c, d); |
80 | | const auto t_fixed = IfThenElse(not_literal, t, val); |
81 | | Store(t_fixed, du, out + i); |
82 | | } |
83 | | if (last_full < len) { |
84 | | const auto stop = Set(du, len); |
85 | | const auto fence = Iota(du, last_full); |
86 | | const auto take = Lt(fence, stop); |
87 | | const auto val = LoadU(du, values + last_full); |
88 | | const auto is_large = Gt(val, kLargeThreshold); |
89 | | const auto val_shifted = ShiftRight<kLargeShiftVal>(val); |
90 | | const auto not_literal = Ge(val, kSplit); |
91 | | const auto val_fixed = IfThenElse(is_large, val_shifted, val); |
92 | | const auto b = BitCast(du, ConvertTo(df, val_fixed)); |
93 | | const auto l = And(val, kMaskL); |
94 | | const auto exp = ShiftRight<23>(b); |
95 | | const auto exp_fixed = IfThenElse(is_large, Add(exp, kLargeShift), exp); |
96 | | const auto n = Sub(exp_fixed, kExpOffset); |
97 | | const auto eb = Sub(exp_fixed, kEBOffset); |
98 | | const auto m = ShiftRight<23 - M - L>(b); |
99 | | const auto a = Add(kBase, Mul(n, kMulN)); |
100 | | const auto d = And(m, kMaskM); |
101 | | const auto eb_fixed = IfThenElseZero(not_literal, eb); |
102 | | const auto eb_masked = IfThenElseZero(take, eb_fixed); |
103 | | const auto c = Or(a, l); |
104 | | extra_bits = Add(extra_bits, eb_masked); |
105 | | const auto t = Or(c, d); |
106 | | const auto t_fixed = IfThenElse(not_literal, t, val); |
107 | | Store(t_fixed, du, out + last_full); |
108 | | } |
109 | | return GetLane(SumOfLanes(du, extra_bits)); |
110 | | } |
111 | | |
112 | | uint32_t EstimateTokenCost(uint32_t* JXL_RESTRICT values, size_t len, |
113 | 0 | HybridUintConfig cfg, AlignedMemory& tokens) { |
114 | 0 | uint32_t* JXL_RESTRICT out = tokens.address<uint32_t>(); |
115 | 0 | #if HWY_TARGET == HWY_SCALAR |
116 | 0 | uint32_t extra_bits = 0; |
117 | 0 | for (size_t i = 0; i < len; ++i) { |
118 | 0 | uint32_t v = values[i]; |
119 | 0 | uint32_t tok, nbits, bits; |
120 | 0 | cfg.Encode(v, &tok, &nbits, &bits); |
121 | 0 | extra_bits += nbits; |
122 | 0 | out[i] = tok; |
123 | 0 | } |
124 | 0 | return extra_bits; |
125 | | #else |
126 | | if (cfg.split_exponent == 0) { |
127 | | return EstimateTokenCostImpl<0, 0, 0>(values, len, out); |
128 | | } else if (cfg.split_exponent == 2) { |
129 | | JXL_DASSERT((cfg.msb_in_token == 0) && (cfg.lsb_in_token == 1)); |
130 | | return EstimateTokenCostImpl<2, 0, 1>(values, len, out); |
131 | | } else if (cfg.split_exponent == 3) { |
132 | | if (cfg.msb_in_token == 1) { |
133 | | if (cfg.lsb_in_token == 0) { |
134 | | return EstimateTokenCostImpl<3, 1, 0>(values, len, out); |
135 | | } else { |
136 | | JXL_DASSERT(cfg.lsb_in_token == 2); |
137 | | return EstimateTokenCostImpl<3, 1, 2>(values, len, out); |
138 | | } |
139 | | } else { |
140 | | JXL_DASSERT(cfg.msb_in_token == 2); |
141 | | if (cfg.lsb_in_token == 0) { |
142 | | return EstimateTokenCostImpl<3, 2, 0>(values, len, out); |
143 | | } else { |
144 | | JXL_DASSERT(cfg.lsb_in_token == 1); |
145 | | return EstimateTokenCostImpl<3, 2, 1>(values, len, out); |
146 | | } |
147 | | } |
148 | | } else if (cfg.split_exponent == 4) { |
149 | | if (cfg.msb_in_token == 1) { |
150 | | if (cfg.lsb_in_token == 0) { |
151 | | return EstimateTokenCostImpl<4, 1, 0>(values, len, out); |
152 | | } else if (cfg.lsb_in_token == 2) { |
153 | | return EstimateTokenCostImpl<4, 1, 2>(values, len, out); |
154 | | } else { |
155 | | JXL_DASSERT(cfg.lsb_in_token == 3); |
156 | | return EstimateTokenCostImpl<4, 1, 3>(values, len, out); |
157 | | } |
158 | | } else { |
159 | | JXL_DASSERT(cfg.msb_in_token == 2); |
160 | | if (cfg.lsb_in_token == 0) { |
161 | | return EstimateTokenCostImpl<4, 2, 0>(values, len, out); |
162 | | } else if (cfg.lsb_in_token == 1) { |
163 | | return EstimateTokenCostImpl<4, 2, 1>(values, len, out); |
164 | | } else { |
165 | | JXL_DASSERT(cfg.lsb_in_token == 2); |
166 | | return EstimateTokenCostImpl<4, 2, 2>(values, len, out); |
167 | | } |
168 | | } |
169 | | } else if (cfg.split_exponent == 5) { |
170 | | if (cfg.msb_in_token == 1) { |
171 | | if (cfg.lsb_in_token == 0) { |
172 | | return EstimateTokenCostImpl<5, 1, 0>(values, len, out); |
173 | | } else if (cfg.lsb_in_token == 2) { |
174 | | return EstimateTokenCostImpl<5, 1, 2>(values, len, out); |
175 | | } else { |
176 | | JXL_DASSERT(cfg.lsb_in_token == 4); |
177 | | return EstimateTokenCostImpl<5, 1, 4>(values, len, out); |
178 | | } |
179 | | } else { |
180 | | JXL_DASSERT(cfg.msb_in_token == 2); |
181 | | if (cfg.lsb_in_token == 0) { |
182 | | return EstimateTokenCostImpl<5, 2, 0>(values, len, out); |
183 | | } else if (cfg.lsb_in_token == 1) { |
184 | | return EstimateTokenCostImpl<5, 2, 1>(values, len, out); |
185 | | } else if (cfg.lsb_in_token == 2) { |
186 | | return EstimateTokenCostImpl<5, 2, 2>(values, len, out); |
187 | | } else { |
188 | | JXL_DASSERT(cfg.lsb_in_token == 3); |
189 | | return EstimateTokenCostImpl<5, 2, 3>(values, len, out); |
190 | | } |
191 | | } |
192 | | } else if (cfg.split_exponent == 6) { |
193 | | if (cfg.msb_in_token == 0) { |
194 | | JXL_DASSERT(cfg.lsb_in_token == 0); |
195 | | return EstimateTokenCostImpl<6, 0, 0>(values, len, out); |
196 | | } else if (cfg.msb_in_token == 1) { |
197 | | JXL_DASSERT(cfg.lsb_in_token == 5); |
198 | | return EstimateTokenCostImpl<6, 1, 5>(values, len, out); |
199 | | } else { |
200 | | JXL_DASSERT(cfg.msb_in_token == 2); |
201 | | JXL_DASSERT(cfg.lsb_in_token == 4); |
202 | | return EstimateTokenCostImpl<6, 2, 4>(values, len, out); |
203 | | } |
204 | | } else if (cfg.split_exponent >= 7 && cfg.split_exponent <= 12) { |
205 | | JXL_DASSERT(cfg.msb_in_token == 0); |
206 | | JXL_DASSERT(cfg.lsb_in_token == 0); |
207 | | if (cfg.split_exponent == 7) { |
208 | | return EstimateTokenCostImpl<7, 0, 0>(values, len, out); |
209 | | } else if (cfg.split_exponent == 8) { |
210 | | return EstimateTokenCostImpl<8, 0, 0>(values, len, out); |
211 | | } else if (cfg.split_exponent == 9) { |
212 | | return EstimateTokenCostImpl<9, 0, 0>(values, len, out); |
213 | | } else if (cfg.split_exponent == 10) { |
214 | | return EstimateTokenCostImpl<10, 0, 0>(values, len, out); |
215 | | } else if (cfg.split_exponent == 11) { |
216 | | return EstimateTokenCostImpl<11, 0, 0>(values, len, out); |
217 | | } else { |
218 | | return EstimateTokenCostImpl<12, 0, 0>(values, len, out); |
219 | | } |
220 | | } else { |
221 | | JXL_DASSERT(false); |
222 | | } |
223 | | return ~0; |
224 | | #endif |
225 | 0 | } |
226 | | |
227 | | // NOLINTNEXTLINE(google-readability-namespace-comments) |
228 | | } // namespace HWY_NAMESPACE |
229 | | } // namespace jxl |
230 | | HWY_AFTER_NAMESPACE(); |
231 | | |
232 | | #if HWY_ONCE |
233 | | namespace jxl { |
234 | | |
235 | | HWY_EXPORT(EstimateTokenCost); |
236 | | |
237 | | uint32_t EstimateTokenCost(uint32_t* JXL_RESTRICT values, size_t len, |
238 | 0 | HybridUintConfig cfg, AlignedMemory& tokens) { |
239 | 0 | JXL_DASSERT(cfg.lsb_in_token + cfg.msb_in_token <= cfg.split_exponent); |
240 | 0 | return HWY_DYNAMIC_DISPATCH(EstimateTokenCost)(values, len, cfg, tokens); |
241 | 0 | } |
242 | | |
243 | | } // namespace jxl |
244 | | #endif |