Skip to content

Commit ea748c9

Browse files
authored
add AFM (shenweichen#3)
1 parent 3d73222 commit ea748c9

13 files changed

+321
-30
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# DeepCTR-Pytorch
1+
# DeepCTR-PyTorch
22

33
[![Python Versions](https://img.shields.io/pypi/pyversions/deepctr.svg)](https://pypi.org/project/deepctr)
44
[![Downloads](https://pepy.tech/badge/deepctr)](https://pepy.tech/project/deepctr)
@@ -32,7 +32,7 @@ please send a brief introduction of your background and experience to wcshen1994
3232

3333
| Model | Paper |
3434
| :------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------- |
35-
| Convolutional Click Prediction Modelin progress】 | [CIKM 2015][A Convolutional Click Prediction Model](http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf) |
35+
| Convolutional Click Prediction Model | [CIKM 2015][A Convolutional Click Prediction Model](http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf) |
3636
| Factorization-supported Neural Network | [ECIR 2016][Deep Learning over Multi-field Categorical Data: A Case Study on User Response Prediction](https://arxiv.org/pdf/1601.02376.pdf) |
3737
| Product-based Neural Network | [ICDM 2016][Product-based neural networks for user response prediction](https://arxiv.org/pdf/1611.00144.pdf) |
3838
| Wide & Deep | [DLRS 2016][Wide & Deep Learning for Recommender Systems](https://arxiv.org/pdf/1606.07792.pdf) |

deepctr_torch/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from . import layers
2+
from . import models
3+
from deepctr.utils import check_version
4+
5+
__version__ = '0.0.1'
6+
check_version(__version__)

deepctr_torch/layers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .interaction import FM
1+
from .interaction import FM,AFMLayer
22
from .core import DNN,PredictionLayer

deepctr_torch/layers/core.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@ def __init__(self, inputs_dim, hidden_units, activation=F.relu, l2_reg=0, dropou
1414
self.l2_reg = l2_reg
1515
self.use_bn = use_bn
1616
hidden_units = [inputs_dim] + list(hidden_units)
17-
self.linears = nn.ModuleList(
18-
[nn.Linear(hidden_units[i], hidden_units[i + 1]) for i in range(len(hidden_units) - 1)])
19-
for tensor in self.linears:
20-
nn.init.normal_(tensor.weight, mean=0, std=init_std)
17+
self.weight = nn.ParameterList([nn.Parameter(torch.Tensor(hidden_units[i+1],hidden_units[i])) for i in range(len(hidden_units)-1)])
18+
self.bias = nn.ParameterList([nn.Parameter(torch.zeros((hidden_units[i+1],))) for i in range(len(hidden_units)-1)])
19+
if self.use_bn:
20+
self.bn = nn.ModuleList([nn.BatchNorm1d(hidden_units[i+1]) for i in range(len(hidden_units)-1)])
21+
for tensor in self.weight:
22+
nn.init.normal_(tensor, mean=0, std=init_std)
2123

2224
def forward(self, inputs):
2325
deep_input = inputs
2426

25-
for i in range(len(self.linears)):
26-
fc = self.linears[i](deep_input)
27+
for i in range(len(self.weight)):
28+
fc = F.linear(deep_input,self.weight[i],self.bias[i])
2729

28-
# if self.use_bn:
29-
# fc = self.bn_layers[i](fc, training=training)
30+
if self.use_bn:
31+
fc = self.bn[i](fc)
3032

3133
fc = self.activation(fc)
3234

@@ -50,12 +52,12 @@ def __init__(self, task='binary', use_bias=True, **kwargs):
5052
self.use_bias = use_bias
5153
self.task = task
5254
if self.use_bias:
53-
self.global_bias = nn.Parameter(torch.zeros((1,)))
55+
self.bias = nn.Parameter(torch.zeros((1,)))
5456

5557
def forward(self, X):
5658
output = X
5759
if self.use_bias:
58-
output += self.global_bias
60+
output += self.bias
5961
if self.task == "binary":
6062
output = torch.sigmoid(output)
6163
return output

deepctr_torch/layers/interaction.py

+85
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
1+
import itertools
2+
13
import torch
24
import torch.nn as nn
5+
import torch.nn.functional as F
36

47

58
class FM(nn.Module):
9+
"""Factorization Machine models pairwise (order-2) feature interactions
10+
without linear term and bias.
11+
Input shape
12+
- 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
13+
Output shape
14+
- 2D tensor with shape: ``(batch_size, 1)``.
15+
References
16+
- [Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
17+
"""
18+
619
def __init__(self):
720
super(FM, self).__init__()
821

@@ -15,3 +28,75 @@ def forward(self, inputs):
1528
cross_term = 0.5 * torch.sum(cross_term, dim=2, keepdim=False)
1629

1730
return cross_term
31+
32+
33+
class AFMLayer(nn.Module):
34+
"""Attentonal Factorization Machine models pairwise (order-2) feature
35+
interactions without linear term and bias.
36+
Input shape
37+
- A list of 3D tensor with shape: ``(batch_size,1,embedding_size)``.
38+
Output shape
39+
- 2D tensor with shape: ``(batch_size, 1)``.
40+
Arguments
41+
- **attention_factor** : Positive integer, dimensionality of the
42+
attention network output space.
43+
- **l2_reg_w** : float between 0 and 1. L2 regularizer strength
44+
applied to attention network.
45+
- **dropout_rate** : float between in [0,1). Fraction of the attention net output units to dropout.
46+
- **seed** : A Python integer to use as random seed.
47+
References
48+
- [Attentional Factorization Machines : Learning the Weight of Feature
49+
Interactions via Attention Networks](https://arxiv.org/pdf/1708.04617.pdf)
50+
"""
51+
52+
def __init__(self, in_feature, attention_factor=4, l2_reg_w=0, dropout_rate=0, seed=1024, device='cpu'):
53+
super(AFMLayer, self).__init__()
54+
self.attention_factor = attention_factor
55+
self.l2_reg_w = l2_reg_w
56+
self.dropout_rate = dropout_rate
57+
self.seed = seed
58+
embedding_size = in_feature
59+
60+
self.attention_W = nn.Parameter(torch.Tensor(embedding_size, self.attention_factor))
61+
62+
self.attention_b = nn.Parameter(torch.Tensor(self.attention_factor))
63+
64+
self.projection_h = nn.Parameter(torch.Tensor(self.attention_factor, 1))
65+
66+
self.projection_p = nn.Parameter(torch.Tensor(embedding_size, 1))
67+
68+
self.weight = self.attention_W
69+
70+
for tensor in [self.attention_W, self.projection_h, self.projection_p]:
71+
nn.init.xavier_normal_(tensor, )
72+
73+
self.dropout = nn.Dropout(dropout_rate)
74+
75+
self.to(device)
76+
77+
def forward(self, inputs):
78+
embeds_vec_list = inputs
79+
row = []
80+
col = []
81+
82+
for r, c in itertools.combinations(embeds_vec_list, 2):
83+
row.append(r)
84+
col.append(c)
85+
86+
p = torch.cat(row, dim=1)
87+
q = torch.cat(col, dim=1)
88+
inner_product = p * q
89+
90+
bi_interaction = inner_product
91+
attention_temp = F.relu(torch.tensordot(
92+
bi_interaction, self.attention_W, dims=([-1], [0])) + self.attention_b)
93+
94+
self.normalized_att_score = F.softmax(torch.tensordot(
95+
attention_temp, self.projection_h, dims=([-1], [0])), dim=1)
96+
attention_output = torch.sum(
97+
self.normalized_att_score * bi_interaction, dim=1)
98+
99+
attention_output = self.dropout(attention_output) # training
100+
101+
afm_out = torch.tensordot(attention_output, self.projection_p, dims=([-1], [0]))
102+
return afm_out

deepctr_torch/models/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .wdl import WDL
2-
from .deepfm import DeepFM
2+
from .deepfm import DeepFM
3+
from .afm import AFM

deepctr_torch/models/afm.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
from .basemodel import BaseModel
5+
from ..layers import FM, AFMLayer
6+
7+
8+
class AFM(BaseModel):
9+
10+
def __init__(self,
11+
linear_feature_columns, dnn_feature_columns, embedding_size=8, use_attention=True, attention_factor=8,
12+
l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_att=1e-5, afm_dropout=0, init_std=0.0001, seed=1024,
13+
task='binary', device='cpu'):
14+
"""Instantiates the Attentional Factorization Machine architecture.
15+
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
16+
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
17+
:param embedding_size: positive integer,sparse feature embedding_size
18+
:param use_attention: bool,whether use attention or not,if set to ``False``.it is the same as **standard Factorization Machine**
19+
:param attention_factor: positive integer,units in attention net
20+
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
21+
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
22+
:param l2_reg_att: float. L2 regularizer strength applied to attention net
23+
:param afm_dropout: float in [0,1), Fraction of the attention net output units to dropout.
24+
:param init_std: float,to use as the initialize std of embedding vector
25+
:param seed: integer ,to use as random seed.
26+
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
27+
:param device
28+
:return: A PyTorch model instance.
29+
"""
30+
31+
super(AFM, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size,
32+
dnn_hidden_units=[],
33+
l2_reg_linear=l2_reg_linear,
34+
l2_reg_embedding=l2_reg_embedding, l2_reg_dnn=0, init_std=init_std,
35+
seed=seed,
36+
dnn_dropout=0, dnn_activation=F.relu,
37+
task=task, device=device)
38+
39+
self.use_attention = use_attention
40+
41+
if use_attention:
42+
self.fm = AFMLayer(embedding_size, attention_factor, l2_reg_att, afm_dropout,
43+
seed, device)
44+
self.add_regularization_loss(self._modules['fm'].weight, l2_reg_att)
45+
else:
46+
self.fm = FM()
47+
48+
self.to(device)
49+
50+
def forward(self, X):
51+
52+
sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns,
53+
self.embedding_dict)
54+
logit = self.linear_model(X)
55+
if self.use_attention:
56+
logit += self.fm(sparse_embedding_list)
57+
else:
58+
logit += self.fm(torch.cat(sparse_embedding_list, dim=1))
59+
60+
y_pred = self.out(logit)
61+
62+
return y_pred

deepctr_torch/models/basemodel.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,15 @@ def predict(self, x, batch_size=256):
243243
pred_ans.append(y_pred)
244244
return np.concatenate(pred_ans)
245245

246-
def input_from_feature_columns(self, X, feature_columns, embedding_dict):
246+
def input_from_feature_columns(self, X, feature_columns, embedding_dict,support_dense=True):
247247
sparse_feature_columns = list(
248248
filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if len(feature_columns) else []
249249
dense_feature_columns = list(
250250
filter(lambda x: isinstance(x, DenseFeat), feature_columns)) if len(feature_columns) else []
251251

252+
if not support_dense and len(dense_feature_columns) > 0:
253+
raise ValueError("DenseFeat is not supported in dnn_feature_columns")
254+
252255
sparse_embedding_list = [embedding_dict[feat.embedding_name](
253256
X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for
254257
feat in sparse_feature_columns]

deepctr_torch/models/deepfm.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@ def __init__(self,
1515
l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024,
1616
dnn_dropout=0,
1717
dnn_activation=F.relu, dnn_use_bn=False, task='binary', device='cpu'):
18+
"""Instantiates the DeepFM Network architecture.
19+
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
20+
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
21+
:param embedding_size: positive integer,sparse feature embedding_size
22+
:param use_fm: bool,use FM part or not
23+
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
24+
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
25+
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
26+
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN
27+
:param init_std: float,to use as the initialize std of embedding vector
28+
:param seed: integer ,to use as random seed.
29+
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
30+
:param dnn_activation: Activation function to use in DNN
31+
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN
32+
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
33+
:param device:
34+
:return: A PyTorch model instance.
35+
"""
1836

1937
super(DeepFM, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size,
2038
dnn_hidden_units=dnn_hidden_units,
@@ -25,10 +43,15 @@ def __init__(self,
2543
task=task, device=device)
2644

2745
self.dnn = DNN(self.compute_input_dim(dnn_feature_columns, embedding_size, ), dnn_hidden_units,
28-
activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, init_std=init_std)
46+
activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout,use_bn=dnn_use_bn, init_std=init_std)
2947
self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False)
30-
# self.add_regularization_loss(chain(self.dnn.parameters(), self.dnn_linear.parameters()), l2_reg_dnn)
31-
self.fm = FM()
48+
49+
self.add_regularization_loss(self.dnn.weight, l2_reg_dnn)
50+
self.add_regularization_loss(self.dnn_linear.weight,l2_reg_dnn)
51+
52+
if use_fm:
53+
self.fm = FM()
54+
self.use_fm = use_fm
3255
self.to(device)
3356

3457
def forward(self, X):
@@ -37,16 +60,15 @@ def forward(self, X):
3760
self.embedding_dict)
3861
linear_logit = self.linear_model(X)
3962

40-
if len(sparse_embedding_list) > 0:
41-
fm_input = torch.cat(sparse_embedding_list, dim=1)
42-
fm_out = self.fm(fm_input)
43-
else:
44-
fm_out = 0
4563
dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list)
4664

4765
dnn_output = self.dnn(dnn_input)
4866
dnn_logit = self.dnn_linear(dnn_output)
49-
logit = linear_logit + dnn_logit + fm_out
67+
logit = linear_logit + dnn_logit
68+
69+
if self.use_fm:
70+
fm_input = torch.cat(sparse_embedding_list, dim=1)
71+
logit += self.fm(fm_input)
5072
y_pred = self.out(logit)
5173

5274
return y_pred

deepctr_torch/models/wdl.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,23 @@ def __init__(self,
1414
linear_feature_columns, dnn_feature_columns, embedding_size=8, dnn_hidden_units=(128, 128),
1515
l2_reg_linear=1e-5,
1616
l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation=F.relu,
17-
task='binary', device='cpu'):
17+
dnn_use_bn=False,task='binary', device='cpu'):
18+
"""Instantiates the Wide&Deep Learning architecture.
19+
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
20+
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
21+
:param embedding_size: positive integer,sparse feature embedding_size
22+
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
23+
:param l2_reg_linear: float. L2 regularizer strength applied to wide part
24+
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
25+
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN
26+
:param init_std: float,to use as the initialize std of embedding vector
27+
:param seed: integer ,to use as random seed.
28+
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
29+
:param dnn_activation: Activation function to use in DNN
30+
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
31+
:param device:
32+
:return: A PyTorch model instance.
33+
"""
1834
super(WDL, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size,
1935
dnn_hidden_units=dnn_hidden_units,
2036
l2_reg_linear=l2_reg_linear,
@@ -24,9 +40,11 @@ def __init__(self,
2440
task=task, device=device)
2541

2642
self.dnn = DNN(self.compute_input_dim(dnn_feature_columns, embedding_size, ), dnn_hidden_units,
27-
activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, init_std=init_std)
43+
activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout,use_bn= dnn_use_bn,init_std=init_std)
2844
self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False)
29-
self.add_regularization_loss(chain(self.dnn.parameters(), self.dnn_linear.parameters()), l2_reg_dnn)
45+
self.add_regularization_loss(self.dnn.weight, l2_reg_dnn)
46+
self.add_regularization_loss(self.dnn_linear.weight, l2_reg_dnn)
47+
3048
self.to(device)
3149

3250
def forward(self, X):

0 commit comments

Comments
 (0)