Skip to content

Commit

Permalink
saving the qk matrix in the attention module for convenience
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Dec 30, 2022
1 parent 0b5dcfd commit 68e44bd
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions whisper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, n_state: int, n_head: int):
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.last_qk = None

def forward(
self,
Expand Down Expand Up @@ -96,6 +97,8 @@ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor]
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]

self.last_qk = qk.detach()

w = F.softmax(qk.float(), dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)

Expand Down

0 comments on commit 68e44bd

Please sign in to comment.