/src/libjxl/lib/extras/metrics.cc
Line | Count | Source |
1 | | // Copyright (c) the JPEG XL Project Authors. All rights reserved. |
2 | | // |
3 | | // Use of this source code is governed by a BSD-style |
4 | | // license that can be found in the LICENSE file. |
5 | | |
6 | | #include "lib/extras/metrics.h" |
7 | | |
8 | | #include <jxl/cms_interface.h> |
9 | | #include <jxl/memory_manager.h> |
10 | | |
11 | | #include <atomic> |
12 | | #include <cmath> |
13 | | #include <cstdlib> |
14 | | #include <limits> |
15 | | |
16 | | #include "lib/jxl/butteraugli/butteraugli.h" |
17 | | #include "lib/jxl/image.h" |
18 | | #include "lib/jxl/image_bundle.h" |
19 | | #include "lib/jxl/image_ops.h" |
20 | | |
21 | | #undef HWY_TARGET_INCLUDE |
22 | | #define HWY_TARGET_INCLUDE "lib/extras/metrics.cc" |
23 | | #include <hwy/foreach_target.h> |
24 | | #include <hwy/highway.h> |
25 | | |
26 | | #include "lib/jxl/base/compiler_specific.h" |
27 | | #include "lib/jxl/base/rect.h" |
28 | | #include "lib/jxl/base/status.h" |
29 | | #include "lib/jxl/color_encoding_internal.h" |
30 | | #include "lib/jxl/memory_manager_internal.h" |
31 | | HWY_BEFORE_NAMESPACE(); |
32 | | namespace jxl { |
33 | | namespace HWY_NAMESPACE { |
34 | | |
35 | | // These templates are not found via ADL. |
36 | | using hwy::HWY_NAMESPACE::Add; |
37 | | using hwy::HWY_NAMESPACE::GetLane; |
38 | | using hwy::HWY_NAMESPACE::Mul; |
39 | | using hwy::HWY_NAMESPACE::Rebind; |
40 | | |
41 | | StatusOr<double> ComputeDistanceP(const ImageF& distmap, |
42 | 0 | const ButteraugliParams& params, double p) { |
43 | 0 | if (distmap.xsize() == 0 || distmap.ysize() == 0) { |
44 | 0 | return 0.0; |
45 | 0 | } |
46 | 0 | JxlMemoryManager* memory_manager = distmap.memory_manager(); |
47 | 0 | JXL_ENSURE(memory_manager != nullptr); |
48 | 0 | const double onePerPixels = 1.0 / (distmap.ysize() * distmap.xsize()); |
49 | 0 | if (std::abs(p - 3.0) < 1E-6) { |
50 | 0 | double sum1[3] = {0.0}; |
51 | | |
52 | | // Prefer double if possible, but otherwise use float rather than scalar. |
53 | 0 | #if HWY_CAP_FLOAT64 |
54 | 0 | using T = double; |
55 | 0 | const Rebind<float, HWY_FULL(double)> df; |
56 | | #else |
57 | | using T = float; |
58 | | #endif |
59 | 0 | const HWY_FULL(T) d; |
60 | 0 | JXL_ASSIGN_OR_RETURN( |
61 | 0 | AlignedMemory sum_totals, |
62 | 0 | AlignedMemory::Create(memory_manager, 3 * Lanes(d) * sizeof(T))); |
63 | | // Manually aligned storage to avoid asan crash on clang-7 due to |
64 | | // unaligned spill. |
65 | 0 | T* sum_totals0 = sum_totals.address<T>(); |
66 | 0 | T* sum_totals1 = sum_totals0 + Lanes(d); |
67 | 0 | T* sum_totals2 = sum_totals1 + Lanes(d); |
68 | 0 | Store(Zero(d), d, sum_totals0); |
69 | 0 | Store(Zero(d), d, sum_totals1); |
70 | 0 | Store(Zero(d), d, sum_totals2); |
71 | |
|
72 | 0 | for (size_t y = 0; y < distmap.ysize(); ++y) { |
73 | 0 | const float* JXL_RESTRICT row = distmap.ConstRow(y); |
74 | |
|
75 | 0 | auto sums0 = Zero(d); |
76 | 0 | auto sums1 = Zero(d); |
77 | 0 | auto sums2 = Zero(d); |
78 | |
|
79 | 0 | size_t x = 0; |
80 | 0 | for (; x + Lanes(d) <= distmap.xsize(); x += Lanes(d)) { |
81 | 0 | #if HWY_CAP_FLOAT64 |
82 | 0 | const auto d1 = PromoteTo(d, Load(df, row + x)); |
83 | | #else |
84 | | const auto d1 = Load(d, row + x); |
85 | | #endif |
86 | 0 | const auto d2 = Mul(d1, Mul(d1, d1)); |
87 | 0 | sums0 = Add(sums0, d2); |
88 | 0 | const auto d3 = Mul(d2, d2); |
89 | 0 | sums1 = Add(sums1, d3); |
90 | 0 | const auto d4 = Mul(d3, d3); |
91 | 0 | sums2 = Add(sums2, d4); |
92 | 0 | } |
93 | |
|
94 | 0 | Store(Add(sums0, Load(d, sum_totals0)), d, sum_totals0); |
95 | 0 | Store(Add(sums1, Load(d, sum_totals1)), d, sum_totals1); |
96 | 0 | Store(Add(sums2, Load(d, sum_totals2)), d, sum_totals2); |
97 | |
|
98 | 0 | for (; x < distmap.xsize(); ++x) { |
99 | 0 | const double d1 = row[x]; |
100 | 0 | double d2 = d1 * d1 * d1; |
101 | 0 | sum1[0] += d2; |
102 | 0 | d2 *= d2; |
103 | 0 | sum1[1] += d2; |
104 | 0 | d2 *= d2; |
105 | 0 | sum1[2] += d2; |
106 | 0 | } |
107 | 0 | } |
108 | 0 | double v = 0; |
109 | 0 | v += pow( |
110 | 0 | onePerPixels * (sum1[0] + GetLane(SumOfLanes(d, Load(d, sum_totals0)))), |
111 | 0 | 1.0 / (p * 1.0)); |
112 | 0 | v += pow( |
113 | 0 | onePerPixels * (sum1[1] + GetLane(SumOfLanes(d, Load(d, sum_totals1)))), |
114 | 0 | 1.0 / (p * 2.0)); |
115 | 0 | v += pow( |
116 | 0 | onePerPixels * (sum1[2] + GetLane(SumOfLanes(d, Load(d, sum_totals2)))), |
117 | 0 | 1.0 / (p * 4.0)); |
118 | 0 | v /= 3.0; |
119 | 0 | return v; |
120 | 0 | } else { |
121 | 0 | static std::atomic<uint32_t> once{0}; |
122 | 0 | if (once.fetch_add(1, std::memory_order_relaxed) == 0) { |
123 | 0 | JXL_WARNING("WARNING: using slow ComputeDistanceP"); |
124 | 0 | } |
125 | 0 | double sum1[3] = {0.0}; |
126 | 0 | for (size_t y = 0; y < distmap.ysize(); ++y) { |
127 | 0 | const float* JXL_RESTRICT row = distmap.ConstRow(y); |
128 | 0 | for (size_t x = 0; x < distmap.xsize(); ++x) { |
129 | 0 | double d2 = std::pow(row[x], p); |
130 | 0 | sum1[0] += d2; |
131 | 0 | d2 *= d2; |
132 | 0 | sum1[1] += d2; |
133 | 0 | d2 *= d2; |
134 | 0 | sum1[2] += d2; |
135 | 0 | } |
136 | 0 | } |
137 | 0 | double v = 0; |
138 | 0 | for (int i = 0; i < 3; ++i) { |
139 | 0 | v += pow(onePerPixels * (sum1[i]), 1.0 / (p * (1 << i))); |
140 | 0 | } |
141 | 0 | v /= 3.0; |
142 | 0 | return v; |
143 | 0 | } |
144 | 0 | } Unexecuted instantiation: jxl::N_SSE4::ComputeDistanceP(jxl::Plane<float> const&, jxl::ButteraugliParams const&, double) Unexecuted instantiation: jxl::N_AVX2::ComputeDistanceP(jxl::Plane<float> const&, jxl::ButteraugliParams const&, double) Unexecuted instantiation: jxl::N_SSE2::ComputeDistanceP(jxl::Plane<float> const&, jxl::ButteraugliParams const&, double) |
145 | | |
146 | | void ComputeSumOfSquares(const ImageBundle& ib1, const ImageBundle& ib2, |
147 | 0 | const JxlCmsInterface& cms, double sum_of_squares[3]) { |
148 | 0 | sum_of_squares[0] = sum_of_squares[1] = sum_of_squares[2] = |
149 | 0 | std::numeric_limits<double>::max(); |
150 | | // Convert to sRGB - closer to perception than linear. |
151 | 0 | const Image3F* srgb1 = &ib1.color(); |
152 | 0 | Image3F copy1; |
153 | 0 | if (!ib1.IsSRGB()) { |
154 | 0 | if (!ib1.CopyTo(Rect(ib1), ColorEncoding::SRGB(ib1.IsGray()), cms, ©1)) |
155 | 0 | return; |
156 | 0 | srgb1 = ©1; |
157 | 0 | } |
158 | 0 | const Image3F* srgb2 = &ib2.color(); |
159 | 0 | Image3F copy2; |
160 | 0 | if (!ib2.IsSRGB()) { |
161 | 0 | if (!ib2.CopyTo(Rect(ib2), ColorEncoding::SRGB(ib2.IsGray()), cms, ©2)) |
162 | 0 | return; |
163 | 0 | srgb2 = ©2; |
164 | 0 | } |
165 | | |
166 | 0 | if (!SameSize(*srgb1, *srgb2)) return; |
167 | | |
168 | 0 | sum_of_squares[0] = sum_of_squares[1] = sum_of_squares[2] = 0.0; |
169 | | |
170 | | // TODO(veluca): SIMD. |
171 | 0 | float yuvmatrix[3][3] = {{0.299, 0.587, 0.114}, |
172 | 0 | {-0.14713, -0.28886, 0.436}, |
173 | 0 | {0.615, -0.51499, -0.10001}}; |
174 | 0 | for (size_t y = 0; y < srgb1->ysize(); ++y) { |
175 | 0 | const float* JXL_RESTRICT row1[3]; |
176 | 0 | const float* JXL_RESTRICT row2[3]; |
177 | 0 | for (size_t j = 0; j < 3; j++) { |
178 | 0 | row1[j] = srgb1->ConstPlaneRow(j, y); |
179 | 0 | row2[j] = srgb2->ConstPlaneRow(j, y); |
180 | 0 | } |
181 | 0 | for (size_t x = 0; x < srgb1->xsize(); ++x) { |
182 | 0 | float cdiff[3] = {}; |
183 | | // YUV conversion is linear, so we can run it on the difference. |
184 | 0 | for (size_t j = 0; j < 3; j++) { |
185 | 0 | cdiff[j] = row1[j][x] - row2[j][x]; |
186 | 0 | } |
187 | 0 | float yuvdiff[3] = {}; |
188 | 0 | for (size_t j = 0; j < 3; j++) { |
189 | 0 | for (size_t k = 0; k < 3; k++) { |
190 | 0 | yuvdiff[j] += yuvmatrix[j][k] * cdiff[k]; |
191 | 0 | } |
192 | 0 | } |
193 | 0 | for (size_t j = 0; j < 3; j++) { |
194 | 0 | sum_of_squares[j] += static_cast<double>(yuvdiff[j]) * yuvdiff[j]; |
195 | 0 | } |
196 | 0 | } |
197 | 0 | } |
198 | 0 | } Unexecuted instantiation: jxl::N_SSE4::ComputeSumOfSquares(jxl::ImageBundle const&, jxl::ImageBundle const&, JxlCmsInterface const&, double*) Unexecuted instantiation: jxl::N_AVX2::ComputeSumOfSquares(jxl::ImageBundle const&, jxl::ImageBundle const&, JxlCmsInterface const&, double*) Unexecuted instantiation: jxl::N_SSE2::ComputeSumOfSquares(jxl::ImageBundle const&, jxl::ImageBundle const&, JxlCmsInterface const&, double*) |
199 | | |
200 | | // NOLINTNEXTLINE(google-readability-namespace-comments) |
201 | | } // namespace HWY_NAMESPACE |
202 | | } // namespace jxl |
203 | | HWY_AFTER_NAMESPACE(); |
204 | | |
205 | | #if HWY_ONCE |
206 | | namespace jxl { |
207 | | HWY_EXPORT(ComputeDistanceP); |
208 | | StatusOr<double> ComputeDistanceP(const ImageF& distmap, |
209 | 0 | const ButteraugliParams& params, double p) { |
210 | 0 | return HWY_DYNAMIC_DISPATCH(ComputeDistanceP)(distmap, params, p); |
211 | 0 | } |
212 | | |
213 | | HWY_EXPORT(ComputeSumOfSquares); |
214 | | |
215 | | double ComputeDistance2(const ImageBundle& ib1, const ImageBundle& ib2, |
216 | 0 | const JxlCmsInterface& cms) { |
217 | 0 | double sum_of_squares[3] = {}; |
218 | 0 | HWY_DYNAMIC_DISPATCH(ComputeSumOfSquares)(ib1, ib2, cms, sum_of_squares); |
219 | | // Weighted PSNR as in JPEG-XL: chroma counts 1/8. |
220 | 0 | const float weights[3] = {6.0f / 8, 1.0f / 8, 1.0f / 8}; |
221 | | // Avoid squaring the weight - 1/64 is too extreme. |
222 | 0 | double norm = 0; |
223 | 0 | for (size_t i = 0; i < 3; i++) { |
224 | 0 | norm += std::sqrt(sum_of_squares[i]) * weights[i]; |
225 | 0 | } |
226 | | // This function returns distance *squared*. |
227 | 0 | return norm * norm; |
228 | 0 | } |
229 | | |
230 | | double ComputePSNR(const ImageBundle& ib1, const ImageBundle& ib2, |
231 | 0 | const JxlCmsInterface& cms) { |
232 | 0 | if (!SameSize(ib1, ib2)) return 0.0; |
233 | 0 | double sum_of_squares[3] = {}; |
234 | 0 | HWY_DYNAMIC_DISPATCH(ComputeSumOfSquares)(ib1, ib2, cms, sum_of_squares); |
235 | 0 | constexpr double kChannelWeights[3] = {6.0 / 8, 1.0 / 8, 1.0 / 8}; |
236 | 0 | double avg_psnr = 0; |
237 | 0 | const size_t input_pixels = ib1.xsize() * ib1.ysize(); |
238 | 0 | for (int i = 0; i < 3; ++i) { |
239 | 0 | const double rmse = std::sqrt(sum_of_squares[i] / input_pixels); |
240 | 0 | const double psnr = |
241 | 0 | sum_of_squares[i] == 0 ? 99.99 : (20 * std::log10(1 / rmse)); |
242 | 0 | avg_psnr += kChannelWeights[i] * psnr; |
243 | 0 | } |
244 | 0 | return avg_psnr; |
245 | 0 | } |
246 | | |
247 | | } // namespace jxl |
248 | | #endif |