Skip to content

Commit

Permalink
+ MHA_shift
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Aug 13, 2021
1 parent 4096fff commit a31a3b2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 16 deletions.
37 changes: 29 additions & 8 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def forward(self, x):
k = torch.exp(k)
sum_k = torch.cumsum(k, dim=1)

k = k.view(B, T, self.n_head, self.head_size)
v = v.view(B, T, self.n_head, self.head_size)
kv = (k * v).view(B, T, self.n_head, self.head_size)

wkv = (torch.einsum('htu,buhc->bthc', w, kv)).contiguous().view(B, T, C)

wkv = (torch.einsum('htu,buhc->bthc', w, k * v)).contiguous().view(B, T, C)
rwkv = torch.sigmoid(r) * wkv / sum_k

return self.output(rwkv) * self.time_gamma[:T, :]
Expand All @@ -83,6 +83,7 @@ def forward(self, x):
r = self.receptance(x)

wkv = self.weight(F.mish(k) * v) # seems mish is a bit better than gelu

rwkv = torch.sigmoid(r) * wkv

return rwkv
Expand Down Expand Up @@ -120,14 +121,17 @@ def apply_rotary_pos_emb(q, k, cos, sin):
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class MHA_rotary(nn.Module):
def __init__(self, config, layer_id):
def __init__(self, config, layer_id, time_shift = False):
super().__init__()
self.layer_id = layer_id
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.ctx_len = config.ctx_len
self.head_size = config.n_embd // config.n_head

if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,0))

self.query = nn.Linear(config.n_embd, config.n_embd)
self.key = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
Expand All @@ -142,6 +146,9 @@ def __init__(self, config, layer_id):
def forward(self, x):
B, T, C = x.size()

if hasattr(self, 'time_shift'):
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)

q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
Expand All @@ -160,19 +167,27 @@ def forward(self, x):
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
x = x.transpose(1, 2).contiguous().view(B, T, C) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)

x = self.output(x) # output projection
x = self.output(x)
return x

class GeGLU(torch.nn.Module):
def __init__(self, config, layer_id):
def __init__(self, config, layer_id, time_shift = False):
super().__init__()
self.layer_id = layer_id

if time_shift:
self.time_shift = nn.ZeroPad2d((0,0,1,0))

hidden_sz = 3 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)

def forward(self, x):
B, T, C = x.size()
if hasattr(self, 'time_shift'):
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)

k = self.key(x)
v = self.value(x)
y = self.weight(F.gelu(k) * v)
Expand Down Expand Up @@ -205,7 +220,7 @@ def __init__(self, config, layer_id):
self.rotary_ndims = int(self.head_size * 0.5)
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)

self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads
self.head_mix = nn.Conv2d(self.n_head, self.n_head, kernel_size=1, bias=False) # talking heads

self.output = nn.Linear(config.n_embd, config.n_embd)

Expand All @@ -218,7 +233,7 @@ def forward(self, x):
w = w[:, :, TT-1:] # w is now a circulant matrix
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]

x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-mixing
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-shift mixing
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
Expand Down Expand Up @@ -300,9 +315,15 @@ def __init__(self, config, layer_id):
if config.model_type == 'RWKV':
self.attn = RWKV_TimeMix(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)

elif config.model_type == 'MHA_rotary':
self.attn = MHA_rotary(config, layer_id)
self.mlp = GeGLU(config, layer_id)

elif config.model_type == 'MHA_shift':
self.attn = MHA_rotary(config, layer_id, time_shift=True)
self.mlp = GeGLU(config, layer_id, time_shift=True)

elif config.model_type == 'MHA_pro':
self.attn = MHA_pro(config, layer_id)
self.mlp = RWKV_ChannelMix(config, layer_id)
Expand Down
4 changes: 3 additions & 1 deletion src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ def __init__(self, model, train_dataset, test_dataset, config):

if 'wandb' in sys.modules:
cfg = model.config
for k in config.__dict__:
setattr(cfg, k, config.__dict__[k]) # combine cfg
run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
wandb.init(project="RWKV-LM", name=run_name + '-' + wandb.util.generate_id(), config=config, save_code=False)
wandb.init(project="RWKV-LM", name=run_name + '-' + wandb.util.generate_id(), config=cfg, save_code=False)

# take over whatever gpus are on the system
self.device = 'cpu'
Expand Down
15 changes: 8 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,)

# RWKV - our new model - fastest when ctx_len is long - VRAM friendly - good performance
# MHA_rotary - usual Multi-head Attention+Rotary+GeGLU - not as good
# MHA_pro - slow (lots of tricks) - VRAM hungry - good performance
model_type = 'RWKV' # 'RWKV' or 'MHA_rotary' or 'MHA_pro'
# RWKV : our new model - fastest when ctx_len is long - VRAM friendly - good performance
# MHA_rotary : usual Multi-head Attention+Rotary+GeGLU - not as good
# MHA_shift : with time-shift - good performance
# MHA_pro : slow (lots of tricks) - VRAM hungry - very good performance
model_type = 'RWKV'

# datafile = u"V:\\NLP\\text8"
# datafile = u"V:\\NLP\\enwik8"
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" # https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt"
datafile_encoding = 'utf-8'
# datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt"
# datafile_encoding = 'utf-16'
Expand Down Expand Up @@ -60,8 +61,8 @@ def __init__(self, data, model_level, ctx_len):
print('splitting token...')
data = data.lower().split(' ')
unique = sorted(list(set(data)))
for u in unique:
print(u, end=' ')
# for u in unique:
# print(u, end=' ')
data_size, vocab_size = len(data), len(unique)
print('\n\ndata has %d %ss, %d unique.' % (data_size, model_level, vocab_size))
self.stoi = { ch:i for i,ch in enumerate(unique) }
Expand Down

0 comments on commit a31a3b2

Please sign in to comment.