/src/libjxl/lib/jxl/enc_optimize.h
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | // Utility functions for optimizing multi-dimensional nonlinear functions. |
7 | | |
8 | | #ifndef LIB_JXL_OPTIMIZE_H_ |
9 | | #define LIB_JXL_OPTIMIZE_H_ |
10 | | |
11 | | #include <cmath> |
12 | | #include <cstdio> |
13 | | #include <functional> |
14 | | #include <vector> |
15 | | |
16 | | #include "lib/jxl/base/status.h" |
17 | | |
18 | | namespace jxl { |
19 | | namespace optimize { |
20 | | |
21 | | // An array type of numeric values that supports math operations with operator-, |
22 | | // operator+, etc. |
23 | | template <typename T, size_t N> |
24 | | class Array { |
25 | | public: |
26 | | Array() = default; |
27 | | explicit Array(T v) { |
28 | | for (size_t i = 0; i < N; i++) v_[i] = v; |
29 | | } |
30 | | |
31 | 0 | size_t size() const { return N; } |
32 | | |
33 | 0 | T& operator[](size_t index) { |
34 | 0 | JXL_DASSERT(index < N); |
35 | 0 | return v_[index]; |
36 | 0 | } |
37 | 0 | T operator[](size_t index) const { |
38 | 0 | JXL_DASSERT(index < N); |
39 | 0 | return v_[index]; |
40 | 0 | } |
41 | | |
42 | | private: |
43 | | // The values used by this Array. |
44 | | T v_[N]; |
45 | | }; |
46 | | |
47 | | template <typename T, size_t N> |
48 | 0 | Array<T, N> operator+(const Array<T, N>& x, const Array<T, N>& y) { |
49 | 0 | Array<T, N> z; |
50 | 0 | for (size_t i = 0; i < N; ++i) { |
51 | 0 | z[i] = x[i] + y[i]; |
52 | 0 | } |
53 | 0 | return z; |
54 | 0 | } |
55 | | |
56 | | template <typename T, size_t N> |
57 | 0 | Array<T, N> operator-(const Array<T, N>& x, const Array<T, N>& y) { |
58 | 0 | Array<T, N> z; |
59 | 0 | for (size_t i = 0; i < N; ++i) { |
60 | 0 | z[i] = x[i] - y[i]; |
61 | 0 | } |
62 | 0 | return z; |
63 | 0 | } |
64 | | |
65 | | template <typename T, size_t N> |
66 | 0 | Array<T, N> operator*(T v, const Array<T, N>& x) { |
67 | 0 | Array<T, N> y; |
68 | 0 | for (size_t i = 0; i < N; ++i) { |
69 | 0 | y[i] = v * x[i]; |
70 | 0 | } |
71 | 0 | return y; |
72 | 0 | } |
73 | | |
74 | | template <typename T, size_t N> |
75 | 0 | T operator*(const Array<T, N>& x, const Array<T, N>& y) { |
76 | 0 | T r = 0.0; |
77 | 0 | for (size_t i = 0; i < N; ++i) { |
78 | 0 | r += x[i] * y[i]; |
79 | 0 | } |
80 | 0 | return r; |
81 | 0 | } |
82 | | |
83 | | // Runs Nelder-Mead like optimization. Runs for max_iterations times, |
84 | | // fun gets called with a vector of size dim as argument, and returns the score |
85 | | // based on those parameters (lower is better). Returns a vector of dim+1 |
86 | | // dimensions, where the first value is the optimal value of the function and |
87 | | // the rest is the argmin value. Use init to pass an initial guess or where |
88 | | // the optimal value is. |
89 | | // |
90 | | // Usage example: |
91 | | // |
92 | | // RunSimplex(2, 0.1, 100, [](const vector<float>& v) { |
93 | | // return (v[0] - 5) * (v[0] - 5) + (v[1] - 7) * (v[1] - 7); |
94 | | // }); |
95 | | // |
96 | | // Returns (0.0, 5, 7) |
97 | | std::vector<double> RunSimplex( |
98 | | int dim, double amount, int max_iterations, |
99 | | const std::function<double(const std::vector<double>&)>& fun); |
100 | | std::vector<double> RunSimplex( |
101 | | int dim, double amount, int max_iterations, const std::vector<double>& init, |
102 | | const std::function<double(const std::vector<double>&)>& fun); |
103 | | |
104 | | // Implementation of the Scaled Conjugate Gradient method described in the |
105 | | // following paper: |
106 | | // Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised |
107 | | // Learning", Neural Networks, Vol. 6. pp. 525-533, 1993 |
108 | | // http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf |
109 | | // |
110 | | // The Function template parameter is a class that has the following method: |
111 | | // |
112 | | // // Returns the value of the function at point w and sets *df to be the |
113 | | // // negative gradient vector of the function at point w. |
114 | | // double Compute(const optimize::Array<T, N>& w, |
115 | | // optimize::Array<T, N>* df) const; |
116 | | // |
117 | | // Returns a vector w, such that |df(w)| < grad_norm_threshold. |
118 | | template <typename T, size_t N, typename Function> |
119 | | Array<T, N> OptimizeWithScaledConjugateGradientMethod( |
120 | | const Function& f, const Array<T, N>& w0, const T grad_norm_threshold, |
121 | 0 | size_t max_iters) { |
122 | 0 | const size_t n = w0.size(); |
123 | 0 | const T rsq_threshold = grad_norm_threshold * grad_norm_threshold; |
124 | 0 | const T sigma0 = static_cast<T>(0.0001); |
125 | 0 | const T l_min = static_cast<T>(1.0e-15); |
126 | 0 | const T l_max = static_cast<T>(1.0e15); |
127 | |
|
128 | 0 | Array<T, N> w = w0; |
129 | 0 | Array<T, N> wp; |
130 | 0 | Array<T, N> r; |
131 | 0 | Array<T, N> rt; |
132 | 0 | Array<T, N> e; |
133 | 0 | Array<T, N> p; |
134 | 0 | T psq; |
135 | 0 | T fp; |
136 | 0 | T D; |
137 | 0 | T d; |
138 | 0 | T m; |
139 | 0 | T a; |
140 | 0 | T b; |
141 | 0 | T s; |
142 | 0 | T t; |
143 | |
|
144 | 0 | T fw = f.Compute(w, &r); |
145 | 0 | T rsq = r * r; |
146 | 0 | e = r; |
147 | 0 | p = r; |
148 | 0 | T l = static_cast<T>(1.0); |
149 | 0 | bool success = true; |
150 | 0 | size_t n_success = 0; |
151 | 0 | size_t k = 0; |
152 | |
|
153 | 0 | while (k++ < max_iters) { |
154 | 0 | if (success) { |
155 | 0 | m = -(p * r); |
156 | 0 | if (m >= 0) { |
157 | 0 | p = r; |
158 | 0 | m = -(p * r); |
159 | 0 | } |
160 | 0 | psq = p * p; |
161 | 0 | s = sigma0 / std::sqrt(psq); |
162 | 0 | f.Compute(w + (s * p), &rt); |
163 | 0 | t = (p * (r - rt)) / s; |
164 | 0 | } |
165 | |
|
166 | 0 | d = t + l * psq; |
167 | 0 | if (d <= 0) { |
168 | 0 | d = l * psq; |
169 | 0 | l = l - t / psq; |
170 | 0 | } |
171 | |
|
172 | 0 | a = -m / d; |
173 | 0 | wp = w + a * p; |
174 | 0 | fp = f.Compute(wp, &rt); |
175 | |
|
176 | 0 | D = 2.0 * (fp - fw) / (a * m); |
177 | 0 | if (D >= 0.0) { |
178 | 0 | success = true; |
179 | 0 | n_success++; |
180 | 0 | w = wp; |
181 | 0 | } else { |
182 | 0 | success = false; |
183 | 0 | } |
184 | |
|
185 | 0 | if (success) { |
186 | 0 | e = r; |
187 | 0 | r = rt; |
188 | 0 | rsq = r * r; |
189 | 0 | fw = fp; |
190 | 0 | if (rsq <= rsq_threshold) { |
191 | 0 | break; |
192 | 0 | } |
193 | 0 | } |
194 | | |
195 | 0 | if (D < 0.25) { |
196 | 0 | l = std::min(4.0 * l, l_max); |
197 | 0 | } else if (D > 0.75) { |
198 | 0 | l = std::max(0.25 * l, l_min); |
199 | 0 | } |
200 | |
|
201 | 0 | if ((n_success % n) == 0) { |
202 | 0 | p = r; |
203 | 0 | l = 1.0; |
204 | 0 | } else if (success) { |
205 | 0 | b = ((e - r) * r) / m; |
206 | 0 | p = b * p + r; |
207 | 0 | } |
208 | 0 | } |
209 | |
|
210 | 0 | return w; |
211 | 0 | } |
212 | | |
213 | | } // namespace optimize |
214 | | } // namespace jxl |
215 | | |
216 | | #endif // LIB_JXL_OPTIMIZE_H_ |