Skip to content

Commit

Permalink
textgcn applys user:=events
Browse files Browse the repository at this point in the history
  • Loading branch information
chihming committed Apr 1, 2019
1 parent fc721a5 commit 9b9e8ac
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 15 deletions.
22 changes: 20 additions & 2 deletions src/model/TEXTGCN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,34 @@ void TEXTGCN::SaveWeights(string model_name){
counter++;
model << counter << " " << dim << endl;

vector<double> w_avg;
w_avg.resize(dim, 0.0);

long temp_vid;
for (long vid=0; vid!=pnet.MAX_vid; vid++)
{
if (pnet.field[vid].fields[0]==1)
continue;

model << pnet.vertex_hash.keys[vid];

if (pnet.field[vid].fields[0]==0)
{
for (int b=0; b<pnet.vertex[vid].branch; b++)
{
temp_vid = pnet.context[pnet.vertex[vid].offset + b].vid;
for (int d=0; d<dim; ++d)
{
w_avg[d] += w_vertex[temp_vid][d];
}
}
for (int d=0; d<dim; ++d)
model << " " << w_context[vid][d];
{
model << " " << w_avg[d];
w_avg[d] = 0.0;
}
//for (int d=0; d<dim; ++d)
// model << " " << w_context[vid][d];
}
if (pnet.field[vid].fields[0]==2)
for (int d=0; d<dim; ++d)
model << " " << w_vertex[vid][d];
Expand Down
97 changes: 84 additions & 13 deletions src/proNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2739,23 +2739,42 @@ void proNet::UpdateCBOWdev(vector< vector<double> >& w_vertex, vector< vector<do

void proNet::UpdateCBOW(vector< vector<double> >& w_vertex, vector< vector<double> >& w_context, long vertex, long context, int dimension, double reg, int walk_steps, int negative_samples, double alpha){

vector<long> vertices;
vector<double> w_avg, back_err;
vector<long> vertex_contexts, context_contexts;
vector<double> w_avg, c_avg, back_err, back_c_err;
w_avg.resize(dimension, 0.0);
c_avg.resize(dimension, 0.0);
back_err.resize(dimension, 0.0);
long vertex_context;
back_c_err.resize(dimension, 0.0);
long vertex_context, context_context;

//vertices.push_back(vertex);
// context-context (user-event)
for (int i=0; i!=walk_steps; ++i)
{
context_context = TargetSample(context);
if (context_context==-1) break;
context_contexts.push_back(context_context);
}

// vertex-context (event-word)
for (int i=0; i!=walk_steps; ++i)
{
vertex_context = TargetSample(vertex);
if (vertex_context==-1) break;
vertices.push_back(vertex_context);
vertex_contexts.push_back(vertex_context);
}

//double decay = 1.0;
vector<double>* w_ptr;
for (auto v: vertices)
// context-context (user-event)
for (auto c: context_contexts)
{
w_ptr = &w_vertex[c];
for (int d=0; d!=dimension;++d)
{
c_avg[d] += (*w_ptr)[d];
}
}
// vertex-context (event-word)
for (auto v: vertex_contexts)
{
w_ptr = &w_vertex[v];
for (int d=0; d!=dimension;++d)
Expand All @@ -2776,21 +2795,73 @@ void proNet::UpdateCBOW(vector< vector<double> >& w_vertex, vector< vector<doubl

// positive training
label = 1.0;
Opt_SigmoidRegSGD(w_avg, w_context[context], label, alpha, reg, back_err, w_context[context]);
Opt_SigmoidRegSGD(w_avg, c_avg, label, alpha, reg, back_err, back_c_err);
//Opt_SGD(w_vertex[vertex], w_context[context], label, alpha, reg, back_err, w_context[context]);

// update context_contexts (user-event)
for (auto c: context_contexts)
{
w_ptr = &w_vertex[c];
for (int d=0; d!=dimension;++d)
{
(*w_ptr)[d] += back_c_err[d];
}
}
for (int d=0; d!=dimension;++d)
{
c_avg[d] = 0.0;
back_c_err[d] = 0.0;
}

// negative sampling
label = 0.0;
for (int neg=0; neg!=negative_samples; ++neg)
{
neg_context = random_gen(0, MAX_vid);
while(field[neg_context].fields[0]!=0)
neg_context = random_gen(0, MAX_vid);
Opt_SigmoidRegSGD(w_avg, w_context[neg_context], label, alpha, reg, back_err, w_context[neg_context]);
//neg_context = random_gen(0, MAX_vid);
//while(field[neg_context].fields[0]!=0)
// neg_context = random_gen(0, MAX_vid);

// sample neg context-context (user-event)
context_contexts.clear();
for (int i=0; i!=walk_steps; ++i)
{
context_context = random_gen(0, MAX_vid);
while(field[context_context].fields[0]!=1) // event
context_context = random_gen(0, MAX_vid);
context_contexts.push_back(context_context);
}
// context-context (user-event)
for (auto c: context_contexts)
{
w_ptr = &w_vertex[c];
for (int d=0; d!=dimension;++d)
{
c_avg[d] += (*w_ptr)[d];
}
}

Opt_SigmoidRegSGD(w_avg, c_avg, label, alpha, reg, back_err, back_c_err);
//Opt_SGD(w_vertex[vertex], w_context[context], label, alpha, reg, back_err, w_context[context]);

// update context_contexts (user-event)
for (auto c: context_contexts)
{
w_ptr = &w_vertex[c];
for (int d=0; d!=dimension;++d)
{
(*w_ptr)[d] += back_c_err[d];
}
}
for (int d=0; d!=dimension;++d)
{
c_avg[d] = 0.0;
back_c_err[d] = 0.0;
}

}

for (auto v: vertices)
// update context_contexts (event-words)
for (auto v: vertex_contexts)
{
w_ptr = &w_vertex[v];
for (int d=0; d!=dimension;++d)
Expand Down

0 comments on commit 9b9e8ac

Please sign in to comment.