Coverage Report

Created: 2025-06-22 08:04

/src/libjxl/lib/jxl/dec_patch_dictionary.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/dec_patch_dictionary.h"
7
8
#include <jxl/memory_manager.h>
9
#include <sys/types.h>
10
11
#include <algorithm>
12
#include <cstdint>
13
#include <cstdlib>
14
#include <utility>
15
#include <vector>
16
17
#include "lib/jxl/base/printf_macros.h"
18
#include "lib/jxl/base/status.h"
19
#include "lib/jxl/blending.h"
20
#include "lib/jxl/common.h"  // kMaxNumReferenceFrames
21
#include "lib/jxl/dec_ans.h"
22
#include "lib/jxl/image.h"
23
#include "lib/jxl/image_bundle.h"
24
#include "lib/jxl/pack_signed.h"
25
#include "lib/jxl/patch_dictionary_internal.h"
26
27
namespace jxl {
28
29
Status PatchDictionary::Decode(JxlMemoryManager* memory_manager, BitReader* br,
30
                               size_t xsize, size_t ysize,
31
                               size_t num_extra_channels,
32
2.97k
                               bool* uses_extra_channels) {
33
2.97k
  positions_.clear();
34
2.97k
  blendings_stride_ = num_extra_channels + 1;
35
2.97k
  std::vector<uint8_t> context_map;
36
2.97k
  ANSCode code;
37
2.97k
  JXL_RETURN_IF_ERROR(DecodeHistograms(
38
2.97k
      memory_manager, br, kNumPatchDictionaryContexts, &code, &context_map));
39
5.91k
  JXL_ASSIGN_OR_RETURN(ANSSymbolReader decoder,
40
5.91k
                       ANSSymbolReader::Create(&code, br));
41
42
10.8M
  auto read_num = [&](size_t context) {
43
10.8M
    size_t r = decoder.ReadHybridUint(context, br, context_map);
44
10.8M
    return r;
45
10.8M
  };
46
47
5.91k
  size_t num_ref_patch = read_num(kNumRefPatchContext);
48
  // Limit max memory usage of patches to about 66 bytes per pixel (assuming 8
49
  // bytes per size_t)
50
5.91k
  const size_t num_pixels = xsize * ysize;
51
5.91k
  const size_t max_ref_patches = 1024 + num_pixels / 4;
52
5.91k
  const size_t max_patches = max_ref_patches * 4;
53
5.91k
  const size_t max_blending_infos = max_patches * 4;
54
5.91k
  if (num_ref_patch > max_ref_patches) {
55
10
    return JXL_FAILURE("Too many patches in dictionary");
56
10
  }
57
58
2.94k
  size_t total_patches = 0;
59
2.94k
  size_t next_size = 1;
60
61
94.4k
  for (size_t id = 0; id < num_ref_patch; id++) {
62
91.6k
    PatchReferencePosition ref_pos;
63
91.6k
    ref_pos.ref = read_num(kReferenceFrameContext);
64
91.6k
    if (ref_pos.ref >= kMaxNumReferenceFrames ||
65
91.6k
        reference_frames_->at(ref_pos.ref).frame->xsize() == 0) {
66
22
      return JXL_FAILURE("Invalid reference frame ID");
67
22
    }
68
91.6k
    if (!reference_frames_->at(ref_pos.ref).ib_is_in_xyb) {
69
1
      return JXL_FAILURE(
70
1
          "Patches cannot use frames saved post color transforms");
71
1
    }
72
91.6k
    const ImageBundle& ib = *reference_frames_->at(ref_pos.ref).frame;
73
91.6k
    ref_pos.x0 = read_num(kPatchReferencePositionContext);
74
91.6k
    ref_pos.y0 = read_num(kPatchReferencePositionContext);
75
91.6k
    ref_pos.xsize = read_num(kPatchSizeContext) + 1;
76
91.6k
    ref_pos.ysize = read_num(kPatchSizeContext) + 1;
77
91.6k
    if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) {
78
1
      return JXL_FAILURE("Invalid position specified in reference frame");
79
1
    }
80
91.6k
    if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) {
81
32
      return JXL_FAILURE("Invalid position specified in reference frame");
82
32
    }
83
91.5k
    size_t id_count = read_num(kPatchCountContext);
84
91.5k
    if (id_count > max_patches) {
85
3
      return JXL_FAILURE("Too many patches in dictionary");
86
3
    }
87
91.5k
    id_count++;
88
91.5k
    total_patches += id_count;
89
91.5k
    if (total_patches > max_patches) {
90
1
      return JXL_FAILURE("Too many patches in dictionary");
91
1
    }
92
91.5k
    if (next_size < total_patches) {
93
16.9k
      next_size *= 2;
94
16.9k
      next_size = std::min<size_t>(next_size, max_patches);
95
16.9k
    }
96
91.5k
    if (next_size * blendings_stride_ > max_blending_infos) {
97
0
      return JXL_FAILURE("Too many patches in dictionary");
98
0
    }
99
91.5k
    positions_.reserve(next_size);
100
91.5k
    blendings_.reserve(next_size * blendings_stride_);
101
91.5k
    bool choose_alpha = (num_extra_channels > 1);
102
3.50M
    for (size_t i = 0; i < id_count; i++) {
103
3.41M
      PatchPosition pos;
104
3.41M
      pos.ref_pos_idx = ref_positions_.size();
105
3.41M
      if (i == 0) {
106
91.5k
        pos.x = read_num(kPatchPositionContext);
107
91.5k
        pos.y = read_num(kPatchPositionContext);
108
3.32M
      } else {
109
3.32M
        ssize_t deltax = UnpackSigned(read_num(kPatchOffsetContext));
110
3.32M
        if (deltax < 0 && static_cast<size_t>(-deltax) > positions_.back().x) {
111
31
          return JXL_FAILURE("Invalid patch: negative x coordinate (%" PRIuS
112
31
                             " base x %" PRIdS " delta x)",
113
31
                             positions_.back().x, deltax);
114
31
        }
115
3.32M
        pos.x = positions_.back().x + deltax;
116
3.32M
        ssize_t deltay = UnpackSigned(read_num(kPatchOffsetContext));
117
3.32M
        if (deltay < 0 && static_cast<size_t>(-deltay) > positions_.back().y) {
118
2
          return JXL_FAILURE("Invalid patch: negative y coordinate (%" PRIuS
119
2
                             " base y %" PRIdS " delta y)",
120
2
                             positions_.back().y, deltay);
121
2
        }
122
3.32M
        pos.y = positions_.back().y + deltay;
123
3.32M
      }
124
3.41M
      if (pos.x + ref_pos.xsize > xsize) {
125
3
        return JXL_FAILURE("Invalid patch x: at %" PRIuS " + %" PRIuS
126
3
                           " > %" PRIuS,
127
3
                           pos.x, ref_pos.xsize, xsize);
128
3
      }
129
3.41M
      if (pos.y + ref_pos.ysize > ysize) {
130
7
        return JXL_FAILURE("Invalid patch y: at %" PRIuS " + %" PRIuS
131
7
                           " > %" PRIuS,
132
7
                           pos.y, ref_pos.ysize, ysize);
133
7
      }
134
6.84M
      for (size_t j = 0; j < blendings_stride_; j++) {
135
3.43M
        uint32_t blend_mode = read_num(kPatchBlendModeContext);
136
3.43M
        if (blend_mode >= kNumPatchBlendModes) {
137
9
          return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode);
138
9
        }
139
3.43M
        PatchBlending info;
140
3.43M
        info.mode = static_cast<PatchBlendMode>(blend_mode);
141
3.43M
        if (UsesAlpha(info.mode)) {
142
4.29k
          *uses_extra_channels = true;
143
4.29k
        }
144
3.43M
        if (info.mode != PatchBlendMode::kNone && j > 0) {
145
3.95k
          *uses_extra_channels = true;
146
3.95k
        }
147
3.43M
        if (UsesAlpha(info.mode) && choose_alpha) {
148
852
          info.alpha_channel = read_num(kPatchAlphaChannelContext);
149
852
          if (info.alpha_channel >= num_extra_channels) {
150
6
            return JXL_FAILURE(
151
6
                "Invalid alpha channel for blending: %u out of %u\n",
152
6
                info.alpha_channel, static_cast<uint32_t>(num_extra_channels));
153
6
          }
154
3.43M
        } else {
155
3.43M
          info.alpha_channel = 0;
156
3.43M
        }
157
3.43M
        if (UsesClamp(info.mode)) {
158
5.60k
          info.clamp = static_cast<bool>(read_num(kPatchClampContext));
159
3.42M
        } else {
160
3.42M
          info.clamp = false;
161
3.42M
        }
162
3.43M
        blendings_.push_back(info);
163
3.43M
      }
164
3.41M
      positions_.emplace_back(pos);
165
3.41M
    }
