Skip to content

Commit faccf7f

Browse files
authored
Optimize closest_points by pre-sorting array (TheAlgorithms#513)
This avoid nlog(n) sorts in each recursive call. The overall bounds change from O(nlog^2(n)) -> O(nlog(n))
1 parent 0d17ced commit faccf7f

File tree

1 file changed

+36
-21
lines changed

1 file changed

+36
-21
lines changed

src/geometry/closest_points.rs

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::geometry::Point;
22
use std::cmp::Ordering;
33

4-
fn point_cmp(p1: &Point, p2: &Point) -> Ordering {
4+
fn cmp_x(p1: &Point, p2: &Point) -> Ordering {
55
let acmp = f64_cmp(&p1.x, &p2.x);
66
match acmp {
77
Ordering::Equal => f64_cmp(&p1.y, &p2.y),
@@ -16,14 +16,19 @@ fn f64_cmp(a: &f64, b: &f64) -> Ordering {
1616
/// returns the two closest points
1717
/// or None if there are zero or one point
1818
pub fn closest_points(points: &[Point]) -> Option<(Point, Point)> {
19-
let mut points: Vec<Point> = points.to_vec();
20-
points.sort_by(point_cmp);
19+
let mut points_x: Vec<Point> = points.to_vec();
20+
points_x.sort_by(cmp_x);
21+
let mut points_y = points_x.clone();
22+
points_y.sort_by(|p1: &Point, p2: &Point| -> Ordering { p1.y.partial_cmp(&p2.y).unwrap() });
2123

22-
closest_points_aux(&points, 0, points.len())
24+
closest_points_aux(&points_x, points_y, 0, points_x.len())
2325
}
2426

27+
// We maintain two vectors with the same points, one sort by x coordinates and one sorted by y
28+
// coordinates.
2529
fn closest_points_aux(
26-
points: &[Point],
30+
points_x: &[Point],
31+
points_y: Vec<Point>,
2732
mut start: usize,
2833
mut end: usize,
2934
) -> Option<(Point, Point)> {
@@ -35,24 +40,38 @@ fn closest_points_aux(
3540

3641
if n <= 3 {
3742
// bruteforce
38-
let mut min = points[0].euclidean_distance(&points[1]);
39-
let mut pair = (points[0].clone(), points[1].clone());
43+
let mut min = points_x[0].euclidean_distance(&points_x[1]);
44+
let mut pair = (points_x[0].clone(), points_x[1].clone());
4045

4146
for i in 1..n {
4247
for j in (i + 1)..n {
43-
let new = points[i].euclidean_distance(&points[j]);
48+
let new = points_x[i].euclidean_distance(&points_x[j]);
4449
if new < min {
4550
min = new;
46-
pair = (points[i].clone(), points[j].clone());
51+
pair = (points_x[i].clone(), points_x[j].clone());
4752
}
4853
}
4954
}
5055
return Some(pair);
5156
}
5257

5358
let mid = (start + end) / 2;
54-
let left = closest_points_aux(points, start, mid);
55-
let right = closest_points_aux(points, mid, end);
59+
let mid_x = points_x[mid].x;
60+
61+
// Separate points into y_left and y_right vectors based on their x-coordinate. Since y is
62+
// already sorted by y_axis, y_left and y_right will also be sorted.
63+
let mut y_left = vec![];
64+
let mut y_right = vec![];
65+
for point in &points_y {
66+
if point.x < mid_x {
67+
y_left.push(point.clone());
68+
} else {
69+
y_right.push(point.clone());
70+
}
71+
}
72+
73+
let left = closest_points_aux(points_x, y_left, start, mid);
74+
let right = closest_points_aux(points_x, y_right, mid, end);
5675

5776
let (mut min_sqr_dist, mut pair) = match (left, right) {
5877
(Some((l1, l2)), Some((r1, r2))) => {
@@ -69,28 +88,24 @@ fn closest_points_aux(
6988
(None, None) => unreachable!(),
7089
};
7190

72-
let mid_x = points[mid].x;
7391
let dist = min_sqr_dist;
74-
while points[start].x < mid_x - dist {
92+
while points_x[start].x < mid_x - dist {
7593
start += 1;
7694
}
77-
while points[end - 1].x > mid_x + dist {
95+
while points_x[end - 1].x > mid_x + dist {
7896
end -= 1;
7997
}
8098

81-
let mut mids: Vec<&Point> = points[start..end].iter().collect();
82-
mids.sort_by(|a, b| f64_cmp(&a.y, &b.y));
83-
84-
for (i, e) in mids.iter().enumerate() {
99+
for (i, e) in points_y.iter().enumerate() {
85100
for k in 1..8 {
86-
if i + k >= mids.len() {
101+
if i + k >= points_y.len() {
87102
break;
88103
}
89104

90-
let new = e.euclidean_distance(mids[i + k]);
105+
let new = e.euclidean_distance(&points_y[i + k]);
91106
if new < min_sqr_dist {
92107
min_sqr_dist = new;
93-
pair = ((*e).clone(), mids[i + k].clone());
108+
pair = ((*e).clone(), points_y[i + k].clone());
94109
}
95110
}
96111
}

0 commit comments

Comments
 (0)