diff --git a/src/model/TEXTGCN.cpp b/src/model/TEXTGCN.cpp index 532e5fe..567e1b1 100644 --- a/src/model/TEXTGCN.cpp +++ b/src/model/TEXTGCN.cpp @@ -15,16 +15,34 @@ void TEXTGCN::SaveWeights(string model_name){ counter++; model << counter << " " << dim << endl; + vector w_avg; + w_avg.resize(dim, 0.0); + + long temp_vid; for (long vid=0; vid!=pnet.MAX_vid; vid++) { if (pnet.field[vid].fields[0]==1) continue; model << pnet.vertex_hash.keys[vid]; - if (pnet.field[vid].fields[0]==0) + { + for (int b=0; b >& w_vertex, vector< vector >& w_vertex, vector< vector >& w_context, long vertex, long context, int dimension, double reg, int walk_steps, int negative_samples, double alpha){ - vector vertices; - vector w_avg, back_err; + vector vertex_contexts, context_contexts; + vector w_avg, c_avg, back_err, back_c_err; w_avg.resize(dimension, 0.0); + c_avg.resize(dimension, 0.0); back_err.resize(dimension, 0.0); - long vertex_context; + back_c_err.resize(dimension, 0.0); + long vertex_context, context_context; - //vertices.push_back(vertex); + // context-context (user-event) + for (int i=0; i!=walk_steps; ++i) + { + context_context = TargetSample(context); + if (context_context==-1) break; + context_contexts.push_back(context_context); + } + + // vertex-context (event-word) for (int i=0; i!=walk_steps; ++i) { vertex_context = TargetSample(vertex); if (vertex_context==-1) break; - vertices.push_back(vertex_context); + vertex_contexts.push_back(vertex_context); } - //double decay = 1.0; vector* w_ptr; - for (auto v: vertices) + // context-context (user-event) + for (auto c: context_contexts) + { + w_ptr = &w_vertex[c]; + for (int d=0; d!=dimension;++d) + { + c_avg[d] += (*w_ptr)[d]; + } + } + // vertex-context (event-word) + for (auto v: vertex_contexts) { w_ptr = &w_vertex[v]; for (int d=0; d!=dimension;++d) @@ -2776,21 +2795,73 @@ void proNet::UpdateCBOW(vector< vector >& w_vertex, vector< vector