|
1 |
| -use std::cmp::min; |
| 1 | +//! A module providing a Segment Tree data structure for efficient range queries |
| 2 | +//! and updates. It supports operations like finding the minimum, maximum, |
| 3 | +//! and sum of segments in an array. |
| 4 | +
|
2 | 5 | use std::fmt::Debug;
|
3 | 6 | use std::ops::Range;
|
4 | 7 |
|
5 |
| -/// This data structure implements a segment-tree that can efficiently answer range (interval) queries on arrays. |
6 |
| -/// It represents this array as a binary tree of merged intervals. From top to bottom: [aggregated value for the overall array], then [left-hand half, right hand half], etc. until [each individual value, ...] |
7 |
| -/// It is generic over a reduction function for each segment or interval: basically, to describe how we merge two intervals together. |
8 |
| -/// Note that this function should be commutative and associative |
9 |
| -/// It could be `std::cmp::min(interval_1, interval_2)` or `std::cmp::max(interval_1, interval_2)`, or `|a, b| a + b`, `|a, b| a * b` |
10 |
| -pub struct SegmentTree<T: Debug + Default + Ord + Copy> { |
11 |
| - len: usize, // length of the represented |
12 |
| - tree: Vec<T>, // represents a binary tree of intervals as an array (as a BinaryHeap does, for instance) |
13 |
| - merge: fn(T, T) -> T, // how we merge two values together |
| 8 | +/// Custom error types representing possible errors that can occur during operations on the `SegmentTree`. |
| 9 | +#[derive(Debug, PartialEq, Eq)] |
| 10 | +pub enum SegmentTreeError { |
| 11 | + /// Error indicating that an index is out of bounds. |
| 12 | + IndexOutOfBounds, |
| 13 | + /// Error indicating that a range provided for a query is invalid. |
| 14 | + InvalidRange, |
| 15 | +} |
| 16 | + |
| 17 | +/// A structure representing a Segment Tree. This tree can be used to efficiently |
| 18 | +/// perform range queries and updates on an array of elements. |
| 19 | +pub struct SegmentTree<T, F> |
| 20 | +where |
| 21 | + T: Debug + Default + Ord + Copy, |
| 22 | + F: Fn(T, T) -> T, |
| 23 | +{ |
| 24 | + /// The length of the input array for which the segment tree is built. |
| 25 | + size: usize, |
| 26 | + /// A vector representing the segment tree. |
| 27 | + nodes: Vec<T>, |
| 28 | + /// A merging function defined as a closure or callable type. |
| 29 | + merge_fn: F, |
14 | 30 | }
|
15 | 31 |
|
16 |
| -impl<T: Debug + Default + Ord + Copy> SegmentTree<T> { |
17 |
| - /// Builds a SegmentTree from an array and a merge function |
18 |
| - pub fn from_vec(arr: &[T], merge: fn(T, T) -> T) -> Self { |
19 |
| - let len = arr.len(); |
20 |
| - let mut buf: Vec<T> = vec![T::default(); 2 * len]; |
21 |
| - // Populate the tree bottom-up, from right to left |
22 |
| - buf[len..(2 * len)].clone_from_slice(&arr[0..len]); // last len pos is the bottom of the tree -> every individual value |
23 |
| - for i in (1..len).rev() { |
24 |
| - // a nice property of this "flat" representation of a tree: the parent of an element at index i is located at index i/2 |
25 |
| - buf[i] = merge(buf[2 * i], buf[2 * i + 1]); |
| 32 | +impl<T, F> SegmentTree<T, F> |
| 33 | +where |
| 34 | + T: Debug + Default + Ord + Copy, |
| 35 | + F: Fn(T, T) -> T, |
| 36 | +{ |
| 37 | + /// Creates a new `SegmentTree` from the provided slice of elements. |
| 38 | + /// |
| 39 | + /// # Arguments |
| 40 | + /// |
| 41 | + /// * `arr`: A slice of elements of type `T` to initialize the segment tree. |
| 42 | + /// * `merge`: A merging function that defines how to merge two elements of type `T`. |
| 43 | + /// |
| 44 | + /// # Returns |
| 45 | + /// |
| 46 | + /// A new `SegmentTree` instance populated with the given elements. |
| 47 | + pub fn from_vec(arr: &[T], merge: F) -> Self { |
| 48 | + let size = arr.len(); |
| 49 | + let mut buffer: Vec<T> = vec![T::default(); 2 * size]; |
| 50 | + |
| 51 | + // Populate the leaves of the tree |
| 52 | + buffer[size..(2 * size)].clone_from_slice(arr); |
| 53 | + for idx in (1..size).rev() { |
| 54 | + buffer[idx] = merge(buffer[2 * idx], buffer[2 * idx + 1]); |
26 | 55 | }
|
| 56 | + |
27 | 57 | SegmentTree {
|
28 |
| - len, |
29 |
| - tree: buf, |
30 |
| - merge, |
| 58 | + size, |
| 59 | + nodes: buffer, |
| 60 | + merge_fn: merge, |
31 | 61 | }
|
32 | 62 | }
|
33 | 63 |
|
34 |
| - /// Query the range (exclusive) |
35 |
| - /// returns None if the range is out of the array's boundaries (eg: if start is after the end of the array, or start > end, etc.) |
36 |
| - /// return the aggregate of values over this range otherwise |
37 |
| - pub fn query(&self, range: Range<usize>) -> Option<T> { |
38 |
| - let mut l = range.start + self.len; |
39 |
| - let mut r = min(self.len, range.end) + self.len; |
40 |
| - let mut res = None; |
41 |
| - // Check Wikipedia or other detailed explanations here for how to navigate the tree bottom-up to limit the number of operations |
42 |
| - while l < r { |
43 |
| - if l % 2 == 1 { |
44 |
| - res = Some(match res { |
45 |
| - None => self.tree[l], |
46 |
| - Some(old) => (self.merge)(old, self.tree[l]), |
| 64 | + /// Queries the segment tree for the result of merging the elements in the given range. |
| 65 | + /// |
| 66 | + /// # Arguments |
| 67 | + /// |
| 68 | + /// * `range`: A range specified as `Range<usize>`, indicating the start (inclusive) |
| 69 | + /// and end (exclusive) indices of the segment to query. |
| 70 | + /// |
| 71 | + /// # Returns |
| 72 | + /// |
| 73 | + /// * `Ok(Some(result))` if the query was successful and there are elements in the range, |
| 74 | + /// * `Ok(None)` if the range is empty, |
| 75 | + /// * `Err(SegmentTreeError::InvalidRange)` if the provided range is invalid. |
| 76 | + pub fn query(&self, range: Range<usize>) -> Result<Option<T>, SegmentTreeError> { |
| 77 | + if range.start >= self.size || range.end > self.size { |
| 78 | + return Err(SegmentTreeError::InvalidRange); |
| 79 | + } |
| 80 | + |
| 81 | + let mut left = range.start + self.size; |
| 82 | + let mut right = range.end + self.size; |
| 83 | + let mut result = None; |
| 84 | + |
| 85 | + // Iterate through the segment tree to accumulate results |
| 86 | + while left < right { |
| 87 | + if left % 2 == 1 { |
| 88 | + result = Some(match result { |
| 89 | + None => self.nodes[left], |
| 90 | + Some(old) => (self.merge_fn)(old, self.nodes[left]), |
47 | 91 | });
|
48 |
| - l += 1; |
| 92 | + left += 1; |
49 | 93 | }
|
50 |
| - if r % 2 == 1 { |
51 |
| - r -= 1; |
52 |
| - res = Some(match res { |
53 |
| - None => self.tree[r], |
54 |
| - Some(old) => (self.merge)(old, self.tree[r]), |
| 94 | + if right % 2 == 1 { |
| 95 | + right -= 1; |
| 96 | + result = Some(match result { |
| 97 | + None => self.nodes[right], |
| 98 | + Some(old) => (self.merge_fn)(old, self.nodes[right]), |
55 | 99 | });
|
56 | 100 | }
|
57 |
| - l /= 2; |
58 |
| - r /= 2; |
| 101 | + left /= 2; |
| 102 | + right /= 2; |
59 | 103 | }
|
60 |
| - res |
| 104 | + |
| 105 | + Ok(result) |
61 | 106 | }
|
62 | 107 |
|
63 |
| - /// Updates the value at index `idx` in the original array with a new value `val` |
64 |
| - pub fn update(&mut self, idx: usize, val: T) { |
65 |
| - // change every value where `idx` plays a role, bottom -> up |
66 |
| - // 1: change in the right-hand side of the tree (bottom row) |
67 |
| - let mut idx = idx + self.len; |
68 |
| - self.tree[idx] = val; |
69 |
| - |
70 |
| - // 2: then bubble up |
71 |
| - idx /= 2; |
72 |
| - while idx != 0 { |
73 |
| - self.tree[idx] = (self.merge)(self.tree[2 * idx], self.tree[2 * idx + 1]); |
74 |
| - idx /= 2; |
| 108 | + /// Updates the value at the specified index in the segment tree. |
| 109 | + /// |
| 110 | + /// # Arguments |
| 111 | + /// |
| 112 | + /// * `idx`: The index (0-based) of the element to update. |
| 113 | + /// * `val`: The new value of type `T` to set at the specified index. |
| 114 | + /// |
| 115 | + /// # Returns |
| 116 | + /// |
| 117 | + /// * `Ok(())` if the update was successful, |
| 118 | + /// * `Err(SegmentTreeError::IndexOutOfBounds)` if the index is out of bounds. |
| 119 | + pub fn update(&mut self, idx: usize, val: T) -> Result<(), SegmentTreeError> { |
| 120 | + if idx >= self.size { |
| 121 | + return Err(SegmentTreeError::IndexOutOfBounds); |
| 122 | + } |
| 123 | + |
| 124 | + let mut index = idx + self.size; |
| 125 | + if self.nodes[index] == val { |
| 126 | + return Ok(()); |
75 | 127 | }
|
| 128 | + |
| 129 | + self.nodes[index] = val; |
| 130 | + while index > 1 { |
| 131 | + index /= 2; |
| 132 | + self.nodes[index] = (self.merge_fn)(self.nodes[2 * index], self.nodes[2 * index + 1]); |
| 133 | + } |
| 134 | + |
| 135 | + Ok(()) |
76 | 136 | }
|
77 | 137 | }
|
78 | 138 |
|
79 | 139 | #[cfg(test)]
|
80 | 140 | mod tests {
|
81 | 141 | use super::*;
|
82 |
| - use quickcheck::TestResult; |
83 |
| - use quickcheck_macros::quickcheck; |
84 | 142 | use std::cmp::{max, min};
|
85 | 143 |
|
86 | 144 | #[test]
|
87 | 145 | fn test_min_segments() {
|
88 | 146 | let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8];
|
89 |
| - let min_seg_tree = SegmentTree::from_vec(&vec, min); |
90 |
| - assert_eq!(Some(-5), min_seg_tree.query(4..7)); |
91 |
| - assert_eq!(Some(-30), min_seg_tree.query(0..vec.len())); |
92 |
| - assert_eq!(Some(-30), min_seg_tree.query(0..2)); |
93 |
| - assert_eq!(Some(-4), min_seg_tree.query(1..3)); |
94 |
| - assert_eq!(Some(-5), min_seg_tree.query(1..7)); |
| 147 | + let mut min_seg_tree = SegmentTree::from_vec(&vec, min); |
| 148 | + assert_eq!(min_seg_tree.query(4..7), Ok(Some(-5))); |
| 149 | + assert_eq!(min_seg_tree.query(0..vec.len()), Ok(Some(-30))); |
| 150 | + assert_eq!(min_seg_tree.query(0..2), Ok(Some(-30))); |
| 151 | + assert_eq!(min_seg_tree.query(1..3), Ok(Some(-4))); |
| 152 | + assert_eq!(min_seg_tree.query(1..7), Ok(Some(-5))); |
| 153 | + assert_eq!(min_seg_tree.update(5, 10), Ok(())); |
| 154 | + assert_eq!(min_seg_tree.update(14, -8), Ok(())); |
| 155 | + assert_eq!(min_seg_tree.query(4..7), Ok(Some(3))); |
| 156 | + assert_eq!( |
| 157 | + min_seg_tree.update(15, 100), |
| 158 | + Err(SegmentTreeError::IndexOutOfBounds) |
| 159 | + ); |
| 160 | + assert_eq!(min_seg_tree.query(5..5), Ok(None)); |
| 161 | + assert_eq!( |
| 162 | + min_seg_tree.query(10..16), |
| 163 | + Err(SegmentTreeError::InvalidRange) |
| 164 | + ); |
| 165 | + assert_eq!( |
| 166 | + min_seg_tree.query(15..20), |
| 167 | + Err(SegmentTreeError::InvalidRange) |
| 168 | + ); |
95 | 169 | }
|
96 | 170 |
|
97 | 171 | #[test]
|
98 | 172 | fn test_max_segments() {
|
99 |
| - let val_at_6 = 6; |
100 |
| - let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8]; |
| 173 | + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; |
101 | 174 | let mut max_seg_tree = SegmentTree::from_vec(&vec, max);
|
102 |
| - assert_eq!(Some(15), max_seg_tree.query(0..vec.len())); |
103 |
| - let max_4_to_6 = 6; |
104 |
| - assert_eq!(Some(max_4_to_6), max_seg_tree.query(4..7)); |
105 |
| - let delta = 2; |
106 |
| - max_seg_tree.update(6, val_at_6 + delta); |
107 |
| - assert_eq!(Some(val_at_6 + delta), max_seg_tree.query(4..7)); |
| 175 | + assert_eq!(max_seg_tree.query(0..vec.len()), Ok(Some(15))); |
| 176 | + assert_eq!(max_seg_tree.query(3..5), Ok(Some(7))); |
| 177 | + assert_eq!(max_seg_tree.query(4..8), Ok(Some(11))); |
| 178 | + assert_eq!(max_seg_tree.query(8..10), Ok(Some(9))); |
| 179 | + assert_eq!(max_seg_tree.query(9..12), Ok(Some(15))); |
| 180 | + assert_eq!(max_seg_tree.update(4, 10), Ok(())); |
| 181 | + assert_eq!(max_seg_tree.update(14, -8), Ok(())); |
| 182 | + assert_eq!(max_seg_tree.query(3..5), Ok(Some(10))); |
| 183 | + assert_eq!( |
| 184 | + max_seg_tree.update(15, 100), |
| 185 | + Err(SegmentTreeError::IndexOutOfBounds) |
| 186 | + ); |
| 187 | + assert_eq!(max_seg_tree.query(5..5), Ok(None)); |
| 188 | + assert_eq!( |
| 189 | + max_seg_tree.query(10..16), |
| 190 | + Err(SegmentTreeError::InvalidRange) |
| 191 | + ); |
| 192 | + assert_eq!( |
| 193 | + max_seg_tree.query(15..20), |
| 194 | + Err(SegmentTreeError::InvalidRange) |
| 195 | + ); |
108 | 196 | }
|
109 | 197 |
|
110 | 198 | #[test]
|
111 | 199 | fn test_sum_segments() {
|
112 |
| - let val_at_6 = 6; |
113 |
| - let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8]; |
| 200 | + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; |
114 | 201 | let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b);
|
115 |
| - for (i, val) in vec.iter().enumerate() { |
116 |
| - assert_eq!(Some(*val), sum_seg_tree.query(i..(i + 1))); |
117 |
| - } |
118 |
| - let sum_4_to_6 = sum_seg_tree.query(4..7); |
119 |
| - assert_eq!(Some(4), sum_4_to_6); |
120 |
| - let delta = 3; |
121 |
| - sum_seg_tree.update(6, val_at_6 + delta); |
| 202 | + assert_eq!(sum_seg_tree.query(0..vec.len()), Ok(Some(38))); |
| 203 | + assert_eq!(sum_seg_tree.query(1..4), Ok(Some(5))); |
| 204 | + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(4))); |
| 205 | + assert_eq!(sum_seg_tree.query(6..9), Ok(Some(-3))); |
| 206 | + assert_eq!(sum_seg_tree.query(9..vec.len()), Ok(Some(37))); |
| 207 | + assert_eq!(sum_seg_tree.update(5, 10), Ok(())); |
| 208 | + assert_eq!(sum_seg_tree.update(14, -8), Ok(())); |
| 209 | + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(19))); |
122 | 210 | assert_eq!(
|
123 |
| - sum_4_to_6.unwrap() + delta, |
124 |
| - sum_seg_tree.query(4..7).unwrap() |
| 211 | + sum_seg_tree.update(15, 100), |
| 212 | + Err(SegmentTreeError::IndexOutOfBounds) |
| 213 | + ); |
| 214 | + assert_eq!(sum_seg_tree.query(5..5), Ok(None)); |
| 215 | + assert_eq!( |
| 216 | + sum_seg_tree.query(10..16), |
| 217 | + Err(SegmentTreeError::InvalidRange) |
| 218 | + ); |
| 219 | + assert_eq!( |
| 220 | + sum_seg_tree.query(15..20), |
| 221 | + Err(SegmentTreeError::InvalidRange) |
125 | 222 | );
|
126 |
| - } |
127 |
| - |
128 |
| - // Some properties over segment trees: |
129 |
| - // When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc. |
130 |
| - // When asking for an interval containing a single value, return this value, no matter the merge function |
131 |
| - |
132 |
| - #[quickcheck] |
133 |
| - fn check_overall_interval_min(array: Vec<i32>) -> TestResult { |
134 |
| - let seg_tree = SegmentTree::from_vec(&array, min); |
135 |
| - TestResult::from_bool(array.iter().min().copied() == seg_tree.query(0..array.len())) |
136 |
| - } |
137 |
| - |
138 |
| - #[quickcheck] |
139 |
| - fn check_overall_interval_max(array: Vec<i32>) -> TestResult { |
140 |
| - let seg_tree = SegmentTree::from_vec(&array, max); |
141 |
| - TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) |
142 |
| - } |
143 |
| - |
144 |
| - #[quickcheck] |
145 |
| - fn check_overall_interval_sum(array: Vec<i32>) -> TestResult { |
146 |
| - let seg_tree = SegmentTree::from_vec(&array, max); |
147 |
| - TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) |
148 |
| - } |
149 |
| - |
150 |
| - #[quickcheck] |
151 |
| - fn check_single_interval_min(array: Vec<i32>) -> TestResult { |
152 |
| - let seg_tree = SegmentTree::from_vec(&array, min); |
153 |
| - for (i, value) in array.into_iter().enumerate() { |
154 |
| - let res = seg_tree.query(i..(i + 1)); |
155 |
| - if res != Some(value) { |
156 |
| - return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); |
157 |
| - } |
158 |
| - } |
159 |
| - TestResult::passed() |
160 |
| - } |
161 |
| - |
162 |
| - #[quickcheck] |
163 |
| - fn check_single_interval_max(array: Vec<i32>) -> TestResult { |
164 |
| - let seg_tree = SegmentTree::from_vec(&array, max); |
165 |
| - for (i, value) in array.into_iter().enumerate() { |
166 |
| - let res = seg_tree.query(i..(i + 1)); |
167 |
| - if res != Some(value) { |
168 |
| - return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); |
169 |
| - } |
170 |
| - } |
171 |
| - TestResult::passed() |
172 |
| - } |
173 |
| - |
174 |
| - #[quickcheck] |
175 |
| - fn check_single_interval_sum(array: Vec<i32>) -> TestResult { |
176 |
| - let seg_tree = SegmentTree::from_vec(&array, max); |
177 |
| - for (i, value) in array.into_iter().enumerate() { |
178 |
| - let res = seg_tree.query(i..(i + 1)); |
179 |
| - if res != Some(value) { |
180 |
| - return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); |
181 |
| - } |
182 |
| - } |
183 |
| - TestResult::passed() |
184 | 223 | }
|
185 | 224 | }
|
0 commit comments