Skip to content
This repository was archived by the owner on Nov 29, 2020. It is now read-only.

Commit d41d685

Browse files
committed
PSquareMatrixMultiply.cpp
1 parent 845e468 commit d41d685

File tree

3 files changed

+283
-3
lines changed

3 files changed

+283
-3
lines changed

PSquareMatrixMultiply.cpp

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
//
2+
// algorithm - some algorithms in "Introduction to Algorithms", third edition
3+
// Copyright (C) 2018 lxylxy123456
4+
//
5+
// This program is free software: you can redistribute it and/or modify
6+
// it under the terms of the GNU Affero General Public License as
7+
// published by the Free Software Foundation, either version 3 of the
8+
// License, or (at your option) any later version.
9+
//
10+
// This program is distributed in the hope that it will be useful,
11+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
// GNU Affero General Public License for more details.
14+
//
15+
// You should have received a copy of the GNU Affero General Public License
16+
// along with this program. If not, see <https://www.gnu.org/licenses/>.
17+
//
18+
19+
#ifndef MAIN
20+
#define MAIN
21+
#define MAIN_PSquareMatrixMultiply
22+
#endif
23+
24+
#ifndef FUNC_PSquareMatrixMultiply
25+
#define FUNC_PSquareMatrixMultiply
26+
27+
#include <thread>
28+
#include "utils.h"
29+
30+
#include "MatVec.cpp"
31+
32+
#define SM SubMatrix<T>
33+
34+
template <typename T>
35+
Matrix<T> PSquareMatrixMultiply(Matrix<T>&A, Matrix<T>&B, T T0) {
36+
Matrix<T> C(A.rows, B.cols, T0);
37+
parallel_for<size_t>(0, C.rows, [&](size_t i){
38+
parallel_for<size_t>(0, C.cols, [&](size_t j){
39+
int& loc = C.data[i][j];
40+
for(size_t k = 0; k < A.cols; k++) {
41+
loc += A[i][k] * B[k][j];
42+
}
43+
});
44+
});
45+
return C;
46+
}
47+
48+
template <typename T>
49+
void PMatrixMultiplyRecursive(SM A, SM B, SM C, T T0) {
50+
size_t a_row = A.rows(), a_col = A.cols();
51+
size_t b_row = B.rows(), b_col = B.cols();
52+
size_t c_row = C.rows(), c_col = C.cols();
53+
assert(a_row == c_row && b_col == c_col && a_col == b_row);
54+
switch(a_row * a_col * b_col) {
55+
case 1:
56+
C.get_elem(0, 0) += A.get_elem(0, 0) * B.get_elem(0, 0);
57+
case 0:
58+
return;
59+
default:
60+
Matrix<T> S(c_row, c_col, T0);
61+
size_t a_mid = c_row / 2; // Rows of A & C
62+
size_t b_mid = c_col / 2; // Cols of B & C
63+
size_t c_mid = a_col / 2; // Cols of A & Rows of B
64+
SM A11(A, 0, c_mid, 0, a_mid);
65+
SM A12(A, c_mid, a_col, 0, a_mid);
66+
SM A21(A, 0, c_mid, a_mid, a_row);
67+
SM A22(A, c_mid, a_col, a_mid, a_row);
68+
SM B11(B, 0, b_mid, 0, c_mid);
69+
SM B12(B, b_mid, b_col, 0, c_mid);
70+
SM B21(B, 0, b_mid, c_mid, b_row);
71+
SM B22(B, b_mid, b_col, c_mid, b_row);
72+
SM C11(C, 0, b_mid, 0, a_mid);
73+
SM C12(C, b_mid, b_col, 0, a_mid);
74+
SM C21(C, 0, b_mid, a_mid, c_row);
75+
SM C22(C, b_mid, c_col, a_mid, c_row);
76+
SM S11(S, 0, b_mid, 0, a_mid);
77+
SM S12(S, b_mid, b_col, 0, a_mid);
78+
SM S21(S, 0, b_mid, a_mid, c_row);
79+
SM S22(S, b_mid, c_col, a_mid, c_row);
80+
std::thread t1(PMatrixMultiplyRecursive<T>, A11, B11, C11, T0);
81+
std::thread t2(PMatrixMultiplyRecursive<T>, A12, B21, S11, T0);
82+
std::thread t3(PMatrixMultiplyRecursive<T>, A11, B12, C12, T0);
83+
std::thread t4(PMatrixMultiplyRecursive<T>, A12, B22, S12, T0);
84+
std::thread t5(PMatrixMultiplyRecursive<T>, A21, B11, C21, T0);
85+
std::thread t6(PMatrixMultiplyRecursive<T>, A22, B21, S21, T0);
86+
std::thread t7(PMatrixMultiplyRecursive<T>, A21, B12, C22, T0);
87+
PMatrixMultiplyRecursive(A22, B22, S22, T0);
88+
t1.join();
89+
t2.join();
90+
t3.join();
91+
t4.join();
92+
t5.join();
93+
t6.join();
94+
t7.join();
95+
parallel_for<size_t>(0, c_row, [&](size_t i){
96+
parallel_for<size_t>(0, c_col, [&](size_t j){
97+
C.get_elem(i, j) += S[i][j];
98+
});
99+
});
100+
}
101+
}
102+
103+
template <typename T>
104+
Matrix<T> PMatrixMultiplyRecursive(Matrix<T>& A, Matrix<T>& B, T T0) {
105+
Matrix<T> C(A.rows, B.cols, T0);
106+
SM A_sub(A), B_sub(B), C_sub(C);
107+
PMatrixMultiplyRecursive(A_sub, B_sub, C_sub, T0);
108+
return C;
109+
}
110+
111+
template <typename T>
112+
Matrix<T> PMatAdd(SM A, SM B) {
113+
size_t a_row = A.rows(), a_col = A.cols();
114+
size_t b_row = B.rows(), b_col = B.cols();
115+
assert(a_row == b_row && a_col == b_col);
116+
Matrix<T> C(a_row, a_col, 0);
117+
parallel_for<size_t>(0, a_row, [&](size_t i){
118+
parallel_for<size_t>(0, a_col, [&](size_t j){
119+
C[i][j] = A.get_elem(i, j) + B.get_elem(i, j);
120+
});
121+
});
122+
return C;
123+
}
124+
125+
template <typename T>
126+
Matrix<T> PMatAdd(Matrix<T> A, Matrix<T> B) {
127+
return PMatAdd(SM(A), SM(B));
128+
}
129+
130+
template <typename T>
131+
Matrix<T> PMatSub(SM A, SM B) {
132+
size_t a_row = A.rows(), a_col = A.cols();
133+
size_t b_row = B.rows(), b_col = B.cols();
134+
assert(a_row == b_row && a_col == b_col);
135+
Matrix<T> C(a_row, a_col, 0);
136+
parallel_for<size_t>(0, a_row, [&](size_t i){
137+
parallel_for<size_t>(0, a_col, [&](size_t j){
138+
C[i][j] = A.get_elem(i, j) - B.get_elem(i, j);
139+
});
140+
});
141+
return C;
142+
}
143+
144+
template <typename T>
145+
Matrix<T> PMatSub(Matrix<T> A, Matrix<T> B) {
146+
return PMatSub(SM(A), SM(B));
147+
}
148+
149+
template <typename T, T T0>
150+
void PMatrixMultiplyStrassen(SM A, SM B, SM CC) {
151+
size_t a_row = A.rows(), a_col = A.cols();
152+
size_t b_row = B.rows(), b_col = B.cols();
153+
assert(a_col == b_row);
154+
switch(a_row * a_col * b_col) {
155+
case 1:
156+
CC.data = Matrix<T>(1, 1, A.get_elem(0, 0) * B.get_elem(0, 0));
157+
break;
158+
case 0:
159+
CC.data = Matrix<T>(0, 0);
160+
break;
161+
default:
162+
size_t a_mid = a_row / 2; // Rows of A & C
163+
size_t b_mid = b_col / 2; // Cols of B & C
164+
size_t c_mid = a_col / 2; // Cols of A & Rows of B
165+
size_t a_end = a_mid * 2;
166+
size_t b_end = b_mid * 2;
167+
size_t c_end = c_mid * 2;
168+
SM A11(A, 0, c_mid, 0, a_mid);
169+
SM A12(A, c_mid, c_end, 0, a_mid);
170+
SM A21(A, 0, c_mid, a_mid, a_end);
171+
SM A22(A, c_mid, c_end, a_mid, a_end);
172+
SM B11(B, 0, b_mid, 0, c_mid);
173+
SM B12(B, b_mid, b_end, 0, c_mid);
174+
SM B21(B, 0, b_mid, c_mid, c_end);
175+
SM B22(B, b_mid, b_end, c_mid, c_end);
176+
Matrix<T> S1 = PMatSub(B12, B22);
177+
Matrix<T> S2 = PMatAdd(A11, A12);
178+
Matrix<T> S3 = PMatAdd(A21, A22);
179+
Matrix<T> S4 = PMatSub(B21, B11);
180+
Matrix<T> S5 = PMatAdd(A11, A22);
181+
Matrix<T> S6 = PMatAdd(B11, B22);
182+
Matrix<T> S7 = PMatSub(A12, A22);
183+
Matrix<T> S8 = PMatAdd(B21, B22);
184+
Matrix<T> S9 = PMatSub(A11, A21);
185+
Matrix<T> S10 = B11 + B12;
186+
Matrix<T> P1(0, 0), P2(P1), P3(P2), P4(P3), P5(P4), P6(P5), P7(P6);
187+
std::thread
188+
t1(PMatrixMultiplyStrassen<T, T0>, A11, SM(S1), SM(P1)),
189+
t2(PMatrixMultiplyStrassen<T, T0>, SM(S2), B22, SM(P2)),
190+
t3(PMatrixMultiplyStrassen<T, T0>, SM(S3), B11, SM(P3)),
191+
t4(PMatrixMultiplyStrassen<T, T0>, A22, SM(S4), SM(P4)),
192+
t5(PMatrixMultiplyStrassen<T, T0>, SM(S5), SM(S6), SM(P5)),
193+
t6(PMatrixMultiplyStrassen<T, T0>, SM(S7), SM(S8), SM(P6));
194+
PMatrixMultiplyStrassen<T, T0>(S9 , S10, P7);
195+
t1.join();
196+
t2.join();
197+
t3.join();
198+
t4.join();
199+
t5.join();
200+
t6.join();
201+
Matrix<T> C11 = PMatAdd(PMatSub(PMatAdd(P5, P4), P2), P6);
202+
Matrix<T> C12 = PMatAdd(P1, P2);
203+
Matrix<T> C21 = PMatAdd(P3, P4);
204+
Matrix<T> C22 = PMatSub(PMatSub(PMatAdd(P5, P1), P3), P7);
205+
Matrix<T>& C = CC.data;
206+
C = (C11.concat_h(C12)).concat_v(C21.concat_h(C22));
207+
if (a_end != a_row) {
208+
assert(a_end == a_row - 1);
209+
C.add_row(T0);
210+
parallel_for<size_t>(0, b_end, [&](size_t i){
211+
for (size_t j = 0; j < c_end; j++)
212+
C[a_end][i] += A.get_elem(a_end, j) * B.get_elem(j, i);
213+
});
214+
a_end += 1;
215+
}
216+
if (b_end != b_col) {
217+
assert(b_end == b_col - 1);
218+
C.add_col(T0);
219+
parallel_for<size_t>(0, a_end, [&](size_t i){
220+
for (size_t j = 0; j < c_end; j++)
221+
C[i][b_end] += A.get_elem(i, j) * B.get_elem(j, b_end);
222+
});
223+
b_end += 1;
224+
}
225+
if (c_end != a_col) {
226+
assert(c_end == a_col - 1);
227+
parallel_for<size_t>(0, a_end, [&](size_t i){
228+
for (size_t j = 0; j < b_end; j++)
229+
C[i][j] += A.get_elem(i, c_end) * B.get_elem(c_end, j);
230+
});
231+
}
232+
}
233+
}
234+
235+
template <typename T, T T0>
236+
Matrix<T> PMatrixMultiplyStrassen(Matrix<T>& A, Matrix<T>& B) {
237+
Matrix<T> C(A.rows, B.cols, T0);
238+
PMatrixMultiplyStrassen<T, T0>(SM(A), SM(B), SM(C));
239+
return C;
240+
}
241+
#endif
242+
243+
#ifdef MAIN_PSquareMatrixMultiply
244+
int main(int argc, char *argv[]) {
245+
const size_t n = get_argv(argc, argv, 1, 8);
246+
const size_t compute = get_argv(argc, argv, 2, 7);
247+
std::vector<int> buf_a, buf_b;
248+
random_integers(buf_a, 0, n, n * n);
249+
random_integers(buf_b, 0, n, n * n);
250+
Matrix<int> A(n, n, buf_a);
251+
Matrix<int> B(n, n, buf_b);
252+
std::cout << A << std::endl;
253+
std::cout << B << std::endl;
254+
Matrix<int> ans1(A);
255+
if (compute >> 0 & 1) {
256+
std::cout << "PSquareMatrixMultiply" << std::endl;
257+
ans1 = PSquareMatrixMultiply(A, B, 0);
258+
std::cout << ans1 << std::endl;
259+
}
260+
if (compute >> 1 & 1) {
261+
std::cout << "PMatrixMultiplyRecursive" << std::endl;
262+
Matrix<int> ans2 = PMatrixMultiplyRecursive(A, B, 0);
263+
std::cout << ans2 << std::endl;
264+
if (compute >> 0 & 1)
265+
std::cout << std::boolalpha << (ans1 == ans2) << std::endl;
266+
}
267+
if (compute >> 2 & 1) {
268+
std::cout << "PMatrixMultiplyStrassen" << std::endl;
269+
Matrix<int> ans3 = PMatrixMultiplyStrassen<int, 0>(A, B);
270+
std::cout << ans3 << std::endl;
271+
if (compute >> 0 & 1)
272+
std::cout << std::boolalpha << (ans1 == ans3) << std::endl;
273+
}
274+
return 0;
275+
}
276+
#endif
277+

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@
177177
| 27 | MatVec.cpp | Mat Vec Main Loop | 785 |
178178
| 27 | RaceExample.cpp | Race Example | 788 |
179179
| 27 | MatVec.cpp | Mat Vec Wrong | 790 |
180+
| 27 | PSquareMatrixMultiply.cpp | P Square Matrix Multiply | 793 |
181+
| 27 | PSquareMatrixMultiply.cpp | P Matrix Multiply Recursive | 794 |
182+
| 27 | PSquareMatrixMultiply.cpp | P Matrix Multiply Strassen | 794 |
180183

