Coverage Report

Created: 2026-02-14 07:42

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/libjxl/lib/jxl/enc_optimize.h
Line
Count
Source
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
// Utility functions for optimizing multi-dimensional nonlinear functions.
7
8
#ifndef LIB_JXL_OPTIMIZE_H_
9
#define LIB_JXL_OPTIMIZE_H_
10
11
#include <cmath>
12
#include <cstdio>
13
14
#include "lib/jxl/base/status.h"
15
16
namespace jxl {
17
namespace optimize {
18
19
// An array type of numeric values that supports math operations with operator-,
20
// operator+, etc.
21
template <typename T, size_t N>
22
class Array {
23
 public:
24
  Array() = default;
25
  explicit Array(T v) {
26
    for (size_t i = 0; i < N; i++) v_[i] = v;
27
  }
28
29
7.51k
  size_t size() const { return N; }
30
31
29.9k
  T& operator[](size_t index) {
32
29.9k
    JXL_DASSERT(index < N);
33
29.9k
    return v_[index];
34
29.9k
  }
35
38.1k
  T operator[](size_t index) const {
36
38.1k
    JXL_DASSERT(index < N);
37
38.1k
    return v_[index];
38
38.1k
  }
39
40
 private:
41
  // The values used by this Array.
42
  T v_[N];
43
};
44
45
template <typename T, size_t N>
46
480
Array<T, N> operator+(const Array<T, N>& x, const Array<T, N>& y) {
47
480
  Array<T, N> z;
48
4.32k
  for (size_t i = 0; i < N; ++i) {
49
3.84k
    z[i] = x[i] + y[i];
50
3.84k
  }
51
480
  return z;
52
480
}
53
54
template <typename T, size_t N>
55
158
Array<T, N> operator-(const Array<T, N>& x, const Array<T, N>& y) {
56
158
  Array<T, N> z;
57
1.42k
  for (size_t i = 0; i < N; ++i) {
58
1.26k
    z[i] = x[i] - y[i];
59
1.26k
  }
60
158
  return z;
61
158
}
62
63
template <typename T, size_t N>
64
480
Array<T, N> operator*(T v, const Array<T, N>& x) {
65
480
  Array<T, N> y;
66
4.32k
  for (size_t i = 0; i < N; ++i) {
67
3.84k
    y[i] = v * x[i];
68
3.84k
  }
69
480
  return y;
70
480
}
71
72
template <typename T, size_t N>
73
428
T operator*(const Array<T, N>& x, const Array<T, N>& y) {
74
428
  T r = 0.0;
75
3.85k
  for (size_t i = 0; i < N; ++i) {
76
3.42k
    r += x[i] * y[i];
77
3.42k
  }
78
428
  return r;
79
428
}
80
81
// Implementation of the Scaled Conjugate Gradient method described in the
82
// following paper:
83
//   Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised
84
//   Learning", Neural Networks, Vol. 6. pp. 525-533, 1993
85
//   http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf
86
//
87
// The Function template parameter is a class that has the following method:
88
//
89
//   // Returns the value of the function at point w and sets *df to be the
90
//   // negative gradient vector of the function at point w.
91
//   double Compute(const optimize::Array<T, N>& w,
92
//                  optimize::Array<T, N>* df) const;
93
//
94
// Returns a vector w, such that |df(w)| < grad_norm_threshold.
95
template <typename T, size_t N, typename Function>
96
Array<T, N> OptimizeWithScaledConjugateGradientMethod(
97
    const Function& f, const Array<T, N>& w0, const T grad_norm_threshold,
98
10
    size_t max_iters) {
99
10
  const size_t n = w0.size();
100
10
  const T rsq_threshold = grad_norm_threshold * grad_norm_threshold;
101
10
  const T sigma0 = static_cast<T>(0.0001);
102
10
  const T l_min = static_cast<T>(1.0e-15);
103
10
  const T l_max = static_cast<T>(1.0e15);
104
105
10
  Array<T, N> w = w0;
106
10
  Array<T, N> wp;
107
10
  Array<T, N> r;
108
10
  Array<T, N> rt;
109
10
  Array<T, N> e;
110
10
  Array<T, N> p;
111
10
  T psq;
112
10
  T fp;
113
10
  T D;
114
10
  T d;
115
10
  T m;
116
10
  T a;
117
10
  T b;
118
10
  T s;
119
10
  T t;
120
121
10
  T fw = f.Compute(w, &r);
122
10
  T rsq = r * r;
123
10
  e = r;
124
10
  p = r;
125
10
  T l = static_cast<T>(1.0);
126
10
  bool success = true;
127
10
  size_t n_success = 0;
128
10
  size_t k = 0;
129
130
330
  while (k++ < max_iters) {
131
322
    if (success) {
132
88
      m = -(p * r);
133
88
      if (m >= 0) {
134
2
        p = r;
135
2
        m = -(p * r);
136
2
      }
137
88
      psq = p * p;
138
88
      s = sigma0 / std::sqrt(psq);
139
88
      f.Compute(w + (s * p), &rt);
140
88
      t = (p * (r - rt)) / s;
141
88
    }
142
143
322
    d = t + l * psq;
144
322
    if (d <= 0) {
145
0
      d = l * psq;
146
0
      l = l - t / psq;
147
0
    }
148
149
322
    a = -m / d;
150
322
    wp = w + a * p;
151
322
    fp = f.Compute(wp, &rt);
152
153
322
    D = 2.0 * (fp - fw) / (a * m);
154
322
    if (D >= 0.0) {
155
82
      success = true;
156
82
      n_success++;
157
82
      w = wp;
158
240
    } else {
159
240
      success = false;
160
240
    }
161
162
322
    if (success) {
163
82
      e = r;
164
82
      r = rt;
165
82
      rsq = r * r;
166
82
      fw = fp;
167
82
      if (rsq <= rsq_threshold) {
168
2
        break;
169
2
      }
170
82
    }
171
172
320
    if (D < 0.25) {
173
0
      l = std::min(4.0 * l, l_max);
174
320
    } else if (D > 0.75) {
175
80
      l = std::max(0.25 * l, l_min);
176
80
    }
177
178
320
    if ((n_success % n) == 0) {
179
250
      p = r;
180
250
      l = 1.0;
181
250
    } else if (success) {
182
70
      b = ((e - r) * r) / m;
183
70
      p = b * p + r;
184
70
    }
185
320
  }
186
187
10
  return w;
188
10
}
189
190
}  // namespace optimize
191
}  // namespace jxl
192
193
#endif  // LIB_JXL_OPTIMIZE_H_