Coverage Report

Created: 2025-06-22 08:04

/src/libjxl/lib/jxl/enc_patch_dictionary.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
#include "lib/jxl/enc_patch_dictionary.h"
7
8
#include <jxl/memory_manager.h>
9
#include <jxl/types.h>
10
#include <sys/types.h>
11
12
#include <algorithm>
13
#include <atomic>
14
#include <cstdint>
15
#include <cstdlib>
16
#include <utility>
17
#include <vector>
18
19
#include "lib/jxl/base/common.h"
20
#include "lib/jxl/base/compiler_specific.h"
21
#include "lib/jxl/base/override.h"
22
#include "lib/jxl/base/printf_macros.h"
23
#include "lib/jxl/base/random.h"
24
#include "lib/jxl/base/rect.h"
25
#include "lib/jxl/base/status.h"
26
#include "lib/jxl/dec_cache.h"
27
#include "lib/jxl/dec_frame.h"
28
#include "lib/jxl/enc_ans.h"
29
#include "lib/jxl/enc_aux_out.h"
30
#include "lib/jxl/enc_cache.h"
31
#include "lib/jxl/enc_debug_image.h"
32
#include "lib/jxl/enc_dot_dictionary.h"
33
#include "lib/jxl/enc_frame.h"
34
#include "lib/jxl/frame_header.h"
35
#include "lib/jxl/image.h"
36
#include "lib/jxl/image_bundle.h"
37
#include "lib/jxl/image_ops.h"
38
#include "lib/jxl/pack_signed.h"
39
#include "lib/jxl/patch_dictionary_internal.h"
40
41
namespace jxl {
42
43
static constexpr size_t kPatchFrameReferenceId = 3;
44
45
// static
46
Status PatchDictionaryEncoder::Encode(const PatchDictionary& pdic,
47
                                      BitWriter* writer, LayerType layer,
48
0
                                      AuxOut* aux_out) {
49
0
  JXL_ENSURE(pdic.HasAny());
50
0
  JxlMemoryManager* memory_manager = writer->memory_manager();
51
0
  std::vector<std::vector<Token>> tokens(1);
52
53
0
  auto add_num = [&](int context, size_t num) {
54
0
    tokens[0].emplace_back(context, num);
55
0
  };
56
0
  size_t num_ref_patch = 0;
57
0
  for (size_t i = 0; i < pdic.positions_.size();) {
58
0
    size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx;
59
0
    while (i < pdic.positions_.size() &&
60
0
           pdic.positions_[i].ref_pos_idx == ref_pos_idx) {
61
0
      i++;
62
0
    }
63
0
    num_ref_patch++;
64
0
  }
65
0
  add_num(kNumRefPatchContext, num_ref_patch);
66
0
  size_t blend_pos = 0;
67
0
  size_t blending_stride = pdic.blendings_stride_;
68
  // blending_stride == num_ec + 1; num_ec > 1 =>
69
0
  bool choose_alpha = (blending_stride > 1 + 1);
70
0
  for (size_t i = 0; i < pdic.positions_.size();) {
71
0
    size_t i_start = i;
72
0
    size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx;
73
0
    const auto& ref_pos = pdic.ref_positions_[ref_pos_idx];
74
0
    while (i < pdic.positions_.size() &&
75
0
           pdic.positions_[i].ref_pos_idx == ref_pos_idx) {
76
0
      i++;
77
0
    }
78
0
    size_t num = i - i_start;
79
0
    JXL_ENSURE(num > 0);
80
0
    add_num(kReferenceFrameContext, ref_pos.ref);
81
0
    add_num(kPatchReferencePositionContext, ref_pos.x0);
82
0
    add_num(kPatchReferencePositionContext, ref_pos.y0);
83
0
    add_num(kPatchSizeContext, ref_pos.xsize - 1);
84
0
    add_num(kPatchSizeContext, ref_pos.ysize - 1);
85
0
    add_num(kPatchCountContext, num - 1);
86
0
    for (size_t j = i_start; j < i; j++) {
87
0
      const PatchPosition& pos = pdic.positions_[j];
88
0
      if (j == i_start) {
89
0
        add_num(kPatchPositionContext, pos.x);
90
0
        add_num(kPatchPositionContext, pos.y);
91
0
      } else {
92
0
        add_num(kPatchOffsetContext,
93
0
                PackSigned(pos.x - pdic.positions_[j - 1].x));
94
0
        add_num(kPatchOffsetContext,
95
0
                PackSigned(pos.y - pdic.positions_[j - 1].y));
96
0
      }
97
0
      for (size_t j = 0; j < blending_stride; ++j, ++blend_pos) {
98
0
        const PatchBlending& info = pdic.blendings_[blend_pos];
99
0
        add_num(kPatchBlendModeContext, static_cast<uint32_t>(info.mode));
100
0
        if (UsesAlpha(info.mode) && choose_alpha) {
101
0
          add_num(kPatchAlphaChannelContext, info.alpha_channel);
102
0
        }
103
0
        if (UsesClamp(info.mode)) {
104
0
          add_num(kPatchClampContext, TO_JXL_BOOL(info.clamp));
105
0
        }
106
0
      }
107
0
    }
108
0
  }
109
110
0
  EntropyEncodingData codes;
111
0
  std::vector<uint8_t> context_map;
112
0
  JXL_ASSIGN_OR_RETURN(
113
0
      size_t cost,
114
0
      BuildAndEncodeHistograms(memory_manager, HistogramParams(),
115
0
                               kNumPatchDictionaryContexts, tokens, &codes,
116
0
                               &context_map, writer, layer, aux_out));
117
0
  (void)cost;
118
0
  JXL_RETURN_IF_ERROR(
119
0
      WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out));
120
0
  return true;
121
0
}
122
123
// static
124
Status PatchDictionaryEncoder::SubtractFrom(const PatchDictionary& pdic,
125
0
                                            Image3F* opsin) {
126
  // TODO(veluca): this can likely be optimized knowing it runs on full images.
127
0
  for (size_t y = 0; y < opsin->ysize(); y++) {
128
0
    float* JXL_RESTRICT rows[3] = {
129
0
        opsin->PlaneRow(0, y),
130
0
        opsin->PlaneRow(1, y),
131
0
        opsin->PlaneRow(2, y),
132
0
    };
133
0
    size_t blending_stride = pdic.blendings_stride_;
134
0
    for (size_t pos_idx : pdic.GetPatchesForRow(y)) {
135
0
      const size_t blending_idx = pos_idx * blending_stride;
136
0
      const PatchPosition& pos = pdic.positions_[pos_idx];
137
0
      const PatchReferencePosition& ref_pos =
138
0
          pdic.ref_positions_[pos.ref_pos_idx];
139
0
      const PatchBlendMode mode = pdic.blendings_[blending_idx].mode;
140
0
      size_t by = pos.y;
141
0
      size_t bx = pos.x;
142
0
      size_t xsize = ref_pos.xsize;
143
0
      JXL_ENSURE(y >= by);
144
0
      JXL_ENSURE(y < by + ref_pos.ysize);
145
0
      size_t iy = y - by;
146
0
      size_t ref = ref_pos.ref;
147
0
      const float* JXL_RESTRICT ref_rows[3] = {
148
0
          pdic.reference_frames_->at(ref).frame->color()->ConstPlaneRow(
149
0
              0, ref_pos.y0 + iy) +
150
0
              ref_pos.x0,
151
0
          pdic.reference_frames_->at(ref).frame->color()->ConstPlaneRow(
152
0
              1, ref_pos.y0 + iy) +
153
0
              ref_pos.x0,
154
0
          pdic.reference_frames_->at(ref).frame->color()->ConstPlaneRow(
155
0
              2, ref_pos.y0 + iy) +
156
0
              ref_pos.x0,
157
0
      };
158
0
      for (size_t ix = 0; ix < xsize; ix++) {
159
0
        for (size_t c = 0; c < 3; c++) {
160
0
          if (mode == PatchBlendMode::kAdd) {
161
0
            rows[c][bx + ix] -= ref_rows[c][ix];
162
0
          } else if (mode == PatchBlendMode::kReplace) {
163
0
            rows[c][bx + ix] = 0;
164
0
          } else if (mode == PatchBlendMode::kNone) {
165
            // Nothing to do.
166
0
          } else {
167
0
            return JXL_UNREACHABLE("blending mode %u not yet implemented",
168
0
                                   static_cast<uint32_t>(mode));
169
0
          }
170
0
        }
171
0
      }
172
0
    }
173
0
  }
174
0
  return true;
175
0
}
176
177
namespace {
178
179
struct PatchColorspaceInfo {
180
  float kChannelDequant[3];
181
  float kChannelWeights[3];
182
183
0
  explicit PatchColorspaceInfo(bool is_xyb) {
184
0
    if (is_xyb) {
185
0
      kChannelDequant[0] = 0.01615;
186
0
      kChannelDequant[1] = 0.08875;
187
0
      kChannelDequant[2] = 0.1922;
188
0
      kChannelWeights[0] = 30.0;
189
0
      kChannelWeights[1] = 3.0;
190
0
      kChannelWeights[2] = 1.0;
191
0
    } else {
192
0
      kChannelDequant[0] = 20.0f / 255;
193
0
      kChannelDequant[1] = 22.0f / 255;
194
0
      kChannelDequant[2] = 20.0f / 255;
195
0
      kChannelWeights[0] = 0.017 * 255;
196
0
      kChannelWeights[1] = 0.02 * 255;
197
0
      kChannelWeights[2] = 0.017 * 255;
198
0
    }
199
0
  }
200
201
0
  float ScaleForQuantization(float val, size_t c) {
202
0
    return val / kChannelDequant[c];
203
0
  }
204
205
0
  int Quantize(float val, size_t c) {
206
0
    return truncf(ScaleForQuantization(val, c));
207
0
  }
208
209
0
  bool is_similar_v(const float v1[3], const float v2[3], float threshold) {
210
0
    float distance = 0;
211
0
    for (size_t c = 0; c < 3; c++) {
212
0
      distance += std::fabs(v1[c] - v2[c]) * kChannelWeights[c];
213
0
    }
214
0
    return distance <= threshold;
215
0
  }
216
};
217
218
StatusOr<std::vector<PatchInfo>> FindTextLikePatches(
219
    const CompressParams& cparams, const Image3F& opsin,
220
    const PassesEncoderState* JXL_RESTRICT state, ThreadPool* pool,
221
0
    AuxOut* aux_out, bool is_xyb) {
222
0
  std::vector<PatchInfo> info;
223
0
  if (state->cparams.patches == Override::kOff) return info;
224
0
  const auto& frame_dim = state->shared.frame_dim;
225
0
  JxlMemoryManager* memory_manager = opsin.memory_manager();
226
227
0
  PatchColorspaceInfo pci(is_xyb);
228
0
  float kSimilarThreshold = 0.8f;
229
230
0
  auto is_similar_impl = [&pci](std::pair<uint32_t, uint32_t> p1,
231
0
                                std::pair<uint32_t, uint32_t> p2,
232
0
                                const float* JXL_RESTRICT rows[3],
233
0
                                size_t stride, float threshold) {
234
0
    float v1[3];
235
0
    float v2[3];
236
0
    for (size_t c = 0; c < 3; c++) {
237
0
      v1[c] = rows[c][p1.second * stride + p1.first];
238
0
      v2[c] = rows[c][p2.second * stride + p2.first];
239
0
    }
240
0
    return pci.is_similar_v(v1, v2, threshold);
241
0
  };
242
243
0
  std::atomic<bool> has_screenshot_areas{false};
244
0
  const size_t opsin_stride = opsin.PixelsPerRow();
245
0
  const float* JXL_RESTRICT opsin_rows[3] = {opsin.ConstPlaneRow(0, 0),
246
0
                                             opsin.ConstPlaneRow(1, 0),
247
0
                                             opsin.ConstPlaneRow(2, 0)};
248
249
0
  auto is_same = [&opsin_rows, opsin_stride](std::pair<uint32_t, uint32_t> p1,
250
0
                                             std::pair<uint32_t, uint32_t> p2) {
251
0
    for (auto& opsin_row : opsin_rows) {
252
0
      float v1 = opsin_row[p1.second * opsin_stride + p1.first];
253
0
      float v2 = opsin_row[p2.second * opsin_stride + p2.first];
254
0
      if (std::fabs(v1 - v2) > 1e-4) {
255
0
        return false;
256
0
      }
257
0
    }
258
0
    return true;
259
0
  };
260
261
0
  auto is_similar = [&](std::pair<uint32_t, uint32_t> p1,
262
0
                        std::pair<uint32_t, uint32_t> p2) {
263
0
    return is_similar_impl(p1, p2, opsin_rows, opsin_stride, kSimilarThreshold);
264
0
  };
265
266
0
  constexpr int64_t kPatchSide = 4;
267
0
  constexpr int64_t kExtraSide = 4;
268
269
  // Look for kPatchSide size squares, naturally aligned, that all have the same
270
  // pixel values.
271
0
  JXL_ASSIGN_OR_RETURN(
272
0
      ImageB is_screenshot_like,
273
0
      ImageB::Create(memory_manager, DivCeil(frame_dim.xsize, kPatchSide),
274
0
                     DivCeil(frame_dim.ysize, kPatchSide)));
275
0
  ZeroFillImage(&is_screenshot_like);
276
0
  uint8_t* JXL_RESTRICT screenshot_row = is_screenshot_like.Row(0);
277
0
  const size_t screenshot_stride = is_screenshot_like.PixelsPerRow();
278
0
  const auto process_row = [&](const uint32_t y,
279
0
                               size_t /* thread */) -> Status {
280
0
    for (uint64_t x = 0; x < frame_dim.xsize / kPatchSide; x++) {
281
0
      bool all_same = true;
282
0
      for (size_t iy = 0; iy < static_cast<size_t>(kPatchSide); iy++) {
283
0
        for (size_t ix = 0; ix < static_cast<size_t>(kPatchSide); ix++) {
284
0
          size_t cx = x * kPatchSide + ix;
285
0
          size_t cy = y * kPatchSide + iy;
286
0
          if (!is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) {
287
0
            all_same = false;
288
0
            break;
289
0
          }
290
0
        }
291
0
      }
292
0
      if (!all_same) continue;
293
0
      size_t num = 0;
294
0
      size_t num_same = 0;
295
0
      for (int64_t iy = -kExtraSide; iy < kExtraSide + kPatchSide; iy++) {
296
0
        for (int64_t ix = -kExtraSide; ix < kExtraSide + kPatchSide; ix++) {
297
0
          int64_t cx = x * kPatchSide + ix;
298
0
          int64_t cy = y * kPatchSide + iy;
299
0
          if (cx < 0 || static_cast<uint64_t>(cx) >= frame_dim.xsize ||  //
300
0
              cy < 0 || static_cast<uint64_t>(cy) >= frame_dim.ysize) {
301
0
            continue;
302
0
          }
303
0
          num++;
304
0
          if (is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) num_same++;
305
0
        }
306
0
      }
307
      // Too few equal pixels nearby.
308
0
      if (num_same * 8 < num * 7) continue;
309
0
      screenshot_row[y * screenshot_stride + x] = 1;
310
0
      has_screenshot_areas = true;
311
0
    }
312
0
    return true;
313
0
  };
314
0
  JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, frame_dim.ysize / kPatchSide,
315
0
                                ThreadPool::NoInit, process_row,
316
0
                                "IsScreenshotLike"));
