forked from bigdata-ustc/EduCDM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request bigdata-ustc#27 from bigdata-ustc/IRR
[FEATURE] Item Response Ranking with DINA, MIRT and NCDM
- Loading branch information
Showing
23 changed files
with
1,227 additions
and
282 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# coding: utf-8 | ||
# 2021/7/1 @ tongshiwei | ||
|
||
import pandas as pd | ||
import numpy as np | ||
import torch | ||
from torch import nn | ||
from EduCDM import GDDINA | ||
from .loss import PairSCELoss, HarmonicLoss, loss_mask | ||
from tqdm import tqdm | ||
from longling.ML.metrics import ranking_report | ||
|
||
|
||
class DINA(GDDINA): | ||
def __init__(self, user_num, item_num, knowledge_num, ste=False, zeta=0.5): | ||
super(DINA, self).__init__(user_num, item_num, knowledge_num, ste) | ||
self.zeta = zeta | ||
|
||
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...: | ||
point_loss_function = nn.BCELoss() | ||
pair_loss_function = PairSCELoss() | ||
loss_function = HarmonicLoss(self.zeta) | ||
|
||
trainer = torch.optim.Adam(self.dina_net.parameters(), lr, weight_decay=1e-4) | ||
|
||
for e in range(epoch): | ||
point_losses = [] | ||
pair_losses = [] | ||
losses = [] | ||
for batch_data in tqdm(train_data, "Epoch %s" % e): | ||
user_id, item_id, knowledge, score, n_samples, *neg_users = batch_data | ||
user_id: torch.Tensor = user_id.to(device) | ||
item_id: torch.Tensor = item_id.to(device) | ||
knowledge: torch.Tensor = knowledge.to(device) | ||
predicted_pos_score: torch.Tensor = self.dina_net(user_id, item_id, knowledge) | ||
score: torch.Tensor = score.to(device) | ||
neg_score = 1 - score | ||
|
||
point_loss = point_loss_function(predicted_pos_score, score) | ||
predicted_neg_scores = [] | ||
if neg_users: | ||
for neg_user in neg_users: | ||
predicted_neg_score = self.dina_net(neg_user, item_id, knowledge) | ||
predicted_neg_scores.append(predicted_neg_score) | ||
|
||
# prediction loss | ||
pair_pred_loss_list = [] | ||
for i, predicted_neg_score in enumerate(predicted_neg_scores): | ||
pair_pred_loss_list.append( | ||
pair_loss_function( | ||
predicted_pos_score, | ||
predicted_neg_score, | ||
score - neg_score | ||
) | ||
) | ||
|
||
pair_loss = sum(loss_mask(pair_pred_loss_list, n_samples)) | ||
else: | ||
pair_loss = 0 | ||
|
||
loss = loss_function(point_loss, pair_loss) | ||
|
||
# back propagation | ||
trainer.zero_grad() | ||
loss.backward() | ||
trainer.step() | ||
|
||
point_losses.append(point_loss.mean().item()) | ||
pair_losses.append(pair_loss.mean().item() if not isinstance(pair_loss, int) else pair_loss) | ||
losses.append(loss.item()) | ||
print( | ||
"[Epoch %d] Loss: %.6f, PointLoss: %.6f, PairLoss: %.6f" % ( | ||
e, float(np.mean(losses)), float(np.mean(point_losses)), float(np.mean(pair_losses)) | ||
) | ||
) | ||
|
||
if test_data is not None: | ||
eval_data = self.eval(test_data) | ||
print("[Epoch %d]\n%s" % (e, eval_data)) | ||
|
||
def eval(self, test_data, device="cpu"): | ||
self.dina_net.eval() | ||
y_pred = [] | ||
y_true = [] | ||
items = [] | ||
for batch_data in tqdm(test_data, "evaluating"): | ||
user_id, item_id, knowledge, response = batch_data | ||
user_id: torch.Tensor = user_id.to(device) | ||
item_id: torch.Tensor = item_id.to(device) | ||
pred: torch.Tensor = self.dina_net(user_id, item_id, knowledge) | ||
y_pred.extend(pred.tolist()) | ||
y_true.extend(response.tolist()) | ||
items.extend(item_id.tolist()) | ||
|
||
df = pd.DataFrame({ | ||
"item_id": items, | ||
"score": y_true, | ||
"pred": y_pred, | ||
}) | ||
|
||
ground_truth = [] | ||
prediction = [] | ||
|
||
for _, group_df in tqdm(df.groupby("item_id"), "formatting item df"): | ||
ground_truth.append(group_df["score"].values) | ||
prediction.append(group_df["pred"].values) | ||
|
||
self.dina_net.train() | ||
|
||
return ranking_report( | ||
ground_truth, | ||
y_pred=prediction, | ||
coerce="padding" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# coding: utf-8 | ||
# 2021/7/1 @ tongshiwei | ||
|
||
|
||
import torch | ||
from torch import nn | ||
from tqdm import tqdm | ||
from EduCDM import MIRT as PointMIRT | ||
import numpy as np | ||
import pandas as pd | ||
from .loss import PairSCELoss, HarmonicLoss, loss_mask | ||
from longling.ML.metrics import ranking_report | ||
|
||
__all__ = ["MIRT"] | ||
|
||
|
||
class MIRT(PointMIRT): | ||
def __init__(self, user_num, item_num, knowledge_num, latent_dim=None, zeta=0.5): | ||
latent_dim = knowledge_num if latent_dim is None else latent_dim | ||
super(MIRT, self).__init__(user_num, item_num, latent_dim) | ||
self.knowledge_num = knowledge_num | ||
self.zeta = zeta | ||
|
||
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...: | ||
point_loss_function = nn.BCELoss() | ||
pair_loss_function = PairSCELoss() | ||
loss_function = HarmonicLoss(self.zeta) | ||
|
||
trainer = torch.optim.Adam(self.irt_net.parameters(), lr, weight_decay=1e-4) | ||
|
||
for e in range(epoch): | ||
point_losses = [] | ||
pair_losses = [] | ||
losses = [] | ||
for batch_data in tqdm(train_data, "Epoch %s" % e): | ||
user_id, item_id, _, score, n_samples, *neg_users = batch_data | ||
user_id: torch.Tensor = user_id.to(device) | ||
item_id: torch.Tensor = item_id.to(device) | ||
predicted_pos_score: torch.Tensor = self.irt_net(user_id, item_id) | ||
score: torch.Tensor = score.to(device) | ||
neg_score = 1 - score | ||
|
||
point_loss = point_loss_function(predicted_pos_score, score) | ||
predicted_neg_scores = [] | ||
if neg_users: | ||
for neg_user in neg_users: | ||
predicted_neg_score = self.irt_net(neg_user, item_id) | ||
predicted_neg_scores.append(predicted_neg_score) | ||
|
||
# prediction loss | ||
pair_pred_loss_list = [] | ||
for i, predicted_neg_score in enumerate(predicted_neg_scores): | ||
pair_pred_loss_list.append( | ||
pair_loss_function( | ||
predicted_pos_score, | ||
predicted_neg_score, | ||
score - neg_score | ||
) | ||
) | ||
|
||
pair_loss = sum(loss_mask(pair_pred_loss_list, n_samples)) | ||
else: | ||
pair_loss = 0 | ||
|
||
loss = loss_function(point_loss, pair_loss) | ||
|
||
# back propagation | ||
trainer.zero_grad() | ||
loss.backward() | ||
trainer.step() | ||
|
||
point_losses.append(point_loss.mean().item()) | ||
pair_losses.append(pair_loss.mean().item() if not isinstance(pair_loss, int) else pair_loss) | ||
losses.append(loss.item()) | ||
print( | ||
"[Epoch %d] Loss: %.6f, PointLoss: %.6f, PairLoss: %.6f" % ( | ||
e, float(np.mean(losses)), float(np.mean(point_losses)), float(np.mean(pair_losses)) | ||
) | ||
) | ||
|
||
if test_data is not None: | ||
eval_data = self.eval(test_data) | ||
print("[Epoch %d]\n%s" % (e, eval_data)) | ||
|
||
def eval(self, test_data, device="cpu"): | ||
self.irt_net.eval() | ||
y_pred = [] | ||
y_true = [] | ||
items = [] | ||
for batch_data in tqdm(test_data, "evaluating"): | ||
user_id, item_id, _, response = batch_data | ||
user_id: torch.Tensor = user_id.to(device) | ||
item_id: torch.Tensor = item_id.to(device) | ||
pred: torch.Tensor = self.irt_net(user_id, item_id) | ||
y_pred.extend(pred.tolist()) | ||
y_true.extend(response.tolist()) | ||
items.extend(item_id.tolist()) | ||
|
||
df = pd.DataFrame({ | ||
"item_id": items, | ||
"score": y_true, | ||
"pred": y_pred, | ||
}) | ||
|
||
ground_truth = [] | ||
prediction = [] | ||
|
||
for _, group_df in tqdm(df.groupby("item_id"), "formatting item df"): | ||
ground_truth.append(group_df["score"].values) | ||
prediction.append(group_df["pred"].values) | ||
|
||
self.irt_net.train() | ||
|
||
return ranking_report( | ||
ground_truth, | ||
y_pred=prediction, | ||
coerce="padding" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# coding: utf-8 | ||
# 2021/7/1 @ tongshiwei | ||
|
||
import pandas as pd | ||
import numpy as np | ||
import torch | ||
from torch import nn | ||
from EduCDM import NCDM as PointNCDM | ||
from .loss import PairSCELoss, HarmonicLoss, loss_mask | ||
from tqdm import tqdm | ||
from longling.ML.metrics import ranking_report | ||
|
||
|
||
class NCDM(PointNCDM): | ||
def __init__(self, user_num, item_num, knowledge_num, zeta=0.5): | ||
super(NCDM, self).__init__(knowledge_num, item_num, user_num) | ||
self.zeta = zeta | ||
|
||
def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, silence=False) -> ...: | ||
point_loss_function = nn.BCELoss() | ||
pair_loss_function = PairSCELoss() | ||
loss_function = HarmonicLoss(self.zeta) | ||
|
||
trainer = torch.optim.Adam(self.ncdm_net.parameters(), lr, weight_decay=1e-4) | ||
|
||
for e in range(epoch): | ||
point_losses = [] | ||
pair_losses = [] | ||
losses = [] | ||
for batch_data in tqdm(train_data, "Epoch %s" % e): | ||
user_id, item_id, knowledge, score, n_samples, *neg_users = batch_data | ||
user_id: torch.Tensor = user_id.to(device) | ||
item_id: torch.Tensor = item_id.to(device) | ||
knowledge: torch.Tensor = knowledge.to(device) | ||
predicted_pos_score: torch.Tensor = self.ncdm_net(user_id, item_id, knowledge) | ||
score: torch.Tensor = score.to(device) | ||
neg_score = 1 - score | ||
|
||
point_loss = point_loss_function(predicted_pos_score, score) | ||
predicted_neg_scores = [] | ||
if neg_users: | ||
for neg_user in neg_users: | ||
predicted_neg_score = self.ncdm_net(neg_user, item_id, knowledge) | ||
predicted_neg_scores.append(predicted_neg_score) | ||
|
||
# prediction loss | ||
pair_pred_loss_list = [] | ||
for i, predicted_neg_score in enumerate(predicted_neg_scores): | ||
pair_pred_loss_list.append( | ||
pair_loss_function( | ||
predicted_pos_score, | ||
predicted_neg_score, | ||
score - neg_score | ||
) | ||
) | ||
|
||
pair_loss = sum(loss_mask(pair_pred_loss_list, n_samples)) | ||
else: | ||
pair_loss = 0 | ||
|
||
loss = loss_function(point_loss, pair_loss) | ||
|
||
# back propagation | ||
trainer.zero_grad() | ||
loss.backward() | ||
trainer.step() | ||
|
||
point_losses.append(point_loss.mean().item()) | ||
pair_losses.append(pair_loss.mean().item() if not isinstance(pair_loss, int) else pair_loss) | ||
losses.append(loss.item()) | ||
print( | ||
"[Epoch %d] Loss: %.6f, PointLoss: %.6f, PairLoss: %.6f" % ( | ||
e, float(np.mean(losses)), float(np.mean(point_losses)), float(np.mean(pair_losses)) | ||
) | ||
) | ||
|
||
if test_data is not None: | ||
eval_data = self.eval(test_data) | ||
print("[Epoch %d]\n%s" % (e, eval_data)) | ||
|
||
def eval(self, test_data, device="cpu"): | ||
self.ncdm_net.eval() | ||
y_pred = [] | ||
y_true = [] | ||
items = [] | ||
for batch_data in tqdm(test_data, "evaluating"): | ||
user_id, item_id, knowledge, response = batch_data | ||
user_id: torch.Tensor = user_id.to(device) | ||
item_id: torch.Tensor = item_id.to(device) | ||
pred: torch.Tensor = self.ncdm_net(user_id, item_id, knowledge) | ||
y_pred.extend(pred.tolist()) | ||
y_true.extend(response.tolist()) | ||
items.extend(item_id.tolist()) | ||
|
||
df = pd.DataFrame({ | ||
"item_id": items, | ||
"score": y_true, | ||
"pred": y_pred, | ||
}) | ||
|
||
ground_truth = [] | ||
prediction = [] | ||
|
||
for _, group_df in tqdm(df.groupby("item_id"), "formatting item df"): | ||
ground_truth.append(group_df["score"].values) | ||
prediction.append(group_df["pred"].values) | ||
|
||
self.ncdm_net.train() | ||
|
||
return ranking_report( | ||
ground_truth, | ||
y_pred=prediction, | ||
coerce="padding" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.