Skip to content

Commit

Permalink
Fix First Token Time metric in generate (pytorch#942)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmpuri authored Jul 23, 2024
1 parent fab7b6c commit 4789750
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def generate(
**sampling_kwargs,
)
time_to_first_token = time.perf_counter() - prefill_t0
yield None, {"time_to_first_token": time_to_first_token}
seq[T] = next_token
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)
Expand Down Expand Up @@ -527,9 +528,9 @@ def generate(
accept_counts[len(next_tokens) - 1] += 1
num_added = min(T_new - input_pos - 1, len(next_tokens))
seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[:num_added]
for i in next_tokens[:num_added,]:
callback(i)
yield i, {}
for token in next_tokens[:num_added,]:
callback(token)
yield token, None
input_pos = input_pos + num_added
next_token = next_tokens[-1]
else:
Expand All @@ -550,7 +551,7 @@ def generate(
**sampling_kwargs,
):
generated_tokens.append(generated_token)
yield generated_token, {}
yield generated_token, None

seq[T + 1 : T + 1 + len(generated_tokens)] = torch.cat(generated_tokens)
seq = seq[
Expand All @@ -559,9 +560,8 @@ def generate(

generate_stats = {
"accept_counts": accept_counts,
"time_to_first_token": time_to_first_token,
}
return seq, generate_stats
yield None, generate_stats

def encode_tokens(self, string, bos=True, device="cpu"):
tokens = self.tokenizer.encode(string)
Expand Down Expand Up @@ -747,7 +747,7 @@ def callback(x, *, done_generating=False):
t0 = time.perf_counter()
num_tokens_generated = 0
with prof:
for y, metrics in self.generate(
generator_func = self.generate(
self.model,
encoded,
generator_args.max_new_tokens,
Expand All @@ -760,15 +760,14 @@ def callback(x, *, done_generating=False):
sequential_prefill=generator_args.sequential_prefill,
start_pos=start_pos,
max_seq_length=max_seq_length,
):
if metrics:
aggregate_metrics["accept_counts"].append(
metrics["accept_counts"]
)
start_pos += y.size(0)
num_tokens_generated += y.size(0)
yield y, metrics

)
for token_tensor, metrics in generator_func:
if token_tensor is not None:
start_pos += token_tensor.size(0)
num_tokens_generated += token_tensor.size(0)
if metrics is not None:
aggregate_metrics.update(metrics)
yield token_tensor, metrics
jit_compile = (i == 0) and (
generator_args.compile or generator_args.compile_prefill
)
Expand Down Expand Up @@ -798,7 +797,7 @@ def callback(x, *, done_generating=False):
# continue

logging.info(
f"\nTime for inference {i + 1}: {t:.02f} sec total, time to first token {metrics.get('time_to_first_token', 0.0):.02f} sec with {'sequential' if generator_args.sequential_prefill else 'parallel'} prefill, {num_tokens_generated} tokens, {tokens_sec:.02f} tokens/sec, {1000 / tokens_sec:.02f} ms/token"
f"\nTime for inference {i + 1}: {t:.02f} sec total, time to first token {aggregate_metrics.get('time_to_first_token', -1.0):.02f} sec with {'sequential' if generator_args.sequential_prefill else 'parallel'} prefill, {num_tokens_generated} tokens, {tokens_sec:.02f} tokens/sec, {1000 / tokens_sec:.02f} ms/token"
)
logging.info(
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
Expand Down

0 comments on commit 4789750

Please sign in to comment.