Skip to content

Commit

Permalink
Update decoder_layer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunwoongko authored Mar 3, 2021
1 parent 65d91f3 commit 489bb06
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions models/blocks/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
def forward(self, dec, enc, t_mask, s_mask):
# 1. compute self attention
_x = dec
x = self.self_attention(dec, dec, dec, mask=trg_mask)
x = self.self_attention(q=dec, k=dec, v=dec, mask=trg_mask)

# 2. add and norm
x = self.norm1(x + _x)
Expand All @@ -38,7 +38,7 @@ def forward(self, dec, enc, t_mask, s_mask):
if enc is not None:
# 3. compute encoder - decoder attention
_x = x
x = self.enc_dec_attention(x, enc, enc, mask=src_mask)
x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)

# 4. add and norm
x = self.norm2(x + _x)
Expand All @@ -51,5 +51,4 @@ def forward(self, dec, enc, t_mask, s_mask):
# 6. add and norm
x = self.norm3(x + _x)
x = self.dropout3(x)

return x

0 comments on commit 489bb06

Please sign in to comment.