Coverage Report

Created: 2026-02-14 07:11

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/libjxl/lib/jxl/dec_patch_dictionary.cc
Line
Count
Source
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/dec_patch_dictionary.h"
7
8
#include <jxl/memory_manager.h>
9
10
#include <algorithm>
11
#include <cstdint>
12
#include <cstdlib>
13
#include <utility>
14
#include <vector>
15
16
#include "lib/jxl/base/printf_macros.h"
17
#include "lib/jxl/base/status.h"
18
#include "lib/jxl/blending.h"
19
#include "lib/jxl/common.h"  // kMaxNumReferenceFrames
20
#include "lib/jxl/dec_ans.h"
21
#include "lib/jxl/dec_bit_reader.h"
22
#include "lib/jxl/image.h"
23
#include "lib/jxl/image_bundle.h"
24
#include "lib/jxl/image_metadata.h"
25
#include "lib/jxl/pack_signed.h"
26
#include "lib/jxl/patch_dictionary_internal.h"
27
28
namespace jxl {
29
30
Status PatchDictionary::Decode(JxlMemoryManager* memory_manager, BitReader* br,
31
                               size_t xsize, size_t ysize,
32
                               size_t num_extra_channels,
33
1.20k
                               bool* uses_extra_channels) {
34
1.20k
  positions_.clear();
35
1.20k
  blendings_stride_ = num_extra_channels + 1;
36
1.20k
  std::vector<uint8_t> context_map;
37
1.20k
  ANSCode code;
38
1.20k
  JXL_RETURN_IF_ERROR(DecodeHistograms(
39
1.20k
      memory_manager, br, kNumPatchDictionaryContexts, &code, &context_map));
40
2.38k
  JXL_ASSIGN_OR_RETURN(ANSSymbolReader decoder,
41
2.38k
                       ANSSymbolReader::Create(&code, br));
42
43
4.57M
  auto read_num = [&](size_t context) -> size_t {
44
4.57M
    size_t r = decoder.ReadHybridUint(context, br, context_map);
45
4.57M
    return r;
46
4.57M
  };
47
48
2.38k
  size_t num_ref_patch = read_num(kNumRefPatchContext);
49
  // Limit max memory usage of patches to about 66 bytes per pixel (assuming 8
50
  // bytes per size_t)
51
2.38k
  const size_t num_pixels = xsize * ysize;
52
2.38k
  const size_t max_ref_patches = 1024 + num_pixels / 4;
53
2.38k
  const size_t max_patches = max_ref_patches * 4;
54
2.38k
  const size_t max_blending_infos = max_patches * 4;
55
2.38k
  if (num_ref_patch > max_ref_patches) {
56
2
    return JXL_FAILURE("Too many patches in dictionary");
57
2
  }
58
59
1.19k
  size_t total_patches = 0;
60
1.19k
  size_t next_size = 1;
61
62
36.5k
  for (size_t id = 0; id < num_ref_patch; id++) {
63
35.3k
    PatchReferencePosition ref_pos;
64
35.3k
    ref_pos.ref = read_num(kReferenceFrameContext);
65
35.3k
    if (ref_pos.ref >= kMaxNumReferenceFrames ||
66
35.3k
        reference_frames_->at(ref_pos.ref).frame->xsize() == 0) {
67
8
      return JXL_FAILURE("Invalid reference frame ID");
68
8
    }
69
35.3k
    if (!reference_frames_->at(ref_pos.ref).ib_is_in_xyb) {
70
2
      return JXL_FAILURE(
71
2
          "Patches cannot use frames saved post color transforms");
72
2
    }
73
35.3k
    const ImageBundle& ib = *reference_frames_->at(ref_pos.ref).frame;
74
35.3k
    ref_pos.x0 = read_num(kPatchReferencePositionContext);
75
35.3k
    ref_pos.y0 = read_num(kPatchReferencePositionContext);
76
35.3k
    ref_pos.xsize = read_num(kPatchSizeContext) + 1;
77
35.3k
    ref_pos.ysize = read_num(kPatchSizeContext) + 1;
78
35.3k
    if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) {
79
2
      return JXL_FAILURE("Invalid position specified in reference frame");
80
2
    }
81
35.3k
    if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) {
82
1
      return JXL_FAILURE("Invalid position specified in reference frame");
83
1
    }
84
35.3k
    size_t id_count = read_num(kPatchCountContext);
85
35.3k
    if (id_count > max_patches) {
86
0
      return JXL_FAILURE("Too many patches in dictionary");
87
0
    }
88
35.3k
    id_count++;
89
35.3k
    total_patches += id_count;
90
35.3k
    if (total_patches > max_patches) {
91
2
      return JXL_FAILURE("Too many patches in dictionary");
92
2
    }
93
35.3k
    if (next_size < total_patches) {
94
7.55k
      next_size *= 2;
95
7.55k
      next_size = std::min<size_t>(next_size, max_patches);
96
7.55k
    }
97
35.3k
    if (next_size * blendings_stride_ > max_blending_infos) {
98
0
      return JXL_FAILURE("Too many patches in dictionary");
99
0
    }
100
35.3k
    positions_.reserve(next_size);
101
35.3k
    blendings_.reserve(next_size * blendings_stride_);
102
35.3k
    bool choose_alpha = (num_extra_channels > 1);
103
1.48M
    for (size_t i = 0; i < id_count; i++) {
104
1.44M
      PatchPosition pos;
105
1.44M
      pos.ref_pos_idx = ref_positions_.size();
106
1.44M
      if (i == 0) {
107
35.3k
        pos.x = read_num(kPatchPositionContext);
108
35.3k
        pos.y = read_num(kPatchPositionContext);
109
1.41M
      } else {
110
1.41M
        ptrdiff_t deltax = UnpackSigned(read_num(kPatchOffsetContext));
111
1.41M
        if (deltax < 0 && static_cast<size_t>(-deltax) > positions_.back().x) {
112
5
          return JXL_FAILURE("Invalid patch: negative x coordinate (%" PRIuS
113
5
                             " base x %" PRIdS " delta x)",
114
5
                             positions_.back().x, deltax);
115
5
        }
116
1.41M
        pos.x = positions_.back().x + deltax;
117
1.41M
        ptrdiff_t deltay = UnpackSigned(read_num(kPatchOffsetContext));
118
1.41M
        if (deltay < 0 && static_cast<size_t>(-deltay) > positions_.back().y) {
119
5
          return JXL_FAILURE("Invalid patch: negative y coordinate (%" PRIuS
120
5
                             " base y %" PRIdS " delta y)",
121
5
                             positions_.back().y, deltay);
122
5
        }
123
1.41M
        pos.y = positions_.back().y + deltay;
124
1.41M
      }
125
1.44M
      if (pos.x + ref_pos.xsize > xsize) {
126
1
        return JXL_FAILURE("Invalid patch x: at %" PRIuS " + %" PRIuS
127
1
                           " > %" PRIuS,
128
1
                           pos.x, ref_pos.xsize, xsize);
129
1
      }
130
1.44M
      if (pos.y + ref_pos.ysize > ysize) {
131
1
        return JXL_FAILURE("Invalid patch y: at %" PRIuS " + %" PRIuS
132
1
                           " > %" PRIuS,
133
1
                           pos.y, ref_pos.ysize, ysize);
134
1
      }
135
2.90M
      for (size_t j = 0; j < blendings_stride_; j++) {
136
1.45M
        uint32_t blend_mode = read_num(kPatchBlendModeContext);
137
1.45M
        if (blend_mode >= kNumPatchBlendModes) {
138
3
          return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode);
139
3
        }
140
1.45M
        PatchBlending info;
141
1.45M
        info.mode = static_cast<PatchBlendMode>(blend_mode);
142
1.45M
        if (UsesAlpha(info.mode)) {
143
10.1k
          *uses_extra_channels = true;
144
10.1k
        }
145
1.45M
        if (info.mode != PatchBlendMode::kNone && j > 0) {
146
1.01k
          *uses_extra_channels = true;
147
1.01k
        }
148
1.45M
        if (UsesAlpha(info.mode) && choose_alpha) {
149
701
          info.alpha_channel = read_num(kPatchAlphaChannelContext);
150
701
          if (info.alpha_channel >= num_extra_channels) {
151
4
            return JXL_FAILURE(
152
4
                "Invalid alpha channel for blending: %u out of %u\n",
153
4
                info.alpha_channel, static_cast<uint32_t>(num_extra_channels));
154
4
          }
155
1.45M
        } else {
156
1.45M
          info.alpha_channel = 0;
157
1.45M
        }
158
1.45M
        if (UsesClamp(info.mode)) {
159
13.7k
          info.clamp = static_cast<bool>(read_num(kPatchClampContext));
160
1.43M
        } else {
161
1.43M
          info.clamp = false;
162
1.43M
        }
163
1.45M
        blendings_.push_back(info);
164
1.45M
      }
165
1.44M
      positions_.emplace_back(pos);
166
1.44M
    }