317
318
  // TODO(veluca): also parallelize the rest of this function.
319
0
  if (WantDebugOutput(cparams)) {
320
0
    JXL_RETURN_IF_ERROR(
321
0
        DumpPlaneNormalized(cparams, "screenshot_like", is_screenshot_like));
322
0
  }
323
324
0
  constexpr int kSearchRadius = 1;
325
326
0
  if (!ApplyOverride(state->cparams.patches, has_screenshot_areas)) {
327
0
    return info;
328
0
  }
329
330
  // Search for "similar enough" pixels near the screenshot-like areas.
331
0
  JXL_ASSIGN_OR_RETURN(
332
0
      ImageB is_background,
333
0
      ImageB::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
334
0
  ZeroFillImage(&is_background);
335
0
  JXL_ASSIGN_OR_RETURN(
336
0
      Image3F background,
337
0
      Image3F::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
338
0
  ZeroFillImage(&background);
339
0
  constexpr size_t kDistanceLimit = 50;
340
0
  float* JXL_RESTRICT background_rows[3] = {
341
0
      background.PlaneRow(0, 0),
342
0
      background.PlaneRow(1, 0),
343
0
      background.PlaneRow(2, 0),
344
0
  };
345
0
  const size_t background_stride = background.PixelsPerRow();
346
0
  uint8_t* JXL_RESTRICT is_background_row = is_background.Row(0);
347
0
  const size_t is_background_stride = is_background.PixelsPerRow();
348
0
  std::vector<
349
0
      std::pair<std::pair<uint32_t, uint32_t>, std::pair<uint32_t, uint32_t>>>
350
0
      queue;
351
0
  size_t queue_front = 0;
352
0
  for (size_t y = 0; y < frame_dim.ysize; y++) {
353
0
    for (size_t x = 0; x < frame_dim.xsize; x++) {
354
0
      if (!screenshot_row[screenshot_stride * (y / kPatchSide) +
355
0
                          (x / kPatchSide)])
356
0
        continue;
357
0
      queue.push_back({{x, y}, {x, y}});
358
0
    }
359
0
  }
360
0
  while (queue.size() != queue_front) {
361
0
    std::pair<uint32_t, uint32_t> cur = queue[queue_front].first;
362
0
    std::pair<uint32_t, uint32_t> src = queue[queue_front].second;
363
0
    queue_front++;
364
0
    if (is_background_row[cur.second * is_background_stride + cur.first])
365
0
      continue;
366
0
    is_background_row[cur.second * is_background_stride + cur.first] = 1;
367
0
    for (size_t c = 0; c < 3; c++) {
368
0
      background_rows[c][cur.second * background_stride + cur.first] =
369
0
          opsin_rows[c][src.second * opsin_stride + src.first];
370
0
    }
371
0
    for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) {
372
0
      for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) {
373
0
        if (dx == 0 && dy == 0) continue;
374
0
        int next_first = cur.first + dx;
375
0
        int next_second = cur.second + dy;
376
0
        if (next_first < 0 || next_second < 0 ||
377
0
            static_cast<uint32_t>(next_first) >= frame_dim.xsize ||
378
0
            static_cast<uint32_t>(next_second) >= frame_dim.ysize) {
379
0
          continue;
380
0
        }
381
0
        if (static_cast<uint32_t>(
382
0
                std::abs(next_first - static_cast<int>(src.first)) +
383
0
                std::abs(next_second - static_cast<int>(src.second))) >
384
0
            kDistanceLimit) {
385
0
          continue;
386
0
        }
387
0
        std::pair<uint32_t, uint32_t> next{next_first, next_second};
388
0
        if (is_similar(src, next)) {
389
0
          if (!screenshot_row[next.second / kPatchSide * screenshot_stride +
390
0
                              next.first / kPatchSide] ||
391
0
              is_same(src, next)) {
392
0
            if (!is_background_row[next.second * is_background_stride +
393
0
                                   next.first])
394
0
              queue.emplace_back(next, src);
395
0
          }
396
0
        }
397
0
      }
398
0
    }
399
0
  }
