Coverage Report

Created: 2023-04-25 07:07

/rust/registry/src/index.crates.io-6f17d22bba15001f/pulldown-cmark-0.8.0/src/escape.rs
Line
Count
Source (jump to first uncovered line)
1
// Copyright 2015 Google Inc. All rights reserved.
2
//
3
// Permission is hereby granted, free of charge, to any person obtaining a copy
4
// of this software and associated documentation files (the "Software"), to deal
5
// in the Software without restriction, including without limitation the rights
6
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
// copies of the Software, and to permit persons to whom the Software is
8
// furnished to do so, subject to the following conditions:
9
//
10
// The above copyright notice and this permission notice shall be included in
11
// all copies or substantial portions of the Software.
12
//
13
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
// THE SOFTWARE.
20
21
//! Utility functions for HTML escaping. Only useful when building your own
22
//! HTML renderer.
23
24
use std::fmt::{Arguments, Write as FmtWrite};
25
use std::io::{self, ErrorKind, Write};
26
use std::str::from_utf8;
27
28
#[rustfmt::skip]
29
static HREF_SAFE: [u8; 128] = [
30
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
31
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
32
    0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
33
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
34
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
35
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1,
36
    0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
37
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
38
];
39
40
static HEX_CHARS: &[u8] = b"0123456789ABCDEF";
41
static AMP_ESCAPE: &str = "&";
42
static SLASH_ESCAPE: &str = "'";
43
44
/// This wrapper exists because we can't have both a blanket implementation
45
/// for all types implementing `Write` and types of the for `&mut W` where
46
/// `W: StrWrite`. Since we need the latter a lot, we choose to wrap
47
/// `Write` types.
48
pub struct WriteWrapper<W>(pub W);
49
50
/// Trait that allows writing string slices. This is basically an extension
51
/// of `std::io::Write` in order to include `String`.
52
pub trait StrWrite {
53
    fn write_str(&mut self, s: &str) -> io::Result<()>;
54
55
    fn write_fmt(&mut self, args: Arguments) -> io::Result<()>;
56
}
57
58
impl<W> StrWrite for WriteWrapper<W>
59
where
60
    W: Write,
61
{
62
    #[inline]
63
0
    fn write_str(&mut self, s: &str) -> io::Result<()> {
64
0
        self.0.write_all(s.as_bytes())
65
0
    }
66
67
    #[inline]
68
0
    fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
69
0
        self.0.write_fmt(args)
70
0
    }
71
}
72
73
impl<'w> StrWrite for String {
74
    #[inline]
75
0
    fn write_str(&mut self, s: &str) -> io::Result<()> {
76
0
        self.push_str(s);
77
0
        Ok(())
78
0
    }
79
80
    #[inline]
81
0
    fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
82
0
        // FIXME: translate fmt error to io error?
83
0
        FmtWrite::write_fmt(self, args).map_err(|_| ErrorKind::Other.into())
84
0
    }
85
}
86
87
impl<W> StrWrite for &'_ mut W
88
where
89
    W: StrWrite,
90
{
91
    #[inline]
92
0
    fn write_str(&mut self, s: &str) -> io::Result<()> {
93
0
        (**self).write_str(s)
94
0
    }
95
96
    #[inline]
97
0
    fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
98
0
        (**self).write_fmt(args)
99
0
    }
100
}
101
102
/// Writes an href to the buffer, escaping href unsafe bytes.
103
0
pub fn escape_href<W>(mut w: W, s: &str) -> io::Result<()>
104
0
where
105
0
    W: StrWrite,
