Skip to content

Commit

Permalink
Update mutli-gpu end2end inference support.
Browse files Browse the repository at this point in the history
Update README.md about multi-round interaction.
  • Loading branch information
CaraJ7 committed Sep 24, 2024
1 parent d7fcaa7 commit be8be35
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 20 deletions.
1 change: 1 addition & 0 deletions README.md
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
---

## Annoucement
- [2024-09] 🎉🎉 We welcome the new task [MMSearch](https://mmsearch.github.io/).
- [2024-09] 🎉🎉 We welcome the new task [MME-RealWorld](https://mme-realworld.github.io/) for inference acceleration
- [2024-09] ⚙️️⚙️️️️ We upgrade `lmms-eval` to `0.2.3` with more tasks and features. We support a compact set of language tasks evaluations (code credit to [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)), and we remove the registration logic at start (for all models and tasks) to reduce the overhead. Now `lmms-eval` only launches necessary tasks/models. Please check the [release notes](https://github.com/EvolvingLMMs-Lab/lmms-eval/releases/tag/v0.2.3) for more details.
- [2024-08] 🎉🎉 We welcome the new model [LLaVA-OneVision](https://huggingface.co/papers/2408.03326), [Mantis](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/162), new tasks [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench), [LongVideoBench](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/117), [MMStar](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/158). We provide new feature of SGlang Runtime API for llava-onevision model, please refer the [doc](https://github.com/EvolvingLMMs-Lab/lmms-eval/blob/main/docs/commands.md) for inference acceleration
Expand Down
43 changes: 41 additions & 2 deletions docs/task_guide.md
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,44 @@ metadata:
- version: 0.0
```

Multi-round-generation-based tasks:

- MMSearch(`lmms_eval/tasks/mmsearch/mmsearch_end2end.yaml`)

```yaml
dataset_path: CaraJ/MMSearch
dataset_name: end2end
dataset_kwargs:
token: False
task: "mmsearch_end2end"
test_split: end2end
output_type: generate_until_multi_round # Note that here we use the new output_type here for multi-round generation. It basicly follows generate_until but incorporate multi-round inference
doc_to_visual: !function lmms_eval_utils.mmsearch_end2end_doc_to_visual
doc_to_text: !function lmms_eval_utils.mmsearch_end2end_doc_to_text
doc_to_target: "answer"
generation_kwargs:
until:
- "ASSISTANT:"
max_new_tokens: 512
temperature: 0
top_p: 0
num_beams: 1
do_sample: false
process_results: !function lmms_eval_utils.mmsearch_end2end_process_results
metric_list:
- metric: end2end_f1_score
aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_f1_score
higher_is_better: true
- metric: requery_score
aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_req_score
higher_is_better: true
lmms_eval_specific_kwargs: # Note that here we cache the result of every sample whenever the it is inferenced
middle_resules_dir: /data1/zrr/jdz/mmsearch/mmsearch_middile_results
result_cache_dir: /data1/zrr/jdz/mmsearch/mmsearch_result_cache_dir
```


## Configurations

Tasks are configured via the `TaskConfig` object. Below, we describe all fields usable within the object, and their role in defining a task.
Expand All @@ -96,8 +134,9 @@ Dataset configuration options:
- **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template.

Prompting / in-context formatting options:
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Column name or function to process a sample into the appropriate input for the model
- **doc_to_visial** (`Union[Callable, str]`, *optional*) — Function to process a sample into the appropriate input images for the model.
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Column name or function to process a sample into the appropriate input for the model.

For multi-round generation, (e.g., MMSearch), the function accepts additional parameters about the round index, previous round information and previous model output. It should return the input image for the next round, input text for the next round, a boolean indicating if round inference should terminate, model outputs from all rounds, and extra information from previous rounds.
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Column name or or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Column name or or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.

Expand Down
3 changes: 2 additions & 1 deletion lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from PIL import ImageFile
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed
from tqdm import tqdm
from functools import partial

from lmms_eval import utils
from lmms_eval.api import samplers
Expand Down Expand Up @@ -1382,7 +1383,7 @@ def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Inst
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, copy.deepcopy(self.config.generation_kwargs), self.doc_to_visual, doc_id, self.config.task, split)
elif self.OUTPUT_TYPE == "generate_until_multi_round":
arguments = (ctx, copy.deepcopy(self.config.generation_kwargs), self.doc_to_visual, self.config.doc_to_text, doc_id, self.config.task, split)
arguments = (ctx, copy.deepcopy(self.config.generation_kwargs), self.doc_to_visual, partial(self.config.doc_to_text, lmms_eval_specific_kwargs=self.lmms_eval_specific_kwargs), doc_id, self.config.task, split)
return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs)

# TODO: we add a full_docs interface here for some evaluations that needs to access the full datasets during process_results function. we may have better ways to handle this.
Expand Down
4 changes: 2 additions & 2 deletions lmms_eval/tasks/mmsearch/get_final_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ def parse_args():
for subfield in all_task_result_summary["end2end"]["subfield_dict"]:
final_result_summary["subfield_dict"][subfield] = sum([ratio * all_task_result_summary[task]["subfield_dict"][subfield]["average"] for task, ratio in task_ratio_dict.items()])

logger.info(f"Average final score: {final_result_summary['total_dict']['average']}")
json.dump(final_result_summary, open(args.save_path, "w"))
print(f"Average final score: {final_result_summary['total_dict']['average']}")
json.dump(final_result_summary, open(args.save_path, "w"), indent=4)
28 changes: 19 additions & 9 deletions lmms_eval/tasks/mmsearch/lmms_eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@


def mmsearch_end2end_doc_to_text(doc, lmms_eval_specific_kwargs=None, previous_output=None, round_idx=None, previous_round_info=None):
'''
Returns:
visuals (for next round)
contexts (for next round)
terminal_signal
round_result
previous_round_info
'''
# prepare save dir
middle_result_dir = lmms_eval_specific_kwargs["middle_resules_dir"] if lmms_eval_specific_kwargs is not None and "middle_resules_dir" in lmms_eval_specific_kwargs else "mmsearch_middile_results"
result_cache_dir = lmms_eval_specific_kwargs["result_cache_dir"] if lmms_eval_specific_kwargs is not None and "result_cache_dir" in lmms_eval_specific_kwargs else "mmsearch_result_cache_dir"
Expand All @@ -51,9 +59,9 @@ def mmsearch_end2end_doc_to_text(doc, lmms_eval_specific_kwargs=None, previous_o
query_has_image = True
prompt_template_dict = image_search_text_query_dict
query = doc["query"]
eval_logger.info(query)

# initial round: round_idx is None. This remains the same output format as other benchmark
eval_logger.info('----------------Round1: Requery----------------')
if round_idx is None:
prompt_template = prompt_template_dict["stage1"]
if not query_has_image:
Expand All @@ -63,13 +71,13 @@ def mmsearch_end2end_doc_to_text(doc, lmms_eval_specific_kwargs=None, previous_o
return text_query
# round2: search result + rerank
if round_idx == 1:
eval_logger.info("-------------------Stage2-------------------")

# if exist, return. This check has to be done here to avoid many
cache_path = os.path.join(result_cache_dir, f"{doc['sample_id']}.json")
if os.path.exists(cache_path):
eval_logger.info(f"{doc['sample_id']} already exists. Load the cache result.")
round_res = json.load(open(cache_path))["round_res"]
return None, None, True, round_res, None
eval_logger.info('----------------Round2: Rerank----------------')
# prepare
requery = previous_output[-1]
stage1_screenshot_dir = os.path.join(middle_result_dir, doc["sample_id"], "stage1")
Expand All @@ -83,14 +91,14 @@ def mmsearch_end2end_doc_to_text(doc, lmms_eval_specific_kwargs=None, previous_o
return None, None, True, round_res, None

website_information, input_image_list = get_website_information(result_brief)
input_image_list = [Image.open(f) for f in input_image_list]
input_image_list = [Image.open(f).convert("RGB") for f in input_image_list]

prompt_template = prompt_template_dict["stage2"]
if not query_has_image:
image_files = input_image_list
text_query = prompt_template.format(brief_result_num=brief_result_num, rerank_num=fullpage_num, question=query, website_information=website_information, incontext_example=get_rerank_incontext_example(fullpage_num))
else:
image_files = [doc["query_image"], doc["image_search_result"], *input_image_list]
image_files = [doc["query_image"].convert("RGB"), doc["image_search_result"].convert("RGB"), *input_image_list]
text_query = prompt_template.format(
brief_result_num=brief_result_num,
rerank_num=fullpage_num,
Expand All @@ -104,6 +112,7 @@ def mmsearch_end2end_doc_to_text(doc, lmms_eval_specific_kwargs=None, previous_o
return image_files, text_query, False, previous_output, dict(result_brief=result_brief)
# round3: get full page + summarization
if round_idx == 2:
eval_logger.info('----------------Round3: Summarization----------------')
# prepare
stage3_screenshot_dir = os.path.join(middle_result_dir, doc["sample_id"], "stage3")
requery = previous_output[0]
Expand All @@ -129,7 +138,7 @@ def mmsearch_end2end_doc_to_text(doc, lmms_eval_specific_kwargs=None, previous_o

website_full_information, input_image_list = get_full_website_information(result_full=result_full, image_dir=stage3_screenshot_dir, fullpage_split_dict=FULLPAGE_SPLIT_DICT)

input_image_list = [Image.open(f) for f in input_image_list]
input_image_list = [Image.open(f).convert("RGB") for f in input_image_list]
# text_query and input_image_list
prompt_template = prompt_template_dict["stage3"]
if not query_has_image:
Expand All @@ -140,7 +149,7 @@ def mmsearch_end2end_doc_to_text(doc, lmms_eval_specific_kwargs=None, previous_o
question=query,
)
else:
image_files = [*input_image_list, doc["image_search_result"], doc["query_image"]]
image_files = [*input_image_list, doc["image_search_result"].convert("RGB"), doc["query_image"].convert("RGB")]
# assume only 1 image in the query
text_query = prompt_template.format(rerank_num=fullpage_num, website_information=website_full_information, image_search_result=DEFAULT_IMAGE_TOKEN, question=DEFAULT_IMAGE_TOKEN + query)

Expand Down Expand Up @@ -275,6 +284,7 @@ def mmsearch_end2end_process_results(doc, results):
"area": doc["area"],
"subfield": doc["subfield"],
"gt_answer": doc["gt_answer"],
"gt_requery": doc["gt_requery"],
"alternative_gt_answers": doc["alternative_gt_answers"],
"requery_prediction": round_res[0],
"answer_prediction": round_res[2],
Expand Down Expand Up @@ -356,7 +366,7 @@ def mmsearch_aggregate_results_f1_score(results, args, *, calculate_gain=False,
def mmsearch_aggregate_results_req_score(results, args, *, calculate_gain=False, random_scores=None):
result_list = []
for inst in results:
prediction = inst["requery_prediction"]
requery = inst["requery_prediction"]
gt_requery = inst["gt_requery"]
req_score = get_requery_score(requery, gt_requery)
inst.update(
Expand All @@ -367,7 +377,7 @@ def mmsearch_aggregate_results_req_score(results, args, *, calculate_gain=False,
)
result_list.append(inst)

# assert len(result_list) == 300 # assert to be the benchmark length, or the get_result_summary function will not work
assert len(result_list) == 300 # assert to be the benchmark length, or the get_result_summary function will not work
# save results
path = generate_submission_file(f"{args.tasks}_requery_results.json", args)
with open(path, "w") as f:
Expand Down
6 changes: 3 additions & 3 deletions lmms_eval/tasks/mmsearch/mmsearch_end2end.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dataset_path: CaraJ/MMSearch
dataset_name: end2end
dataset_kwargs:
token: False
task: "mmsearch_end2end2"
task: "mmsearch_end2end"
test_split: end2end
output_type: generate_until_multi_round
doc_to_visual: !function lmms_eval_utils.mmsearch_end2end_doc_to_visual
Expand All @@ -27,5 +27,5 @@ metric_list:
aggregation: !function lmms_eval_utils.mmsearch_aggregate_results_req_score
higher_is_better: true
lmms_eval_specific_kwargs: # whenever a sample is infered, save it
middle_resules_dir: mmsearch_logs/mmsearch_middile_results
result_cache_dir: mmsearch_logs/mmsearch_result_cache_dir
middle_resules_dir: /data1/zrr/jdz/mmsearch/mmsearch_middile_results
result_cache_dir: /data1/zrr/jdz/mmsearch/mmsearch_result_cache_dir
3 changes: 2 additions & 1 deletion lmms_eval/tasks/mmsearch/retrieve_content/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(self):
self.tokenizer_offsets.settings["do_sliding_window_passages"] = self.config.slidew
self.tokenizer_offsets.settings["respect_sent_boundaries"] = self.config.sentb
# define retrieval model
self.model = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
# self.model = BGEM3FlagModel("BAAI/bge-m3", device='cpu', use_fp16=False) # Setting use_fp16 to True speeds up computation with a slight performance degradation
self.model = BGEM3FlagModel("/data1/zrr/jdz/models/bge-m3", device='cpu', use_fp16=False) # Setting use_fp16 to True speeds up computation with a slight performance degradation

def split_doc_into_passages(self, doc):
text = doc
Expand Down
11 changes: 9 additions & 2 deletions lmms_eval/tasks/mmsearch/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
from lmms_eval.tasks.mmsearch.constants import *
from lmms_eval.tasks.mmsearch.utils.web_content_utils import *

# get rank id for random seed
from accelerate import Accelerator
accelerator = Accelerator()
WORLD_SIZE = accelerator.num_processes
RANK = accelerator.process_index
random.seed(RANK)

### Proxy setting
def get_proxy_settings():
Expand Down Expand Up @@ -71,7 +77,7 @@ def query(self, text: str, max_results: int) -> List[Dict[str, Any]]:

for attempt in range(max_retries):
try:
time.sleep(5) # Avoid frequent requests
time.sleep(random.choice([i for i in range(5, 10+20*WORLD_SIZE, 5)])) # Avoid frequent requests and multiple rank query at the same time
response = list(self.ddgs.text(" ".join(text.strip("'").split(" ")[:100]), max_results=max_results))
return response[:max_results]
except Exception as e:
Expand Down Expand Up @@ -101,7 +107,8 @@ async def __call__(self, query: str, screenshot_dir_path: str) -> List[Dict[str,
try:
output = self.api_wrapper.query(query, max_results=self.max_results + 20) # account for error website
except Exception as e:
eval_logger.error(f"DDGSQueryRun call failed: {e}")
eval_logger.error(f"DDGSQueryRun call failed:")
eval_logger.error(f"{e}")
output = []

evidences = []
Expand Down

0 comments on commit be8be35

Please sign in to comment.