1
1
import math
2
+ from typing import List , Tuple
2
3
3
- def default_matrix_multiplication (a , b ):
4
+
5
+ def default_matrix_multiplication (a : List , b : List ) -> List :
4
6
"""
5
7
Multiplication only for 2x2 matrices
6
8
"""
7
9
if len (a ) != 2 or len (a [0 ]) != 2 or len (b ) != 2 or len (b [0 ]) != 2 :
8
- raise Exception ('Matrices are not 2x2' )
9
- new_matrix = [[a [0 ][0 ] * b [0 ][0 ] + a [0 ][1 ] * b [1 ][0 ], a [0 ][0 ] * b [0 ][1 ] + a [0 ][1 ] * b [1 ][1 ]],
10
- [a [1 ][0 ] * b [0 ][0 ] + a [1 ][1 ] * b [1 ][0 ], a [1 ][0 ] * b [0 ][1 ] + a [1 ][1 ] * b [1 ][1 ]]]
10
+ raise Exception ("Matrices are not 2x2" )
11
+ new_matrix = [
12
+ [a [0 ][0 ] * b [0 ][0 ] + a [0 ][1 ] * b [1 ][0 ], a [0 ][0 ] * b [0 ][1 ] + a [0 ][1 ] * b [1 ][1 ]],
13
+ [a [1 ][0 ] * b [0 ][0 ] + a [1 ][1 ] * b [1 ][0 ], a [1 ][0 ] * b [0 ][1 ] + a [1 ][1 ] * b [1 ][1 ]],
14
+ ]
11
15
return new_matrix
12
16
13
- def matrix_addition (matrix_a , matrix_b ):
14
- return [[matrix_a [row ][col ] + matrix_b [row ][col ] for col in range (len (matrix_a [row ]))] for row in range (len (matrix_a ))]
15
17
16
- def matrix_subtraction (matrix_a , matrix_b ):
17
- return [[matrix_a [row ][col ] - matrix_b [row ][col ] for col in range (len (matrix_a [row ]))] for row in range (len (matrix_a ))]
18
+ def matrix_addition (matrix_a : List , matrix_b : List ):
19
+ return [
20
+ [matrix_a [row ][col ] + matrix_b [row ][col ] for col in range (len (matrix_a [row ]))]
21
+ for row in range (len (matrix_a ))
22
+ ]
23
+
24
+
25
+ def matrix_subtraction (matrix_a : List , matrix_b : List ):
26
+ return [
27
+ [matrix_a [row ][col ] - matrix_b [row ][col ] for col in range (len (matrix_a [row ]))]
28
+ for row in range (len (matrix_a ))
29
+ ]
18
30
19
31
20
- def split_matrix (a ) :
32
+ def split_matrix (a : List ,) -> Tuple [ List , List , List , List ] :
21
33
"""
22
34
Given an even length matrix, returns the top_left, top_right, bot_left, bot_right quadrant.
35
+
36
+ >>> split_matrix([[4,3,2,4],[2,3,1,1],[6,5,4,3],[8,4,1,6]])
37
+ ([[4, 3], [2, 3]], [[2, 4], [1, 1]], [[6, 5], [8, 4]], [[4, 3], [1, 6]])
38
+ >>> split_matrix([[4,3,2,4,4,3,2,4],[2,3,1,1,2,3,1,1],[6,5,4,3,6,5,4,3],[8,4,1,6,8,4,1,6],[4,3,2,4,4,3,2,4],[2,3,1,1,2,3,1,1],[6,5,4,3,6,5,4,3],[8,4,1,6,8,4,1,6]])
39
+ ([[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]])
23
40
"""
24
41
if len (a ) % 2 != 0 or len (a [0 ]) % 2 != 0 :
25
- raise Exception (' Odd matrices are not supported!' )
42
+ raise Exception (" Odd matrices are not supported!" )
26
43
27
44
matrix_length = len (a )
28
45
mid = matrix_length // 2
29
46
30
47
top_right = [[a [i ][j ] for j in range (mid , matrix_length )] for i in range (mid )]
31
- bot_right = [[a [i ][j ] for j in range (mid , matrix_length )] for i in range (mid , matrix_length )]
48
+ bot_right = [
49
+ [a [i ][j ] for j in range (mid , matrix_length )] for i in range (mid , matrix_length )
50
+ ]
32
51
33
52
top_left = [[a [i ][j ] for j in range (mid )] for i in range (mid )]
34
53
bot_left = [[a [i ][j ] for j in range (mid )] for i in range (mid , matrix_length )]
35
54
36
-
37
-
38
55
return top_left , top_right , bot_left , bot_right
39
56
40
- def matrix_dimensions (matrix ):
57
+
58
+ def matrix_dimensions (matrix : List ) -> Tuple [int , int ]:
41
59
return len (matrix ), len (matrix [0 ])
42
60
43
61
44
- def strassen (matrix_a , matrix_b ):
62
+ def print_matrix (matrix : List ) -> None :
63
+ for i in range (len (matrix )):
64
+ print (matrix [i ])
65
+
66
+
67
+ def actual_strassen (matrix_a : List , matrix_b : List ) -> List :
45
68
"""
46
69
Recursive function to calculate the product of two matrices, using the Strassen Algorithm.
47
70
It only supports even length matrices.
48
71
"""
49
72
if matrix_dimensions (matrix_a ) == (2 , 2 ):
50
73
return default_matrix_multiplication (matrix_a , matrix_b )
51
74
52
- a ,b , c , d = split_matrix (matrix_a )
53
- e ,f , g , h = split_matrix (matrix_b )
75
+ a , b , c , d = split_matrix (matrix_a )
76
+ e , f , g , h = split_matrix (matrix_b )
54
77
55
- t1 = strassen (a , matrix_subtraction (f , h ))
56
- t2 = strassen (matrix_addition (a , b ), h )
57
- t3 = strassen (matrix_addition (c , d ), e )
58
- t4 = strassen (d , matrix_subtraction (g , e ))
59
- t5 = strassen (matrix_addition (a , d ), matrix_addition (e , h ))
60
- t6 = strassen (matrix_subtraction (b , d ), matrix_addition (g , h ))
61
- t7 = strassen (matrix_subtraction (a , c ), matrix_addition (e , f ))
78
+ t1 = actual_strassen (a , matrix_subtraction (f , h ))
79
+ t2 = actual_strassen (matrix_addition (a , b ), h )
80
+ t3 = actual_strassen (matrix_addition (c , d ), e )
81
+ t4 = actual_strassen (d , matrix_subtraction (g , e ))
82
+ t5 = actual_strassen (matrix_addition (a , d ), matrix_addition (e , h ))
83
+ t6 = actual_strassen (matrix_subtraction (b , d ), matrix_addition (g , h ))
84
+ t7 = actual_strassen (matrix_subtraction (a , c ), matrix_addition (e , f ))
62
85
63
86
top_left = matrix_addition (matrix_subtraction (matrix_addition (t5 , t4 ), t2 ), t6 )
64
87
top_right = matrix_addition (t1 , t2 )
@@ -73,13 +96,18 @@ def strassen(matrix_a, matrix_b):
73
96
new_matrix .append (bot_left [i ] + bot_right [i ])
74
97
return new_matrix
75
98
76
- def print_matrix (matrix ):
77
- for i in range (len (matrix )):
78
- print (matrix [i ])
79
99
80
- def multiply_matrices (matrix1 , matrix2 ):
100
+ def strassen (matrix1 : List , matrix2 : List ) -> List :
101
+ """
102
+ >>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]])
103
+ [[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]]
104
+ >>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])
105
+ [[139, 163], [121, 134], [100, 121]]
106
+ """
81
107
if matrix_dimensions (matrix1 )[1 ] != matrix_dimensions (matrix2 )[0 ]:
82
- raise Exception (f'Unable to multiply these matrices, please check the dimensions. \n Matrix A:{ matrix1 } \n Matrix B:{ matrix2 } ' )
108
+ raise Exception (
109
+ f"Unable to multiply these matrices, please check the dimensions. \n Matrix A:{ matrix1 } \n Matrix B:{ matrix2 } "
110
+ )
83
111
dimension1 = matrix_dimensions (matrix1 )
84
112
dimension2 = matrix_dimensions (matrix2 )
85
113
@@ -88,39 +116,46 @@ def multiply_matrices(matrix1, matrix2):
88
116
89
117
maximum = max (max (dimension1 ), max (dimension2 ))
90
118
maxim = int (math .pow (2 , math .ceil (math .log2 (maximum ))))
91
- print (max )
92
119
new_matrix1 = matrix1
93
120
new_matrix2 = matrix2
94
- """
95
- Adding zeros to the matrices so that the arrays dimensions are the same and also power of 2
96
- """
97
- for i in range (0 ,maxim ):
121
+
122
+ # Adding zeros to the matrices so that the arrays dimensions are the same and also power of 2
123
+ for i in range (0 , maxim ):
98
124
if i < dimension1 [0 ]:
99
- for j in range (dimension1 [1 ],maxim ):
125
+ for j in range (dimension1 [1 ], maxim ):
100
126
new_matrix1 [i ].append (0 )
101
127
else :
102
128
new_matrix1 .append ([0 ] * maxim )
103
129
if i < dimension2 [0 ]:
104
- for j in range (dimension2 [1 ],maxim ):
130
+ for j in range (dimension2 [1 ], maxim ):
105
131
new_matrix2 [i ].append (0 )
106
132
else :
107
133
new_matrix2 .append ([0 ] * maxim )
108
134
109
- final_matrix = strassen (new_matrix1 , new_matrix2 )
135
+ final_matrix = actual_strassen (new_matrix1 , new_matrix2 )
110
136
111
- """
112
- Removing the additional zeros
113
- """
114
- for i in range (0 ,maxim ):
137
+ # Removing the additional zeros
138
+ for i in range (0 , maxim ):
115
139
if i < dimension1 [0 ]:
116
- for j in range (dimension2 [1 ],maxim ):
140
+ for j in range (dimension2 [1 ], maxim ):
117
141
final_matrix [i ].pop ()
118
142
else :
119
143
final_matrix .pop ()
120
144
return final_matrix
121
145
122
146
123
- if __name__ == '__main__' :
124
- matrix1 = [[2 ,3 ,4 ,5 ],[6 ,4 ,3 ,1 ],[2 ,3 ,6 ,7 ],[3 ,1 ,2 ,4 ],[2 ,3 ,4 ,5 ],[6 ,4 ,3 ,1 ],[2 ,3 ,6 ,7 ],[3 ,1 ,2 ,4 ],[2 ,3 ,4 ,5 ],[6 ,2 ,3 ,1 ]]
125
- matrix2 = [[0 ,2 ,1 ,1 ],[16 ,2 ,3 ,3 ],[2 ,2 ,7 ,7 ],[13 ,11 ,22 ,4 ]]
126
- print_matrix (multiply_matrices (matrix1 ,matrix2 ))
147
+ if __name__ == "__main__" :
148
+ matrix1 = [
149
+ [2 , 3 , 4 , 5 ],
150
+ [6 , 4 , 3 , 1 ],
151
+ [2 , 3 , 6 , 7 ],
152
+ [3 , 1 , 2 , 4 ],
153
+ [2 , 3 , 4 , 5 ],
154
+ [6 , 4 , 3 , 1 ],
155
+ [2 , 3 , 6 , 7 ],
156
+ [3 , 1 , 2 , 4 ],
157
+ [2 , 3 , 4 , 5 ],
158
+ [6 , 2 , 3 , 1 ],
159
+ ]
160
+ matrix2 = [[0 , 2 , 1 , 1 ], [16 , 2 , 3 , 3 ], [2 , 2 , 7 , 7 ], [13 , 11 , 22 , 4 ]]
161
+ print (strassen (matrix1 , matrix2 ))
0 commit comments