diff --git a/library/core/src/str/validations.rs b/library/core/src/str/validations.rs index 2acef432f2063..6172aa03666d7 100644 --- a/library/core/src/str/validations.rs +++ b/library/core/src/str/validations.rs @@ -112,128 +112,133 @@ where Some(ch) } +const WORD_BYTES: usize = mem::size_of::(); +/// A word-sized bitmask where every byte's MSB is set, +/// indicating a non-ASCII character. const NONASCII_MASK: usize = usize::repeat_u8(0x80); -/// Returns `true` if any byte in the word `x` is nonascii (>= 128). -#[inline] -const fn contains_nonascii(x: usize) -> bool { - (x & NONASCII_MASK) != 0 -} - -/// Walks through `v` checking that it's a valid UTF-8 sequence, +/// Walks through `buf` checking that it's a valid UTF-8 sequence, /// returning `Ok(())` in that case, or, if it is invalid, `Err(err)`. #[inline(always)] #[rustc_const_unstable(feature = "str_internals", issue = "none")] -pub(super) const fn run_utf8_validation(v: &[u8]) -> Result<(), Utf8Error> { - let mut index = 0; - let len = v.len(); - - let usize_bytes = mem::size_of::(); - let ascii_block_size = 2 * usize_bytes; - let blocks_end = if len >= ascii_block_size { len - ascii_block_size + 1 } else { 0 }; - let align = v.as_ptr().align_offset(usize_bytes); - - while index < len { - let old_offset = index; - macro_rules! err { - ($error_len: expr) => { - return Err(Utf8Error { valid_up_to: old_offset, error_len: $error_len }) - }; - } +pub(super) const fn run_utf8_validation(buf: &[u8]) -> Result<(), Utf8Error> { + // we check aligned blocks of up to 8 words at a time + const ASCII_BLOCK_8X: usize = 8 * WORD_BYTES; + const ASCII_BLOCK_4X: usize = 4 * WORD_BYTES; + const ASCII_BLOCK_2X: usize = 2 * WORD_BYTES; - macro_rules! next { - () => {{ - index += 1; - // we needed data, but there was none: error! - if index >= len { - err!(None) - } - v[index] - }}; - } + // establish buffer extent + let (mut curr, end) = (0, buf.len()); + let start = buf.as_ptr(); + // calculate the byte offset until the first word aligned block + let align_offset = start.align_offset(WORD_BYTES); + + // calculate the maximum byte at which a block of size N could begin, + // without taking alignment into account + let block_end_8x = block_end(end, ASCII_BLOCK_8X); + let block_end_4x = block_end(end, ASCII_BLOCK_4X); + let block_end_2x = block_end(end, ASCII_BLOCK_2X); + + while curr < end { + if buf[curr] < 128 { + // `align_offset` can basically only be `usize::MAX` for ZST + // pointers, so the first check is almost certainly optimized away + if align_offset == usize::MAX { + curr += 1; + continue; + } + + // check if `curr`'s pointer is word-aligned + let offset = align_offset.wrapping_sub(curr) % WORD_BYTES; + if offset == 0 { + let len = 'block: loop { + macro_rules! block_loop { + ($N:expr) => { + // SAFETY: we have checked before that there are + // still at least `N * size_of::()` in the + // buffer and that the current byte is word-aligned + let block = unsafe { &*(start.add(curr) as *const [usize; $N]) }; + if has_non_ascii_byte(block) { + break 'block Some($N); + } - let first = v[index]; - if first >= 128 { - let w = utf8_char_width(first); - // 2-byte encoding is for codepoints \u{0080} to \u{07ff} - // first C2 80 last DF BF - // 3-byte encoding is for codepoints \u{0800} to \u{ffff} - // first E0 A0 80 last EF BF BF - // excluding surrogates codepoints \u{d800} to \u{dfff} - // ED A0 80 to ED BF BF - // 4-byte encoding is for codepoints \u{1000}0 to \u{10ff}ff - // first F0 90 80 80 last F4 8F BF BF - // - // Use the UTF-8 syntax from the RFC - // - // https://tools.ietf.org/html/rfc3629 - // UTF8-1 = %x00-7F - // UTF8-2 = %xC2-DF UTF8-tail - // UTF8-3 = %xE0 %xA0-BF UTF8-tail / %xE1-EC 2( UTF8-tail ) / - // %xED %x80-9F UTF8-tail / %xEE-EF 2( UTF8-tail ) - // UTF8-4 = %xF0 %x90-BF 2( UTF8-tail ) / %xF1-F3 3( UTF8-tail ) / - // %xF4 %x80-8F 2( UTF8-tail ) - match w { - 2 => { - if next!() as i8 >= -64 { - err!(Some(1)) + curr += $N * WORD_BYTES; + }; } - } - 3 => { - match (first, next!()) { - (0xE0, 0xA0..=0xBF) - | (0xE1..=0xEC, 0x80..=0xBF) - | (0xED, 0x80..=0x9F) - | (0xEE..=0xEF, 0x80..=0xBF) => {} - _ => err!(Some(1)), + + // check 8-word blocks for non-ASCII bytes + while curr < block_end_8x { + block_loop!(8); } - if next!() as i8 >= -64 { - err!(Some(2)) + + // check 4-word blocks for non-ASCII bytes + while curr < block_end_4x { + block_loop!(4); } - } - 4 => { - match (first, next!()) { - (0xF0, 0x90..=0xBF) | (0xF1..=0xF3, 0x80..=0xBF) | (0xF4, 0x80..=0x8F) => {} - _ => err!(Some(1)), + + // check 2-word blocks for non-ASCII bytes + while curr < block_end_2x { + block_loop!(2); } - if next!() as i8 >= -64 { - err!(Some(2)) + + // `(size_of::() * 2) + (align_of:: - 1)` + // bytes remain at most + break None; + }; + + // if the block loops were stopped due to a non-ascii byte + // in some block, do another block-wise search using the last + // used block-size for the specific byte in the previous block + // in order to skip checking all bytes up to that one + // individually. + // NOTE: this operation does not auto-vectorize well, so it is + // done only in case a non-ASCII byte is actually found + if let Some(len) = len { + // SAFETY: `curr` has not changed since the last block loop, + // so it still points at a byte marking the beginning of a + // word-sized block of the given `len` + let block = unsafe { + let ptr = start.add(curr) as *const usize; + core::slice::from_raw_parts(ptr, len) + }; + + // calculate the amount of bytes that can be skipped without + // having to check them individually + let (skip, non_ascii) = non_ascii_byte_position(block); + curr += skip; + + // if a non-ASCII byte was found, skip the subsequent + // byte-wise loop and go straight back to the main loop + if non_ascii { + continue; } - if next!() as i8 >= -64 { - err!(Some(3)) + } + + // ...otherwise, fall back to byte-wise checks + while curr < end && buf[curr] < 128 { + curr += 1; + } + } else { + // byte is < 128 (ASCII), but pointer is not word-aligned, skip + // until the loop reaches the next word-aligned block) + let mut i = 0; + while i < offset { + // no need to check alignment again for every byte, so skip + // up to `offset` valid ASCII bytes if possible + curr += 1; + if !(curr < end && buf[curr] < 128) { + break; } + + i += 1; } - _ => err!(Some(1)), } - index += 1; } else { - // Ascii case, try to skip forward quickly. - // When the pointer is aligned, read 2 words of data per iteration - // until we find a word containing a non-ascii byte. - if align != usize::MAX && align.wrapping_sub(index) % usize_bytes == 0 { - let ptr = v.as_ptr(); - while index < blocks_end { - // SAFETY: since `align - index` and `ascii_block_size` are - // multiples of `usize_bytes`, `block = ptr.add(index)` is - // always aligned with a `usize` so it's safe to dereference - // both `block` and `block.add(1)`. - unsafe { - let block = ptr.add(index) as *const usize; - // break if there is a nonascii byte - let zu = contains_nonascii(*block); - let zv = contains_nonascii(*block.add(1)); - if zu || zv { - break; - } - } - index += ascii_block_size; - } - // step from the point where the wordwise loop stopped - while index < len && v[index] < 128 { - index += 1; - } - } else { - index += 1; + // non-ASCII case: validate up to 4 bytes, then advance `curr` + // accordingly + match validate_non_acii_bytes(buf, curr) { + Ok(next) => curr = next, + Err(e) => return Err(e), } } } @@ -241,33 +246,159 @@ pub(super) const fn run_utf8_validation(v: &[u8]) -> Result<(), Utf8Error> { Ok(()) } -// https://tools.ietf.org/html/rfc3629 -const UTF8_CHAR_WIDTH: &[u8; 256] = &[ - // 1 2 3 4 5 6 7 8 9 A B C D E F - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0 - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 1 - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2 - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 3 - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 4 - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 5 - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 6 - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 7 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 8 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 9 - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // A - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // B - 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // C - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // D - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // E - 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // F -]; +#[inline] +const fn validate_non_acii_bytes(buf: &[u8], mut curr: usize) -> Result { + const fn subarray(buf: &[u8], idx: usize) -> Option<[u8; N]> { + if buf.len() - idx < N { + return None; + } + + // SAFETY: checked in previous condition + Some(unsafe { *(buf.as_ptr().add(idx) as *const [u8; N]) }) + } + + let prev = curr; + macro_rules! err { + ($error_len: expr) => { + return Err(Utf8Error { valid_up_to: prev, error_len: $error_len }) + }; + } + + let b0 = buf[curr]; + match utf8_char_width(b0) { + 2 => { + let Some([_, b1]) = subarray(buf, curr) else { + err!(None); + }; + + if b1 as i8 >= -64 { + err!(Some(1)); + } + + curr += 2; + } + 3 => { + let Some([_, b1, b2]) = subarray(buf, curr) else { + err!(None); + }; + + match (b0, b1) { + (0xE0, 0xA0..=0xBF) + | (0xE1..=0xEC, 0x80..=0xBF) + | (0xED, 0x80..=0x9F) + | (0xEE..=0xEF, 0x80..=0xBF) => {} + _ => err!(Some(1)), + } + + if b2 as i8 >= -64 { + err!(Some(2)); + } + + curr += 3; + } + 4 => { + let Some([_, b1, b2, b3]) = subarray(buf, curr) else { + err!(None); + }; + + match (b0, b1) { + (0xF0, 0x90..=0xBF) | (0xF1..=0xF3, 0x80..=0xBF) | (0xF4, 0x80..=0x8F) => {} + _ => err!(Some(1)), + } + + if b2 as i8 >= -64 { + err!(Some(2)); + } + + if b3 as i8 >= -64 { + err!(Some(3)); + } + + curr += 4; + } + _ => err!(Some(1)), + } + + Ok(curr) +} + +/// Returns `true` if any one block is not a valid ASCII byte. +#[inline(always)] +const fn has_non_ascii_byte(block: &[usize; N]) -> bool { + let mut vector = [0; N]; + + let mut i = 0; + while i < N { + vector[i] = block[i] & NONASCII_MASK; + i += 1; + } + + i = 0; + while i < N { + if vector[i] > 0 { + return true; + } + i += 1; + } + + false +} + +/// Returns the number of consecutive ASCII bytes within `block` until the first +/// non-ASCII byte and `true`, if a non-ASCII byte was found. +/// +/// Returns `block.len() * size_of::()` and `false`, if all bytes are +/// ASCII bytes. +#[inline(always)] +const fn non_ascii_byte_position(block: &[usize]) -> (usize, bool) { + let mut i = 0; + while i < block.len() { + let mask = block[i] & NONASCII_MASK; + let ctz = mask.trailing_zeros() as usize; + let byte = ctz / WORD_BYTES; + + if byte != WORD_BYTES { + return (byte + (i * WORD_BYTES), true); + } + + i += 1; + } + + (WORD_BYTES * block.len(), false) +} + +#[inline(always)] +const fn block_end(end: usize, block_size: usize) -> usize { + if end >= block_size { end - block_size + 1 } else { 0 } +} /// Given a first byte, determines how many bytes are in this UTF-8 character. #[unstable(feature = "str_internals", issue = "none")] #[must_use] #[inline] -pub const fn utf8_char_width(b: u8) -> usize { - UTF8_CHAR_WIDTH[b as usize] as usize +pub const fn utf8_char_width(byte: u8) -> usize { + // https://tools.ietf.org/html/rfc3629 + const UTF8_CHAR_WIDTH: [u8; 256] = [ + // 1 2 3 4 5 6 7 8 9 A B C D E F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 1 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 3 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 4 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 5 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 6 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 7 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 8 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 9 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // A + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // B + 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // C + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // D + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // E + 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // F + ]; + + UTF8_CHAR_WIDTH[byte as usize] as usize } /// Mask of the value bits of a continuation byte.