Skip to content

Commit

Permalink
Fix bug in README.md
Browse files Browse the repository at this point in the history
Related: hyunwoongko#4
  • Loading branch information
ayaka14732 committed Jan 18, 2022
1 parent 8f7aaa1 commit 665a671
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class MultiHeadAttention(nn.Module):
batch_size, length, d_model = tensor.size()

d_tensor = d_model // self.n_head
tensor = tensor.view(batch_size, self.n_head, length, d_tensor)
tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
# it is similar with group convolution (split by number of heads)

return tensor
Expand All @@ -117,7 +117,7 @@ class MultiHeadAttention(nn.Module):
batch_size, head, length, d_tensor = tensor.size()
d_model = head * d_tensor

tensor = tensor.view(batch_size, length, d_model)
tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
return tensor
```
<br><br>
Expand All @@ -138,15 +138,15 @@ class ScaleDotProductAttention(nn.Module):

def __init__(self):
super(ScaleDotProductAttention, self).__init__()
self.softmax = nn.Softmax()
self.softmax = nn.Softmax(dim=-1)

def forward(self, q, k, v, mask=None, e=1e-12):
# input is 4 dimension tensor
# [batch_size, head, length, d_tensor]
batch_size, head, length, d_tensor = k.size()

# 1. dot product Query with Key^T to compute similarity
k_t = k.view(batch_size, head, d_tensor, length) # transpose
k_t = k.transpose(2, 3) # transpose
score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product

# 2. apply masking (opt)
Expand Down

0 comments on commit 665a671

Please sign in to comment.