1
+ import itertools
2
+
1
3
import torch
2
4
import torch .nn as nn
5
+ import torch .nn .functional as F
3
6
4
7
5
8
class FM (nn .Module ):
9
+ """Factorization Machine models pairwise (order-2) feature interactions
10
+ without linear term and bias.
11
+ Input shape
12
+ - 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
13
+ Output shape
14
+ - 2D tensor with shape: ``(batch_size, 1)``.
15
+ References
16
+ - [Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
17
+ """
18
+
6
19
def __init__ (self ):
7
20
super (FM , self ).__init__ ()
8
21
@@ -15,3 +28,75 @@ def forward(self, inputs):
15
28
cross_term = 0.5 * torch .sum (cross_term , dim = 2 , keepdim = False )
16
29
17
30
return cross_term
31
+
32
+
33
+ class AFMLayer (nn .Module ):
34
+ """Attentonal Factorization Machine models pairwise (order-2) feature
35
+ interactions without linear term and bias.
36
+ Input shape
37
+ - A list of 3D tensor with shape: ``(batch_size,1,embedding_size)``.
38
+ Output shape
39
+ - 2D tensor with shape: ``(batch_size, 1)``.
40
+ Arguments
41
+ - **attention_factor** : Positive integer, dimensionality of the
42
+ attention network output space.
43
+ - **l2_reg_w** : float between 0 and 1. L2 regularizer strength
44
+ applied to attention network.
45
+ - **dropout_rate** : float between in [0,1). Fraction of the attention net output units to dropout.
46
+ - **seed** : A Python integer to use as random seed.
47
+ References
48
+ - [Attentional Factorization Machines : Learning the Weight of Feature
49
+ Interactions via Attention Networks](https://arxiv.org/pdf/1708.04617.pdf)
50
+ """
51
+
52
+ def __init__ (self , in_feature , attention_factor = 4 , l2_reg_w = 0 , dropout_rate = 0 , seed = 1024 , device = 'cpu' ):
53
+ super (AFMLayer , self ).__init__ ()
54
+ self .attention_factor = attention_factor
55
+ self .l2_reg_w = l2_reg_w
56
+ self .dropout_rate = dropout_rate
57
+ self .seed = seed
58
+ embedding_size = in_feature
59
+
60
+ self .attention_W = nn .Parameter (torch .Tensor (embedding_size , self .attention_factor ))
61
+
62
+ self .attention_b = nn .Parameter (torch .Tensor (self .attention_factor ))
63
+
64
+ self .projection_h = nn .Parameter (torch .Tensor (self .attention_factor , 1 ))
65
+
66
+ self .projection_p = nn .Parameter (torch .Tensor (embedding_size , 1 ))
67
+
68
+ self .weight = self .attention_W
69
+
70
+ for tensor in [self .attention_W , self .projection_h , self .projection_p ]:
71
+ nn .init .xavier_normal_ (tensor , )
72
+
73
+ self .dropout = nn .Dropout (dropout_rate )
74
+
75
+ self .to (device )
76
+
77
+ def forward (self , inputs ):
78
+ embeds_vec_list = inputs
79
+ row = []
80
+ col = []
81
+
82
+ for r , c in itertools .combinations (embeds_vec_list , 2 ):
83
+ row .append (r )
84
+ col .append (c )
85
+
86
+ p = torch .cat (row , dim = 1 )
87
+ q = torch .cat (col , dim = 1 )
88
+ inner_product = p * q
89
+
90
+ bi_interaction = inner_product
91
+ attention_temp = F .relu (torch .tensordot (
92
+ bi_interaction , self .attention_W , dims = ([- 1 ], [0 ])) + self .attention_b )
93
+
94
+ self .normalized_att_score = F .softmax (torch .tensordot (
95
+ attention_temp , self .projection_h , dims = ([- 1 ], [0 ])), dim = 1 )
96
+ attention_output = torch .sum (
97
+ self .normalized_att_score * bi_interaction , dim = 1 )
98
+
99
+ attention_output = self .dropout (attention_output ) # training
100
+
101
+ afm_out = torch .tensordot (attention_output , self .projection_p , dims = ([- 1 ], [0 ]))
102
+ return afm_out
0 commit comments