Skip to content

Commit

Permalink
Merge pull request bigdata-ustc#32 from ViviHong200709/main
Browse files Browse the repository at this point in the history
[BUGFIX] Fix logical bug in IRTNet.forward and MIRTNet.forward
  • Loading branch information
tswsxk authored Oct 15, 2021
2 parents 49f7cc2 + ac6c755 commit 6ba2e91
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
2 changes: 2 additions & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@

[Fangzhou Yao](https://github.com/fannazya)

[Yuting Hong](https://github.com/ViviHong200709)

4 changes: 2 additions & 2 deletions EduCDM/IRR/IRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@


class IRT(PointIRT):
def __init__(self, user_num, item_num, knowledge_num, zeta=0.5):
super(IRT, self).__init__(user_num, item_num)
def __init__(self, user_num, item_num, knowledge_num, value_range=10, zeta=0.5):
super(IRT, self).__init__(user_num, item_num, value_range=value_range)
self.knowledge_num = knowledge_num
self.zeta = zeta

Expand Down
13 changes: 9 additions & 4 deletions EduCDM/IRT/GD/IRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class IRTNet(nn.Module):
def __init__(self, user_num, item_num, irf_kwargs=None):
def __init__(self, user_num, item_num, value_range, irf_kwargs=None):
super(IRTNet, self).__init__()
self.user_num = user_num
self.item_num = item_num
Expand All @@ -21,23 +21,28 @@ def __init__(self, user_num, item_num, irf_kwargs=None):
self.a = nn.Embedding(self.item_num, 1)
self.b = nn.Embedding(self.item_num, 1)
self.c = nn.Embedding(self.item_num, 1)
self.value_range = value_range

def forward(self, user, item):
theta = torch.squeeze(self.theta(user), dim=-1)
theta = self.value_range * (torch.sigmoid(theta) - 0.5)
a = torch.squeeze(self.a(item), dim=-1)
a = torch.sigmoid(a)
b = torch.squeeze(self.b(item), dim=-1)
b = self.value_range * (torch.sigmoid(b) - 0.5)
c = torch.squeeze(self.c(item), dim=-1)
return torch.sigmoid(self.irf(theta, a, b, c, **self.irf_kwargs))
c = torch.sigmoid(c)
return self.irf(theta, a, b, c, **self.irf_kwargs)

@classmethod
def irf(cls, theta, a, b, c, **kwargs):
return irt3pl(theta, a, b, c, F=torch, **kwargs)


class IRT(CDM):
def __init__(self, user_num, item_num):
def __init__(self, user_num, item_num, value_range=10):
super(IRT, self).__init__()
self.irt_net = IRTNet(user_num, item_num)
self.irt_net = IRTNet(user_num, item_num, value_range)

def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
loss_function = nn.BCELoss()
Expand Down
2 changes: 1 addition & 1 deletion EduCDM/MIRT/MIRT.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def forward(self, user, item):
theta = torch.squeeze(self.theta(user), dim=-1)
a = torch.squeeze(self.a(item), dim=-1)
b = torch.squeeze(self.b(item), dim=-1)
return torch.sigmoid(self.irf(theta, a, b, **self.irf_kwargs))
return self.irf(theta, a, b, **self.irf_kwargs)

@classmethod
def irf(cls, theta, a, b, **kwargs):
Expand Down

0 comments on commit 6ba2e91

Please sign in to comment.