Coverage Report

Created: 2024-05-21 06:26

/src/libjxl/lib/jxl/enc_optimize.h
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
2
//
3
// Use of this source code is governed by a BSD-style
4
// license that can be found in the LICENSE file.
5
6
// Utility functions for optimizing multi-dimensional nonlinear functions.
7
8
#ifndef LIB_JXL_OPTIMIZE_H_
9
#define LIB_JXL_OPTIMIZE_H_
10
11
#include <cmath>
12
#include <cstdio>
13
#include <functional>
14
#include <vector>
15
16
#include "lib/jxl/base/status.h"
17
18
namespace jxl {
19
namespace optimize {
20
21
// An array type of numeric values that supports math operations with operator-,
22
// operator+, etc.
23
template <typename T, size_t N>
24
class Array {
25
 public:
26
  Array() = default;
27
  explicit Array(T v) {
28
    for (size_t i = 0; i < N; i++) v_[i] = v;
29
  }
30
31
0
  size_t size() const { return N; }
32
33
0
  T& operator[](size_t index) {
34
0
    JXL_DASSERT(index < N);
35
0
    return v_[index];
36
0
  }
37
0
  T operator[](size_t index) const {
38
0
    JXL_DASSERT(index < N);
39
0
    return v_[index];
40
0
  }
41
42
 private:
43
  // The values used by this Array.
44
  T v_[N];
45
};
46
47
template <typename T, size_t N>
48
0
Array<T, N> operator+(const Array<T, N>& x, const Array<T, N>& y) {
49
0
  Array<T, N> z;
50
0
  for (size_t i = 0; i < N; ++i) {
51
0
    z[i] = x[i] + y[i];
52
0
  }
53
0
  return z;
54
0
}
55
56
template <typename T, size_t N>
57
0
Array<T, N> operator-(const Array<T, N>& x, const Array<T, N>& y) {
58
0
  Array<T, N> z;
59
0
  for (size_t i = 0; i < N; ++i) {
60
0
    z[i] = x[i] - y[i];
61
0
  }
62
0
  return z;
63
0
}
64
65
template <typename T, size_t N>
66
0
Array<T, N> operator*(T v, const Array<T, N>& x) {
67
0
  Array<T, N> y;
68
0
  for (size_t i = 0; i < N; ++i) {
69
0
    y[i] = v * x[i];
70
0
  }
71
0
  return y;
72
0
}
73
74
template <typename T, size_t N>
75
0
T operator*(const Array<T, N>& x, const Array<T, N>& y) {
76
0
  T r = 0.0;
77
0
  for (size_t i = 0; i < N; ++i) {
78
0
    r += x[i] * y[i];
79
0
  }
80
0
  return r;
81
0
}
82
83
// Runs Nelder-Mead like optimization. Runs for max_iterations times,
84
// fun gets called with a vector of size dim as argument, and returns the score
85
// based on those parameters (lower is better). Returns a vector of dim+1
86
// dimensions, where the first value is the optimal value of the function and
87
// the rest is the argmin value. Use init to pass an initial guess or where
88
// the optimal value is.
89
//
90
// Usage example:
91
//
92
// RunSimplex(2, 0.1, 100, [](const vector<float>& v) {
93
//   return (v[0] - 5) * (v[0] - 5) + (v[1] - 7) * (v[1] - 7);
94
// });
95
//
96
// Returns (0.0, 5, 7)
97
std::vector<double> RunSimplex(
98
    int dim, double amount, int max_iterations,
99
    const std::function<double(const std::vector<double>&)>& fun);
100
std::vector<double> RunSimplex(
101
    int dim, double amount, int max_iterations, const std::vector<double>& init,
102
    const std::function<double(const std::vector<double>&)>& fun);
103
104
// Implementation of the Scaled Conjugate Gradient method described in the
105
// following paper:
106
//   Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised
107
//   Learning", Neural Networks, Vol. 6. pp. 525-533, 1993
108
//   http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf
109
//
110
// The Function template parameter is a class that has the following method:
111
//
112
//   // Returns the value of the function at point w and sets *df to be the
113
//   // negative gradient vector of the function at point w.
114
//   double Compute(const optimize::Array<T, N>& w,
115
//                  optimize::Array<T, N>* df) const;
116
//
117
// Returns a vector w, such that |df(w)| < grad_norm_threshold.
118
template <typename T, size_t N, typename Function>
119
Array<T, N> OptimizeWithScaledConjugateGradientMethod(
120
    const Function& f, const Array<T, N>& w0, const T grad_norm_threshold,
121
0
    size_t max_iters) {
122
0
  const size_t n = w0.size();
123
0
  const T rsq_threshold = grad_norm_threshold * grad_norm_threshold;
124
0
  const T sigma0 = static_cast<T>(0.0001);
125
0
  const T l_min = static_cast<T>(1.0e-15);
126
0
  const T l_max = static_cast<T>(1.0e15);
127
128
0
  Array<T, N> w = w0;
129
0
  Array<T, N> wp;
130
0
  Array<T, N> r;
131
0
  Array<T, N> rt;
132
0
  Array<T, N> e;
133
0
  Array<T, N> p;
134
0
  T psq;
135
0
  T fp;
136
0
  T D;
137
0
  T d;
138
0
  T m;
139
0
  T a;
140
0
  T b;
141
0
  T s;
142
0
  T t;
143
144
0
  T fw = f.Compute(w, &r);
145
0
  T rsq = r * r;
146
0
  e = r;
147
0
  p = r;
148
0
  T l = static_cast<T>(1.0);
149
0
  bool success = true;
150
0
  size_t n_success = 0;
151
0
  size_t k = 0;
152
153
0
  while (k++ < max_iters) {
154
0
    if (success) {
155
0
      m = -(p * r);
156
0
      if (m >= 0) {
157
0
        p = r;
158
0
        m = -(p * r);
159
0
      }
160
0
      psq = p * p;
161
0
      s = sigma0 / std::sqrt(psq);
162
0
      f.Compute(w + (s * p), &rt);
163
0
      t = (p * (r - rt)) / s;
164
0
    }
165
166
0
    d = t + l * psq;
167
0
    if (d <= 0) {
168
0
      d = l * psq;
169
0
      l = l - t / psq;
170
0
    }
171
172
0
    a = -m / d;
173
0
    wp = w + a * p;
174
0
    fp = f.Compute(wp, &rt);
175
176
0
    D = 2.0 * (fp - fw) / (a * m);
177
0
    if (D >= 0.0) {
178
0
      success = true;
179
0
      n_success++;
180
0
      w = wp;
181
0
    } else {
182
0
      success = false;
183
0
    }
184
185
0
    if (success) {
186
0
      e = r;
187
0
      r = rt;
188
0
      rsq = r * r;
189
0
      fw = fp;
190
0
      if (rsq <= rsq_threshold) {
191
0
        break;
192
0
      }
193
0
    }
194
195
0
    if (D < 0.25) {
196
0
      l = std::min(4.0 * l, l_max);
197
0
    } else if (D > 0.75) {
198
0
      l = std::max(0.25 * l, l_min);
199
0
    }
200
201
0
    if ((n_success % n) == 0) {
202
0
      p = r;
203
0
      l = 1.0;
204
0
    } else if (success) {
205
0
      b = ((e - r) * r) / m;
206
0
      p = b * p + r;
207
0
    }
208
0
  }
209
210
0
  return w;
211
0
}
212
213
}  // namespace optimize
214
}  // namespace jxl
215
216
#endif  // LIB_JXL_OPTIMIZE_H_