Skip to content

Commit

Permalink
Revert changes for masking
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunwoongko authored and hyunwoongko committed Jun 19, 2023
1 parent 24ea43a commit 243672c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 32 deletions.
2 changes: 1 addition & 1 deletion models/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ def forward(self, trg, enc_src, trg_mask, src_mask):

# pass to LM head
output = self.linear(trg)
return output
return output
2 changes: 1 addition & 1 deletion models/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ def forward(self, x, s_mask):
for layer in self.layers:
x = layer(x, s_mask)

return x
return x
42 changes: 12 additions & 30 deletions models/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,37 +38,19 @@ def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_
device=device)

def forward(self, src, trg):
src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)

src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)

trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx) * \
self.make_no_peak_mask(trg, trg)

src_mask = self.make_src_mask(src)
trg_mask = self.make_trg_mask(trg)
enc_src = self.encoder(src, src_mask)
output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)
output = self.decoder(trg, enc_src, trg_mask, src_mask)
return output

def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
len_q, len_k = q.size(1), k.size(1)

# batch_size x 1 x 1 x len_k
k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
# batch_size x 1 x len_q x len_k
k = k.repeat(1, 1, len_q, 1)

# batch_size x 1 x len_q x 1
q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
# batch_size x 1 x len_q x len_k
q = q.repeat(1, 1, 1, len_k)

mask = k & q
return mask

def make_no_peak_mask(self, q, k):
len_q, len_k = q.size(1), k.size(1)

# len_q x len_k
mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor).to(self.device)
def make_src_mask(self, src):
src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
return src_mask

return mask
def make_trg_mask(self, trg):
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
trg_len = trg.shape[1]
trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to(self.device)
trg_mask = trg_pad_mask & trg_sub_mask
return trg_mask

0 comments on commit 243672c

Please sign in to comment.