|
| 1 | +// |
| 2 | +// algorithm - some algorithms in "Introduction to Algorithms", third edition |
| 3 | +// Copyright (C) 2018 lxylxy123456 |
| 4 | +// |
| 5 | +// This program is free software: you can redistribute it and/or modify |
| 6 | +// it under the terms of the GNU Affero General Public License as |
| 7 | +// published by the Free Software Foundation, either version 3 of the |
| 8 | +// License, or (at your option) any later version. |
| 9 | +// |
| 10 | +// This program is distributed in the hope that it will be useful, |
| 11 | +// but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 12 | +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 13 | +// GNU Affero General Public License for more details. |
| 14 | +// |
| 15 | +// You should have received a copy of the GNU Affero General Public License |
| 16 | +// along with this program. If not, see <https://www.gnu.org/licenses/>. |
| 17 | +// |
| 18 | + |
| 19 | +#ifndef MAIN |
| 20 | +#define MAIN |
| 21 | +#define MAIN_PSquareMatrixMultiply |
| 22 | +#endif |
| 23 | + |
| 24 | +#ifndef FUNC_PSquareMatrixMultiply |
| 25 | +#define FUNC_PSquareMatrixMultiply |
| 26 | + |
| 27 | +#include <thread> |
| 28 | +#include "utils.h" |
| 29 | + |
| 30 | +#include "MatVec.cpp" |
| 31 | + |
| 32 | +#define SM SubMatrix<T> |
| 33 | + |
| 34 | +template <typename T> |
| 35 | +Matrix<T> PSquareMatrixMultiply(Matrix<T>&A, Matrix<T>&B, T T0) { |
| 36 | + Matrix<T> C(A.rows, B.cols, T0); |
| 37 | + parallel_for<size_t>(0, C.rows, [&](size_t i){ |
| 38 | + parallel_for<size_t>(0, C.cols, [&](size_t j){ |
| 39 | + int& loc = C.data[i][j]; |
| 40 | + for(size_t k = 0; k < A.cols; k++) { |
| 41 | + loc += A[i][k] * B[k][j]; |
| 42 | + } |
| 43 | + }); |
| 44 | + }); |
| 45 | + return C; |
| 46 | +} |
| 47 | + |
| 48 | +template <typename T> |
| 49 | +void PMatrixMultiplyRecursive(SM A, SM B, SM C, T T0) { |
| 50 | + size_t a_row = A.rows(), a_col = A.cols(); |
| 51 | + size_t b_row = B.rows(), b_col = B.cols(); |
| 52 | + size_t c_row = C.rows(), c_col = C.cols(); |
| 53 | + assert(a_row == c_row && b_col == c_col && a_col == b_row); |
| 54 | + switch(a_row * a_col * b_col) { |
| 55 | + case 1: |
| 56 | + C.get_elem(0, 0) += A.get_elem(0, 0) * B.get_elem(0, 0); |
| 57 | + case 0: |
| 58 | + return; |
| 59 | + default: |
| 60 | + Matrix<T> S(c_row, c_col, T0); |
| 61 | + size_t a_mid = c_row / 2; // Rows of A & C |
| 62 | + size_t b_mid = c_col / 2; // Cols of B & C |
| 63 | + size_t c_mid = a_col / 2; // Cols of A & Rows of B |
| 64 | + SM A11(A, 0, c_mid, 0, a_mid); |
| 65 | + SM A12(A, c_mid, a_col, 0, a_mid); |
| 66 | + SM A21(A, 0, c_mid, a_mid, a_row); |
| 67 | + SM A22(A, c_mid, a_col, a_mid, a_row); |
| 68 | + SM B11(B, 0, b_mid, 0, c_mid); |
| 69 | + SM B12(B, b_mid, b_col, 0, c_mid); |
| 70 | + SM B21(B, 0, b_mid, c_mid, b_row); |
| 71 | + SM B22(B, b_mid, b_col, c_mid, b_row); |
| 72 | + SM C11(C, 0, b_mid, 0, a_mid); |
| 73 | + SM C12(C, b_mid, b_col, 0, a_mid); |
| 74 | + SM C21(C, 0, b_mid, a_mid, c_row); |
| 75 | + SM C22(C, b_mid, c_col, a_mid, c_row); |
| 76 | + SM S11(S, 0, b_mid, 0, a_mid); |
| 77 | + SM S12(S, b_mid, b_col, 0, a_mid); |
| 78 | + SM S21(S, 0, b_mid, a_mid, c_row); |
| 79 | + SM S22(S, b_mid, c_col, a_mid, c_row); |
| 80 | + std::thread t1(PMatrixMultiplyRecursive<T>, A11, B11, C11, T0); |
| 81 | + std::thread t2(PMatrixMultiplyRecursive<T>, A12, B21, S11, T0); |
| 82 | + std::thread t3(PMatrixMultiplyRecursive<T>, A11, B12, C12, T0); |
| 83 | + std::thread t4(PMatrixMultiplyRecursive<T>, A12, B22, S12, T0); |
| 84 | + std::thread t5(PMatrixMultiplyRecursive<T>, A21, B11, C21, T0); |
| 85 | + std::thread t6(PMatrixMultiplyRecursive<T>, A22, B21, S21, T0); |
| 86 | + std::thread t7(PMatrixMultiplyRecursive<T>, A21, B12, C22, T0); |
| 87 | + PMatrixMultiplyRecursive(A22, B22, S22, T0); |
| 88 | + t1.join(); |
| 89 | + t2.join(); |
| 90 | + t3.join(); |
| 91 | + t4.join(); |
| 92 | + t5.join(); |
| 93 | + t6.join(); |
| 94 | + t7.join(); |
| 95 | + parallel_for<size_t>(0, c_row, [&](size_t i){ |
| 96 | + parallel_for<size_t>(0, c_col, [&](size_t j){ |
| 97 | + C.get_elem(i, j) += S[i][j]; |
| 98 | + }); |
| 99 | + }); |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +template <typename T> |
| 104 | +Matrix<T> PMatrixMultiplyRecursive(Matrix<T>& A, Matrix<T>& B, T T0) { |
| 105 | + Matrix<T> C(A.rows, B.cols, T0); |
| 106 | + SM A_sub(A), B_sub(B), C_sub(C); |
| 107 | + PMatrixMultiplyRecursive(A_sub, B_sub, C_sub, T0); |
| 108 | + return C; |
| 109 | +} |
| 110 | + |
| 111 | +template <typename T> |
| 112 | +Matrix<T> PMatAdd(SM A, SM B) { |
| 113 | + size_t a_row = A.rows(), a_col = A.cols(); |
| 114 | + size_t b_row = B.rows(), b_col = B.cols(); |
| 115 | + assert(a_row == b_row && a_col == b_col); |
| 116 | + Matrix<T> C(a_row, a_col, 0); |
| 117 | + parallel_for<size_t>(0, a_row, [&](size_t i){ |
| 118 | + parallel_for<size_t>(0, a_col, [&](size_t j){ |
| 119 | + C[i][j] = A.get_elem(i, j) + B.get_elem(i, j); |
| 120 | + }); |
| 121 | + }); |
| 122 | + return C; |
| 123 | +} |
| 124 | + |
| 125 | +template <typename T> |
| 126 | +Matrix<T> PMatAdd(Matrix<T> A, Matrix<T> B) { |
| 127 | + return PMatAdd(SM(A), SM(B)); |
| 128 | +} |
| 129 | + |
| 130 | +template <typename T> |
| 131 | +Matrix<T> PMatSub(SM A, SM B) { |
| 132 | + size_t a_row = A.rows(), a_col = A.cols(); |
| 133 | + size_t b_row = B.rows(), b_col = B.cols(); |
| 134 | + assert(a_row == b_row && a_col == b_col); |
| 135 | + Matrix<T> C(a_row, a_col, 0); |
| 136 | + parallel_for<size_t>(0, a_row, [&](size_t i){ |
| 137 | + parallel_for<size_t>(0, a_col, [&](size_t j){ |
| 138 | + C[i][j] = A.get_elem(i, j) - B.get_elem(i, j); |
| 139 | + }); |
| 140 | + }); |
| 141 | + return C; |
| 142 | +} |
| 143 | + |
| 144 | +template <typename T> |
| 145 | +Matrix<T> PMatSub(Matrix<T> A, Matrix<T> B) { |
| 146 | + return PMatSub(SM(A), SM(B)); |
| 147 | +} |
| 148 | + |
| 149 | +template <typename T, T T0> |
| 150 | +void PMatrixMultiplyStrassen(SM A, SM B, SM CC) { |
| 151 | + size_t a_row = A.rows(), a_col = A.cols(); |
| 152 | + size_t b_row = B.rows(), b_col = B.cols(); |
| 153 | + assert(a_col == b_row); |
| 154 | + switch(a_row * a_col * b_col) { |
| 155 | + case 1: |
| 156 | + CC.data = Matrix<T>(1, 1, A.get_elem(0, 0) * B.get_elem(0, 0)); |
| 157 | + break; |
| 158 | + case 0: |
| 159 | + CC.data = Matrix<T>(0, 0); |
| 160 | + break; |
| 161 | + default: |
| 162 | + size_t a_mid = a_row / 2; // Rows of A & C |
| 163 | + size_t b_mid = b_col / 2; // Cols of B & C |
| 164 | + size_t c_mid = a_col / 2; // Cols of A & Rows of B |
| 165 | + size_t a_end = a_mid * 2; |
| 166 | + size_t b_end = b_mid * 2; |
| 167 | + size_t c_end = c_mid * 2; |
| 168 | + SM A11(A, 0, c_mid, 0, a_mid); |
| 169 | + SM A12(A, c_mid, c_end, 0, a_mid); |
| 170 | + SM A21(A, 0, c_mid, a_mid, a_end); |
| 171 | + SM A22(A, c_mid, c_end, a_mid, a_end); |
| 172 | + SM B11(B, 0, b_mid, 0, c_mid); |
| 173 | + SM B12(B, b_mid, b_end, 0, c_mid); |
| 174 | + SM B21(B, 0, b_mid, c_mid, c_end); |
| 175 | + SM B22(B, b_mid, b_end, c_mid, c_end); |
| 176 | + Matrix<T> S1 = PMatSub(B12, B22); |
| 177 | + Matrix<T> S2 = PMatAdd(A11, A12); |
| 178 | + Matrix<T> S3 = PMatAdd(A21, A22); |
| 179 | + Matrix<T> S4 = PMatSub(B21, B11); |
| 180 | + Matrix<T> S5 = PMatAdd(A11, A22); |
| 181 | + Matrix<T> S6 = PMatAdd(B11, B22); |
| 182 | + Matrix<T> S7 = PMatSub(A12, A22); |
| 183 | + Matrix<T> S8 = PMatAdd(B21, B22); |
| 184 | + Matrix<T> S9 = PMatSub(A11, A21); |
| 185 | + Matrix<T> S10 = B11 + B12; |
| 186 | + Matrix<T> P1(0, 0), P2(P1), P3(P2), P4(P3), P5(P4), P6(P5), P7(P6); |
| 187 | + std::thread |
| 188 | + t1(PMatrixMultiplyStrassen<T, T0>, A11, SM(S1), SM(P1)), |
| 189 | + t2(PMatrixMultiplyStrassen<T, T0>, SM(S2), B22, SM(P2)), |
| 190 | + t3(PMatrixMultiplyStrassen<T, T0>, SM(S3), B11, SM(P3)), |
| 191 | + t4(PMatrixMultiplyStrassen<T, T0>, A22, SM(S4), SM(P4)), |
| 192 | + t5(PMatrixMultiplyStrassen<T, T0>, SM(S5), SM(S6), SM(P5)), |
| 193 | + t6(PMatrixMultiplyStrassen<T, T0>, SM(S7), SM(S8), SM(P6)); |
| 194 | + PMatrixMultiplyStrassen<T, T0>(S9 , S10, P7); |
| 195 | + t1.join(); |
| 196 | + t2.join(); |
| 197 | + t3.join(); |
| 198 | + t4.join(); |
| 199 | + t5.join(); |
| 200 | + t6.join(); |
| 201 | + Matrix<T> C11 = PMatAdd(PMatSub(PMatAdd(P5, P4), P2), P6); |
| 202 | + Matrix<T> C12 = PMatAdd(P1, P2); |
| 203 | + Matrix<T> C21 = PMatAdd(P3, P4); |
| 204 | + Matrix<T> C22 = PMatSub(PMatSub(PMatAdd(P5, P1), P3), P7); |
| 205 | + Matrix<T>& C = CC.data; |
| 206 | + C = (C11.concat_h(C12)).concat_v(C21.concat_h(C22)); |
| 207 | + if (a_end != a_row) { |
| 208 | + assert(a_end == a_row - 1); |
| 209 | + C.add_row(T0); |
| 210 | + parallel_for<size_t>(0, b_end, [&](size_t i){ |
| 211 | + for (size_t j = 0; j < c_end; j++) |
| 212 | + C[a_end][i] += A.get_elem(a_end, j) * B.get_elem(j, i); |
| 213 | + }); |
| 214 | + a_end += 1; |
| 215 | + } |
| 216 | + if (b_end != b_col) { |
| 217 | + assert(b_end == b_col - 1); |
| 218 | + C.add_col(T0); |
| 219 | + parallel_for<size_t>(0, a_end, [&](size_t i){ |
| 220 | + for (size_t j = 0; j < c_end; j++) |
| 221 | + C[i][b_end] += A.get_elem(i, j) * B.get_elem(j, b_end); |
| 222 | + }); |
| 223 | + b_end += 1; |
| 224 | + } |
| 225 | + if (c_end != a_col) { |
| 226 | + assert(c_end == a_col - 1); |
| 227 | + parallel_for<size_t>(0, a_end, [&](size_t i){ |
| 228 | + for (size_t j = 0; j < b_end; j++) |
| 229 | + C[i][j] += A.get_elem(i, c_end) * B.get_elem(c_end, j); |
| 230 | + }); |
| 231 | + } |
| 232 | + } |
| 233 | +} |
| 234 | + |
| 235 | +template <typename T, T T0> |
| 236 | +Matrix<T> PMatrixMultiplyStrassen(Matrix<T>& A, Matrix<T>& B) { |
| 237 | + Matrix<T> C(A.rows, B.cols, T0); |
| 238 | + PMatrixMultiplyStrassen<T, T0>(SM(A), SM(B), SM(C)); |
| 239 | + return C; |
| 240 | +} |
| 241 | +#endif |
| 242 | + |
| 243 | +#ifdef MAIN_PSquareMatrixMultiply |
| 244 | +int main(int argc, char *argv[]) { |
| 245 | + const size_t n = get_argv(argc, argv, 1, 8); |
| 246 | + const size_t compute = get_argv(argc, argv, 2, 7); |
| 247 | + std::vector<int> buf_a, buf_b; |
| 248 | + random_integers(buf_a, 0, n, n * n); |
| 249 | + random_integers(buf_b, 0, n, n * n); |
| 250 | + Matrix<int> A(n, n, buf_a); |
| 251 | + Matrix<int> B(n, n, buf_b); |
| 252 | + std::cout << A << std::endl; |
| 253 | + std::cout << B << std::endl; |
| 254 | + Matrix<int> ans1(A); |
| 255 | + if (compute >> 0 & 1) { |
| 256 | + std::cout << "PSquareMatrixMultiply" << std::endl; |
| 257 | + ans1 = PSquareMatrixMultiply(A, B, 0); |
| 258 | + std::cout << ans1 << std::endl; |
| 259 | + } |
| 260 | + if (compute >> 1 & 1) { |
| 261 | + std::cout << "PMatrixMultiplyRecursive" << std::endl; |
| 262 | + Matrix<int> ans2 = PMatrixMultiplyRecursive(A, B, 0); |
| 263 | + std::cout << ans2 << std::endl; |
| 264 | + if (compute >> 0 & 1) |
| 265 | + std::cout << std::boolalpha << (ans1 == ans2) << std::endl; |
| 266 | + } |
| 267 | + if (compute >> 2 & 1) { |
| 268 | + std::cout << "PMatrixMultiplyStrassen" << std::endl; |
| 269 | + Matrix<int> ans3 = PMatrixMultiplyStrassen<int, 0>(A, B); |
| 270 | + std::cout << ans3 << std::endl; |
| 271 | + if (compute >> 0 & 1) |
| 272 | + std::cout << std::boolalpha << (ans1 == ans3) << std::endl; |
| 273 | + } |
| 274 | + return 0; |
| 275 | +} |
| 276 | +#endif |
| 277 | + |
0 commit comments