Skip to content

Commit

Permalink
use softplus to ensure a>0
Browse files Browse the repository at this point in the history
  • Loading branch information
ViviHong200709 committed Oct 19, 2021
1 parent def72f3 commit 4097e75
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
4 changes: 2 additions & 2 deletions EduCDM/IRT/GD/IRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from EduCDM import CDM
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from ..irt import irt3pl
from sklearn.metrics import roc_auc_score, accuracy_score
Expand All @@ -27,14 +28,13 @@ def forward(self, user, item):
theta = torch.squeeze(self.theta(user), dim=-1)
theta = torch.sigmoid(theta) - 0.5
a = torch.squeeze(self.a(item), dim=-1)
a = torch.sigmoid(a)
a = F.softplus(a)
b = torch.squeeze(self.b(item), dim=-1)
b = torch.sigmoid(b) - 0.5
c = torch.squeeze(self.c(item), dim=-1)
c = torch.sigmoid(c)
if self.value_range is not None:
theta = self.value_range * theta
a = self.value_range * a
b = self.value_range * b
if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b):
raise Exception('Error:theta,a,b may contains nan! The value_range is too large.')
Expand Down
15 changes: 6 additions & 9 deletions EduCDM/MIRT/MIRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from EduCDM import CDM
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score

Expand Down Expand Up @@ -41,25 +42,21 @@ def irt2pl(theta, a, b, *, F=np):


class MIRTNet(nn.Module):
def __init__(self, user_num, item_num, latent_dim, value_range, irf_kwargs=None):
def __init__(self, user_num, item_num, latent_dim, irf_kwargs=None):
super(MIRTNet, self).__init__()
self.user_num = user_num
self.item_num = item_num
self.irf_kwargs = irf_kwargs if irf_kwargs is not None else {}
self.theta = nn.Embedding(self.user_num, latent_dim)
self.a = nn.Embedding(self.item_num, latent_dim)
self.b = nn.Embedding(self.item_num, 1)
self.value_range = value_range

def forward(self, user, item):
theta = torch.squeeze(self.theta(user), dim=-1)
a = torch.squeeze(self.a(item), dim=-1)
a = torch.sigmoid(a)
a = F.softplus(a)
b = torch.squeeze(self.b(item), dim=-1)
if self.value_range is not None:
a = self.value_range * a
if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b):
raise Exception('Error:theta,a,b may contains nan! The value_range is too large.')

return self.irf(theta, a, b, **self.irf_kwargs)

@classmethod
Expand All @@ -68,9 +65,9 @@ def irf(cls, theta, a, b, **kwargs):


class MIRT(CDM):
def __init__(self, user_num, item_num, latent_dim, value_range=None):
def __init__(self, user_num, item_num, latent_dim):
super(MIRT, self).__init__()
self.irt_net = MIRTNet(user_num, item_num, latent_dim, value_range)
self.irt_net = MIRTNet(user_num, item_num, latent_dim)

def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
loss_function = nn.BCELoss()
Expand Down

0 comments on commit 4097e75

Please sign in to comment.