-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtwo_tower_base_retrieval.py
391 lines (346 loc) · 16.4 KB
/
two_tower_base_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
"""
This is a baseline implementation of a two-tower based candidate generator
(retrieval) in a recommender system.
Ref: https://recsysml.substack.com/p/two-tower-models-for-retrieval-of
In training we create a user embedding from user features. In this example, we
are ignoring user history features. They will be handled in a follow on derived
class. We compute item embeddings for the items in the batch and use the user
embedding and item embeddings to compute a softmax loss. We assume the training
data comprises of all items impressed by the user. Hence it includes both
positive and hard negatives. We weight the loss by the net_user_value, which
is a linear combination of point-wise immediate rewards. Hence the loss is
effectively derived only from the "positives", as is assumed in two-tower
models.
"""
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.baseline_mips_module import BaselineMIPSModule
class TwoTowerBaseRetrieval(nn.Module):
"""Two-tower model for candidate retrieval in recommender systems."""
def __init__(
self,
num_items: int,
user_id_hash_size: int,
user_id_embedding_dim: int,
user_features_size: int,
item_id_hash_size: int,
item_id_embedding_dim: int,
item_features_size: int,
user_value_weights: List[float],
mips_module: BaselineMIPSModule,
) -> None:
"""
Initialize the TwoTowerBaseRetrieval model.
params:
num_items: the number of items to return per user/query
user_id_hash_size: the size of the embedding table for users
user_id_embedding_dim (DU): internal dimension
user_features_size (IU): input feature size for users
item_id_hash_size: the size of the embedding table for items
item_id_embedding_dim (DI): internal dimension
item_features_size: (II) input feature size for items
user_value_weights: T dimensional weights, such that a linear
combination of point-wise immediate rewards is the best
predictor of long term user satisfaction.
mips_module: a module that computes the Maximum Inner Product
Search (MIPS) over the item embeddings given the user
embedding.
"""
super().__init__()
self.num_items = num_items
# [T] dimensional vector describing how positive each label is.
# TODO add device input.
self.user_value_weights = torch.tensor(user_value_weights)
self.mips_module = mips_module
# Create the machinery for user tower
# 1. Create a module to represent user preference by a table lookup.
# Please see https://github.com/gauravchak/user_preference_modeling
# for other ways to represent user preference embedding.
self.user_id_embedding_arch = nn.Embedding(
user_id_hash_size, user_id_embedding_dim
)
# 2. Create an arch to process the user_features. We are using one
# hidden layer of 256 dimensions. This is just a reasonable default.
# You can experiment with other architectures.
self.user_features_arch = nn.Sequential(
nn.Linear(user_features_size, 256),
nn.ReLU(),
nn.Linear(256, user_id_embedding_dim),
)
# 3. Create an arch to process the user_tower_input
# Input dimension =
# user_id_embedding_dim from get_user_embedding,
# essentially based on user_id
# + user_id_embedding_dim from user_features_arch,
# essentially based on user_features
# Output dimension = item_id_embedding_dim
# The output of this arch will be used for MIPS module.
# Hence the output dimension needs to be same as the item tower output.
self.user_tower_arch = nn.Linear(
in_features=2 * user_id_embedding_dim,
out_features=item_id_embedding_dim,
)
# Create the archs for item tower
# 1. Embedding layers for item id
self.item_id_embedding_arch = nn.Embedding(
item_id_hash_size, item_id_embedding_dim
)
# 2. Create an arch to process the item_features
self.item_features_arch = nn.Sequential(
nn.Linear(item_features_size, 256),
nn.ReLU(),
nn.Linear(256, item_id_embedding_dim),
)
# 3. Create an arch to process the item_tower_input
self.item_tower_arch = nn.Linear(
in_features=2 * item_id_embedding_dim, # concat id and features
out_features=item_id_embedding_dim,
)
def get_user_embedding(
self,
user_id: torch.Tensor, # [B]
user_features: torch.Tensor, # [B, IU]
) -> torch.Tensor:
"""
Extract user representation via memorization/generalization
The API is same as the multiple ways of user representation implemented
in https://github.com/gauravchak/user_preference_modeling
In particular, we recommend trying the Mixture of Represenations
implementation in https://github.com/gauravchak/user_preference_modeling/blob/main/src/user_mo_representations.py#L62
In this implementation we use an embedding table lookup approach.
"""
user_id_embedding = self.user_id_embedding_arch(user_id)
return user_id_embedding
def process_user_features(
self,
user_id: torch.Tensor, # [B]
user_features: torch.Tensor, # [B, IU]
user_history: torch.Tensor, # [B, H]
) -> torch.Tensor:
"""
Process the user features to compute the input to user tower arch.
Args:
user_id (torch.Tensor): Tensor containing the user IDs. Shape: [B]
user_features (torch.Tensor): Tensor containing the user features. Shape: [B, IU]
user_history (torch.Tensor): For each batch an H length history of ids. Shape: [B, H]
In this base implementation this is unused. In subclasses this
affects the computation.
Returns:
torch.Tensor: Shape: [B, 2 * DU]
"""
user_id_embedding = self.get_user_embedding(
user_id=user_id, user_features=user_features
) # [B, DU]
# Process user features
user_features_embedding = self.user_features_arch(
user_features
) # [B, DU]
# Concatenate the inputs. This will be used in future to compute
# the next user embedding.
user_tower_input = torch.cat(
[user_id_embedding, user_features_embedding], dim=1
)
return user_tower_input
def compute_user_embedding(
self,
user_id: torch.Tensor, # [B]
user_features: torch.Tensor, # [B, IU]
user_history: torch.Tensor, # [B, H]
) -> torch.Tensor:
"""
Compute the user embedding. This will be used to query mips.
Args:
user_id: the user id
user_features: the user features. We are assuming these are all dense features.
In practice you will probably want to support sparse embedding features as well.
user_history: for each user, the history of items they have interacted with.
This is a tensor of item ids. Here we are assuming that the history is
a fixed length, but in practice you will probably want to support variable
length histories. jagged tensors are a good way to do this.
This is NOT USED in this implementation. It is handled in a follow on derived class.
Returns:
torch.Tensor: Tensor containing query user embeddings. Shape: [B, DI]
"""
user_tower_input = self.process_user_features(
user_id=user_id, user_features=user_features, user_history=user_history
)
# Compute the user embedding
user_embedding = self.user_tower_arch(user_tower_input) # [B, DI]
return user_embedding
def compute_item_embeddings(
self,
item_id: torch.Tensor, # [B]
item_features: torch.Tensor, # [B, II]
) -> torch.Tensor:
"""
Process item_id and item_features to compute item embeddings.
Args:
item_id (torch.Tensor): Tensor containing item IDs. Shape: [B]
item_features (torch.Tensor): Tensor containing item features. Shape: [B, II]
Returns:
torch.Tensor: Tensor containing item embeddings. Shape: [B, DI]
"""
# Process item id
item_id_embedding = self.item_id_embedding_arch(item_id)
# Process item features
item_features_embedding = self.item_features_arch(item_features)
# Concatenate the inputs and pass them through item_tower_arch to
# compute the item embedding.
item_tower_input = torch.cat(
[item_id_embedding, item_features_embedding], dim=1
)
# Compute the item embedding
item_embedding = self.item_tower_arch(item_tower_input) # [B, DI]
return item_embedding
def forward(
self,
user_id: torch.Tensor, # [B]
user_features: torch.Tensor, # [B, IU]
user_history: torch.Tensor, # [B, H]
) -> torch.Tensor:
"""This is used for inference.
Compute the user embedding and return the top num_items items using the mips module.
Args:
user_id (torch.Tensor): Tensor representing the user ID. Shape: [B]
user_features (torch.Tensor): Tensor representing the user features. Shape: [B, IU]
user_history (torch.Tensor): Tensor representing the user history. Shape: [B, H]
Returns:
torch.Tensor: Tensor representing the top num_items items. Shape: [B, num_items]
"""
# Compute the user embedding
user_embedding = self.compute_user_embedding(
user_id, user_features, user_history
)
# Query the mips module to get the top num_items items and their
# embeddings. The embeddings aren't strictly necessary in the base
# implementation.
top_items, _, _ = self.mips_module(
query_embedding=user_embedding, num_items=self.num_items
) # indices [B, num_items], mips_scores [B, NI], embeddings [B, NI, DI] # noqa
return top_items
def debias_net_user_value(
self,
net_user_value: torch.Tensor, # [B]
position: torch.Tensor, # [B]
user_embedding: torch.Tensor, # [B, DI]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns the processed net_user_value and any losses to be added
to the loss function.
The idea here is to model the user value as a function of purely
user and context features. This way the user and item interaction
can be tasked to only predict what is incremental over what could
have been predicted using user and position (context).
Args:
net_user_value (torch.Tensor): The net user value tensor [B].
position (torch.Tensor): The position tensor of shape [B].
user_embedding: same as what is used in MIPS # [B, DI]
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the
processed net_user_value tensor and any losses to be added
to the loss function.
This is written as a function and not in train_forward to make
it easier to implement in a derived class.
"""
return net_user_value, 0
def compute_training_loss(
self,
user_embedding: torch.Tensor, # [B, DI]
item_embeddings: torch.Tensor, # [B, DI]
position: torch.Tensor, # [B]
labels: torch.Tensor, # [B, T]
) -> torch.Tensor:
# Compute the scores for every pair of user and item
scores = torch.matmul(user_embedding, item_embeddings.t()) # [B, B]
# You should either try to handle the popularity bias
# of in-batch negatives using log-Q correction or
# use random negatives.
# [Mixed Negative Sampling paper](https://research.google/pubs/mixed-negative-sampling-for-learning-two-tower-neural-networks-in-recommendations/) noqa
# suggests random negatives is a better approach.
# Here we are restricting ourselves to in-batch negatives and we are
# not implementing either corrections due to time constraints.
# Compute softmax loss
# F.cross_entropy accepts target as
# ground truth class indices or class probabilities;
# Here we are using class indices
target = torch.arange(scores.shape[0]).to(scores.device) # [B]
# In the cross entropy computation below, we are not reducing
# to mean since not every row in the batch is a "positive" example.
# To only learn from positive examples, we are computing loss per row
# and then using per row weights. Specifically, we are weighting the
# loss by the net_user_value after this to give more weight to the
# positive examples and 0 weight to the hard-negative examples.
# Note that net_user_value is assumed to be non-negative.
loss = F.cross_entropy(
input=scores, target=target, reduction="none"
) # [B]
# Compute the weighted average of the labels using user_value_weights
# In the simplest case, assume you have a single label per item.
# This label is either 1 or 0 depending on whether the user engaged
# with this item when recommended. Then the net_user_value is 1 when
# the user has engaged with the item and 0 otherwise.
net_user_value = torch.matmul(labels, self.user_value_weights) # [B]
# Optionally debias the net_user_value by the part explained purely
# by position. Not implemented in this version. Hence net_user_value
# is unchanged and additional_loss is 0.
net_user_value, additional_loss = self.debias_net_user_value(
net_user_value=net_user_value,
position=position,
user_embedding=user_embedding,
) # [B], [1]
# Floor by epsilon to only preserve positive net_user_value
net_user_value = torch.clamp(
net_user_value, min=0.000001 # small epsilon to avoid divide by 0
) # [B]
# Normalize net_user_value by the max value of it in batch.
# This is to ensure that the net_user_value is between 0 and 1.
net_user_value = net_user_value / torch.max(net_user_value) # [B]
# Compute the product of loss and net_user_value
loss = loss * net_user_value # [B]
loss = torch.mean(loss) # ()
# This loss helps us learn the debiasing archs
loss = loss + additional_loss
return loss
def train_forward(
self,
user_id: torch.Tensor, # [B]
user_features: torch.Tensor, # [B, IU]
user_history: torch.Tensor, # [B, H]
item_id: torch.Tensor, # [B]
item_features: torch.Tensor, # [B, II]
position: torch.Tensor, # [B]
labels: torch.Tensor, # [B, T]
) -> float:
"""
This function computes the loss during training.
Args:
user_id (torch.Tensor): User IDs. Shape: [B].
user_features (torch.Tensor): User features. Shape: [B, IU].
user_history (torch.Tensor): User history. Shape: [B, H].
item_id (torch.Tensor): Item IDs. Shape: [B].
item_features (torch.Tensor): Item features. Shape: [B, II].
position (torch.Tensor): Position. Shape: [B].
labels (torch.Tensor): Labels. Shape: [B, T].
Returns:
float: The computed loss.
Notes:
- The loss is computed using softmax loss and weighted by the net_user_value.
- Optionally, the net_user_value can be debiased by the part explained purely by position.
- The loss is clamped to preserve positive net_user_value and normalized between 0 and 1.
"""
# Compute the user embedding
user_embedding = self.compute_user_embedding(
user_id, user_features, user_history
) # [B, DI]
# Compute item embeddings
item_embeddings = self.compute_item_embeddings(
item_id, item_features
) # [B, DI]
loss = self.compute_training_loss(
user_embedding=user_embedding,
item_embeddings=item_embeddings,
position=position,
labels=labels,
)
return loss # ()