|
| 1 | +// matrix_chain_multiply finds the minimum number of multiplications to perform a chain of matrix |
| 2 | +// multiplications. The input matrices represents the dimensions of matrices. For example [1,2,3,4] |
| 3 | +// represents matrices of dimension (1x2), (2x3), and (3x4) |
| 4 | +// |
| 5 | +// Lets say we are given [4, 3, 2, 1]. If we naively multiply left to right, we get: |
| 6 | +// |
| 7 | +// (4*3*2) + (4*2*1) = 20 |
| 8 | +// |
| 9 | +// We can reduce the multiplications by reordering the matrix multiplications: |
| 10 | +// |
| 11 | +// (3*2*1) + (4*3*1) = 18 |
| 12 | +// |
| 13 | +// We solve this problem with dynamic programming and tabulation. table[i][j] holds the optimal |
| 14 | +// number of multiplications in range matrices[i..j] (inclusive). Note this means that table[i][i] |
| 15 | +// and table[i][i+1] are always zero, since those represent a single vector/matrix and do not |
| 16 | +// require any multiplications. |
| 17 | +// |
| 18 | +// For any i, j, and k = i+1, i+2, ..., j-1: |
| 19 | +// |
| 20 | +// table[i][j] = min(table[i][k] + table[k][j] + matrices[i] * matrices[k] * matrices[j]) |
| 21 | +// |
| 22 | +// table[i][k] holds the optimal solution to matrices[i..k] |
| 23 | +// |
| 24 | +// table[k][j] holds the optimal solution to matrices[k..j] |
| 25 | +// |
| 26 | +// matrices[i] * matrices[k] * matrices[j] computes the number of multiplications to join the two |
| 27 | +// matrices together. |
| 28 | +// |
| 29 | +// Runs in O(n^3) time and O(n^2) space. |
| 30 | + |
| 31 | +pub fn matrix_chain_multiply(matrices: Vec<u32>) -> u32 { |
| 32 | + let n = matrices.len(); |
| 33 | + if n <= 2 { |
| 34 | + // No multiplications required. |
| 35 | + return 0; |
| 36 | + } |
| 37 | + let mut table = vec![vec![0; n]; n]; |
| 38 | + |
| 39 | + for length in 2..n { |
| 40 | + for i in 0..n - length { |
| 41 | + let j = i + length; |
| 42 | + table[i][j] = u32::MAX; |
| 43 | + for k in i + 1..j { |
| 44 | + let multiplications = |
| 45 | + table[i][k] + table[k][j] + matrices[i] * matrices[k] * matrices[j]; |
| 46 | + if multiplications < table[i][j] { |
| 47 | + table[i][j] = multiplications; |
| 48 | + } |
| 49 | + } |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + table[0][n - 1] |
| 54 | +} |
| 55 | + |
| 56 | +#[cfg(test)] |
| 57 | +mod tests { |
| 58 | + use super::*; |
| 59 | + |
| 60 | + #[test] |
| 61 | + fn basic() { |
| 62 | + assert_eq!(matrix_chain_multiply(vec![1, 2, 3, 4]), 18); |
| 63 | + assert_eq!(matrix_chain_multiply(vec![4, 3, 2, 1]), 18); |
| 64 | + assert_eq!(matrix_chain_multiply(vec![40, 20, 30, 10, 30]), 26000); |
| 65 | + assert_eq!(matrix_chain_multiply(vec![1, 2, 3, 4, 3]), 30); |
| 66 | + assert_eq!(matrix_chain_multiply(vec![1, 2, 3, 4, 3]), 30); |
| 67 | + assert_eq!(matrix_chain_multiply(vec![4, 10, 3, 12, 20, 7]), 1344); |
| 68 | + } |
| 69 | + |
| 70 | + #[test] |
| 71 | + fn zero() { |
| 72 | + assert_eq!(matrix_chain_multiply(vec![]), 0); |
| 73 | + assert_eq!(matrix_chain_multiply(vec![10]), 0); |
| 74 | + assert_eq!(matrix_chain_multiply(vec![10, 20]), 0); |
| 75 | + } |
| 76 | +} |
0 commit comments