-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
310 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
#define _GLIBCXX_USE_CXX11_ABI 1 | ||
#include "../src/model/SkewOPT.h" | ||
|
||
int ArgPos(char *str, int argc, char **argv) { | ||
int a; | ||
for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) { | ||
if (a == argc - 1) { | ||
printf("Argument missing for %s\n", str); | ||
exit(1); | ||
} | ||
return a; | ||
} | ||
return -1; | ||
} | ||
|
||
int main(int argc, char **argv){ | ||
|
||
int i; | ||
|
||
if (argc == 1) { | ||
printf("[proNet-core]\n"); | ||
printf("\tcommand spr interface for proNet-core\n\n"); | ||
printf("Options Description:\n"); | ||
printf("\t-train <string>\n"); | ||
printf("\t\tTrain the Network data\n"); | ||
printf("\t-save <string>\n"); | ||
printf("\t\tSave the representation data\n"); | ||
printf("\t-dimensions <int>\n"); | ||
printf("\t\tDimension of vertex representation; default is 64\n"); | ||
printf("\t-sample_times <int>\n"); | ||
printf("\t\tNumber of training samples *Million; default is 10\n"); | ||
printf("\t-threads <int>\n"); | ||
printf("\t\tNumber of training threads; default is 1\n"); | ||
printf("\t-reg <float>\n"); | ||
printf("\t\tThe regularization term; default is 0.01\n"); | ||
printf("\t-xi <float>\n"); | ||
printf("\t\tThe xi term; default is 10.0\n"); | ||
printf("\t-omega <float>\n"); | ||
printf("\t\tThe omega term; default is 3.0\n"); | ||
printf("\t-eta <float>\n"); | ||
printf("\t\tThe eta term; default is 3.0\n"); | ||
printf("\t-alpha <float>\n"); | ||
printf("\t\tInit learning rate; default is 0.025\n"); | ||
|
||
printf("Usage:\n"); | ||
printf("\n[SkewOPT]\n"); | ||
printf("./spr -train net.txt -save rep.txt -dimensions 64 -sample_times 10 -alpha 0.025 -threads 1\n"); | ||
|
||
return 0; | ||
} | ||
|
||
char network_file[100], rep_file[100]; | ||
int dimensions=64, negative_samples=5, sample_times=10, threads=1, eta=3; | ||
double init_alpha=0.025, reg=0.01, xi=10.0, omega=3.0; | ||
|
||
if ((i = ArgPos((char *)"-train", argc, argv)) > 0) strcpy(network_file, argv[i + 1]); | ||
if ((i = ArgPos((char *)"-save", argc, argv)) > 0) strcpy(rep_file, argv[i + 1]); | ||
if ((i = ArgPos((char *)"-dimensions", argc, argv)) > 0) dimensions = atoi(argv[i + 1]); | ||
if ((i = ArgPos((char *)"-sample_times", argc, argv)) > 0) sample_times = atoi(argv[i + 1]); | ||
if ((i = ArgPos((char *)"-reg", argc, argv)) > 0) reg = atof(argv[i + 1]); | ||
if ((i = ArgPos((char *)"-xi", argc, argv)) > 0) xi = atof(argv[i + 1]); | ||
if ((i = ArgPos((char *)"-omega", argc, argv)) > 0) omega = atof(argv[i + 1]); | ||
if ((i = ArgPos((char *)"-eta", argc, argv)) > 0) eta = atoi(argv[i + 1]); | ||
if ((i = ArgPos((char *)"-alpha", argc, argv)) > 0) init_alpha = atof(argv[i + 1]); | ||
if ((i = ArgPos((char *)"-threads", argc, argv)) > 0) threads = atoi(argv[i + 1]); | ||
|
||
SPR *spr; | ||
spr = new SPR(); | ||
spr->LoadEdgeList(network_file, 0); | ||
spr->Init(dimensions); | ||
spr->Train(sample_times, negative_samples, init_alpha, reg, xi, omega, eta, threads); | ||
spr->SaveWeights(rep_file); | ||
|
||
return 0; | ||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
#include "SkewOPT.h" | ||
#include <omp.h> | ||
|
||
SPR::SPR() { | ||
char negative_method[15] = "no_degrees"; | ||
pnet.SetNegativeMethod(negative_method); | ||
} | ||
SPR::~SPR() { | ||
} | ||
|
||
void SPR::LoadEdgeList(string filename, bool undirect) { | ||
pnet.LoadEdgeList(filename, undirect); | ||
} | ||
|
||
void SPR::SaveWeights(string model_name){ | ||
|
||
cout << "Save Model:" << endl; | ||
ofstream model(model_name); | ||
if (model) | ||
{ | ||
model << pnet.MAX_vid << " " << dim << endl; | ||
for (long vid=0; vid!=pnet.MAX_vid; vid++) | ||
{ | ||
model << pnet.vertex_hash.keys[vid]; | ||
for (int d=0; d<dim; ++d) | ||
model << " " << w_vertex[vid][d]; | ||
model << endl; | ||
} | ||
cout << "\tSave to <" << model_name << ">" << endl; | ||
} | ||
else | ||
{ | ||
cout << "\tfail to open file" << endl; | ||
} | ||
} | ||
|
||
void SPR::Init(int dim) { | ||
|
||
this->dim = dim; | ||
cout << "Model Setting:" << endl; | ||
cout << "\tdimension:\t\t" << dim << endl; | ||
|
||
w_vertex.resize(pnet.MAX_vid); | ||
|
||
for (long vid=0; vid<pnet.MAX_vid; ++vid) | ||
{ | ||
w_vertex[vid].resize(dim); | ||
for (int d=0; d<dim;++d) | ||
w_vertex[vid][d] = (rand()/(double)RAND_MAX - 0.5) / dim + 0.01; | ||
} | ||
|
||
} | ||
|
||
|
||
void SPR::Train(int sample_times, int negative_samples, double alpha, double reg, double xi, double omega, int eta, int workers){ | ||
|
||
omp_set_num_threads(workers); | ||
|
||
cout << "Model:" << endl; | ||
cout << "\t[Skew-OPT]" << endl; | ||
|
||
cout << "Learning Parameters:" << endl; | ||
cout << "\tsample_times:\t\t" << sample_times << endl; | ||
cout << "\talpha:\t\t\t" << alpha << endl; | ||
cout << "\tregularization:\t\t" << reg << endl; | ||
cout << "\txi:\t\t\t" << xi << endl; | ||
cout << "\tomega:\t\t\t" << omega << endl; | ||
cout << "\teta:\t\t\t" << eta << endl; | ||
cout << "\tworkers:\t\t" << workers << endl; | ||
|
||
cout << "Start Training:" << endl; | ||
|
||
unsigned long long total_sample_times = (unsigned long long)sample_times*1000000; | ||
double alpha_min = alpha * 0.0001; | ||
double alpha_last; | ||
|
||
unsigned long long current_sample = 0; | ||
unsigned long long jobs = total_sample_times/workers; | ||
|
||
#pragma omp parallel for | ||
for (int worker=0; worker<workers; ++worker) | ||
{ | ||
|
||
long v1, v2, v3; | ||
unsigned long long count = 0; | ||
double _alpha = alpha; | ||
|
||
while (count<jobs) | ||
{ | ||
v1 = pnet.SourceSample(); | ||
v2 = pnet.TargetSample(v1); | ||
|
||
pnet.UpdateSBPRPair(w_vertex, w_vertex, v1, v2, dim, reg, xi, omega, eta, _alpha); | ||
|
||
count ++; | ||
if (count % MONITOR == 0) | ||
{ | ||
_alpha = alpha* ( 1.0 - (double)(current_sample)/total_sample_times ); | ||
current_sample += MONITOR; | ||
if (_alpha < alpha_min) _alpha = alpha_min; | ||
alpha_last = _alpha; | ||
printf("\tAlpha: %.6f\tProgress: %.3f %%%c", _alpha, (double)(current_sample)/total_sample_times * 100, 13); | ||
fflush(stdout); | ||
} | ||
} | ||
|
||
} | ||
printf("\tAlpha: %.6f\tProgress: 100.00 %%\n", alpha_last); | ||
|
||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#ifndef SPR_H | ||
#define SPR_H | ||
|
||
#include "../proNet.h" | ||
|
||
/***** | ||
* SkewOPT | ||
* **************************************************************/ | ||
|
||
class SPR { | ||
|
||
public: | ||
|
||
SPR(); | ||
~SPR(); | ||
|
||
proNet pnet; | ||
|
||
// parameters | ||
int dim; // representation dimensions | ||
vector< vector<double> > w_vertex; | ||
|
||
// data function | ||
void LoadEdgeList(string, bool); | ||
void SaveWeights(string); | ||
|
||
// model function | ||
void Init(int); | ||
void Train(int, int, double, double, double, double, int, int); | ||
|
||
}; | ||
|
||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters