/src/duckdb/third_party/zstd/common/entropy_common.cpp
Line | Count | Source |
1 | | /* ****************************************************************** |
2 | | * Common functions of New Generation Entropy library |
3 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
4 | | * |
5 | | * You can contact the author at : |
6 | | * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy |
7 | | * - Public forum : https://groups.google.com/forum/#!forum/lz4c |
8 | | * |
9 | | * This source code is licensed under both the BSD-style license (found in the |
10 | | * LICENSE file in the root directory of this source tree) and the GPLv2 (found |
11 | | * in the COPYING file in the root directory of this source tree). |
12 | | * You may select, at your option, one of the above-listed licenses. |
13 | | ****************************************************************** */ |
14 | | |
15 | | /* ************************************* |
16 | | * Dependencies |
17 | | ***************************************/ |
18 | | #include "zstd/common/mem.h" |
19 | | #include "zstd/common/error_private.h" /* ERR_*, ERROR */ |
20 | | #define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */ |
21 | | #include "zstd/common/fse.h" |
22 | | #include "zstd/common/huf.h" |
23 | | #include "zstd/common/bits.h" /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */ |
24 | | |
25 | | namespace duckdb_zstd { |
26 | | |
27 | | /*=== Version ===*/ |
28 | 0 | unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; } |
29 | | |
30 | | |
31 | | /*=== Error Management ===*/ |
32 | 0 | unsigned FSE_isError(size_t code) { return ERR_isError(code); } |
33 | 0 | const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); } |
34 | | |
35 | 0 | unsigned HUF_isError(size_t code) { return ERR_isError(code); } |
36 | 0 | const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } |
37 | | |
38 | | |
39 | | /*-************************************************************** |
40 | | * FSE NCount encoding-decoding |
41 | | ****************************************************************/ |
42 | | FORCE_INLINE_TEMPLATE |
43 | | size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
44 | | const void* headerBuffer, size_t hbSize) |
45 | 0 | { |
46 | 0 | const BYTE* const istart = (const BYTE*) headerBuffer; |
47 | 0 | const BYTE* const iend = istart + hbSize; |
48 | 0 | const BYTE* ip = istart; |
49 | 0 | int nbBits; |
50 | 0 | int remaining; |
51 | 0 | int threshold; |
52 | 0 | U32 bitStream; |
53 | 0 | int bitCount; |
54 | 0 | unsigned charnum = 0; |
55 | 0 | unsigned const maxSV1 = *maxSVPtr + 1; |
56 | 0 | int previous0 = 0; |
57 | |
|
58 | 0 | if (hbSize < 8) { |
59 | | /* This function only works when hbSize >= 8 */ |
60 | 0 | char buffer[8] = {0}; |
61 | 0 | ZSTD_memcpy(buffer, headerBuffer, hbSize); |
62 | 0 | { size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr, |
63 | 0 | buffer, sizeof(buffer)); |
64 | 0 | if (FSE_isError(countSize)) return countSize; |
65 | 0 | if (countSize > hbSize) return ERROR(corruption_detected); |
66 | 0 | return countSize; |
67 | 0 | } } |
68 | 0 | assert(hbSize >= 8); |
69 | | |
70 | | /* init */ |
71 | 0 | ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0])); /* all symbols not present in NCount have a frequency of 0 */ |
72 | 0 | bitStream = MEM_readLE32(ip); |
73 | 0 | nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */ |
74 | 0 | if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge); |
75 | 0 | bitStream >>= 4; |
76 | 0 | bitCount = 4; |
77 | 0 | *tableLogPtr = nbBits; |
78 | 0 | remaining = (1<<nbBits)+1; |
79 | 0 | threshold = 1<<nbBits; |
80 | 0 | nbBits++; |
81 | |
|
82 | 0 | for (;;) { |
83 | 0 | if (previous0) { |
84 | | /* Count the number of repeats. Each time the |
85 | | * 2-bit repeat code is 0b11 there is another |
86 | | * repeat. |
87 | | * Avoid UB by setting the high bit to 1. |
88 | | */ |
89 | 0 | int repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; |
90 | 0 | while (repeats >= 12) { |
91 | 0 | charnum += 3 * 12; |
92 | 0 | if (LIKELY(ip <= iend-7)) { |
93 | 0 | ip += 3; |
94 | 0 | } else { |
95 | 0 | bitCount -= (int)(8 * (iend - 7 - ip)); |
96 | 0 | bitCount &= 31; |
97 | 0 | ip = iend - 4; |
98 | 0 | } |
99 | 0 | bitStream = MEM_readLE32(ip) >> bitCount; |
100 | 0 | repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; |
101 | 0 | } |
102 | 0 | charnum += 3 * repeats; |
103 | 0 | bitStream >>= 2 * repeats; |
104 | 0 | bitCount += 2 * repeats; |
105 | | |
106 | | /* Add the final repeat which isn't 0b11. */ |
107 | 0 | assert((bitStream & 3) < 3); |
108 | 0 | charnum += bitStream & 3; |
109 | 0 | bitCount += 2; |
110 | | |
111 | | /* This is an error, but break and return an error |
112 | | * at the end, because returning out of a loop makes |
113 | | * it harder for the compiler to optimize. |
114 | | */ |
115 | 0 | if (charnum >= maxSV1) break; |
116 | | |
117 | | /* We don't need to set the normalized count to 0 |
118 | | * because we already memset the whole buffer to 0. |
119 | | */ |
120 | | |
121 | 0 | if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { |
122 | 0 | assert((bitCount >> 3) <= 3); /* For first condition to work */ |
123 | 0 | ip += bitCount>>3; |
124 | 0 | bitCount &= 7; |
125 | 0 | } else { |
126 | 0 | bitCount -= (int)(8 * (iend - 4 - ip)); |
127 | 0 | bitCount &= 31; |
128 | 0 | ip = iend - 4; |
129 | 0 | } |
130 | 0 | bitStream = MEM_readLE32(ip) >> bitCount; |
131 | 0 | } |
132 | 0 | { |
133 | 0 | int const max = (2*threshold-1) - remaining; |
134 | 0 | int count; |
135 | |
|
136 | 0 | if ((bitStream & (threshold-1)) < (U32)max) { |
137 | 0 | count = bitStream & (threshold-1); |
138 | 0 | bitCount += nbBits-1; |
139 | 0 | } else { |
140 | 0 | count = bitStream & (2*threshold-1); |
141 | 0 | if (count >= threshold) count -= max; |
142 | 0 | bitCount += nbBits; |
143 | 0 | } |
144 | |
|
145 | 0 | count--; /* extra accuracy */ |
146 | | /* When it matters (small blocks), this is a |
147 | | * predictable branch, because we don't use -1. |
148 | | */ |
149 | 0 | if (count >= 0) { |
150 | 0 | remaining -= count; |
151 | 0 | } else { |
152 | 0 | assert(count == -1); |
153 | 0 | remaining += count; |
154 | 0 | } |
155 | 0 | normalizedCounter[charnum++] = (short)count; |
156 | 0 | previous0 = !count; |
157 | |
|
158 | 0 | assert(threshold > 1); |
159 | 0 | if (remaining < threshold) { |
160 | | /* This branch can be folded into the |
161 | | * threshold update condition because we |
162 | | * know that threshold > 1. |
163 | | */ |
164 | 0 | if (remaining <= 1) break; |
165 | 0 | nbBits = ZSTD_highbit32(remaining) + 1; |
166 | 0 | threshold = 1 << (nbBits - 1); |
167 | 0 | } |
168 | 0 | if (charnum >= maxSV1) break; |
169 | | |
170 | 0 | if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { |
171 | 0 | ip += bitCount>>3; |
172 | 0 | bitCount &= 7; |
173 | 0 | } else { |
174 | 0 | bitCount -= (int)(8 * (iend - 4 - ip)); |
175 | 0 | bitCount &= 31; |
176 | 0 | ip = iend - 4; |
177 | 0 | } |
178 | 0 | bitStream = MEM_readLE32(ip) >> bitCount; |
179 | 0 | } } |
180 | 0 | if (remaining != 1) return ERROR(corruption_detected); |
181 | | /* Only possible when there are too many zeros. */ |
182 | 0 | if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall); |
183 | 0 | if (bitCount > 32) return ERROR(corruption_detected); |
184 | 0 | *maxSVPtr = charnum-1; |
185 | |
|
186 | 0 | ip += (bitCount+7)>>3; |
187 | 0 | return ip-istart; |
188 | 0 | } |
189 | | |
190 | | /* Avoids the FORCE_INLINE of the _body() function. */ |
191 | | static size_t FSE_readNCount_body_default( |
192 | | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
193 | | const void* headerBuffer, size_t hbSize) |
194 | 0 | { |
195 | 0 | return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
196 | 0 | } |
197 | | |
198 | | #if DYNAMIC_BMI2 |
199 | | BMI2_TARGET_ATTRIBUTE static size_t FSE_readNCount_body_bmi2( |
200 | | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
201 | | const void* headerBuffer, size_t hbSize) |
202 | 0 | { |
203 | 0 | return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
204 | 0 | } |
205 | | #endif |
206 | | |
207 | | size_t FSE_readNCount_bmi2( |
208 | | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
209 | | const void* headerBuffer, size_t hbSize, int bmi2) |
210 | 0 | { |
211 | 0 | #if DYNAMIC_BMI2 |
212 | 0 | if (bmi2) { |
213 | 0 | return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
214 | 0 | } |
215 | 0 | #endif |
216 | 0 | (void)bmi2; |
217 | 0 | return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
218 | 0 | } |
219 | | |
220 | | size_t FSE_readNCount( |
221 | | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
222 | | const void* headerBuffer, size_t hbSize) |
223 | 0 | { |
224 | 0 | return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, /* bmi2 */ 0); |
225 | 0 | } |
226 | | |
227 | | |
228 | | /*! HUF_readStats() : |
229 | | Read compact Huffman tree, saved by HUF_writeCTable(). |
230 | | `huffWeight` is destination buffer. |
231 | | `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32. |
232 | | @return : size read from `src` , or an error Code . |
233 | | Note : Needed by HUF_readCTable() and HUF_readDTableX?() . |
234 | | */ |
235 | | size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
236 | | U32* nbSymbolsPtr, U32* tableLogPtr, |
237 | | const void* src, size_t srcSize) |
238 | 0 | { |
239 | 0 | U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; |
240 | 0 | return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* flags */ 0); |
241 | 0 | } |
242 | | |
243 | | FORCE_INLINE_TEMPLATE size_t |
244 | | HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
245 | | U32* nbSymbolsPtr, U32* tableLogPtr, |
246 | | const void* src, size_t srcSize, |
247 | | void* workSpace, size_t wkspSize, |
248 | | int bmi2) |
249 | 0 | { |
250 | 0 | U32 weightTotal; |
251 | 0 | const BYTE* ip = (const BYTE*) src; |
252 | 0 | size_t iSize; |
253 | 0 | size_t oSize; |
254 | |
|
255 | 0 | if (!srcSize) return ERROR(srcSize_wrong); |
256 | 0 | iSize = ip[0]; |
257 | | /* ZSTD_memset(huffWeight, 0, hwSize); *//* is not necessary, even though some analyzer complain ... */ |
258 | |
|
259 | 0 | if (iSize >= 128) { /* special header */ |
260 | 0 | oSize = iSize - 127; |
261 | 0 | iSize = ((oSize+1)/2); |
262 | 0 | if (iSize+1 > srcSize) return ERROR(srcSize_wrong); |
263 | 0 | if (oSize >= hwSize) return ERROR(corruption_detected); |
264 | 0 | ip += 1; |
265 | 0 | { U32 n; |
266 | 0 | for (n=0; n<oSize; n+=2) { |
267 | 0 | huffWeight[n] = ip[n/2] >> 4; |
268 | 0 | huffWeight[n+1] = ip[n/2] & 15; |
269 | 0 | } } } |
270 | 0 | else { /* header compressed with FSE (normal case) */ |
271 | 0 | if (iSize+1 > srcSize) return ERROR(srcSize_wrong); |
272 | | /* max (hwSize-1) values decoded, as last one is implied */ |
273 | 0 | oSize = FSE_decompress_wksp_bmi2(huffWeight, hwSize-1, ip+1, iSize, 6, workSpace, wkspSize, bmi2); |
274 | 0 | if (FSE_isError(oSize)) return oSize; |
275 | 0 | } |
276 | | |
277 | | /* collect weight stats */ |
278 | 0 | ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32)); |
279 | 0 | weightTotal = 0; |
280 | 0 | { U32 n; for (n=0; n<oSize; n++) { |
281 | 0 | if (huffWeight[n] > HUF_TABLELOG_MAX) return ERROR(corruption_detected); |
282 | 0 | rankStats[huffWeight[n]]++; |
283 | 0 | weightTotal += (1 << huffWeight[n]) >> 1; |
284 | 0 | } } |
285 | 0 | if (weightTotal == 0) return ERROR(corruption_detected); |
286 | | |
287 | | /* get last non-null symbol weight (implied, total must be 2^n) */ |
288 | 0 | { U32 const tableLog = ZSTD_highbit32(weightTotal) + 1; |
289 | 0 | if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); |
290 | 0 | *tableLogPtr = tableLog; |
291 | | /* determine last weight */ |
292 | 0 | { U32 const total = 1 << tableLog; |
293 | 0 | U32 const rest = total - weightTotal; |
294 | 0 | U32 const verif = 1 << ZSTD_highbit32(rest); |
295 | 0 | U32 const lastWeight = ZSTD_highbit32(rest) + 1; |
296 | 0 | if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */ |
297 | 0 | huffWeight[oSize] = (BYTE)lastWeight; |
298 | 0 | rankStats[lastWeight]++; |
299 | 0 | } } |
300 | | |
301 | | /* check tree construction validity */ |
302 | 0 | if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */ |
303 | | |
304 | | /* results */ |
305 | 0 | *nbSymbolsPtr = (U32)(oSize+1); |
306 | 0 | return iSize+1; |
307 | 0 | } |
308 | | |
309 | | /* Avoids the FORCE_INLINE of the _body() function. */ |
310 | | static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
311 | | U32* nbSymbolsPtr, U32* tableLogPtr, |
312 | | const void* src, size_t srcSize, |
313 | | void* workSpace, size_t wkspSize) |
314 | 0 | { |
315 | 0 | return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 0); |
316 | 0 | } |
317 | | |
318 | | #if DYNAMIC_BMI2 |
319 | | static BMI2_TARGET_ATTRIBUTE size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
320 | | U32* nbSymbolsPtr, U32* tableLogPtr, |
321 | | const void* src, size_t srcSize, |
322 | | void* workSpace, size_t wkspSize) |
323 | 0 | { |
324 | 0 | return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 1); |
325 | 0 | } |
326 | | #endif |
327 | | |
328 | | size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
329 | | U32* nbSymbolsPtr, U32* tableLogPtr, |
330 | | const void* src, size_t srcSize, |
331 | | void* workSpace, size_t wkspSize, |
332 | | int flags) |
333 | 0 | { |
334 | 0 | #if DYNAMIC_BMI2 |
335 | 0 | if (flags & HUF_flags_bmi2) { |
336 | 0 | return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); |
337 | 0 | } |
338 | 0 | #endif |
339 | 0 | (void)flags; |
340 | 0 | return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); |
341 | 0 | } |
342 | | |
343 | | } // namespace duckdb_zstd |