Coverage Report

Created: 2025-06-16 07:00

/src/libjxl/lib/jxl/dec_patch_dictionary.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/dec_patch_dictionary.h"
7
8
#include <jxl/memory_manager.h>
9
10
#include <algorithm>
11
#include <cstdint>
12
#include <cstdlib>
13
#include <utility>
14
#include <vector>
15
16
#include "lib/jxl/base/compiler_specific.h"  // ssize_t
17
#include "lib/jxl/base/printf_macros.h"
18
#include "lib/jxl/base/status.h"
19
#include "lib/jxl/blending.h"
20
#include "lib/jxl/common.h"  // kMaxNumReferenceFrames
21
#include "lib/jxl/dec_ans.h"
22
#include "lib/jxl/dec_bit_reader.h"
23
#include "lib/jxl/image.h"
24
#include "lib/jxl/image_bundle.h"
25
#include "lib/jxl/image_metadata.h"
26
#include "lib/jxl/pack_signed.h"
27
#include "lib/jxl/patch_dictionary_internal.h"
28
29
namespace jxl {
30
31
Status PatchDictionary::Decode(JxlMemoryManager* memory_manager, BitReader* br,
32
                               size_t xsize, size_t ysize,
33
                               size_t num_extra_channels,
34
464
                               bool* uses_extra_channels) {
35
464
  positions_.clear();
36
464
  blendings_stride_ = num_extra_channels + 1;
37
464
  std::vector<uint8_t> context_map;
38
464
  ANSCode code;
39
464
  JXL_RETURN_IF_ERROR(DecodeHistograms(
40
464
      memory_manager, br, kNumPatchDictionaryContexts, &code, &context_map));
41
898
  JXL_ASSIGN_OR_RETURN(ANSSymbolReader decoder,
42
898
                       ANSSymbolReader::Create(&code, br));
43
44
32.0k
  auto read_num = [&](size_t context) -> size_t {
45
32.0k
    size_t r = decoder.ReadHybridUint(context, br, context_map);
46
32.0k
    return r;
47
32.0k
  };
48
49
898
  size_t num_ref_patch = read_num(kNumRefPatchContext);
50
  // Limit max memory usage of patches to about 66 bytes per pixel (assuming 8
51
  // bytes per size_t)
52
898
  const size_t num_pixels = xsize * ysize;
53
898
  const size_t max_ref_patches = 1024 + num_pixels / 4;
54
898
  const size_t max_patches = max_ref_patches * 4;
55
898
  const size_t max_blending_infos = max_patches * 4;
56
898
  if (num_ref_patch > max_ref_patches) {
57
2
    return JXL_FAILURE("Too many patches in dictionary");
58
2
  }
59
60
447
  size_t total_patches = 0;
61
447
  size_t next_size = 1;
62
63
843
  for (size_t id = 0; id < num_ref_patch; id++) {
64
401
    PatchReferencePosition ref_pos;
65
401
    ref_pos.ref = read_num(kReferenceFrameContext);
66
401
    if (ref_pos.ref >= kMaxNumReferenceFrames ||
67
401
        reference_frames_->at(ref_pos.ref).frame->xsize() == 0) {
68
2
      return JXL_FAILURE("Invalid reference frame ID");
69
2
    }
70
399
    if (!reference_frames_->at(ref_pos.ref).ib_is_in_xyb) {
71
2
      return JXL_FAILURE(
72
2
          "Patches cannot use frames saved post color transforms");
73
2
    }
74
397
    const ImageBundle& ib = *reference_frames_->at(ref_pos.ref).frame;
75
397
    ref_pos.x0 = read_num(kPatchReferencePositionContext);
76
397
    ref_pos.y0 = read_num(kPatchReferencePositionContext);
77
397
    ref_pos.xsize = read_num(kPatchSizeContext) + 1;
78
397
    ref_pos.ysize = read_num(kPatchSizeContext) + 1;
79
397
    if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) {
80
0
      return JXL_FAILURE("Invalid position specified in reference frame");
81
0
    }
82
397
    if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) {
83
0
      return JXL_FAILURE("Invalid position specified in reference frame");
84
0
    }
85
397
    size_t id_count = read_num(kPatchCountContext);
86
397
    if (id_count > max_patches) {
87
0
      return JXL_FAILURE("Too many patches in dictionary");
88
0
    }
89
397
    id_count++;
90
397
    total_patches += id_count;
91
397
    if (total_patches > max_patches) {
92
0
      return JXL_FAILURE("Too many patches in dictionary");
93
0
    }
94
397
    if (next_size < total_patches) {
95
178
      next_size *= 2;
96
178
      next_size = std::min<size_t>(next_size, max_patches);
97
178
    }
98
397
    if (next_size * blendings_stride_ > max_blending_infos) {
99
0
      return JXL_FAILURE("Too many patches in dictionary");
100
0
    }
101
397
    positions_.reserve(next_size);
102
397
    blendings_.reserve(next_size * blendings_stride_);
103
397
    bool choose_alpha = (num_extra_channels > 1);
104
9.21k
    for (size_t i = 0; i < id_count; i++) {
105
8.82k
      PatchPosition pos;
106
8.82k
      pos.ref_pos_idx = ref_positions_.size();
107
8.82k
      if (i == 0) {
108
397
        pos.x = read_num(kPatchPositionContext);
109
397
        pos.y = read_num(kPatchPositionContext);
110
8.42k
      } else {
111
8.42k
        ssize_t deltax = UnpackSigned(read_num(kPatchOffsetContext));
112
8.42k
        if (deltax < 0 && static_cast<size_t>(-deltax) > positions_.back().x) {
113
0
          return JXL_FAILURE("Invalid patch: negative x coordinate (%" PRIuS
114
0
                             " base x %" PRIdS " delta x)",
115
0
                             positions_.back().x, deltax);
116
0
        }
117
8.42k
        pos.x = positions_.back().x + deltax;
118
8.42k
        ssize_t deltay = UnpackSigned(read_num(kPatchOffsetContext));
119
8.42k
        if (deltay < 0 && static_cast<size_t>(-deltay) > positions_.back().y) {
120
0
          return JXL_FAILURE("Invalid patch: negative y coordinate (%" PRIuS
121
0
                             " base y %" PRIdS " delta y)",
122
0
                             positions_.back().y, deltay);
123
0
        }
124
8.42k
        pos.y = positions_.back().y + deltay;
125
8.42k
      }
126
8.82k
      if (pos.x + ref_pos.xsize > xsize) {
127
0
        return JXL_FAILURE("Invalid patch x: at %" PRIuS " + %" PRIuS
128
0
                           " > %" PRIuS,
129
0
                           pos.x, ref_pos.xsize, xsize);
130
0
      }
131
8.82k
      if (pos.y + ref_pos.ysize > ysize) {
132
0
        return JXL_FAILURE("Invalid patch y: at %" PRIuS " + %" PRIuS
133
0
                           " > %" PRIuS,
134
0
                           pos.y, ref_pos.ysize, ysize);
135
0
      }
136
17.7k
      for (size_t j = 0; j < blendings_stride_; j++) {
137
8.91k
        uint32_t blend_mode = read_num(kPatchBlendModeContext);
138
8.91k
        if (blend_mode >= kNumPatchBlendModes) {
139
0
          return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode);
140
0
        }
141
8.91k
        PatchBlending info;
142
8.91k
        info.mode = static_cast<PatchBlendMode>(blend_mode);
143
8.91k
        if (UsesAlpha(info.mode)) {
144
1.63k
          *uses_extra_channels = true;
145
1.63k
        }
146
8.91k
        if (info.mode != PatchBlendMode::kNone && j > 0) {
147
33
          *uses_extra_channels = true;
148
33
        }
149
8.91k
        if (UsesAlpha(info.mode) && choose_alpha) {
150
1
          info.alpha_channel = read_num(kPatchAlphaChannelContext);
151
1
          if (info.alpha_channel >= num_extra_channels) {
152
1
            return JXL_FAILURE(
153
1
                "Invalid alpha channel for blending: %u out of %u\n",
154
1
                info.alpha_channel, static_cast<uint32_t>(num_extra_channels));
155
1
          }
156
8.91k
        } else {
157
8.91k
          info.alpha_channel = 0;
158
8.91k
        }
159
8.91k
        if (UsesClamp(info.mode)) {
160
2.62k
          info.clamp = static_cast<bool>(read_num(kPatchClampContext));
161
6.29k
        } else {
162
6.29k
          info.clamp = false;
163
6.29k
        }
164
8.91k
        blendings_.push_back(info);
165
8.91k
      }
166
8.82k
      positions_.emplace_back(pos);
167
8.82k
    }
