Coverage Report

Created: 2025-06-02 07:01

/rust/registry/src/index.crates.io-6f17d22bba15001f/ucd-trie-0.1.6/src/owned.rs
Line
Count
Source (jump to first uncovered line)
1
use std::borrow::Borrow;
2
use std::collections::HashMap;
3
use std::error;
4
use std::fmt;
5
use std::io;
6
use std::result;
7
8
use super::{TrieSetSlice, CHUNK_SIZE};
9
10
// This implementation was pretty much cribbed from raphlinus' contribution
11
// to the standard library: https://github.com/rust-lang/rust/pull/33098/files
12
//
13
// The fundamental principle guiding this implementation is to take advantage
14
// of the fact that similar Unicode codepoints are often grouped together, and
15
// that most boolean Unicode properties are quite sparse over the entire space
16
// of Unicode codepoints.
17
//
18
// To do this, we represent sets using something like a trie (which gives us
19
// prefix compression). The "final" states of the trie are embedded in leaves
20
// or "chunks," where each chunk is a 64 bit integer. Each bit position of the
21
// integer corresponds to whether a particular codepoint is in the set or not.
22
// These chunks are not just a compact representation of the final states of
23
// the trie, but are also a form of suffix compression. In particular, if
24
// multiple ranges of 64 contiguous codepoints map have the same set membership
25
// ordering, then they all map to the exact same chunk in the trie.
26
//
27
// We organize this structure by partitioning the space of Unicode codepoints
28
// into three disjoint sets. The first set corresponds to codepoints
29
// [0, 0x800), the second [0x800, 0x1000) and the third [0x10000, 0x110000).
30
// These partitions conveniently correspond to the space of 1 or 2 byte UTF-8
31
// encoded codepoints, 3 byte UTF-8 encoded codepoints and 4 byte UTF-8 encoded
32
// codepoints, respectively.
33
//
34
// Each partition has its own tree with its own root. The first partition is
35
// the simplest, since the tree is completely flat. In particular, to determine
36
// the set membership of a Unicode codepoint (that is less than `0x800`), we
37
// do the following (where `cp` is the codepoint we're testing):
38
//
39
//     let chunk_address = cp >> 6;
40
//     let chunk_bit = cp & 0b111111;
41
//     let chunk = tree1[cp >> 6];
42
//     let is_member = 1 == ((chunk >> chunk_bit) & 1);
43
//
44
// We do something similar for the second partition:
45
//
46
//     // we subtract 0x20 since (0x800 >> 6) == 0x20.
47
//     let child_address = (cp >> 6) - 0x20;
48
//     let chunk_address = tree2_level1[child_address];
49
//     let chunk_bit = cp & 0b111111;
50
//     let chunk = tree2_level2[chunk_address];
51
//     let is_member = 1 == ((chunk >> chunk_bit) & 1);
52
//
53
// And so on for the third partition.
54
//
55
// Note that as a special case, if the second or third partitions are empty,
56
// then the trie will store empty slices for those levels. The `contains`
57
// check knows to return `false` in those cases.
58
59
const CHUNKS: usize = 0x110000 / CHUNK_SIZE;
60
61
/// A type alias that maps to `std::result::Result<T, ucd_trie::Error>`.
62
pub type Result<T> = result::Result<T, Error>;
63
64
/// An error that can occur during construction of a trie.
65
#[derive(Clone, Debug)]
66
pub enum Error {
67
    /// This error is returned when an invalid codepoint is given to
68
    /// `TrieSetOwned::from_codepoints`. An invalid codepoint is a `u32` that
69
    /// is greater than `0x10FFFF`.
70
    InvalidCodepoint(u32),
71
    /// This error is returned when a set of Unicode codepoints could not be
72
    /// sufficiently compressed into the trie provided by this crate. There is
73
    /// no work-around for this error at this time.
74
    GaveUp,
75
}
76
77
impl error::Error for Error {}
78
79
impl fmt::Display for Error {
80
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81
0
        match *self {
82
0
            Error::InvalidCodepoint(cp) => write!(
83
0
                f,
84
0
                "could not construct trie set containing an \
85
0
                 invalid Unicode codepoint: 0x{:X}",
86
0
                cp
87
0
            ),
88
            Error::GaveUp => {
89
0
                write!(f, "could not compress codepoint set into a trie")
90
            }
91
        }
92
0
    }