400
0
  queue.clear();
401
402
0
  ImageF ccs;
403
0
  Rng rng(0);
404
0
  bool paint_ccs = false;
405
0
  if (WantDebugOutput(cparams)) {
406
0
    JXL_RETURN_IF_ERROR(
407
0
        DumpPlaneNormalized(cparams, "is_background", is_background));
408
0
    if (is_xyb) {
409
0
      JXL_RETURN_IF_ERROR(DumpXybImage(cparams, "background", background));
410
0
    } else {
411
0
      JXL_RETURN_IF_ERROR(DumpImage(cparams, "background", background));
412
0
    }
413
0
    JXL_ASSIGN_OR_RETURN(
414
0
        ccs, ImageF::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
415
0
    ZeroFillImage(&ccs);
416
0
    paint_ccs = true;
417
0
  }
418
419
0
  constexpr float kVerySimilarThreshold = 0.03f;
420
0
  constexpr float kHasSimilarThreshold = 0.03f;
421
422
0
  const float* JXL_RESTRICT const_background_rows[3] = {
423
0
      background_rows[0], background_rows[1], background_rows[2]};
424
0
  auto is_similar_b = [&](std::pair<int, int> p1, std::pair<int, int> p2) {
425
0
    return is_similar_impl(p1, p2, const_background_rows, background_stride,
426
0
                           kVerySimilarThreshold);
427
0
  };
428
429
0
  constexpr int kMinPeak = 2;
430
0
  constexpr int kHasSimilarRadius = 2;
431
432
  // Find small CC outside the "similar enough" areas, compute bounding boxes,
433
  // and run heuristics to exclude some patches.
434
0
  JXL_ASSIGN_OR_RETURN(
435
0
      ImageB visited,
436
0
      ImageB::Create(memory_manager, frame_dim.xsize, frame_dim.ysize));
437
0
  ZeroFillImage(&visited);
438
0
  uint8_t* JXL_RESTRICT visited_row = visited.Row(0);
439
0
  const size_t visited_stride = visited.PixelsPerRow();
440
0
  std::vector<std::pair<uint32_t, uint32_t>> cc;
441
0
  std::vector<std::pair<uint32_t, uint32_t>> stack;
442
0
  for (size_t y = 0; y < frame_dim.ysize; y++) {
443
0
    for (size_t x = 0; x < frame_dim.xsize; x++) {
444
0
      if (is_background_row[y * is_background_stride + x]) continue;
445
0
      cc.clear();
446
0
      stack.clear();
447
0
      stack.emplace_back(x, y);
448
0
      size_t min_x = x;
449
0
      size_t max_x = x;
450
0
      size_t min_y = y;
451
0
      size_t max_y = y;
452
0
      std::pair<uint32_t, uint32_t> reference;
453
0
      bool found_border = false;
454
0
      bool all_similar = true;
455
0
      while (!stack.empty()) {
456
0
        std::pair<uint32_t, uint32_t> cur = stack.back();
457
0
        stack.pop_back();
458
0
        if (visited_row[cur.second * visited_stride + cur.first]) continue;
459
0
        visited_row[cur.second * visited_stride + cur.first] = 1;
460
0
        if (cur.first < min_x) min_x = cur.first;
461
0
        if (cur.first > max_x) max_x = cur.first;
462
0
        if (cur.second < min_y) min_y = cur.second;
463
0
        if (cur.second > max_y) max_y = cur.second;
464
0
        if (paint_ccs) {
465
0
          cc.push_back(cur);
466
0
        }
467
0
        for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) {
468
0
          for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) {
469
0
            if (dx == 0 && dy == 0) continue;
470
0
            int next_first = static_cast<int32_t>(cur.first) + dx;
471
0
            int next_second = static_cast<int32_t>(cur.second) + dy;
472
0
            if (next_first < 0 || next_second < 0 ||
473
0
                static_cast<uint32_t>(next_first) >= frame_dim.xsize ||
474
0
                static_cast<uint32_t>(next_second) >= frame_dim.ysize) {
475
0
              continue;
476
0
            }
477
0
            std::pair<uint32_t, uint32_t> next{next_first, next_second};
478
0
            if (!is_background_row[next.second * is_background_stride +
479
0
                                   next.first]) {
480
0
              stack.push_back(next);
481
0
            } else {
482
0
              if (!found_border) {
483
0
                reference = next;
484
0
                found_border = true;
485
0
              } else {
486
0
                if (!is_similar_b(next, reference)) all_similar = false;
487
0
              }
488
0
            }
489
0
          }
490
0
        }
491
0
      }
492
0
      if (!found_border || !all_similar || max_x - min_x >= kMaxPatchSize ||
493
0
          max_y - min_y >= kMaxPatchSize) {
494
0
        continue;
495
0
      }
496
0
      size_t bpos = background_stride * reference.second + reference.first;
497
0
      float ref[3] = {background_rows[0][bpos], background_rows[1][bpos],
498
0
                      background_rows[2][bpos]};
499
0
      bool has_similar = false;
500
0
      for (size_t iy = std::max<int>(
501
0
               static_cast<int32_t>(min_y) - kHasSimilarRadius, 0);
502
0
           iy < std::min(max_y + kHasSimilarRadius + 1, frame_dim.ysize);
503
0
           iy++) {
504
0
        for (size_t ix = std::max<int>(
505
0
                 static_cast<int32_t>(min_x) - kHasSimilarRadius, 0);
506
0
             ix < std::min(max_x + kHasSimilarRadius + 1, frame_dim.xsize);
507
0
             ix++) {
508
0
          size_t opos = opsin_stride * iy + ix;
509
0
          float px[3] = {opsin_rows[0][opos], opsin_rows[1][opos],
510
0
                         opsin_rows[2][opos]};
511
0
          if (pci.is_similar_v(ref, px, kHasSimilarThreshold)) {
512
0
            has_similar = true;
513
0
          }
514
0
        }
515
0
      }
516
0
      if (!has_similar) continue;
517
0
      info.emplace_back();
518
0
      info.back().second.emplace_back(min_x, min_y);
519
0
      QuantizedPatch& patch = info.back().first;
520
0
      patch.xsize = max_x - min_x + 1;
521
0
      patch.ysize = max_y - min_y + 1;
522
0
      int max_value = 0;
523
0
      for (size_t c : {1, 0, 2}) {
524
0
        for (size_t iy = min_y; iy <= max_y; iy++) {
525
0
          for (size_t ix = min_x; ix <= max_x; ix++) {
526
0
            size_t offset = (iy - min_y) * patch.xsize + ix - min_x;
527
0
            patch.fpixels[c][offset] =
528
0
                opsin_rows[c][iy * opsin_stride + ix] - ref[c];
529
0
            int val = pci.Quantize(patch.fpixels[c][offset], c);
530
0
            patch.pixels[c][offset] = val;
531
0
            if (std::abs(val) > max_value) max_value = std::abs(val);
532
0
          }
533
0
        }
534
0
      }
535
0
      if (max_value < kMinPeak) {
536
0
        info.pop_back();
537
0
        continue;
538
0
      }
539
0
      if (paint_ccs) {
540
0
        float cc_color = rng.UniformF(0.5, 1.0);
541
0
        for (std::pair<uint32_t, uint32_t> p : cc) {
542
0
          ccs.Row(p.second)[p.first] = cc_color;
543
0
        }
544
0
      }
545
0
    }
