Skip to content

Commit 5a83939

Browse files
authored
Refactor Segment Tree Implementation (TheAlgorithms#835)
ref: refactor segment tree
1 parent 5151982 commit 5a83939

File tree

1 file changed

+181
-142
lines changed

1 file changed

+181
-142
lines changed

src/data_structures/segment_tree.rs

Lines changed: 181 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,185 +1,224 @@
1-
use std::cmp::min;
1+
//! A module providing a Segment Tree data structure for efficient range queries
2+
//! and updates. It supports operations like finding the minimum, maximum,
3+
//! and sum of segments in an array.
4+
25
use std::fmt::Debug;
36
use std::ops::Range;
47

5-
/// This data structure implements a segment-tree that can efficiently answer range (interval) queries on arrays.
6-
/// It represents this array as a binary tree of merged intervals. From top to bottom: [aggregated value for the overall array], then [left-hand half, right hand half], etc. until [each individual value, ...]
7-
/// It is generic over a reduction function for each segment or interval: basically, to describe how we merge two intervals together.
8-
/// Note that this function should be commutative and associative
9-
/// It could be `std::cmp::min(interval_1, interval_2)` or `std::cmp::max(interval_1, interval_2)`, or `|a, b| a + b`, `|a, b| a * b`
10-
pub struct SegmentTree<T: Debug + Default + Ord + Copy> {
11-
len: usize, // length of the represented
12-
tree: Vec<T>, // represents a binary tree of intervals as an array (as a BinaryHeap does, for instance)
13-
merge: fn(T, T) -> T, // how we merge two values together
8+
/// Custom error types representing possible errors that can occur during operations on the `SegmentTree`.
9+
#[derive(Debug, PartialEq, Eq)]
10+
pub enum SegmentTreeError {
11+
/// Error indicating that an index is out of bounds.
12+
IndexOutOfBounds,
13+
/// Error indicating that a range provided for a query is invalid.
14+
InvalidRange,
15+
}
16+
17+
/// A structure representing a Segment Tree. This tree can be used to efficiently
18+
/// perform range queries and updates on an array of elements.
19+
pub struct SegmentTree<T, F>
20+
where
21+
T: Debug + Default + Ord + Copy,
22+
F: Fn(T, T) -> T,
23+
{
24+
/// The length of the input array for which the segment tree is built.
25+
size: usize,
26+
/// A vector representing the segment tree.
27+
nodes: Vec<T>,
28+
/// A merging function defined as a closure or callable type.
29+
merge_fn: F,
1430
}
1531

16-
impl<T: Debug + Default + Ord + Copy> SegmentTree<T> {
17-
/// Builds a SegmentTree from an array and a merge function
18-
pub fn from_vec(arr: &[T], merge: fn(T, T) -> T) -> Self {
19-
let len = arr.len();
20-
let mut buf: Vec<T> = vec![T::default(); 2 * len];
21-
// Populate the tree bottom-up, from right to left
22-
buf[len..(2 * len)].clone_from_slice(&arr[0..len]); // last len pos is the bottom of the tree -> every individual value
23-
for i in (1..len).rev() {
24-
// a nice property of this "flat" representation of a tree: the parent of an element at index i is located at index i/2
25-
buf[i] = merge(buf[2 * i], buf[2 * i + 1]);
32+
impl<T, F> SegmentTree<T, F>
33+
where
34+
T: Debug + Default + Ord + Copy,
35+
F: Fn(T, T) -> T,
36+
{
37+
/// Creates a new `SegmentTree` from the provided slice of elements.
38+
///
39+
/// # Arguments
40+
///
41+
/// * `arr`: A slice of elements of type `T` to initialize the segment tree.
42+
/// * `merge`: A merging function that defines how to merge two elements of type `T`.
43+
///
44+
/// # Returns
45+
///
46+
/// A new `SegmentTree` instance populated with the given elements.
47+
pub fn from_vec(arr: &[T], merge: F) -> Self {
48+
let size = arr.len();
49+
let mut buffer: Vec<T> = vec![T::default(); 2 * size];
50+
51+
// Populate the leaves of the tree
52+
buffer[size..(2 * size)].clone_from_slice(arr);
53+
for idx in (1..size).rev() {
54+
buffer[idx] = merge(buffer[2 * idx], buffer[2 * idx + 1]);
2655
}
56+
2757
SegmentTree {
28-
len,
29-
tree: buf,
30-
merge,
58+
size,
59+
nodes: buffer,
60+
merge_fn: merge,
3161
}
3262
}
3363

34-
/// Query the range (exclusive)
35-
/// returns None if the range is out of the array's boundaries (eg: if start is after the end of the array, or start > end, etc.)
36-
/// return the aggregate of values over this range otherwise
37-
pub fn query(&self, range: Range<usize>) -> Option<T> {
38-
let mut l = range.start + self.len;
39-
let mut r = min(self.len, range.end) + self.len;
40-
let mut res = None;
41-
// Check Wikipedia or other detailed explanations here for how to navigate the tree bottom-up to limit the number of operations
42-
while l < r {
43-
if l % 2 == 1 {
44-
res = Some(match res {
45-
None => self.tree[l],
46-
Some(old) => (self.merge)(old, self.tree[l]),
64+
/// Queries the segment tree for the result of merging the elements in the given range.
65+
///
66+
/// # Arguments
67+
///
68+
/// * `range`: A range specified as `Range<usize>`, indicating the start (inclusive)
69+
/// and end (exclusive) indices of the segment to query.
70+
///
71+
/// # Returns
72+
///
73+
/// * `Ok(Some(result))` if the query was successful and there are elements in the range,
74+
/// * `Ok(None)` if the range is empty,
75+
/// * `Err(SegmentTreeError::InvalidRange)` if the provided range is invalid.
76+
pub fn query(&self, range: Range<usize>) -> Result<Option<T>, SegmentTreeError> {
77+
if range.start >= self.size || range.end > self.size {
78+
return Err(SegmentTreeError::InvalidRange);
79+
}
80+
81+
let mut left = range.start + self.size;
82+
let mut right = range.end + self.size;
83+
let mut result = None;
84+
85+
// Iterate through the segment tree to accumulate results
86+
while left < right {
87+
if left % 2 == 1 {
88+
result = Some(match result {
89+
None => self.nodes[left],
90+
Some(old) => (self.merge_fn)(old, self.nodes[left]),
4791
});
48-
l += 1;
92+
left += 1;
4993
}
50-
if r % 2 == 1 {
51-
r -= 1;
52-
res = Some(match res {
53-
None => self.tree[r],
54-
Some(old) => (self.merge)(old, self.tree[r]),
94+
if right % 2 == 1 {
95+
right -= 1;
96+
result = Some(match result {
97+
None => self.nodes[right],
98+
Some(old) => (self.merge_fn)(old, self.nodes[right]),
5599
});
56100
}
57-
l /= 2;
58-
r /= 2;
101+
left /= 2;
102+
right /= 2;
59103
}
60-
res
104+
105+
Ok(result)
61106
}
62107

63-
/// Updates the value at index `idx` in the original array with a new value `val`
64-
pub fn update(&mut self, idx: usize, val: T) {
65-
// change every value where `idx` plays a role, bottom -> up
66-
// 1: change in the right-hand side of the tree (bottom row)
67-
let mut idx = idx + self.len;
68-
self.tree[idx] = val;
69-
70-
// 2: then bubble up
71-
idx /= 2;
72-
while idx != 0 {
73-
self.tree[idx] = (self.merge)(self.tree[2 * idx], self.tree[2 * idx + 1]);
74-
idx /= 2;
108+
/// Updates the value at the specified index in the segment tree.
109+
///
110+
/// # Arguments
111+
///
112+
/// * `idx`: The index (0-based) of the element to update.
113+
/// * `val`: The new value of type `T` to set at the specified index.
114+
///
115+
/// # Returns
116+
///
117+
/// * `Ok(())` if the update was successful,
118+
/// * `Err(SegmentTreeError::IndexOutOfBounds)` if the index is out of bounds.
119+
pub fn update(&mut self, idx: usize, val: T) -> Result<(), SegmentTreeError> {
120+
if idx >= self.size {
121+
return Err(SegmentTreeError::IndexOutOfBounds);
122+
}
123+
124+
let mut index = idx + self.size;
125+
if self.nodes[index] == val {
126+
return Ok(());
75127
}
128+
129+
self.nodes[index] = val;
130+
while index > 1 {
131+
index /= 2;
132+
self.nodes[index] = (self.merge_fn)(self.nodes[2 * index], self.nodes[2 * index + 1]);
133+
}
134+
135+
Ok(())
76136
}
77137
}
78138

79139
#[cfg(test)]
80140
mod tests {
81141
use super::*;
82-
use quickcheck::TestResult;
83-
use quickcheck_macros::quickcheck;
84142
use std::cmp::{max, min};
85143

86144
#[test]
87145
fn test_min_segments() {
88146
let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8];
89-
let min_seg_tree = SegmentTree::from_vec(&vec, min);
90-
assert_eq!(Some(-5), min_seg_tree.query(4..7));
91-
assert_eq!(Some(-30), min_seg_tree.query(0..vec.len()));
92-
assert_eq!(Some(-30), min_seg_tree.query(0..2));
93-
assert_eq!(Some(-4), min_seg_tree.query(1..3));
94-
assert_eq!(Some(-5), min_seg_tree.query(1..7));
147+
let mut min_seg_tree = SegmentTree::from_vec(&vec, min);
148+
assert_eq!(min_seg_tree.query(4..7), Ok(Some(-5)));
149+
assert_eq!(min_seg_tree.query(0..vec.len()), Ok(Some(-30)));
150+
assert_eq!(min_seg_tree.query(0..2), Ok(Some(-30)));
151+
assert_eq!(min_seg_tree.query(1..3), Ok(Some(-4)));
152+
assert_eq!(min_seg_tree.query(1..7), Ok(Some(-5)));
153+
assert_eq!(min_seg_tree.update(5, 10), Ok(()));
154+
assert_eq!(min_seg_tree.update(14, -8), Ok(()));
155+
assert_eq!(min_seg_tree.query(4..7), Ok(Some(3)));
156+
assert_eq!(
157+
min_seg_tree.update(15, 100),
158+
Err(SegmentTreeError::IndexOutOfBounds)
159+
);
160+
assert_eq!(min_seg_tree.query(5..5), Ok(None));
161+
assert_eq!(
162+
min_seg_tree.query(10..16),
163+
Err(SegmentTreeError::InvalidRange)
164+
);
165+
assert_eq!(
166+
min_seg_tree.query(15..20),
167+
Err(SegmentTreeError::InvalidRange)
168+
);
95169
}
96170

97171
#[test]
98172
fn test_max_segments() {
99-
let val_at_6 = 6;
100-
let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8];
173+
let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8];
101174
let mut max_seg_tree = SegmentTree::from_vec(&vec, max);
102-
assert_eq!(Some(15), max_seg_tree.query(0..vec.len()));
103-
let max_4_to_6 = 6;
104-
assert_eq!(Some(max_4_to_6), max_seg_tree.query(4..7));
105-
let delta = 2;
106-
max_seg_tree.update(6, val_at_6 + delta);
107-
assert_eq!(Some(val_at_6 + delta), max_seg_tree.query(4..7));
175+
assert_eq!(max_seg_tree.query(0..vec.len()), Ok(Some(15)));
176+
assert_eq!(max_seg_tree.query(3..5), Ok(Some(7)));
177+
assert_eq!(max_seg_tree.query(4..8), Ok(Some(11)));
178+
assert_eq!(max_seg_tree.query(8..10), Ok(Some(9)));
179+
assert_eq!(max_seg_tree.query(9..12), Ok(Some(15)));
180+
assert_eq!(max_seg_tree.update(4, 10), Ok(()));
181+
assert_eq!(max_seg_tree.update(14, -8), Ok(()));
182+
assert_eq!(max_seg_tree.query(3..5), Ok(Some(10)));
183+
assert_eq!(
184+
max_seg_tree.update(15, 100),
185+
Err(SegmentTreeError::IndexOutOfBounds)
186+
);
187+
assert_eq!(max_seg_tree.query(5..5), Ok(None));
188+
assert_eq!(
189+
max_seg_tree.query(10..16),
190+
Err(SegmentTreeError::InvalidRange)
191+
);
192+
assert_eq!(
193+
max_seg_tree.query(15..20),
194+
Err(SegmentTreeError::InvalidRange)
195+
);
108196
}
109197

110198
#[test]
111199
fn test_sum_segments() {
112-
let val_at_6 = 6;
113-
let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8];
200+
let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8];
114201
let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b);
115-
for (i, val) in vec.iter().enumerate() {
116-
assert_eq!(Some(*val), sum_seg_tree.query(i..(i + 1)));
117-
}
118-
let sum_4_to_6 = sum_seg_tree.query(4..7);
119-
assert_eq!(Some(4), sum_4_to_6);
120-
let delta = 3;
121-
sum_seg_tree.update(6, val_at_6 + delta);
202+
assert_eq!(sum_seg_tree.query(0..vec.len()), Ok(Some(38)));
203+
assert_eq!(sum_seg_tree.query(1..4), Ok(Some(5)));
204+
assert_eq!(sum_seg_tree.query(4..7), Ok(Some(4)));
205+
assert_eq!(sum_seg_tree.query(6..9), Ok(Some(-3)));
206+
assert_eq!(sum_seg_tree.query(9..vec.len()), Ok(Some(37)));
207+
assert_eq!(sum_seg_tree.update(5, 10), Ok(()));
208+
assert_eq!(sum_seg_tree.update(14, -8), Ok(()));
209+
assert_eq!(sum_seg_tree.query(4..7), Ok(Some(19)));
122210
assert_eq!(
123-
sum_4_to_6.unwrap() + delta,
124-
sum_seg_tree.query(4..7).unwrap()
211+
sum_seg_tree.update(15, 100),
212+
Err(SegmentTreeError::IndexOutOfBounds)
213+
);
214+
assert_eq!(sum_seg_tree.query(5..5), Ok(None));
215+
assert_eq!(
216+
sum_seg_tree.query(10..16),
217+
Err(SegmentTreeError::InvalidRange)
218+
);
219+
assert_eq!(
220+
sum_seg_tree.query(15..20),
221+
Err(SegmentTreeError::InvalidRange)
125222
);
126-
}
127-
128-
// Some properties over segment trees:
129-
// When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc.
130-
// When asking for an interval containing a single value, return this value, no matter the merge function
131-
132-
#[quickcheck]
133-
fn check_overall_interval_min(array: Vec<i32>) -> TestResult {
134-
let seg_tree = SegmentTree::from_vec(&array, min);
135-
TestResult::from_bool(array.iter().min().copied() == seg_tree.query(0..array.len()))
136-
}
137-
138-
#[quickcheck]
139-
fn check_overall_interval_max(array: Vec<i32>) -> TestResult {
140-
let seg_tree = SegmentTree::from_vec(&array, max);
141-
TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len()))
142-
}
143-
144-
#[quickcheck]
145-
fn check_overall_interval_sum(array: Vec<i32>) -> TestResult {
146-
let seg_tree = SegmentTree::from_vec(&array, max);
147-
TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len()))
148-
}
149-
150-
#[quickcheck]
151-
fn check_single_interval_min(array: Vec<i32>) -> TestResult {
152-
let seg_tree = SegmentTree::from_vec(&array, min);
153-
for (i, value) in array.into_iter().enumerate() {
154-
let res = seg_tree.query(i..(i + 1));
155-
if res != Some(value) {
156-
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
157-
}
158-
}
159-
TestResult::passed()
160-
}
161-
162-
#[quickcheck]
163-
fn check_single_interval_max(array: Vec<i32>) -> TestResult {
164-
let seg_tree = SegmentTree::from_vec(&array, max);
165-
for (i, value) in array.into_iter().enumerate() {
166-
let res = seg_tree.query(i..(i + 1));
167-
if res != Some(value) {
168-
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
169-
}
170-
}
171-
TestResult::passed()
172-
}
173-
174-
#[quickcheck]
175-
fn check_single_interval_sum(array: Vec<i32>) -> TestResult {
176-
let seg_tree = SegmentTree::from_vec(&array, max);
177-
for (i, value) in array.into_iter().enumerate() {
178-
let res = seg_tree.query(i..(i + 1));
179-
if res != Some(value) {
180-
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
181-
}
182-
}
183-
TestResult::passed()
184223
}
185224
}

0 commit comments

Comments
 (0)