Skip to content

Commit

Permalink
Merge pull request flintlib#218 from aaditya-thakkar/trunk
Browse files Browse the repository at this point in the history
Added function for strassen matrix multiplication for fmpz_mat
  • Loading branch information
wbhart committed Feb 3, 2016
2 parents f95c65f + 6df1d06 commit b2ad26b
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 4 deletions.
2 changes: 2 additions & 0 deletions fmpz_mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ FLINT_DLL void fmpz_mat_mul(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B

FLINT_DLL void fmpz_mat_mul_classical(fmpz_mat_t C, const fmpz_mat_t A,
const fmpz_mat_t B);

FLINT_DLL void fmpz_mat_mul_strassen(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B);

FLINT_DLL void fmpz_mat_mul_classical_inline(fmpz_mat_t C, const fmpz_mat_t A,
const fmpz_mat_t B);
Expand Down
6 changes: 6 additions & 0 deletions fmpz_mat/doc/fmpz_mat.txt
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,12 @@ void fmpz_mat_mul_classical(fmpz_mat_t C,

The matrices must have compatible dimensions for matrix multiplication.
No aliasing is allowed.

void fmpz_mat_mul_strassen(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)

Sets $C = AB$. Dimensions must be compatible for matrix multiplication.
$C$ is not allowed to be aliased with $A$ or $B$. Uses Strassen
multiplication (the Strassen-Winograd variant).

void _fmpz_mat_mul_multi_mod(fmpz_mat_t C,
const fmpz_mat_t A, const fmpz_mat_t B, mp_bitcnt_t bits)
Expand Down
17 changes: 16 additions & 1 deletion fmpz_mat/mul.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
/******************************************************************************
Copyright (C) 2010,2011 Fredrik Johansson
Copyright (C) 2016 Aaditya Thakkar
******************************************************************************/

Expand Down Expand Up @@ -68,7 +69,21 @@ fmpz_mat_mul(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)

if (5*(ab + bb) > dim * dim || (bits > FLINT_BITS - 3 && dim < 60))
{
fmpz_mat_mul_classical_inline(C, A, B);
if ((ab + bb) * dim < 17000)
{
fmpz_mat_mul_classical_inline(C, A, B);
}
else
{
if (dim > 75 && (ab + bb) > 650)
{
_fmpz_mat_mul_multi_mod(C, A, B, bits);
}
else
{
fmpz_mat_mul_strassen(C, A, B);
}
}
}
else
{
Expand Down
15 changes: 12 additions & 3 deletions fmpz_mat/profile/p-mul.c
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ void sample(void * arg, ulong count)
else if (algorithm == 3)
for (i = 0; i < count; i++)
fmpz_mat_mul_multi_mod(C, A, B);
else if (algorithm == 4)
for (i = 0; i < count; i++)
fmpz_mat_mul_strassen(C, A, B);

prof_stop();

Expand All @@ -87,7 +90,7 @@ void sample(void * arg, ulong count)

int main(void)
{
double min_default, min_classical, min_inline, min_multi_mod, max;
double min_default, min_classical, min_inline, min_multi_mod, min_strassen, max;
mat_mul_t params;
slong bits, dim;

Expand All @@ -114,15 +117,21 @@ int main(void)

params.algorithm = 3;
prof_repeat(&min_multi_mod, &max, sample, &params);

params.algorithm = 4;
prof_repeat(&min_strassen, &max, sample, &params);

flint_printf("dim = %wd default/classical/inline/multi_mod %.2f %.2f %.2f %.2f (us)\n",
dim, min_default, min_classical, min_inline, min_multi_mod);
flint_printf("dim = %wd default/classical/inline/multi_mod/strassen %.2f %.2f %.2f %.2f %.2f (us)\n",
dim, min_default, min_classical, min_inline, min_multi_mod, min_strassen);

if (min_multi_mod < 0.6*min_default)
flint_printf("BAD!\n");

if (min_inline < 0.6*min_default)
flint_printf("BAD!\n");

if (min_strassen < 0.7*min_default)
flint_printf("BAD!\n");

if (min_multi_mod < 0.7*min_inline)
break;
Expand Down
166 changes: 166 additions & 0 deletions fmpz_mat/strassen_mul.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*=============================================================================
This file is part of FLINT.
FLINT is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
FLINT is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with FLINT; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
=============================================================================*/
/******************************************************************************
Copyright (C) 2016 Aaditya Thakkar
******************************************************************************/

#include "fmpz_mat.h"
#include "fmpz.h"
#include "fmpz_vec.h"
#include "flint.h"

void fmpz_mat_mul_strassen(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)
{
slong a, b, c;
slong anr, anc, bnr, bnc;

fmpz_mat_t A11, A12, A21, A22;
fmpz_mat_t B11, B12, B21, B22;
fmpz_mat_t C11, C12, C21, C22;
fmpz_mat_t X1, X2;

a = A->r;
b = A->c;
c = B->c;

if (a <= 4 || b <= 4 || c <= 4)
{
fmpz_mat_mul(C, A, B);
return;
}

anr = a / 2;
anc = b / 2;
bnr = anc;
bnc = c / 2;

fmpz_mat_window_init(A11, A, 0, 0, anr, anc);
fmpz_mat_window_init(A12, A, 0, anc, anr, 2*anc);
fmpz_mat_window_init(A21, A, anr, 0, 2*anr, anc);
fmpz_mat_window_init(A22, A, anr, anc, 2*anr, 2*anc);

fmpz_mat_window_init(B11, B, 0, 0, bnr, bnc);
fmpz_mat_window_init(B12, B, 0, bnc, bnr, 2*bnc);
fmpz_mat_window_init(B21, B, bnr, 0, 2*bnr, bnc);
fmpz_mat_window_init(B22, B, bnr, bnc, 2*bnr, 2*bnc);

fmpz_mat_window_init(C11, C, 0, 0, anr, bnc);
fmpz_mat_window_init(C12, C, 0, bnc, anr, 2*bnc);
fmpz_mat_window_init(C21, C, anr, 0, 2*anr, bnc);
fmpz_mat_window_init(C22, C, anr, bnc, 2*anr, 2*bnc);

fmpz_mat_init(X1, anr, FLINT_MAX(bnc, anc));
fmpz_mat_init(X2, anc, bnc);

X1->c = anc;

fmpz_mat_sub(X1, A11, A21);
fmpz_mat_sub(X2, B22, B12);
fmpz_mat_mul(C21, X1, X2);

fmpz_mat_add(X1, A21, A22);
fmpz_mat_sub(X2, B12, B11);
fmpz_mat_mul(C22, X1, X2);

fmpz_mat_sub(X1, X1, A11);
fmpz_mat_sub(X2, B22, X2);
fmpz_mat_mul(C12, X1, X2);

fmpz_mat_sub(X1, A12, X1);
fmpz_mat_mul(C11, X1, B22);

X1->c = bnc;
fmpz_mat_mul(X1, A11, B11);
fmpz_mat_add(C12, X1, C12);
fmpz_mat_add(C21, C12, C21);
fmpz_mat_add(C12, C12, C22);
fmpz_mat_add(C22, C21, C22);
fmpz_mat_add(C12, C12, C11);
fmpz_mat_sub(X2, X2, B21);
fmpz_mat_mul(C11, A22, X2);

fmpz_mat_clear(X2);

fmpz_mat_sub(C21, C21, C11);
fmpz_mat_mul(C11, A12, B21);

fmpz_mat_add(C11, X1, C11);

fmpz_mat_clear(X1);

fmpz_mat_window_clear(A11);
fmpz_mat_window_clear(A12);
fmpz_mat_window_clear(A21);
fmpz_mat_window_clear(A22);

fmpz_mat_window_clear(B11);
fmpz_mat_window_clear(B12);
fmpz_mat_window_clear(B21);
fmpz_mat_window_clear(B22);

fmpz_mat_window_clear(C11);
fmpz_mat_window_clear(C12);
fmpz_mat_window_clear(C21);
fmpz_mat_window_clear(C22);

if (c > 2*bnc)
{
fmpz_mat_t Bc, Cc;
fmpz_mat_window_init(Bc, B, 0, 2*bnc, b, c);
fmpz_mat_window_init(Cc, C, 0, 2*bnc, a, c);
fmpz_mat_mul(Cc, A, Bc);
fmpz_mat_window_clear(Bc);
fmpz_mat_window_clear(Cc);
}

if (a > 2*anr)
{
fmpz_mat_t Ar, Cr;
fmpz_mat_window_init(Ar, A, 2*anr, 0, a, b);
fmpz_mat_window_init(Cr, C, 2*anr, 0, a, c);
fmpz_mat_mul(Cr, Ar, B);
fmpz_mat_window_clear(Ar);
fmpz_mat_window_clear(Cr);
}

if (b > 2*anc)
{
fmpz_mat_t Ac, Br, Cb;
fmpz_mat_window_init(Ac, A, 0, 2*anc, 2*anr, b);
fmpz_mat_window_init(Br, B, 2*bnr, 0, b, 2*bnc);
fmpz_mat_window_init(Cb, C, 0, 0, 2*anr, 2*bnc);

slong mt, kt, nt;

mt = Ac->r;
kt = Ac->c;
nt = Br->c;
fmpz_mat_t tmp;
fmpz_mat_init(tmp, mt, nt);
fmpz_mat_mul(tmp, Ac, Br);
fmpz_mat_add(Cb, Cb, tmp);
fmpz_mat_clear(tmp);
fmpz_mat_window_clear(Ac);
fmpz_mat_window_clear(Br);
fmpz_mat_window_clear(Cb);
}
}
84 changes: 84 additions & 0 deletions fmpz_mat/test/t-strassen_mul.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*=============================================================================
This file is part of FLINT.
FLINT is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
FLINT is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with FLINT; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
=============================================================================*/
/******************************************************************************
Copyright (C) 2016 Aaditya Thakkar
******************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <gmp.h>
#include "flint.h"
#include "ulong_extras.h"
#include "fmpz_mat.h"
#include "fmpz.h"

int
main(void)
{
slong i;
FLINT_TEST_INIT(state);

flint_printf("strassen_mul....");
fflush(stdout);

for (i = 0; i < 20 * flint_test_multiplier(); i++)
{
fmpz_mat_t A, B, C, D;

slong m, k, n;

m = n_randint(state, 400);
k = n_randint(state, 400);
n = n_randint(state, 400);

fmpz_mat_init(A, m, n);
fmpz_mat_init(B, n, k);
fmpz_mat_init(C, m, k);
fmpz_mat_init(D, m, k);

fmpz_mat_randtest(A, state, n_randint(state, 200)+1);
fmpz_mat_randtest(B, state, n_randint(state, 200)+1);

fmpz_mat_mul_classical(C, A, B);
fmpz_mat_mul_strassen(D, A, B);

if (!fmpz_mat_equal(C, D))
{
flint_printf("FAIL: results not equal\n");
fmpz_mat_print_pretty(A);
fmpz_mat_print_pretty(B);
fmpz_mat_print_pretty(C);
fmpz_mat_print_pretty(D);
abort();
}

fmpz_mat_clear(A);
fmpz_mat_clear(B);
fmpz_mat_clear(C);
fmpz_mat_clear(D);
}

FLINT_TEST_CLEANUP(state);

flint_printf("PASS\n");
return 0;
}

0 comments on commit b2ad26b

Please sign in to comment.