546
0
  }
547
548
0
  if (paint_ccs) {
549
0
    JXL_ENSURE(WantDebugOutput(cparams));
550
0
    JXL_RETURN_IF_ERROR(DumpPlaneNormalized(cparams, "ccs", ccs));
551
0
  }
552
0
  if (info.empty()) {
553
0
    return info;
554
0
  }
555
556
  // Remove duplicates.
557
0
  constexpr size_t kMinPatchOccurrences = 2;
558
0
  std::sort(info.begin(), info.end());
559
0
  size_t unique = 0;
560
0
  for (size_t i = 1; i < info.size(); i++) {
561
0
    if (info[i].first == info[unique].first) {
562
0
      info[unique].second.insert(info[unique].second.end(),
563
0
                                 info[i].second.begin(), info[i].second.end());
564
0
    } else {
565
0
      if (info[unique].second.size() >= kMinPatchOccurrences) {
566
0
        unique++;
567
0
      }
568
0
      info[unique] = info[i];
569
0
    }
570
0
  }
571
0
  if (info[unique].second.size() >= kMinPatchOccurrences) {
572
0
    unique++;
573
0
  }
574
0
  info.resize(unique);
575
576
0
  size_t max_patch_size = 0;
577
578
0
  for (const auto& patch : info) {
579
0
    size_t pixels = patch.first.xsize * patch.first.ysize;
580
0
    if (pixels > max_patch_size) max_patch_size = pixels;
581
0
  }
