Skip to content

Commit

Permalink
Add opt_rp_spmv_exec() for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
huanghua1994 committed Apr 9, 2024
1 parent c32e959 commit e4f15b6
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/test_para2d_spmm.c
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ int main(int argc, char **argv)
et = get_wtime_sec();
if (my_rank == 0)
{
printf("%.2f\n", et - st);
printf("%.6f\n", et - st);
fflush(stdout);
}
}
Expand Down
14 changes: 10 additions & 4 deletions examples/test_rp_spmm.c
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,26 @@ int main(int argc, char **argv)
loc_A_srow, loc_A_nrow, loc_A_rowptr, loc_A_colidx, loc_A_csrval,
x_displs, glb_n, MPI_COMM_WORLD, &opt_rp_spmm
);
opt_rp_spmm_exec(opt_rp_spmm, layout, loc_B, loc_B_ld, loc_C, loc_C_ld); // Warm up
if (glb_n > 1) opt_rp_spmm_exec(opt_rp_spmm, layout, loc_B, loc_B_ld, loc_C, loc_C_ld); // Warm up
else opt_rp_spmv_exec(opt_rp_spmm, loc_B, loc_C);
opt_rp_spmm_clear_stat(opt_rp_spmm);
}
for (int i = 0; i < n_test; i++)
{
MPI_Barrier(MPI_COMM_WORLD);
st = get_wtime_sec();
if (use_opt == 0) rp_spmm_exec(rp_spmm, layout, loc_B, loc_B_ld, loc_C, loc_C_ld);
else opt_rp_spmm_exec(opt_rp_spmm, layout, loc_B, loc_B_ld, loc_C, loc_C_ld);
if (use_opt == 0)
{
rp_spmm_exec(rp_spmm, layout, loc_B, loc_B_ld, loc_C, loc_C_ld);
} else {
if (glb_n > 1) opt_rp_spmm_exec(opt_rp_spmm, layout, loc_B, loc_B_ld, loc_C, loc_C_ld);
else opt_rp_spmv_exec(opt_rp_spmm, loc_B, loc_C);
}
MPI_Barrier(MPI_COMM_WORLD);
et = get_wtime_sec();
if (my_rank == 0)
{
printf("%.2f\n", et - st);
printf("%.6f\n", et - st);
fflush(stdout);
}
}
Expand Down
106 changes: 106 additions & 0 deletions src/opt_rp_spmm.c
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,112 @@ void opt_rp_spmm_exec(
opt_rp_spmm->n_exec++;
}

