/src/zstd/lib/common/entropy_common.c
Line | Count | Source (jump to first uncovered line) |
1 | | /* ****************************************************************** |
2 | | * Common functions of New Generation Entropy library |
3 | | * Copyright (c) Meta Platforms, Inc. and affiliates. |
4 | | * |
5 | | * You can contact the author at : |
6 | | * - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy |
7 | | * - Public forum : https://groups.google.com/forum/#!forum/lz4c |
8 | | * |
9 | | * This source code is licensed under both the BSD-style license (found in the |
10 | | * LICENSE file in the root directory of this source tree) and the GPLv2 (found |
11 | | * in the COPYING file in the root directory of this source tree). |
12 | | * You may select, at your option, one of the above-listed licenses. |
13 | | ****************************************************************** */ |
14 | | |
15 | | /* ************************************* |
16 | | * Dependencies |
17 | | ***************************************/ |
18 | | #include "mem.h" |
19 | | #include "error_private.h" /* ERR_*, ERROR */ |
20 | | #define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */ |
21 | | #include "fse.h" |
22 | | #include "huf.h" |
23 | | #include "bits.h" /* ZSDT_highbit32, ZSTD_countTrailingZeros32 */ |
24 | | |
25 | | |
26 | | /*=== Version ===*/ |
27 | 0 | unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; } |
28 | | |
29 | | |
30 | | /*=== Error Management ===*/ |
31 | 27.1k | unsigned FSE_isError(size_t code) { return ERR_isError(code); } |
32 | 0 | const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); } |
33 | | |
34 | 0 | unsigned HUF_isError(size_t code) { return ERR_isError(code); } |
35 | 0 | const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); } |
36 | | |
37 | | |
38 | | /*-************************************************************** |
39 | | * FSE NCount encoding-decoding |
40 | | ****************************************************************/ |
41 | | FORCE_INLINE_TEMPLATE |
42 | | size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
43 | | const void* headerBuffer, size_t hbSize) |
44 | 51.1k | { |
45 | 51.1k | const BYTE* const istart = (const BYTE*) headerBuffer; |
46 | 51.1k | const BYTE* const iend = istart + hbSize; |
47 | 51.1k | const BYTE* ip = istart; |
48 | 51.1k | int nbBits; |
49 | 51.1k | int remaining; |
50 | 51.1k | int threshold; |
51 | 51.1k | U32 bitStream; |
52 | 51.1k | int bitCount; |
53 | 51.1k | unsigned charnum = 0; |
54 | 51.1k | unsigned const maxSV1 = *maxSVPtr + 1; |
55 | 51.1k | int previous0 = 0; |
56 | | |
57 | 51.1k | if (hbSize < 8) { |
58 | | /* This function only works when hbSize >= 8 */ |
59 | 5.94k | char buffer[8] = {0}; |
60 | 5.94k | ZSTD_memcpy(buffer, headerBuffer, hbSize); |
61 | 5.94k | { size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr, |
62 | 5.94k | buffer, sizeof(buffer)); |
63 | 5.94k | if (FSE_isError(countSize)) return countSize; |
64 | 5.88k | if (countSize > hbSize) return ERROR(corruption_detected); |
65 | 5.81k | return countSize; |
66 | 5.88k | } } |
67 | 45.1k | assert(hbSize >= 8); |
68 | | |
69 | | /* init */ |
70 | 45.1k | ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0])); /* all symbols not present in NCount have a frequency of 0 */ |
71 | 45.1k | bitStream = MEM_readLE32(ip); |
72 | 45.1k | nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */ |
73 | 45.1k | if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge); |
74 | 45.1k | bitStream >>= 4; |
75 | 45.1k | bitCount = 4; |
76 | 45.1k | *tableLogPtr = nbBits; |
77 | 45.1k | remaining = (1<<nbBits)+1; |
78 | 45.1k | threshold = 1<<nbBits; |
79 | 45.1k | nbBits++; |
80 | | |
81 | 401k | for (;;) { |
82 | 401k | if (previous0) { |
83 | | /* Count the number of repeats. Each time the |
84 | | * 2-bit repeat code is 0b11 there is another |
85 | | * repeat. |
86 | | * Avoid UB by setting the high bit to 1. |
87 | | */ |
88 | 53.6k | int repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; |
89 | 57.2k | while (repeats >= 12) { |
90 | 3.54k | charnum += 3 * 12; |
91 | 3.54k | if (LIKELY(ip <= iend-7)) { |
92 | 3.29k | ip += 3; |
93 | 3.29k | } else { |
94 | 254 | bitCount -= (int)(8 * (iend - 7 - ip)); |
95 | 254 | bitCount &= 31; |
96 | 254 | ip = iend - 4; |
97 | 254 | } |
98 | 3.54k | bitStream = MEM_readLE32(ip) >> bitCount; |
99 | 3.54k | repeats = ZSTD_countTrailingZeros32(~bitStream | 0x80000000) >> 1; |
100 | 3.54k | } |
101 | 53.6k | charnum += 3 * repeats; |
102 | 53.6k | bitStream >>= 2 * repeats; |
103 | 53.6k | bitCount += 2 * repeats; |
104 | | |
105 | | /* Add the final repeat which isn't 0b11. */ |
106 | 53.6k | assert((bitStream & 3) < 3); |
107 | 53.6k | charnum += bitStream & 3; |
108 | 53.6k | bitCount += 2; |
109 | | |
110 | | /* This is an error, but break and return an error |
111 | | * at the end, because returning out of a loop makes |
112 | | * it harder for the compiler to optimize. |
113 | | */ |
114 | 53.6k | if (charnum >= maxSV1) break; |
115 | | |
116 | | /* We don't need to set the normalized count to 0 |
117 | | * because we already memset the whole buffer to 0. |
118 | | */ |
119 | | |
120 | 53.6k | if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { |
121 | 49.6k | assert((bitCount >> 3) <= 3); /* For first condition to work */ |
122 | 49.6k | ip += bitCount>>3; |
123 | 49.6k | bitCount &= 7; |
124 | 49.6k | } else { |
125 | 3.97k | bitCount -= (int)(8 * (iend - 4 - ip)); |
126 | 3.97k | bitCount &= 31; |
127 | 3.97k | ip = iend - 4; |
128 | 3.97k | } |
129 | 53.6k | bitStream = MEM_readLE32(ip) >> bitCount; |
130 | 53.6k | } |
131 | 401k | { |
132 | 401k | int const max = (2*threshold-1) - remaining; |
133 | 401k | int count; |
134 | | |
135 | 401k | if ((bitStream & (threshold-1)) < (U32)max) { |
136 | 261k | count = bitStream & (threshold-1); |
137 | 261k | bitCount += nbBits-1; |
138 | 261k | } else { |
139 | 140k | count = bitStream & (2*threshold-1); |
140 | 140k | if (count >= threshold) count -= max; |
141 | 140k | bitCount += nbBits; |
142 | 140k | } |
143 | | |
144 | 401k | count--; /* extra accuracy */ |
145 | | /* When it matters (small blocks), this is a |
146 | | * predictable branch, because we don't use -1. |
147 | | */ |
148 | 401k | if (count >= 0) { |
149 | 269k | remaining -= count; |
150 | 269k | } else { |
151 | 132k | assert(count == -1); |
152 | 132k | remaining += count; |
153 | 132k | } |
154 | 401k | normalizedCounter[charnum++] = (short)count; |
155 | 401k | previous0 = !count; |
156 | | |
157 | 401k | assert(threshold > 1); |
158 | 401k | if (remaining < threshold) { |
159 | | /* This branch can be folded into the |
160 | | * threshold update condition because we |
161 | | * know that threshold > 1. |
162 | | */ |
163 | 171k | if (remaining <= 1) break; |
164 | 126k | nbBits = ZSTD_highbit32(remaining) + 1; |
165 | 126k | threshold = 1 << (nbBits - 1); |
166 | 126k | } |
167 | 356k | if (charnum >= maxSV1) break; |
168 | | |
169 | 356k | if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) { |
170 | 337k | ip += bitCount>>3; |
171 | 337k | bitCount &= 7; |
172 | 337k | } else { |
173 | 19.0k | bitCount -= (int)(8 * (iend - 4 - ip)); |
174 | 19.0k | bitCount &= 31; |
175 | 19.0k | ip = iend - 4; |
176 | 19.0k | } |
177 | 356k | bitStream = MEM_readLE32(ip) >> bitCount; |
178 | 356k | } } |
179 | 45.1k | if (remaining != 1) return ERROR(corruption_detected); |
180 | | /* Only possible when there are too many zeros. */ |
181 | 45.0k | if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall); |
182 | 45.0k | if (bitCount > 32) return ERROR(corruption_detected); |
183 | 45.0k | *maxSVPtr = charnum-1; |
184 | | |
185 | 45.0k | ip += (bitCount+7)>>3; |
186 | 45.0k | return ip-istart; |
187 | 45.0k | } |
188 | | |
189 | | /* Avoids the FORCE_INLINE of the _body() function. */ |
190 | | static size_t FSE_readNCount_body_default( |
191 | | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
192 | | const void* headerBuffer, size_t hbSize) |
193 | 29.8k | { |
194 | 29.8k | return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
195 | 29.8k | } |
196 | | |
197 | | #if DYNAMIC_BMI2 |
198 | | BMI2_TARGET_ATTRIBUTE static size_t FSE_readNCount_body_bmi2( |
199 | | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
200 | | const void* headerBuffer, size_t hbSize) |
201 | 21.2k | { |
202 | 21.2k | return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
203 | 21.2k | } |
204 | | #endif |
205 | | |
206 | | size_t FSE_readNCount_bmi2( |
207 | | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
208 | | const void* headerBuffer, size_t hbSize, int bmi2) |
209 | 51.1k | { |
210 | 51.1k | #if DYNAMIC_BMI2 |
211 | 51.1k | if (bmi2) { |
212 | 21.2k | return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
213 | 21.2k | } |
214 | 29.8k | #endif |
215 | 29.8k | (void)bmi2; |
216 | 29.8k | return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize); |
217 | 51.1k | } |
218 | | |
219 | | size_t FSE_readNCount( |
220 | | short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr, |
221 | | const void* headerBuffer, size_t hbSize) |
222 | 29.8k | { |
223 | 29.8k | return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, /* bmi2 */ 0); |
224 | 29.8k | } |
225 | | |
226 | | |
227 | | /*! HUF_readStats() : |
228 | | Read compact Huffman tree, saved by HUF_writeCTable(). |
229 | | `huffWeight` is destination buffer. |
230 | | `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32. |
231 | | @return : size read from `src` , or an error Code . |
232 | | Note : Needed by HUF_readCTable() and HUF_readDTableX?() . |
233 | | */ |
234 | | size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
235 | | U32* nbSymbolsPtr, U32* tableLogPtr, |
236 | | const void* src, size_t srcSize) |
237 | 0 | { |
238 | 0 | U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32]; |
239 | 0 | return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* flags */ 0); |
240 | 0 | } |
241 | | |
242 | | FORCE_INLINE_TEMPLATE size_t |
243 | | HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
244 | | U32* nbSymbolsPtr, U32* tableLogPtr, |
245 | | const void* src, size_t srcSize, |
246 | | void* workSpace, size_t wkspSize, |
247 | | int bmi2) |
248 | 25.0k | { |
249 | 25.0k | U32 weightTotal; |
250 | 25.0k | const BYTE* ip = (const BYTE*) src; |
251 | 25.0k | size_t iSize; |
252 | 25.0k | size_t oSize; |
253 | | |
254 | 25.0k | if (!srcSize) return ERROR(srcSize_wrong); |
255 | 25.0k | iSize = ip[0]; |
256 | | /* ZSTD_memset(huffWeight, 0, hwSize); *//* is not necessary, even though some analyzer complain ... */ |
257 | | |
258 | 25.0k | if (iSize >= 128) { /* special header */ |
259 | 3.81k | oSize = iSize - 127; |
260 | 3.81k | iSize = ((oSize+1)/2); |
261 | 3.81k | if (iSize+1 > srcSize) return ERROR(srcSize_wrong); |
262 | 3.80k | if (oSize >= hwSize) return ERROR(corruption_detected); |
263 | 3.80k | ip += 1; |
264 | 3.80k | { U32 n; |
265 | 18.5k | for (n=0; n<oSize; n+=2) { |
266 | 14.7k | huffWeight[n] = ip[n/2] >> 4; |
267 | 14.7k | huffWeight[n+1] = ip[n/2] & 15; |
268 | 14.7k | } } } |
269 | 21.2k | else { /* header compressed with FSE (normal case) */ |
270 | 21.2k | if (iSize+1 > srcSize) return ERROR(srcSize_wrong); |
271 | | /* max (hwSize-1) values decoded, as last one is implied */ |
272 | 21.2k | oSize = FSE_decompress_wksp_bmi2(huffWeight, hwSize-1, ip+1, iSize, 6, workSpace, wkspSize, bmi2); |
273 | 21.2k | if (FSE_isError(oSize)) return oSize; |
274 | 21.2k | } |
275 | | |
276 | | /* collect weight stats */ |
277 | 24.8k | ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32)); |
278 | 24.8k | weightTotal = 0; |
279 | 1.62M | { U32 n; for (n=0; n<oSize; n++) { |
280 | 1.59M | if (huffWeight[n] > HUF_TABLELOG_MAX) return ERROR(corruption_detected); |
281 | 1.59M | rankStats[huffWeight[n]]++; |
282 | 1.59M | weightTotal += (1 << huffWeight[n]) >> 1; |
283 | 1.59M | } } |
284 | 24.8k | if (weightTotal == 0) return ERROR(corruption_detected); |
285 | | |
286 | | /* get last non-null symbol weight (implied, total must be 2^n) */ |
287 | 24.8k | { U32 const tableLog = ZSTD_highbit32(weightTotal) + 1; |
288 | 24.8k | if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected); |
289 | 24.8k | *tableLogPtr = tableLog; |
290 | | /* determine last weight */ |
291 | 24.8k | { U32 const total = 1 << tableLog; |
292 | 24.8k | U32 const rest = total - weightTotal; |
293 | 24.8k | U32 const verif = 1 << ZSTD_highbit32(rest); |
294 | 24.8k | U32 const lastWeight = ZSTD_highbit32(rest) + 1; |
295 | 24.8k | if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */ |
296 | 24.7k | huffWeight[oSize] = (BYTE)lastWeight; |
297 | 24.7k | rankStats[lastWeight]++; |
298 | 24.7k | } } |
299 | | |
300 | | /* check tree construction validity */ |
301 | 24.7k | if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */ |
302 | | |
303 | | /* results */ |
304 | 24.7k | *nbSymbolsPtr = (U32)(oSize+1); |
305 | 24.7k | return iSize+1; |
306 | 24.7k | } |
307 | | |
308 | | /* Avoids the FORCE_INLINE of the _body() function. */ |
309 | | static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
310 | | U32* nbSymbolsPtr, U32* tableLogPtr, |
311 | | const void* src, size_t srcSize, |
312 | | void* workSpace, size_t wkspSize) |
313 | 0 | { |
314 | 0 | return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 0); |
315 | 0 | } |
316 | | |
317 | | #if DYNAMIC_BMI2 |
318 | | static BMI2_TARGET_ATTRIBUTE size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
319 | | U32* nbSymbolsPtr, U32* tableLogPtr, |
320 | | const void* src, size_t srcSize, |
321 | | void* workSpace, size_t wkspSize) |
322 | 25.0k | { |
323 | 25.0k | return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 1); |
324 | 25.0k | } |
325 | | #endif |
326 | | |
327 | | size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, |
328 | | U32* nbSymbolsPtr, U32* tableLogPtr, |
329 | | const void* src, size_t srcSize, |
330 | | void* workSpace, size_t wkspSize, |
331 | | int flags) |
332 | 25.0k | { |
333 | 25.0k | #if DYNAMIC_BMI2 |
334 | 25.0k | if (flags & HUF_flags_bmi2) { |
335 | 25.0k | return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); |
336 | 25.0k | } |
337 | 0 | #endif |
338 | 0 | (void)flags; |
339 | 0 | return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize); |
340 | 25.0k | } |