106
0
{
107
0
    let bytes = s.as_bytes();
108
0
    let mut mark = 0;
109
0
    for i in 0..bytes.len() {
110
0
        let c = bytes[i];
111
0
        if c >= 0x80 || HREF_SAFE[c as usize] == 0 {
112
            // character needing escape
113
114
            // write partial substring up to mark
115
0
            if mark < i {
116
0
                w.write_str(&s[mark..i])?;
117
0
            }
118
0
            match c {
119
                b'&' => {
120
0
                    w.write_str(AMP_ESCAPE)?;
121
                }
122
                b'\'' => {
123
0
                    w.write_str(SLASH_ESCAPE)?;
124
                }
125
                _ => {
126
0
                    let mut buf = [0u8; 3];
127
0
                    buf[0] = b'%';
128
0
                    buf[1] = HEX_CHARS[((c as usize) >> 4) & 0xF];
129
0
                    buf[2] = HEX_CHARS[(c as usize) & 0xF];
130
0
                    let escaped = from_utf8(&buf).unwrap();
131
0
                    w.write_str(escaped)?;
132
                }
133
            }
134
0
            mark = i + 1; // all escaped characters are ASCII
135
0
        }
136
    }
137
0
    w.write_str(&s[mark..])
138
0
}
139
140
0
const fn create_html_escape_table() -> [u8; 256] {
141
0
    let mut table = [0; 256];
142
0
    table[b'"' as usize] = 1;
143
0
    table[b'&' as usize] = 2;
144
0
    table[b'<' as usize] = 3;
145
0
    table[b'>' as usize] = 4;
146
0
    table
147
0
}
148
149
static HTML_ESCAPE_TABLE: [u8; 256] = create_html_escape_table();
150
151
static HTML_ESCAPES: [&'static str; 5] = ["", "&quot;", "&amp;", "&lt;", "&gt;"];
152
153
/// Writes the given string to the Write sink, replacing special HTML bytes
154
/// (<, >, &, ") by escape sequences.
155
0
pub fn escape_html<W: StrWrite>(w: W, s: &str) -> io::Result<()> {
156
0
    #[cfg(all(target_arch = "x86_64", feature = "simd"))]
157
0
    {
158
0
        simd::escape_html(w, s)
159
0
    }
160
0
    #[cfg(not(all(target_arch = "x86_64", feature = "simd")))]
161
0
    {
162
0
        escape_html_scalar(w, s)
163
0
    }
164
0
}
165
166
0
fn escape_html_scalar<W: StrWrite>(mut w: W, s: &str) -> io::Result<()> {
167
0
    let bytes = s.as_bytes();
168
0
    let mut mark = 0;
169
0
    let mut i = 0;
170
0
    while i < s.len() {
171
0
        match bytes[i..]
172
0
            .iter()
173
0
            .position(|&c| HTML_ESCAPE_TABLE[c as usize] != 0)
174
        {
175
0
            Some(pos) => {
176
0
                i += pos;
177
0
            }
178
0
            None => break,
179
        }
180
0
        let c = bytes[i];
181
0
        let escape = HTML_ESCAPE_TABLE[c as usize];
182
0
        let escape_seq = HTML_ESCAPES[escape as usize];
183
0
        w.write_str(&s[mark..i])?;
184
0
        w.write_str(escape_seq)?;
185
0
        i += 1;
186
0
        mark = i; // all escaped characters are ASCII
187
    }
188
0
    w.write_str(&s[mark..])
189
0
}
190
191
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
192
mod simd {
193
    use super::StrWrite;
194
    use std::arch::x86_64::*;
195
    use std::io;
196
    use std::mem::size_of;
197
198
    const VECTOR_SIZE: usize = size_of::<__m128i>();
199
200
    pub(crate) fn escape_html<W: StrWrite>(mut w: W, s: &str) -> io::Result<()> {
201
        // The SIMD accelerated code uses the PSHUFB instruction, which is part
202
        // of the SSSE3 instruction set. Further, we can only use this code if
203
        // the buffer is at least one VECTOR_SIZE in length to prevent reading
204
        // out of bounds. If either of these conditions is not met, we fall back
205
        // to scalar code.
206
        if is_x86_feature_detected!("ssse3") && s.len() >= VECTOR_SIZE {
207
            let bytes = s.as_bytes();
208
            let mut mark = 0;
209
210
            unsafe {
211
                foreach_special_simd(bytes, 0, |i| {
212
                    let escape_ix = *bytes.get_unchecked(i) as usize;
213
                    let replacement =
214
                        super::HTML_ESCAPES[super::HTML_ESCAPE_TABLE[escape_ix] as usize];
215
                    w.write_str(&s.get_unchecked(mark..i))?;
216
                    mark = i + 1; // all escaped characters are ASCII
217
                    w.write_str(replacement)
218
                })?;
219
                w.write_str(&s.get_unchecked(mark..))
220
            }
221
        } else {
222
            super::escape_html_scalar(w, s)
223
        }
224
    }
225
226
    /// Creates the lookup table for use in `compute_mask`.
227
    const fn create_lookup() -> [u8; 16] {
228
        let mut table = [0; 16];
229
        table[(b'<' & 0x0f) as usize] = b'<';
230
        table[(b'>' & 0x0f) as usize] = b'>';
231
        table[(b'&' & 0x0f) as usize] = b'&';
232
        table[(b'"' & 0x0f) as usize] = b'"';
233
        table[0] = 0b0111_1111;
234
        table
235
    }
236
237
    #[target_feature(enable = "ssse3")]
238
    /// Computes a byte mask at given offset in the byte buffer. Its first 16 (least significant)
239
    /// bits correspond to whether there is an HTML special byte (&, <, ", >) at the 16 bytes
240
    /// `bytes[offset..]`. For example, the mask `(1 << 3)` states that there is an HTML byte
241
    /// at `offset + 3`. It is only safe to call this function when
242
    /// `bytes.len() >= offset + VECTOR_SIZE`.
243
    unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 {
244
        debug_assert!(bytes.len() >= offset + VECTOR_SIZE);
245
246
        let table = create_lookup();
247
        let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i);
248
        let raw_ptr = bytes.as_ptr().offset(offset as isize) as *const __m128i;
249
250
        // Load the vector from memory.
251
        let vector = _mm_loadu_si128(raw_ptr);
252
        // We take the least significant 4 bits of every byte and use them as indices
253
        // to map into the lookup vector.
254
        // Note that shuffle maps bytes with their most significant bit set to lookup[0].
255
        // Bytes that share their lower nibble with an HTML special byte get mapped to that
256
        // corresponding special byte. Note that all HTML special bytes have distinct lower
257
        // nibbles. Other bytes either get mapped to 0 or 127.
258
        let expected = _mm_shuffle_epi8(lookup, vector);
259
        // We compare the original vector to the mapped output. Bytes that shared a lower
260
        // nibble with an HTML special byte match *only* if they are that special byte. Bytes
261
        // that have either a 0 lower nibble or their most significant bit set were mapped to
262
        // 127 and will hence never match. All other bytes have non-zero lower nibbles but
263
        // were mapped to 0 and will therefore also not match.
264
        let matches = _mm_cmpeq_epi8(expected, vector);
265
266
        // Translate matches to a bitmask, where every 1 corresponds to a HTML special character
267
        // and a 0 is a non-HTML byte.
268
        _mm_movemask_epi8(matches)
269
    }
270
271
    /// Calls the given function with the index of every byte in the given byteslice
272
    /// that is either ", &, <, or > and for no other byte.
273
    /// Make sure to only call this when `bytes.len() >= 16`, undefined behaviour may
274
    /// occur otherwise.
275
    #[target_feature(enable = "ssse3")]
276
    unsafe fn foreach_special_simd<F>(
277
        bytes: &[u8],
278
        mut offset: usize,
279
        mut callback: F,
280
    ) -> io::Result<()>
281
    where
282
        F: FnMut(usize) -> io::Result<()>,
283
    {
284
        // The strategy here is to walk the byte buffer in chunks of VECTOR_SIZE (16)
285
        // bytes at a time starting at the given offset. For each chunk, we compute a
286
        // a bitmask indicating whether the corresponding byte is a HTML special byte.
287
        // We then iterate over all the 1 bits in this mask and call the callback function
288
        // with the corresponding index in the buffer.
289
        // When the number of HTML special bytes in the buffer is relatively low, this
290
        // allows us to quickly go through the buffer without a lookup and for every
291
        // single byte.
292
293
        debug_assert!(bytes.len() >= VECTOR_SIZE);
294
        let upperbound = bytes.len() - VECTOR_SIZE;
295
        while offset < upperbound {
296
            let mut mask = compute_mask(bytes, offset);
297
            while mask != 0 {
298
                let ix = mask.trailing_zeros();
299
                callback(offset + ix as usize)?;
300
                mask ^= mask & -mask;
301
            }
302
            offset += VECTOR_SIZE;
303
        }
304
305
        // Final iteration. We align the read with the end of the slice and
306
        // shift off the bytes at start we have already scanned.
307
        let mut mask = compute_mask(bytes, upperbound);
308
        mask >>= offset - upperbound;
309
        while mask != 0 {
310
            let ix = mask.trailing_zeros();
311
            callback(offset + ix as usize)?;
312
            mask ^= mask & -mask;
313
        }
314
        Ok(())
315
    }
316
317
    #[cfg(test)]
318
    mod html_scan_tests {
319
        #[test]
320
        fn multichunk() {
321
            let mut vec = Vec::new();
322
            unsafe {
323
                super::foreach_special_simd("&aXaaaa.a'aa9a<>aab&".as_bytes(), 0, |ix| {
324
                    Ok(vec.push(ix))
325
                })
326
                .unwrap();
327
            }
328
            assert_eq!(vec, vec![0, 14, 15, 19]);
329
        }
330
331
        // only match these bytes, and when we match them, match them VECTOR_SIZE times
332
        #[test]
333
        fn only_right_bytes_matched() {
334
            for b in 0..255u8 {
335
                let right_byte = b == b'&' || b == b'<' || b == b'>' || b == b'"';
336
                let vek = vec![b; super::VECTOR_SIZE];
337
                let mut match_count = 0;
338
                unsafe {
339
                    super::foreach_special_simd(&vek, 0, |_| {
340
                        match_count += 1;
341
                        Ok(())
342
                    })
343
                    .unwrap();
344
                }
345
                assert!((match_count > 0) == (match_count == super::VECTOR_SIZE));
346
                assert_eq!(
347
                    (match_count == super::VECTOR_SIZE),
348
                    right_byte,
349
                    "match_count: {}, byte: {:?}",
350
                    match_count,
351
                    b as char
352
                );
353
            }
354
        }
355
    }
356
}