// Compute y := A * x
void opt_rp_spmv_exec(opt_rp_spmm_p opt_rp_spmm, const double *x_local, double *y_local)
{
if (opt_rp_spmm == NULL) return;
int my_rank = opt_rp_spmm->my_rank;
int nproc = opt_rp_spmm->nproc;
int *B_scnts = opt_rp_spmm->B_scnts;
int *B_sdispls = opt_rp_spmm->B_sdispls;
int *B_sridxs = opt_rp_spmm->B_sridxs;
int *B_rcnts = opt_rp_spmm->B_rcnts;
int *B_rdispls = opt_rp_spmm->B_rdispls;

double st, et;
double exec_s = get_wtime_sec();

// 1. Allocate work memory
int n_send = 0, n_recv = 0;
const int B_local_nrow = opt_rp_spmm->B_local_nrow;
const int B_remote_nrow = opt_rp_spmm->B_remote_nrow;
const int B_sbuf_nrow = opt_rp_spmm->B_sdispls[nproc];
double *x_sbuf = (double *) malloc(sizeof(double) * B_sbuf_nrow);
double *x_remote = (double *) malloc(sizeof(double) * B_remote_nrow);
MPI_Request *x_sreqs = (MPI_Request *) malloc(sizeof(MPI_Request) * nproc);
MPI_Request *x_rreqs = (MPI_Request *) malloc(sizeof(MPI_Request) * nproc);
ASSERT_PRINTF(
x_sbuf != NULL && x_remote != NULL && x_sreqs != NULL && x_rreqs != NULL,
"Failed to allocate work memory\n"
);

// 2. Post all Irecv
st = get_wtime_sec();
for (int shift = 1; shift < nproc; shift++)
{
int p = (my_rank + shift) % nproc;
if (B_rcnts[p] == 0) continue;
MPI_Irecv(
x_remote + B_rdispls[p], B_rcnts[p], MPI_DOUBLE,
p, p, opt_rp_spmm->comm, x_rreqs + n_recv
);
n_recv++;
}
et = get_wtime_sec();
opt_rp_spmm->t_comm += et - st;

// 3. Pack B send buffer for each proc and post Isend
st = get_wtime_sec();
for (int shift = 1; shift < nproc; shift++)
{
int p = (my_rank + nproc - shift) % nproc;
if (B_scnts[p] == 0) continue;
int *p_x_sridxs = B_sridxs + B_sdispls[p];
double *p_x_sbuf = x_sbuf + B_sdispls[p];
#pragma omp simd
for (int i = 0; i < B_scnts[p]; i++) p_x_sbuf[i] = x_local[p_x_sridxs[i]];
MPI_Isend(
x_sbuf + B_sdispls[p], B_scnts[p], MPI_DOUBLE,
p, my_rank, opt_rp_spmm->comm, x_sreqs + n_send
);
n_send++;
} // End of p loop
et = get_wtime_sec();
opt_rp_spmm->t_pack += et - st;

// 4. Compute C = A_diag * B_diag
st = get_wtime_sec();
const double d_one = 1.0, d_zero = 0.0;
struct matrix_descr mkl_descA;
mkl_descA.type = SPARSE_MATRIX_TYPE_GENERAL;
mkl_descA.mode = SPARSE_FILL_MODE_FULL;
mkl_descA.diag = SPARSE_DIAG_NON_UNIT;
sparse_matrix_t mkl_A_diag = (sparse_matrix_t) opt_rp_spmm->mkl_A_diag;
mkl_sparse_d_mv(
SPARSE_OPERATION_NON_TRANSPOSE, d_one, mkl_A_diag,
mkl_descA, x_local, d_zero, y_local
);
et = get_wtime_sec();
opt_rp_spmm->t_spmm += et - st;

// 5. Wait for all Isend and Irecv
st = get_wtime_sec();
MPI_Waitall(n_send, x_sreqs, MPI_STATUSES_IGNORE);
MPI_Waitall(n_recv, x_rreqs, MPI_STATUSES_IGNORE);
et = get_wtime_sec();
opt_rp_spmm->t_comm += et - st;

// 6. Compute C += A_offd * B_offd
st = get_wtime_sec();
sparse_matrix_t mkl_A_offd = (sparse_matrix_t) opt_rp_spmm->mkl_A_offd;
mkl_sparse_d_mv(
SPARSE_OPERATION_NON_TRANSPOSE, d_one, mkl_A_offd,
mkl_descA, x_remote, d_one, y_local
);
et = get_wtime_sec();
opt_rp_spmm->t_spmm += et - st;

// 8. Free work memory
free(x_sbuf);
free(x_remote);
free(x_sreqs);
free(x_rreqs);

double exec_e = get_wtime_sec();
opt_rp_spmm->t_exec += exec_e - exec_s;
opt_rp_spmm->n_exec++;
}

// Print statistic info of opt_rp_spmm_p
void opt_rp_spmm_print_stat(opt_rp_spmm_p opt_rp_spmm)
{
Expand Down
8 changes: 8 additions & 0 deletions src/opt_rp_spmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ void opt_rp_spmm_exec(
double *C_local, const int ldC
);

// Compute y := A * x
// Input parameters:
// opt_rp_spmm : Pointer to an initialized opt_rp_spmm struct
// x_local : Size opt_rp_spmm->B_local_nrow, local x vector
// Output parameter:
// y_local : Size opt_rp_spmm->A_nrow, local y vector
void opt_rp_spmv_exec(opt_rp_spmm_p opt_rp_spmm, const double *x_local, double *y_local);

// Print statistic info of an opt_rp_spmm struct
void opt_rp_spmm_print_stat(opt_rp_spmm_p opt_rp_spmm);

Expand Down

0 comments on commit e4f15b6

Please sign in to comment.