Skip to content

Commit 767e34c

Browse files
authored
Improve heap (TheAlgorithms#560)
1 parent fec10e7 commit 767e34c

File tree

2 files changed

+96
-85
lines changed

2 files changed

+96
-85
lines changed

src/data_structures/heap.rs

Lines changed: 93 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,110 @@
11
// Heap data structure
22
// Takes a closure as a comparator to allow for min-heap, max-heap, and works with custom key functions
33

4-
use std::cmp::Ord;
5-
use std::default::Default;
4+
use std::{cmp::Ord, slice::Iter};
65

7-
pub struct Heap<T>
8-
where
9-
T: Default,
10-
{
11-
count: usize,
6+
pub struct Heap<T> {
127
items: Vec<T>,
138
comparator: fn(&T, &T) -> bool,
149
}
1510

16-
impl<T> Heap<T>
17-
where
18-
T: Default,
19-
{
11+
impl<T> Heap<T> {
2012
pub fn new(comparator: fn(&T, &T) -> bool) -> Self {
2113
Self {
22-
count: 0,
2314
// Add a default in the first spot to offset indexes
2415
// for the parent/child math to work out.
2516
// Vecs have to have all the same type so using Default
2617
// is a way to add an unused item.
27-
items: vec![T::default()],
18+
items: vec![],
2819
comparator,
2920
}
3021
}
3122

3223
pub fn len(&self) -> usize {
33-
self.count
24+
self.items.len()
3425
}
3526

3627
pub fn is_empty(&self) -> bool {
3728
self.len() == 0
3829
}
3930

4031
pub fn add(&mut self, value: T) {
41-
self.count += 1;
4232
self.items.push(value);
4333

4434
// Heapify Up
45-
let mut idx = self.count;
46-
while self.parent_idx(idx) > 0 {
47-
let pdx = self.parent_idx(idx);
35+
let mut idx = self.len() - 1;
36+
while let Some(pdx) = self.parent_idx(idx) {
4837
if (self.comparator)(&self.items[idx], &self.items[pdx]) {
4938
self.items.swap(idx, pdx);
5039
}
5140
idx = pdx;
5241
}
5342
}
5443

55-
fn parent_idx(&self, idx: usize) -> usize {
56-
idx / 2
44+
pub fn pop(&mut self) -> Option<T> {
45+
if self.is_empty() {
46+
return None;
47+
}
48+
// This feels like a function built for heap impl :)
49+
// Removes an item at an index and fills in with the last item
50+
// of the Vec
51+
let next = Some(self.items.swap_remove(0));
52+
53+
if !self.is_empty() {
54+
// Heapify Down
55+
let mut idx = 0;
56+
while self.children_present(idx) {
57+
let cdx = {
58+
if self.right_child_idx(idx) >= self.len() {
59+
self.left_child_idx(idx)
60+
} else {
61+
let ldx = self.left_child_idx(idx);
62+
let rdx = self.right_child_idx(idx);
63+
if (self.comparator)(&self.items[ldx], &self.items[rdx]) {
64+
ldx
65+
} else {
66+
rdx
67+
}
68+
}
69+
};
70+
if !(self.comparator)(&self.items[idx], &self.items[cdx]) {
71+
self.items.swap(idx, cdx);
72+
}
73+
idx = cdx;
74+
}
75+
}
76+
77+
next
78+
}
79+
80+
pub fn iter(&self) -> Iter<'_, T> {
81+
self.items.iter()
82+
}
83+
84+
fn parent_idx(&self, idx: usize) -> Option<usize> {
85+
if idx > 0 {
86+
Some((idx - 1) / 2)
87+
} else {
88+
None
89+
}
5790
}
5891

5992
fn children_present(&self, idx: usize) -> bool {
60-
self.left_child_idx(idx) <= self.count
93+
self.left_child_idx(idx) <= (self.len() - 1)
6194
}
6295

6396
fn left_child_idx(&self, idx: usize) -> usize {
64-
idx * 2
97+
idx * 2 + 1
6598
}
6699

67100
fn right_child_idx(&self, idx: usize) -> usize {
68101
self.left_child_idx(idx) + 1
69102
}
70-
71-
fn smallest_child_idx(&self, idx: usize) -> usize {
72-
if self.right_child_idx(idx) > self.count {
73-
self.left_child_idx(idx)
74-
} else {
75-
let ldx = self.left_child_idx(idx);
76-
let rdx = self.right_child_idx(idx);
77-
if (self.comparator)(&self.items[ldx], &self.items[rdx]) {
78-
ldx
79-
} else {
80-
rdx
81-
}
82-
}
83-
}
84103
}
85104

86105
impl<T> Heap<T>
87106
where
88-
T: Default + Ord,
107+
T: Ord,
89108
{
90109
/// Create a new MinHeap
91110
pub fn new_min() -> Heap<T> {
@@ -98,45 +117,13 @@ where
98117
}
99118
}
100119

101-
impl<T> Iterator for Heap<T>
102-
where
103-
T: Default,
104-
{
105-
type Item = T;
106-
107-
fn next(&mut self) -> Option<T> {
108-
if self.count == 0 {
109-
return None;
110-
}
111-
// This feels like a function built for heap impl :)
112-
// Removes an item at an index and fills in with the last item
113-
// of the Vec
114-
let next = Some(self.items.swap_remove(1));
115-
self.count -= 1;
116-
117-
if self.count > 0 {
118-
// Heapify Down
119-
let mut idx = 1;
120-
while self.children_present(idx) {
121-
let cdx = self.smallest_child_idx(idx);
122-
if !(self.comparator)(&self.items[idx], &self.items[cdx]) {
123-
self.items.swap(idx, cdx);
124-
}
125-
idx = cdx;
126-
}
127-
}
128-
129-
next
130-
}
131-
}
132-
133120
#[cfg(test)]
134121
mod tests {
135122
use super::*;
136123
#[test]
137124
fn test_empty_heap() {
138125
let mut heap: Heap<i32> = Heap::new_max();
139-
assert_eq!(heap.next(), None);
126+
assert_eq!(heap.pop(), None);
140127
}
141128

142129
#[test]
@@ -147,11 +134,11 @@ mod tests {
147134
heap.add(9);
148135
heap.add(11);
149136
assert_eq!(heap.len(), 4);
150-
assert_eq!(heap.next(), Some(2));
151-
assert_eq!(heap.next(), Some(4));
152-
assert_eq!(heap.next(), Some(9));
137+
assert_eq!(heap.pop(), Some(2));
138+
assert_eq!(heap.pop(), Some(4));
139+
assert_eq!(heap.pop(), Some(9));
153140
heap.add(1);
154-
assert_eq!(heap.next(), Some(1));
141+
assert_eq!(heap.pop(), Some(1));
155142
}
156143

157144
#[test]
@@ -162,14 +149,13 @@ mod tests {
162149
heap.add(9);
163150
heap.add(11);
164151
assert_eq!(heap.len(), 4);
165-
assert_eq!(heap.next(), Some(11));
166-
assert_eq!(heap.next(), Some(9));
167-
assert_eq!(heap.next(), Some(4));
152+
assert_eq!(heap.pop(), Some(11));
153+
assert_eq!(heap.pop(), Some(9));
154+
assert_eq!(heap.pop(), Some(4));
168155
heap.add(1);
169-
assert_eq!(heap.next(), Some(2));
156+
assert_eq!(heap.pop(), Some(2));
170157
}
171158

172-
#[derive(Default)]
173159
struct Point(/* x */ i32, /* y */ i32);
174160

175161
#[test]
@@ -179,9 +165,34 @@ mod tests {
179165
heap.add(Point(3, 10));
180166
heap.add(Point(-2, 4));
181167
assert_eq!(heap.len(), 3);
182-
assert_eq!(heap.next().unwrap().0, -2);
183-
assert_eq!(heap.next().unwrap().0, 1);
168+
assert_eq!(heap.pop().unwrap().0, -2);
169+
assert_eq!(heap.pop().unwrap().0, 1);
184170
heap.add(Point(50, 34));
185-
assert_eq!(heap.next().unwrap().0, 3);
171+
assert_eq!(heap.pop().unwrap().0, 3);
172+
}
173+
174+
#[test]
175+
fn test_iter_heap() {
176+
let mut heap = Heap::new_min();
177+
heap.add(4);
178+
heap.add(2);
179+
heap.add(9);
180+
heap.add(11);
181+
182+
// test iterator, which is not in order except the first one.
183+
let mut iter = heap.iter();
184+
assert_eq!(iter.next(), Some(&2));
185+
assert_ne!(iter.next(), None);
186+
assert_ne!(iter.next(), None);
187+
assert_ne!(iter.next(), None);
188+
assert_eq!(iter.next(), None);
189+
190+
// test the heap after run iterator.
191+
assert_eq!(heap.len(), 4);
192+
assert_eq!(heap.pop(), Some(2));
193+
assert_eq!(heap.pop(), Some(4));
194+
assert_eq!(heap.pop(), Some(9));
195+
assert_eq!(heap.pop(), Some(11));
196+
assert_eq!(heap.pop(), None);
186197
}
187198
}

src/searching/kth_smallest_heap.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use std::cmp::{Ord, Ordering};
1212
/// operation's complexity is O(log(k)).
1313
pub fn kth_smallest_heap<T>(input: &[T], k: usize) -> Option<T>
1414
where
15-
T: Default + Ord + Copy,
15+
T: Ord + Copy,
1616
{
1717
if input.len() < k {
1818
return None;
@@ -37,7 +37,7 @@ where
3737

3838
for &val in input.iter().skip(k) {
3939
// compare new value to the current kth smallest value
40-
let cur_big = heap.next().unwrap(); // heap.next() can't be None
40+
let cur_big = heap.pop().unwrap(); // heap.pop() can't be None
4141
match val.cmp(&cur_big) {
4242
Ordering::Greater => {
4343
heap.add(cur_big);
@@ -48,7 +48,7 @@ where
4848
}
4949
}
5050

51-
heap.next()
51+
heap.pop()
5252
}
5353

5454
#[cfg(test)]

0 commit comments

Comments
 (0)