Skip to content

Commit

Permalink
llama logits优化
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Aug 14, 2024
1 parent 220d7b6 commit 4d1b144
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,19 @@ namespace fastllm {
std::vector <Data> curLogits;
curLogits.resize(batch);

if (batch > 1 && !all1) {
int total = 0;
std::vector <Data> lastTokens;
std::vector <Data*> lastTokenPointers;
lastTokens.resize(seqLens.size());
for (int b = 0; b < seqLens.size(); b++) {
Split(hiddenStates, 1, total + seqLens[b] - 1, total + seqLens[b], lastTokens[b]);
total += seqLens[b];
lastTokenPointers.push_back(&lastTokens[b]);
}
CatBatch(lastTokenPointers, 1, hiddenStates);
}

RMSNorm(hiddenStates, weight["model.norm.weight"], rms_norm_eps, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
ToDataType(logits, DataType::FLOAT32);
Expand All @@ -832,7 +845,7 @@ namespace fastllm {
maxTopK = std::max(maxTopK, generationConfigs[b].top_k);
}

if (all1 && batch > 1 && allSimple) {
if (batch > 1 && allSimple) {
Data topk;
TopK(logits, topk, 1);
topk.ToDevice(DataDevice::CPU);
Expand All @@ -841,7 +854,7 @@ namespace fastllm {
lastRet.push_back((int) (topkData[0] + 1e-3));
topkData += topk.Count(2);
}
} else if (all1 && batch > 1 && maxTopK <= 50 && !needLogits) {
} else if (batch > 1 && maxTopK <= 50 && !needLogits) {
int maxTokenSetSize = 0;
for (int b = 0; b < batch; b++) {
maxTokenSetSize = std::max(maxTokenSetSize, (int)lastTokens.units[b].tokenSet.size());
Expand All @@ -867,17 +880,10 @@ namespace fastllm {
lastRet.push_back(LLMSamplingOnly(topk, b, generationConfigs[b]));
}
} else {
if (all1 && batch > 1) {
for (int b = 0; b < batch; b++) {
pointersK[b] = (&curLogits[b]);
}
SplitBatch(logits, 1, batch, pointersK);
} else {
for (int b = 0; b < batch; b++) {
Split(logits, 1, total + seqLens[b] - 1, total + seqLens[b], curLogits[b]);
total += seqLens[b];
}
for (int b = 0; b < batch; b++) {
pointersK[b] = (&curLogits[b]);
}
SplitBatch(logits, 1, batch, pointersK);

for (int b = 0; b < batch; b++) {
Data &curLogit = curLogits[b];
Expand Down

0 comments on commit 4d1b144

Please sign in to comment.