-
Notifications
You must be signed in to change notification settings - Fork 0
/
abstract_attention.py
56 lines (44 loc) · 1.7 KB
/
abstract_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import Optional, Tuple, Dict
import torch.nn as nn
from torch import Tensor
class AbstractAttention(nn.Module):
def __init__(self, cross=False, causal=False, **kwargs) -> None:
super(AbstractAttention, self).__init__()
self.name = f'{self.__class__.__name__}.{hash(self)}'
self.causal=causal
self.cross=cross
def _reset_parameters(self):
raise NotImplementedError
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
query_padding_mask: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
need_head_weights: bool = False,
attn_mask: Optional[Tensor] = None,
static_kv: bool = False,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
**kwargs
) -> Tuple[Tensor, Optional[Tensor]]:
raise NotImplementedError
def _get_input_buffer(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
return incremental_state[self.name] if incremental_state and self.name in incremental_state is not None else {}
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
incremental_state[self.name] = buffer
return incremental_state
def _apply_attention(self, *args, **kwargs):
raise NotImplementedError
def _get_saved_states(self, *args, **kwargs):
raise NotImplementedError
def _update_saved_states(self, *args, **kwargs):
raise NotImplementedError