Skip to content

Commit 0331777

Browse files
authored
Add matrix chain multiplication (TheAlgorithms#504)
1 parent e59c04a commit 0331777

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// matrix_chain_multiply finds the minimum number of multiplications to perform a chain of matrix
2+
// multiplications. The input matrices represents the dimensions of matrices. For example [1,2,3,4]
3+
// represents matrices of dimension (1x2), (2x3), and (3x4)
4+
//
5+
// Lets say we are given [4, 3, 2, 1]. If we naively multiply left to right, we get:
6+
//
7+
// (4*3*2) + (4*2*1) = 20
8+
//
9+
// We can reduce the multiplications by reordering the matrix multiplications:
10+
//
11+
// (3*2*1) + (4*3*1) = 18
12+
//
13+
// We solve this problem with dynamic programming and tabulation. table[i][j] holds the optimal
14+
// number of multiplications in range matrices[i..j] (inclusive). Note this means that table[i][i]
15+
// and table[i][i+1] are always zero, since those represent a single vector/matrix and do not
16+
// require any multiplications.
17+
//
18+
// For any i, j, and k = i+1, i+2, ..., j-1:
19+
//
20+
// table[i][j] = min(table[i][k] + table[k][j] + matrices[i] * matrices[k] * matrices[j])
21+
//
22+
// table[i][k] holds the optimal solution to matrices[i..k]
23+
//
24+
// table[k][j] holds the optimal solution to matrices[k..j]
25+
//
26+
// matrices[i] * matrices[k] * matrices[j] computes the number of multiplications to join the two
27+
// matrices together.
28+
//
29+
// Runs in O(n^3) time and O(n^2) space.
30+
31+
pub fn matrix_chain_multiply(matrices: Vec<u32>) -> u32 {
32+
let n = matrices.len();
33+
if n <= 2 {
34+
// No multiplications required.
35+
return 0;
36+
}
37+
let mut table = vec![vec![0; n]; n];
38+
39+
for length in 2..n {
40+
for i in 0..n - length {
41+
let j = i + length;
42+
table[i][j] = u32::MAX;
43+
for k in i + 1..j {
44+
let multiplications =
45+
table[i][k] + table[k][j] + matrices[i] * matrices[k] * matrices[j];
46+
if multiplications < table[i][j] {
47+
table[i][j] = multiplications;
48+
}
49+
}
50+
}
51+
}
52+
53+
table[0][n - 1]
54+
}
55+
56+
#[cfg(test)]
57+
mod tests {
58+
use super::*;
59+
60+
#[test]
61+
fn basic() {
62+
assert_eq!(matrix_chain_multiply(vec![1, 2, 3, 4]), 18);
63+
assert_eq!(matrix_chain_multiply(vec![4, 3, 2, 1]), 18);
64+
assert_eq!(matrix_chain_multiply(vec![40, 20, 30, 10, 30]), 26000);
65+
assert_eq!(matrix_chain_multiply(vec![1, 2, 3, 4, 3]), 30);
66+
assert_eq!(matrix_chain_multiply(vec![1, 2, 3, 4, 3]), 30);
67+
assert_eq!(matrix_chain_multiply(vec![4, 10, 3, 12, 20, 7]), 1344);
68+
}
69+
70+
#[test]
71+
fn zero() {
72+
assert_eq!(matrix_chain_multiply(vec![]), 0);
73+
assert_eq!(matrix_chain_multiply(vec![10]), 0);
74+
assert_eq!(matrix_chain_multiply(vec![10, 20]), 0);
75+
}
76+
}

src/dynamic_programming/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod longest_common_subsequence;
88
mod longest_common_substring;
99
mod longest_continuous_increasing_subsequence;
1010
mod longest_increasing_subsequence;
11+
mod matrix_chain_multiply;
1112
mod maximal_square;
1213
mod maximum_subarray;
1314
mod rod_cutting;
@@ -29,6 +30,7 @@ pub use self::longest_common_subsequence::longest_common_subsequence;
2930
pub use self::longest_common_substring::longest_common_substring;
3031
pub use self::longest_continuous_increasing_subsequence::longest_continuous_increasing_subsequence;
3132
pub use self::longest_increasing_subsequence::longest_increasing_subsequence;
33+
pub use self::matrix_chain_multiply::matrix_chain_multiply;
3234
pub use self::maximal_square::maximal_square;
3335
pub use self::maximum_subarray::maximum_subarray;
3436
pub use self::rod_cutting::rod_cut;

0 commit comments

Comments
 (0)