-
Notifications
You must be signed in to change notification settings - Fork 82
/
textgcn.cpp
80 lines (68 loc) · 3.36 KB
/
textgcn.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#define _GLIBCXX_USE_CXX11_ABI 1
#include "../src/model/TEXTGCN.h"
int ArgPos(char *str, int argc, char **argv) {
int a;
for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) {
if (a == argc - 1) {
printf("Argument missing for %s\n", str);
exit(1);
}
return a;
}
return -1;
}
int main(int argc, char **argv){
int i;
if (argc == 1) {
printf("[proNet-core]\n");
printf("\tcommand line interface for proNet-core\n\n");
printf("Options Description:\n");
printf("\t-train <string>\n");
printf("\t\tTrain the Network data\n");
printf("\t-save <string>\n");
printf("\t\tSave the representation data\n");
printf("\t-field <string>\n");
printf("\t\tField data\n");
printf("\t-dimensions <int>\n");
printf("\t\tDimension of vertex representation; default is 64\n");
printf("\t-undirected <int>\n");
printf("\t\tWhether the edge is undirected; default is 1\n");
printf("\t-negative_samples <int>\n");
printf("\t\tNumber of negative examples; default is 5\n");
printf("\t-walk_steps <int>\n");
printf("\t\tStep of random walk; default is 5\n");
printf("\t-sample_times <int>\n");
printf("\t\tNumber of training samples *Million; default is 5\n");
printf("\t-threads <int>\n");
printf("\t\tNumber of training threads; default is 1\n");
printf("\t-reg <float>\n");
printf("\t\tRegularization term; default is 0.01\n");
printf("\t-alpha <float>\n");
printf("\t\tInit learning rate; default is 0.025\n");
printf("Usage:\n");
printf("./textgcn -train net.txt -field field.txt -save rep.txt -undirected 0 -dimensions 64 -reg 0.01 -sample_times 5 -walk_steps 5 -negative_samples 5 -alpha 0.025 -threads 1\n");
return 0;
}
char network_file[100], rep_file[100], field_file[100];
int dimensions=64, undirected=1, negative_samples=5, walk_steps=5, sample_times=10, threads=1;
double init_alpha=0.025, reg=0.01;
if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(network_file, argv[i + 1]);
if ((i = ArgPos((char *)"-save", argc, argv)) > 0) strcpy(rep_file, argv[i + 1]);
if ((i = ArgPos((char *)"-field", argc, argv)) > 0) strcpy(field_file, argv[i + 1]);
if ((i = ArgPos((char *)"-undirected", argc, argv)) > 0) undirected = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-dimensions", argc, argv)) > 0) dimensions = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-negative_samples", argc, argv)) > 0) negative_samples = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-walk_steps", argc, argv)) > 0) walk_steps = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-sample_times", argc, argv)) > 0) sample_times = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-reg", argc, argv)) > 0) reg = atof(argv[i + 1]);
if ((i = ArgPos((char *)"-alpha", argc, argv)) > 0) init_alpha = atof(argv[i + 1]);
if ((i = ArgPos((char *)"-threads", argc, argv)) > 0) threads = atoi(argv[i + 1]);
TEXTGCN *textgcn;
textgcn = new TEXTGCN();
textgcn->LoadEdgeList(network_file, undirected);
textgcn->LoadFieldMeta(field_file);
textgcn->Init(dimensions);
textgcn->Train(sample_times, walk_steps, negative_samples, reg, init_alpha, threads);
textgcn->SaveWeights(rep_file);
return 0;
}