168
396
    ref_positions_.emplace_back(ref_pos);
169
396
  }
170
442
  positions_.shrink_to_fit();
171
172
442
  if (!decoder.CheckANSFinalState()) {
173
0
    return JXL_FAILURE("ANS checksum failure.");
174
0
  }
175
176
442
  ComputePatchTree();
177
442
  return true;
178
442
}
179
180
275
int PatchDictionary::GetReferences() const {
181
275
  int result = 0;
182
1.21k
  for (const auto& ref_pos : ref_positions_) {
183
1.21k
    result |= (1 << static_cast<int>(ref_pos.ref));
184
1.21k
  }
185
275
  return result;
186
275
}
187
188
namespace {
189
struct PatchInterval {
190
  size_t idx;
191
  size_t y0, y1;
192
};
193
}  // namespace
194
195
25.6k
void PatchDictionary::ComputePatchTree() {
196
25.6k
  patch_tree_.clear();
197
25.6k
  num_patches_.clear();
198
25.6k
  sorted_patches_y0_.clear();
199
25.6k
  sorted_patches_y1_.clear();
200
25.6k
  if (positions_.empty()) {
201
25.3k
    return;
202
25.3k
  }
203
  // Create a y-interval for each patch.
204
273
  std::vector<PatchInterval> intervals(positions_.size());
205
757k
  for (size_t i = 0; i < positions_.size(); ++i) {
206
756k
    const auto& pos = positions_[i];
207
756k
    intervals[i].idx = i;
208
756k
    intervals[i].y0 = pos.y;
209
756k
    intervals[i].y1 = pos.y + ref_positions_[pos.ref_pos_idx].ysize;
210
756k
  }
211
42.5k
  auto sort_by_y0 = [&intervals](size_t start, size_t end) {
212
42.5k
    std::sort(intervals.data() + start, intervals.data() + end,
213
29.2M
              [](const PatchInterval& i0, const PatchInterval& i1) {
214
29.2M
                return i0.y0 < i1.y0;
215
29.2M
              });
216
42.5k
  };
217
21.5k
  auto sort_by_y1 = [&intervals](size_t start, size_t end) {
218
21.5k
    std::sort(intervals.data() + start, intervals.data() + end,
219
22.1M
              [](const PatchInterval& i0, const PatchInterval& i1) {
220
22.1M
                return i0.y1 < i1.y1;
221
22.1M
              });
222
21.5k
  };
223
  // Count the number of patches for each row.
224
273
  sort_by_y1(0, intervals.size());
225
273
  num_patches_.resize(intervals.back().y1);
226
756k
  for (auto iv : intervals) {
227
2.08M
    for (size_t y = iv.y0; y < iv.y1; ++y) num_patches_[y]++;
228
756k
  }
229
273
  PatchTreeNode root;
230
273
  root.start = 0;
231
273
  root.num = intervals.size();
232
273
  patch_tree_.push_back(root);
233
273
  size_t next = 0;
234
21.5k
  while (next < patch_tree_.size()) {
235
21.2k
    auto& node = patch_tree_[next];
236
21.2k
    size_t start = node.start;
237
21.2k
    size_t end = node.start + node.num;
238
    // Choose the y_center for this node to be the median of interval starts.
239
21.2k
    sort_by_y0(start, end);
240
21.2k
    size_t middle_idx = start + node.num / 2;
241
21.2k
    node.y_center = intervals[middle_idx].y0;
242
    // Divide the intervals in [start, end) into three groups:
243
    //   * those completely to the right of y_center: [right_start, end)
244
    //   * those overlapping y_center: [left_end, right_start)
245
    //   * those completely to the left of y_center: [start, left_end)
246
21.2k
    size_t right_start = middle_idx;
247
367k
    while (right_start < end && intervals[right_start].y0 == node.y_center) {
248
345k
      ++right_start;
249
345k
    }
250
21.2k
    sort_by_y1(start, right_start);
251
21.2k
    size_t left_end = right_start;
252
778k
    while (left_end > start && intervals[left_end - 1].y1 > node.y_center) {
253
756k
      --left_end;
254
756k
    }
255
    // Fill in sorted_patches_y0_ and sorted_patches_y1_ for the current node.
256
21.2k
    node.num = right_start - left_end;
257
21.2k
    node.start = sorted_patches_y0_.size();
258
21.2k
    for (ssize_t i = static_cast<ssize_t>(right_start) - 1;
259
778k
         i >= static_cast<ssize_t>(left_end); --i) {
260
756k
      sorted_patches_y1_.emplace_back(intervals[i].y1, intervals[i].idx);
261
756k
    }
262
21.2k
    sort_by_y0(left_end, right_start);
263
778k
    for (size_t i = left_end; i < right_start; ++i) {
264
756k
      sorted_patches_y0_.emplace_back(intervals[i].y0, intervals[i].idx);
265
756k
    }
266
    // Create the left and right nodes (if not empty).
267
21.2k
    node.left_child = node.right_child = -1;
268
21.2k
    if (left_end > start) {
269
10.0k
      PatchTreeNode left;
270
10.0k
      left.start = start;
271
10.0k
      left.num = left_end - left.start;
272
10.0k
      patch_tree_[next].left_child = patch_tree_.size();
273
10.0k
      patch_tree_.push_back(left);
274
10.0k
    }
275
21.2k
    if (right_start < end) {
276
10.9k
      PatchTreeNode right;
277
10.9k
      right.start = right_start;
278
10.9k
      right.num = end - right.start;
279
10.9k
      patch_tree_[next].right_child = patch_tree_.size();
280
10.9k
      patch_tree_.push_back(right);
281
10.9k
    }
282
21.2k
    ++next;
283
21.2k
  }
284
273
}
285
286
304k
std::vector<size_t> PatchDictionary::GetPatchesForRow(size_t y) const {
287
304k
  std::vector<size_t> result;
288
304k
  if (y < num_patches_.size() && num_patches_[y] > 0) {
289
232k
    result.reserve(num_patches_[y]);
290
1.93M
    for (ssize_t tree_idx = 0; tree_idx != -1;) {
291
1.70M
      JXL_DASSERT(tree_idx < static_cast<ssize_t>(patch_tree_.size()));
292
1.70M
      const auto& node = patch_tree_[tree_idx];
293
1.70M
      if (y <= node.y_center) {
294
7.13M
        for (size_t i = 0; i < node.num; ++i) {
295
6.98M
          const auto& p = sorted_patches_y0_[node.start + i];
296
6.98M
          if (y < p.first) break;
297
6.22M
          result.push_back(p.second);
298
6.22M
        }
299
902k
        tree_idx = y < node.y_center ? node.left_child : -1;
300
902k
      } else {
301
3.79M
        for (size_t i = 0; i < node.num; ++i) {
302
3.76M
          const auto& p = sorted_patches_y1_[node.start + i];
303
3.76M
          if (y >= p.first) break;
304
2.99M
          result.push_back(p.second);
305
2.99M
        }
306
801k
        tree_idx = node.right_child;
307
801k
      }
308
1.70M
    }
309
    // Ensure that he relative order of patches that affect the same pixels is
310
    // preserved. This is important for patches that have a blend mode
311
    // different from kAdd.
312
232k
    std::sort(result.begin(), result.end());
313
232k
  }
314
304k
  return result;
315
304k
}
316
317
// Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed
318
// to be located at position (x0, y) in the frame.
319
Status PatchDictionary::AddOneRow(
320
    float* const* inout, size_t y, size_t x0, size_t xsize,
321
239k
    const std::vector<ExtraChannelInfo>& extra_channel_info) const {
322
239k
  size_t num_ec = extra_channel_info.size();
323
239k
  JXL_ENSURE(num_ec + 1 <= blendings_stride_);
324
239k
  std::vector<const float*> fg_ptrs(3 + num_ec);
325
7.90M
  for (size_t pos_idx : GetPatchesForRow(y)) {
326
7.90M
    const size_t blending_idx = pos_idx * blendings_stride_;
327
7.90M
    const PatchPosition& pos = positions_[pos_idx];
328
7.90M
    const PatchReferencePosition& ref_pos = ref_positions_[pos.ref_pos_idx];
329
7.90M
    size_t by = pos.y;
330
7.90M
    size_t bx = pos.x;
331
7.90M
    size_t patch_xsize = ref_pos.xsize;
332
7.90M
    JXL_ENSURE(y >= by);
333
7.90M
    JXL_ENSURE(y < by + ref_pos.ysize);
334
7.90M
    size_t iy = y - by;
335
7.90M
    size_t ref = ref_pos.ref;
336
7.90M
    if (bx >= x0 + xsize) continue;
337
6.02M
    if (bx + patch_xsize < x0) continue;
338
3.98M
    size_t patch_x0 = std::max(bx, x0);
339
3.98M
    size_t patch_x1 = std::min(bx + patch_xsize, x0 + xsize);
340
15.9M
    for (size_t c = 0; c < 3; c++) {
341
11.9M
      fg_ptrs[c] = reference_frames_->at(ref).frame->color()->ConstPlaneRow(
342
11.9M
                       c, ref_pos.y0 + iy) +
343
11.9M
                   ref_pos.x0 + x0 - bx;
344
11.9M
    }
345
3.98M
    for (size_t i = 0; i < num_ec; i++) {
346
196
      fg_ptrs[3 + i] =
347
196
          reference_frames_->at(ref).frame->extra_channels()[i].ConstRow(
348
196
              ref_pos.y0 + iy) +
349
196
          ref_pos.x0 + x0 - bx;
350
196
    }
351
3.98M
    JXL_RETURN_IF_ERROR(PerformBlending(
352
3.98M
        memory_manager_, inout, fg_ptrs.data(), inout, patch_x0 - x0,
353
3.98M
        patch_x1 - patch_x0, blendings_[blending_idx],
354
3.98M
        blendings_.data() + blending_idx + 1, extra_channel_info));
355
3.98M
  }
356
239k
  return true;
357
239k
}
358
}  // namespace jxl