Skip to content

Commit

Permalink
Fix run_llava.py
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed Feb 2, 2024
1 parent 137173b commit 498e18d
Showing 1 changed file with 3 additions and 18 deletions.
21 changes: 3 additions & 18 deletions llava/eval/run_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def eval_model(args):

image_files = image_parser(args)
images = load_images(image_files)
image_sizes = [x.size for x in images]
images_tensor = process_images(
images,
image_processor,
Expand All @@ -107,36 +108,20 @@ def eval_model(args):
.cuda()
)

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=images_tensor,
image_sizes=image_sizes,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
max_new_tokens=args.max_new_tokens,
use_cache=True,
stopping_criteria=[stopping_criteria],
)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(
f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
)
outputs = tokenizer.batch_decode(
output_ids[:, input_token_len:], skip_special_tokens=True
)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(outputs)


Expand Down

0 comments on commit 498e18d

Please sign in to comment.