/rust/registry/src/index.crates.io-1949cf8c6b5b557f/av1-grain-0.2.4/src/diff/solver.rs
Line | Count | Source |
1 | | mod util; |
2 | | |
3 | | use std::ops::{Add, AddAssign}; |
4 | | |
5 | | use anyhow::anyhow; |
6 | | use arrayvec::ArrayVec; |
7 | | use v_frame::{frame::Frame, math::clamp, plane::Plane}; |
8 | | |
9 | | use self::util::{extract_ar_row, get_block_mean, get_noise_var, linsolve, multiply_mat}; |
10 | | use super::{NoiseStatus, BLOCK_SIZE, BLOCK_SIZE_SQUARED}; |
11 | | use crate::{ |
12 | | diff::solver::util::normalized_cross_correlation, GrainTableSegment, DEFAULT_GRAIN_SEED, |
13 | | NUM_UV_COEFFS, NUM_UV_POINTS, NUM_Y_COEFFS, NUM_Y_POINTS, |
14 | | }; |
15 | | |
16 | | const LOW_POLY_NUM_PARAMS: usize = 3; |
17 | | const NOISE_MODEL_LAG: usize = 3; |
18 | | const BLOCK_NORMALIZATION: f64 = 255.0f64; |
19 | | |
20 | | #[derive(Debug, Clone)] |
21 | | pub(super) struct FlatBlockFinder { |
22 | | a: Box<[f64]>, |
23 | | a_t_a_inv: [f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS], |
24 | | } |
25 | | |
26 | | impl FlatBlockFinder { |
27 | | #[must_use] |
28 | 0 | pub fn new() -> Self { |
29 | 0 | let mut eqns = EquationSystem::new(LOW_POLY_NUM_PARAMS); |
30 | 0 | let mut a_t_a_inv = [0.0f64; LOW_POLY_NUM_PARAMS * LOW_POLY_NUM_PARAMS]; |
31 | 0 | let mut a = vec![0.0f64; LOW_POLY_NUM_PARAMS * BLOCK_SIZE_SQUARED]; |
32 | | |
33 | 0 | let bs_half = (BLOCK_SIZE / 2) as f64; |
34 | 0 | (0..BLOCK_SIZE).for_each(|y| { |
35 | 0 | let yd = (y as f64 - bs_half) / bs_half; |
36 | 0 | (0..BLOCK_SIZE).for_each(|x| { |
37 | 0 | let xd = (x as f64 - bs_half) / bs_half; |
38 | 0 | let coords = [yd, xd, 1.0f64]; |
39 | 0 | let row = y * BLOCK_SIZE + x; |
40 | 0 | a[LOW_POLY_NUM_PARAMS * row] = yd; |
41 | 0 | a[LOW_POLY_NUM_PARAMS * row + 1] = xd; |
42 | 0 | a[LOW_POLY_NUM_PARAMS * row + 2] = 1.0f64; |
43 | | |
44 | 0 | (0..LOW_POLY_NUM_PARAMS).for_each(|i| { |
45 | 0 | (0..LOW_POLY_NUM_PARAMS).for_each(|j| { |
46 | 0 | eqns.a[LOW_POLY_NUM_PARAMS * i + j] += coords[i] * coords[j]; |
47 | 0 | }); |
48 | 0 | }); |
49 | 0 | }); |
50 | 0 | }); |
51 | | |
52 | | // Lazy inverse using existing equation solver. |
53 | 0 | (0..LOW_POLY_NUM_PARAMS).for_each(|i| { |
54 | 0 | eqns.b.fill(0.0f64); |
55 | 0 | eqns.b[i] = 1.0f64; |
56 | 0 | eqns.solve(); |
57 | | |
58 | 0 | (0..LOW_POLY_NUM_PARAMS).for_each(|j| { |
59 | 0 | a_t_a_inv[j * LOW_POLY_NUM_PARAMS + i] = eqns.x[j]; |
60 | 0 | }); |
61 | 0 | }); |
62 | | |
63 | 0 | FlatBlockFinder { |
64 | 0 | a: a.into_boxed_slice(), |
65 | 0 | a_t_a_inv, |
66 | 0 | } |
67 | 0 | } |
68 | | |
69 | | // The gradient-based features used in this code are based on: |
70 | | // A. Kokaram, D. Kelly, H. Denman and A. Crawford, "Measuring noise |
71 | | // correlation for improved video denoising," 2012 19th, ICIP. |
72 | | // The thresholds are more lenient to allow for correct grain modeling |
73 | | // in extreme cases. |
74 | | #[must_use] |
75 | | #[allow(clippy::too_many_lines)] |
76 | 0 | pub fn run(&self, plane: &Plane<u8>) -> (Vec<u8>, usize) { |
77 | | const TRACE_THRESHOLD: f64 = 0.15f64 / BLOCK_SIZE_SQUARED as f64; |
78 | | const RATIO_THRESHOLD: f64 = 1.25f64; |
79 | | const NORM_THRESHOLD: f64 = 0.08f64 / BLOCK_SIZE_SQUARED as f64; |
80 | | const VAR_THRESHOLD: f64 = 0.005f64 / BLOCK_SIZE_SQUARED as f64; |
81 | | |
82 | | // The following weights are used to combine the above features to give |
83 | | // a sigmoid score for flatness. If the input was normalized to [0,100] |
84 | | // the magnitude of these values would be close to 1 (e.g., weights |
85 | | // corresponding to variance would be a factor of 10000x smaller). |
86 | | const VAR_WEIGHT: f64 = -6682f64; |
87 | | const RATIO_WEIGHT: f64 = -0.2056f64; |
88 | | const TRACE_WEIGHT: f64 = 13087f64; |
89 | | const NORM_WEIGHT: f64 = -12434f64; |
90 | | const OFFSET: f64 = 2.5694f64; |
91 | | |
92 | 0 | let num_blocks_w = (plane.cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE; |
93 | 0 | let num_blocks_h = (plane.cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE; |
94 | 0 | let num_blocks = num_blocks_w * num_blocks_h; |
95 | 0 | let mut flat_blocks = vec![0u8; num_blocks]; |
96 | 0 | let mut num_flat = 0; |
97 | 0 | let mut plane_result = [0.0f64; BLOCK_SIZE_SQUARED]; |
98 | 0 | let mut block_result = [0.0f64; BLOCK_SIZE_SQUARED]; |
99 | 0 | let mut scores = vec![IndexAndScore::default(); num_blocks]; |
100 | | |
101 | 0 | for by in 0..num_blocks_h { |
102 | 0 | for bx in 0..num_blocks_w { |
103 | | // Compute gradient covariance matrix. |
104 | 0 | let mut gxx = 0f64; |
105 | 0 | let mut gxy = 0f64; |
106 | 0 | let mut gyy = 0f64; |
107 | 0 | let mut var = 0f64; |
108 | 0 | let mut mean = 0f64; |
109 | | |
110 | 0 | self.extract_block( |
111 | 0 | plane, |
112 | 0 | bx * BLOCK_SIZE, |
113 | 0 | by * BLOCK_SIZE, |
114 | 0 | &mut plane_result, |
115 | 0 | &mut block_result, |
116 | | ); |
117 | 0 | for yi in 1..(BLOCK_SIZE - 1) { |
118 | 0 | for xi in 1..(BLOCK_SIZE - 1) { |
119 | | // SAFETY: We know the size of `block_result` and that we cannot exceed the bounds of it |
120 | 0 | unsafe { |
121 | 0 | let result_ptr = block_result.as_ptr().add(yi * BLOCK_SIZE + xi); |
122 | 0 |
|
123 | 0 | let gx = (*result_ptr.add(1) - *result_ptr.sub(1)) / 2f64; |
124 | 0 | let gy = |
125 | 0 | (*result_ptr.add(BLOCK_SIZE) - *result_ptr.sub(BLOCK_SIZE)) / 2f64; |
126 | 0 | gxx += gx * gx; |
127 | 0 | gxy += gx * gy; |
128 | 0 | gyy += gy * gy; |
129 | 0 |
|
130 | 0 | let block_val = *result_ptr; |
131 | 0 | mean += block_val; |
132 | 0 | var += block_val * block_val; |
133 | 0 | } |
134 | | } |
135 | | } |
136 | 0 | let block_size_norm_factor = (BLOCK_SIZE - 2).pow(2) as f64; |
137 | 0 | mean /= block_size_norm_factor; |
138 | | |
139 | | // Normalize gradients by block_size. |
140 | 0 | gxx /= block_size_norm_factor; |
141 | 0 | gxy /= block_size_norm_factor; |
142 | 0 | gyy /= block_size_norm_factor; |
143 | 0 | var = mean.mul_add(-mean, var / block_size_norm_factor); |
144 | | |
145 | 0 | let trace = gxx + gyy; |
146 | 0 | let det = gxx.mul_add(gyy, -gxy.powi(2)); |
147 | 0 | let e_sub = (trace.mul_add(trace, -4f64 * det)).max(0.).sqrt(); |
148 | 0 | let e1 = (trace + e_sub) / 2.0f64; |
149 | 0 | let e2 = (trace - e_sub) / 2.0f64; |
150 | | // Spectral norm |
151 | 0 | let norm = e1; |
152 | 0 | let ratio = e1 / e2.max(1.0e-6_f64); |
153 | 0 | let is_flat = trace < TRACE_THRESHOLD |
154 | 0 | && ratio < RATIO_THRESHOLD |
155 | 0 | && norm < NORM_THRESHOLD |
156 | 0 | && var > VAR_THRESHOLD; |
157 | | |
158 | 0 | let sum_weights = NORM_WEIGHT.mul_add( |
159 | 0 | norm, |
160 | 0 | TRACE_WEIGHT.mul_add( |
161 | 0 | trace, |
162 | 0 | VAR_WEIGHT.mul_add(var, RATIO_WEIGHT.mul_add(ratio, OFFSET)), |
163 | | ), |
164 | | ); |
165 | | // clamp the value to [-25.0, 100.0] to prevent overflow |
166 | 0 | let sum_weights = clamp(sum_weights, -25.0f64, 100.0f64); |
167 | 0 | let score = (1.0f64 / (1.0f64 + (-sum_weights).exp())) as f32; |
168 | | // SAFETY: We know the size of `flat_blocks` and `scores` and that we cannot exceed the bounds of it |
169 | | unsafe { |
170 | 0 | let index = by * num_blocks_w + bx; |
171 | 0 | *flat_blocks.get_unchecked_mut(index) = if is_flat { 255 } else { 0 }; |
172 | 0 | *scores.get_unchecked_mut(index) = IndexAndScore { |
173 | 0 | score: if var > VAR_THRESHOLD { score } else { 0f32 }, |
174 | 0 | index, |
175 | | }; |
176 | | } |
177 | 0 | if is_flat { |
178 | 0 | num_flat += 1; |
179 | 0 | } |
180 | | } |
181 | | } |
182 | | |
183 | 0 | scores.sort_unstable_by(|a, b| a.score.partial_cmp(&b.score).expect("Shouldn't be NaN")); |
184 | | // SAFETY: We know the size of `flat_blocks` and `scores` and that we cannot exceed the bounds of it |
185 | | unsafe { |
186 | 0 | let top_nth_percentile = num_blocks * 90 / 100; |
187 | 0 | let score_threshold = scores.get_unchecked(top_nth_percentile).score; |
188 | 0 | for score in &scores { |
189 | 0 | if score.score >= score_threshold { |
190 | 0 | let block_ref = flat_blocks.get_unchecked_mut(score.index); |
191 | 0 | if *block_ref == 0 { |
192 | 0 | num_flat += 1; |
193 | 0 | } |
194 | 0 | *block_ref |= 1; |
195 | 0 | } |
196 | | } |
197 | | } |
198 | | |
199 | 0 | (flat_blocks, num_flat) |
200 | 0 | } |
201 | | |
202 | 0 | fn extract_block( |
203 | 0 | &self, |
204 | 0 | plane: &Plane<u8>, |
205 | 0 | offset_x: usize, |
206 | 0 | offset_y: usize, |
207 | 0 | plane_result: &mut [f64; BLOCK_SIZE_SQUARED], |
208 | 0 | block_result: &mut [f64; BLOCK_SIZE_SQUARED], |
209 | 0 | ) { |
210 | 0 | let mut plane_coords = [0f64; LOW_POLY_NUM_PARAMS]; |
211 | 0 | let mut a_t_a_inv_b = [0f64; LOW_POLY_NUM_PARAMS]; |
212 | 0 | let plane_origin = plane.data_origin(); |
213 | | |
214 | 0 | for yi in 0..BLOCK_SIZE { |
215 | 0 | let y = clamp(offset_y + yi, 0, plane.cfg.height - 1); |
216 | 0 | for xi in 0..BLOCK_SIZE { |
217 | 0 | let x = clamp(offset_x + xi, 0, plane.cfg.width - 1); |
218 | | // SAFETY: We know the bounds of the plane data and `block_result` |
219 | | // and do not exceed them. |
220 | 0 | unsafe { |
221 | 0 | *block_result.get_unchecked_mut(yi * BLOCK_SIZE + xi) = |
222 | 0 | f64::from(*plane_origin.get_unchecked(y * plane.cfg.stride + x)) |
223 | 0 | / BLOCK_NORMALIZATION; |
224 | 0 | } |
225 | | } |
226 | | } |
227 | | |
228 | 0 | multiply_mat( |
229 | 0 | block_result, |
230 | 0 | &self.a, |
231 | 0 | &mut a_t_a_inv_b, |
232 | | 1, |
233 | | BLOCK_SIZE_SQUARED, |
234 | | LOW_POLY_NUM_PARAMS, |
235 | | ); |
236 | 0 | multiply_mat( |
237 | 0 | &self.a_t_a_inv, |
238 | 0 | &a_t_a_inv_b, |
239 | 0 | &mut plane_coords, |
240 | | LOW_POLY_NUM_PARAMS, |
241 | | LOW_POLY_NUM_PARAMS, |
242 | | 1, |
243 | | ); |
244 | 0 | multiply_mat( |
245 | 0 | &self.a, |
246 | 0 | &plane_coords, |
247 | 0 | plane_result, |
248 | | BLOCK_SIZE_SQUARED, |
249 | | LOW_POLY_NUM_PARAMS, |
250 | | 1, |
251 | | ); |
252 | | |
253 | 0 | for (block_res, plane_res) in block_result.iter_mut().zip(plane_result.iter()) { |
254 | 0 | *block_res -= *plane_res; |
255 | 0 | } |
256 | 0 | } |
257 | | } |
258 | | |
259 | | #[derive(Debug, Clone, Copy, Default)] |
260 | | struct IndexAndScore { |
261 | | pub index: usize, |
262 | | pub score: f32, |
263 | | } |
264 | | |
265 | | /// Wrapper of data required to represent linear system of eqns and soln. |
266 | | #[derive(Debug, Clone)] |
267 | | struct EquationSystem { |
268 | | a: Vec<f64>, |
269 | | b: Vec<f64>, |
270 | | x: Vec<f64>, |
271 | | n: usize, |
272 | | } |
273 | | |
274 | | impl EquationSystem { |
275 | | #[must_use] |
276 | 0 | pub fn new(n: usize) -> Self { |
277 | 0 | Self { |
278 | 0 | a: vec![0.0f64; n * n], |
279 | 0 | b: vec![0.0f64; n], |
280 | 0 | x: vec![0.0f64; n], |
281 | 0 | n, |
282 | 0 | } |
283 | 0 | } |
284 | | |
285 | 0 | pub fn solve(&mut self) -> bool { |
286 | 0 | let n = self.n; |
287 | 0 | let mut a = self.a.clone(); |
288 | 0 | let mut b = self.b.clone(); |
289 | | |
290 | 0 | linsolve(n, &mut a, self.n, &mut b, &mut self.x) |
291 | 0 | } |
292 | | |
293 | 0 | pub fn set_chroma_coefficient_fallback_solution(&mut self) { |
294 | | const TOLERANCE: f64 = 1.0e-6f64; |
295 | 0 | let last = self.n - 1; |
296 | | // Set all of the AR coefficients to zero, but try to solve for correlation |
297 | | // with the luma channel |
298 | 0 | self.x.fill(0f64); |
299 | 0 | if self.a[last * self.n + last].abs() > TOLERANCE { |
300 | 0 | self.x[last] = self.b[last] / self.a[last * self.n + last]; |
301 | 0 | } |
302 | 0 | } |
303 | | |
304 | 0 | pub fn copy_from(&mut self, other: &Self) { |
305 | 0 | assert_eq!(self.n, other.n); |
306 | 0 | self.a.copy_from_slice(&other.a); |
307 | 0 | self.x.copy_from_slice(&other.x); |
308 | 0 | self.b.copy_from_slice(&other.b); |
309 | 0 | } |
310 | | |
311 | 0 | pub fn clear(&mut self) { |
312 | 0 | self.a.fill(0f64); |
313 | 0 | self.b.fill(0f64); |
314 | 0 | self.x.fill(0f64); |
315 | 0 | } |
316 | | } |
317 | | |
318 | | impl Add<&EquationSystem> for EquationSystem { |
319 | | type Output = EquationSystem; |
320 | | |
321 | 0 | fn add(self, addend: &EquationSystem) -> Self::Output { |
322 | 0 | let mut dest = self.clone(); |
323 | 0 | let n = self.n; |
324 | 0 | for i in 0..n { |
325 | 0 | for j in 0..n { |
326 | 0 | dest.a[i * n + j] += addend.a[i * n + j]; |
327 | 0 | } |
328 | 0 | dest.b[i] += addend.b[i]; |
329 | | } |
330 | 0 | dest |
331 | 0 | } |
332 | | } |
333 | | |
334 | | impl AddAssign<&EquationSystem> for EquationSystem { |
335 | 0 | fn add_assign(&mut self, rhs: &EquationSystem) { |
336 | 0 | *self = self.clone() + rhs; |
337 | 0 | } |
338 | | } |
339 | | |
340 | | /// Representation of a piecewise linear curve |
341 | | /// |
342 | | /// Holds n points as (x, y) pairs, that store the curve. |
343 | | struct NoiseStrengthLut { |
344 | | points: Vec<[f64; 2]>, |
345 | | } |
346 | | |
347 | | impl NoiseStrengthLut { |
348 | | #[must_use] |
349 | 0 | pub fn new(num_bins: usize) -> Self { |
350 | 0 | assert!(num_bins > 0); |
351 | 0 | Self { |
352 | 0 | points: vec![[0f64; 2]; num_bins], |
353 | 0 | } |
354 | 0 | } |
355 | | } |
356 | | |
357 | | #[derive(Debug, Clone)] |
358 | | pub(super) struct NoiseModel { |
359 | | combined_state: [NoiseModelState; 3], |
360 | | latest_state: [NoiseModelState; 3], |
361 | | n: usize, |
362 | | coords: Vec<[isize; 2]>, |
363 | | } |
364 | | |
365 | | impl NoiseModel { |
366 | | #[must_use] |
367 | 0 | pub fn new() -> Self { |
368 | 0 | let n = Self::num_coeffs(); |
369 | 0 | let combined_state = [ |
370 | 0 | NoiseModelState::new(n), |
371 | 0 | NoiseModelState::new(n + 1), |
372 | 0 | NoiseModelState::new(n + 1), |
373 | 0 | ]; |
374 | 0 | let latest_state = [ |
375 | 0 | NoiseModelState::new(n), |
376 | 0 | NoiseModelState::new(n + 1), |
377 | 0 | NoiseModelState::new(n + 1), |
378 | 0 | ]; |
379 | 0 | let mut coords = Vec::new(); |
380 | | |
381 | 0 | let neg_lag = -(NOISE_MODEL_LAG as isize); |
382 | 0 | for y in neg_lag..=0 { |
383 | 0 | let max_x = if y == 0 { |
384 | 0 | -1isize |
385 | | } else { |
386 | 0 | NOISE_MODEL_LAG as isize |
387 | | }; |
388 | 0 | for x in neg_lag..=max_x { |
389 | 0 | coords.push([x, y]); |
390 | 0 | } |
391 | | } |
392 | 0 | assert!(n == coords.len()); |
393 | | |
394 | 0 | Self { |
395 | 0 | combined_state, |
396 | 0 | latest_state, |
397 | 0 | n, |
398 | 0 | coords, |
399 | 0 | } |
400 | 0 | } |
401 | | |
402 | 0 | pub fn update( |
403 | 0 | &mut self, |
404 | 0 | source: &Frame<u8>, |
405 | 0 | denoised: &Frame<u8>, |
406 | 0 | flat_blocks: &[u8], |
407 | 0 | ) -> NoiseStatus { |
408 | 0 | let num_blocks_w = (source.planes[0].cfg.width + BLOCK_SIZE - 1) / BLOCK_SIZE; |
409 | 0 | let num_blocks_h = (source.planes[0].cfg.height + BLOCK_SIZE - 1) / BLOCK_SIZE; |
410 | 0 | let mut y_model_different = false; |
411 | | |
412 | | // Clear the latest equation system |
413 | 0 | for i in 0..3 { |
414 | 0 | self.latest_state[i].eqns.clear(); |
415 | 0 | self.latest_state[i].num_observations = 0; |
416 | 0 | self.latest_state[i].strength_solver.clear(); |
417 | 0 | } |
418 | | |
419 | | // Check that we have enough flat blocks |
420 | 0 | let num_blocks = flat_blocks.iter().filter(|b| **b > 0).count(); |
421 | 0 | if num_blocks <= 1 { |
422 | 0 | return NoiseStatus::Error(anyhow!("Not enough flat blocks to update noise estimate")); |
423 | 0 | } |
424 | | |
425 | 0 | let frame_dims = (source.planes[0].cfg.width, source.planes[0].cfg.height); |
426 | 0 | for channel in 0..3 { |
427 | 0 | if source.planes[channel].data.is_empty() { |
428 | | // Monochrome source |
429 | 0 | break; |
430 | 0 | } |
431 | 0 | let is_chroma = channel > 0; |
432 | 0 | let alt_source = (channel > 0).then(|| &source.planes[0]); |
433 | 0 | let alt_denoised = (channel > 0).then(|| &denoised.planes[0]); |
434 | 0 | self.add_block_observations( |
435 | 0 | channel, |
436 | 0 | &source.planes[channel], |
437 | 0 | &denoised.planes[channel], |
438 | 0 | alt_source, |
439 | 0 | alt_denoised, |
440 | 0 | frame_dims, |
441 | 0 | flat_blocks, |
442 | 0 | num_blocks_w, |
443 | 0 | num_blocks_h, |
444 | | ); |
445 | 0 | if !self.latest_state[channel].ar_equation_system_solve(is_chroma) { |
446 | 0 | if is_chroma { |
447 | 0 | self.latest_state[channel] |
448 | 0 | .eqns |
449 | 0 | .set_chroma_coefficient_fallback_solution(); |
450 | 0 | } else { |
451 | 0 | return NoiseStatus::Error(anyhow!( |
452 | 0 | "Solving latest noise equation system failed on plane {}", |
453 | 0 | channel |
454 | 0 | )); |
455 | | } |
456 | 0 | } |
457 | 0 | self.add_noise_std_observations( |
458 | 0 | channel, |
459 | 0 | &source.planes[channel], |
460 | 0 | &denoised.planes[channel], |
461 | 0 | alt_source, |
462 | 0 | frame_dims, |
463 | 0 | flat_blocks, |
464 | 0 | num_blocks_w, |
465 | 0 | num_blocks_h, |
466 | | ); |
467 | 0 | if !self.latest_state[channel].strength_solver.solve() { |
468 | 0 | return NoiseStatus::Error(anyhow!( |
469 | 0 | "Failed to solve strength solver for latest state" |
470 | 0 | )); |
471 | 0 | } |
472 | | |
473 | | // Check noise characteristics and return if error |
474 | 0 | if channel == 0 |
475 | 0 | && self.combined_state[channel].strength_solver.num_equations > 0 |
476 | 0 | && self.is_different() |
477 | 0 | { |
478 | 0 | y_model_different = true; |
479 | 0 | } |
480 | | |
481 | 0 | if y_model_different { |
482 | 0 | continue; |
483 | 0 | } |
484 | | |
485 | 0 | self.combined_state[channel].num_observations += |
486 | 0 | self.latest_state[channel].num_observations; |
487 | 0 | self.combined_state[channel].eqns += &self.latest_state[channel].eqns; |
488 | 0 | if !self.combined_state[channel].ar_equation_system_solve(is_chroma) { |
489 | 0 | if is_chroma { |
490 | 0 | self.combined_state[channel] |
491 | 0 | .eqns |
492 | 0 | .set_chroma_coefficient_fallback_solution(); |
493 | 0 | } else { |
494 | 0 | return NoiseStatus::Error(anyhow!( |
495 | 0 | "Solving combined noise equation system failed on plane {}", |
496 | 0 | channel |
497 | 0 | )); |
498 | | } |
499 | 0 | } |
500 | | |
501 | 0 | self.combined_state[channel].strength_solver += |
502 | 0 | &self.latest_state[channel].strength_solver; |
503 | | |
504 | 0 | if !self.combined_state[channel].strength_solver.solve() { |
505 | 0 | return NoiseStatus::Error(anyhow!( |
506 | 0 | "Failed to solve strength solver for combined state" |
507 | 0 | )); |
508 | 0 | }; |
509 | | } |
510 | | |
511 | 0 | if y_model_different { |
512 | 0 | return NoiseStatus::DifferentType; |
513 | 0 | } |
514 | | |
515 | 0 | NoiseStatus::Ok |
516 | 0 | } |
517 | | |
518 | | #[allow(clippy::too_many_lines)] |
519 | | #[must_use] |
520 | 0 | pub fn get_grain_parameters(&self, start_ts: u64, end_ts: u64) -> GrainTableSegment { |
521 | | // Both the domain and the range of the scaling functions in the film_grain |
522 | | // are normalized to 8-bit (e.g., they are implicitly scaled during grain |
523 | | // synthesis). |
524 | 0 | let scaling_points_y = self.combined_state[0] |
525 | 0 | .strength_solver |
526 | 0 | .fit_piecewise(NUM_Y_POINTS) |
527 | 0 | .points; |
528 | 0 | let scaling_points_cb = self.combined_state[1] |
529 | 0 | .strength_solver |
530 | 0 | .fit_piecewise(NUM_UV_POINTS) |
531 | 0 | .points; |
532 | 0 | let scaling_points_cr = self.combined_state[2] |
533 | 0 | .strength_solver |
534 | 0 | .fit_piecewise(NUM_UV_POINTS) |
535 | 0 | .points; |
536 | | |
537 | 0 | let mut max_scaling_value: f64 = 1.0e-4f64; |
538 | 0 | for p in scaling_points_y |
539 | 0 | .iter() |
540 | 0 | .chain(scaling_points_cb.iter()) |
541 | 0 | .chain(scaling_points_cr.iter()) |
542 | 0 | .map(|p| p[1]) |
543 | | { |
544 | 0 | if p > max_scaling_value { |
545 | 0 | max_scaling_value = p; |
546 | 0 | } |
547 | | } |
548 | | |
549 | | // Scaling_shift values are in the range [8,11] |
550 | 0 | let max_scaling_value_log2 = |
551 | 0 | clamp((max_scaling_value.log2() + 1f64).floor() as u8, 2u8, 5u8); |
552 | 0 | let scale_factor = f64::from(1u32 << (8u8 - max_scaling_value_log2)); |
553 | 0 | let map_scaling_point = |p: [f64; 2]| { |
554 | 0 | [ |
555 | 0 | (p[0] + 0.5f64) as u8, |
556 | 0 | clamp(scale_factor.mul_add(p[1], 0.5f64) as i32, 0i32, 255i32) as u8, |
557 | 0 | ] |
558 | 0 | }; |
559 | | |
560 | 0 | let scaling_points_y: ArrayVec<_, NUM_Y_POINTS> = scaling_points_y |
561 | 0 | .into_iter() |
562 | 0 | .map(map_scaling_point) |
563 | 0 | .collect(); |
564 | 0 | let scaling_points_cb: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cb |
565 | 0 | .into_iter() |
566 | 0 | .map(map_scaling_point) |
567 | 0 | .collect(); |
568 | 0 | let scaling_points_cr: ArrayVec<_, NUM_UV_POINTS> = scaling_points_cr |
569 | 0 | .into_iter() |
570 | 0 | .map(map_scaling_point) |
571 | 0 | .collect(); |
572 | | |
573 | | // Convert the ar_coeffs into 8-bit values |
574 | 0 | let n_coeff = self.combined_state[0].eqns.n; |
575 | 0 | let mut max_coeff = 1.0e-4f64; |
576 | 0 | let mut min_coeff = 1.0e-4f64; |
577 | 0 | let mut y_corr = [0f64; 2]; |
578 | 0 | let mut avg_luma_strength = 0f64; |
579 | 0 | for c in 0..3 { |
580 | 0 | let eqns = &self.combined_state[c].eqns; |
581 | 0 | for i in 0..n_coeff { |
582 | 0 | if eqns.x[i] > max_coeff { |
583 | 0 | max_coeff = eqns.x[i]; |
584 | 0 | } |
585 | 0 | if eqns.x[i] < min_coeff { |
586 | 0 | min_coeff = eqns.x[i]; |
587 | 0 | } |
588 | | } |
589 | | |
590 | | // Since the correlation between luma/chroma was computed in an already |
591 | | // scaled space, we adjust it in the un-scaled space. |
592 | 0 | let solver = &self.combined_state[c].strength_solver; |
593 | | // Compute a weighted average of the strength for the channel. |
594 | 0 | let mut average_strength = 0f64; |
595 | 0 | let mut total_weight = 0f64; |
596 | 0 | for i in 0..solver.eqns.n { |
597 | 0 | let mut w = 0f64; |
598 | 0 | for j in 0..solver.eqns.n { |
599 | 0 | w += solver.eqns.a[i * solver.eqns.n + j]; |
600 | 0 | } |
601 | 0 | w = w.sqrt(); |
602 | 0 | average_strength += solver.eqns.x[i] * w; |
603 | 0 | total_weight += w; |
604 | | } |
605 | 0 | if total_weight.abs() < f64::EPSILON { |
606 | 0 | average_strength = 1f64; |
607 | 0 | } else { |
608 | 0 | average_strength /= total_weight; |
609 | 0 | } |
610 | 0 | if c == 0 { |
611 | 0 | avg_luma_strength = average_strength; |
612 | 0 | } else { |
613 | 0 | y_corr[c - 1] = avg_luma_strength * eqns.x[n_coeff] / average_strength; |
614 | 0 | max_coeff = max_coeff.max(y_corr[c - 1]); |
615 | 0 | min_coeff = min_coeff.min(y_corr[c - 1]); |
616 | 0 | } |
617 | | } |
618 | | |
619 | | // Shift value: AR coeffs range (values 6-9) |
620 | | // 6: [-2, 2), 7: [-1, 1), 8: [-0.5, 0.5), 9: [-0.25, 0.25) |
621 | 0 | let ar_coeff_shift = clamp( |
622 | 0 | 7i32 - (1.0f64 + max_coeff.log2().floor()).max((-min_coeff).log2().ceil()) as i32, |
623 | 0 | 6i32, |
624 | 0 | 9i32, |
625 | 0 | ) as u8; |
626 | 0 | let scale_ar_coeff = f64::from(1u16 << ar_coeff_shift); |
627 | 0 | let ar_coeffs_y = self.get_ar_coeffs_y(n_coeff, scale_ar_coeff); |
628 | 0 | let ar_coeffs_cb = self.get_ar_coeffs_uv(1, n_coeff, scale_ar_coeff, y_corr); |
629 | 0 | let ar_coeffs_cr = self.get_ar_coeffs_uv(2, n_coeff, scale_ar_coeff, y_corr); |
630 | | |
631 | | GrainTableSegment { |
632 | 0 | random_seed: if start_ts == 0 { DEFAULT_GRAIN_SEED } else { 0 }, |
633 | 0 | start_time: start_ts, |
634 | 0 | end_time: end_ts, |
635 | 0 | ar_coeff_lag: NOISE_MODEL_LAG as u8, |
636 | 0 | scaling_points_y, |
637 | 0 | scaling_points_cb, |
638 | 0 | scaling_points_cr, |
639 | 0 | scaling_shift: 5 + (8 - max_scaling_value_log2), |
640 | 0 | ar_coeff_shift, |
641 | 0 | ar_coeffs_y, |
642 | 0 | ar_coeffs_cb, |
643 | 0 | ar_coeffs_cr, |
644 | | // At the moment, the noise modeling code assumes that the chroma scaling |
645 | | // functions are a function of luma. |
646 | | cb_mult: 128, |
647 | | cb_luma_mult: 192, |
648 | | cb_offset: 256, |
649 | | cr_mult: 128, |
650 | | cr_luma_mult: 192, |
651 | | cr_offset: 256, |
652 | | chroma_scaling_from_luma: false, |
653 | | grain_scale_shift: 0, |
654 | | overlap_flag: true, |
655 | | } |
656 | 0 | } |
657 | | |
658 | 0 | pub fn save_latest(&mut self) { |
659 | 0 | for c in 0..3 { |
660 | 0 | let latest_state = &self.latest_state[c]; |
661 | 0 | let combined_state = &mut self.combined_state[c]; |
662 | 0 | combined_state.eqns.copy_from(&latest_state.eqns); |
663 | 0 | combined_state |
664 | 0 | .strength_solver |
665 | 0 | .eqns |
666 | 0 | .copy_from(&latest_state.strength_solver.eqns); |
667 | 0 | combined_state.strength_solver.num_equations = |
668 | 0 | latest_state.strength_solver.num_equations; |
669 | 0 | combined_state.num_observations = latest_state.num_observations; |
670 | 0 | combined_state.ar_gain = latest_state.ar_gain; |
671 | 0 | } |
672 | 0 | } |
673 | | |
674 | | #[must_use] |
675 | 0 | const fn num_coeffs() -> usize { |
676 | 0 | let n = 2 * NOISE_MODEL_LAG + 1; |
677 | 0 | (n * n) / 2 |
678 | 0 | } |
679 | | |
680 | | #[must_use] |
681 | 0 | fn get_ar_coeffs_y(&self, n_coeff: usize, scale_ar_coeff: f64) -> ArrayVec<i8, NUM_Y_COEFFS> { |
682 | 0 | assert!(n_coeff <= NUM_Y_COEFFS); |
683 | 0 | let mut coeffs = ArrayVec::new(); |
684 | 0 | let eqns = &self.combined_state[0].eqns; |
685 | 0 | for i in 0..n_coeff { |
686 | 0 | coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8); |
687 | 0 | } |
688 | 0 | coeffs |
689 | 0 | } |
690 | | |
691 | | #[must_use] |
692 | 0 | fn get_ar_coeffs_uv( |
693 | 0 | &self, |
694 | 0 | channel: usize, |
695 | 0 | n_coeff: usize, |
696 | 0 | scale_ar_coeff: f64, |
697 | 0 | y_corr: [f64; 2], |
698 | 0 | ) -> ArrayVec<i8, NUM_UV_COEFFS> { |
699 | 0 | assert!(n_coeff <= NUM_Y_COEFFS); |
700 | 0 | let mut coeffs = ArrayVec::new(); |
701 | 0 | let eqns = &self.combined_state[channel].eqns; |
702 | 0 | for i in 0..n_coeff { |
703 | 0 | coeffs.push(clamp((scale_ar_coeff * eqns.x[i]).round() as i32, -128i32, 127i32) as i8); |
704 | 0 | } |
705 | 0 | coeffs.push(clamp( |
706 | 0 | (scale_ar_coeff * y_corr[channel - 1]).round() as i32, |
707 | 0 | -128i32, |
708 | 0 | 127i32, |
709 | 0 | ) as i8); |
710 | 0 | coeffs |
711 | 0 | } |
712 | | |
713 | | // Return true if the noise estimate appears to be different from the combined |
714 | | // (multi-frame) estimate. The difference is measured by checking whether the |
715 | | // AR coefficients have diverged (using a threshold on normalized cross |
716 | | // correlation), or whether the noise strength has changed. |
717 | | #[must_use] |
718 | 0 | fn is_different(&self) -> bool { |
719 | | const COEFF_THRESHOLD: f64 = 0.9f64; |
720 | | const STRENGTH_THRESHOLD: f64 = 0.005f64; |
721 | | |
722 | 0 | let latest = &self.latest_state[0]; |
723 | 0 | let combined = &self.combined_state[0]; |
724 | 0 | let corr = normalized_cross_correlation(&latest.eqns.x, &combined.eqns.x, combined.eqns.n); |
725 | 0 | if corr < COEFF_THRESHOLD { |
726 | 0 | return true; |
727 | 0 | } |
728 | | |
729 | 0 | let dx = 1.0f64 / latest.strength_solver.num_bins as f64; |
730 | 0 | let latest_eqns = &latest.strength_solver.eqns; |
731 | 0 | let combined_eqns = &combined.strength_solver.eqns; |
732 | 0 | let mut diff = 0.0f64; |
733 | 0 | let mut total_weight = 0.0f64; |
734 | 0 | for j in 0..latest_eqns.n { |
735 | 0 | let mut weight = 0.0f64; |
736 | 0 | for i in 0..latest_eqns.n { |
737 | 0 | weight += latest_eqns.a[i * latest_eqns.n + j]; |
738 | 0 | } |
739 | 0 | weight = weight.sqrt(); |
740 | 0 | diff += weight * (latest_eqns.x[j] - combined_eqns.x[j]).abs(); |
741 | 0 | total_weight += weight; |
742 | | } |
743 | | |
744 | 0 | diff * dx / total_weight > STRENGTH_THRESHOLD |
745 | 0 | } |
746 | | |
747 | | #[allow(clippy::too_many_arguments)] |
748 | 0 | fn add_block_observations( |
749 | 0 | &mut self, |
750 | 0 | channel: usize, |
751 | 0 | source: &Plane<u8>, |
752 | 0 | denoised: &Plane<u8>, |
753 | 0 | alt_source: Option<&Plane<u8>>, |
754 | 0 | alt_denoised: Option<&Plane<u8>>, |
755 | 0 | frame_dims: (usize, usize), |
756 | 0 | flat_blocks: &[u8], |
757 | 0 | num_blocks_w: usize, |
758 | 0 | num_blocks_h: usize, |
759 | 0 | ) { |
760 | 0 | let num_coords = self.n; |
761 | 0 | let state = &mut self.latest_state[channel]; |
762 | 0 | let a = &mut state.eqns.a; |
763 | 0 | let b = &mut state.eqns.b; |
764 | 0 | let mut buffer = vec![0f64; num_coords + 1].into_boxed_slice(); |
765 | 0 | let n = state.eqns.n; |
766 | 0 | let block_w = BLOCK_SIZE >> source.cfg.xdec; |
767 | 0 | let block_h = BLOCK_SIZE >> source.cfg.ydec; |
768 | | |
769 | 0 | let dec = (source.cfg.xdec, source.cfg.ydec); |
770 | 0 | let stride = source.cfg.stride; |
771 | 0 | let source_origin = source.data_origin(); |
772 | 0 | let denoised_origin = denoised.data_origin(); |
773 | 0 | let alt_stride = alt_source.map_or(0, |s| s.cfg.stride); |
774 | 0 | let alt_source_origin = alt_source.map(|s| s.data_origin()); |
775 | 0 | let alt_denoised_origin = alt_denoised.map(|s| s.data_origin()); |
776 | | |
777 | 0 | for by in 0..num_blocks_h { |
778 | 0 | let y_o = by * block_h; |
779 | 0 | for bx in 0..num_blocks_w { |
780 | | // SAFETY: We know the indexes we provide do not overflow the data bounds |
781 | | unsafe { |
782 | 0 | let flat_block_ptr = flat_blocks.as_ptr().add(by * num_blocks_w + bx); |
783 | 0 | let x_o = bx * block_w; |
784 | 0 | if *flat_block_ptr == 0 { |
785 | 0 | continue; |
786 | 0 | } |
787 | 0 | let y_start = if by > 0 && *flat_block_ptr.sub(num_blocks_w) > 0 { |
788 | 0 | 0 |
789 | | } else { |
790 | 0 | NOISE_MODEL_LAG |
791 | | }; |
792 | 0 | let x_start = if bx > 0 && *flat_block_ptr.sub(1) > 0 { |
793 | 0 | 0 |
794 | | } else { |
795 | 0 | NOISE_MODEL_LAG |
796 | | }; |
797 | 0 | let y_end = ((frame_dims.1 >> dec.1) - by * block_h).min(block_h); |
798 | 0 | let x_end = ((frame_dims.0 >> dec.0) - bx * block_w - NOISE_MODEL_LAG).min( |
799 | 0 | if bx + 1 < num_blocks_w && *flat_block_ptr.add(1) > 0 { |
800 | 0 | block_w |
801 | | } else { |
802 | 0 | block_w - NOISE_MODEL_LAG |
803 | | }, |
804 | | ); |
805 | 0 | for y in y_start..y_end { |
806 | 0 | for x in x_start..x_end { |
807 | 0 | let val = extract_ar_row( |
808 | 0 | &self.coords, |
809 | 0 | num_coords, |
810 | 0 | source_origin, |
811 | 0 | denoised_origin, |
812 | 0 | stride, |
813 | 0 | dec, |
814 | 0 | alt_source_origin, |
815 | 0 | alt_denoised_origin, |
816 | 0 | alt_stride, |
817 | 0 | x + x_o, |
818 | 0 | y + y_o, |
819 | 0 | &mut buffer, |
820 | | ); |
821 | 0 | for i in 0..n { |
822 | 0 | for j in 0..n { |
823 | 0 | *a.get_unchecked_mut(i * n + j) += (*buffer.get_unchecked(i) |
824 | 0 | * *buffer.get_unchecked(j)) |
825 | 0 | / BLOCK_NORMALIZATION.powi(2); |
826 | 0 | } |
827 | 0 | *b.get_unchecked_mut(i) += |
828 | 0 | (*buffer.get_unchecked(i) * val) / BLOCK_NORMALIZATION.powi(2); |
829 | | } |
830 | 0 | state.num_observations += 1; |
831 | | } |
832 | | } |
833 | | } |
834 | | } |
835 | | } |
836 | 0 | } |
837 | | |
838 | | #[allow(clippy::too_many_arguments)] |
839 | 0 | fn add_noise_std_observations( |
840 | 0 | &mut self, |
841 | 0 | channel: usize, |
842 | 0 | source: &Plane<u8>, |
843 | 0 | denoised: &Plane<u8>, |
844 | 0 | alt_source: Option<&Plane<u8>>, |
845 | 0 | frame_dims: (usize, usize), |
846 | 0 | flat_blocks: &[u8], |
847 | 0 | num_blocks_w: usize, |
848 | 0 | num_blocks_h: usize, |
849 | 0 | ) { |
850 | 0 | let coeffs = &self.latest_state[channel].eqns.x; |
851 | 0 | let num_coords = self.n; |
852 | 0 | let luma_gain = self.latest_state[0].ar_gain; |
853 | 0 | let noise_gain = self.latest_state[channel].ar_gain; |
854 | 0 | let block_w = BLOCK_SIZE >> source.cfg.xdec; |
855 | 0 | let block_h = BLOCK_SIZE >> source.cfg.ydec; |
856 | | |
857 | 0 | for by in 0..num_blocks_h { |
858 | 0 | let y_o = by * block_h; |
859 | 0 | for bx in 0..num_blocks_w { |
860 | 0 | let x_o = bx * block_w; |
861 | 0 | if flat_blocks[by * num_blocks_w + bx] == 0 { |
862 | 0 | continue; |
863 | 0 | } |
864 | 0 | let num_samples_h = ((frame_dims.1 >> source.cfg.ydec) - by * block_h).min(block_h); |
865 | 0 | let num_samples_w = ((frame_dims.0 >> source.cfg.xdec) - bx * block_w).min(block_w); |
866 | | // Make sure that we have a reasonable amount of samples to consider the |
867 | | // block |
868 | 0 | if num_samples_w * num_samples_h > BLOCK_SIZE { |
869 | 0 | let block_mean = get_block_mean( |
870 | 0 | alt_source.unwrap_or(source), |
871 | 0 | frame_dims, |
872 | 0 | x_o << source.cfg.xdec, |
873 | 0 | y_o << source.cfg.ydec, |
874 | | ); |
875 | 0 | let noise_var = get_noise_var( |
876 | 0 | source, |
877 | 0 | denoised, |
878 | 0 | ( |
879 | 0 | frame_dims.0 >> source.cfg.xdec, |
880 | 0 | frame_dims.1 >> source.cfg.ydec, |
881 | 0 | ), |
882 | 0 | x_o, |
883 | 0 | y_o, |
884 | 0 | block_w, |
885 | 0 | block_h, |
886 | | ); |
887 | | // We want to remove the part of the noise that came from being |
888 | | // correlated with luma. Note that the noise solver for luma must |
889 | | // have already been run. |
890 | 0 | let luma_strength = if channel > 0 { |
891 | 0 | luma_gain * self.latest_state[0].strength_solver.get_value(block_mean) |
892 | | } else { |
893 | 0 | 0f64 |
894 | | }; |
895 | 0 | let corr = if channel > 0 { |
896 | 0 | coeffs[num_coords] |
897 | | } else { |
898 | 0 | 0f64 |
899 | | }; |
900 | | // Chroma noise: |
901 | | // N(0, noise_var) = N(0, uncorr_var) + corr * N(0, luma_strength^2) |
902 | | // The uncorrelated component: |
903 | | // uncorr_var = noise_var - (corr * luma_strength)^2 |
904 | | // But don't allow fully correlated noise (hence the max), since the |
905 | | // synthesis cannot model it. |
906 | 0 | let uncorr_std = (noise_var / 16f64) |
907 | 0 | .max((corr * luma_strength).mul_add(-(corr * luma_strength), noise_var)) |
908 | 0 | .sqrt(); |
909 | 0 | let adjusted_strength = uncorr_std / noise_gain; |
910 | 0 | self.latest_state[channel] |
911 | 0 | .strength_solver |
912 | 0 | .add_measurement(block_mean, adjusted_strength); |
913 | 0 | } |
914 | | } |
915 | | } |
916 | 0 | } |
917 | | } |
918 | | |
919 | | #[derive(Debug, Clone)] |
920 | | struct NoiseModelState { |
921 | | eqns: EquationSystem, |
922 | | ar_gain: f64, |
923 | | num_observations: usize, |
924 | | strength_solver: StrengthSolver, |
925 | | } |
926 | | |
927 | | impl NoiseModelState { |
928 | | #[must_use] |
929 | 0 | pub fn new(n: usize) -> Self { |
930 | | const NUM_BINS: usize = 20; |
931 | | |
932 | 0 | Self { |
933 | 0 | eqns: EquationSystem::new(n), |
934 | 0 | ar_gain: 1.0f64, |
935 | 0 | num_observations: 0usize, |
936 | 0 | strength_solver: StrengthSolver::new(NUM_BINS), |
937 | 0 | } |
938 | 0 | } |
939 | | |
940 | 0 | pub fn ar_equation_system_solve(&mut self, is_chroma: bool) -> bool { |
941 | 0 | let ret = self.eqns.solve(); |
942 | 0 | self.ar_gain = 1.0f64; |
943 | 0 | if !ret { |
944 | 0 | return ret; |
945 | 0 | } |
946 | | |
947 | | // Update the AR gain from the equation system as it will be used to fit |
948 | | // the noise strength as a function of intensity. In the Yule-Walker |
949 | | // equations, the diagonal should be the variance of the correlated noise. |
950 | | // In the case of the least squares estimate, there will be some variability |
951 | | // in the diagonal. So use the mean of the diagonal as the estimate of |
952 | | // overall variance (this works for least squares or Yule-Walker formulation). |
953 | 0 | let mut var = 0f64; |
954 | 0 | let n_adjusted = self.eqns.n - usize::from(is_chroma); |
955 | 0 | for i in 0..n_adjusted { |
956 | 0 | var += self.eqns.a[i * self.eqns.n + i] / self.num_observations as f64; |
957 | 0 | } |
958 | 0 | var /= n_adjusted as f64; |
959 | | |
960 | | // Keep track of E(Y^2) = <b, x> + E(X^2) |
961 | | // In the case that we are using chroma and have an estimate of correlation |
962 | | // with luma we adjust that estimate slightly to remove the correlated bits by |
963 | | // subtracting out the last column of a scaled by our correlation estimate |
964 | | // from b. E(y^2) = <b - A(:, end)*x(end), x> |
965 | 0 | let mut sum_covar = 0f64; |
966 | 0 | for i in 0..n_adjusted { |
967 | 0 | let mut bi = self.eqns.b[i]; |
968 | 0 | if is_chroma { |
969 | 0 | bi -= self.eqns.a[i * self.eqns.n + n_adjusted] * self.eqns.x[n_adjusted]; |
970 | 0 | } |
971 | 0 | sum_covar += (bi * self.eqns.x[i]) / self.num_observations as f64; |
972 | | } |
973 | | |
974 | | // Now, get an estimate of the variance of uncorrelated noise signal and use |
975 | | // it to determine the gain of the AR filter. |
976 | 0 | let noise_var = (var - sum_covar).max(1e-6f64); |
977 | 0 | self.ar_gain = 1f64.max((var / noise_var).max(1e-6f64).sqrt()); |
978 | 0 | ret |
979 | 0 | } |
980 | | } |
981 | | |
982 | | #[derive(Debug, Clone)] |
983 | | struct StrengthSolver { |
984 | | eqns: EquationSystem, |
985 | | num_bins: usize, |
986 | | num_equations: usize, |
987 | | total: f64, |
988 | | } |
989 | | |
990 | | impl StrengthSolver { |
991 | | #[must_use] |
992 | 0 | pub fn new(num_bins: usize) -> Self { |
993 | 0 | Self { |
994 | 0 | eqns: EquationSystem::new(num_bins), |
995 | 0 | num_bins, |
996 | 0 | num_equations: 0usize, |
997 | 0 | total: 0f64, |
998 | 0 | } |
999 | 0 | } |
1000 | | |
1001 | 0 | pub fn add_measurement(&mut self, block_mean: f64, noise_std: f64) { |
1002 | 0 | let bin = self.get_bin_index(block_mean); |
1003 | 0 | let bin_i0 = bin.floor() as usize; |
1004 | 0 | let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1); |
1005 | 0 | let a = bin - bin_i0 as f64; |
1006 | 0 | let n = self.num_bins; |
1007 | 0 | let eqns = &mut self.eqns; |
1008 | 0 | eqns.a[bin_i0 * n + bin_i0] += (1f64 - a).powi(2); |
1009 | 0 | eqns.a[bin_i1 * n + bin_i0] += a * (1f64 - a); |
1010 | 0 | eqns.a[bin_i1 * n + bin_i1] += a.powi(2); |
1011 | 0 | eqns.a[bin_i0 * n + bin_i1] += (1f64 - a) * a; |
1012 | 0 | eqns.b[bin_i0] += (1f64 - a) * noise_std; |
1013 | 0 | eqns.b[bin_i1] += a * noise_std; |
1014 | 0 | self.total += noise_std; |
1015 | 0 | self.num_equations += 1; |
1016 | 0 | } |
1017 | | |
1018 | 0 | pub fn solve(&mut self) -> bool { |
1019 | | // Add regularization proportional to the number of constraints |
1020 | 0 | let n = self.num_bins; |
1021 | 0 | let alpha = 2f64 * self.num_equations as f64 / n as f64; |
1022 | | |
1023 | | // Do this in a non-destructive manner so it is not confusing to the caller |
1024 | 0 | let old_a = self.eqns.a.clone(); |
1025 | 0 | for i in 0..n { |
1026 | 0 | let i_lo = i.saturating_sub(1); |
1027 | 0 | let i_hi = (n - 1).min(i + 1); |
1028 | 0 | self.eqns.a[i * n + i_lo] -= alpha; |
1029 | 0 | self.eqns.a[i * n + i] += 2f64 * alpha; |
1030 | 0 | self.eqns.a[i * n + i_hi] -= alpha; |
1031 | 0 | } |
1032 | | |
1033 | | // Small regularization to give average noise strength |
1034 | 0 | let mean = self.total / self.num_equations as f64; |
1035 | 0 | for i in 0..n { |
1036 | 0 | self.eqns.a[i * n + i] += 1f64 / 8192f64; |
1037 | 0 | self.eqns.b[i] += mean / 8192f64; |
1038 | 0 | } |
1039 | 0 | let result = self.eqns.solve(); |
1040 | 0 | self.eqns.a = old_a; |
1041 | 0 | result |
1042 | 0 | } |
1043 | | |
1044 | | #[must_use] |
1045 | 0 | pub fn fit_piecewise(&self, max_output_points: usize) -> NoiseStrengthLut { |
1046 | | const TOLERANCE: f64 = 0.00625f64; |
1047 | | |
1048 | 0 | let mut lut = NoiseStrengthLut::new(self.num_bins); |
1049 | 0 | for i in 0..self.num_bins { |
1050 | 0 | lut.points[i][0] = self.get_center(i); |
1051 | 0 | lut.points[i][1] = self.eqns.x[i]; |
1052 | 0 | } |
1053 | | |
1054 | 0 | let mut residual = vec![0.0f64; self.num_bins]; |
1055 | 0 | self.update_piecewise_linear_residual(&lut, &mut residual, 0, self.num_bins); |
1056 | | |
1057 | | // Greedily remove points if there are too many or if it doesn't hurt local |
1058 | | // approximation (never remove the end points) |
1059 | 0 | while lut.points.len() > 2 { |
1060 | 0 | let mut min_index = 1usize; |
1061 | 0 | for j in 1..(lut.points.len() - 1) { |
1062 | 0 | if residual[j] < residual[min_index] { |
1063 | 0 | min_index = j; |
1064 | 0 | } |
1065 | | } |
1066 | 0 | let dx = lut.points[min_index + 1][0] - lut.points[min_index - 1][0]; |
1067 | 0 | let avg_residual = residual[min_index] / dx; |
1068 | 0 | if lut.points.len() <= max_output_points && avg_residual > TOLERANCE { |
1069 | 0 | break; |
1070 | 0 | } |
1071 | | |
1072 | 0 | lut.points.remove(min_index); |
1073 | 0 | self.update_piecewise_linear_residual( |
1074 | 0 | &lut, |
1075 | 0 | &mut residual, |
1076 | 0 | min_index - 1, |
1077 | 0 | min_index + 1, |
1078 | | ); |
1079 | | } |
1080 | | |
1081 | 0 | lut |
1082 | 0 | } |
1083 | | |
1084 | | #[must_use] |
1085 | 0 | pub fn get_value(&self, x: f64) -> f64 { |
1086 | 0 | let bin = self.get_bin_index(x); |
1087 | 0 | let bin_i0 = bin.floor() as usize; |
1088 | 0 | let bin_i1 = (self.num_bins - 1).min(bin_i0 + 1); |
1089 | 0 | let a = bin - bin_i0 as f64; |
1090 | 0 | (1f64 - a).mul_add(self.eqns.x[bin_i0], a * self.eqns.x[bin_i1]) |
1091 | 0 | } |
1092 | | |
1093 | 0 | pub fn clear(&mut self) { |
1094 | 0 | self.eqns.clear(); |
1095 | 0 | self.num_equations = 0; |
1096 | 0 | self.total = 0f64; |
1097 | 0 | } |
1098 | | |
1099 | | #[must_use] |
1100 | 0 | fn get_bin_index(&self, value: f64) -> f64 { |
1101 | 0 | let max = 255f64; |
1102 | 0 | let val = clamp(value, 0f64, max); |
1103 | 0 | (self.num_bins - 1) as f64 * val / max |
1104 | 0 | } |
1105 | | |
1106 | 0 | fn update_piecewise_linear_residual( |
1107 | 0 | &self, |
1108 | 0 | lut: &NoiseStrengthLut, |
1109 | 0 | residual: &mut [f64], |
1110 | 0 | start: usize, |
1111 | 0 | end: usize, |
1112 | 0 | ) { |
1113 | 0 | let dx = 255f64 / self.num_bins as f64; |
1114 | | #[allow(clippy::needless_range_loop)] |
1115 | 0 | for i in start.max(1)..end.min(lut.points.len() - 1) { |
1116 | 0 | let lower = 0usize.max(self.get_bin_index(lut.points[i - 1][0]).floor() as usize); |
1117 | 0 | let upper = |
1118 | 0 | (self.num_bins - 1).min(self.get_bin_index(lut.points[i + 1][0]).ceil() as usize); |
1119 | 0 | let mut r = 0f64; |
1120 | 0 | for j in lower..=upper { |
1121 | 0 | let x = self.get_center(j); |
1122 | 0 | if x < lut.points[i - 1][0] || x >= lut.points[i + 1][0] { |
1123 | 0 | continue; |
1124 | 0 | } |
1125 | | |
1126 | 0 | let y = self.eqns.x[j]; |
1127 | 0 | let a = (x - lut.points[i - 1][0]) / (lut.points[i + 1][0] - lut.points[i - 1][0]); |
1128 | 0 | let estimate_y = lut.points[i - 1][1].mul_add(1f64 - a, lut.points[i + 1][1] * a); |
1129 | 0 | r += (y - estimate_y).abs(); |
1130 | | } |
1131 | 0 | residual[i] = r * dx; |
1132 | | } |
1133 | 0 | } |
1134 | | |
1135 | | #[must_use] |
1136 | 0 | fn get_center(&self, i: usize) -> f64 { |
1137 | 0 | let range = 255f64; |
1138 | 0 | let n = self.num_bins; |
1139 | 0 | i as f64 / (n - 1) as f64 * range |
1140 | 0 | } |
1141 | | } |
1142 | | |
1143 | | impl Add<&StrengthSolver> for StrengthSolver { |
1144 | | type Output = StrengthSolver; |
1145 | | |
1146 | 0 | fn add(self, addend: &StrengthSolver) -> Self::Output { |
1147 | 0 | let mut dest = self; |
1148 | 0 | dest.eqns += &addend.eqns; |
1149 | 0 | dest.num_equations += addend.num_equations; |
1150 | 0 | dest.total += addend.total; |
1151 | 0 | dest |
1152 | 0 | } |
1153 | | } |
1154 | | |
1155 | | impl AddAssign<&StrengthSolver> for StrengthSolver { |
1156 | 0 | fn add_assign(&mut self, rhs: &StrengthSolver) { |
1157 | 0 | *self = self.clone() + rhs; |
1158 | 0 | } |
1159 | | } |