181184
# Supplementary Files
182185
* `utils.h`: Utils

SquareMatrixMultiply.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,9 @@ Matrix<T> SquareMatrixMultiplyStrassen(SubMatrix<T> A, SubMatrix<T> B) {
249249
assert(a_col == b_row);
250250
switch(a_row * a_col * b_col) {
251251
case 1:
252-
return Matrix<T> (1, 1, A.get_elem(0, 0) * B.get_elem(0, 0));
252+
return Matrix<T>(1, 1, A.get_elem(0, 0) * B.get_elem(0, 0));
253253
case 0:
254-
return Matrix<T> (0, 0);
254+
return Matrix<T>(0, 0);
255255
default:
256256
size_t a_mid = a_row / 2; // Rows of A & C
257257
size_t b_mid = b_col / 2; // Cols of B & C
@@ -300,7 +300,7 @@ Matrix<T> SquareMatrixMultiplyStrassen(SubMatrix<T> A, SubMatrix<T> B) {
300300
}
301301
if (c_end != a_col) {
302302
assert(c_end == a_col - 1);
303-
for (size_t i = 0; i < a_end; i++)
303+
for (size_t i = 0; i < a_end; i++)
304304
for (size_t j = 0; j < b_end; j++)
305305
C[i][j] += A.get_elem(i, c_end) * B.get_elem(c_end, j);
306306
}

0 commit comments

Comments
 (0)