Skip to content

Commit

Permalink
add cache to speed up generating
Browse files Browse the repository at this point in the history
SkyTNT committed Oct 8, 2024
1 parent 3f1edd3 commit 74f8bfe
Showing 2 changed files with 35 additions and 11 deletions.
11 changes: 9 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
import torch.nn.functional as F
import tqdm
from huggingface_hub import hf_hub_download
from transformers import DynamicCache

import MIDI
from midi_model import MIDIModel, config_name_list, MIDIModelConfig
@@ -49,12 +50,14 @@ def generate(prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k
input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
cur_len = input_tensor.shape[1]
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
cache1 = DynamicCache()
with bar:
while cur_len < max_len:
end = [False] * batch_size
hidden = model.forward(input_tensor)[:, -1]
hidden = model.forward(input_tensor[:,-1:], cache=cache1)[:, -1]
next_token_seq = None
event_names = [""] * batch_size
cache2 = DynamicCache()
for i in range(max_token_seq):
mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=model.device)
for b in range(batch_size):
@@ -79,7 +82,11 @@ def generate(prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k
mask_ids = [i for i in mask_ids if i not in disable_channels]
mask[b, mask_ids] = 1
mask = mask.unsqueeze(1)
logits = model.forward_token(hidden, next_token_seq)[:, -1:]
x = next_token_seq
if i != 0:
hidden = None
x = x[:, -1:]
logits = model.forward_token(hidden, x, cache=cache2)[:, -1:]
scores = torch.softmax(logits / temp, dim=-1) * mask
samples = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
if i == 0:
35 changes: 26 additions & 9 deletions midi_model.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import tqdm
import lightning as pl
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
from transformers import LlamaModel, LlamaConfig
from transformers import LlamaModel, LlamaConfig, DynamicCache
from transformers.integrations import PeftAdapterMixin

from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
@@ -78,30 +78,40 @@ def load_merge_lora(self, model_id):
set_peft_model_state_dict(self, adapter_state_dict, "default")
return model.merge_and_unload()

def forward_token(self, hidden_state, x=None):
def forward_token(self, hidden_state=None, x=None, cache=None):
"""
:param hidden_state: (batch_size, n_embd)
:param x: (batch_size, token_sequence_length)
:param cache: Cache
:return: (batch_size, 1 + token_sequence_length, vocab_size)
"""
hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
if hidden_state is not None:
#if you use cache, you don't need to pass in hidden_state
hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
if x is not None:
x = self.net_token.embed_tokens(x)
hidden_state = torch.cat([hidden_state, x], dim=1)
hidden_state = self.net_token.forward(inputs_embeds=hidden_state).last_hidden_state
if hidden_state is not None:
x = torch.cat([hidden_state, x], dim=1)
hidden_state = x
hidden_state = self.net_token.forward(inputs_embeds=hidden_state,
past_key_values=cache,
use_cache=cache is not None).last_hidden_state
return self.lm_head(hidden_state)

def forward(self, x):
def forward(self, x, cache = None):
"""
:param x: (batch_size, midi_sequence_length, token_sequence_length)
:param cache: Cache
:return: hidden (batch_size, midi_sequence_length, n_embd)
"""

# merge token sequence
x = self.net.embed_tokens(x)
x = x.sum(dim=-2)
x = self.net.forward(inputs_embeds=x)
x = self.net.forward(inputs_embeds=x,
past_key_values=cache,
use_cache=cache is not None)
return x.last_hidden_state

def sample_top_p_k(self, probs, p, k, generator=None):
@@ -144,12 +154,14 @@ def generate(self, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98,

cur_len = input_tensor.shape[1]
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
cache1 = DynamicCache()
with bar:
while cur_len < max_len:
end = [False] * batch_size
hidden = self.forward(input_tensor)[:, -1]
hidden = self.forward(input_tensor[:,-1:], cache=cache1)[:, -1]
next_token_seq = None
event_names = [""] * batch_size
cache2 = DynamicCache()
for i in range(max_token_seq):
mask = torch.zeros((batch_size, tokenizer.vocab_size), dtype=torch.int64, device=self.device)
for b in range(batch_size):
@@ -165,7 +177,12 @@ def generate(self, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98,
continue
mask[b, tokenizer.parameter_ids[param_names[i - 1]]] = 1
mask = mask.unsqueeze(1)
logits = self.forward_token(hidden, next_token_seq)[:, -1:]
x = next_token_seq
if i != 0:
# cached
hidden = None
x = x[:, -1:]
logits = self.forward_token(hidden, x, cache=cache2)[:, -1:]
scores = torch.softmax(logits / temp, dim=-1) * mask
samples = self.sample_top_p_k(scores, top_p, top_k, generator=generator)
if i == 0:

0 comments on commit 74f8bfe

Please sign in to comment.