Skip to content

Commit

Permalink
add gcn dev version
Browse files Browse the repository at this point in the history
  • Loading branch information
chihming committed Mar 28, 2019
1 parent af1b8d7 commit c7865c5
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 2 deletions.
2 changes: 1 addition & 1 deletion cli/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
CC= g++
CPPFLAGS= -std=c++11 -fopenmp -lm -Ofast
OBJS= deepwalk line walklets hpe app mf bpr hoprec warp nemf nerank eco gcn textgcn
OBJS= deepwalk line walklets hpe app mf bpr hoprec warp nemf nerank eco gcn textgcn textgcndev
LIBS= -L ../bin -lproNet

all: $(OBJS)
Expand Down
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
CC = g++
CPPFLAGS = -std=c++11 -fPIC -fopenmp -lm -Ofast
OBJECTS = util.o random.o proNet.o ./model/DeepWalk.o ./model/Walklets.o ./model/LINE.o ./model/HPE.o ./model/APP.o ./model/MF.o ./model/BPR.o ./model/HBPR.o ./model/NEMF.o ./model/WARP.o ./model/NERANK.o ./model/ECO.o ./model/GCN.o ./model/TEXTGCN.o
OBJECTS = util.o random.o proNet.o ./model/DeepWalk.o ./model/Walklets.o ./model/LINE.o ./model/HPE.o ./model/APP.o ./model/MF.o ./model/BPR.o ./model/HBPR.o ./model/NEMF.o ./model/WARP.o ./model/NERANK.o ./model/ECO.o ./model/GCN.o ./model/TEXTGCN.o ./model/TEXTGCNdev.o
all: $(OBJECTS)
mkdir -p ../bin
ar rcs ../bin/libproNet.a $(OBJECTS)
Expand Down
65 changes: 65 additions & 0 deletions src/proNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2670,6 +2670,71 @@ void proNet::UpdatePairs(vector< vector<double> >& w_vertex, vector< vector<doub

}

void proNet::UpdateCBOWdev(vector< vector<double> >& w_vertex, vector< vector<double> >& w_context, long vertex, long context, int dimension, double reg, int walk_steps, int negative_samples, double alpha){

vector<long> vertices, neg_vertices;
vector<double> w_avg, back_err;
w_avg.resize(dimension, 0.0);
back_err.resize(dimension, 0.0);
long vertex_context;

//vertices.push_back(vertex);
for (int i=0; i!=walk_steps; ++i)
{
vertex_context = TargetSample(vertex);
if (vertex_context==-1) break;
vertices.push_back(vertex_context);
}

//double decay = 1.0;
vector<double>* w_ptr;
for (auto v: vertices)
{
w_ptr = &w_vertex[v];
for (int d=0; d!=dimension;++d)
{
w_avg[d] += (*w_ptr)[d];
}
}
/*
int num = vertices.size();
for (int d=0; d!=dimension;++d)
{
w_avg[d] /= num;
}
*/

long neg_context;
double label;

// positive training
label = 1.0;
Opt_SigmoidRegSGD(w_avg, w_context[context], label, alpha, reg, back_err, w_context[context]);
//Opt_SGD(w_vertex[vertex], w_context[context], label, alpha, reg, back_err, w_context[context]);

// negative sampling
label = 0.0;
for (int neg=0; neg!=negative_samples; ++neg)
{
neg_context = NegativeSample();
while(field[neg_context].fields[0]!=2)
neg_context = NegativeSample();
Opt_SigmoidRegSGD(w_avg, w_context[neg_context], label, alpha, reg, back_err, w_context[neg_context]);
//Opt_SGD(w_vertex[vertex], w_context[context], label, alpha, reg, back_err, w_context[context]);
}

for (auto v: vertices)
{
w_ptr = &w_vertex[v];
for (int d=0; d!=dimension;++d)
{
(*w_ptr)[d] += back_err[d];
}
}

}


void proNet::UpdateCBOW(vector< vector<double> >& w_vertex, vector< vector<double> >& w_context, long vertex, long context, int dimension, double reg, int walk_steps, int negative_samples, double alpha){

vector<long> vertices;
Expand Down
1 change: 1 addition & 0 deletions src/proNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class proNet {
// vertex vector, context vector, vertex, context, dimension, regularization, walk steps, negative samples, alpha
void UpdateCBOW(vector< vector<double> >&, vector< vector<double> >&, long, long, int, double, int, int, double);
void UpdateCBOWs(vector< vector<double> >&, vector< vector<double> >&, vector<long>&, vector<long>&, int, double, int, int, double);
void UpdateCBOWdev(vector< vector<double> >&, vector< vector<double> >&, long, long, int, double, int, int, double);

// vertex vector, context vector, vertex series, context series, dimension, negative samples, alpha
void UpdatePairs(vector< vector<double> >&, vector< vector<double> >&, vector<long>&, vector<long>&, int, int, double);
Expand Down

0 comments on commit c7865c5

Please sign in to comment.