582
583
  // don't use patches if all patches are smaller than this
584
0
  constexpr size_t kMinMaxPatchSize = 20;
585
0
  if (max_patch_size < kMinMaxPatchSize) {
586
0
    info.clear();
587
0
  }
588
589
0
  return info;
590
0
}
591
592
}  // namespace
593
594
Status FindBestPatchDictionary(const Image3F& opsin,
595
                               PassesEncoderState* JXL_RESTRICT state,
596
                               const JxlCmsInterface& cms, ThreadPool* pool,
597
0
                               AuxOut* aux_out, bool is_xyb) {
598
0
  JXL_ASSIGN_OR_RETURN(
599
0
      std::vector<PatchInfo> info,
600
0
      FindTextLikePatches(state->cparams, opsin, state, pool, aux_out, is_xyb));
601
0
  JxlMemoryManager* memory_manager = opsin.memory_manager();
602
603
  // TODO(veluca): this doesn't work if both dots and patches are enabled.
604
  // For now, since dots and patches are not likely to occur in the same kind of
605
  // images, disable dots if some patches were found.
606
0
  if (info.empty() &&
607
0
      ApplyOverride(
608
0
          state->cparams.dots,
609
0
          state->cparams.speed_tier <= SpeedTier::kSquirrel &&
610
0
              state->cparams.butteraugli_distance >= kMinButteraugliForDots &&
611
0
              !state->cparams.disable_perceptual_optimizations)) {
612
0
    Rect rect(0, 0, state->shared.frame_dim.xsize,
613
0
              state->shared.frame_dim.ysize);
614
0
    JXL_ASSIGN_OR_RETURN(info,
615
0
                         FindDotDictionary(state->cparams, opsin, rect,
616
0
                                           state->shared.cmap.base(), pool));
617
0
  }
618
619
0
  if (info.empty()) return true;
620
621
0
  std::sort(
622
0
      info.begin(), info.end(), [&](const PatchInfo& a, const PatchInfo& b) {
623
0
        return a.first.xsize * a.first.ysize > b.first.xsize * b.first.ysize;
624
0
      });
625
626
0
  size_t max_x_size = 0;
627
0
  size_t max_y_size = 0;
628
0
  size_t total_pixels = 0;
629
630
0
  for (const auto& patch : info) {
631
0
    size_t pixels = patch.first.xsize * patch.first.ysize;
632
0
    if (max_x_size < patch.first.xsize) max_x_size = patch.first.xsize;
633
0
    if (max_y_size < patch.first.ysize) max_y_size = patch.first.ysize;
634
0
    total_pixels += pixels;
635
0
  }
636
637
  // Bin-packing & conversion of patches.
638
0
  constexpr float kBinPackingSlackness = 1.05f;
639
0
  size_t ref_xsize = std::max<float>(max_x_size, std::sqrt(total_pixels));
640
0
  size_t ref_ysize = std::max<float>(max_y_size, std::sqrt(total_pixels));
641
0
  std::vector<std::pair<size_t, size_t>> ref_positions(info.size());
642
  // TODO(veluca): allow partial overlaps of patches that have the same pixels.
643
0
  size_t max_y = 0;
644
0
  do {
645
0
    max_y = 0;
646
    // Increase packed image size.
647
0
    ref_xsize = ref_xsize * kBinPackingSlackness + 1;
648
0
    ref_ysize = ref_ysize * kBinPackingSlackness + 1;
649
650
0
    JXL_ASSIGN_OR_RETURN(ImageB occupied,
651
0
                         ImageB::Create(memory_manager, ref_xsize, ref_ysize));
652
0
    ZeroFillImage(&occupied);
653
0
    uint8_t* JXL_RESTRICT occupied_rows = occupied.Row(0);
654
0
    size_t occupied_stride = occupied.PixelsPerRow();
655
656
0
    bool success = true;
657
    // For every patch...
658
0
    for (size_t patch = 0; patch < info.size(); patch++) {
659
0
      size_t x0 = 0;
660
0
      size_t y0 = 0;
661
0
      size_t xsize = info[patch].first.xsize;
662
0
      size_t ysize = info[patch].first.ysize;
663
0
      bool found = false;
664
      // For every possible start position ...
665
0
      for (; y0 + ysize <= ref_ysize; y0++) {
666
0
        x0 = 0;
667
0
        for (; x0 + xsize <= ref_xsize; x0++) {
668
0
          bool has_occupied_pixel = false;
669
0
          size_t x = x0;
670
          // Check if it is possible to place the patch in this position in the
671
          // reference frame.
672
0
          for (size_t y = y0; y < y0 + ysize; y++) {
673
0
            x = x0;
674
0
            for (; x < x0 + xsize; x++) {
675
0
              if (occupied_rows[y * occupied_stride + x]) {
676
0
                has_occupied_pixel = true;
677
0
                break;
678
0
              }
679
0
            }
680
0
          }  // end of positioning check
681
0
          if (!has_occupied_pixel) {
682
0
            found = true;
683
0
            break;
684
0
          }
685
0
          x0 = x;  // Jump to next pixel after the occupied one.
686
0
        }
687
0
        if (found) break;
688
0
      }  // end of start position checking
689
690
      // We didn't find a possible position: repeat from the beginning with a
691
      // larger reference frame size.
692
0
      if (!found) {
693
0
        success = false;
694
0
        break;
695
0
      }
696
697
      // We found a position: mark the corresponding positions in the reference
698
      // image as used.
699
0
      ref_positions[patch] = {x0, y0};
700
0
      for (size_t y = y0; y < y0 + ysize; y++) {
701
0
        for (size_t x = x0; x < x0 + xsize; x++) {
702
0
          occupied_rows[y * occupied_stride + x] = JXL_TRUE;
703
0
        }
704
0
      }
705
0
      max_y = std::max(max_y, y0 + ysize);
706
0
    }
707
708
0
    if (success) break;
709
0
  } while (true);
710
711
0
  JXL_ENSURE(ref_ysize >= max_y);
712
713
0
  ref_ysize = max_y;
714
715
0
  JXL_ASSIGN_OR_RETURN(Image3F reference_frame,
716
0
                       Image3F::Create(memory_manager, ref_xsize, ref_ysize));
717
  // TODO(veluca): figure out a better way to fill the image.
718
0
  ZeroFillImage(&reference_frame);
719
0
  std::vector<PatchPosition> positions;
720
0
  std::vector<PatchReferencePosition> pref_positions;
721
0
  std::vector<PatchBlending> blendings;
722
0
  float* JXL_RESTRICT ref_rows[3] = {
723
0
      reference_frame.PlaneRow(0, 0),
724
0
      reference_frame.PlaneRow(1, 0),
725
0
      reference_frame.PlaneRow(2, 0),
726
0
  };
727
0
  size_t ref_stride = reference_frame.PixelsPerRow();
728
0
  size_t num_ec = state->shared.metadata->m.num_extra_channels;
729
730
0
  for (size_t i = 0; i < info.size(); i++) {
731
0
    PatchReferencePosition ref_pos;
732
0
    ref_pos.xsize = info[i].first.xsize;
733
0
    ref_pos.ysize = info[i].first.ysize;
734
0
    ref_pos.x0 = ref_positions[i].first;
735
0
    ref_pos.y0 = ref_positions[i].second;
736
0
    ref_pos.ref = kPatchFrameReferenceId;
737
0
    for (size_t y = 0; y < ref_pos.ysize; y++) {
738
0
      for (size_t x = 0; x < ref_pos.xsize; x++) {
739
0
        for (size_t c = 0; c < 3; c++) {
740
0
          ref_rows[c][(y + ref_pos.y0) * ref_stride + x + ref_pos.x0] =
741
0
              info[i].first.fpixels[c][y * ref_pos.xsize + x];
742
0
        }
743
0
      }
744
0
    }
745
0
    for (const auto& pos : info[i].second) {
746
0
      JXL_DEBUG_V(4, "Patch %" PRIuS "x%" PRIuS " at position %u,%u",
747
0
                  ref_pos.xsize, ref_pos.ysize, pos.first, pos.second);
748
0
      positions.emplace_back(
749
0
          PatchPosition{pos.first, pos.second, pref_positions.size()});
750
      // Add blending for color channels, ignore other channels.
751
0
      blendings.push_back({PatchBlendMode::kAdd, 0, false});
752
0
      for (size_t j = 0; j < num_ec; ++j) {
753
0
        blendings.push_back({PatchBlendMode::kNone, 0, false});
754
0
      }
755
0
    }
756
0
    pref_positions.emplace_back(ref_pos);
757
0
  }
758
759
0
  CompressParams cparams = state->cparams;
760
  // Recursive application of patches could create very weird issues.
761
0
  cparams.patches = Override::kOff;
762
763
0
  if (WantDebugOutput(cparams)) {
764
0
    if (is_xyb) {
765
0
      JXL_RETURN_IF_ERROR(
766
0
          DumpXybImage(cparams, "patch_reference", reference_frame));
767
0
    } else {
768
0
      JXL_RETURN_IF_ERROR(
769
0
          DumpImage(cparams, "patch_reference", reference_frame));
770
0
    }
771
0
  }
772
773
0
  JXL_RETURN_IF_ERROR(RoundtripPatchFrame(&reference_frame, state,
774
0
                                          kPatchFrameReferenceId, cparams, cms,
775
0
                                          pool, aux_out, /*subtract=*/true));
776
777
  // TODO(veluca): this assumes that applying patches is commutative, which is
778
  // not true for all blending modes. This code only produces kAdd patches, so
779
  // this works out.
780
0
  PatchDictionaryEncoder::SetPositions(
781
0
      &state->shared.image_features.patches, std::move(positions),
782
0
      std::move(pref_positions), std::move(blendings), num_ec + 1);
783
0
  return true;
784
0
}
785
786
Status RoundtripPatchFrame(Image3F* reference_frame,
787
                           PassesEncoderState* JXL_RESTRICT state, int idx,
788
                           CompressParams& cparams, const JxlCmsInterface& cms,
789
0
                           ThreadPool* pool, AuxOut* aux_out, bool subtract) {
790
0
  JxlMemoryManager* memory_manager = state->memory_manager();
791
0
  FrameInfo patch_frame_info;
792
0
  cparams.resampling = 1;
793
0
  cparams.ec_resampling = 1;
794
0
  cparams.dots = Override::kOff;
795
0
  cparams.noise = Override::kOff;
796
0
  cparams.modular_mode = true;
797
0
  cparams.responsive = 0;
798
0
  cparams.progressive_dc = 0;
799
0
  cparams.progressive_mode = Override::kOff;
800
0
  cparams.qprogressive_mode = Override::kOff;
801
  // Use gradient predictor and not Predictor::Best.
802
0
  cparams.options.predictor = Predictor::Gradient;
803
0
  patch_frame_info.save_as_reference = idx;  // always saved.
804
0
  patch_frame_info.frame_type = FrameType::kReferenceOnly;
805
0
  patch_frame_info.save_before_color_transform = true;
806
0
  ImageBundle ib(memory_manager, &state->shared.metadata->m);
807
  // TODO(veluca): metadata.color_encoding is a lie: ib is in XYB, but there is
808
  // no simple way to express that yet.
809
0
  patch_frame_info.ib_needs_color_transform = false;
810
0
  JXL_RETURN_IF_ERROR(ib.SetFromImage(
811
0
      std::move(*reference_frame), state->shared.metadata->m.color_encoding));
812
0
  if (!ib.metadata()->extra_channel_info.empty()) {
813
    // Add placeholder extra channels to the patch image: patch encoding does
814
    // not yet support extra channels, but the codec expects that the amount of
815
    // extra channels in frames matches that in the metadata of the codestream.
816
0
    std::vector<ImageF> extra_channels;
817
0
    extra_channels.reserve(ib.metadata()->extra_channel_info.size());
818
0
    for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) {
819
0
      JXL_ASSIGN_OR_RETURN(
820
0
          ImageF ch, ImageF::Create(memory_manager, ib.xsize(), ib.ysize()));
821
0
      extra_channels.emplace_back(std::move(ch));
822
      // Must initialize the image with data to not affect blending with
823
      // uninitialized memory.
824
      // TODO(lode): patches must copy and use the real extra channels instead.
825
0
      ZeroFillImage(&extra_channels.back());
826
0
    }
827
0
    JXL_RETURN_IF_ERROR(ib.SetExtraChannels(std::move(extra_channels)));
828
0
  }
