/src/gdal/frmts/vrt/vrtreclassifier.cpp
Line  | Count  | Source (jump to first uncovered line)  | 
1  |  | /******************************************************************************  | 
2  |  | *  | 
3  |  |  * Project:  Virtual GDAL Datasets  | 
4  |  |  * Purpose:  Implementation of Reclassifier  | 
5  |  |  * Author:   Daniel Baston  | 
6  |  |  *  | 
7  |  |  ******************************************************************************  | 
8  |  |  * Copyright (c) 2025, ISciences LLC  | 
9  |  |  *  | 
10  |  |  * SPDX-License-Identifier: MIT  | 
11  |  |  ****************************************************************************/  | 
12  |  |  | 
13  |  | #include "cpl_conv.h"  | 
14  |  | #include "vrtreclassifier.h"  | 
15  |  |  | 
16  |  | #include <algorithm>  | 
17  |  | #include <cmath>  | 
18  |  | #include <limits>  | 
19  |  |  | 
20  |  | namespace gdal  | 
21  |  | { | 
22  |  |  | 
23  |  | bool Reclassifier::Interval::Overlaps(const Interval &other) const  | 
24  | 0  | { | 
25  | 0  |     if (dfMin > other.dfMax || dfMax < other.dfMin)  | 
26  | 0  |     { | 
27  | 0  |         return false;  | 
28  | 0  |     }  | 
29  |  |  | 
30  | 0  |     return true;  | 
31  | 0  | }  | 
32  |  |  | 
33  |  | CPLErr Reclassifier::Interval::Parse(const char *s, char **rest)  | 
34  | 0  | { | 
35  | 0  |     const char *start = s;  | 
36  | 0  |     bool bMinIncluded;  | 
37  | 0  |     bool bMaxIncluded;  | 
38  |  | 
  | 
39  | 0  |     while (isspace(*start))  | 
40  | 0  |     { | 
41  | 0  |         start++;  | 
42  | 0  |     }  | 
43  |  | 
  | 
44  | 0  |     char *end;  | 
45  |  | 
  | 
46  | 0  |     if (*start == '(') | 
47  | 0  |     { | 
48  | 0  |         bMinIncluded = false;  | 
49  | 0  |     }  | 
50  | 0  |     else if (*start == '[')  | 
51  | 0  |     { | 
52  | 0  |         bMinIncluded = true;  | 
53  | 0  |     }  | 
54  | 0  |     else  | 
55  | 0  |     { | 
56  | 0  |         double dfVal = CPLStrtod(start, &end);  | 
57  |  | 
  | 
58  | 0  |         if (end == start)  | 
59  | 0  |         { | 
60  | 0  |             CPLError(CE_Failure, CPLE_AppDefined,  | 
61  | 0  |                      "Interval must start with '(' or ']'"); | 
62  | 0  |             return CE_Failure;  | 
63  | 0  |         }  | 
64  |  |  | 
65  | 0  |         SetToConstant(dfVal);  | 
66  |  | 
  | 
67  | 0  |         if (rest != nullptr)  | 
68  | 0  |         { | 
69  | 0  |             *rest = end;  | 
70  | 0  |         }  | 
71  |  | 
  | 
72  | 0  |         return CE_None;  | 
73  | 0  |     }  | 
74  | 0  |     start++;  | 
75  |  | 
  | 
76  | 0  |     while (isspace(*start))  | 
77  | 0  |     { | 
78  | 0  |         start++;  | 
79  | 0  |     }  | 
80  |  | 
  | 
81  | 0  |     if (STARTS_WITH_CI(start, "-inf"))  | 
82  | 0  |     { | 
83  | 0  |         dfMin = -std::numeric_limits<double>::infinity();  | 
84  | 0  |         end = const_cast<char *>(start + 4);  | 
85  | 0  |     }  | 
86  | 0  |     else  | 
87  | 0  |     { | 
88  | 0  |         dfMin = CPLStrtod(start, &end);  | 
89  | 0  |     }  | 
90  |  | 
  | 
91  | 0  |     if (end == start || *end != ',')  | 
92  | 0  |     { | 
93  | 0  |         CPLError(CE_Failure, CPLE_AppDefined, "Expected a number");  | 
94  | 0  |         return CE_Failure;  | 
95  | 0  |     }  | 
96  | 0  |     start = end + 1;  | 
97  |  | 
  | 
98  | 0  |     while (isspace(*start))  | 
99  | 0  |     { | 
100  | 0  |         start++;  | 
101  | 0  |     }  | 
102  |  | 
  | 
103  | 0  |     if (STARTS_WITH_CI(start, "inf"))  | 
104  | 0  |     { | 
105  | 0  |         dfMax = std::numeric_limits<double>::infinity();  | 
106  | 0  |         end = const_cast<char *>(start + 3);  | 
107  | 0  |     }  | 
108  | 0  |     else  | 
109  | 0  |     { | 
110  | 0  |         dfMax = CPLStrtod(start, &end);  | 
111  | 0  |     }  | 
112  |  | 
  | 
113  | 0  |     if (end == start || (*end != ')' && *end != ']'))  | 
114  | 0  |     { | 
115  | 0  |         CPLError(CE_Failure, CPLE_AppDefined,  | 
116  | 0  |                  "Interval must end with ')' or ']");  | 
117  | 0  |         return CE_Failure;  | 
118  | 0  |     }  | 
119  | 0  |     if (*end == ')')  | 
120  | 0  |     { | 
121  | 0  |         bMaxIncluded = false;  | 
122  | 0  |     }  | 
123  | 0  |     else  | 
124  | 0  |     { | 
125  | 0  |         bMaxIncluded = true;  | 
126  | 0  |     }  | 
127  |  | 
  | 
128  | 0  |     if (rest != nullptr)  | 
129  | 0  |     { | 
130  | 0  |         *rest = end + 1;  | 
131  | 0  |     }  | 
132  |  | 
  | 
133  | 0  |     if (std::isnan(dfMin) || std::isnan(dfMax))  | 
134  | 0  |     { | 
135  | 0  |         CPLError(CE_Failure, CPLE_AppDefined,  | 
136  | 0  |                  "NaN is not a valid value for bounds of interval");  | 
137  | 0  |         return CE_Failure;  | 
138  | 0  |     }  | 
139  |  |  | 
140  | 0  |     if (dfMin > dfMax)  | 
141  | 0  |     { | 
142  | 0  |         CPLError(  | 
143  | 0  |             CE_Failure, CPLE_AppDefined,  | 
144  | 0  |             "Lower bound of interval must be lower or equal to upper bound");  | 
145  | 0  |         return CE_Failure;  | 
146  | 0  |     }  | 
147  |  |  | 
148  | 0  |     if (!bMinIncluded)  | 
149  | 0  |     { | 
150  | 0  |         dfMin = std::nextafter(dfMin, std::numeric_limits<double>::infinity());  | 
151  | 0  |     }  | 
152  | 0  |     if (!bMaxIncluded)  | 
153  | 0  |     { | 
154  | 0  |         dfMax = std::nextafter(dfMax, -std::numeric_limits<double>::infinity());  | 
155  | 0  |     }  | 
156  |  | 
  | 
157  | 0  |     return CE_None;  | 
158  | 0  | }  | 
159  |  |  | 
160  |  | void Reclassifier::Interval::SetToConstant(double dfVal)  | 
161  | 0  | { | 
162  | 0  |     dfMin = dfVal;  | 
163  | 0  |     dfMax = dfVal;  | 
164  | 0  | }  | 
165  |  |  | 
166  |  | CPLErr Reclassifier::Finalize()  | 
167  | 0  | { | 
168  | 0  |     std::sort(m_aoIntervalMappings.begin(), m_aoIntervalMappings.end(),  | 
169  | 0  |               [](const auto &a, const auto &b)  | 
170  | 0  |               { return a.first.dfMin < b.first.dfMin; }); | 
171  |  | 
  | 
172  | 0  |     for (std::size_t i = 1; i < m_aoIntervalMappings.size(); i++)  | 
173  | 0  |     { | 
174  | 0  |         if (m_aoIntervalMappings[i - 1].first.Overlaps(  | 
175  | 0  |                 m_aoIntervalMappings[i].first))  | 
176  | 0  |         { | 
177  |  |             // Don't use [, ) notation because we will have modified those values for an open interval  | 
178  | 0  |             CPLError(CE_Failure, CPLE_AppDefined,  | 
179  | 0  |                      "Interval from %g to %g (mapped to %g) overlaps with "  | 
180  | 0  |                      "interval from %g to %g (mapped to %g)",  | 
181  | 0  |                      m_aoIntervalMappings[i - 1].first.dfMin,  | 
182  | 0  |                      m_aoIntervalMappings[i - 1].first.dfMax,  | 
183  | 0  |                      m_aoIntervalMappings[i - 1].second.value_or(  | 
184  | 0  |                          std::numeric_limits<double>::quiet_NaN()),  | 
185  | 0  |                      m_aoIntervalMappings[i].first.dfMin,  | 
186  | 0  |                      m_aoIntervalMappings[i].first.dfMax,  | 
187  | 0  |                      m_aoIntervalMappings[i].second.value_or(  | 
188  | 0  |                          std::numeric_limits<double>::quiet_NaN()));  | 
189  | 0  |             return CE_Failure;  | 
190  | 0  |         }  | 
191  | 0  |     }  | 
192  |  |  | 
193  | 0  |     return CE_None;  | 
194  | 0  | }  | 
195  |  |  | 
196  |  | void Reclassifier::AddMapping(const Interval &interval,  | 
197  |  |                               std::optional<double> dfDstVal)  | 
198  | 0  | { | 
199  | 0  |     m_aoIntervalMappings.emplace_back(interval, dfDstVal);  | 
200  | 0  | }  | 
201  |  |  | 
202  |  | CPLErr Reclassifier::Init(const char *pszText,  | 
203  |  |                           std::optional<double> noDataValue,  | 
204  |  |                           GDALDataType eBufType)  | 
205  | 0  | { | 
206  | 0  |     const char *start = pszText;  | 
207  | 0  |     char *end = const_cast<char *>(start);  | 
208  |  | 
  | 
209  | 0  |     while (*end != '\0')  | 
210  | 0  |     { | 
211  | 0  |         while (isspace(*start))  | 
212  | 0  |         { | 
213  | 0  |             start++;  | 
214  | 0  |         }  | 
215  |  | 
  | 
216  | 0  |         Interval sInt{}; | 
217  | 0  |         bool bFromIsDefault = false;  | 
218  | 0  |         bool bPassThrough = false;  | 
219  | 0  |         bool bFromNaN = false;  | 
220  |  | 
  | 
221  | 0  |         if (STARTS_WITH_CI(start, "DEFAULT"))  | 
222  | 0  |         { | 
223  | 0  |             bFromIsDefault = true;  | 
224  | 0  |             end = const_cast<char *>(start + 7);  | 
225  | 0  |         }  | 
226  | 0  |         else if (STARTS_WITH_CI(start, "NO_DATA"))  | 
227  | 0  |         { | 
228  | 0  |             if (!noDataValue.has_value())  | 
229  | 0  |             { | 
230  | 0  |                 CPLError(  | 
231  | 0  |                     CE_Failure, CPLE_AppDefined,  | 
232  | 0  |                     "Value mapped from NO_DATA, but NoData value is not set");  | 
233  | 0  |                 return CE_Failure;  | 
234  | 0  |             }  | 
235  |  |  | 
236  | 0  |             sInt.SetToConstant(noDataValue.value());  | 
237  | 0  |             end = const_cast<char *>(start + 7);  | 
238  | 0  |         }  | 
239  | 0  |         else if (STARTS_WITH_CI(start, "NAN"))  | 
240  | 0  |         { | 
241  | 0  |             bFromNaN = true;  | 
242  | 0  |             end = const_cast<char *>(start + 3);  | 
243  | 0  |         }  | 
244  | 0  |         else  | 
245  | 0  |         { | 
246  | 0  |             if (auto eErr = sInt.Parse(start, &end); eErr != CE_None)  | 
247  | 0  |             { | 
248  | 0  |                 return eErr;  | 
249  | 0  |             }  | 
250  | 0  |         }  | 
251  |  |  | 
252  | 0  |         while (isspace(*end))  | 
253  | 0  |         { | 
254  | 0  |             end++;  | 
255  | 0  |         }  | 
256  |  | 
  | 
257  | 0  |         if (*end != MAPPING_FROMTO_SEP_CHAR)  | 
258  | 0  |         { | 
259  | 0  |             CPLError(CE_Failure, CPLE_AppDefined,  | 
260  | 0  |                      "Failed to parse mapping (expected '%c', got '%c')",  | 
261  | 0  |                      MAPPING_FROMTO_SEP_CHAR, *end);  | 
262  | 0  |             return CE_Failure;  | 
263  | 0  |         }  | 
264  |  |  | 
265  | 0  |         start = end + 1;  | 
266  |  | 
  | 
267  | 0  |         while (isspace(*start))  | 
268  | 0  |         { | 
269  | 0  |             start++;  | 
270  | 0  |         }  | 
271  |  | 
  | 
272  | 0  |         std::optional<double> dfDstVal{}; | 
273  | 0  |         if (STARTS_WITH(start, "NO_DATA"))  | 
274  | 0  |         { | 
275  | 0  |             if (!noDataValue.has_value())  | 
276  | 0  |             { | 
277  | 0  |                 CPLError(  | 
278  | 0  |                     CE_Failure, CPLE_AppDefined,  | 
279  | 0  |                     "Value mapped to NO_DATA, but NoData value is not set");  | 
280  | 0  |                 return CE_Failure;  | 
281  | 0  |             }  | 
282  | 0  |             dfDstVal = noDataValue.value();  | 
283  | 0  |             end = const_cast<char *>(start) + 7;  | 
284  | 0  |         }  | 
285  | 0  |         else if (STARTS_WITH(start, "PASS_THROUGH"))  | 
286  | 0  |         { | 
287  | 0  |             bPassThrough = true;  | 
288  | 0  |             end = const_cast<char *>(start + 12);  | 
289  | 0  |         }  | 
290  | 0  |         else  | 
291  | 0  |         { | 
292  | 0  |             dfDstVal = CPLStrtod(start, &end);  | 
293  | 0  |             if (start == end)  | 
294  | 0  |             { | 
295  | 0  |                 CPLError(CE_Failure, CPLE_AppDefined,  | 
296  | 0  |                          "Failed to parse output value (expected number or "  | 
297  | 0  |                          "NO_DATA)");  | 
298  | 0  |                 return CE_Failure;  | 
299  | 0  |             }  | 
300  | 0  |         }  | 
301  |  |  | 
302  | 0  |         while (isspace(*end))  | 
303  | 0  |         { | 
304  | 0  |             end++;  | 
305  | 0  |         }  | 
306  |  | 
  | 
307  | 0  |         if (*end != '\0' && *end != MAPPING_INTERVAL_SEP_CHAR)  | 
308  | 0  |         { | 
309  | 0  |             CPLError(CE_Failure, CPLE_AppDefined,  | 
310  | 0  |                      "Failed to parse mapping (expected '%c' or end of string, "  | 
311  | 0  |                      "got '%c')",  | 
312  | 0  |                      MAPPING_INTERVAL_SEP_CHAR, *end);  | 
313  | 0  |             return CE_Failure;  | 
314  | 0  |         }  | 
315  |  |  | 
316  | 0  |         if (dfDstVal.has_value() &&  | 
317  | 0  |             !GDALIsValueExactAs(dfDstVal.value(), eBufType))  | 
318  | 0  |         { | 
319  | 0  |             CPLError(CE_Failure, CPLE_AppDefined,  | 
320  | 0  |                      "Value %g cannot be represented as data type %s",  | 
321  | 0  |                      dfDstVal.value(), GDALGetDataTypeName(eBufType));  | 
322  | 0  |             return CE_Failure;  | 
323  | 0  |         }  | 
324  |  |  | 
325  | 0  |         if (bFromNaN)  | 
326  | 0  |         { | 
327  | 0  |             SetNaNValue(bPassThrough ? std::numeric_limits<double>::quiet_NaN()  | 
328  | 0  |                                      : dfDstVal.value());  | 
329  | 0  |         }  | 
330  | 0  |         else if (bFromIsDefault)  | 
331  | 0  |         { | 
332  | 0  |             if (bPassThrough)  | 
333  | 0  |             { | 
334  | 0  |                 SetDefaultPassThrough(true);  | 
335  | 0  |             }  | 
336  | 0  |             else  | 
337  | 0  |             { | 
338  | 0  |                 SetDefaultValue(dfDstVal.value());  | 
339  | 0  |             }  | 
340  | 0  |         }  | 
341  | 0  |         else  | 
342  | 0  |         { | 
343  | 0  |             AddMapping(sInt, dfDstVal);  | 
344  | 0  |         }  | 
345  |  | 
  | 
346  | 0  |         start = end + 1;  | 
347  | 0  |     }  | 
348  |  |  | 
349  | 0  |     return Finalize();  | 
350  | 0  | }  | 
351  |  |  | 
352  |  | static std::optional<size_t> FindInterval(  | 
353  |  |     const std::vector<std::pair<Reclassifier::Interval, std::optional<double>>>  | 
354  |  |         &arr,  | 
355  |  |     double srcVal)  | 
356  | 0  | { | 
357  | 0  |     if (arr.empty())  | 
358  | 0  |     { | 
359  | 0  |         return std::nullopt;  | 
360  | 0  |     }  | 
361  |  |  | 
362  | 0  |     size_t low = 0;  | 
363  | 0  |     size_t high = arr.size() - 1;  | 
364  |  | 
  | 
365  | 0  |     while (low <= high)  | 
366  | 0  |     { | 
367  | 0  |         auto mid = low + (high - low) / 2;  | 
368  |  | 
  | 
369  | 0  |         const auto &mid_interval = arr[mid].first;  | 
370  | 0  |         if (mid_interval.Contains(srcVal))  | 
371  | 0  |         { | 
372  | 0  |             return mid;  | 
373  | 0  |         }  | 
374  |  |  | 
375  |  |         // Could an interval exist to the left?  | 
376  | 0  |         if (srcVal < mid_interval.dfMin)  | 
377  | 0  |         { | 
378  | 0  |             if (mid == 0)  | 
379  | 0  |             { | 
380  | 0  |                 return std::nullopt;  | 
381  | 0  |             }  | 
382  | 0  |             high = mid - 1;  | 
383  | 0  |         }  | 
384  |  |         // Could an interval exist to the right?  | 
385  | 0  |         else if (srcVal > mid_interval.dfMax)  | 
386  | 0  |         { | 
387  | 0  |             low = mid + 1;  | 
388  | 0  |         }  | 
389  | 0  |         else  | 
390  | 0  |         { | 
391  | 0  |             return std::nullopt;  | 
392  | 0  |         }  | 
393  | 0  |     }  | 
394  |  |  | 
395  | 0  |     return std::nullopt;  | 
396  | 0  | }  | 
397  |  |  | 
398  |  | double Reclassifier::Reclassify(double srcVal, bool &bFoundInterval) const  | 
399  | 0  | { | 
400  | 0  |     bFoundInterval = false;  | 
401  |  | 
  | 
402  | 0  |     if (std::isnan(srcVal))  | 
403  | 0  |     { | 
404  | 0  |         if (m_NaNValue.has_value())  | 
405  | 0  |         { | 
406  | 0  |             bFoundInterval = true;  | 
407  | 0  |             return m_NaNValue.value();  | 
408  | 0  |         }  | 
409  | 0  |     }  | 
410  | 0  |     else  | 
411  | 0  |     { | 
412  | 0  |         auto nInterval = FindInterval(m_aoIntervalMappings, srcVal);  | 
413  | 0  |         if (nInterval.has_value())  | 
414  | 0  |         { | 
415  | 0  |             bFoundInterval = true;  | 
416  | 0  |             return m_aoIntervalMappings[nInterval.value()].second.value_or(  | 
417  | 0  |                 srcVal);  | 
418  | 0  |         }  | 
419  | 0  |     }  | 
420  |  |  | 
421  | 0  |     if (m_defaultValue.has_value())  | 
422  | 0  |     { | 
423  | 0  |         bFoundInterval = true;  | 
424  | 0  |         return m_defaultValue.value();  | 
425  | 0  |     }  | 
426  |  |  | 
427  | 0  |     if (m_defaultPassThrough)  | 
428  | 0  |     { | 
429  | 0  |         bFoundInterval = true;  | 
430  | 0  |         return srcVal;  | 
431  | 0  |     }  | 
432  |  |  | 
433  | 0  |     return 0;  | 
434  | 0  | }  | 
435  |  |  | 
436  |  | }  // namespace gdal  |