93
}
94
95
impl From<Error> for io::Error {
96
0
    fn from(err: Error) -> io::Error {
97
0
        io::Error::new(io::ErrorKind::Other, err)
98
0
    }
99
}
100
101
/// An owned trie set.
102
#[derive(Clone)]
103
pub struct TrieSetOwned {
104
    tree1_level1: Vec<u64>,
105
    tree2_level1: Vec<u8>,
106
    tree2_level2: Vec<u64>,
107
    tree3_level1: Vec<u8>,
108
    tree3_level2: Vec<u8>,
109
    tree3_level3: Vec<u64>,
110
}
111
112
impl fmt::Debug for TrieSetOwned {
113
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114
0
        write!(f, "TrieSetOwned(...)")
115
0
    }
116
}
117
118
impl TrieSetOwned {
119
0
    fn new(all: &[bool]) -> Result<TrieSetOwned> {
120
0
        let mut bitvectors = Vec::with_capacity(CHUNKS);
121
0
        for i in 0..CHUNKS {
122
0
            let mut bitvector = 0u64;
123
0
            for j in 0..CHUNK_SIZE {
124
0
                if all[i * CHUNK_SIZE + j] {
125
0
                    bitvector |= 1 << j;
126
0
                }
127
            }
128
0
            bitvectors.push(bitvector);
129
        }
130
131
0
        let tree1_level1 =
132
0
            bitvectors.iter().cloned().take(0x800 / CHUNK_SIZE).collect();
133
134
0
        let (mut tree2_level1, mut tree2_level2) = compress_postfix_leaves(
135
0
            &bitvectors[0x800 / CHUNK_SIZE..0x10000 / CHUNK_SIZE],
136
0
        )?;
137
0
        if tree2_level2.len() == 1 && tree2_level2[0] == 0 {
138
0
            tree2_level1.clear();
139
0
            tree2_level2.clear();
140
0
        }
141
142
0
        let (mid, mut tree3_level3) = compress_postfix_leaves(
143
0
            &bitvectors[0x10000 / CHUNK_SIZE..0x110000 / CHUNK_SIZE],
144
0
        )?;
145
0
        let (mut tree3_level1, mut tree3_level2) =
146
0
            compress_postfix_mid(&mid, 64)?;
147
0
        if tree3_level3.len() == 1 && tree3_level3[0] == 0 {
148
0
            tree3_level1.clear();
149
0
            tree3_level2.clear();
150
0
            tree3_level3.clear();
151
0
        }
152
153
0
        Ok(TrieSetOwned {
154
0
            tree1_level1,
155
0
            tree2_level1,
156
0
            tree2_level2,
157
0
            tree3_level1,
158
0
            tree3_level2,
159
0
            tree3_level3,
160
0
        })
161
0
    }
162
163
    /// Create a new trie set from a set of Unicode scalar values.
164
    ///
165
    /// This returns an error if a set could not be sufficiently compressed to
166
    /// fit into a trie.
167
0
    pub fn from_scalars<I, C>(scalars: I) -> Result<TrieSetOwned>
168
0
    where
169
0
        I: IntoIterator<Item = C>,
170
0
        C: Borrow<char>,
171
0
    {
172
0
        let mut all = vec![false; 0x110000];
173
0
        for s in scalars {
174
0
            all[*s.borrow() as usize] = true;
175
0
        }
176
0
        TrieSetOwned::new(&all)
177
0
    }
178
179
    /// Create a new trie set from a set of Unicode scalar values.
180
    ///
181
    /// This returns an error if a set could not be sufficiently compressed to
182
    /// fit into a trie. This also returns an error if any of the given
183
    /// codepoints are greater than `0x10FFFF`.
184
0
    pub fn from_codepoints<I, C>(codepoints: I) -> Result<TrieSetOwned>
185
0
    where
186
0
        I: IntoIterator<Item = C>,
187
0
        C: Borrow<u32>,
188
0
    {
189
0
        let mut all = vec![false; 0x110000];
190
0
        for cp in codepoints {
191
0
            let cp = *cp.borrow();
192
0
            if cp > 0x10FFFF {
193
0
                return Err(Error::InvalidCodepoint(cp));
194
0
            }
195
0
            all[cp as usize] = true;
196
        }
197
0
        TrieSetOwned::new(&all)
198
0
    }
199
200
    /// Return this set as a slice.
201
    #[inline(always)]
202
0
    pub fn as_slice(&self) -> TrieSetSlice<'_> {
203
0
        TrieSetSlice {
204
0
            tree1_level1: &self.tree1_level1,
205
0
            tree2_level1: &self.tree2_level1,
206
0
            tree2_level2: &self.tree2_level2,
207
0
            tree3_level1: &self.tree3_level1,
208
0
            tree3_level2: &self.tree3_level2,
209
0
            tree3_level3: &self.tree3_level3,
210
0
        }
211
0
    }
