diff --git a/core/src/string.rs b/core/src/string.rs index 8185c6213..35577fb6d 100644 --- a/core/src/string.rs +++ b/core/src/string.rs @@ -12,6 +12,7 @@ mod common; mod avm; mod buf; mod ops; +mod pattern; mod raw; mod slice; mod tables; @@ -24,6 +25,7 @@ pub use avm::AvmString; pub use buf::WString; pub use common::{BorrowWStr, BorrowWStrMut, Units}; pub use ops::{Iter, Split}; +pub use pattern::Pattern; pub use slice::{WStr, WStrMut}; use common::panic_on_invalid_length; diff --git a/core/src/string/avm.rs b/core/src/string/avm.rs index 8df0f83a2..bf13de21f 100644 --- a/core/src/string/avm.rs +++ b/core/src/string/avm.rs @@ -70,6 +70,7 @@ impl<'gc> AvmString<'gc> { lifetime: '_; self: &Self; deref: self.as_ucs2(); + pattern['a,]: 'a, &'a Self; } } diff --git a/core/src/string/buf.rs b/core/src/string/buf.rs index e08c4a5b1..ff326af5e 100644 --- a/core/src/string/buf.rs +++ b/core/src/string/buf.rs @@ -136,6 +136,7 @@ impl WString { lifetime: '_; self: &Self; deref: self.borrow(); + pattern['a,]: 'a, &'a Self; } impl_str_mut_methods! { diff --git a/core/src/string/common.rs b/core/src/string/common.rs index 233ce8721..7c6f0dadf 100644 --- a/core/src/string/common.rs +++ b/core/src/string/common.rs @@ -83,6 +83,7 @@ macro_rules! impl_str_methods { lifetime: $lt:lifetime; $self:ident: $receiver:ty; deref: $deref:expr; + pattern[$($pat_gen:tt)*]: $pat_lt:lifetime, $pat_self:ty; ) => { /// Provides access to the underlying buffer. #[inline] @@ -152,24 +153,29 @@ macro_rules! impl_str_methods { crate::string::ops::str_cmp_ignore_case($deref, other) } - /// Analogue of [`str::find`]. - // TODO: add our own Pattern trait to support several kinds of needles? + /// Returns `true` is the string contains only LATIN1 characters. + /// + /// Note that this doesn't necessarily means that `self.is_wide()` is `false`. #[inline] - pub fn find($self: $receiver, needle: WStr<'_>) -> Option { - crate::string::ops::str_find($deref, needle) + pub fn is_latin1($self: $receiver) -> bool { + crate::string::ops::str_is_latin1($deref) + } + + /// Analogue of [`str::find`]. + #[inline] + pub fn find<$($pat_gen)* P: crate::string::Pattern<$pat_lt>>($self: $pat_self, pattern: P) -> Option { + crate::string::ops::str_find($deref, pattern) } /// Analogue of [`str::rfind`]. - // TODO: add our own Pattern trait to support several kinds of needles? #[inline] - pub fn rfind($self: $receiver, needle: WStr<'_>) -> Option { - crate::string::ops::str_rfind($deref, needle) + pub fn rfind<$($pat_gen)* P: crate::string::Pattern<$pat_lt>>($self: $pat_self, pattern: P) -> Option { + crate::string::ops::str_rfind($deref, pattern) } /// Analogue of [`str::split`]. - // TODO: add our own Pattern trait to support several kinds of needles? #[inline] - pub fn split<'s>($self: $receiver, separator: WStr<'s>) -> crate::string::ops::Split<$lt, 's> { + pub fn split<$($pat_gen)* P: crate::string::Pattern<$pat_lt>>($self: $pat_self, separator: P) -> crate::string::ops::Split<$pat_lt, P> { crate::string::ops::str_split($deref, separator) } } diff --git a/core/src/string/ops.rs b/core/src/string/ops.rs index af99a0b3f..85eaf134e 100644 --- a/core/src/string/ops.rs +++ b/core/src/string/ops.rs @@ -2,7 +2,8 @@ use std::fmt::{self, Write}; use std::hash::Hasher; use std::slice::Iter as SliceIter; -use super::{utils, Units, WStr}; +use super::pattern::Searcher; +use super::{utils, Pattern, WStr, Units}; pub struct Iter<'a> { inner: Units, SliceIter<'a, u16>>, @@ -113,53 +114,56 @@ pub fn str_hash(s: WStr<'_>, state: &mut H) { } } -pub fn str_find(haystack: WStr<'_>, needle: WStr<'_>) -> Option { - let max = haystack.len().checked_sub(needle.len())?; - - (0..=max).find(|i| haystack.slice(*i..*i + needle.len()) == needle) +pub fn str_is_latin1(s: WStr<'_>) -> bool { + match s.units() { + Units::Bytes(_) => true, + Units::Wide(us) => us.iter().all(|c| *c <= u16::from(u8::MAX)), + } } -pub fn str_rfind(haystack: WStr<'_>, needle: WStr<'_>) -> Option { - let max = haystack.len().checked_sub(needle.len())?; +pub fn str_find<'a, P: Pattern<'a>>(haystack: WStr<'a>, pattern: P) -> Option { + pattern + .into_searcher(haystack) + .next_match() + .map(|(start, _)| start) +} - (0..=max) - .rev() - .find(|i| haystack.slice(*i..*i + needle.len()) == needle) +pub fn str_rfind<'a, P: Pattern<'a>>(haystack: WStr<'a>, pattern: P) -> Option { + pattern + .into_searcher(haystack) + .next_match_back() + .map(|(start, _)| start) } #[inline] -pub fn str_split<'a, 'b>(string: WStr<'a>, separator: WStr<'b>) -> Split<'a, 'b> { +pub fn str_split<'a, P: Pattern<'a>>(string: WStr<'a>, pattern: P) -> Split<'a, P> { Split { - string, - separator, - done: false, + string: Some(string), + searcher: pattern.into_searcher(string), + prev_end: 0, } } -pub struct Split<'a, 'b> { - string: WStr<'a>, - separator: WStr<'b>, - done: bool, +pub struct Split<'a, P: Pattern<'a>> { + string: Option>, + searcher: P::Searcher, + prev_end: usize, } -impl<'a, 'b> Iterator for Split<'a, 'b> { +impl<'a, P: Pattern<'a>> Iterator for Split<'a, P> { type Item = WStr<'a>; fn next(&mut self) -> Option { - if self.done { - return None; - } + let string = self.string?; - match self.string.find(self.separator) { - Some(i) => { - let prefix = self.string.slice(..i); - let suffix = self.string.slice((i + self.separator.len())..); - self.string = suffix; - Some(prefix) + match self.searcher.next_match() { + Some((start, end)) => { + let end = std::mem::replace(&mut self.prev_end, end); + Some(string.slice(end..start)) } None => { - self.done = true; - Some(self.string) + self.string = None; + Some(string.slice(self.prev_end..)) } } } diff --git a/core/src/string/pattern.rs b/core/src/string/pattern.rs new file mode 100644 index 000000000..34f74aeef --- /dev/null +++ b/core/src/string/pattern.rs @@ -0,0 +1,533 @@ +//! Like [`std::str::Pattern`], but for [`WStr`]. + +// TODO: Is performance good? ideas for improvements: +// - add some inlines? +// - remove implicit bound checks? +// - use memchr crate? + +use super::{WStr, Units}; + +/// A pattern that can be searched in a [`WStr`]. +/// +/// - `WStr` searches for the given string. +/// - `u8` searches for a single LATIN1 code unit. +/// - `u16` searches for a single UCS2 code unit. +/// - `&[u8]` searches for any of the given LATIN1 code units. +/// - `&[u16]` searches for any of the given UCS2 code units. +/// - `FnMut(u16) -> bool` searches for code units matching the predicate. +pub trait Pattern<'a> { + type Searcher: Searcher<'a>; + + fn into_searcher(self, haystack: WStr<'a>) -> Self::Searcher; +} + +pub enum SearchStep { + Match(usize, usize), + Reject(usize, usize), + Done, +} + +pub trait Searcher<'a> { + fn next(&mut self) -> SearchStep; + + fn next_back(&mut self) -> SearchStep; + + fn next_match(&mut self) -> Option<(usize, usize)> { + loop { + break match self.next() { + SearchStep::Match(i, j) => Some((i, j)), + SearchStep::Reject(_, _) => continue, + SearchStep::Done => None, + }; + } + } + + fn next_match_back(&mut self) -> Option<(usize, usize)> { + loop { + break match self.next_back() { + SearchStep::Match(i, j) => Some((i, j)), + SearchStep::Reject(_, _) => continue, + SearchStep::Done => None, + }; + } + } + + fn next_reject(&mut self) -> Option<(usize, usize)> { + loop { + break match self.next() { + SearchStep::Match(_, _) => continue, + SearchStep::Reject(i, j) => Some((i, j)), + SearchStep::Done => None, + }; + } + } + + fn next_reject_back(&mut self) -> Option<(usize, usize)> { + loop { + break match self.next_back() { + SearchStep::Match(_, _) => continue, + SearchStep::Reject(i, j) => Some((i, j)), + SearchStep::Done => None, + }; + } + } +} + +impl<'a> Pattern<'a> for u8 { + type Searcher = Either, PredSearcher<'a, u16, u16>>; + + fn into_searcher(self, haystack: WStr<'a>) -> Self::Searcher { + match haystack.units() { + Units::Bytes(h) => Either::Left(PredSearcher::new(true, h, self)), + Units::Wide(h) => Either::Right(PredSearcher::new(true, h, self.into())), + } + } +} + +impl<'a> Pattern<'a> for u16 { + type Searcher = Either, PredSearcher<'a, u16, u16>>; + + fn into_searcher(self, haystack: WStr<'a>) -> Self::Searcher { + let is_latin1 = self <= u8::MAX as u16; + match haystack.units() { + Units::Bytes(h) => Either::Left(PredSearcher::new(is_latin1, h, self as u8)), + Units::Wide(h) => Either::Right(PredSearcher::new(true, h, self)), + } + } +} + +impl<'a> Pattern<'a> for &'a [u8] { + type Searcher = + Either>, PredSearcher<'a, u16, AnyOf<'a, u8>>>; + + fn into_searcher(self, haystack: WStr<'a>) -> Self::Searcher { + let can_match = !self.is_empty(); + match haystack.units() { + Units::Bytes(h) => Either::Left(PredSearcher::new(can_match, h, AnyOf(self))), + Units::Wide(h) => Either::Right(PredSearcher::new(can_match, h, AnyOf(self))), + } + } +} + +impl<'a> Pattern<'a> for &'a [u16] { + type Searcher = + Either>, PredSearcher<'a, u16, AnyOf<'a, u16>>>; + + fn into_searcher(self, haystack: WStr<'a>) -> Self::Searcher { + let can_match = + !self.is_empty() && (haystack.is_wide() || self.iter().any(|c| *c <= u8::MAX as u16)); + match haystack.units() { + Units::Bytes(h) => Either::Left(PredSearcher::new(can_match, h, AnyOf(self))), + Units::Wide(h) => Either::Right(PredSearcher::new(can_match, h, AnyOf(self))), + } + } +} + +impl<'a, F: FnMut(u16) -> bool> Pattern<'a> for F { + type Searcher = Either>, PredSearcher<'a, u16, FnPred>>; + + fn into_searcher(self, haystack: WStr<'a>) -> Self::Searcher { + match haystack.units() { + Units::Bytes(h) => Either::Left(PredSearcher::new(true, h, FnPred(self))), + Units::Wide(h) => Either::Right(PredSearcher::new(true, h, FnPred(self))), + } + } +} + +impl<'a> Pattern<'a> for WStr<'a> { + #[allow(clippy::type_complexity)] + type Searcher = Either< + Either, SliceSearcher<'a, u16>>, StrSearcher<'a>>, + EmptySearcher, + >; + + fn into_searcher(self, haystack: WStr<'a>) -> Self::Searcher { + if self.is_empty() { + return Either::Right(EmptySearcher::new(haystack.len())); + } + + let s = match (haystack.units(), self.units()) { + (Units::Bytes(h), Units::Bytes(n)) => Either::Left(SliceSearcher::new(h, n)), + (Units::Wide(h), Units::Wide(n)) => Either::Right(SliceSearcher::new(h, n)), + (Units::Bytes(_), _) if self.len() > haystack.len() || !self.is_latin1() => { + Either::Left(SliceSearcher::new(&[], &[0])) + } + _ => return Either::Left(Either::Right(StrSearcher::new(haystack, self))), + }; + + Either::Left(Either::Left(s)) + } +} + +pub enum Either { + Left(T), + Right(U), +} + +impl<'a, T: Searcher<'a>, U: Searcher<'a>> Searcher<'a> for Either { + fn next(&mut self) -> SearchStep { + match self { + Either::Left(s) => s.next(), + Either::Right(s) => s.next(), + } + } + + fn next_back(&mut self) -> SearchStep { + match self { + Either::Left(s) => s.next_back(), + Either::Right(s) => s.next_back(), + } + } + + fn next_match(&mut self) -> Option<(usize, usize)> { + match self { + Either::Left(s) => s.next_match(), + Either::Right(s) => s.next_match(), + } + } + + fn next_match_back(&mut self) -> Option<(usize, usize)> { + match self { + Either::Left(s) => s.next_match_back(), + Either::Right(s) => s.next_match_back(), + } + } + + fn next_reject(&mut self) -> Option<(usize, usize)> { + match self { + Either::Left(s) => s.next_reject(), + Either::Right(s) => s.next_reject(), + } + } + + fn next_reject_back(&mut self) -> Option<(usize, usize)> { + match self { + Either::Left(s) => s.next_reject_back(), + Either::Right(s) => s.next_reject_back(), + } + } +} + +pub struct EmptySearcher { + range: std::ops::Range, +} + +impl EmptySearcher { + // The empty needle matches on every char boundary. + fn new(len: usize) -> Self { + Self { + range: 0..(len + 1), + } + } +} + +impl<'a> Searcher<'a> for EmptySearcher { + fn next(&mut self) -> SearchStep { + match self.range.next() { + Some(i) => SearchStep::Match(i, i), + None => SearchStep::Done, + } + } + + fn next_back(&mut self) -> SearchStep { + match self.range.next_back() { + Some(i) => SearchStep::Match(i, i), + None => SearchStep::Done, + } + } +} + +pub struct PredSearcher<'a, T, P> { + haystack: &'a [T], + predicate: P, + front: usize, +} + +pub trait Predicate { + fn is_match(&mut self, c: T) -> bool; +} + +impl Predicate for T { + fn is_match(&mut self, c: T) -> bool { + *self == c + } +} + +pub struct AnyOf<'a, T>(&'a [T]); + +impl<'a, T: Copy, U: Copy + Eq + TryFrom> Predicate for AnyOf<'a, U> { + fn is_match(&mut self, c: T) -> bool { + self.0.iter().any(|m| U::try_from(c).ok() == Some(*m)) + } +} + +pub struct FnPred(F); + +impl<'a, T: Into, F: FnMut(u16) -> bool> Predicate for FnPred { + fn is_match(&mut self, c: T) -> bool { + (self.0)(c.into()) + } +} + +impl<'a, T: Copy, P: Predicate> PredSearcher<'a, T, P> { + #[inline] + fn new(can_match: bool, haystack: &'a [T], predicate: P) -> Self { + Self { + haystack, + predicate, + front: if can_match { 0 } else { haystack.len() }, + } + } +} + +impl<'a, T: Copy, M: Predicate> Searcher<'a> for PredSearcher<'a, T, M> { + fn next(&mut self) -> SearchStep { + let c = match self.haystack.get(self.front) { + None => return SearchStep::Done, + Some(c) => *c, + }; + + let i = self.front; + self.front += 1; + if self.predicate.is_match(c) { + SearchStep::Match(i, i + 1) + } else { + SearchStep::Reject(i, i + 1) + } + } + + fn next_back(&mut self) -> SearchStep { + let len = self.haystack.len(); + if self.front >= len { + return SearchStep::Done; + } + let c = self.haystack[len - 1]; + self.haystack = &self.haystack[..len - 1]; + if self.predicate.is_match(c) { + SearchStep::Match(len - 1, len) + } else { + SearchStep::Reject(len - 1, len) + } + } +} + +pub struct SliceSearcher<'a, T> { + haystack: &'a [T], + needle: &'a [T], + front: usize, + back: usize, +} + +impl<'a, T> SliceSearcher<'a, T> { + fn new(haystack: &'a [T], needle: &'a [T]) -> Self { + debug_assert!(!needle.is_empty()); + let (front, back) = match haystack.len().checked_sub(needle.len()) { + Some(i) => (0, i), + None => (1, 0), + }; + Self { + haystack, + needle, + front, + back, + } + } +} + +impl<'a, T: Eq> Searcher<'a> for SliceSearcher<'a, T> { + fn next(&mut self) -> SearchStep { + if self.front > self.back { + return SearchStep::Done; + } + + let start = self.front; + let end = self.front + self.needle.len(); + if &self.haystack[start..end] == self.needle { + self.front = end; + SearchStep::Match(start, end) + } else { + self.front += 1; + SearchStep::Reject(start, start + 1) + } + } + + fn next_back(&mut self) -> SearchStep { + if self.front > self.back { + return SearchStep::Done; + } + + let start = self.back; + let end = self.back + self.needle.len(); + if &self.haystack[start..end] == self.needle { + if let Some(back) = start.checked_sub(self.needle.len()) { + self.back = back; + } else { + self.front = 1; + self.back = 0; + } + SearchStep::Match(start, end) + } else { + if self.back == 0 { + self.front = 1; + } else { + self.back -= 1; + } + SearchStep::Reject(end - 1, end) + } + } +} + +pub struct StrSearcher<'a> { + haystack: WStr<'a>, + needle: WStr<'a>, + front: usize, + back: usize, +} + +impl<'a> StrSearcher<'a> { + fn new(haystack: WStr<'a>, needle: WStr<'a>) -> Self { + debug_assert!(!needle.is_empty()); + let (front, back) = match haystack.len().checked_sub(needle.len()) { + Some(i) => (0, i), + None => (1, 0), + }; + Self { + haystack, + needle, + front, + back, + } + } +} + +impl<'a> Searcher<'a> for StrSearcher<'a> { + fn next(&mut self) -> SearchStep { + if self.front > self.back { + return SearchStep::Done; + } + + let start = self.front; + let end = self.front + self.needle.len(); + if self.haystack.slice(start..end) == self.needle { + self.front = end; + SearchStep::Match(start, end) + } else { + self.front += 1; + SearchStep::Reject(start, start + 1) + } + } + + fn next_back(&mut self) -> SearchStep { + if self.front > self.back { + return SearchStep::Done; + } + + let start = self.back; + let end = start + self.needle.len(); + if self.haystack.slice(start..end) == self.needle { + if let Some(back) = start.checked_sub(self.needle.len()) { + self.back = back; + } else { + self.front = 1; + self.back = 0; + } + SearchStep::Match(start, end) + } else { + if self.back == 0 { + self.front = 1; + } else { + self.back -= 1; + } + SearchStep::Reject(end - 1, end) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fmt::Debug; + + macro_rules! bstr { + ($str:literal) => { + WStr::from_units($str) + }; + } + + macro_rules! wstr { + ($($char:literal)*) => { + WStr::from_units(&[$($char as u16),*]) + } + } + + fn test_pattern<'a, P: Pattern<'a> + Clone + Debug>( + haystack: WStr<'a>, + pattern: P, + forwards: &[(usize, usize)], + backwards: Option<&[(usize, usize)]>, + ) { + let mut searcher = pattern.clone().into_searcher(haystack); + let mut actual: Vec<_> = std::iter::from_fn(|| searcher.next_match()).collect(); + assert_eq!( + actual, forwards, + "incorrect forwards matching: haystack={:?}; pattern={:?}", + haystack, pattern + ); + + searcher = pattern.clone().into_searcher(haystack); + actual = std::iter::from_fn(|| searcher.next_match_back()).collect(); + actual.reverse(); + assert_eq!( + actual, + backwards.unwrap_or(forwards), + "incorrect backwards matching: haystack={:?}; pattern={:?}", + haystack, + pattern + ); + } + + #[test] + fn char_patterns() { + test_pattern(bstr!(b"a"), b'a', &[(0, 1)], None); + + let bytes = bstr!(b"abaabbcab"); + test_pattern(bytes, b'b', &[(1, 2), (4, 5), (5, 6), (8, 9)], None); + test_pattern(bytes, b'd', &[], None); + test_pattern(bytes, 'c' as u16, &[(6, 7)], None); + test_pattern(bytes, '↓' as u16, &[], None); + + let wide = wstr!('↓''a''a''↓''a'); + test_pattern(wide, b'c', &[], None); + test_pattern(wide, '↓' as u16, &[(0, 1), (3, 4)], None); + } + + #[test] + fn multi_char_patterns() { + let bytes = bstr!(b"abcdabcd"); + let matches = &[(0, 1), (2, 3), (4, 5), (6, 7)]; + test_pattern(bytes, &[b'a', b'c'][..], matches, None); + test_pattern(bytes, &['a' as u16, 'c' as u16][..], matches, None); + + let wide = wstr!('↓''a''b''↓''b''c'); + test_pattern(wide, &[b'a', b'b'][..], &[(1, 2), (2, 3), (4, 5)], None); + test_pattern(wide, &['↓' as u16, '−' as u16][..], &[(0, 1), (3, 4)], None); + + // Don't test `FnMut(u16) -> bool` because it isn't `Debug` + } + + #[test] + fn str_patterns() { + test_pattern(bstr!(b"aa"), bstr!(b""), &[(0, 0), (1, 1), (2, 2)], None); + test_pattern(bstr!(b"abcde"), bstr!(b"abcde"), &[(0, 5)], None); + + let bytes = bstr!(b"bbabbbabbbba"); + let matches = &[(0, 2), (3, 5), (7, 9), (9, 11)]; + let matches_rev = &[(0, 2), (4, 6), (7, 9), (9, 11)]; + test_pattern(bytes, bstr!(b"bb"), matches, Some(matches_rev)); + test_pattern(bytes, wstr!('b''b'), matches, Some(matches_rev)); + + let wide = wstr!('↓''↓''a''a''↓''↓''a''a''↓''↓'); + test_pattern(wide, bstr!(b"aa"), &[(2, 4), (6, 8)], None); + test_pattern(wide, wstr!('↓''a'), &[(1, 3), (5, 7)], None); + } +} diff --git a/core/src/string/slice.rs b/core/src/string/slice.rs index 2e1320679..33b339dc8 100644 --- a/core/src/string/slice.rs +++ b/core/src/string/slice.rs @@ -46,6 +46,7 @@ impl<'a> WStr<'a> { lifetime: 'a; self: Self; deref: self; + pattern[]: 'a, Self; } } @@ -110,6 +111,7 @@ impl<'a> WStrMut<'a> { lifetime: '_; self: &Self; deref: self.borrow(); + pattern['b,]: 'b, &'b Self; } impl_str_mut_methods! {