diff --git a/TransR/Train_TransR.cpp b/TransR/Train_TransR.cpp index 3d79044..17165b8 100644 --- a/TransR/Train_TransR.cpp +++ b/TransR/Train_TransR.cpp @@ -1,417 +1,415 @@ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -using namespace std; - - -#define pi 3.1415926535897932384626433832795 - -bool L1_flag=1; - -//normal distribution -double rand(double min, double max) -{ - return min+(max-min)*rand()/(RAND_MAX+1.0); -} -double normal(double x, double miu,double sigma) -{ - return 1.0/sqrt(2*pi)/sigma*exp(-1*(x-miu)*(x-miu)/(2*sigma*sigma)); -} -double sigmod(double x) -{ - return 1.0/(1+exp(-2*x)); -} -double randn(double miu,double sigma, double min ,double max) -{ - double x,y,dScope; - do{ - x=rand(min,max); - y=normal(x,miu,sigma); - dScope=rand(0.0,normal(miu,miu,sigma)); - }while(dScope>y); - return x; -} - -double sqr(double x) -{ - return x*x; -} - -double vec_len(vector &a) -{ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace std; + + +#define pi 3.1415926535897932384626433832795 + +bool L1_flag=1; + +//normal distribution +double rand(double min, double max) +{ + return min+(max-min)*rand()/(RAND_MAX+1.0); +} +double normal(double x, double miu,double sigma) +{ + return 1.0/sqrt(2*pi)/sigma*exp(-1*(x-miu)*(x-miu)/(2*sigma*sigma)); +} +double sigmod(double x) +{ + return 1.0/(1+exp(-2*x)); +} +double randn(double miu,double sigma, double min ,double max) +{ + double x,y,dScope; + do{ + x=rand(min,max); + y=normal(x,miu,sigma); + dScope=rand(0.0,normal(miu,miu,sigma)); + }while(dScope>y); + return x; +} + +double sqr(double x) +{ + return x*x; +} + +double vec_len(vector &a) +{ double res=0; - for (int i=0; i a) -{ - for (int i=0; i relation2id,entity2id; -map id2entity,id2relation; - -map entity2num; - - -map > > left_entity,right_entity; -map left_mean,right_mean,left_var,right_var; - -class Train{ - -public: - map, map > ok; - void add(int x,int y,int z) - { - fb_h.push_back(x); - fb_r.push_back(z); - fb_l.push_back(y); - ok[make_pair(x,z)][y]=1; - } - void run(int n_in,double rate_in,double margin_in,int method_in) - { + for (int i=0; i a) +{ + for (int i=0; i relation2id,entity2id; +map id2entity,id2relation; + +map entity2num; + + +map > > left_entity,right_entity; +map left_mean,right_mean,left_var,right_var; + +class Train{ + +public: + map, map > ok; + void add(int x,int y,int z) + { + fb_h.push_back(x); + fb_r.push_back(z); + fb_l.push_back(y); + ok[make_pair(x,z)][y]=1; + } + void run(int n_in,double rate_in,double margin_in,int method_in) + { n = n_in; - m = n_in; + m = n_in; rate = rate_in; margin = margin_in; - method = method_in; - A.resize(relation_num); - for (int i=0; i fb_h,fb_l,fb_r; - vector > relation_vec,entity_vec; - vector > relation_tmp,entity_tmp; - vector > > A,A_tmp; - void norm(vector &a) - { - double x = vec_len(a); - if (x>1) - for (int ii=0; ii &a, vector > &A) - { - while (true) - { - double x=0; - for (int ii=0; ii1) - { - double lambda=1; - for (int ii=0; ii fb_h,fb_l,fb_r; + vector > relation_vec,entity_vec; + vector > relation_tmp,entity_tmp; + vector > > A,A_tmp; + void norm(vector &a) + { + double x = vec_len(a); + if (x>1) + for (int ii=0; ii &a, vector > &A) + { + while (true) + { + double x=0; + for (int ii=0; ii1) + { + double lambda=1; + for (int ii=0; ii0) - j=rand_max(entity_num); - train_kb(fb_h[i],fb_l[i],fb_r[i],fb_h[i],j,fb_r[i]); - - } - else - { - while (ok[make_pair(j,fb_r[i])].count(fb_l[i])>0) - j=rand_max(entity_num); - train_kb(fb_h[i],fb_l[i],fb_r[i],j,fb_l[i],fb_r[i]); - } - norm(relation_tmp[fb_r[i]]); - norm(entity_tmp[fb_h[i]]); - norm(entity_tmp[fb_l[i]]); - norm(entity_tmp[j]); - norm(entity_tmp[fb_h[i]],A_tmp[fb_r[i]]); - norm(entity_tmp[fb_l[i]],A_tmp[fb_r[i]]); - norm(entity_tmp[j],A_tmp[fb_r[i]]); - norm(entity_tmp[k]); - norm(entity_tmp[k],A_tmp[fb_r[i]]); - } - relation_vec = relation_tmp; - entity_vec = entity_tmp; - A = A_tmp; - } - cout<<"epoch:"< e1_vec; - e1_vec.resize(m); - vector e2_vec; - e2_vec.resize(m); - for (int ii=0; ii0) - x=1; - else - x=-1; - for (int jj=0; jjsum2) - { - res+=margin+sum1-sum2; - gradient( e1_a, e2_a, rel_a, e1_b, e2_b, rel_b); - } - } -}; - -Train train; -void prepare() -{ - FILE* f1 = fopen("../data/entity2id.txt","r"); - FILE* f2 = fopen("../data/relation2id.txt","r"); - int x; - while (fscanf(f1,"%s%d",buf,&x)==2) - { - string st=buf; - entity2id[st]=x; - id2entity[x]=st; - entity_num++; - } - while (fscanf(f2,"%s%d",buf,&x)==2) - { - string st=buf; - relation2id[st]=x; - id2relation[x]=st; - relation_num++; - } - FILE* f_kb = fopen("../data/train.txt","r"); - while (fscanf(f_kb,"%s",buf)==1) - { - string s1=buf; - fscanf(f_kb,"%s",buf); - string s2=buf; - fscanf(f_kb,"%s",buf); - string s3=buf; - if (entity2id.count(s1)==0) - { - cout<<"miss entity:"< >::iterator it = left_entity[i].begin(); it!=left_entity[i].end(); it++) - { - sum1++; - sum2+=it->second.size(); - sum3+=sqr(it->second.size()); - } - left_mean[i]=sum2/sum1; - - left_var[i]=sum3/sum1-sqr(left_mean[i]); - } - for (int i=0; i >::iterator it = right_entity[i].begin(); it!=right_entity[i].end(); it++) - { - sum1++; - sum2+=it->second.size(); - sum3+=sqr(it->second.size()); - } - right_mean[i]=sum2/sum1; - right_var[i]=sum3/sum1-sqr(right_mean[i]); - } - - fclose(f_kb); - cout<<"relation_num="< e1_vec; + e1_vec.resize(m); + vector e2_vec; + e2_vec.resize(m); + for (int ii=0; ii0) + x=1; + else + x=-1; + for (int jj=0; jjsum2) + { + res+=margin+sum1-sum2; + gradient( e1_a, e2_a, rel_a, e1_b, e2_b, rel_b); + } + } +}; + +Train train; +void prepare() +{ + FILE* f1 = fopen("../data/entity2id.txt","r"); + FILE* f2 = fopen("../data/relation2id.txt","r"); + int x; + while (fscanf(f1,"%s%d",buf,&x)==2) + { + string st=buf; + entity2id[st]=x; + id2entity[x]=st; + entity_num++; + } + while (fscanf(f2,"%s%d",buf,&x)==2) + { + string st=buf; + relation2id[st]=x; + id2relation[x]=st; + relation_num++; + } + FILE* f_kb = fopen("../data/train.txt","r"); + while (fscanf(f_kb,"%s",buf)==1) + { + string s1=buf; + fscanf(f_kb,"%s",buf); + string s2=buf; + fscanf(f_kb,"%s",buf); + string s3=buf; + if (entity2id.count(s1)==0) + { + cout<<"miss entity:"< >::iterator it = left_entity[i].begin(); it!=left_entity[i].end(); it++) + { + sum1++; + sum2+=it->second.size(); + sum3+=sqr(it->second.size()); + } + left_mean[i]=sum2/sum1; + + left_var[i]=sum3/sum1-sqr(left_mean[i]); + } + for (int i=0; i >::iterator it = right_entity[i].begin(); it!=right_entity[i].end(); it++) + { + sum1++; + sum2+=it->second.size(); + sum3+=sqr(it->second.size()); + } + right_mean[i]=sum2/sum1; + right_var[i]=sum3/sum1-sqr(right_mean[i]); + } + + fclose(f_kb); + cout<<"relation_num="<