/src/libjxl/lib/jxl/enc_optimize.h
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | // Utility functions for optimizing multi-dimensional nonlinear functions. |
7 | | |
8 | | #ifndef LIB_JXL_OPTIMIZE_H_ |
9 | | #define LIB_JXL_OPTIMIZE_H_ |
10 | | |
11 | | #include <cmath> |
12 | | #include <cstdio> |
13 | | |
14 | | #include "lib/jxl/base/status.h" |
15 | | |
16 | | namespace jxl { |
17 | | namespace optimize { |
18 | | |
19 | | // An array type of numeric values that supports math operations with operator-, |
20 | | // operator+, etc. |
21 | | template <typename T, size_t N> |
22 | | class Array { |
23 | | public: |
24 | | Array() = default; |
25 | | explicit Array(T v) { |
26 | | for (size_t i = 0; i < N; i++) v_[i] = v; |
27 | | } |
28 | | |
29 | 0 | size_t size() const { return N; } |
30 | | |
31 | 0 | T& operator[](size_t index) { |
32 | 0 | JXL_DASSERT(index < N); |
33 | 0 | return v_[index]; |
34 | 0 | } |
35 | 0 | T operator[](size_t index) const { |
36 | 0 | JXL_DASSERT(index < N); |
37 | 0 | return v_[index]; |
38 | 0 | } |
39 | | |
40 | | private: |
41 | | // The values used by this Array. |
42 | | T v_[N]; |
43 | | }; |
44 | | |
45 | | template <typename T, size_t N> |
46 | 0 | Array<T, N> operator+(const Array<T, N>& x, const Array<T, N>& y) { |
47 | 0 | Array<T, N> z; |
48 | 0 | for (size_t i = 0; i < N; ++i) { |
49 | 0 | z[i] = x[i] + y[i]; |
50 | 0 | } |
51 | 0 | return z; |
52 | 0 | } |
53 | | |
54 | | template <typename T, size_t N> |
55 | 0 | Array<T, N> operator-(const Array<T, N>& x, const Array<T, N>& y) { |
56 | 0 | Array<T, N> z; |
57 | 0 | for (size_t i = 0; i < N; ++i) { |
58 | 0 | z[i] = x[i] - y[i]; |
59 | 0 | } |
60 | 0 | return z; |
61 | 0 | } |
62 | | |
63 | | template <typename T, size_t N> |
64 | 0 | Array<T, N> operator*(T v, const Array<T, N>& x) { |
65 | 0 | Array<T, N> y; |
66 | 0 | for (size_t i = 0; i < N; ++i) { |
67 | 0 | y[i] = v * x[i]; |
68 | 0 | } |
69 | 0 | return y; |
70 | 0 | } |
71 | | |
72 | | template <typename T, size_t N> |
73 | 0 | T operator*(const Array<T, N>& x, const Array<T, N>& y) { |
74 | 0 | T r = 0.0; |
75 | 0 | for (size_t i = 0; i < N; ++i) { |
76 | 0 | r += x[i] * y[i]; |
77 | 0 | } |
78 | 0 | return r; |
79 | 0 | } |
80 | | |
81 | | // Implementation of the Scaled Conjugate Gradient method described in the |
82 | | // following paper: |
83 | | // Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised |
84 | | // Learning", Neural Networks, Vol. 6. pp. 525-533, 1993 |
85 | | // http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf |
86 | | // |
87 | | // The Function template parameter is a class that has the following method: |
88 | | // |
89 | | // // Returns the value of the function at point w and sets *df to be the |
90 | | // // negative gradient vector of the function at point w. |
91 | | // double Compute(const optimize::Array<T, N>& w, |
92 | | // optimize::Array<T, N>* df) const; |
93 | | // |
94 | | // Returns a vector w, such that |df(w)| < grad_norm_threshold. |
95 | | template <typename T, size_t N, typename Function> |
96 | | Array<T, N> OptimizeWithScaledConjugateGradientMethod( |
97 | | const Function& f, const Array<T, N>& w0, const T grad_norm_threshold, |
98 | 0 | size_t max_iters) { |
99 | 0 | const size_t n = w0.size(); |
100 | 0 | const T rsq_threshold = grad_norm_threshold * grad_norm_threshold; |
101 | 0 | const T sigma0 = static_cast<T>(0.0001); |
102 | 0 | const T l_min = static_cast<T>(1.0e-15); |
103 | 0 | const T l_max = static_cast<T>(1.0e15); |
104 | |
|
105 | 0 | Array<T, N> w = w0; |
106 | 0 | Array<T, N> wp; |
107 | 0 | Array<T, N> r; |
108 | 0 | Array<T, N> rt; |
109 | 0 | Array<T, N> e; |
110 | 0 | Array<T, N> p; |
111 | 0 | T psq; |
112 | 0 | T fp; |
113 | 0 | T D; |
114 | 0 | T d; |
115 | 0 | T m; |
116 | 0 | T a; |
117 | 0 | T b; |
118 | 0 | T s; |
119 | 0 | T t; |
120 | |
|
121 | 0 | T fw = f.Compute(w, &r); |
122 | 0 | T rsq = r * r; |
123 | 0 | e = r; |
124 | 0 | p = r; |
125 | 0 | T l = static_cast<T>(1.0); |
126 | 0 | bool success = true; |
127 | 0 | size_t n_success = 0; |
128 | 0 | size_t k = 0; |
129 | |
|
130 | 0 | while (k++ < max_iters) { |
131 | 0 | if (success) { |
132 | 0 | m = -(p * r); |
133 | 0 | if (m >= 0) { |
134 | 0 | p = r; |
135 | 0 | m = -(p * r); |
136 | 0 | } |
137 | 0 | psq = p * p; |
138 | 0 | s = sigma0 / std::sqrt(psq); |
139 | 0 | f.Compute(w + (s * p), &rt); |
140 | 0 | t = (p * (r - rt)) / s; |
141 | 0 | } |
142 | |
|
143 | 0 | d = t + l * psq; |
144 | 0 | if (d <= 0) { |
145 | 0 | d = l * psq; |
146 | 0 | l = l - t / psq; |
147 | 0 | } |
148 | |
|
149 | 0 | a = -m / d; |
150 | 0 | wp = w + a * p; |
151 | 0 | fp = f.Compute(wp, &rt); |
152 | |
|
153 | 0 | D = 2.0 * (fp - fw) / (a * m); |
154 | 0 | if (D >= 0.0) { |
155 | 0 | success = true; |
156 | 0 | n_success++; |
157 | 0 | w = wp; |
158 | 0 | } else { |
159 | 0 | success = false; |
160 | 0 | } |
161 | |
|
162 | 0 | if (success) { |
163 | 0 | e = r; |
164 | 0 | r = rt; |
165 | 0 | rsq = r * r; |
166 | 0 | fw = fp; |
167 | 0 | if (rsq <= rsq_threshold) { |
168 | 0 | break; |
169 | 0 | } |
170 | 0 | } |
171 | | |
172 | 0 | if (D < 0.25) { |
173 | 0 | l = std::min(4.0 * l, l_max); |
174 | 0 | } else if (D > 0.75) { |
175 | 0 | l = std::max(0.25 * l, l_min); |
176 | 0 | } |
177 | |
|
178 | 0 | if ((n_success % n) == 0) { |
179 | 0 | p = r; |
180 | 0 | l = 1.0; |
181 | 0 | } else if (success) { |
182 | 0 | b = ((e - r) * r) / m; |
183 | 0 | p = b * p + r; |
184 | 0 | } |
185 | 0 | } |
186 | |
|
187 | 0 | return w; |
188 | 0 | } |
189 | | |
190 | | } // namespace optimize |
191 | | } // namespace jxl |
192 | | |
193 | | #endif // LIB_JXL_OPTIMIZE_H_ |