-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathrepack.cxx
117 lines (90 loc) · 2.26 KB
/
repack.cxx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/*Copyright (c) 2011, Edgar Solomonik, all rights reserved.*/
/** \addtogroup tests
* @{
* \defgroup repack repack
* @{
* \brief Tests contraction of a symmetric index group with a nonsymmetric one
*/
#include <ctf.hpp>
using namespace CTF;
int repack(int n,
World & dw){
int rank, i, num_pes, pass;
int64_t np;
double * pairs;
int64_t * indices;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &num_pes);
int shapeN4[] = {NS,NS,NS,NS};
int shapeS4[] = {NS,NS,SY,NS};
int sizeN4[] = {n,n,n,n};
//* Creates distributed tensors initialized with zeros
Tensor<> An(4, sizeN4, shapeN4, dw);
Tensor<> As(4, sizeN4, shapeS4, dw);
As.get_local_data(&np, &indices, &pairs);
for (i=0; i<np; i++ ) pairs[i] = drand48()-.5; //(1.E-3)*sin(indices[i]);
As.write(np, indices, pairs);
An.write(np, indices, pairs);
Tensor<> Anr(An, shapeS4);
Anr["ijkl"] -= As["ijkl"];
double norm = Anr.norm2();
if (norm < 1.E-6)
pass = 1;
else
pass = 0;
if (!pass)
printf("{ NS -> SY repack } failed \n");
else {
Tensor<> Anur(As, shapeN4);
Tensor<> Asur(As, shapeN4);
Asur["ijkl"] = 0.0;
Asur.write(np, indices, pairs);
Anur["ijkl"] -= Asur["ijkl"];
norm = Anur.norm2();
if (norm < 1.E-6){
pass = 1;
if (rank == 0)
printf("{ NS -> SY -> NS repack } passed \n");
} else {
pass = 0;
if (rank == 0)
printf("{ SY -> NS repack } failed \n");
}
}
delete [] pairs;
free(indices);
return pass;
}
#ifndef TEST_SUITE
char* getCmdOption(char ** begin,
char ** end,
const std::string & option){
char ** itr = std::find(begin, end, option);
if (itr != end && ++itr != end){
return *itr;
}
return 0;
}
int main(int argc, char ** argv){
int rank, np, n;
int in_num = argc;
char ** input_str = argv;
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &np);
if (getCmdOption(input_str, input_str+in_num, "-n")){
n = atoi(getCmdOption(input_str, input_str+in_num, "-n"));
if (n < 0) n = 7;
} else n = 7;
{
World dw(argc, argv);
repack(n, dw);
}
MPI_Finalize();
return 0;
}
/**
* @}
* @}
*/
#endif