166
91.5k
    ref_positions_.emplace_back(ref_pos);
167
91.5k
  }
168
2.83k
  positions_.shrink_to_fit();
169
170
2.83k
  if (!decoder.CheckANSFinalState()) {
171
0
    return JXL_FAILURE("ANS checksum failure.");
172
0
  }
173
174
2.83k
  ComputePatchTree();
175
2.83k
  return true;
176
2.83k
}
177
178
2.16k
int PatchDictionary::GetReferences() const {
179
2.16k
  int result = 0;
180
265k
  for (const auto& ref_pos : ref_positions_) {
181
265k
    result |= (1 << static_cast<int>(ref_pos.ref));
182
265k
  }
183
2.16k
  return result;
184
2.16k
}
185
186
namespace {
187
struct PatchInterval {
188
  size_t idx;
189
  size_t y0, y1;
190
};
191
}  // namespace
192
193
43.1k
void PatchDictionary::ComputePatchTree() {
194
43.1k
  patch_tree_.clear();
195
43.1k
  num_patches_.clear();
196
43.1k
  sorted_patches_y0_.clear();
197
43.1k
  sorted_patches_y1_.clear();
198
43.1k
  if (positions_.empty()) {
199
40.9k
    return;
200
40.9k
  }
201
  // Create a y-interval for each patch.
202
2.19k
  std::vector<PatchInterval> intervals(positions_.size());
203
3.34M
  for (size_t i = 0; i < positions_.size(); ++i) {
204
3.34M
    const auto& pos = positions_[i];
205
3.34M
    intervals[i].idx = i;
206
3.34M
    intervals[i].y0 = pos.y;
207
3.34M
    intervals[i].y1 = pos.y + ref_positions_[pos.ref_pos_idx].ysize;
208
3.34M
  }
209
223k
  auto sort_by_y0 = [&intervals](size_t start, size_t end) {
210
223k
    std::sort(intervals.data() + start, intervals.data() + end,
211
97.3M
              [](const PatchInterval& i0, const PatchInterval& i1) {
212
97.3M
                return i0.y0 < i1.y0;
213
97.3M
              });
214
223k
  };
215
113k
  auto sort_by_y1 = [&intervals](size_t start, size_t end) {
216
113k
    std::sort(intervals.data() + start, intervals.data() + end,
217
83.2M
              [](const PatchInterval& i0, const PatchInterval& i1) {
218
83.2M
                return i0.y1 < i1.y1;
219
83.2M
              });
220
113k
  };
221
  // Count the number of patches for each row.
222
2.19k
  sort_by_y1(0, intervals.size());
223
2.19k
  num_patches_.resize(intervals.back().y1);
224
3.34M
  for (auto iv : intervals) {
225
81.1M
    for (size_t y = iv.y0; y < iv.y1; ++y) num_patches_[y]++;
226
3.34M
  }
227
2.19k
  PatchTreeNode root;
228
2.19k
  root.start = 0;
229
2.19k
  root.num = intervals.size();
230
2.19k
  patch_tree_.push_back(root);
231
2.19k
  size_t next = 0;
232
113k
  while (next < patch_tree_.size()) {
233
111k
    auto& node = patch_tree_[next];
234
111k
    size_t start = node.start;
235
111k
    size_t end = node.start + node.num;
236
    // Choose the y_center for this node to be the median of interval starts.
237
111k
    sort_by_y0(start, end);
238
111k
    size_t middle_idx = start + node.num / 2;
239
111k
    node.y_center = intervals[middle_idx].y0;
240
    // Divide the intervals in [start, end) into three groups:
241
    //   * those completely to the right of y_center: [right_start, end)
242
    //   * those overlapping y_center: [left_end, right_start)
243
    //   * those completely to the left of y_center: [start, left_end)
244
111k
    size_t right_start = middle_idx;
245
916k
    while (right_start < end && intervals[right_start].y0 == node.y_center) {
246
804k
      ++right_start;
247
804k
    }
248
111k
    sort_by_y1(start, right_start);
249
111k
    size_t left_end = right_start;
250
3.45M
    while (left_end > start && intervals[left_end - 1].y1 > node.y_center) {
251
3.34M
      --left_end;
252
3.34M
    }
253
    // Fill in sorted_patches_y0_ and sorted_patches_y1_ for the current node.
254
111k
    node.num = right_start - left_end;
255
111k
    node.start = sorted_patches_y0_.size();
256
111k
    for (ssize_t i = static_cast<ssize_t>(right_start) - 1;
257
3.45M
         i >= static_cast<ssize_t>(left_end); --i) {
258
3.34M
      sorted_patches_y1_.emplace_back(intervals[i].y1, intervals[i].idx);
259
3.34M
    }
260
111k
    sort_by_y0(left_end, right_start);
261
3.45M
    for (size_t i = left_end; i < right_start; ++i) {
262
3.34M
      sorted_patches_y0_.emplace_back(intervals[i].y0, intervals[i].idx);
263
3.34M
    }
264
    // Create the left and right nodes (if not empty).
265
111k
    node.left_child = node.right_child = -1;
266
111k
    if (left_end > start) {
267
53.9k
      PatchTreeNode left;
268
53.9k
      left.start = start;
269
53.9k
      left.num = left_end - left.start;
270
53.9k
      patch_tree_[next].left_child = patch_tree_.size();
271
53.9k
      patch_tree_.push_back(left);
272
53.9k
    }
273
111k
    if (right_start < end) {
274
55.4k
      PatchTreeNode right;
275
55.4k
      right.start = right_start;
276
55.4k
      right.num = end - right.start;
277
55.4k
      patch_tree_[next].right_child = patch_tree_.size();
278
55.4k
      patch_tree_.push_back(right);
279
55.4k
    }
280
111k
    ++next;
281
111k
  }
282
2.19k
}
283
284
297k
std::vector<size_t> PatchDictionary::GetPatchesForRow(size_t y) const {
285
297k
  std::vector<size_t> result;
286
297k
  if (y < num_patches_.size() && num_patches_[y] > 0) {
287
173k
    result.reserve(num_patches_[y]);
288
1.21M
    for (ssize_t tree_idx = 0; tree_idx != -1;) {
289
1.04M
      JXL_DASSERT(tree_idx < static_cast<ssize_t>(patch_tree_.size()));
290
1.04M
      const auto& node = patch_tree_[tree_idx];
291
1.04M
      if (y <= node.y_center) {
292
15.3M
        for (size_t i = 0; i < node.num; ++i) {
293
15.3M
          const auto& p = sorted_patches_y0_[node.start + i];
294
15.3M
          if (y < p.first) break;
295
15.0M
          result.push_back(p.second);
296
15.0M
        }
297
369k
        tree_idx = y < node.y_center ? node.left_child : -1;
298
671k
      } else {
299
33.2M
        for (size_t i = 0; i < node.num; ++i) {
300
33.1M
          const auto& p = sorted_patches_y1_[node.start + i];
301
33.1M
          if (y >= p.first) break;
302
32.6M
          result.push_back(p.second);
303
32.6M
        }
304
671k
        tree_idx = node.right_child;
305
671k
      }
306
1.04M
    }
307
    // Ensure that he relative order of patches that affect the same pixels is
308
    // preserved. This is important for patches that have a blend mode
309
    // different from kAdd.
310
173k
    std::sort(result.begin(), result.end());
311
173k
  }
312
297k
  return result;
313
297k
}
314
315
// Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed
316
// to be located at position (x0, y) in the frame.
317
Status PatchDictionary::AddOneRow(
318
    float* const* inout, size_t y, size_t x0, size_t xsize,
319
297k
    const std::vector<ExtraChannelInfo>& extra_channel_info) const {
320
297k
  size_t num_ec = extra_channel_info.size();
321
297k
  JXL_ENSURE(num_ec + 1 <= blendings_stride_);
322
297k
  std::vector<const float*> fg_ptrs(3 + num_ec);
323
47.6M
  for (size_t pos_idx : GetPatchesForRow(y)) {
324
47.6M
    const size_t blending_idx = pos_idx * blendings_stride_;
325
47.6M
    const PatchPosition& pos = positions_[pos_idx];
326
47.6M
    const PatchReferencePosition& ref_pos = ref_positions_[pos.ref_pos_idx];
327
47.6M
    size_t by = pos.y;
328
47.6M
    size_t bx = pos.x;
329
47.6M
    size_t patch_xsize = ref_pos.xsize;
330
47.6M
    JXL_ENSURE(y >= by);
331
47.6M
    JXL_ENSURE(y < by + ref_pos.ysize);
332
47.6M
    size_t iy = y - by;
333
47.6M
    size_t ref = ref_pos.ref;
334
47.6M
    if (bx >= x0 + xsize) continue;
335
47.6M
    if (bx + patch_xsize < x0) continue;
336
47.6M
    size_t patch_x0 = std::max(bx, x0);
337
47.6M
    size_t patch_x1 = std::min(bx + patch_xsize, x0 + xsize);
338
190M
    for (size_t c = 0; c < 3; c++) {
339
142M
      fg_ptrs[c] = reference_frames_->at(ref).frame->color()->ConstPlaneRow(
340
142M
                       c, ref_pos.y0 + iy) +
341
142M
                   ref_pos.x0 + x0 - bx;
342
142M
    }
343
47.7M
    for (size_t i = 0; i < num_ec; i++) {
344
165k
      fg_ptrs[3 + i] =
345
165k
          reference_frames_->at(ref).frame->extra_channels()[i].ConstRow(
346
165k
              ref_pos.y0 + iy) +
347
165k
          ref_pos.x0 + x0 - bx;
348
165k
    }
349
47.6M
    JXL_RETURN_IF_ERROR(PerformBlending(
350
47.6M
        memory_manager_, inout, fg_ptrs.data(), inout, patch_x0 - x0,
351
47.6M
        patch_x1 - patch_x0, blendings_[blending_idx],
352
47.6M
        blendings_.data() + blending_idx + 1, extra_channel_info));
353
47.6M
  }
354
297k
  return true;
355
297k
}
356
}  // namespace jxl