forked from mnick/scikit-kge
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_hole.py
65 lines (54 loc) · 2.13 KB
/
run_hole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python
import numpy as np
from base import Experiment, FilteredRankingEval
from skge.util import ccorr
from skge import StochasticTrainer, PairwiseStochasticTrainer, HolE
from skge import activation_functions as afs
class HolEEval(FilteredRankingEval):
def prepare(self, mdl, p):
#print ("UNM$$$ ####### shape(mdl.E) : {} , shape(mdl.R[p]) = {} ".format(np.shape(mdl.E), np.shape(mdl.R[p])))
self.ER = ccorr(mdl.R[p], mdl.E)
#print ("UNM$$$ &&&&&&&&&&& shape(self.ER) = ", np.shape(self.ER))
def scores_o(self, mdl, s, p):
return np.dot(self.ER, mdl.E[s])
def scores_s(self, mdl, o, p):
return np.dot(mdl.E, self.ER[o])
class ExpHolE(Experiment):
def __init__(self):
super(ExpHolE, self).__init__()
self.parser.add_argument('--ncomp', type=int, help='Number of latent components (dimensions)')
self.parser.add_argument('--rparam', type=float, help='Regularization for W', default=0)
self.parser.add_argument('--afs', type=str, default='sigmoid', help='Activation function')
self.evaluator = HolEEval
self.algo = "HolE"
def setup_trainer(self, sz, sampler):
model = HolE(
sz,
self.args.ncomp,
rparam=self.args.rparam,
af=afs[self.args.afs],
init=self.args.init
)
if self.args.no_pairwise:
trainer = StochasticTrainer(
model,
nbatches=self.args.nb,
max_epochs=self.args.me,
post_epoch=[self.callback],
learning_rate=self.args.lr,
samplef=sampler.sample
)
else:
#print ("UNM$$$ Running Pairwise stochastic trainer")
trainer = PairwiseStochasticTrainer(
model,
nbatches=self.args.nb,
max_epochs=self.args.me,
post_epoch=[self.callback],
learning_rate=self.args.lr,
margin=self.args.margin,
samplef=sampler.sample
)
return trainer
if __name__ == '__main__':
ExpHolE().run()