829
0
  auto special_frame = jxl::make_unique<BitWriter>(memory_manager);
830
0
  AuxOut patch_aux_out;
831
0
  JXL_RETURN_IF_ERROR(EncodeFrame(
832
0
      memory_manager, cparams, patch_frame_info, state->shared.metadata, ib,
833
0
      cms, pool, special_frame.get(), aux_out ? &patch_aux_out : nullptr));
834
0
  if (aux_out) {
835
0
    for (const auto& l : patch_aux_out.layers) {
836
0
      aux_out->layer(LayerType::Dictionary).Assimilate(l);
837
0
    }
838
0
  }
839
0
  const Span<const uint8_t> encoded = special_frame->GetSpan();
840
0
  state->special_frames.emplace_back(std::move(special_frame));
841
0
  if (subtract) {
842
0
    ImageBundle decoded(memory_manager, &state->shared.metadata->m);
843
0
    PassesDecoderState dec_state(memory_manager);
844
0
    JXL_RETURN_IF_ERROR(dec_state.output_encoding_info.SetFromMetadata(
845
0
        *state->shared.metadata));
846
0
    const uint8_t* frame_start = encoded.data();
847
0
    size_t encoded_size = encoded.size();
848
0
    JXL_RETURN_IF_ERROR(DecodeFrame(&dec_state, pool, frame_start, encoded_size,
849
0
                                    /*frame_header=*/nullptr, &decoded,
850
0
                                    *state->shared.metadata));
851
0
    frame_start += decoded.decoded_bytes();
852
0
    encoded_size -= decoded.decoded_bytes();
853
0
    size_t ref_xsize =
854
0
        dec_state.shared_storage.reference_frames[idx].frame->color()->xsize();
855
    // if the frame itself uses patches, we need to decode another frame
856
0
    if (!ref_xsize) {
857
0
      JXL_RETURN_IF_ERROR(DecodeFrame(
858
0
          &dec_state, pool, frame_start, encoded_size,
859
0
          /*frame_header=*/nullptr, &decoded, *state->shared.metadata));
860
0
    }
861
0
    JXL_ENSURE(encoded_size == 0);
862
0
    state->shared.reference_frames[idx] =
863
0
        std::move(dec_state.shared_storage.reference_frames[idx]);
864
0
  } else {
865
0
    *state->shared.reference_frames[idx].frame = std::move(ib);
866
0
  }
867
0
  return true;
868
0
}
869
870
}  // namespace jxl