Strassen's Matrix
Multiplication
Dr. Mili Dhar
Basic Matrix Multiplication
Suppose we want to multiply two matrices of size N x N:
for example A x B = C.
C11 = A11B11 + A12B21
C12 = A11B12 + A12B22
C21 = A21B11 + A22B21
C22 = A21B12 + A22B22
Basic Matrix Multiplication
void multiply(int A[][N], int B[][N], int C[][N])
{
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
{
C[i][j] = 0;
Time analysis
for (int k = 0; k < N; k++)
{ The time Complexity of
C[i][j] += A[i][k]*B[k][j]; the above method is
}}} O(N3).
}
Matrix Multiplication using
Divide and Conquer
Following is simple Divide and Conquer method to multiply
two square matrices.
1. Divide matrices A and B in 4 sub-matrices of size N/2 x
N/2 as shown in the below diagram.
2. Calculate following values recursively. A0B0 + A1B2,
A0B1+A1B3, A2B0 + A3B2 and A2B1 + A3B3.
Divide and Conquer Matrix Multiply
A B = R
A11 A12 B11 B12 A11B11+A12B21 A11B12+A12B22
=
A21 A22 B21 B22 A21B11+A22B21 A21B12+A22B22
•Divide matrices into sub-matrices: A11 , A12, A21 etc
•Use blocked matrix multiply equations
•Recursively multiply sub-matrices
Divide-and-Conquer
Divide-and conquer is a general algorithm design
paradigm:
◼ Divide: divide the input data S in two or more disjoint
subsets S1, S2, …
◼ Recur: solve the subproblems recursively
◼ Conquer: combine the solutions for S1, S2, …, into a
solution for S
The base case for the recursion are subproblems of
constant size
Analysis can be done using recurrence equations
Divide and Conquer Matrix Multiply
A B = R
a0 b0 = a0 b0
• Terminate recursion with a simple base case
Matrix Multiplication using
Divide and Conquer
Array A => Array B =>
1 1 1 1 1 1 1 1
2 2 2 2 2 2 2 2
3 3 3 3 3 3 3 3
2 2 2 2 2 2 2 2
Result Array =>
8 8 8 8
16 16 16 16
24 24 24 24
16 16 16 16
Matrix Multiplication using
Divide and Conquer
MMult(A, B, n)
1. If n = 1 Output A × B
2. Else
3. Compute A11, B11, . . ., A22, B22 % by computing m = n/2
4. X1 ← MMult(A11, B11, n/2)
5. X2 ← MMult(A12, B21, n/2)
6. X3 ← MMult(A11, B12, n/2)
7. X4 ← MMult(A12, B22, n/2)
8. X5 ← MMult(A21, B11, n/2)
9. X6 ← MMult(A22, B21, n/2)
10. X7 ← MMult(A21, B12, n/2)
11. X8 ← MMult(A22, B22, n/2)
12. C11 ← X1 + X2
13. C12 ← X3 + X4
14. C21 ← X5 + X6
15. C22 ← X7 + X8
16. Output C
17. End If
Matrix Multiplication using
Divide and Conquer
Analysis:
The operations on line 3 take constant time. The combining cost
(lines 12–15) is Θ(n2) (adding two n/2 × n/2 matrices takes time
n2/4 = Θ(n2)). There are 8 recursive calls (lines 4–11). So let
T(n) be the total number of mathematical operations performed
by MMult(A, B, n),
then T(n) = 8T(n2) + Θ(n2)
The Master Theorem gives us T(n) = Θ(nlog2(8)) = Θ(n3)
So this is not an improvement on the “obvious” algorithm given
earlier (that uses n3 operations).
Strassens’s Matrix Multiplication
Strassen showed that 2x2 matrix multiplication
can be accomplished in 7 multiplication and 18
additions or subtractions. .(2log27 =22.807)
This reduction can be done by Divide and
Conquer Approach.
Strassen’s method, the four sub-matrices of the
result are calculated using the following formulae.
Strassens’s Matrix Multiplication
A B = R
A11 A12 B11 B12 P5+P4-P2+P6 P1+P2
=
A21 A22 B21 B22 P3+P4 P1+P5-P3-P7
•Divide matrices into sub-matrices: A11 , A12, A21 etc
•Use blocked matrix multiply equations
•Recursively multiply sub-matrices
Strassens’s Matrix Multiplication
P1 = (A11+ A22)(B11+B22) C11 = P1 + P4 - P5 + P7
P2 = (A21 + A22) * B11 C12 = P3 + P5
P3 = A11 * (B12 - B22) C21 = P2 + P4
P4 = A22 * (B21 - B11) C22 = P1 + P3 - P2 + P6
P5 = (A11 + A12) * B22
P6 = (A21 - A11) * (B11 + B12)
P7 = (A12 - A22) * (B21 + B22)
Comparison
C11 = P1 + P4 - P5 + P7
= (A11+ A22)(B11+B22) + A22 * (B21 - B11) - (A11 + A12) * B22+
(A12 - A22) * (B21 + B22)
= A11 B11 + A11 B22 + A22 B11 + A22 B22 + A22 B21 – A22 B11 -
A11 B22 -A12 B22 + A12 B21 + A12 B22 – A22 B21 – A22 B22
= A11 B11 + A12 B21
Strassen Algorithm
Strassen(A, B)
1. If n = 1 Output A × B
2. Else
3. Compute A11, B11, . . ., A22, B22 % by computing m = n/2
4. P1 ← Strassen(A11, B12 − B22)
5. P2 ← Strassen(A11 + A12, B22)
6. P3 ← Strassen(A21 + A22, B11)
7. P4 ← Strassen(A22, B21 − B11)
8. P5 ← Strassen(A11 + A22, B11 + B22)
9. P6 ← Strassen(A12 − A22, B21 + B22)
10. P7 ← Strassen(A11 − A21, B11 + B12)
11. C11 ← P5 + P4 − P2 + P6
12. C12 ← P1 + P2
13. C21 ← P3 + P4
14. C22 ← P1 + P5 − P3 − P7
15. Output C
16. End If
Strassen Algorithm
Analysis: The operations on line 3 take constant
time. The combining cost (lines 11–14) is Θ(n2).
There are 7 recursive calls (lines 4–10). So let
T(n) be the total number of mathematical
operations performed by Strassen(A, B), then
T(n) = 7T(n/2) + Θ(n2)
The Master Theorem gives us
T(n) = Θ(nlog2(7)) = Θ(n2.8)