/src/libjxl/lib/jxl/enc_quant_weights.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/enc_quant_weights.h" |
7 | | |
8 | | #include <jxl/memory_manager.h> |
9 | | #include <jxl/types.h> |
10 | | |
11 | | #include <cmath> |
12 | | #include <cstdlib> |
13 | | |
14 | | #include "lib/jxl/base/common.h" |
15 | | #include "lib/jxl/base/status.h" |
16 | | #include "lib/jxl/enc_aux_out.h" |
17 | | #include "lib/jxl/enc_bit_writer.h" |
18 | | #include "lib/jxl/enc_modular.h" |
19 | | #include "lib/jxl/fields.h" |
20 | | #include "lib/jxl/modular/encoding/encoding.h" |
21 | | |
22 | | namespace jxl { |
23 | | |
24 | | namespace { |
25 | | |
26 | 0 | Status EncodeDctParams(const DctQuantWeightParams& params, BitWriter* writer) { |
27 | 0 | JXL_ENSURE(params.num_distance_bands >= 1); |
28 | 0 | writer->Write(DctQuantWeightParams::kLog2MaxDistanceBands, |
29 | 0 | params.num_distance_bands - 1); |
30 | 0 | for (size_t c = 0; c < 3; c++) { |
31 | 0 | for (size_t i = 0; i < params.num_distance_bands; i++) { |
32 | 0 | JXL_RETURN_IF_ERROR(F16Coder::Write( |
33 | 0 | params.distance_bands[c][i] * (i == 0 ? (1 / 64.0f) : 1.0f), writer)); |
34 | 0 | } |
35 | 0 | } |
36 | 0 | return true; |
37 | 0 | } |
38 | | |
39 | | Status EncodeQuant(JxlMemoryManager* memory_manager, |
40 | | const QuantEncoding& encoding, size_t idx, size_t size_x, |
41 | | size_t size_y, BitWriter* writer, |
42 | 0 | ModularFrameEncoder* modular_frame_encoder) { |
43 | 0 | writer->Write(kLog2NumQuantModes, encoding.mode); |
44 | 0 | size_x *= kBlockDim; |
45 | 0 | size_y *= kBlockDim; |
46 | 0 | switch (encoding.mode) { |
47 | 0 | case QuantEncoding::kQuantModeLibrary: { |
48 | 0 | writer->Write(kCeilLog2NumPredefinedTables, encoding.predefined); |
49 | 0 | break; |
50 | 0 | } |
51 | 0 | case QuantEncoding::kQuantModeID: { |
52 | 0 | for (size_t c = 0; c < 3; c++) { |
53 | 0 | for (size_t i = 0; i < 3; i++) { |
54 | 0 | JXL_RETURN_IF_ERROR( |
55 | 0 | F16Coder::Write(encoding.idweights[c][i] * (1.0f / 64), writer)); |
56 | 0 | } |
57 | 0 | } |
58 | 0 | break; |
59 | 0 | } |
60 | 0 | case QuantEncoding::kQuantModeDCT2: { |
61 | 0 | for (size_t c = 0; c < 3; c++) { |
62 | 0 | for (size_t i = 0; i < 6; i++) { |
63 | 0 | JXL_RETURN_IF_ERROR(F16Coder::Write( |
64 | 0 | encoding.dct2weights[c][i] * (1.0f / 64), writer)); |
65 | 0 | } |
66 | 0 | } |
67 | 0 | break; |
68 | 0 | } |
69 | 0 | case QuantEncoding::kQuantModeDCT4X8: { |
70 | 0 | for (size_t c = 0; c < 3; c++) { |
71 | 0 | JXL_RETURN_IF_ERROR( |
72 | 0 | F16Coder::Write(encoding.dct4x8multipliers[c], writer)); |
73 | 0 | } |
74 | 0 | JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); |
75 | 0 | break; |
76 | 0 | } |
77 | 0 | case QuantEncoding::kQuantModeDCT4: { |
78 | 0 | for (size_t c = 0; c < 3; c++) { |
79 | 0 | for (size_t i = 0; i < 2; i++) { |
80 | 0 | JXL_RETURN_IF_ERROR( |
81 | 0 | F16Coder::Write(encoding.dct4multipliers[c][i], writer)); |
82 | 0 | } |
83 | 0 | } |
84 | 0 | JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); |
85 | 0 | break; |
86 | 0 | } |
87 | 0 | case QuantEncoding::kQuantModeDCT: { |
88 | 0 | JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); |
89 | 0 | break; |
90 | 0 | } |
91 | 0 | case QuantEncoding::kQuantModeRAW: { |
92 | 0 | JXL_RETURN_IF_ERROR(ModularFrameEncoder::EncodeQuantTable( |
93 | 0 | memory_manager, size_x, size_y, writer, encoding, idx, |
94 | 0 | modular_frame_encoder)); |
95 | 0 | break; |
96 | 0 | } |
97 | 0 | case QuantEncoding::kQuantModeAFV: { |
98 | 0 | for (size_t c = 0; c < 3; c++) { |
99 | 0 | for (size_t i = 0; i < 9; i++) { |
100 | 0 | JXL_RETURN_IF_ERROR(F16Coder::Write( |
101 | 0 | encoding.afv_weights[c][i] * (i < 6 ? 1.0f / 64 : 1.0f), writer)); |
102 | 0 | } |
103 | 0 | } |
104 | 0 | JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); |
105 | 0 | JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params_afv_4x4, writer)); |
106 | 0 | break; |
107 | 0 | } |
108 | 0 | } |
109 | 0 | return true; |
110 | 0 | } |
111 | | |
112 | | } // namespace |
113 | | |
114 | | Status DequantMatricesEncode(JxlMemoryManager* memory_manager, |
115 | | const DequantMatrices& matrices, BitWriter* writer, |
116 | | LayerType layer, AuxOut* aux_out, |
117 | 0 | ModularFrameEncoder* modular_frame_encoder) { |
118 | 0 | bool all_default = true; |
119 | 0 | const std::vector<QuantEncoding>& encodings = matrices.encodings(); |
120 | |
|
121 | 0 | for (const auto& encoding : encodings) { |
122 | 0 | if (encoding.mode != QuantEncoding::kQuantModeLibrary || |
123 | 0 | encoding.predefined != 0) { |
124 | 0 | all_default = false; |
125 | 0 | } |
126 | 0 | } |
127 | | // TODO(janwas): better bound |
128 | 0 | return writer->WithMaxBits(512 * 1024, layer, aux_out, [&]() -> Status { |
129 | 0 | writer->Write(1, TO_JXL_BOOL(all_default)); |
130 | 0 | if (!all_default) { |
131 | 0 | for (size_t i = 0; i < encodings.size(); i++) { |
132 | 0 | JXL_RETURN_IF_ERROR(EncodeQuant(memory_manager, encodings[i], i, |
133 | 0 | DequantMatrices::required_size_x[i], |
134 | 0 | DequantMatrices::required_size_y[i], |
135 | 0 | writer, modular_frame_encoder)); |
136 | 0 | } |
137 | 0 | } |
138 | 0 | return true; |
139 | 0 | }); |
140 | 0 | } |
141 | | |
142 | | Status DequantMatricesEncodeDC(const DequantMatrices& matrices, |
143 | | BitWriter* writer, LayerType layer, |
144 | 0 | AuxOut* aux_out) { |
145 | 0 | bool all_default = true; |
146 | 0 | const float* dc_quant = matrices.DCQuants(); |
147 | 0 | for (size_t c = 0; c < 3; c++) { |
148 | 0 | if (dc_quant[c] != kDCQuant[c]) { |
149 | 0 | all_default = false; |
150 | 0 | } |
151 | 0 | } |
152 | 0 | return writer->WithMaxBits( |
153 | 0 | 1 + sizeof(float) * kBitsPerByte * 3, layer, aux_out, [&]() -> Status { |
154 | 0 | writer->Write(1, TO_JXL_BOOL(all_default)); |
155 | 0 | if (!all_default) { |
156 | 0 | for (size_t c = 0; c < 3; c++) { |
157 | 0 | JXL_RETURN_IF_ERROR(F16Coder::Write(dc_quant[c] * 128.0f, writer)); |
158 | 0 | } |
159 | 0 | } |
160 | 0 | return true; |
161 | 0 | }); |
162 | 0 | } |
163 | | |
164 | | Status DequantMatricesSetCustomDC(JxlMemoryManager* memory_manager, |
165 | 0 | DequantMatrices* matrices, const float* dc) { |
166 | 0 | matrices->SetDCQuant(dc); |
167 | | // Roundtrip encode/decode DC to ensure same values as decoder. |
168 | 0 | BitWriter writer{memory_manager}; |
169 | | // TODO(eustas): should it be LayerType::Quant? |
170 | 0 | JXL_RETURN_IF_ERROR( |
171 | 0 | DequantMatricesEncodeDC(*matrices, &writer, LayerType::Header, nullptr)); |
172 | 0 | writer.ZeroPadToByte(); |
173 | 0 | BitReader br(writer.GetSpan()); |
174 | | // Called only in the encoder: should fail only for programmer errors. |
175 | 0 | JXL_RETURN_IF_ERROR(matrices->DecodeDC(&br)); |
176 | 0 | JXL_RETURN_IF_ERROR(br.Close()); |
177 | 0 | return true; |
178 | 0 | } |
179 | | |
180 | | Status DequantMatricesScaleDC(JxlMemoryManager* memory_manager, |
181 | 0 | DequantMatrices* matrices, const float scale) { |
182 | 0 | float dc[3]; |
183 | 0 | for (size_t c = 0; c < 3; ++c) { |
184 | 0 | dc[c] = matrices->InvDCQuant(c) * (1.0f / scale); |
185 | 0 | } |
186 | 0 | JXL_RETURN_IF_ERROR(DequantMatricesSetCustomDC(memory_manager, matrices, dc)); |
187 | 0 | return true; |
188 | 0 | } |
189 | | |
190 | | Status DequantMatricesRoundtrip(JxlMemoryManager* memory_manager, |
191 | 0 | DequantMatrices* matrices) { |
192 | | // Do not pass modular en/decoder, as they only change entropy and not |
193 | | // values. |
194 | 0 | BitWriter writer{memory_manager}; |
195 | | // TODO(eustas): should it be LayerType::Quant? |
196 | 0 | JXL_RETURN_IF_ERROR(DequantMatricesEncode(memory_manager, *matrices, &writer, |
197 | 0 | LayerType::Header, nullptr)); |
198 | 0 | writer.ZeroPadToByte(); |
199 | 0 | BitReader br(writer.GetSpan()); |
200 | | // Called only in the encoder: should fail only for programmer errors. |
201 | 0 | JXL_RETURN_IF_ERROR(matrices->Decode(memory_manager, &br)); |
202 | 0 | JXL_RETURN_IF_ERROR(br.Close()); |
203 | 0 | return true; |
204 | 0 | } |
205 | | |
206 | | Status DequantMatricesSetCustom(DequantMatrices* matrices, |
207 | | const std::vector<QuantEncoding>& encodings, |
208 | 0 | ModularFrameEncoder* encoder) { |
209 | 0 | JXL_ENSURE(encoder != nullptr); |
210 | 0 | JXL_ENSURE(encodings.size() == kNumQuantTables); |
211 | 0 | JxlMemoryManager* memory_manager = encoder->memory_manager(); |
212 | 0 | matrices->SetEncodings(encodings); |
213 | 0 | for (size_t i = 0; i < encodings.size(); i++) { |
214 | 0 | if (encodings[i].mode == QuantEncodingInternal::kQuantModeRAW) { |
215 | 0 | JXL_RETURN_IF_ERROR(encoder->AddQuantTable( |
216 | 0 | DequantMatrices::required_size_x[i] * kBlockDim, |
217 | 0 | DequantMatrices::required_size_y[i] * kBlockDim, encodings[i], i)); |
218 | 0 | } |
219 | 0 | } |
220 | 0 | JXL_RETURN_IF_ERROR(DequantMatricesRoundtrip(memory_manager, matrices)); |
221 | 0 | return true; |
222 | 0 | } |
223 | | |
224 | | } // namespace jxl |