Skip to content

Commit

Permalink
[feat] use PosLinear to replace clipper operation
Browse files Browse the repository at this point in the history
  • Loading branch information
tswsxk committed Jan 29, 2022
1 parent ee5ddfe commit 166e094
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions EduCDM/NCDM/NCDM.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from EduCDM import CDM


class PosLinear(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = 2 * F.relu(1 * torch.neg(self.weight)) + self.weight
return F.linear(input, weight, self.bias)


class Net(nn.Module):

def __init__(self, knowledge_n, exer_n, student_n):
Expand All @@ -28,11 +34,11 @@ def __init__(self, knowledge_n, exer_n, student_n):
self.student_emb = nn.Embedding(self.emb_num, self.stu_dim)
self.k_difficulty = nn.Embedding(self.exer_n, self.knowledge_dim)
self.e_difficulty = nn.Embedding(self.exer_n, 1)
self.prednet_full1 = nn.Linear(self.prednet_input_len, self.prednet_len1)
self.prednet_full1 = PosLinear(self.prednet_input_len, self.prednet_len1)
self.drop_1 = nn.Dropout(p=0.5)
self.prednet_full2 = nn.Linear(self.prednet_len1, self.prednet_len2)
self.prednet_full2 = PosLinear(self.prednet_len1, self.prednet_len2)
self.drop_2 = nn.Dropout(p=0.5)
self.prednet_full3 = nn.Linear(self.prednet_len2, 1)
self.prednet_full3 = PosLinear(self.prednet_len2, 1)

# initialize
for name, param in self.named_parameters():
Expand All @@ -53,22 +59,6 @@ def forward(self, stu_id, input_exercise, input_knowledge_point):

return output_1.view(-1)

def apply_clipper(self):
clipper = NoneNegClipper()
self.prednet_full1.apply(clipper)
self.prednet_full2.apply(clipper)
self.prednet_full3.apply(clipper)


class NoneNegClipper(object):
def __init__(self):
super(NoneNegClipper, self).__init__()

def __call__(self, module):
if hasattr(module, 'weight'):
w = module.weight.data
module.weight.data = torch.clamp(w, min=0.).detach()


class NCDM(CDM):
'''Neural Cognitive Diagnosis Model'''
Expand Down Expand Up @@ -98,7 +88,6 @@ def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, si
optimizer.zero_grad()
loss.backward()
optimizer.step()
self.ncdm_net.apply_clipper()

epoch_losses.append(loss.mean().item())

Expand Down

0 comments on commit 166e094

Please sign in to comment.