From d6e4dbe42039c5f89f48b0ae2b5b64e4093ad4b5 Mon Sep 17 00:00:00 2001 From: lyk423 Date: Wed, 21 Jan 2015 14:31:51 +0800 Subject: [PATCH] Add trainning code of TransH --- .DS_Store | Bin 6148 -> 6148 bytes TransH/Train_TransH.cpp | 449 ++++++++++++++++++++++++++++++++++++++++ TransR/Train_TransR.cpp | 35 ---- 3 files changed, 449 insertions(+), 35 deletions(-) create mode 100644 TransH/Train_TransH.cpp diff --git a/.DS_Store b/.DS_Store index 2d71c7501aad571b7886370c34f153d0ba3a4621..4e1a66919f1588c25a24fbdc9e98f7a3a801f343 100644 GIT binary patch delta 132 zcmZoMXfc=|#>B!ku~2NHo+2aL#(>?7iw&5W7}+-SFo`qrvN41(6fq<+B)qu~2NHo+2aD#(>?7lMO^zHuJNHFi!Sn-?LePgPmn!gW_g(4t@@x b!p(vl- +#include +#include +#include +#include +#include +#include +#include +#include +using namespace std; + + +#define pi 3.1415926535897932384626433832795 + +//normal distribution +double rand(double min, double max) +{ + return min+(max-min)*rand()/(RAND_MAX+1.0); +} +double normal(double x, double miu,double sigma) +{ + return 1.0/sqrt(2*pi)/sigma*exp(-1*(x-miu)*(x-miu)/(2*sigma*sigma)); +} +double randn(double miu,double sigma, double min ,double max) +{ + double x,y,dScope; + do{ + x=rand(min,max); + y=normal(x,miu,sigma); + dScope=rand(0.0,normal(miu,miu,sigma)); + }while(dScope>y); + return x; +} + +double sqr(double x) +{ + return x*x; +} + +double vec_len(vector &a) +{ + double res=0; + for (int i=0; i relation2id,entity2id; +map id2entity,id2relation; + + + +map > > left_entity,right_entity; +map left_mean,right_mean,left_var,right_var; +map > left_candidate_ok,right_candidate_ok; +map > left_candidate,right_candidate; +map entity2num; + +class Train{ + +public: + map, map > ok; + void add(int x,int y,int z) + { + fb_h.push_back(x); + fb_r.push_back(z); + fb_l.push_back(y); + ok[make_pair(x,z)][y]=1; + } + void run(int n_in,double rate_in,double margin_in,int method_in) + { + n = n_in; + rate = rate_in; + margin = margin_in; + method = method_in; + A.resize(relation_num); + for (int i=0; i fb_h,fb_l,fb_r; + vector > feature; + vector > A, A_tmp; + vector > relation_vec,entity_vec,relation_tmp,entity_tmp; + double norm(vector &a) + { + double x = vec_len(a); + if (x>1) + for (int ii=0; ii &a) + { + double x = vec_len(a); + for (int ii=0; ii &a, vector &A) + { + norm2one(A); + double sum=0; + while (true) + { + for (int i=0; i0.1) + { + for (int ii=0; ii0) + j=rand_max(entity_num); + train_kb(fb_h[i],fb_l[i],fb_r[i],fb_h[i],j,fb_r[i]); + } + else + { + while (ok[make_pair(j,fb_r[i])].count(fb_l[i])>0) + j=rand_max(entity_num); + train_kb(fb_h[i],fb_l[i],fb_r[i],j,fb_l[i],fb_r[i]); + } + norm(entity_tmp[fb_h[i]]); + norm(entity_tmp[fb_l[i]]); + norm(entity_tmp[j]); + norm(entity_tmp[fb_h[i]],A_tmp[fb_r[i]]); + norm(entity_tmp[fb_l[i]],A_tmp[fb_r[i]]); + norm(entity_tmp[j],A_tmp[fb_r[i]]); + } + + A = A_tmp; + relation_vec = relation_tmp; + entity_vec = entity_tmp; + } + cout<0) + x=1; + else + x=-1; + sum_x+=x*A[rel][ii]; + relation_tmp[rel][ii]-=belta*rate*x; + entity_tmp[e1][ii]-=belta*rate*x; + entity_tmp[e2][ii]+=belta*rate*x; + A[rel][ii]+=belta*rate*x*tmp1; + A[rel][ii]-=belta*rate*x*tmp2; + } + for (int ii=0; iisum2) + { + res+=margin+sum1-sum2; + gradient( e1_a, e2_a, rel_a, -1); + gradient(e1_b, e2_b, rel_b,1); + } + } +}; + +Train train; +void prepare() +{ + FILE* f1 = fopen("../data/entity2id.txt","r"); + FILE* f2 = fopen("../data/relation2id.txt","r"); + int x; + while (fscanf(f1,"%s%d",buf,&x)==2) + { + string st=buf; + entity2id[st]=x; + id2entity[x]=st; + entity_num++; + } + while (fscanf(f2,"%s%d",buf,&x)==2) + { + string st=buf; + relation2id[st]=x; + id2relation[x]=st; + relation_num++; + } + FILE* f_kb = fopen("../data/train.txt","r"); + while (fscanf(f_kb,"%s",buf)==1) + { + string s1=buf; + fscanf(f_kb,"%s",buf); + string s2=buf; + fscanf(f_kb,"%s",buf); + string s3=buf; + if (entity2id.count(s1)==0) + { + cout<<"miss entity:"< >::iterator it = left_entity[i].begin(); it!=left_entity[i].end(); it++) + { + sum1++; + sum2+=it->second.size(); + sum3+=sqr(it->second.size()); + } + left_mean[i]=sum2/sum1; + + left_var[i]=sum3/sum1-sqr(left_mean[i]); + } + for (int i=0; i >::iterator it = right_entity[i].begin(); it!=right_entity[i].end(); it++) + { + sum1++; + sum2+=it->second.size(); + sum3+=sqr(it->second.size()); + } + right_mean[i]=sum2/sum1; + right_var[i]=sum3/sum1-sqr(right_mean[i]); + } + + for (int i=0; i 0) n = atoi(argv[i + 1]); + if ((i = ArgPos((char *)"-margin", argc, argv)) > 0) margin = atoi(argv[i + 1]); + if ((i = ArgPos((char *)"-method", argc, argv)) > 0) method = atoi(argv[i + 1]); + cout<<"size = "<