/src/libjxl/lib/jxl/enc_quant_weights.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/enc_quant_weights.h" |
7 | | |
8 | | #include <jxl/memory_manager.h> |
9 | | #include <jxl/types.h> |
10 | | |
11 | | #include <cstdlib> |
12 | | #include <vector> |
13 | | |
14 | | #include "lib/jxl/base/common.h" |
15 | | #include "lib/jxl/base/status.h" |
16 | | #include "lib/jxl/enc_aux_out.h" |
17 | | #include "lib/jxl/enc_bit_writer.h" |
18 | | #include "lib/jxl/enc_modular.h" |
19 | | #include "lib/jxl/fields.h" |
20 | | #include "lib/jxl/frame_dimensions.h" |
21 | | #include "lib/jxl/modular/encoding/encoding.h" |
22 | | #include "lib/jxl/quant_weights.h" |
23 | | |
24 | | namespace jxl { |
25 | | |
26 | | namespace { |
27 | | |
28 | 0 | Status EncodeDctParams(const DctQuantWeightParams& params, BitWriter* writer) { |
29 | 0 | JXL_ENSURE(params.num_distance_bands >= 1); |
30 | 0 | writer->Write(DctQuantWeightParams::kLog2MaxDistanceBands, |
31 | 0 | params.num_distance_bands - 1); |
32 | 0 | for (size_t c = 0; c < 3; c++) { |
33 | 0 | for (size_t i = 0; i < params.num_distance_bands; i++) { |
34 | 0 | JXL_RETURN_IF_ERROR(F16Coder::Write( |
35 | 0 | params.distance_bands[c][i] * (i == 0 ? (1 / 64.0f) : 1.0f), writer)); |
36 | 0 | } |
37 | 0 | } |
38 | 0 | return true; |
39 | 0 | } |
40 | | |
41 | | Status EncodeQuant(JxlMemoryManager* memory_manager, |
42 | | const QuantEncoding& encoding, size_t idx, size_t size_x, |
43 | | size_t size_y, BitWriter* writer, |
44 | 0 | ModularFrameEncoder* modular_frame_encoder) { |
45 | 0 | writer->Write(kLog2NumQuantModes, encoding.mode); |
46 | 0 | size_x *= kBlockDim; |
47 | 0 | size_y *= kBlockDim; |
48 | 0 | switch (encoding.mode) { |
49 | 0 | case QuantEncoding::kQuantModeLibrary: { |
50 | 0 | writer->Write(kCeilLog2NumPredefinedTables, encoding.predefined); |
51 | 0 | break; |
52 | 0 | } |
53 | 0 | case QuantEncoding::kQuantModeID: { |
54 | 0 | for (size_t c = 0; c < 3; c++) { |
55 | 0 | for (size_t i = 0; i < 3; i++) { |
56 | 0 | JXL_RETURN_IF_ERROR( |
57 | 0 | F16Coder::Write(encoding.idweights[c][i] * (1.0f / 64), writer)); |
58 | 0 | } |
59 | 0 | } |
60 | 0 | break; |
61 | 0 | } |
62 | 0 | case QuantEncoding::kQuantModeDCT2: { |
63 | 0 | for (size_t c = 0; c < 3; c++) { |
64 | 0 | for (size_t i = 0; i < 6; i++) { |
65 | 0 | JXL_RETURN_IF_ERROR(F16Coder::Write( |
66 | 0 | encoding.dct2weights[c][i] * (1.0f / 64), writer)); |
67 | 0 | } |
68 | 0 | } |
69 | 0 | break; |
70 | 0 | } |
71 | 0 | case QuantEncoding::kQuantModeDCT4X8: { |
72 | 0 | for (size_t c = 0; c < 3; c++) { |
73 | 0 | JXL_RETURN_IF_ERROR( |
74 | 0 | F16Coder::Write(encoding.dct4x8multipliers[c], writer)); |
75 | 0 | } |
76 | 0 | JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); |
77 | 0 | break; |
78 | 0 | } |
79 | 0 | case QuantEncoding::kQuantModeDCT4: { |
80 | 0 | for (size_t c = 0; c < 3; c++) { |
81 | 0 | for (size_t i = 0; i < 2; i++) { |
82 | 0 | JXL_RETURN_IF_ERROR( |
83 | 0 | F16Coder::Write(encoding.dct4multipliers[c][i], writer)); |
84 | 0 | } |
85 | 0 | } |
86 | 0 | JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); |
87 | 0 | break; |
88 | 0 | } |
89 | 0 | case QuantEncoding::kQuantModeDCT: { |
90 | 0 | JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); |
91 | 0 | break; |
92 | 0 | } |
93 | 0 | case QuantEncoding::kQuantModeRAW: { |
94 | 0 | JXL_RETURN_IF_ERROR(ModularFrameEncoder::EncodeQuantTable( |
95 | 0 | memory_manager, size_x, size_y, writer, encoding, idx, |
96 | 0 | modular_frame_encoder)); |
97 | 0 | break; |
98 | 0 | } |
99 | 0 | case QuantEncoding::kQuantModeAFV: { |
100 | 0 | for (size_t c = 0; c < 3; c++) { |
101 | 0 | for (size_t i = 0; i < 9; i++) { |
102 | 0 | JXL_RETURN_IF_ERROR(F16Coder::Write( |
103 | 0 | encoding.afv_weights[c][i] * (i < 6 ? 1.0f / 64 : 1.0f), writer)); |
104 | 0 | } |
105 | 0 | } |
106 | 0 | JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); |
107 | 0 | JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params_afv_4x4, writer)); |
108 | 0 | break; |
109 | 0 | } |
110 | 0 | } |
111 | 0 | return true; |
112 | 0 | } |
113 | | |
114 | | } // namespace |
115 | | |
116 | | Status DequantMatricesEncode(JxlMemoryManager* memory_manager, |
117 | | const DequantMatrices& matrices, BitWriter* writer, |
118 | | LayerType layer, AuxOut* aux_out, |
119 | 186 | ModularFrameEncoder* modular_frame_encoder) { |
120 | 186 | bool all_default = true; |
121 | 186 | const std::vector<QuantEncoding>& encodings = matrices.encodings(); |
122 | | |
123 | 3.16k | for (const auto& encoding : encodings) { |
124 | 3.16k | if (encoding.mode != QuantEncoding::kQuantModeLibrary || |
125 | 3.16k | encoding.predefined != 0) { |
126 | 0 | all_default = false; |
127 | 0 | } |
128 | 3.16k | } |
129 | | // TODO(janwas): better bound |
130 | 186 | return writer->WithMaxBits(512 * 1024, layer, aux_out, [&]() -> Status { |
131 | 186 | writer->Write(1, TO_JXL_BOOL(all_default)); |
132 | 186 | if (!all_default) { |
133 | 0 | for (size_t i = 0; i < encodings.size(); i++) { |
134 | 0 | JXL_RETURN_IF_ERROR(EncodeQuant(memory_manager, encodings[i], i, |
135 | 0 | DequantMatrices::required_size_x[i], |
136 | 0 | DequantMatrices::required_size_y[i], |
137 | 0 | writer, modular_frame_encoder)); |
138 | 0 | } |
139 | 0 | } |
140 | 186 | return true; |
141 | 186 | }); |
142 | 186 | } |
143 | | |
144 | | Status DequantMatricesEncodeDC(const DequantMatrices& matrices, |
145 | | BitWriter* writer, LayerType layer, |
146 | 566 | AuxOut* aux_out) { |
147 | 566 | bool all_default = true; |
148 | 566 | const float* dc_quant = matrices.DCQuants(); |
149 | 2.26k | for (size_t c = 0; c < 3; c++) { |
150 | 1.69k | if (dc_quant[c] != kDCQuant[c]) { |
151 | 582 | all_default = false; |
152 | 582 | } |
153 | 1.69k | } |
154 | 566 | return writer->WithMaxBits( |
155 | 566 | 1 + sizeof(float) * kBitsPerByte * 3, layer, aux_out, [&]() -> Status { |
156 | 566 | writer->Write(1, TO_JXL_BOOL(all_default)); |
157 | 566 | if (!all_default) { |
158 | 776 | for (size_t c = 0; c < 3; c++) { |
159 | 582 | JXL_RETURN_IF_ERROR(F16Coder::Write(dc_quant[c] * 128.0f, writer)); |
160 | 582 | } |
161 | 194 | } |
162 | 566 | return true; |
163 | 566 | }); |
164 | 566 | } |
165 | | |
166 | | Status DequantMatricesSetCustomDC(JxlMemoryManager* memory_manager, |
167 | 283 | DequantMatrices* matrices, const float* dc) { |
168 | 283 | matrices->SetDCQuant(dc); |
169 | | // Roundtrip encode/decode DC to ensure same values as decoder. |
170 | 283 | BitWriter writer{memory_manager}; |
171 | | // TODO(eustas): should it be LayerType::Quant? |
172 | 283 | JXL_RETURN_IF_ERROR( |
173 | 283 | DequantMatricesEncodeDC(*matrices, &writer, LayerType::Header, nullptr)); |
174 | 283 | writer.ZeroPadToByte(); |
175 | 283 | BitReader br(writer.GetSpan()); |
176 | | // Called only in the encoder: should fail only for programmer errors. |
177 | 283 | JXL_RETURN_IF_ERROR(matrices->DecodeDC(&br)); |
178 | 283 | JXL_RETURN_IF_ERROR(br.Close()); |
179 | 283 | return true; |
180 | 283 | } |
181 | | |
182 | | Status DequantMatricesScaleDC(JxlMemoryManager* memory_manager, |
183 | 186 | DequantMatrices* matrices, const float scale) { |
184 | 186 | float dc[3]; |
185 | 744 | for (size_t c = 0; c < 3; ++c) { |
186 | 558 | dc[c] = matrices->InvDCQuant(c) * (1.0f / scale); |
187 | 558 | } |
188 | 186 | JXL_RETURN_IF_ERROR(DequantMatricesSetCustomDC(memory_manager, matrices, dc)); |
189 | 186 | return true; |
190 | 186 | } |
191 | | |
192 | | Status DequantMatricesRoundtrip(JxlMemoryManager* memory_manager, |
193 | 0 | DequantMatrices* matrices) { |
194 | | // Do not pass modular en/decoder, as they only change entropy and not |
195 | | // values. |
196 | 0 | BitWriter writer{memory_manager}; |
197 | | // TODO(eustas): should it be LayerType::Quant? |
198 | 0 | JXL_RETURN_IF_ERROR(DequantMatricesEncode(memory_manager, *matrices, &writer, |
199 | 0 | LayerType::Header, nullptr)); |
200 | 0 | writer.ZeroPadToByte(); |
201 | 0 | BitReader br(writer.GetSpan()); |
202 | | // Called only in the encoder: should fail only for programmer errors. |
203 | 0 | JXL_RETURN_IF_ERROR(matrices->Decode(memory_manager, &br)); |
204 | 0 | JXL_RETURN_IF_ERROR(br.Close()); |
205 | 0 | return true; |
206 | 0 | } |
207 | | |
208 | | Status DequantMatricesSetCustom(DequantMatrices* matrices, |
209 | | const std::vector<QuantEncoding>& encodings, |
210 | 0 | ModularFrameEncoder* encoder) { |
211 | 0 | JXL_ENSURE(encoder != nullptr); |
212 | 0 | JXL_ENSURE(encodings.size() == kNumQuantTables); |
213 | 0 | JxlMemoryManager* memory_manager = encoder->memory_manager(); |
214 | 0 | matrices->SetEncodings(encodings); |
215 | 0 | for (size_t i = 0; i < encodings.size(); i++) { |
216 | 0 | if (encodings[i].mode == QuantEncodingInternal::kQuantModeRAW) { |
217 | 0 | JXL_RETURN_IF_ERROR(encoder->AddQuantTable( |
218 | 0 | DequantMatrices::required_size_x[i] * kBlockDim, |
219 | 0 | DequantMatrices::required_size_y[i] * kBlockDim, encodings[i], i)); |
220 | 0 | } |
221 | 0 | } |
222 | 0 | JXL_RETURN_IF_ERROR(DequantMatricesRoundtrip(memory_manager, matrices)); |
223 | 0 | return true; |
224 | 0 | } |
225 | | |
226 | | } // namespace jxl |