212
213
    /// Returns true if and only if the given Unicode scalar value is in this
214
    /// set.
215
0
    pub fn contains_char(&self, c: char) -> bool {
216
0
        self.as_slice().contains_char(c)
217
0
    }
218
219
    /// Returns true if and only if the given codepoint is in this set.
220
    ///
221
    /// If the given value exceeds the codepoint range (i.e., it's greater
222
    /// than `0x10FFFF`), then this returns false.
223
0
    pub fn contains_u32(&self, cp: u32) -> bool {
224
0
        self.as_slice().contains_u32(cp)
225
0
    }
226
}
227
228
0
fn compress_postfix_leaves(chunks: &[u64]) -> Result<(Vec<u8>, Vec<u64>)> {
229
0
    let mut root = vec![];
230
0
    let mut children = vec![];
231
0
    let mut bychild = HashMap::new();
232
0
    for &chunk in chunks {
233
0
        if !bychild.contains_key(&chunk) {
234
0
            let start = bychild.len();
235
0
            if start > ::std::u8::MAX as usize {
236
0
                return Err(Error::GaveUp);
237
0
            }
238
0
            bychild.insert(chunk, start as u8);
239
0
            children.push(chunk);
240
0
        }
241
0
        root.push(bychild[&chunk]);
242
    }
243
0
    Ok((root, children))
244
0
}
245
246
0
fn compress_postfix_mid(
247
0
    chunks: &[u8],
248
0
    chunk_size: usize,
249
0
) -> Result<(Vec<u8>, Vec<u8>)> {
250
0
    let mut root = vec![];
251
0
    let mut children = vec![];
252
0
    let mut bychild = HashMap::new();
253
0
    for i in 0..(chunks.len() / chunk_size) {
254
0
        let chunk = &chunks[i * chunk_size..(i + 1) * chunk_size];
255
0
        if !bychild.contains_key(chunk) {
256
0
            let start = bychild.len();
257
0
            if start > ::std::u8::MAX as usize {
258
0
                return Err(Error::GaveUp);
259
0
            }
260
0
            bychild.insert(chunk, start as u8);
261
0
            children.extend(chunk);
262
0
        }
263
0
        root.push(bychild[chunk]);
264
    }
265
0
    Ok((root, children))
266
0
}
267
268
#[cfg(test)]
269
mod tests {
270
    use super::TrieSetOwned;
271
    use crate::general_category;
272
    use std::collections::HashSet;
273
274
    fn mk(scalars: &[char]) -> TrieSetOwned {
275
        TrieSetOwned::from_scalars(scalars).unwrap()
276
    }
277
278
    fn ranges_to_set(ranges: &[(u32, u32)]) -> Vec<u32> {
279
        let mut set = vec![];
280
        for &(start, end) in ranges {
281
            for cp in start..end + 1 {
282
                set.push(cp);
283
            }
284
        }
285
        set
286
    }
287
288
    #[test]
289
    fn set1() {
290
        let set = mk(&['a']);
291
        assert!(set.contains_char('a'));
292
        assert!(!set.contains_char('b'));
293
        assert!(!set.contains_char('β'));
294
        assert!(!set.contains_char('☃'));
295
        assert!(!set.contains_char('😼'));
296
    }
297
298
    #[test]
299
    fn set_combined() {
300
        let set = mk(&['a', 'b', 'β', '☃', '😼']);
301
        assert!(set.contains_char('a'));
302
        assert!(set.contains_char('b'));
303
        assert!(set.contains_char('β'));
304
        assert!(set.contains_char('☃'));
305
        assert!(set.contains_char('😼'));
306
307
        assert!(!set.contains_char('c'));
308
        assert!(!set.contains_char('θ'));
309
        assert!(!set.contains_char('⛇'));
310
        assert!(!set.contains_char('🐲'));
311
    }
312
313
    // Basic tests on all of the general category sets. We check that
314
    // membership is correct on every Unicode codepoint... because we can.
315
316
    macro_rules! category_test {
317
        ($name:ident, $ranges:ident) => {
318
            #[test]
319
            fn $name() {
320
                let set = ranges_to_set(general_category::$ranges);
321
                let hashset: HashSet<u32> = set.iter().cloned().collect();
322
                let trie = TrieSetOwned::from_codepoints(&set).unwrap();
323
                for cp in 0..0x110000 {
324
                    assert!(trie.contains_u32(cp) == hashset.contains(&cp));
325
                }
326
                // Test that an invalid codepoint is treated correctly.
327
                assert!(!trie.contains_u32(0x110000));
328
                assert!(!hashset.contains(&0x110000));
329
            }
330
        };
331
    }
332
333
    category_test!(gencat_cased_letter, CASED_LETTER);
334
    category_test!(gencat_close_punctuation, CLOSE_PUNCTUATION);
335
    category_test!(gencat_connector_punctuation, CONNECTOR_PUNCTUATION);
336
    category_test!(gencat_control, CONTROL);
337
    category_test!(gencat_currency_symbol, CURRENCY_SYMBOL);
338
    category_test!(gencat_dash_punctuation, DASH_PUNCTUATION);
339
    category_test!(gencat_decimal_number, DECIMAL_NUMBER);
340
    category_test!(gencat_enclosing_mark, ENCLOSING_MARK);
341
    category_test!(gencat_final_punctuation, FINAL_PUNCTUATION);
342
    category_test!(gencat_format, FORMAT);
343
    category_test!(gencat_initial_punctuation, INITIAL_PUNCTUATION);
344
    category_test!(gencat_letter, LETTER);
345
    category_test!(gencat_letter_number, LETTER_NUMBER);
346
    category_test!(gencat_line_separator, LINE_SEPARATOR);
347
    category_test!(gencat_lowercase_letter, LOWERCASE_LETTER);
348
    category_test!(gencat_math_symbol, MATH_SYMBOL);
349
    category_test!(gencat_mark, MARK);
350
    category_test!(gencat_modifier_letter, MODIFIER_LETTER);
351
    category_test!(gencat_modifier_symbol, MODIFIER_SYMBOL);
352
    category_test!(gencat_nonspacing_mark, NONSPACING_MARK);
353
    category_test!(gencat_number, NUMBER);
354
    category_test!(gencat_open_punctuation, OPEN_PUNCTUATION);
355
    category_test!(gencat_other, OTHER);
356
    category_test!(gencat_other_letter, OTHER_LETTER);
357
    category_test!(gencat_other_number, OTHER_NUMBER);
358
    category_test!(gencat_other_punctuation, OTHER_PUNCTUATION);
359
    category_test!(gencat_other_symbol, OTHER_SYMBOL);
360
    category_test!(gencat_paragraph_separator, PARAGRAPH_SEPARATOR);
361
    category_test!(gencat_private_use, PRIVATE_USE);
362
    category_test!(gencat_punctuation, PUNCTUATION);
363
    category_test!(gencat_separator, SEPARATOR);
364
    category_test!(gencat_space_separator, SPACE_SEPARATOR);
365
    category_test!(gencat_spacing_mark, SPACING_MARK);
366
    category_test!(gencat_surrogate, SURROGATE);
367
    category_test!(gencat_symbol, SYMBOL);
368
    category_test!(gencat_titlecase_letter, TITLECASE_LETTER);
369
    category_test!(gencat_unassigned, UNASSIGNED);
370
    category_test!(gencat_uppercase_letter, UPPERCASE_LETTER);
371
}