Skip to content

Commit 209c1b4

Browse files
authored
Refactor Rabin Karp (TheAlgorithms#810)
* ref: refactor rabin karp * chore(tests): add some edge tests
1 parent c3da55f commit 209c1b4

File tree

1 file changed

+93
-107
lines changed

1 file changed

+93
-107
lines changed

src/string/rabin_karp.rs

Lines changed: 93 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,137 +1,123 @@
1-
const MODULUS: u16 = 101;
2-
const BASE: u16 = 256;
3-
4-
pub fn rabin_karp(target: &str, pattern: &str) -> Vec<usize> {
5-
// Quick exit
6-
if target.is_empty() || pattern.is_empty() || pattern.len() > target.len() {
1+
//! This module implements the Rabin-Karp string searching algorithm.
2+
//! It uses a rolling hash technique to find all occurrences of a pattern
3+
//! within a target string efficiently.
4+
5+
const MOD: usize = 101;
6+
const RADIX: usize = 256;
7+
8+
/// Finds all starting indices where the `pattern` appears in the `text`.
9+
///
10+
/// # Arguments
11+
/// * `text` - The string where the search is performed.
12+
/// * `pattern` - The substring pattern to search for.
13+
///
14+
/// # Returns
15+
/// A vector of starting indices where the pattern is found.
16+
pub fn rabin_karp(text: &str, pattern: &str) -> Vec<usize> {
17+
if text.is_empty() || pattern.is_empty() || pattern.len() > text.len() {
718
return vec![];
819
}
920

10-
let pattern_hash = hash(pattern);
21+
let pat_hash = compute_hash(pattern);
22+
let mut radix_pow = 1;
1123

12-
// Pre-calculate BASE^(n-1)
13-
let mut pow_rem: u16 = 1;
24+
// Compute RADIX^(n-1) % MOD
1425
for _ in 0..pattern.len() - 1 {
15-
pow_rem *= BASE;
16-
pow_rem %= MODULUS;
26+
radix_pow = (radix_pow * RADIX) % MOD;
1727
}
1828

1929
let mut rolling_hash = 0;
20-
let mut ret = vec![];
21-
for i in 0..=target.len() - pattern.len() {
30+
let mut result = vec![];
31+
for i in 0..=text.len() - pattern.len() {
2232
rolling_hash = if i == 0 {
23-
hash(&target[0..pattern.len()])
33+
compute_hash(&text[0..pattern.len()])
2434
} else {
25-
recalculate_hash(target, i - 1, i + pattern.len() - 1, rolling_hash, pow_rem)
35+
update_hash(text, i - 1, i + pattern.len() - 1, rolling_hash, radix_pow)
2636
};
27-
if rolling_hash == pattern_hash && pattern[..] == target[i..i + pattern.len()] {
28-
ret.push(i);
37+
if rolling_hash == pat_hash && pattern[..] == text[i..i + pattern.len()] {
38+
result.push(i);
2939
}
3040
}
31-
ret
41+
result
3242
}
3343

34-
// hash(s) is defined as BASE^(n-1) * s_0 + BASE^(n-2) * s_1 + ... + BASE^0 * s_(n-1)
35-
fn hash(s: &str) -> u16 {
36-
let mut res: u16 = 0;
37-
for &c in s.as_bytes().iter() {
38-
res = (res * BASE % MODULUS + c as u16) % MODULUS;
39-
}
40-
res
44+
/// Calculates the hash of a string using the Rabin-Karp formula.
45+
///
46+
/// # Arguments
47+
/// * `s` - The string to calculate the hash for.
48+
///
49+
/// # Returns
50+
/// The hash value of the string modulo `MOD`.
51+
fn compute_hash(s: &str) -> usize {
52+
let mut hash_val = 0;
53+
for &byte in s.as_bytes().iter() {
54+
hash_val = (hash_val * RADIX + byte as usize) % MOD;
55+
}
56+
hash_val
4157
}
4258

43-
// new_hash = (old_hash - BASE^(n-1) * s_(i-n)) * BASE + s_i
44-
fn recalculate_hash(
59+
/// Updates the rolling hash when shifting the search window.
60+
///
61+
/// # Arguments
62+
/// * `s` - The full text where the search is performed.
63+
/// * `old_idx` - The index of the character that is leaving the window.
64+
/// * `new_idx` - The index of the new character entering the window.
65+
/// * `old_hash` - The hash of the previous substring.
66+
/// * `radix_pow` - The precomputed value of RADIX^(n-1) % MOD.
67+
///
68+
/// # Returns
69+
/// The updated hash for the new substring.
70+
fn update_hash(
4571
s: &str,
46-
old_index: usize,
47-
new_index: usize,
48-
old_hash: u16,
49-
pow_rem: u16,
50-
) -> u16 {
72+
old_idx: usize,
73+
new_idx: usize,
74+
old_hash: usize,
75+
radix_pow: usize,
76+
) -> usize {
5177
let mut new_hash = old_hash;
52-
let (old_ch, new_ch) = (
53-
s.as_bytes()[old_index] as u16,
54-
s.as_bytes()[new_index] as u16,
55-
);
56-
new_hash = (new_hash + MODULUS - pow_rem * old_ch % MODULUS) % MODULUS;
57-
new_hash = (new_hash * BASE + new_ch) % MODULUS;
78+
let old_char = s.as_bytes()[old_idx] as usize;
79+
let new_char = s.as_bytes()[new_idx] as usize;
80+
new_hash = (new_hash + MOD - (old_char * radix_pow % MOD)) % MOD;
81+
new_hash = (new_hash * RADIX + new_char) % MOD;
5882
new_hash
5983
}
6084

6185
#[cfg(test)]
6286
mod tests {
6387
use super::*;
6488

65-
#[test]
66-
fn hi_hash() {
67-
let hash_result = hash("hi");
68-
assert_eq!(hash_result, 65);
69-
}
70-
71-
#[test]
72-
fn abr_hash() {
73-
let hash_result = hash("abr");
74-
assert_eq!(hash_result, 4);
75-
}
76-
77-
#[test]
78-
fn bra_hash() {
79-
let hash_result = hash("bra");
80-
assert_eq!(hash_result, 30);
81-
}
82-
83-
// Attribution to @pgimalac for his tests from Knuth-Morris-Pratt
84-
#[test]
85-
fn each_letter_matches() {
86-
let index = rabin_karp("aaa", "a");
87-
assert_eq!(index, vec![0, 1, 2]);
88-
}
89-
90-
#[test]
91-
fn a_few_separate_matches() {
92-
let index = rabin_karp("abababa", "ab");
93-
assert_eq!(index, vec![0, 2, 4]);
94-
}
95-
96-
#[test]
97-
fn one_match() {
98-
let index = rabin_karp("ABC ABCDAB ABCDABCDABDE", "ABCDABD");
99-
assert_eq!(index, vec![15]);
100-
}
101-
102-
#[test]
103-
fn lots_of_matches() {
104-
let index = rabin_karp("aaabaabaaaaa", "aa");
105-
assert_eq!(index, vec![0, 1, 4, 7, 8, 9, 10]);
106-
}
107-
108-
#[test]
109-
fn lots_of_intricate_matches() {
110-
let index = rabin_karp("ababababa", "aba");
111-
assert_eq!(index, vec![0, 2, 4, 6]);
112-
}
113-
114-
#[test]
115-
fn not_found0() {
116-
let index = rabin_karp("abcde", "f");
117-
assert_eq!(index, vec![]);
118-
}
119-
120-
#[test]
121-
fn not_found1() {
122-
let index = rabin_karp("abcde", "ac");
123-
assert_eq!(index, vec![]);
124-
}
125-
126-
#[test]
127-
fn not_found2() {
128-
let index = rabin_karp("ababab", "bababa");
129-
assert_eq!(index, vec![]);
89+
macro_rules! test_cases {
90+
($($name:ident: $inputs:expr,)*) => {
91+
$(
92+
#[test]
93+
fn $name() {
94+
let (text, pattern, expected) = $inputs;
95+
assert_eq!(rabin_karp(text, pattern), expected);
96+
}
97+
)*
98+
};
13099
}
131100

132-
#[test]
133-
fn empty_string() {
134-
let index = rabin_karp("", "abcdef");
135-
assert_eq!(index, vec![]);
101+
test_cases! {
102+
single_match_at_start: ("hello world", "hello", vec![0]),
103+
single_match_at_end: ("hello world", "world", vec![6]),
104+
single_match_in_middle: ("abc def ghi", "def", vec![4]),
105+
multiple_matches: ("ababcabc", "abc", vec![2, 5]),
106+
overlapping_matches: ("aaaaa", "aaa", vec![0, 1, 2]),
107+
no_match: ("abcdefg", "xyz", vec![]),
108+
pattern_is_entire_string: ("abc", "abc", vec![0]),
109+
target_is_multiple_patterns: ("abcabcabc", "abc", vec![0, 3, 6]),
110+
empty_text: ("", "abc", vec![]),
111+
empty_pattern: ("abc", "", vec![]),
112+
empty_text_and_pattern: ("", "", vec![]),
113+
pattern_larger_than_text: ("abc", "abcd", vec![]),
114+
large_text_small_pattern: (&("a".repeat(1000) + "b"), "b", vec![1000]),
115+
single_char_match: ("a", "a", vec![0]),
116+
single_char_no_match: ("a", "b", vec![]),
117+
large_pattern_no_match: ("abc", "defghi", vec![]),
118+
repeating_chars: ("aaaaaa", "aa", vec![0, 1, 2, 3, 4]),
119+
special_characters: ("abc$def@ghi", "$def@", vec![3]),
120+
numeric_and_alphabetic_mix: ("abc123abc456", "123abc", vec![3]),
121+
case_sensitivity: ("AbcAbc", "abc", vec![]),
136122
}
137123
}

0 commit comments

Comments
 (0)