/src/libjxl/lib/jxl/dec_patch_dictionary.cc
Line | Count | Source |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/jxl/dec_patch_dictionary.h" |
7 | | |
8 | | #include <jxl/memory_manager.h> |
9 | | |
10 | | #include <algorithm> |
11 | | #include <cstddef> |
12 | | #include <cstdint> |
13 | | #include <cstdlib> |
14 | | #include <utility> |
15 | | #include <vector> |
16 | | |
17 | | #include "lib/jxl/base/printf_macros.h" |
18 | | #include "lib/jxl/base/status.h" |
19 | | #include "lib/jxl/blending.h" |
20 | | #include "lib/jxl/common.h" // kMaxNumReferenceFrames |
21 | | #include "lib/jxl/dec_ans.h" |
22 | | #include "lib/jxl/dec_bit_reader.h" |
23 | | #include "lib/jxl/image.h" |
24 | | #include "lib/jxl/image_bundle.h" |
25 | | #include "lib/jxl/image_metadata.h" |
26 | | #include "lib/jxl/pack_signed.h" |
27 | | #include "lib/jxl/patch_dictionary_internal.h" |
28 | | |
29 | | namespace jxl { |
30 | | |
31 | | Status PatchDictionary::Decode(JxlMemoryManager* memory_manager, BitReader* br, |
32 | | size_t xsize, size_t ysize, |
33 | | size_t num_extra_channels, |
34 | 3.14k | bool* uses_extra_channels) { |
35 | 3.14k | positions_.clear(); |
36 | 3.14k | blendings_stride_ = num_extra_channels + 1; |
37 | 3.14k | std::vector<uint8_t> context_map; |
38 | 3.14k | ANSCode code; |
39 | 3.14k | JXL_RETURN_IF_ERROR(DecodeHistograms( |
40 | 3.14k | memory_manager, br, kNumPatchDictionaryContexts, &code, &context_map)); |
41 | 5.32k | JXL_ASSIGN_OR_RETURN(ANSSymbolReader decoder, |
42 | 5.32k | ANSSymbolReader::Create(&code, br)); |
43 | | |
44 | 8.65M | auto read_num = [&](size_t context) -> size_t { |
45 | 8.65M | size_t r = decoder.ReadHybridUint(context, br, context_map); |
46 | 8.65M | return r; |
47 | 8.65M | }; |
48 | | |
49 | 5.32k | size_t num_ref_patch = read_num(kNumRefPatchContext); |
50 | | // Limit max memory usage of patches to about 66 bytes per pixel (assuming 8 |
51 | | // bytes per size_t) |
52 | 5.32k | const size_t num_pixels = xsize * ysize; |
53 | 5.32k | const size_t max_ref_patches = 1024 + num_pixels / 4; |
54 | 5.32k | const size_t max_patches = max_ref_patches * 4; |
55 | 5.32k | const size_t max_blending_infos = max_patches * 4; |
56 | 5.32k | if (num_ref_patch > max_ref_patches) { |
57 | 10 | return JXL_FAILURE("Too many patches in dictionary"); |
58 | 10 | } |
59 | | |
60 | 2.65k | size_t total_patches = 0; |
61 | 2.65k | size_t next_size = 1; |
62 | | |
63 | 96.7k | for (size_t id = 0; id < num_ref_patch; id++) { |
64 | 94.2k | PatchReferencePosition ref_pos; |
65 | 94.2k | ref_pos.ref = read_num(kReferenceFrameContext); |
66 | 94.2k | if (ref_pos.ref >= kMaxNumReferenceFrames || |
67 | 94.2k | reference_frames_->at(ref_pos.ref).frame->xsize() == 0) { |
68 | 47 | return JXL_FAILURE("Invalid reference frame ID"); |
69 | 47 | } |
70 | 94.1k | if (!reference_frames_->at(ref_pos.ref).ib_is_in_xyb) { |
71 | 2 | return JXL_FAILURE( |
72 | 2 | "Patches cannot use frames saved post color transforms"); |
73 | 2 | } |
74 | 94.1k | const ImageBundle& ib = *reference_frames_->at(ref_pos.ref).frame; |
75 | 94.1k | ref_pos.x0 = read_num(kPatchReferencePositionContext); |
76 | 94.1k | ref_pos.y0 = read_num(kPatchReferencePositionContext); |
77 | 94.1k | ref_pos.xsize = read_num(kPatchSizeContext) + 1; |
78 | 94.1k | ref_pos.ysize = read_num(kPatchSizeContext) + 1; |
79 | 94.1k | if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) { |
80 | 9 | return JXL_FAILURE("Invalid position specified in reference frame"); |
81 | 9 | } |
82 | 94.1k | if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) { |
83 | 3 | return JXL_FAILURE("Invalid position specified in reference frame"); |
84 | 3 | } |
85 | 94.1k | size_t id_count = read_num(kPatchCountContext); |
86 | 94.1k | if (id_count > max_patches) { |
87 | 2 | return JXL_FAILURE("Too many patches in dictionary"); |
88 | 2 | } |
89 | 94.1k | id_count++; |
90 | 94.1k | total_patches += id_count; |
91 | 94.1k | if (total_patches > max_patches) { |
92 | 5 | return JXL_FAILURE("Too many patches in dictionary"); |
93 | 5 | } |
94 | 94.1k | if (next_size < total_patches) { |
95 | 17.5k | next_size *= 2; |
96 | 17.5k | next_size = std::min<size_t>(next_size, max_patches); |
97 | 17.5k | } |
98 | 94.1k | if (next_size * blendings_stride_ > max_blending_infos) { |
99 | 0 | return JXL_FAILURE("Too many patches in dictionary"); |
100 | 0 | } |
101 | 94.1k | positions_.reserve(next_size); |
102 | 94.1k | blendings_.reserve(next_size * blendings_stride_); |
103 | 94.1k | bool choose_alpha = (num_extra_channels > 1); |
104 | 2.77M | for (size_t i = 0; i < id_count; i++) { |
105 | 2.68M | PatchPosition pos; |
106 | 2.68M | pos.ref_pos_idx = ref_positions_.size(); |
107 | 2.68M | if (i == 0) { |
108 | 94.1k | pos.x = read_num(kPatchPositionContext); |
109 | 94.1k | pos.y = read_num(kPatchPositionContext); |
110 | 2.58M | } else { |
111 | 2.58M | ptrdiff_t deltax = UnpackSigned(read_num(kPatchOffsetContext)); |
112 | 2.58M | if (deltax < 0 && static_cast<size_t>(-deltax) > positions_.back().x) { |
113 | 9 | return JXL_FAILURE("Invalid patch: negative x coordinate (%" PRIuS |
114 | 9 | " base x %" PRIdS " delta x)", |
115 | 9 | positions_.back().x, deltax); |
116 | 9 | } |
117 | 2.58M | pos.x = positions_.back().x + deltax; |
118 | 2.58M | ptrdiff_t deltay = UnpackSigned(read_num(kPatchOffsetContext)); |
119 | 2.58M | if (deltay < 0 && static_cast<size_t>(-deltay) > positions_.back().y) { |
120 | 6 | return JXL_FAILURE("Invalid patch: negative y coordinate (%" PRIuS |
121 | 6 | " base y %" PRIdS " delta y)", |
122 | 6 | positions_.back().y, deltay); |
123 | 6 | } |
124 | 2.58M | pos.y = positions_.back().y + deltay; |
125 | 2.58M | } |
126 | 2.68M | if (pos.x + ref_pos.xsize > xsize) { |
127 | 6 | return JXL_FAILURE("Invalid patch x: at %" PRIuS " + %" PRIuS |
128 | 6 | " > %" PRIuS, |
129 | 6 | pos.x, ref_pos.xsize, xsize); |
130 | 6 | } |
131 | 2.68M | if (pos.y + ref_pos.ysize > ysize) { |
132 | 4 | return JXL_FAILURE("Invalid patch y: at %" PRIuS " + %" PRIuS |
133 | 4 | " > %" PRIuS, |
134 | 4 | pos.y, ref_pos.ysize, ysize); |
135 | 4 | } |
136 | 5.38M | for (size_t j = 0; j < blendings_stride_; j++) { |
137 | 2.70M | uint32_t blend_mode = read_num(kPatchBlendModeContext); |
138 | 2.70M | if (blend_mode >= kNumPatchBlendModes) { |
139 | 7 | return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode); |
140 | 7 | } |
141 | 2.70M | PatchBlending info; |
142 | 2.70M | info.mode = static_cast<PatchBlendMode>(blend_mode); |
143 | 2.70M | if (UsesAlpha(info.mode)) { |
144 | 18.2k | *uses_extra_channels = true; |
145 | 18.2k | } |
146 | 2.70M | if (info.mode != PatchBlendMode::kNone && j > 0) { |
147 | 44 | *uses_extra_channels = true; |
148 | 44 | } |
149 | 2.70M | if (UsesAlpha(info.mode) && choose_alpha) { |
150 | 0 | info.alpha_channel = read_num(kPatchAlphaChannelContext); |
151 | 0 | if (info.alpha_channel >= num_extra_channels) { |
152 | 0 | return JXL_FAILURE( |
153 | 0 | "Invalid alpha channel for blending: %u out of %u\n", |
154 | 0 | info.alpha_channel, static_cast<uint32_t>(num_extra_channels)); |
155 | 0 | } |
156 | 2.70M | } else { |
157 | 2.70M | info.alpha_channel = 0; |
158 | 2.70M | } |
159 | 2.70M | if (UsesClamp(info.mode)) { |
160 | 21.2k | info.clamp = static_cast<bool>(read_num(kPatchClampContext)); |
161 | 2.68M | } else { |
162 | 2.68M | info.clamp = false; |
163 | 2.68M | } |
164 | 2.70M | blendings_.push_back(info); |
165 | 2.70M | } |
166 | 2.68M | positions_.emplace_back(pos); |
167 | 2.68M | } |
168 | 94.1k | ref_positions_.emplace_back(ref_pos); |
169 | 94.1k | } |
170 | 2.55k | positions_.shrink_to_fit(); |
171 | | |
172 | 2.55k | if (!decoder.CheckANSFinalState()) { |
173 | 0 | return JXL_FAILURE("ANS checksum failure."); |
174 | 0 | } |
175 | | |
176 | 2.55k | ComputePatchTree(); |
177 | 2.55k | return true; |
178 | 2.55k | } |
179 | | |
180 | 1.46k | int PatchDictionary::GetReferences() const { |
181 | 1.46k | int result = 0; |
182 | 111k | for (const auto& ref_pos : ref_positions_) { |
183 | 111k | result |= (1 << static_cast<int>(ref_pos.ref)); |
184 | 111k | } |
185 | 1.46k | return result; |
186 | 1.46k | } |
187 | | |
188 | | namespace { |
189 | | struct PatchInterval { |
190 | | size_t idx; |
191 | | size_t y0, y1; |
192 | | }; |
193 | | } // namespace |
194 | | |
195 | 201k | void PatchDictionary::ComputePatchTree() { |
196 | 201k | patch_tree_.clear(); |
197 | 201k | num_patches_.clear(); |
198 | 201k | sorted_patches_y0_.clear(); |
199 | 201k | sorted_patches_y1_.clear(); |
200 | 201k | if (positions_.empty()) { |
201 | 196k | return; |
202 | 196k | } |
203 | | // Create a y-interval for each patch. |
204 | 4.09k | std::vector<PatchInterval> intervals(positions_.size()); |
205 | 4.74M | for (size_t i = 0; i < positions_.size(); ++i) { |
206 | 4.74M | const auto& pos = positions_[i]; |
207 | 4.74M | intervals[i].idx = i; |
208 | 4.74M | intervals[i].y0 = pos.y; |
209 | 4.74M | intervals[i].y1 = pos.y + ref_positions_[pos.ref_pos_idx].ysize; |
210 | 4.74M | } |
211 | 545k | auto sort_by_y0 = [&intervals](size_t start, size_t end) { |
212 | 545k | std::sort(intervals.data() + start, intervals.data() + end, |
213 | 176M | [](const PatchInterval& i0, const PatchInterval& i1) { |
214 | 176M | return i0.y0 < i1.y0; |
215 | 176M | }); |
216 | 545k | }; |
217 | 276k | auto sort_by_y1 = [&intervals](size_t start, size_t end) { |
218 | 276k | std::sort(intervals.data() + start, intervals.data() + end, |
219 | 135M | [](const PatchInterval& i0, const PatchInterval& i1) { |
220 | 135M | return i0.y1 < i1.y1; |
221 | 135M | }); |
222 | 276k | }; |
223 | | // Count the number of patches for each row. |
224 | 4.09k | sort_by_y1(0, intervals.size()); |
225 | 4.09k | num_patches_.resize(intervals.back().y1); |
226 | 4.74M | for (auto iv : intervals) { |
227 | 102M | for (size_t y = iv.y0; y < iv.y1; ++y) num_patches_[y]++; |
228 | 4.74M | } |
229 | 4.09k | PatchTreeNode root; |
230 | 4.09k | root.start = 0; |
231 | 4.09k | root.num = intervals.size(); |
232 | 4.09k | patch_tree_.push_back(root); |
233 | 4.09k | size_t next = 0; |
234 | 276k | while (next < patch_tree_.size()) { |
235 | 272k | auto& node = patch_tree_[next]; |
236 | 272k | size_t start = node.start; |
237 | 272k | size_t end = node.start + node.num; |
238 | | // Choose the y_center for this node to be the median of interval starts. |
239 | 272k | sort_by_y0(start, end); |
240 | 272k | size_t middle_idx = start + node.num / 2; |
241 | 272k | node.y_center = intervals[middle_idx].y0; |
242 | | // Divide the intervals in [start, end) into three groups: |
243 | | // * those completely to the right of y_center: [right_start, end) |
244 | | // * those overlapping y_center: [left_end, right_start) |
245 | | // * those completely to the left of y_center: [start, left_end) |
246 | 272k | size_t right_start = middle_idx; |
247 | 1.97M | while (right_start < end && intervals[right_start].y0 == node.y_center) { |
248 | 1.70M | ++right_start; |
249 | 1.70M | } |
250 | 272k | sort_by_y1(start, right_start); |
251 | 272k | size_t left_end = right_start; |
252 | 5.01M | while (left_end > start && intervals[left_end - 1].y1 > node.y_center) { |
253 | 4.74M | --left_end; |
254 | 4.74M | } |
255 | | // Fill in sorted_patches_y0_ and sorted_patches_y1_ for the current node. |
256 | 272k | node.num = right_start - left_end; |
257 | 272k | node.start = sorted_patches_y0_.size(); |
258 | 272k | for (ptrdiff_t i = static_cast<ptrdiff_t>(right_start) - 1; |
259 | 5.01M | i >= static_cast<ptrdiff_t>(left_end); --i) { |
260 | 4.74M | sorted_patches_y1_.emplace_back(intervals[i].y1, intervals[i].idx); |
261 | 4.74M | } |
262 | 272k | sort_by_y0(left_end, right_start); |
263 | 5.01M | for (size_t i = left_end; i < right_start; ++i) { |
264 | 4.74M | sorted_patches_y0_.emplace_back(intervals[i].y0, intervals[i].idx); |
265 | 4.74M | } |
266 | | // Create the left and right nodes (if not empty). |
267 | 272k | node.left_child = node.right_child = -1; |
268 | 272k | if (left_end > start) { |
269 | 129k | PatchTreeNode left; |
270 | 129k | left.start = start; |
271 | 129k | left.num = left_end - left.start; |
272 | 129k | patch_tree_[next].left_child = patch_tree_.size(); |
273 | 129k | patch_tree_.push_back(left); |
274 | 129k | } |
275 | 272k | if (right_start < end) { |
276 | 139k | PatchTreeNode right; |
277 | 139k | right.start = right_start; |
278 | 139k | right.num = end - right.start; |
279 | 139k | patch_tree_[next].right_child = patch_tree_.size(); |
280 | 139k | patch_tree_.push_back(right); |
281 | 139k | } |
282 | 272k | ++next; |
283 | 272k | } |
284 | 4.09k | } |
285 | | |
286 | 5.60M | std::vector<size_t> PatchDictionary::GetPatchesForRow(size_t y) const { |
287 | 5.60M | std::vector<size_t> result; |
288 | 5.60M | if (y < num_patches_.size() && num_patches_[y] > 0) { |
289 | 2.30M | result.reserve(num_patches_[y]); |
290 | 16.6M | for (ptrdiff_t tree_idx = 0; tree_idx != -1;) { |
291 | 14.3M | JXL_DASSERT(tree_idx < static_cast<ptrdiff_t>(patch_tree_.size())); |
292 | 14.3M | const auto& node = patch_tree_[tree_idx]; |
293 | 14.3M | if (y <= node.y_center) { |
294 | 30.6M | for (size_t i = 0; i < node.num; ++i) { |
295 | 29.5M | const auto& p = sorted_patches_y0_[node.start + i]; |
296 | 29.5M | if (y < p.first) break; |
297 | 23.3M | result.push_back(p.second); |
298 | 23.3M | } |
299 | 7.27M | tree_idx = y < node.y_center ? node.left_child : -1; |
300 | 7.27M | } else { |
301 | 29.6M | for (size_t i = 0; i < node.num; ++i) { |
302 | 28.9M | const auto& p = sorted_patches_y1_[node.start + i]; |
303 | 28.9M | if (y >= p.first) break; |
304 | 22.5M | result.push_back(p.second); |
305 | 22.5M | } |
306 | 7.07M | tree_idx = node.right_child; |
307 | 7.07M | } |
308 | 14.3M | } |
309 | | // Ensure that he relative order of patches that affect the same pixels is |
310 | | // preserved. This is important for patches that have a blend mode |
311 | | // different from kAdd. |
312 | 2.30M | std::sort(result.begin(), result.end()); |
313 | 2.30M | } |
314 | 5.60M | return result; |
315 | 5.60M | } |
316 | | |
317 | | // Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed |
318 | | // to be located at position (x0, y) in the frame. |
319 | | Status PatchDictionary::AddOneRow( |
320 | | float* const* inout, size_t y, size_t x0, size_t xsize, |
321 | 4.00M | const std::vector<ExtraChannelInfo>& extra_channel_info) const { |
322 | 4.00M | size_t num_ec = extra_channel_info.size(); |
323 | 4.00M | JXL_ENSURE(num_ec + 1 <= blendings_stride_); |
324 | 4.00M | std::vector<const float*> fg_ptrs(3 + num_ec); |
325 | 41.2M | for (size_t pos_idx : GetPatchesForRow(y)) { |
326 | 41.2M | const size_t blending_idx = pos_idx * blendings_stride_; |
327 | 41.2M | const PatchPosition& pos = positions_[pos_idx]; |
328 | 41.2M | const PatchReferencePosition& ref_pos = ref_positions_[pos.ref_pos_idx]; |
329 | 41.2M | size_t by = pos.y; |
330 | 41.2M | size_t bx = pos.x; |
331 | 41.2M | size_t patch_xsize = ref_pos.xsize; |
332 | 41.2M | JXL_ENSURE(y >= by); |
333 | 41.2M | JXL_ENSURE(y < by + ref_pos.ysize); |
334 | 41.2M | size_t iy = y - by; |
335 | 41.2M | size_t ref = ref_pos.ref; |
336 | 41.2M | if (bx >= x0 + xsize) continue; |
337 | 35.8M | if (bx + patch_xsize < x0) continue; |
338 | 29.5M | size_t patch_x0 = std::max(bx, x0); |
339 | 29.5M | size_t patch_x1 = std::min(bx + patch_xsize, x0 + xsize); |
340 | 118M | for (size_t c = 0; c < 3; c++) { |
341 | 88.7M | fg_ptrs[c] = reference_frames_->at(ref).frame->color()->ConstPlaneRow( |
342 | 88.7M | c, ref_pos.y0 + iy) + |
343 | 88.7M | ref_pos.x0 + x0 - bx; |
344 | 88.7M | } |
345 | 29.5M | for (size_t i = 0; i < num_ec; i++) { |
346 | 2.07k | fg_ptrs[3 + i] = |
347 | 2.07k | reference_frames_->at(ref).frame->extra_channels()[i].ConstRow( |
348 | 2.07k | ref_pos.y0 + iy) + |
349 | 2.07k | ref_pos.x0 + x0 - bx; |
350 | 2.07k | } |
351 | 29.5M | JXL_RETURN_IF_ERROR(PerformBlending( |
352 | 29.5M | memory_manager_, inout, fg_ptrs.data(), inout, patch_x0 - x0, |
353 | 29.5M | patch_x1 - patch_x0, blendings_[blending_idx], |
354 | 29.5M | blendings_.data() + blending_idx + 1, extra_channel_info)); |
355 | 29.5M | } |
356 | 4.00M | return true; |
357 | 4.00M | } |
358 | | } // namespace jxl |