167
35.3k
    ref_positions_.emplace_back(ref_pos);
168
35.3k
  }
169
1.15k
  positions_.shrink_to_fit();
170
171
1.15k
  if (!decoder.CheckANSFinalState()) {
172
0
    return JXL_FAILURE("ANS checksum failure.");
173
0
  }
174
175
1.15k
  ComputePatchTree();
176
1.15k
  return true;
177
1.15k
}
178
179
859
int PatchDictionary::GetReferences() const {
180
859
  int result = 0;
181
119k
  for (const auto& ref_pos : ref_positions_) {
182
119k
    result |= (1 << static_cast<int>(ref_pos.ref));
183
119k
  }
184
859
  return result;
185
859
}
186
187
namespace {
188
struct PatchInterval {
189
  size_t idx;
190
  size_t y0, y1;
191
};
192
}  // namespace
193
194
33.2k
void PatchDictionary::ComputePatchTree() {
195
33.2k
  patch_tree_.clear();
196
33.2k
  num_patches_.clear();
197
33.2k
  sorted_patches_y0_.clear();
198
33.2k
  sorted_patches_y1_.clear();
199
33.2k
  if (positions_.empty()) {
200
31.6k
    return;
201
31.6k
  }
202
  // Create a y-interval for each patch.
203
1.63k
  std::vector<PatchInterval> intervals(positions_.size());
204
2.76M
  for (size_t i = 0; i < positions_.size(); ++i) {
205
2.76M
    const auto& pos = positions_[i];
206
2.76M
    intervals[i].idx = i;
207
2.76M
    intervals[i].y0 = pos.y;
208
2.76M
    intervals[i].y1 = pos.y + ref_positions_[pos.ref_pos_idx].ysize;
209
2.76M
  }
210
298k
  auto sort_by_y0 = [&intervals](size_t start, size_t end) {
211
298k
    std::sort(intervals.data() + start, intervals.data() + end,
212
92.4M
              [](const PatchInterval& i0, const PatchInterval& i1) {
213
92.4M
                return i0.y0 < i1.y0;
214
92.4M
              });
215
298k
  };
216
150k
  auto sort_by_y1 = [&intervals](size_t start, size_t end) {
217
150k
    std::sort(intervals.data() + start, intervals.data() + end,
218
76.8M
              [](const PatchInterval& i0, const PatchInterval& i1) {
219
76.8M
                return i0.y1 < i1.y1;
220
76.8M
              });
221
150k
  };
222
  // Count the number of patches for each row.
223
1.63k
  sort_by_y1(0, intervals.size());
224
1.63k
  num_patches_.resize(intervals.back().y1);
225
2.76M
  for (auto iv : intervals) {
226
41.3M
    for (size_t y = iv.y0; y < iv.y1; ++y) num_patches_[y]++;
227
2.76M
  }
228
1.63k
  PatchTreeNode root;
229
1.63k
  root.start = 0;
230
1.63k
  root.num = intervals.size();
231
1.63k
  patch_tree_.push_back(root);
232
1.63k
  size_t next = 0;
233
150k
  while (next < patch_tree_.size()) {
234
149k
    auto& node = patch_tree_[next];
235
149k
    size_t start = node.start;
236
149k
    size_t end = node.start + node.num;
237
    // Choose the y_center for this node to be the median of interval starts.
238
149k
    sort_by_y0(start, end);
239
149k
    size_t middle_idx = start + node.num / 2;
240
149k
    node.y_center = intervals[middle_idx].y0;
241
    // Divide the intervals in [start, end) into three groups:
242
    //   * those completely to the right of y_center: [right_start, end)
243
    //   * those overlapping y_center: [left_end, right_start)
244
    //   * those completely to the left of y_center: [start, left_end)
245
149k
    size_t right_start = middle_idx;
246
1.08M
    while (right_start < end && intervals[right_start].y0 == node.y_center) {
247
935k
      ++right_start;
248
935k
    }
249
149k
    sort_by_y1(start, right_start);
250
149k
    size_t left_end = right_start;
251
2.91M
    while (left_end > start && intervals[left_end - 1].y1 > node.y_center) {
252
2.76M
      --left_end;
253
2.76M
    }
254
    // Fill in sorted_patches_y0_ and sorted_patches_y1_ for the current node.
255
149k
    node.num = right_start - left_end;
256
149k
    node.start = sorted_patches_y0_.size();
257
149k
    for (ptrdiff_t i = static_cast<ptrdiff_t>(right_start) - 1;
258
2.91M
         i >= static_cast<ptrdiff_t>(left_end); --i) {
259
2.76M
      sorted_patches_y1_.emplace_back(intervals[i].y1, intervals[i].idx);
260
2.76M
    }
261
149k
    sort_by_y0(left_end, right_start);
262
2.91M
    for (size_t i = left_end; i < right_start; ++i) {
263
2.76M
      sorted_patches_y0_.emplace_back(intervals[i].y0, intervals[i].idx);
264
2.76M
    }
265
    // Create the left and right nodes (if not empty).
266
149k
    node.left_child = node.right_child = -1;
267
149k
    if (left_end > start) {
268
72.2k
      PatchTreeNode left;
269
72.2k
      left.start = start;
270
72.2k
      left.num = left_end - left.start;
271
72.2k
      patch_tree_[next].left_child = patch_tree_.size();
272
72.2k
      patch_tree_.push_back(left);
273
72.2k
    }
274
149k
    if (right_start < end) {
275
75.3k
      PatchTreeNode right;
276
75.3k
      right.start = right_start;
277
75.3k
      right.num = end - right.start;
278
75.3k
      patch_tree_[next].right_child = patch_tree_.size();
279
75.3k
      patch_tree_.push_back(right);
280
75.3k
    }
281
149k
    ++next;
282
149k
  }
283
1.63k
}
284
285
2.77M
std::vector<size_t> PatchDictionary::GetPatchesForRow(size_t y) const {
286
2.77M
  std::vector<size_t> result;
287
2.77M
  if (y < num_patches_.size() && num_patches_[y] > 0) {
288
1.36M
    result.reserve(num_patches_[y]);
289
10.8M
    for (ptrdiff_t tree_idx = 0; tree_idx != -1;) {
290
9.46M
      JXL_DASSERT(tree_idx < static_cast<ptrdiff_t>(patch_tree_.size()));
291
9.46M
      const auto& node = patch_tree_[tree_idx];
292
9.46M
      if (y <= node.y_center) {
293
24.2M
        for (size_t i = 0; i < node.num; ++i) {
294
23.4M
          const auto& p = sorted_patches_y0_[node.start + i];
295
23.4M
          if (y < p.first) break;
296
19.3M
          result.push_back(p.second);
297
19.3M
        }
298
4.86M
        tree_idx = y < node.y_center ? node.left_child : -1;
299
4.86M
      } else {
300
26.6M
        for (size_t i = 0; i < node.num; ++i) {
301
26.3M
          const auto& p = sorted_patches_y1_[node.start + i];
302
26.3M
          if (y >= p.first) break;
303
22.0M
          result.push_back(p.second);
304
22.0M
        }
305
4.60M
        tree_idx = node.right_child;
306
4.60M
      }
307
9.46M
    }
308
    // Ensure that he relative order of patches that affect the same pixels is
309
    // preserved. This is important for patches that have a blend mode
310
    // different from kAdd.
311
1.36M
    std::sort(result.begin(), result.end());
312
1.36M
  }
313
2.77M
  return result;
314
2.77M
}
315
316
// Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed
317
// to be located at position (x0, y) in the frame.
318
Status PatchDictionary::AddOneRow(
319
    float* const* inout, size_t y, size_t x0, size_t xsize,
320
2.06M
    const std::vector<ExtraChannelInfo>& extra_channel_info) const {
321
2.06M
  size_t num_ec = extra_channel_info.size();
322
2.06M
  JXL_ENSURE(num_ec + 1 <= blendings_stride_);
323
2.06M
  std::vector<const float*> fg_ptrs(3 + num_ec);
324
38.5M
  for (size_t pos_idx : GetPatchesForRow(y)) {
325
38.5M
    const size_t blending_idx = pos_idx * blendings_stride_;
326
38.5M
    const PatchPosition& pos = positions_[pos_idx];
327
38.5M
    const PatchReferencePosition& ref_pos = ref_positions_[pos.ref_pos_idx];
328
38.5M
    size_t by = pos.y;
329
38.5M
    size_t bx = pos.x;
330
38.5M
    size_t patch_xsize = ref_pos.xsize;
331
38.5M
    JXL_ENSURE(y >= by);
332
38.5M
    JXL_ENSURE(y < by + ref_pos.ysize);
333
38.5M
    size_t iy = y - by;
334
38.5M
    size_t ref = ref_pos.ref;
335
38.5M
    if (bx >= x0 + xsize) continue;
336
34.6M
    if (bx + patch_xsize < x0) continue;
337
30.0M
    size_t patch_x0 = std::max(bx, x0);
338
30.0M
    size_t patch_x1 = std::min(bx + patch_xsize, x0 + xsize);
339
120M
    for (size_t c = 0; c < 3; c++) {
340
90.1M
      fg_ptrs[c] = reference_frames_->at(ref).frame->color()->ConstPlaneRow(
341
90.1M
                       c, ref_pos.y0 + iy) +
342
90.1M
                   ref_pos.x0 + x0 - bx;
343
90.1M
    }
344
30.0M
    for (size_t i = 0; i < num_ec; i++) {
345
142
      fg_ptrs[3 + i] =
346
142
          reference_frames_->at(ref).frame->extra_channels()[i].ConstRow(
347
142
              ref_pos.y0 + iy) +
348
142
          ref_pos.x0 + x0 - bx;
349
142
    }
350
30.0M
    JXL_RETURN_IF_ERROR(PerformBlending(
351
30.0M
        memory_manager_, inout, fg_ptrs.data(), inout, patch_x0 - x0,
352
30.0M
        patch_x1 - patch_x0, blendings_[blending_idx],
353
30.0M
        blendings_.data() + blending_idx + 1, extra_channel_info));
354
30.0M
  }
355
2.06M
  return true;
356
2.06M
}
357
}  // namespace jxl